Просмотр исходного кода

Update diffusion config & code

Lengyue 2 лет назад
Родитель
Сommit
69164b19d2

+ 3 - 3
fish_speech/configs/hubert_vq_diffusion.yaml

@@ -24,14 +24,14 @@ win_length: 2048
 # Dataset Configuration
 train_dataset:
   _target_: fish_speech.datasets.vqgan.VQGANDataset
-  filelist: data/vq_train_filelist.txt
+  filelist: data/filelist.split.train
   sample_rate: ${sample_rate}
   hop_length: ${hop_length}
   slice_frames: 512
 
 val_dataset:
   _target_: fish_speech.datasets.vqgan.VQGANDataset
-  filelist: data/vq_val_filelist.txt
+  filelist: data/filelist.split.valid
   sample_rate: ${sample_rate}
   hop_length: ${hop_length}
 
@@ -51,7 +51,7 @@ model:
 
   text_encoder:
     _target_: fish_speech.models.vqgan.modules.encoders.TextEncoder
-    in_channels: 1024
+    in_channels: 2048
     out_channels: 128
     hidden_channels: 192
     hidden_channels_ffn: 768

+ 1 - 1
fish_speech/datasets/vqgan.py

@@ -94,7 +94,7 @@ class VQGANCollator:
             )
             features.append(
                 torch.nn.functional.pad(
-                    x["features"], (0, 0, 0, feature_maxlen - len(x["features"]))
+                    x["features"], (0, feature_maxlen - len(x["features"]))
                 )
             )
 

+ 25 - 20
fish_speech/models/vq_diffusion/lit_module.py

@@ -11,6 +11,7 @@ from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
 from matplotlib import pyplot as plt
 from torch import nn
 from tqdm import tqdm
+from transformers import HubertModel
 
 from fish_speech.models.vq_diffusion.convnext_1d import ConvNext1DModel
 from fish_speech.models.vqgan.modules.encoders import (
@@ -86,7 +87,7 @@ class VQDiffusion(L.LightningModule):
         features, feature_lengths = batch["features"], batch["feature_lengths"]
 
         audios = audios.float()
-        features = features.float().mT
+        # features = features.float().mT
         audios = audios[:, None, :]
 
         with torch.no_grad():
@@ -95,18 +96,20 @@ class VQDiffusion(L.LightningModule):
         mel_lengths = audio_lengths // self.hop_length
 
         feature_masks = torch.unsqueeze(
-            sequence_mask(feature_lengths, features.shape[2]), 1
+            sequence_mask(feature_lengths, features.shape[1]), 1
         ).to(gt_mels.dtype)
         mel_masks = torch.unsqueeze(sequence_mask(mel_lengths, gt_mels.shape[2]), 1).to(
             gt_mels.dtype
         )
 
         speaker_features = self.speaker_encoder(gt_mels, mel_masks)
-        vq_features, vq_loss = self.vq_encoder(features, feature_masks)
+        # vq_features, vq_loss = self.vq_encoder(features, feature_masks)
 
         # vq_features is 50 hz, need to convert to true mel size
-        vq_features = F.interpolate(vq_features, size=gt_mels.shape[2], mode="nearest")
-        text_features = self.text_encoder(vq_features, mel_masks, g=speaker_features)
+        text_features = self.text_encoder(features, feature_masks, g=speaker_features)
+        text_features = F.interpolate(
+            text_features, size=gt_mels.shape[2], mode="nearest"
+        )
 
         # Sample noise that we'll add to the images
         normalized_gt_mels = self.normalize_mels(gt_mels)
@@ -144,41 +147,43 @@ class VQDiffusion(L.LightningModule):
             sync_dist=True,
         )
 
-        self.log(
-            "train/vq_loss",
-            vq_loss,
-            on_step=True,
-            on_epoch=False,
-            prog_bar=True,
-            logger=True,
-            sync_dist=True,
-        )
+        # self.log(
+        #     "train/vq_loss",
+        #     vq_loss,
+        #     on_step=True,
+        #     on_epoch=False,
+        #     prog_bar=True,
+        #     logger=True,
+        #     sync_dist=True,
+        # )
 
