Преглед изворни кода

Fix bug & finish vqgan data module implementing

Lengyue пре 2 година
родитељ
комит
72a5fbbfa2

+ 14 - 18
fish_speech/configs/hubert_vq.yaml

@@ -12,17 +12,13 @@ trainer:
   precision: 32
   max_steps: 1_000_000
 
-# Dataset Configuration
-tokenizer:
-  _target_: transformers.AutoTokenizer.from_pretrained
-  pretrained_model_name_or_path: fishaudio/speech-lm-300m
-  revision: text-pretrain-10k
+sample_rate: 32000
 
 # Dataset Configuration
 train_dataset:
-  - _target_: fish_speech.datasets.text.TextDataset
-    repo: fishaudio/cn-hubert-25hz-vq
-    prefix: 'data/train'
+  _target_: fish_speech.datasets.vqgan.VQGANDataset
+  filelist: data/test.filelist
+  sample_rate: ${sample_rate}
 
 val_dataset:
   _target_: fish_speech.datasets.text.TextDataset
@@ -42,9 +38,10 @@ model:
   _target_: fish_speech.models.vqgan.VQGAN
 
   encoder:
-    _target_: fish_speech.models.modules.VQEncoder
+    _target_: fish_speech.models.vqgan.modules.VQEncoder
     in_channels: 1024
     channels: 192
+    num_mels: 128
     num_heads: 2
     num_feature_layers: 2
     num_speaker_layers: 4
@@ -54,7 +51,7 @@ model:
     freeze_vq: false
 
   generator:
-    _target_: fish_speech.models.modules.Generator
+    _target_: fish_speech.models.vqgan.modules.Generator
     initial_channel: 192
     resblock: "1"
     resblock_kernel_sizes: [3, 7, 11]
@@ -67,11 +64,11 @@ model:
     upsample_kernel_sizes: [16, 16, 8, 2, 2]
 
   discriminator:
-    _target_: fish_speech.models.modules.EnsembleDiscriminator
+    _target_: fish_speech.models.vqgan.modules.EnsembleDiscriminator
 
   mel_transform:
-    _target_: fish_speech.models.spectrogram.LogMelSpectrogram
-    sample_rate: 32000
+    _target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram
+    sample_rate: ${sample_rate}
     n_fft: 2048
     hop_length: 640
     win_length: 2048
@@ -93,8 +90,7 @@ model:
       num_warmup_steps: 2000
       num_training_steps: ${trainer.max_steps}
       final_lr_ratio: 0.05
-  
-  # Restore from old checkpoint
-  generator_ckpt: results/hubert-vq-pretrain/rcell/G_23000.pth
-  discriminator_ckpt: results/hubert-vq-pretrain/rcell/D_23000.pth
-  kmeans_ckpt: results/hubert-vq-pretrain/rcell/kmeans_23000.pth
+
+# Resume from rcell's checkpoint
+ckpt_path: results/hubert-vq-pretrain/rcell/ckpt_23000_pl.pth
+resume_weights_only: true

+ 74 - 107
fish_speech/datasets/vqgan.py

@@ -2,9 +2,10 @@ from dataclasses import dataclass
 from pathlib import Path
 
 import librosa
+import numpy as np
 import torch
+from lightning import LightningDataModule
 from torch.utils.data import Dataset
-from transformers import WhisperProcessor
 
 
 class VQGANDataset(Dataset):
@@ -19,6 +20,7 @@ class VQGANDataset(Dataset):
         root = filelist.parent
 
         self.files = [root / line.strip() for line in filelist.read_text().splitlines()]
+        self.sample_rate = sample_rate
 
     def __len__(self):
         return len(self.files)
@@ -26,129 +28,94 @@ class VQGANDataset(Dataset):
     def __getitem__(self, idx):
         file = self.files[idx]
 
+        audio, _ = librosa.load(file, sr=self.sample_rate, mono=True)
+        features = np.load(file.with_suffix(".npy"))  # (T, 1024)
+
+        return {
+            "audio": torch.from_numpy(audio),
+            "features": torch.from_numpy(features),
+        }
+
 
 @dataclass
-class WhisperVQCollator:
-    processor = WhisperProcessor.from_pretrained("openai/whisper-medium")
+class VQGANCollator:
+    hop_length: int = 640
 
     def __call__(self, batch):
