Przeglądaj źródła

new vqgan config

Lengyue 2 lat temu
rodzic
commit
eff5909b00

+ 134 - 0
fish_speech/configs/vqgan_single.yaml

@@ -0,0 +1,134 @@
+defaults:
+  - base
+  - _self_
+
+project: vqgan_single
+
+# Lightning Trainer
+trainer:
+  accelerator: gpu
+  devices: 4
+  strategy: ddp_find_unused_parameters_true
+  precision: 32
+  max_steps: 1_000_000
+  val_check_interval: 5000
+
+sample_rate: 22050
+hop_length: 256
+num_mels: 80
+n_fft: 1024
+win_length: 1024
+segment_size: 256
+
+# Dataset Configuration
+train_dataset:
+  _target_: fish_speech.datasets.vqgan.VQGANDataset
+  filelist: data/filelist.split.train
+  sample_rate: ${sample_rate}
+  hop_length: ${hop_length}
+  slice_frames: ${segment_size}
+
+val_dataset:
+  _target_: fish_speech.datasets.vqgan.VQGANDataset
+  filelist: data/filelist.split.valid
+  sample_rate: ${sample_rate}
+  hop_length: ${hop_length}
+
+data:
+  _target_: fish_speech.datasets.vqgan.VQGANDataModule
+  train_dataset: ${train_dataset}
+  val_dataset: ${val_dataset}
+  num_workers: 4
+  batch_size: 32
+  val_batch_size: 4
+
+# Model Configuration
+model:
+  _target_: fish_speech.models.vqgan.VQGAN
+  sample_rate: ${sample_rate}
+  hop_length: ${hop_length}
+  segment_size: 8192
+  freeze_hifigan: false
+
+  downsample:
+    _target_: fish_speech.models.vqgan.modules.encoders.ConvDownSampler
+    dims: ["${num_mels}", 512, 256]
+    kernel_sizes: [3, 3]
+    strides: [2, 2]
+
+  mel_encoder:
+    _target_: fish_speech.models.vqgan.modules.modules.WN
+    hidden_channels: 256
+    kernel_size: 3
+    dilation_rate: 2
+    n_layers: 6
+
+  vq_encoder:
+    _target_: fish_speech.models.vqgan.modules.encoders.VQEncoder
+    in_channels: 256
+    vq_channels: 256
+    codebook_size: 4096
+    codebook_groups: 1
+    downsample: 1
+
+  speaker_encoder:
+    _target_: fish_speech.models.vqgan.modules.encoders.SpeakerEncoder
+    in_channels: ${num_mels}
+    hidden_channels: 256
+    out_channels: 512
+    num_layers: 6
+
+  decoder:
+    _target_: fish_speech.models.vqgan.modules.modules.WN
+    hidden_channels: 256
+    out_channels: ${num_mels}
+    kernel_size: 3
+    dilation_rate: 2
+    n_layers: 6
+    gin_channels: 512
+
+  generator:
+    _target_: fish_speech.models.vqgan.modules.decoder.Generator
+    initial_channel: ${num_mels}
+    resblock: "1"
+    resblock_kernel_sizes: [3, 7, 11]
+    resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
+    upsample_rates: [8, 8, 2, 2]
+    upsample_initial_channel: 512
+    upsample_kernel_sizes: [16, 16, 4, 4]
+    ckpt_path: "checkpoints/hifigan-v1-universal-22050/g_02500000"
+
+  discriminator:
+    _target_: fish_speech.models.vqgan.modules.discriminator.EnsembleDiscriminator
+    ckpt_path: checkpoints/hifigan-v1-universal-22050/do_02500000
+
+  mel_transform:
+    _target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram
+    sample_rate: ${sample_rate}
+    n_fft: ${n_fft}
+    hop_length: ${hop_length}
+    win_length: ${win_length}
+    n_mels: ${num_mels}
+    f_min: 0
+    f_max: 8000
+
+  optimizer:
+    _target_: torch.optim.AdamW
+    _partial_: true
+    lr: 2e-4
+    betas: [0.8, 0.99]
+    eps: 1e-5
+
+  lr_scheduler:
+    _target_: torch.optim.lr_scheduler.ExponentialLR
+    _partial_: true
+    gamma: 0.999999  # Estimated base on LibriTTS dataset
+
+callbacks:
+  grad_norm_monitor:
+    sub_module: 
+      - generator
+      - discriminator
+      - mel_encoder
+      - vq_encoder
+      - decoder