-        return noise_loss + vq_loss
+        return noise_loss  # + vq_loss
 
     def validation_step(self, batch: Any, batch_idx: int):
         audios, audio_lengths = batch["audios"], batch["audio_lengths"]
         features, feature_lengths = batch["features"], batch["feature_lengths"]
 
         audios = audios.float()
-        features = features.float().mT
+        # features = features.float().mT
         audios = audios[:, None, :]
         gt_mels = self.mel_transform(audios)
         mel_lengths = audio_lengths // self.hop_length
 
         feature_masks = torch.unsqueeze(
-            sequence_mask(feature_lengths, features.shape[2]), 1
+            sequence_mask(feature_lengths, features.shape[1]), 1
         ).to(gt_mels.dtype)
         mel_masks = torch.unsqueeze(sequence_mask(mel_lengths, gt_mels.shape[2]), 1).to(
             gt_mels.dtype
         )
 
         speaker_features = self.speaker_encoder(gt_mels, mel_masks)
-        vq_features, _ = self.vq_encoder(features, feature_masks)
+        # vq_features, vq_loss = self.vq_encoder(features, feature_masks)
 
         # vq_features is 50 hz, need to convert to true mel size
-        vq_features = F.interpolate(vq_features, size=gt_mels.shape[2], mode="nearest")
-        text_features = self.text_encoder(vq_features, mel_masks, g=speaker_features)
+        text_features = self.text_encoder(features, feature_masks, g=speaker_features)
+        text_features = F.interpolate(
+            text_features, size=gt_mels.shape[2], mode="nearest"
+        )
 
         # Begin sampling
         sampled_mels = torch.randn_like(gt_mels)

+ 13 - 11
fish_speech/models/vqgan/lit_module.py

@@ -103,7 +103,7 @@ class VQGAN(L.LightningModule):
             (z_q, z_p),
             (m_p, logs_p),
             (m_q, logs_q),
-            vq_loss,
+            # vq_loss,
         ) = self.generator(features, feature_lengths, gt_mels)
 
         y_hat_mel = self.mel_transform(y_hat.squeeze(1))
@@ -155,7 +155,9 @@ class VQGAN(L.LightningModule):
             beta = self.global_step % 1000
             beta = min(beta, 500) / 500 * 0.1 + 1e-6
 
-            loss_gen_all = loss_mel * 45 + loss_fm + loss_adv + loss_kl * beta + vq_loss
+            loss_gen_all = (
+                loss_mel * 45 + loss_fm + loss_adv + loss_kl * beta
+            )  # + vq_loss
 
         self.log(
             "train/generator/loss",
@@ -202,15 +204,15 @@ class VQGAN(L.LightningModule):
             logger=True,
             sync_dist=True,
         )
-        self.log(
-            "train/generator/loss_vq",
-            vq_loss,
-            on_step=True,
-            on_epoch=False,
-            prog_bar=False,
-            logger=True,
-            sync_dist=True,
-        )
+        # self.log(
+        #     "train/generator/loss_vq",
+        #     vq_loss,
+        #     on_step=True,
+        #     on_epoch=False,
+        #     prog_bar=False,
+        #     logger=True,
+        #     sync_dist=True,
+        # )
 
         optim_g.zero_grad()
         self.manual_backward(loss_gen_all)

+ 25 - 11
fish_speech/models/vqgan/modules/encoders.py

@@ -6,7 +6,10 @@ import torch.nn.functional as F
 from vector_quantize_pytorch import VectorQuantize
 
 from fish_speech.models.vqgan.modules.modules import WN
-from fish_speech.models.vqgan.modules.transformer import RelativePositionTransformer
+from fish_speech.models.vqgan.modules.transformer import (
+    MultiHeadAttention,
+    RelativePositionTransformer,
+)
 from fish_speech.models.vqgan.utils import sequence_mask
 
 
