فهرست منبع

Update diffusion config & code

Lengyue 2 سال پیش
والد
کامیت
69164b19d2

+ 3 - 3
fish_speech/configs/hubert_vq_diffusion.yaml

@@ -24,14 +24,14 @@ win_length: 2048
 # Dataset Configuration
 # Dataset Configuration
 train_dataset:
 train_dataset:
   _target_: fish_speech.datasets.vqgan.VQGANDataset
   _target_: fish_speech.datasets.vqgan.VQGANDataset
-  filelist: data/vq_train_filelist.txt
+  filelist: data/filelist.split.train
   sample_rate: ${sample_rate}
   sample_rate: ${sample_rate}
   hop_length: ${hop_length}
   hop_length: ${hop_length}
   slice_frames: 512
   slice_frames: 512
 
 
 val_dataset:
 val_dataset:
   _target_: fish_speech.datasets.vqgan.VQGANDataset
   _target_: fish_speech.datasets.vqgan.VQGANDataset
-  filelist: data/vq_val_filelist.txt
+  filelist: data/filelist.split.valid
   sample_rate: ${sample_rate}
   sample_rate: ${sample_rate}
   hop_length: ${hop_length}
   hop_length: ${hop_length}
 
 
@@ -51,7 +51,7 @@ model:
 
 
   text_encoder:
   text_encoder:
     _target_: fish_speech.models.vqgan.modules.encoders.TextEncoder
     _target_: fish_speech.models.vqgan.modules.encoders.TextEncoder
-    in_channels: 1024
+    in_channels: 2048
     out_channels: 128
     out_channels: 128
     hidden_channels: 192
     hidden_channels: 192
     hidden_channels_ffn: 768
     hidden_channels_ffn: 768

+ 1 - 1
fish_speech/datasets/vqgan.py

@@ -94,7 +94,7 @@ class VQGANCollator:
             )
             )
             features.append(
             features.append(
                 torch.nn.functional.pad(
                 torch.nn.functional.pad(
-                    x["features"], (0, 0, 0, feature_maxlen - len(x["features"]))
+                    x["features"], (0, feature_maxlen - len(x["features"]))
                 )
                 )
             )
             )
 
 

+ 25 - 20
fish_speech/models/vq_diffusion/lit_module.py

@@ -11,6 +11,7 @@ from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
 from matplotlib import pyplot as plt
 from matplotlib import pyplot as plt
 from torch import nn
 from torch import nn
 from tqdm import tqdm
 from tqdm import tqdm
+from transformers import HubertModel
 
 
 from fish_speech.models.vq_diffusion.convnext_1d import ConvNext1DModel
 from fish_speech.models.vq_diffusion.convnext_1d import ConvNext1DModel
 from fish_speech.models.vqgan.modules.encoders import (
 from fish_speech.models.vqgan.modules.encoders import (
@@ -86,7 +87,7 @@ class VQDiffusion(L.LightningModule):
         features, feature_lengths = batch["features"], batch["feature_lengths"]
         features, feature_lengths = batch["features"], batch["feature_lengths"]
 
 
         audios = audios.float()
         audios = audios.float()
-        features = features.float().mT
+        # features = features.float().mT
         audios = audios[:, None, :]
         audios = audios[:, None, :]
 
 
         with torch.no_grad():
         with torch.no_grad():
@@ -95,18 +96,20 @@ class VQDiffusion(L.LightningModule):
         mel_lengths = audio_lengths // self.hop_length
         mel_lengths = audio_lengths // self.hop_length
 
 
         feature_masks = torch.unsqueeze(
         feature_masks = torch.unsqueeze(
-            sequence_mask(feature_lengths, features.shape[2]), 1
+            sequence_mask(feature_lengths, features.shape[1]), 1
         ).to(gt_mels.dtype)
         ).to(gt_mels.dtype)
         mel_masks = torch.unsqueeze(sequence_mask(mel_lengths, gt_mels.shape[2]), 1).to(
         mel_masks = torch.unsqueeze(sequence_mask(mel_lengths, gt_mels.shape[2]), 1).to(
             gt_mels.dtype
             gt_mels.dtype
         )
         )
 
 
         speaker_features = self.speaker_encoder(gt_mels, mel_masks)
         speaker_features = self.speaker_encoder(gt_mels, mel_masks)
-        vq_features, vq_loss = self.vq_encoder(features, feature_masks)
+        # vq_features, vq_loss = self.vq_encoder(features, feature_masks)
 
 
         # vq_features is 50 hz, need to convert to true mel size
         # vq_features is 50 hz, need to convert to true mel size
-        vq_features = F.interpolate(vq_features, size=gt_mels.shape[2], mode="nearest")
-        text_features = self.text_encoder(vq_features, mel_masks, g=speaker_features)
+        text_features = self.text_encoder(features, feature_masks, g=speaker_features)
+        text_features = F.interpolate(
+            text_features, size=gt_mels.shape[2], mode="nearest"
+        )
 
 
         # Sample noise that we'll add to the images
         # Sample noise that we'll add to the images
         normalized_gt_mels = self.normalize_mels(gt_mels)
         normalized_gt_mels = self.normalize_mels(gt_mels)
@@ -144,41 +147,43 @@ class VQDiffusion(L.LightningModule):
             sync_dist=True,
             sync_dist=True,
         )
         )
 
 