+ 134 - 0
fish_speech/configs/vqgan_single_2x.yaml

@@ -0,0 +1,134 @@
+defaults:
+  - base
+  - _self_
+
+project: vqgan_single_2x
+
+# Lightning Trainer
+trainer:
+  accelerator: gpu
+  devices: 4
+  strategy: ddp_find_unused_parameters_true
+  precision: 32
+  max_steps: 1_000_000
+  val_check_interval: 5000
+
+sample_rate: 22050
+hop_length: 256
+num_mels: 80
+n_fft: 1024
+win_length: 1024
+segment_size: 256
+
+# Dataset Configuration
+train_dataset:
+  _target_: fish_speech.datasets.vqgan.VQGANDataset
+  filelist: data/filelist.split.train
+  sample_rate: ${sample_rate}
+  hop_length: ${hop_length}
+  slice_frames: ${segment_size}
+
+val_dataset:
+  _target_: fish_speech.datasets.vqgan.VQGANDataset
+  filelist: data/filelist.split.valid
+  sample_rate: ${sample_rate}
+  hop_length: ${hop_length}
+
+data:
+  _target_: fish_speech.datasets.vqgan.VQGANDataModule
+  train_dataset: ${train_dataset}
+  val_dataset: ${val_dataset}
+  num_workers: 4
+  batch_size: 32
+  val_batch_size: 4
+
+# Model Configuration
+model:
+  _target_: fish_speech.models.vqgan.VQGAN
+  sample_rate: ${sample_rate}
+  hop_length: ${hop_length}
+  segment_size: 8192
+  freeze_hifigan: false
+
+  downsample:
+    _target_: fish_speech.models.vqgan.modules.encoders.ConvDownSampler
+    dims: ["${num_mels}", 512, 384]
+    kernel_sizes: [3, 3]
+    strides: [2, 2]
+
+  mel_encoder:
+    _target_: fish_speech.models.vqgan.modules.modules.WN
+    hidden_channels: 384
+    kernel_size: 3
+    dilation_rate: 2
+    n_layers: 12
+
+  vq_encoder:
+    _target_: fish_speech.models.vqgan.modules.encoders.VQEncoder
+    in_channels: 384
+    vq_channels: 384
+    codebook_size: 4096
+    codebook_groups: 1
+    downsample: 1
+
+  speaker_encoder:
+    _target_: fish_speech.models.vqgan.modules.encoders.SpeakerEncoder
+    in_channels: ${num_mels}
+    hidden_channels: 384
+    out_channels: 512
+    num_layers: 12
+
+  decoder:
+    _target_: fish_speech.models.vqgan.modules.modules.WN
+    hidden_channels: 384
+    out_channels: ${num_mels}
+    kernel_size: 3
+    dilation_rate: 2
+    n_layers: 12
+    gin_channels: 512
+
+  generator:
+    _target_: fish_speech.models.vqgan.modules.decoder.Generator
+    initial_channel: ${num_mels}
+    resblock: "1"
+    resblock_kernel_sizes: [3, 7, 11]
+    resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
+    upsample_rates: [8, 8, 2, 2]
+    upsample_initial_channel: 512
+    upsample_kernel_sizes: [16, 16, 4, 4]
+    ckpt_path: "checkpoints/hifigan-v1-universal-22050/g_02500000"
+
+  discriminator:
+    _target_: fish_speech.models.vqgan.modules.discriminator.EnsembleDiscriminator
+    ckpt_path: checkpoints/hifigan-v1-universal-22050/do_02500000
+
+  mel_transform:
+    _target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram
+    sample_rate: ${sample_rate}
+    n_fft: ${n_fft}
+    hop_length: ${hop_length}
+    win_length: ${win_length}
+    n_mels: ${num_mels}
+    f_min: 0
+    f_max: 8000
+
+  optimizer:
+    _target_: torch.optim.AdamW
+    _partial_: true
+    lr: 2e-4
+    betas: [0.8, 0.99]
+    eps: 1e-5
+
+  lr_scheduler:
+    _target_: torch.optim.lr_scheduler.ExponentialLR
+    _partial_: true
+    gamma: 0.999999  # Estimated base on LibriTTS dataset
+
+callbacks:
+  grad_norm_monitor:
+    sub_module: 
+      - generator
+      - discriminator
+      - mel_encoder
+      - vq_encoder
+      - decoder