@@ -43,7 +46,8 @@ class TextEncoder(nn.Module):
         self.out_channels = out_channels
         self.hidden_channels = hidden_channels
 
-        self.proj_in = nn.Conv1d(in_channels, hidden_channels, 1)
+        # self.proj_in = nn.Conv1d(in_channels, hidden_channels, 1)
+        self.proj_in = nn.Embedding(in_channels, hidden_channels)
 
         self.encoder = RelativePositionTransformer(
             in_channels=hidden_channels,
@@ -75,7 +79,7 @@ class TextEncoder(nn.Module):
             - x: :math:`[B, T]`
             - x_length: :math:`[B]`
         """
-        x = self.proj_in(x) * x_mask
+        x = self.proj_in(x).mT * x_mask
         x = self.encoder(x, x_mask, g=g)
         x = self.proj_out(x) * x_mask
 
@@ -162,13 +166,18 @@ class SpeakerEncoder(nn.Module):
 
         self.in_proj = nn.Sequential(
             nn.Conv1d(in_channels, hidden_channels, 1),
-            nn.SiLU(),
+            nn.Mish(),
+            nn.Dropout(p_dropout),
             nn.Conv1d(hidden_channels, hidden_channels, 5, padding=2),
-            nn.SiLU(),
+            nn.Mish(),
+            nn.Dropout(p_dropout),
             nn.Conv1d(hidden_channels, hidden_channels, 5, padding=2),
-            nn.SiLU(),
+            nn.Mish(),
             nn.Dropout(p_dropout),
         )
+        self.out_proj = nn.Conv1d(hidden_channels, out_channels, 1)
+        self.apply(self._init_weights)
+
         self.encoder = RelativePositionTransformer(
             in_channels=hidden_channels,
             out_channels=hidden_channels,
@@ -176,11 +185,15 @@ class SpeakerEncoder(nn.Module):
             hidden_channels_ffn=hidden_channels,
             n_heads=num_heads,
             n_layers=num_layers,
-            kernel_size=5,
+            kernel_size=1,
             dropout=p_dropout,
-            window_size=4,
+            window_size=None,  # No windowing
         )
-        self.out_proj = nn.Linear(hidden_channels, out_channels)
+
+    def _init_weights(self, m):
+        if isinstance(m, (nn.Conv1d, nn.Linear)):
+            nn.init.normal_(m.weight, mean=0, std=0.02)
+            nn.init.zeros_(m.bias)
 
     def forward(self, mels, mel_masks: torch.Tensor):
         """
@@ -194,8 +207,9 @@ class SpeakerEncoder(nn.Module):
 
         # Avg Pooling
         x = x * mel_masks
-        x = torch.sum(x, dim=2) / torch.sum(mel_masks, dim=2)
-        x = self.out_proj(x)[..., None]
+        x = self.out_proj(x)
+        x = torch.sum(x, dim=-1) / torch.sum(mel_masks, dim=-1)
+        x = x[..., None]
 
         return x
 

+ 6 - 3
fish_speech/models/vqgan/modules/models.py

@@ -121,7 +121,10 @@ class SynthesizerTrn(nn.Module):
         x_masks = torch.unsqueeze(sequence_mask(x_lengths, x.shape[2]), 1).to(x.dtype)
 
         g = self.enc_spk(specs, spec_masks)
-        x, vq_loss = self.vq(x, x_masks)
+
+        # with torch.no_grad():
+        #     x, _ = self.vq(x, x_masks)
+        #     vq_loss = 0
 
         _, m_p, logs_p, _, _ = self.enc_p(x, x_masks, g=g)
         z_q, m_q, logs_q, _ = self.enc_q(specs, spec_masks, g=g)
@@ -138,7 +141,7 @@ class SynthesizerTrn(nn.Module):
             (z_q, z_p),
             (m_p, logs_p),
             (m_q, logs_q),
-            vq_loss,
+            # vq_loss,
         )
 
     def infer(self, x, x_lengths, specs, max_len=None, noise_scale=0.35):