-        # -> {"input_values": ..., "input_features": ..., "input_ids": ..., "decoder_attention_mask": ...}
-        max_values_length = max([x["input_values"].shape[-1] for x in batch])
-        max_ids_length = max([x["input_ids"].shape[-1] for x in batch])
-
-        input_values = []
-        decoder_attention_mask = []
-        decoder_input_ids = []
-        input_features = torch.stack([x["input_features"] for x in batch])
-        encoder_attention_mask = torch.stack([x["mel_mask"] for x in batch])
-
-        for data in batch:
-            values_length = data["input_values"].shape[-1]
-            x = torch.nn.functional.pad(
-                data["input_values"], (0, max_values_length - values_length)
-            )
-            input_values.append(x)
+        audio_lengths = torch.tensor([len(x["audio"]) for x in batch])
+        feature_lengths = torch.tensor([len(x["features"]) for x in batch])
 
-            ids_length = data["input_ids"].shape[-1]
-            ids = torch.nn.functional.pad(
-                data["input_ids"],
-                (0, max_ids_length - ids_length),
-                value=self.processor.tokenizer.pad_token_id,
-            )
-            decoder_input_ids.append(ids)
+        audio_maxlen = audio_lengths.max()
+        feature_maxlen = feature_lengths.max()
 
-            x = torch.zeros(max_ids_length, dtype=torch.float)
-            x[:ids_length] = 1
-            decoder_attention_mask.append(x)
+        if audio_maxlen % self.hop_length != 0:
+            audio_maxlen += self.hop_length - (audio_maxlen % self.hop_length)
 
-        decoder_input_ids = torch.stack(decoder_input_ids)
-        decoder_attention_mask = torch.stack(decoder_attention_mask)
-        labels = decoder_input_ids.clone()
-        labels[decoder_attention_mask == 0] = -100
-        labels[:, :4] = -100  # BOS, LANG, TRANSCRIBE, NOTIMESTAMPS
+        audios, features = [], []
+        for x in batch:
+            audios.append(
+                torch.nn.functional.pad(x["audio"], (0, audio_maxlen - len(x["audio"])))
+            )
+            features.append(
+                torch.nn.functional.pad(
+                    x["features"], (0, 0, 0, feature_maxlen - len(x["features"]))
+                )
+            )
 
         return {
-            "input_values": torch.stack(input_values),
-            "input_features": input_features,
-            "encoder_attention_mask": encoder_attention_mask,
-            "decoder_input_ids": decoder_input_ids[:, :-1],
-            "decoder_attention_mask": decoder_attention_mask[:, :-1],
-            "labels": labels[:, 1:],
+            "audios": torch.stack(audios),
+            "features": torch.stack(features),
+            "audio_lengths": audio_lengths,
+            "feature_lengths": feature_lengths,
         }
 
 
-if __name__ == "__main__":
-    import soundfile as sf
-    from torch.utils.data import DataLoader
-    from transformers import GenerationConfig
-
-    from fish_speech.models.whisper_vq import WhisperVQ
-    from fish_speech.modules.flash_whisper import FlashWhisperForConditionalGeneration
-
-    dataset = WhisperVQDataset("filelists/whisper-vq.test.filelist")
-    dataloader = DataLoader(
-        dataset, batch_size=4, shuffle=True, collate_fn=WhisperVQCollator()
-    )
-    # whisper = FlashWhisperForConditionalGeneration.from_pretrained(
-    #     "openai/whisper-medium"
-    # )
-    # whisper.eval()
-    our_whisper = WhisperVQ()
-    whisper = our_whisper.whisper
-    our_whisper.eval()
-
-    state_dict = torch.load(
-        "results/whisper-vq/checkpoints/step_10000.ckpt", map_location="cpu"
-    )["model"]
-    our_whisper.load_state_dict(state_dict, strict=True)
-    # whisper.cuda()
-
-    for batch in dataloader:
-        # batch = {k: v.cuda() for k, v in batch.items()}
-        print({k: v.shape for k, v in batch.items()})
+class VQGANDataModule(LightningDataModule):
+    def __init__(
+        self,
+        train_dataset: VQGANDataset,
+        val_dataset: VQGANDataset,
+        batch_size: int = 32,
+        hop_length: int = 640,
+        num_workers: int = 4,
+    ):
+        super().__init__()
 
