Sfoglia il codice sorgente

Update kmeans training code & vits

Lengyue 2 anni fa
parent
commit
bd79e89c8a

+ 23 - 6
fish_speech/callbacks/grad_norm.py

@@ -1,4 +1,4 @@
-from typing import Union
+from typing import Optional, Union
 
 import lightning.pytorch as pl
 import torch
@@ -32,6 +32,9 @@ def grad_norm(
         parameters = [parameters]
 
     grads = [p.grad for p in parameters if p.grad is not None]
+    if len(grads) == 0:
+        return None
+
     first_device = grads[0].device
     grouped_grads: dict[
         tuple[torch.device, torch.dtype], list[list[Tensor]]
@@ -54,15 +57,22 @@ class GradNormMonitor(Callback):
     Callback that computes the gradient norm of the model parameters.
     """
 
-    def __init__(self, norm_type: float = 2.0, logging_interval: str = "step") -> None:
+    def __init__(
+        self,
+        norm_type: float = 2.0,
+        logging_interval: str = "step",
+        sub_module: Optional[str] = None,
+    ) -> None:
         """
         Args:
             norm_type (float): type of the used p-norm.
             logging_interval (str): "step" or "epoch".
         """
         super().__init__()
+
         self.norm_type = norm_type
         self.logging_interval = logging_interval
+        self.sub_module = sub_module
 
     def on_after_backward(self, trainer: Trainer, model: LightningModule) -> None:
         """
@@ -73,13 +83,20 @@ class GradNormMonitor(Callback):
             model (LightningModule): The current lightningModule
         """
 
-        grad_norm_val = grad_norm(model.parameters(), self.norm_type)
+        lightning_model = model
 
-        model_name = model.__class__.__name__.lower()
+        path = ""
+        if self.sub_module is not None:
+            model = getattr(model, self.sub_module)
+            path = f"/{self.sub_module}"
+
+        grad_norm_val = grad_norm(model.parameters(), self.norm_type)
+        if grad_norm_val is None:
+            return
 
         on_step = self.logging_interval == "step"
-        model.log(
-            f"train/{model_name}/grad_norm",
+        lightning_model.log(
+            f"train{path}/grad_norm",
             grad_norm_val,
             on_step=on_step,
             on_epoch=not on_step,

+ 11 - 3
fish_speech/configs/hubert_vq.yaml

@@ -7,6 +7,7 @@ project: hubert_vq
 # Lightning Trainer
 trainer:
   accelerator: gpu
+  devices: 4
   strategy:
     _target_: lightning.pytorch.strategies.DDPStrategy
     static_graph: true
@@ -22,20 +23,23 @@ train_dataset:
   _target_: fish_speech.datasets.vqgan.VQGANDataset
   filelist: data/vq_train_filelist.txt
   sample_rate: ${sample_rate}
+  hop_length: ${hop_length}
+  slice_frames: 32
 
 val_dataset:
   _target_: fish_speech.datasets.vqgan.VQGANDataset
   filelist: data/vq_val_filelist.txt
   sample_rate: ${sample_rate}
+  hop_length: ${hop_length}
+  slice_frames: null
 
 data:
   _target_: fish_speech.datasets.vqgan.VQGANDataModule
   train_dataset: ${train_dataset}
   val_dataset: ${val_dataset}
   num_workers: 4
-  batch_size: 8
+  batch_size: 32
   val_batch_size: 4
-  hop_length: ${hop_length}
 
 # Model Configuration
 model:
@@ -94,10 +98,14 @@ model:
     lr_lambda:
       _target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda
       _partial_: true
-      num_warmup_steps: 2000
+      num_warmup_steps: 0
       num_training_steps: ${trainer.max_steps}
       final_lr_ratio: 0.05
 
+callbacks:
+  grad_norm_monitor:
+    sub_module: generator
+
 # Resume from rcell's checkpoint
 ckpt_path: results/hubert-vq-pretrain/rcell/ckpt_23000_pl.pth
 resume_weights_only: true

+ 31 - 10
fish_speech/datasets/vqgan.py

@@ -8,12 +8,18 @@ import torch
 from lightning import LightningDataModule
 from torch.utils.data import DataLoader, Dataset
 
+from fish_speech.utils import RankedLogger
+
+logger = RankedLogger(__name__, rank_zero_only=False)
+
 
 class VQGANDataset(Dataset):
     def __init__(
         self,
         filelist: str,
         sample_rate: int = 32000,
+        hop_length: int = 640,
+        slice_frames: Optional[int] = None,
     ):
         super().__init__()
 
@@ -22,36 +28,53 @@ class VQGANDataset(Dataset):
 
         self.files = [root / line.strip() for line in filelist.read_text().splitlines()]
         self.sample_rate = sample_rate
+        self.hop_length = hop_length
+        self.slice_frames = slice_frames
 
     def __len__(self):
         return len(self.files)
 
-    def __getitem__(self, idx):
+    def get_item(self, idx):
         file = self.files[idx]
 
         audio, _ = librosa.load(file, sr=self.sample_rate, mono=True)
         features = np.load(file.with_suffix(".npy"))  # (T, 1024)
 
+        if len(audio) % self.hop_length != 0:
+            audio = np.pad(audio, (0, self.hop_length - (len(audio) % self.hop_length)))
+
+        # Slice audio and features
+        if self.slice_frames is not None and features.shape[0] > self.slice_frames:
+            start = np.random.randint(0, features.shape[0] - self.slice_frames)
+            features = features[start : start + self.slice_frames]
+            audio = audio[
+                start * self.hop_length : (start + self.slice_frames) * self.hop_length
+            ]
+
         return {
             "audio": torch.from_numpy(audio),
             "features": torch.from_numpy(features),
         }
 
+    def __getitem__(self, idx):
+        try:
+            return self.get_item(idx)
+        except Exception as e:
+            logger.error(f"Error loading {self.files[idx]}: {e}")
+            return None
+
 
 @dataclass
 class VQGANCollator:
-    hop_length: int = 640
-
     def __call__(self, batch):
+        batch = [x for x in batch if x is not None]
+
         audio_lengths = torch.tensor([len(x["audio"]) for x in batch])
         feature_lengths = torch.tensor([len(x["features"]) for x in batch])
 
         audio_maxlen = audio_lengths.max()
         feature_maxlen = feature_lengths.max()
 
-        if audio_maxlen % self.hop_length != 0:
-            audio_maxlen += self.hop_length - (audio_maxlen % self.hop_length)
-
         audios, features = [], []
         for x in batch:
             audios.append(
@@ -77,7 +100,6 @@ class VQGANDataModule(LightningDataModule):
         train_dataset: VQGANDataset,
         val_dataset: VQGANDataset,
         batch_size: int = 32,
-        hop_length: int = 640,
         num_workers: int = 4,
         val_batch_size: Optional[int] = None,
     ):
@@ -87,14 +109,13 @@ class VQGANDataModule(LightningDataModule):
         self.val_dataset = val_dataset
         self.batch_size = batch_size
         self.val_batch_size = val_batch_size or batch_size
-        self.hop_length = hop_length
         self.num_workers = num_workers
 
     def train_dataloader(self):
         return DataLoader(
             self.train_dataset,
             batch_size=self.batch_size,
-            collate_fn=VQGANCollator(self.hop_length),
+            collate_fn=VQGANCollator(),
             num_workers=self.num_workers,
             shuffle=True,
         )
@@ -103,7 +124,7 @@ class VQGANDataModule(LightningDataModule):
         return DataLoader(
             self.val_dataset,
             batch_size=self.batch_size,
-            collate_fn=VQGANCollator(self.hop_length),
+            collate_fn=VQGANCollator(),
             num_workers=self.num_workers,
         )
 

+ 50 - 22
fish_speech/models/vqgan/lit_module.py

@@ -123,10 +123,14 @@ class VQGAN(L.LightningModule):
         audios, audio_lengths = batch["audios"], batch["audio_lengths"]
         features, feature_lengths = batch["features"], batch["feature_lengths"]
 
+        audios = audios.float()
+        features = features.float()
+
         with torch.no_grad():
             gt_mels = self.mel_transform(audios).transpose(1, 2)
             key_padding_mask = sequence_mask(feature_lengths)
             mels_key_padding_mask = sequence_mask(audio_lengths // self.hop_length)
+            audio_masks = sequence_mask(audio_lengths)[:, None]
 
             assert abs(gt_mels.shape[1] - mels_key_padding_mask.shape[1]) <= 1
             gt_mel_length = min(gt_mels.shape[1], mels_key_padding_mask.shape[1])
@@ -138,6 +142,29 @@ class VQGAN(L.LightningModule):
             features = features[:, :gt_feature_length]
             key_padding_mask = key_padding_mask[:, :gt_feature_length]
 
+        audios = audios[:, None, :]
+
+        # # Get slice of audio
+        # if audios.shape[-1] > self.segment_size:
+        #     start = torch.randint(
+        #         0, audios.shape[-1] - self.segment_size, (1,), device=audios.device
+        #     ).item()
+        #     start = start // self.hop_length * self.hop_length
+
+        #     audios = audios[:, :, start : start + self.segment_size]
+        #     audio_masks = sequence_mask(audio_lengths)[
+        #         :, None, start : start + self.segment_size
+        #     ]
+
+        #     mel_start = start // self.hop_length
+        #     mel_size = self.segment_size // self.hop_length
+        #     gt_mels = gt_mels[:, mel_start : mel_start + mel_size]
+        #     mels_key_padding_mask = mels_key_padding_mask[
+        #         :, mel_start : mel_start + mel_size
+        #     ]
+
+        #     features = features[:, :, mel_start : mel_start + mel_size]
+
         # Generator
         encoded = self.encoder(
             x=features,
@@ -147,30 +174,15 @@ class VQGAN(L.LightningModule):
         )
 
         features = encoded.features
-        audios = audios[:, None, :]
-
-        # Get slice of audio
-        if audios.shape[-1] > self.segment_size:
-            start = torch.randint(
-                0, audios.shape[-1] - self.segment_size, (1,), device=audios.device
-            ).item()
-            start = start // self.hop_length * self.hop_length
-
-            audios = audios[:, :, start : start + self.segment_size]
-            audio_masks = sequence_mask(audio_lengths)[
-                :, None, start : start + self.segment_size
-            ]
+        # features = self.naive_proj(features.transpose(1, 2))
 
-            mel_start = start // self.hop_length
-            mel_size = self.segment_size // self.hop_length
-            gt_mels = gt_mels[:, mel_start : mel_start + mel_size]
-            mels_key_padding_mask = mels_key_padding_mask[
-                :, mel_start : mel_start + mel_size
-            ]
+        fake_audios = self.generator(features)
 
-            features = features[:, :, mel_start : mel_start + mel_size]
+        min_audio_length = min(audios.shape[-1], fake_audios.shape[-1])
+        audios = audios[:, :, :min_audio_length]
+        fake_audios = fake_audios[:, :, :min_audio_length]
+        audio_masks = audio_masks[:, :, :min_audio_length]
 
-        fake_audios = self.generator(features)
         audio = torch.masked_fill(audios, audio_masks, 0.0)
         fake_audios = torch.masked_fill(fake_audios, audio_masks, 0.0)
         assert fake_audios.shape == audio.shape
@@ -201,6 +213,12 @@ class VQGAN(L.LightningModule):
         y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = self.discriminator(audios, fake_audios)
         fake_mels = self.mel_transform(fake_audios.squeeze(1)).transpose(1, 2)
 
+        # Min mel length
+        min_mel_length = min(gt_mels.shape[1], fake_mels.shape[1])
+        gt_mels = gt_mels[:, :min_mel_length]
+        fake_mels = fake_mels[:, :min_mel_length]
+        mels_key_padding_mask = mels_key_padding_mask[:, :min_mel_length]
+
         # Fill mel mask
         fake_mels = torch.masked_fill(fake_mels, mels_key_padding_mask[:, :, None], 0.0)
         gt_mels = torch.masked_fill(gt_mels, mels_key_padding_mask[:, :, None], 0.0)
@@ -210,7 +228,7 @@ class VQGAN(L.LightningModule):
             loss_adv, _ = self.generator_loss(y_d_hat_g)
             loss_fm = self.feature_loss(fmap_r, fmap_g)
 
-            loss_gen_all = loss_fm * 45 + loss_mel + loss_adv + encoded.loss
+            loss_gen_all = loss_mel * 45 + loss_fm + loss_adv + encoded.loss
 
         self.log(
             "train/generator/loss",
@@ -274,6 +292,9 @@ class VQGAN(L.LightningModule):
         audios, audio_lengths = batch["audios"], batch["audio_lengths"]
         features, feature_lengths = batch["features"], batch["feature_lengths"]
 
+        audios = audios.float()
+        features = features.float()
+
         with torch.no_grad():
             gt_mels = self.mel_transform(audios).transpose(1, 2)
             key_padding_mask = sequence_mask(feature_lengths)
@@ -298,6 +319,8 @@ class VQGAN(L.LightningModule):
             mels_key_padding_mask=mels_key_padding_mask,
         )
 
+        # features = self.naive_proj(features.transpose(1, 2))
+
         features = encoded.features
         audios = audios[:, None, :]
 
@@ -313,6 +336,11 @@ class VQGAN(L.LightningModule):
         assert fake_audios.shape == audio.shape
 
         fake_mels = self.mel_transform(fake_audios.squeeze(1)).transpose(1, 2)
+        min_mel_length = min(gt_mels.shape[1], fake_mels.shape[1])
+        gt_mels = gt_mels[:, :min_mel_length]
+        fake_mels = fake_mels[:, :min_mel_length]
+        mels_key_padding_mask = mels_key_padding_mask[:, :min_mel_length]
+
         gt_mels = torch.masked_fill(gt_mels, mels_key_padding_mask[:, :, None], 0.0)
         fake_mels = torch.masked_fill(fake_mels, mels_key_padding_mask[:, :, None], 0.0)
 

+ 211 - 67
fish_speech/models/vqgan/modules.py

@@ -8,7 +8,12 @@ from torch.nn import Conv1d, Conv2d, ConvTranspose1d
 from torch.nn import functional as F
 from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
 
-from fish_speech.models.vqgan.utils import convert_pad_shape, get_padding, init_weights
+from fish_speech.models.vqgan.utils import (
+    convert_pad_shape,
+    fused_add_tanh_sigmoid_multiply,
+    get_padding,
+    init_weights,
+)
 
 LRELU_SLOPE = 0.1
 
@@ -16,19 +21,18 @@ LRELU_SLOPE = 0.1
 @dataclass
 class VQEncoderOutput:
     loss: torch.Tensor
-    features: torch.Tensor
+    mean: torch.Tensor
+    logs: torch.Tensor
 
 
 class VQEncoder(nn.Module):
     def __init__(
         self,
         in_channels: int = 1024,
-        channels: int = 192,
-        num_mels: int = 128,
+        channels: int = 384,
+        out_channels: int = 192,
         num_heads: int = 2,
-        num_feature_layers: int = 2,
-        num_speaker_layers: int = 4,
-        num_mixin_layers: int = 4,
+        num_layers: int = 8,
         input_downsample: bool = True,
         code_book_size: int = 2048,
         freeze_vq: bool = False,
@@ -38,7 +42,7 @@ class VQEncoder(nn.Module):
         # Feature Encoder
         down_sample = 2 if input_downsample else 1
 
-        self.vq_in = nn.Conv1d(
+        self.in_proj = nn.Conv1d(
             in_channels, in_channels, kernel_size=down_sample, stride=down_sample
         )
         self.vq = VectorQuantization(
@@ -49,38 +53,14 @@ class VQEncoder(nn.Module):
             kmeans_iters=50,
         )
 
-        self.feature_in = nn.Linear(in_channels, channels)
-        self.feature_blocks = nn.ModuleList(
-            [
-                TransformerBlock(
-                    channels,
-                    num_heads,
-                    window_size=4,
-                    window_heads_share=True,
-                    proximal_init=True,
-                    proximal_bias=False,
-                    use_relative_attn=True,
-                )
-                for _ in range(num_feature_layers)
-            ]
+        # Init weights of in_proj to mimic the effect of avg pooling
+        torch.nn.init.normal_(
+            self.in_proj.weight, mean=1 / (down_sample * in_channels), std=0.01
         )
+        self.in_proj.bias.data.zero_()
 
-        # Speaker Encoder
-        self.speaker_query = nn.Parameter(torch.randn(1, 1, channels))
-        self.speaker_in = nn.Linear(num_mels, channels)
-        self.speaker_blocks = nn.ModuleList(
-            [
-                TransformerBlock(
-                    channels,
-                    num_heads,
-                    use_relative_attn=False,
-                )
-                for _ in range(num_speaker_layers)
-            ]
-        )
-
-        # Final Mixer
-        self.mixer_blocks = nn.ModuleList(
+        self.feature_in = nn.Linear(in_channels, channels)
+        self.blocks = nn.ModuleList(
             [
                 TransformerBlock(
                     channels,
@@ -91,10 +71,12 @@ class VQEncoder(nn.Module):
                     proximal_bias=False,
                     use_relative_attn=True,
                 )
-                for _ in range(num_mixin_layers)
+                for _ in range(num_layers)
             ]
         )
 
+        self.out_proj = nn.Linear(channels, out_channels * 2)
+
         self.input_downsample = input_downsample
 
         if freeze_vq:
@@ -104,22 +86,15 @@ class VQEncoder(nn.Module):
             for p in self.vq_in.parameters():
                 p.requires_grad = False
 
-    def forward(
-        self, x, mels, key_padding_mask=None, mels_key_padding_mask=None
-    ) -> VQEncoderOutput:
+    def forward(self, x, key_padding_mask=None) -> VQEncoderOutput:
         # x: (batch, seq_len, channels)
-        # mels: (batch, seq_len, 128)
 
         assert key_padding_mask.size(1) == x.size(
             1
-        ), f"key_padding_mask shape {key_padding_mask.size()} does not match features shape {features.size()}"
-
-        assert mels_key_padding_mask.size(1) == mels.size(
-            1
-        ), f"mels_key_padding_mask shape {mels_key_padding_mask.size()} does not match mels shape {mels.size()}"
+        ), f"key_padding_mask shape {key_padding_mask.size()} does not match features shape {x.size()}"
 
         # Encode Features
-        features = self.vq_in(x.transpose(1, 2))
+        features = self.in_proj(x.transpose(1, 2))
         features, _, loss = self.vq(features)
         features = features.transpose(1, 2)
 
@@ -136,35 +111,204 @@ class VQEncoder(nn.Module):
             key_padding_mask = key_padding_mask[:, :min_len]
 
         features = self.feature_in(features)
-        for block in self.feature_blocks:
+        for block in self.blocks:
             features = block(features, key_padding_mask=key_padding_mask)
 
-        # Encode Speaker
-        speaker = self.speaker_in(mels)
-        speaker = torch.cat(
-            [self.speaker_query.expand(speaker.shape[0], -1, -1), speaker], dim=1
+        stats = self.out_proj(features).transpose(1, 2)
+        stats = torch.masked_fill(stats, key_padding_mask.unsqueeze(1), 0)
+        mean, logs = torch.chunk(stats, 2, dim=1)
+
+        return VQEncoderOutput(
+            loss=loss,
+            mean=mean,
+            logs=logs,
+        )
+
+
+class WaveNet(nn.Module):
+    def __init__(
+        self,
+        hidden_channels,
+        kernel_size,
+        dilation_rate,
+        n_layers,
+        gin_channels=0,
+        p_dropout=0,
+    ):
+        super(WaveNet, self).__init__()
+        assert kernel_size % 2 == 1
+        self.hidden_channels = hidden_channels
+        self.kernel_size = (kernel_size,)
+        self.dilation_rate = dilation_rate
+        self.n_layers = n_layers
+        self.gin_channels = gin_channels
+        self.p_dropout = p_dropout
+
+        self.in_layers = nn.ModuleList()
+        self.res_skip_layers = nn.ModuleList()
+        self.drop = nn.Dropout(p_dropout)
+
+        if gin_channels != 0:
+            self.cond_layer = weight_norm(
+                nn.Conv1d(gin_channels, 2 * hidden_channels * n_layers, 1)
+            )
+
+        for i in range(n_layers):
+            dilation = dilation_rate**i
+            padding = int((kernel_size * dilation - dilation) / 2)
+            in_layer = weight_norm(
+                nn.Conv1d(
+                    hidden_channels,
+                    2 * hidden_channels,
+                    kernel_size,
+                    dilation=dilation,
+                    padding=padding,
+                )
+            )
+            self.in_layers.append(in_layer)
+
+            # last one is not necessary
+            if i < n_layers - 1:
+                res_skip_channels = 2 * hidden_channels
+            else:
+                res_skip_channels = hidden_channels
+
+            res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
+            res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight")
+            self.res_skip_layers.append(res_skip_layer)
+
+    def forward(self, x, x_mask, g=None):
+        output = torch.zeros_like(x)
+        n_channels_tensor = torch.IntTensor([self.hidden_channels])
+
+        if g is not None:
+            g = self.cond_layer(g)
+
+        for i in range(self.n_layers):
+            x_in = self.in_layers[i](x)
+            if g is not None:
+                cond_offset = i * 2 * self.hidden_channels
+                g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
+            else:
+                g_l = torch.zeros_like(x_in)
+
+            acts = fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
+            acts = self.drop(acts)
+
+            res_skip_acts = self.res_skip_layers[i](acts)
+            if i < self.n_layers - 1:
+                res_acts = res_skip_acts[:, : self.hidden_channels, :]
+                x = (x + res_acts) * x_mask
+                output = output + res_skip_acts[:, self.hidden_channels :, :]
+            else:
+                output = output + res_skip_acts
+
+        return output * x_mask
+
+    def remove_weight_norm(self):
+        if self.gin_channels != 0:
+            torch.nn.utils.remove_weight_norm(self.cond_layer)
+        for l in self.in_layers:
+            torch.nn.utils.remove_weight_norm(l)
+        for l in self.res_skip_layers:
+            torch.nn.utils.remove_weight_norm(l)
+
+
+@dataclass
+class PosteriorEncoderOutput:
+    z: torch.Tensor
+    mean: torch.Tensor
+    logs: torch.Tensor
+
+
+class PosteriorEncoder(nn.Module):
+    def __init__(
+        self,
+        in_channels: int = 1024,
+        out_channels: int = 192,
+        hidden_channels: int = 192,
+        kernel_size: int = 5,
+        dilation_rate: int = 1,
+        n_layers: int = 16,
+        gin_channels: int = 512,
+    ):
+        super().__init__()
+        self.in_channels = in_channels
+        self.out_channels = out_channels
+        self.hidden_channels = hidden_channels
+        self.kernel_size = kernel_size
+        self.dilation_rate = dilation_rate
+        self.n_layers = n_layers
+        self.gin_channels = gin_channels
+
+        self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
+        self.enc = WaveNet(
+            hidden_channels,
+            kernel_size,
+            dilation_rate,
+            n_layers,
+            gin_channels=gin_channels,
+        )
+        self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
+
+    def forward(self, x, x_mask, g=None):
+        g = g.detach()
+        x = self.pre(x) * x_mask
+        x = self.enc(x, x_mask, g=g)
+        stats = self.proj(x) * x_mask
+        m, logs = torch.split(stats, self.out_channels, dim=1)
+        z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
+
+        return PosteriorEncoderOutput(
+            z=z,
+            mean=m,
+            logs=logs,
         )
+
+
+class SpeakerEncoder(nn.Module):
+    def __init__(
+        self,
+        in_channels: int = 128,
+        channels: int = 192,
+        out_channels: int = 512,
+        num_heads: int = 2,
+        num_layers: int = 4,
+    ) -> None:
+        super().__init__()
+
+        self.query = nn.Parameter(torch.randn(1, 1, channels))
+        self.in_proj = nn.Linear(in_channels, channels)
+        self.blocks = nn.ModuleList(
+            [
+                TransformerBlock(
+                    channels,
+                    num_heads,
+                    use_relative_attn=False,
+                )
+                for _ in range(num_layers)
+            ]
+        )
+        self.out_proj = nn.Linear(channels, out_channels)
+
+    def forward(self, mels, mels_key_padding_mask=None):
+        x = self.in_proj(mels)
+        x = torch.cat([self.query.expand(x.shape[0], -1, -1), x], dim=1)
+
         mels_key_padding_mask = torch.cat(
             [
-                torch.ones(
-                    speaker.shape[0], 1, dtype=torch.bool, device=speaker.device
-                ),
+                torch.ones(x.shape[0], 1, dtype=torch.bool, device=x.device),
                 mels_key_padding_mask,
             ],
             dim=1,
         )
-        for block in self.speaker_blocks:
-            speaker = block(speaker, key_padding_mask=mels_key_padding_mask)
+        for block in self.blocks:
+            x = block(x, key_padding_mask=mels_key_padding_mask)
 
-        # Mix
-        x = features + speaker[:, :1]
-        for block in self.mixer_blocks:
-            x = block(x, key_padding_mask=key_padding_mask)
+        x = x[:, :1]
+        x = self.out_proj(x)
 
-        return VQEncoderOutput(
-            loss=loss,
-            features=x.transpose(1, 2),
-        )
+        return x.transpose(1, 2)
 
 
 class TransformerBlock(nn.Module):

+ 11 - 0
fish_speech/models/vqgan/utils.py

@@ -50,3 +50,14 @@ def plot_mel(data, titles=None):
         axes[i][0].set_anchor("W")
 
     return fig
+
+
+@torch.jit.script
+def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
+    n_channels_int = n_channels[0]
+    in_act = input_a + input_b
+    t_act = torch.tanh(in_act[:, :n_channels_int, :])
+    s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
+    acts = t_act * s_act
+
+    return acts

+ 15 - 11
tools/vqgan/calculate_kmeans_init.py

@@ -27,11 +27,15 @@ class KMeansDataset(Dataset):
 
     def __getitem__(self, idx):
         file = self.files[idx]
-        feature = np.load(file.with_suffix(".npy"))
+        try:
+            feature = np.load(file.with_suffix(".npy"))
+        except Exception as e:
+            return None
         return torch.from_numpy(feature).float()
 
     @staticmethod
     def collate_fn(features):
+        features = [feature for feature in features if feature is not None]
         features = torch.concat(features, dim=0)
         return features
 
@@ -40,7 +44,7 @@ class KMeansDataset(Dataset):
 @click.option(
     "--filelist",
     type=click.Path(exists=True, path_type=Path),
-    default="data/test.filelist",
+    default="data/vq_train_filelist.txt",
 )
 @click.option("--output", type=click.Path(path_type=Path), default="kmeans.pt")
 @click.option("--num-clusters", type=int, default=2048)
@@ -55,7 +59,7 @@ def main(filelist: Path, output: Path, num_clusters: int, epochs: int):
     )
 
     means = None
-    for _ in tqdm(range(epochs), desc="Epochs", position=0):
+    for epoch in tqdm(range(epochs), desc="Epochs", position=0):
         total_bins = torch.zeros(1, num_clusters, dtype=torch.int64, device="cuda")
 
         for samples in tqdm(loader, desc="Batches", position=1):
@@ -86,14 +90,14 @@ def main(filelist: Path, output: Path, num_clusters: int, epochs: int):
 
             total_bins += bins
 
-    torch.save(
-        {
-            "centroids": means,
-            "bins": bins,
-        },
-        output,
-    )
-    print(f"Saved to {output}")
+        torch.save(
+            {
+                "centroids": means,
+                "bins": bins,
+            },
+            output,
+        )
+        print(f"Finished epoch {epoch}, total bins: {total_bins}")
 
 
 if __name__ == "__main__":

+ 31 - 0
tools/vqgan/migrate_from_vits.py

@@ -45,6 +45,37 @@ def main(cfg: DictConfig):
     model.discriminator.load_state_dict(discriminator_weights, strict=True)
     logger.info("Discriminator weights restored.")
 
+    # Restore kmeans
+    logger.info("Reset vq projection layer to mimic avg pooling")
+    torch.nn.init.normal_(
+        model.encoder.in_proj.weight,
+        mean=1
+        / (
+            model.encoder.in_proj.weight.shape[0]
+            * model.encoder.in_proj.weight.shape[-1]
+        ),
+        std=1e-2,
+    )
+    model.encoder.in_proj.bias.data.zero_()
+
+    kmeans_ckpt = "results/hubert-vq-pretrain/kmeans.pt"
+    kmeans_ckpt = torch.load(kmeans_ckpt, map_location="cpu")
+
+    centroids = kmeans_ckpt["centroids"][0]
+    bins = kmeans_ckpt["bins"][0]
+    logger.info(
+        f"Restoring kmeans centroids with shape {centroids.shape} and bins {bins.shape}"
+    )
+
+    state_dict = {
+        "_codebook.inited": torch.Tensor([True]),
+        "_codebook.cluster_size": bins,
+        "_codebook.embed": centroids,
+        "_codebook.embed_avg": centroids.clone(),
+    }
+
+    model.encoder.vq.load_state_dict(state_dict, strict=True)
+
     torch.save(model.state_dict(), cfg.ckpt_path)
     logger.info("Done")