-        self.log(
-            "train/vq_loss",
-            vq_loss,
-            on_step=True,
-            on_epoch=False,
-            prog_bar=True,
-            logger=True,
-            sync_dist=True,
-        )
+        # self.log(
+        #     "train/vq_loss",
+        #     vq_loss,
+        #     on_step=True,
+        #     on_epoch=False,
+        #     prog_bar=True,
+        #     logger=True,
+        #     sync_dist=True,
+        # )
 
 
-        return noise_loss + vq_loss
+        return noise_loss  # + vq_loss
 
 
     def validation_step(self, batch: Any, batch_idx: int):
     def validation_step(self, batch: Any, batch_idx: int):
         audios, audio_lengths = batch["audios"], batch["audio_lengths"]
         audios, audio_lengths = batch["audios"], batch["audio_lengths"]
         features, feature_lengths = batch["features"], batch["feature_lengths"]
         features, feature_lengths = batch["features"], batch["feature_lengths"]
 
 
         audios = audios.float()
         audios = audios.float()
-        features = features.float().mT
+        # features = features.float().mT
         audios = audios[:, None, :]
         audios = audios[:, None, :]
         gt_mels = self.mel_transform(audios)
         gt_mels = self.mel_transform(audios)
         mel_lengths = audio_lengths // self.hop_length
         mel_lengths = audio_lengths // self.hop_length
 
 
         feature_masks = torch.unsqueeze(
         feature_masks = torch.unsqueeze(
-            sequence_mask(feature_lengths, features.shape[2]), 1
+            sequence_mask(feature_lengths, features.shape[1]), 1
         ).to(gt_mels.dtype)
         ).to(gt_mels.dtype)
         mel_masks = torch.unsqueeze(sequence_mask(mel_lengths, gt_mels.shape[2]), 1).to(
         mel_masks = torch.unsqueeze(sequence_mask(mel_lengths, gt_mels.shape[2]), 1).to(
             gt_mels.dtype
             gt_mels.dtype
         )
         )
 
 
         speaker_features = self.speaker_encoder(gt_mels, mel_masks)
         speaker_features = self.speaker_encoder(gt_mels, mel_masks)
-        vq_features, _ = self.vq_encoder(features, feature_masks)
+        # vq_features, vq_loss = self.vq_encoder(features, feature_masks)
 
 
         # vq_features is 50 hz, need to convert to true mel size
         # vq_features is 50 hz, need to convert to true mel size
-        vq_features = F.interpolate(vq_features, size=gt_mels.shape[2], mode="nearest")
-        text_features = self.text_encoder(vq_features, mel_masks, g=speaker_features)
+        text_features = self.text_encoder(features, feature_masks, g=speaker_features)
+        text_features = F.interpolate(
+            text_features, size=gt_mels.shape[2], mode="nearest"
+        )
 
 
         # Begin sampling
         # Begin sampling
         sampled_mels = torch.randn_like(gt_mels)
         sampled_mels = torch.randn_like(gt_mels)

+ 13 - 11
fish_speech/models/vqgan/lit_module.py

@@ -103,7 +103,7 @@ class VQGAN(L.LightningModule):
             (z_q, z_p),
             (z_q, z_p),
             (m_p, logs_p),
             (m_p, logs_p),
             (m_q, logs_q),
             (m_q, logs_q),
-            vq_loss,
+            # vq_loss,
         ) = self.generator(features, feature_lengths, gt_mels)
         ) = self.generator(features, feature_lengths, gt_mels)
 
 
         y_hat_mel = self.mel_transform(y_hat.squeeze(1))
         y_hat_mel = self.mel_transform(y_hat.squeeze(1))