-        outputs = whisper.generate(
-            inputs=batch["input_features"],
-            max_length=448,
-            do_sample=False,
+        self.train_dataset = train_dataset
+        self.val_dataset = val_dataset
+        self.batch_size = batch_size
+        self.hop_length = hop_length
+        self.num_workers = num_workers
+
+    def train_dataloader(self):
+        return DataLoader(
+            self.train_dataset,
+            batch_size=self.batch_size,
+            collate_fn=VQGANCollator(self.hop_length),
+            num_workers=self.num_workers,
+            shuffle=True,
         )
 
-        print(outputs, batch["decoder_input_ids"])
-        transcriptions = dataset.processor.batch_decode(
-            outputs, skip_special_tokens=True
+    def val_dataloader(self):
+        return DataLoader(
+            self.val_dataset,
+            batch_size=self.batch_size,
+            collate_fn=VQGANCollator(self.hop_length),
+            num_workers=self.num_workers,
         )
 
-        print(
-            transcriptions,
-            dataset.processor.batch_decode(batch["labels"], skip_special_tokens=True),
-        )
-        sf.write("test.wav", batch["input_values"][0].cpu().numpy(), 16000)
-
-        # Calculate loss
-        # encoder_outputs = whisper.model.encoder(
-        #     batch["input_features"],
-        # )
-        encoder_outputs = our_whisper.decode(
-            our_whisper.encode(
-                batch["input_features"],
-            )[0]
-        )
 
-        decoder_outputs = whisper.generate(
-            # decoder_input_ids=batch["decoder_input_ids"],
-            # decoder_attention_mask=batch["decoder_attention_mask"],
-            # labels=batch["labels"],
-            # generation_config=GenerationConfig(
-            #     encoder_outputs=(encoder_outputs,)
-            # ),
-            encoder_outputs,
-            max_length=448,
-            do_sample=False,
-            # forced_decoder_ids=batch["decoder_input_ids"][:, :4]
-            forced_decoder_ids=dataset.processor.get_decoder_prompt_ids(
-                language="english", task="transcribe"
-            ),
-        )
+if __name__ == "__main__":
+    from torch.utils.data import DataLoader
 
-        print("Our transcript:", dataset.processor.batch_decode(decoder_outputs))
+    dataset = VQGANDataset("data/LibriTTS_R/vq_train_filelist.txt")
+    dataloader = DataLoader(
+        dataset, batch_size=4, shuffle=False, collate_fn=VQGANCollator()
+    )
+
+    for batch in dataloader:
+        print(batch["audios"].shape)
+        print(batch["features"].shape)
+        print(batch["audio_lengths"])
+        print(batch["feature_lengths"])
         break

+ 1 - 0
fish_speech/models/vqgan/lit_module.py

@@ -26,6 +26,7 @@ class VQGAN(L.LightningModule):
 
         # Generator and discriminators
         # Compile generator so that snake can save memory
+        self.encoder = encoder
         self.generator = generator
         self.discriminator = discriminator
         self.mel_transform = mel_transform

+ 15 - 16
fish_speech/models/vqgan/modules.py

@@ -2,11 +2,11 @@ import math
 from dataclasses import dataclass
 
 import torch
+from encodec.quantization.core_vq import VectorQuantization
 from torch import nn
 from torch.nn import Conv1d, Conv2d, ConvTranspose1d
 from torch.nn import functional as F
 from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
-from vector_quantize_pytorch import VectorQuantize
 
 from fish_speech.models.vqgan.utils import convert_pad_shape, get_padding, init_weights
 
@@ -24,6 +24,7 @@ class VQEncoder(nn.Module):
         self,
         in_channels: int = 1024,
         channels: int = 192,
+        num_mels: int = 128,
         num_heads: int = 2,
         num_feature_layers: int = 2,
         num_speaker_layers: int = 4,
@@ -38,10 +39,12 @@ class VQEncoder(nn.Module):
         down_sample = 2 if input_downsample else 1
 
         self.vq_in = nn.Linear(in_channels * down_sample, in_channels)
-        self.vq = VectorQuantize(
+        self.vq = VectorQuantization(
             dim=in_channels,
             codebook_size=code_book_size,
             threshold_ema_dead_code=2,
+            kmeans_init=True,
+            kmeans_iters=50,
         )
 
         self.feature_in = nn.Linear(in_channels, channels)
@@ -62,7 +65,7 @@ class VQEncoder(nn.Module):
 
         # Speaker Encoder
         self.speaker_query = nn.Parameter(torch.randn(1, 1, channels))
-        self.speaker_in = nn.Linear(in_channels * down_sample, channels)
+        self.speaker_in = nn.Linear(num_mels, channels)
         self.speaker_blocks = nn.ModuleList(
             [
                 TransformerBlock(
@@ -99,8 +102,9 @@ class VQEncoder(nn.Module):
             for p in self.vq_in.parameters():
                 p.requires_grad = False
 
-    def forward(self, x, key_padding_mask=None):
-        # (batch, seq_len, channels)
+    def forward(self, x, mels, key_padding_mask=None):
+        # x: (batch, seq_len, channels)
+        # x: (batch, seq_len, 128)
 
         if self.input_downsample and key_padding_mask is not None:
             key_padding_mask = key_padding_mask[:, ::2]
@@ -119,6 +123,11 @@ class VQEncoder(nn.Module):
 
         features, _, loss = self.vq(features, mask=~key_padding_mask)
 
+        if self.input_downsample:
+            features = F.interpolate(
+                features.transpose(1, 2), scale_factor=2
+            ).transpose(1, 2)
+
         features = self.feature_in(features)
         for block in self.feature_blocks:
             features = block(features, key_padding_mask=key_padding_mask)
@@ -129,7 +138,7 @@ class VQEncoder(nn.Module):
             [self.speaker_query.expand(speaker.shape[0], -1, -1), speaker], dim=1
         )
         for block in self.speaker_blocks:
-            speaker = block(speaker)
+            speaker = block(mels, key_padding_mask=key_padding_mask)
 
         # Mix
         x = features + speaker[:, :1]
@@ -794,13 +803,3 @@ class EnsembleDiscriminator(nn.Module):
             fmap_gs.append(fmap_g)
 
         return y_d_rs, y_d_gs, fmap_rs, fmap_gs
-
-
-if __name__ == "__main__":
-    vq = VQEncoder()
-    x = torch.randn(1, 90, 1024)
-    key_padding_mask = torch.zeros(1, 90).bool()
-    key_padding_mask[:, 67:] = True
-
-    output = vq(x, key_padding_mask=key_padding_mask)
-    print(output)

+ 4 - 4
tools/vqgan/calculate_hubert_features.py

@@ -35,7 +35,7 @@ logger.add(sys.stderr, format=logger_format)
 
 @lru_cache(maxsize=1)
 def get_hubert_model():
-    model = HubertModel.from_pretrained("TencentGameMate/chinese-hubert-large")
+    model = HubertModel.from_pretrained("TencentGameMate/chinese-hubert-base")
     model = model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
     model = model.half()
     model.eval()
@@ -58,7 +58,7 @@ def process_batch(files: list[Path]):
         wav = torchaudio.functional.resample(wav.cuda(), sr, 16000)[0]
 
         if len(wav) > sr * 60:
-            continue
+            wav = wav[: sr * 60]
 
         wavs.append(wav)
         total_time += len(wav) / sr
@@ -73,8 +73,8 @@ def process_batch(files: list[Path]):
 
     for i, wav in enumerate(wavs):
         attention_mask[i, len(wav) :] = 0
+        feature_lengths.append(int(len(wav) / 320))
         wavs[i] = torch.nn.functional.pad(wav, (0, max_length - len(wav)), "constant")
-        feature_lengths.append(int(len(wav) / sr * 50))
 
     wavs = torch.stack(wavs, dim=0).half()
     attention_mask = attention_mask.cuda()
@@ -86,7 +86,7 @@ def process_batch(files: list[Path]):
     # Save to disk
     outputs = outputs.last_hidden_state.cpu().numpy()
 
-    for file, length, feature in zip(files, feature_lengths, outputs):
+    for file, length, feature, wav in zip(files, feature_lengths, outputs, wavs):
         feature = feature[:length]
 
         # (T, 1024)

+ 6 - 2
tools/vqgan/create_train_split.py

@@ -10,10 +10,14 @@ from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files
 @click.command()
 @click.argument("root", type=click.Path(exists=True, path_type=Path))
 def main(root):
-    files = list_files(root, AUDIO_EXTENSIONS, recursive=True, show_progress=True)
+    files = list_files(root, AUDIO_EXTENSIONS, recursive=True)
     print(f"Found {len(files)} files")
 
-    files = [str(file) for file in tqdm(files) if file.with_suffix(".npy").exists()]
+    files = [
+        str(file.relative_to(root))
+        for file in tqdm(files)
+        if file.with_suffix(".npy").exists()
+    ]
     print(f"Found {len(files)} files with features")
 
     Random(42).shuffle(files)