+ 14 - 2
fish_speech/models/vqgan/lit_module.py

@@ -49,6 +49,7 @@ class VQGAN(L.LightningModule):
         sample_rate: int = 32000,
         freeze_hifigan: bool = False,
         freeze_vq: bool = False,
+        speaker_encoder: SpeakerEncoder = None,
     ):
         super().__init__()
 
@@ -60,6 +61,7 @@ class VQGAN(L.LightningModule):
         self.downsample = downsample
         self.vq_encoder = vq_encoder
         self.mel_encoder = mel_encoder
+        self.speaker_encoder = speaker_encoder
         self.decoder = decoder
         self.generator = generator
         self.discriminator = discriminator
@@ -168,7 +170,12 @@ class VQGAN(L.LightningModule):
         )
 
         # Sample mels
-        decoded_mels = self.decoder(text_features, mel_masks)
+        speaker_features = (
+            self.speaker_encoder(gt_mels, mel_masks)
+            if self.speaker_encoder is not None
+            else None
+        )
+        decoded_mels = self.decoder(text_features, mel_masks, g=speaker_features)
         fake_audios = self.generator(decoded_mels)
 
         y_hat_mels = self.mel_transform(fake_audios.squeeze(1))
@@ -316,7 +323,12 @@ class VQGAN(L.LightningModule):
         )
 
         # Sample mels
-        decoded_mels = self.decoder(text_features, mel_masks)
+        speaker_features = (
+            self.speaker_encoder(gt_mels, mel_masks)
+            if self.speaker_encoder is not None
+            else None
+        )
+        decoded_mels = self.decoder(text_features, mel_masks, g=speaker_features)
         fake_audios = self.generator(decoded_mels)
 
         fake_mels = self.mel_transform(fake_audios.squeeze(1))

+ 10 - 15
fish_speech/models/vqgan/modules/encoders.py

@@ -221,36 +221,26 @@ class SpeakerEncoder(nn.Module):
         in_channels: int = 128,
         hidden_channels: int = 192,
         out_channels: int = 512,
-        num_heads: int = 2,
         num_layers: int = 4,
-        p_dropout: float = 0.0,
     ) -> None:
         super().__init__()
 
         self.in_proj = nn.Sequential(
             nn.Conv1d(in_channels, hidden_channels, 1),
             nn.Mish(),
-            nn.Dropout(p_dropout),
             nn.Conv1d(hidden_channels, hidden_channels, 5, padding=2),
             nn.Mish(),
-            nn.Dropout(p_dropout),
             nn.Conv1d(hidden_channels, hidden_channels, 5, padding=2),
             nn.Mish(),
-            nn.Dropout(p_dropout),
         )
         self.out_proj = nn.Conv1d(hidden_channels, out_channels, 1)
         self.apply(self._init_weights)
 