@@ -155,7 +155,9 @@ class VQGAN(L.LightningModule):
             beta = self.global_step % 1000
             beta = self.global_step % 1000
             beta = min(beta, 500) / 500 * 0.1 + 1e-6
             beta = min(beta, 500) / 500 * 0.1 + 1e-6
 
 
-            loss_gen_all = loss_mel * 45 + loss_fm + loss_adv + loss_kl * beta + vq_loss
+            loss_gen_all = (
+                loss_mel * 45 + loss_fm + loss_adv + loss_kl * beta
+            )  # + vq_loss
 
 
         self.log(
         self.log(
             "train/generator/loss",
             "train/generator/loss",
@@ -202,15 +204,15 @@ class VQGAN(L.LightningModule):
             logger=True,
             logger=True,
             sync_dist=True,
             sync_dist=True,
         )
         )
-        self.log(
-            "train/generator/loss_vq",
-            vq_loss,
-            on_step=True,
-            on_epoch=False,
-            prog_bar=False,
-            logger=True,
-            sync_dist=True,
-        )
+        # self.log(
+        #     "train/generator/loss_vq",
+        #     vq_loss,
+        #     on_step=True,
+        #     on_epoch=False,
+        #     prog_bar=False,
+        #     logger=True,
+        #     sync_dist=True,
+        # )
 
 
         optim_g.zero_grad()
         optim_g.zero_grad()
         self.manual_backward(loss_gen_all)
         self.manual_backward(loss_gen_all)

+ 25 - 11
fish_speech/models/vqgan/modules/encoders.py

@@ -6,7 +6,10 @@ import torch.nn.functional as F
 from vector_quantize_pytorch import VectorQuantize
 from vector_quantize_pytorch import VectorQuantize
 
 
 from fish_speech.models.vqgan.modules.modules import WN
 from fish_speech.models.vqgan.modules.modules import WN
-from fish_speech.models.vqgan.modules.transformer import RelativePositionTransformer
+from fish_speech.models.vqgan.modules.transformer import (
+    MultiHeadAttention,
+    RelativePositionTransformer,
+)
 from fish_speech.models.vqgan.utils import sequence_mask
 from fish_speech.models.vqgan.utils import sequence_mask
 
 
 
 
@@ -43,7 +46,8 @@ class TextEncoder(nn.Module):
         self.out_channels = out_channels
         self.out_channels = out_channels
         self.hidden_channels = hidden_channels
         self.hidden_channels = hidden_channels
 
 