@@ -148,7 +151,7 @@ class SynthesizerTrn(nn.Module):
         )
         x_masks = torch.unsqueeze(sequence_mask(x_lengths, x.shape[2]), 1).to(x.dtype)
         g = self.enc_spk(specs, spec_masks)
-        x, vq_loss = self.vq(x, x_masks)
+        # x, vq_loss = self.vq(x, x_masks)
         z_p, m_p, logs_p, h_text, _ = self.enc_p(
             x, x_masks, g=g, noise_scale=noise_scale
         )

+ 22 - 12
tools/vqgan/calculate_hubert_features.py

@@ -44,7 +44,7 @@ def get_hubert_model():
     return model
 
 
-def process_batch(files: list[Path]):
+def process_batch(files: list[Path], kmeans_centers: torch.Tensor) -> float:
     model = get_hubert_model()
 
     wavs = []
@@ -81,25 +81,32 @@ def process_batch(files: list[Path]):
 
     # Calculate lengths
     with torch.no_grad():
-        outputs = model(wavs, attention_mask=attention_mask)
+        outputs = model(wavs, attention_mask=attention_mask).last_hidden_state
+
+        # Find closest centroids
+        kmeans_centers = kmeans_centers.to(dtype=outputs.dtype, device=outputs.device)
+        distances = torch.cdist(outputs, kmeans_centers)
+        outputs = torch.min(distances, dim=-1)
+        avg_distance = torch.mean(outputs.values)
 
     # Save to disk
-    outputs = outputs.last_hidden_state.cpu().numpy()
+    outputs = outputs.indices.cpu().numpy()
 
     for file, length, feature, wav in zip(files, feature_lengths, outputs, wavs):
         feature = feature[:length]
 
-        # (T, 1024)
+        # (T,)
         with open(file.with_suffix(".npy"), "wb") as f:
             np.save(f, feature)
 
-    return total_time
+    return total_time, avg_distance
 
 
 @click.command()
 @click.argument("folder")
 @click.option("--num-workers", default=1)
-def main(folder: str, num_workers: int):
+@click.option("--kmeans", default="results/hubert-vq-pretrain/kmeans.pt")
+def main(folder: str, num_workers: int, kmeans: str):
     if num_workers > 1 and WORLD_SIZE != num_workers:
         assert WORLD_SIZE == 1, "You should either use SLURM or this launcher, not both"
 
@@ -140,16 +147,22 @@ def main(folder: str, num_workers: int):
     files = files[RANK::WORLD_SIZE]
     logger.info(f"Processing {len(files)}/{total_files} files")
 
+    # Load kmeans
+    kmeans_centers = torch.load(kmeans)["centroids"]
+
     # Batch size 64
     total_time = 0
     begin_time = time.time()
     processed_files = 0
+    total_distance = 0
 
     for n_batch, idx in enumerate(range(0, len(files), 32)):
         batch = files[idx : idx + 32]
-        batch_time = process_batch(batch)
+        batch_time, avg_distance = process_batch(batch, kmeans_centers)
+
         total_time += batch_time
         processed_files += len(batch)
+        total_distance += avg_distance
 
         if (n_batch + 1) % 10 == 0:
             eta = (
@@ -158,13 +171,10 @@ def main(folder: str, num_workers: int):
                 * (len(files) - processed_files)
             )
             logger.info(
-                f"Processed {processed_files} files, {total_time / 3600:.2f} hours of audio, ETA: {timedelta(seconds=round(eta))}s"
+                f"Processed {processed_files} files, {total_time / 3600:.2f} hours of audio, "
+                + f"err {total_distance/(n_batch+1):.2f}, ETA: {timedelta(seconds=round(eta))}s"
             )
 
-        # Stop after 1000 hours
-        if total_time * WORLD_SIZE > 3600 * 1000:
-            break
-
     logger.info(
         f"Finished processing {len(files)} files, {total_time / 3600:.2f} hours of audio"
     )