-        self.encoder = RelativePositionTransformer(
-            in_channels=hidden_channels,
-            out_channels=hidden_channels,
-            hidden_channels=hidden_channels,
-            hidden_channels_ffn=hidden_channels,
-            n_heads=num_heads,
+        self.encoder = WN(
+            hidden_channels,
+            kernel_size=3,
+            dilation_rate=1,
             n_layers=num_layers,
-            kernel_size=1,
-            dropout=p_dropout,
-            window_size=None,  # No windowing
         )
 
     def _init_weights(self, m):
@@ -338,7 +328,12 @@ class VQEncoder(nn.Module):
         return x, indices, loss
 
     def decode(self, indices):
-        q = self.vq.get_output_from_indices(indices).mT
+        q = self.vq.get_output_from_indices(indices)
+
+        if q.shape[1] != indices.shape[1]:
+            q = q.view(q.shape[0], indices.shape[1], -1)
+        q = q.mT
+
         x = self.conv_out(q)
 
         return x

+ 40 - 16
tools/infer_vq.py

@@ -20,11 +20,11 @@ OmegaConf.register_new_resolver("eval", eval)
 @torch.autocast(device_type="cuda", enabled=True)
 def main():
     with initialize(version_base="1.3", config_path="../fish_speech/configs"):
-        cfg = compose(config_name="vqgan")
+        cfg = compose(config_name="vqgan_single_2x")
 
     model: LightningModule = instantiate(cfg.model)
     state_dict = torch.load(
-        "checkpoints/vqgan/step_000380000.ckpt",
+        "results/vqgan_single_2x/checkpoints/step_000160000.ckpt",
         map_location=model.device,
     )["state_dict"]
     model.load_state_dict(state_dict, strict=True)
@@ -33,7 +33,11 @@ def main():
     logger.info("Restored model from checkpoint")
 
     # Load audio
-    audio = librosa.load("test.wav", sr=model.sampling_rate, mono=True)[0]
+    audio = librosa.load(
+        "data/StarRail/Chinese/停云/chapter2_1_tingyun_142.wav",
+        sr=model.sampling_rate,
+        mono=True,
+    )[0]
     audios = torch.from_numpy(audio).to(model.device)[None, None, :]
     logger.info(
         f"Loaded audio with {audios.shape[2] / model.sampling_rate:.2f} seconds"
@@ -69,29 +73,49 @@ def main():
     print(indices.shape)
 
     # Restore
-    indices = np.load("codes_0.npy")
-    indices = torch.from_numpy(indices).to(model.device).long()
-    indices = indices.unsqueeze(1).unsqueeze(-1)
-    mel_lengths = indices.shape[2] * (
-        model.downsample.total_strides if model.downsample is not None else 1
+    # indices = np.load("codes_0.npy")
+    # indices = torch.from_numpy(indices).to(model.device).long()
+    # indices = indices.unsqueeze(1).unsqueeze(-1)
+    # mel_lengths = indices.shape[2] * (
+    #     model.downsample.total_strides if model.downsample is not None else 1
+    # )
+    # mel_lengths = torch.tensor([mel_lengths], device=model.device, dtype=torch.long)
+    # mel_masks = torch.ones(
+    #     (1, 1, mel_lengths), device=model.device, dtype=torch.float32
+    # )
+
+    # print(mel_lengths)
+
+    # Reference speaker
+    ref_audio = librosa.load(
+        "data/StarRail/Chinese/符玄/chapter2_8_fuxuan_104.wav",
+        sr=model.sampling_rate,
+        mono=True,
+    )[0]
+    ref_audios = torch.from_numpy(ref_audio).to(model.device)[None, None, :]
+    ref_audio_lengths = torch.tensor(
+        [ref_audios.shape[2]], device=model.device, dtype=torch.long
     )
-    mel_lengths = torch.tensor([mel_lengths], device=model.device, dtype=torch.long)
-    mel_masks = torch.ones(
-        (1, 1, mel_lengths), device=model.device, dtype=torch.float32
-    )
-
-    print(mel_lengths)
+    ref_mels = model.mel_transform(ref_audios, sample_rate=model.sampling_rate)
+    ref_mel_lengths = ref_audio_lengths // model.hop_length
+    ref_mel_masks = torch.unsqueeze(
+        sequence_mask(ref_mel_lengths, ref_mels.shape[2]), 1
+    ).to(gt_mels.dtype)
+    speaker_features = model.speaker_encoder(ref_mels, ref_mel_masks)
+    # speaker_features = model.speaker_encoder(gt_mels, mel_masks)
 
+    print("indices", indices.shape)
     text_features = model.vq_encoder.decode(indices)
+
     logger.info(
         f"VQ Encoded, indices: {indices.shape} equivalent to "
-        + f"{1/(mel_lengths[0] * model.hop_length / model.sampling_rate / indices.shape[2]):.2f} Hz"
+        + f"{1/(mel_lengths[0] * model.hop_length / model.sampling_rate / indices.shape[1]):.2f} Hz"
     )
 
     text_features = F.interpolate(text_features, size=mel_lengths[0], mode="nearest")
 
     # Sample mels
-    decoded_mels = model.decoder(text_features, mel_masks)
+    decoded_mels = model.decoder(text_features, mel_masks, g=speaker_features)
     fake_audios = model.generator(decoded_mels)
 
     # Save audio

+ 36 - 14
tools/vqgan/extract_vq.py

@@ -42,27 +42,27 @@ logger.add(sys.stderr, format=logger_format)
 
 
 @lru_cache(maxsize=1)
-def get_model():
+def get_model(
+    config_name: str = "vqgan",
+    checkpoint_path: str = "checkpoints/vqgan/step_000380000.ckpt",
+):
     with initialize(version_base="1.3", config_path="../../fish_speech/configs"):
-        cfg = compose(config_name="vqgan")
+        cfg = compose(config_name=config_name)
 
     model: LightningModule = instantiate(cfg.model)
     state_dict = torch.load(
-        "checkpoints/vqgan/step_000380000.ckpt",
+        checkpoint_path,
         map_location=model.device,
     )["state_dict"]
     model.load_state_dict(state_dict, strict=True)
     model.eval()
     model.cuda()
-    logger.info("Restored model from checkpoint")
 
     logger.info(f"Loaded model")
     return model
 
 
-def process_batch(files: list[Path]) -> float:
-    model = get_model()
-
+def process_batch(files: list[Path], model) -> float:
     wavs = []
     audio_lengths = []
     max_length = total_time = 0
@@ -105,10 +105,19 @@ def process_batch(files: list[Path]) -> float:
             sequence_mask(feature_lengths, features.shape[2]), 1
         ).to(gt_mels.dtype)
 
-        # vq_features is 50 hz, need to convert to true mel size
         text_features = model.mel_encoder(features, feature_masks)
         _, indices, _ = model.vq_encoder(text_features, feature_masks)
-        indices = indices.squeeze(-1)
+
+        if indices.ndim == 4:
+            # Grouped vq
+            assert indices.shape[-1] == 1, f"Residual vq is not supported"
+            indices = indices.squeeze(-1)
+        elif indices.ndim == 2:
+            # Single vq
+            indices = indices.unsqueeze(0)
+        else:
+            raise ValueError(f"Invalid indices shape {indices.shape}")
+
         indices = rearrange(indices, "c b t -> b c t")
 
     # Save to disk
@@ -127,7 +136,19 @@ def process_batch(files: list[Path]) -> float:
 @click.command()
 @click.argument("folder")
 @click.option("--num-workers", default=1)
-def main(folder: str, num_workers: int):
+@click.option("--config-name", default="vqgan")
+@click.option(
+    "--checkpoint-path",
+    default="checkpoints/vqgan/step_000380000.ckpt",
+)
+@click.option("--batch-size", default=64)
+def main(
+    folder: str,
+    num_workers: int,
+    config_name: str,
+    checkpoint_path: str,
+    batch_size: int,
+):
     if num_workers > 1 and WORLD_SIZE != num_workers:
         assert WORLD_SIZE == 1, "You should either use SLURM or this launcher, not both"
 
@@ -168,14 +189,15 @@ def main(folder: str, num_workers: int):
     files = files[RANK::WORLD_SIZE]
     logger.info(f"Processing {len(files)}/{total_files} files")
 
-    # Batch size 64
+    # Batch processing
     total_time = 0
     begin_time = time.time()
     processed_files = 0
+    model = get_model(config_name, checkpoint_path)
 
-    for n_batch, idx in enumerate(range(0, len(files), 32)):
-        batch = files[idx : idx + 32]
-        batch_time = process_batch(batch)
+    for n_batch, idx in enumerate(range(0, len(files), batch_size)):
+        batch = files[idx : idx + batch_size]
+        batch_time = process_batch(batch, model)
 
         total_time += batch_time
         processed_files += len(batch)