-        self.proj_in = nn.Conv1d(in_channels, hidden_channels, 1)
+        # self.proj_in = nn.Conv1d(in_channels, hidden_channels, 1)
+        self.proj_in = nn.Embedding(in_channels, hidden_channels)
 
 
         self.encoder = RelativePositionTransformer(
         self.encoder = RelativePositionTransformer(
             in_channels=hidden_channels,
             in_channels=hidden_channels,
@@ -75,7 +79,7 @@ class TextEncoder(nn.Module):
             - x: :math:`[B, T]`
             - x: :math:`[B, T]`
             - x_length: :math:`[B]`
             - x_length: :math:`[B]`
         """
         """
-        x = self.proj_in(x) * x_mask
+        x = self.proj_in(x).mT * x_mask
         x = self.encoder(x, x_mask, g=g)
         x = self.encoder(x, x_mask, g=g)
         x = self.proj_out(x) * x_mask
         x = self.proj_out(x) * x_mask
 
 
@@ -162,13 +166,18 @@ class SpeakerEncoder(nn.Module):
 
 
         self.in_proj = nn.Sequential(
         self.in_proj = nn.Sequential(
             nn.Conv1d(in_channels, hidden_channels, 1),
             nn.Conv1d(in_channels, hidden_channels, 1),
-            nn.SiLU(),
+            nn.Mish(),
+            nn.Dropout(p_dropout),
             nn.Conv1d(hidden_channels, hidden_channels, 5, padding=2),
             nn.Conv1d(hidden_channels, hidden_channels, 5, padding=2),
-            nn.SiLU(),
+            nn.Mish(),
+            nn.Dropout(p_dropout),
             nn.Conv1d(hidden_channels, hidden_channels, 5, padding=2),
             nn.Conv1d(hidden_channels, hidden_channels, 5, padding=2),
-            nn.SiLU(),
+            nn.Mish(),
             nn.Dropout(p_dropout),
             nn.Dropout(p_dropout),
         )
         )
+        self.out_proj = nn.Conv1d(hidden_channels, out_channels, 1)
+        self.apply(self._init_weights)
+
         self.encoder = RelativePositionTransformer(
         self.encoder = RelativePositionTransformer(
             in_channels=hidden_channels,
             in_channels=hidden_channels,
             out_channels=hidden_channels,
             out_channels=hidden_channels,
@@ -176,11 +185,15 @@ class SpeakerEncoder(nn.Module):
             hidden_channels_ffn=hidden_channels,
             hidden_channels_ffn=hidden_channels,
             n_heads=num_heads,
             n_heads=num_heads,
             n_layers=num_layers,
             n_layers=num_layers,
-            kernel_size=5,
+            kernel_size=1,
             dropout=p_dropout,
             dropout=p_dropout,
-            window_size=4,
+            window_size=None,  # No windowing
         )
         )
-        self.out_proj = nn.Linear(hidden_channels, out_channels)
+
+    def _init_weights(self, m):
+        if isinstance(m, (nn.Conv1d, nn.Linear)):
+            nn.init.normal_(m.weight, mean=0, std=0.02)
+            nn.init.zeros_(m.bias)
 
 
     def forward(self, mels, mel_masks: torch.Tensor):
     def forward(self, mels, mel_masks: torch.Tensor):
         """
         """
@@ -194,8 +207,9 @@ class SpeakerEncoder(nn.Module):
 
 
         # Avg Pooling
         # Avg Pooling
         x = x * mel_masks
         x = x * mel_masks
-        x = torch.sum(x, dim=2) / torch.sum(mel_masks, dim=2)
-        x = self.out_proj(x)[..., None]
+        x = self.out_proj(x)
+        x = torch.sum(x, dim=-1) / torch.sum(mel_masks, dim=-1)
+        x = x[..., None]
 
 
         return x
         return x
 
 

+ 6 - 3
fish_speech/models/vqgan/modules/models.py

@@ -121,7 +121,10 @@ class SynthesizerTrn(nn.Module):
         x_masks = torch.unsqueeze(sequence_mask(x_lengths, x.shape[2]), 1).to(x.dtype)
         x_masks = torch.unsqueeze(sequence_mask(x_lengths, x.shape[2]), 1).to(x.dtype)
 
 
         g = self.enc_spk(specs, spec_masks)
         g = self.enc_spk(specs, spec_masks)
-        x, vq_loss = self.vq(x, x_masks)
+
+        # with torch.no_grad():
+        #     x, _ = self.vq(x, x_masks)
+        #     vq_loss = 0
 
 
         _, m_p, logs_p, _, _ = self.enc_p(x, x_masks, g=g)
         _, m_p, logs_p, _, _ = self.enc_p(x, x_masks, g=g)
         z_q, m_q, logs_q, _ = self.enc_q(specs, spec_masks, g=g)
         z_q, m_q, logs_q, _ = self.enc_q(specs, spec_masks, g=g)
@@ -138,7 +141,7 @@ class SynthesizerTrn(nn.Module):
             (z_q, z_p),
             (z_q, z_p),
             (m_p, logs_p),
             (m_p, logs_p),
             (m_q, logs_q),
             (m_q, logs_q),
-            vq_loss,
+            # vq_loss,
         )
         )
 
 
     def infer(self, x, x_lengths, specs, max_len=None, noise_scale=0.35):
     def infer(self, x, x_lengths, specs, max_len=None, noise_scale=0.35):
@@ -148,7 +151,7 @@ class SynthesizerTrn(nn.Module):
         )
         )
         x_masks = torch.unsqueeze(sequence_mask(x_lengths, x.shape[2]), 1).to(x.dtype)
         x_masks = torch.unsqueeze(sequence_mask(x_lengths, x.shape[2]), 1).to(x.dtype)
         g = self.enc_spk(specs, spec_masks)
         g = self.enc_spk(specs, spec_masks)
-        x, vq_loss = self.vq(x, x_masks)
+        # x, vq_loss = self.vq(x, x_masks)
         z_p, m_p, logs_p, h_text, _ = self.enc_p(
         z_p, m_p, logs_p, h_text, _ = self.enc_p(
             x, x_masks, g=g, noise_scale=noise_scale
             x, x_masks, g=g, noise_scale=noise_scale
         )
         )

+ 22 - 12
tools/vqgan/calculate_hubert_features.py

@@ -44,7 +44,7 @@ def get_hubert_model():
     return model
     return model
 
 
 
 
-def process_batch(files: list[Path]):
+def process_batch(files: list[Path], kmeans_centers: torch.Tensor) -> float:
     model = get_hubert_model()
     model = get_hubert_model()
 
 
     wavs = []
     wavs = []
@@ -81,25 +81,32 @@ def process_batch(files: list[Path]):
 
 
     # Calculate lengths
     # Calculate lengths
     with torch.no_grad():
     with torch.no_grad():
-        outputs = model(wavs, attention_mask=attention_mask)
+        outputs = model(wavs, attention_mask=attention_mask).last_hidden_state
+
+        # Find closest centroids
+        kmeans_centers = kmeans_centers.to(dtype=outputs.dtype, device=outputs.device)
+        distances = torch.cdist(outputs, kmeans_centers)
+        outputs = torch.min(distances, dim=-1)
+        avg_distance = torch.mean(outputs.values)
 
 
     # Save to disk
     # Save to disk
-    outputs = outputs.last_hidden_state.cpu().numpy()
+    outputs = outputs.indices.cpu().numpy()
 
 
     for file, length, feature, wav in zip(files, feature_lengths, outputs, wavs):
     for file, length, feature, wav in zip(files, feature_lengths, outputs, wavs):
         feature = feature[:length]
         feature = feature[:length]
 
 
-        # (T, 1024)
+        # (T,)
         with open(file.with_suffix(".npy"), "wb") as f:
         with open(file.with_suffix(".npy"), "wb") as f:
             np.save(f, feature)
             np.save(f, feature)
 
 
-    return total_time
+    return total_time, avg_distance
 
 
 
 
 @click.command()
 @click.command()
 @click.argument("folder")
 @click.argument("folder")
 @click.option("--num-workers", default=1)
 @click.option("--num-workers", default=1)
-def main(folder: str, num_workers: int):
+@click.option("--kmeans", default="results/hubert-vq-pretrain/kmeans.pt")
+def main(folder: str, num_workers: int, kmeans: str):
     if num_workers > 1 and WORLD_SIZE != num_workers:
     if num_workers > 1 and WORLD_SIZE != num_workers:
         assert WORLD_SIZE == 1, "You should either use SLURM or this launcher, not both"
         assert WORLD_SIZE == 1, "You should either use SLURM or this launcher, not both"
 
 
@@ -140,16 +147,22 @@ def main(folder: str, num_workers: int):
     files = files[RANK::WORLD_SIZE]
     files = files[RANK::WORLD_SIZE]
     logger.info(f"Processing {len(files)}/{total_files} files")
     logger.info(f"Processing {len(files)}/{total_files} files")
 
 
+    # Load kmeans
+    kmeans_centers = torch.load(kmeans)["centroids"]
+
     # Batch size 64
     # Batch size 64
     total_time = 0
     total_time = 0
     begin_time = time.time()
     begin_time = time.time()
     processed_files = 0
     processed_files = 0
+    total_distance = 0
 
 
     for n_batch, idx in enumerate(range(0, len(files), 32)):
     for n_batch, idx in enumerate(range(0, len(files), 32)):
         batch = files[idx : idx + 32]
         batch = files[idx : idx + 32]
-        batch_time = process_batch(batch)
+        batch_time, avg_distance = process_batch(batch, kmeans_centers)
+
         total_time += batch_time
         total_time += batch_time
         processed_files += len(batch)
         processed_files += len(batch)
+        total_distance += avg_distance
 
 
         if (n_batch + 1) % 10 == 0:
         if (n_batch + 1) % 10 == 0:
             eta = (
             eta = (
@@ -158,13 +171,10 @@ def main(folder: str, num_workers: int):
                 * (len(files) - processed_files)
                 * (len(files) - processed_files)
             )
             )
             logger.info(
             logger.info(
-                f"Processed {processed_files} files, {total_time / 3600:.2f} hours of audio, ETA: {timedelta(seconds=round(eta))}s"
+                f"Processed {processed_files} files, {total_time / 3600:.2f} hours of audio, "
+                + f"err {total_distance/(n_batch+1):.2f}, ETA: {timedelta(seconds=round(eta))}s"
             )
             )
 
 
-        # Stop after 1000 hours
-        if total_time * WORLD_SIZE > 3600 * 1000:
-            break
-
     logger.info(
     logger.info(
         f"Finished processing {len(files)} files, {total_time / 3600:.2f} hours of audio"
         f"Finished processing {len(files)} files, {total_time / 3600:.2f} hours of audio"
     )
     )