Lengyue 2 лет назад
Родитель
Сommit
471d8e1da8

+ 21 - 15
fish_speech/configs/vqgan_pretrain.yaml

@@ -2,12 +2,12 @@ defaults:
   - base
   - _self_
 
-project: vq-group-fsq-8x1024-wn-20x768-cond
+project: vq-gan-group-fsq-8x1024-wn-20x768-cond
 
 # Lightning Trainer
 trainer:
   accelerator: gpu
-  devices: 1
+  devices: auto
   precision: bf16-mixed
   max_steps: 1_000_000
   val_check_interval: 5000
@@ -21,15 +21,22 @@ win_length: 2048
 
 # Dataset Configuration
 train_dataset:
-  _target_: fish_speech.datasets.vqgan.VQGANDataset
-  filelist: /***REMOVED***/workspace/diffusion-test/data/HiFi-TTS/vq_train_filelist.txt
-  sample_rate: ${sample_rate}
-  hop_length: ${hop_length}
-  slice_frames: 512
+  _target_: torch.utils.data.ConcatDataset
+  datasets:
+    - _target_: fish_speech.datasets.vqgan.VQGANDataset
+      filelist: data/libri-light-pack/filelist.txt
+      sample_rate: ${sample_rate}
+      hop_length: ${hop_length}
+      slice_frames: 512
+    - _target_: fish_speech.datasets.vqgan.VQGANDataset
+      filelist: data/sft/vq_train_filelist.txt
+      sample_rate: ${sample_rate}
+      hop_length: ${hop_length}
+      slice_frames: 512
 
 val_dataset:
   _target_: fish_speech.datasets.vqgan.VQGANDataset
-  filelist: /***REMOVED***/workspace/diffusion-test/data/HiFi-TTS/vq_val_filelist.txt
+  filelist: data/sft/vq_val_filelist.txt
   sample_rate: ${sample_rate}
   hop_length: ${hop_length}
 
@@ -38,23 +45,19 @@ data:
   train_dataset: ${train_dataset}
   val_dataset: ${val_dataset}
   num_workers: 4
-  batch_size: 64
-  val_batch_size: 64
+  batch_size: 32
+  val_batch_size: 32
 
 # Model Configuration
 model:
   _target_: fish_speech.models.vqgan.VQGAN
 
   sampling_rate: ${sample_rate}
-  weight_reflow: 1.0
+  weight_adv: 0.5
   weight_vq: 1.0
   weight_mel: 1.0
   freeze_encoder: false
 
-  # Reflow configs
-  reflow_inference_steps: 10
-  reflow_inference_start_t: 0.0
-
   encoder:
     _target_: fish_speech.models.vqgan.modules.wavenet.WaveNet
     input_channels: ${num_mels}
@@ -76,6 +79,9 @@ model:
     residual_layers: 20
     dilation_cycle: 4
     condition_channels: 768
+  
+  discriminator:
+    _target_: fish_speech.models.vqgan.modules.discriminator.Discriminator
 
   vocoder:
     _target_: fish_speech.models.vqgan.modules.firefly.FireflyBase

+ 17 - 18
fish_speech/models/vqgan/lit_module.py

@@ -11,6 +11,7 @@ from torch import nn
 
 from fish_speech.models.vqgan.modules.wavenet import WaveNet
 from fish_speech.models.vqgan.utils import plot_mel, sequence_mask
+from fish_speech.models.vqgan.modules.discriminator import Discriminator
 
 
 class VQGAN(L.LightningModule):
@@ -21,16 +22,14 @@ class VQGAN(L.LightningModule):
         encoder: WaveNet,
         quantizer: nn.Module,
         decoder: WaveNet,
-        # reflow: nn.Module,
+        discriminator: Discriminator,
         vocoder: nn.Module,
         mel_transform: nn.Module,
-        weight_reflow: float = 1.0,
+        weight_adv: float = 1.0,
         weight_vq: float = 1.0,
         weight_mel: float = 1.0,
         sampling_rate: int = 44100,
         freeze_encoder: bool = False,
-        reflow_inference_steps: int = 10,
-        reflow_inference_start_t: float = 0.5,
     ):
         super().__init__()
 
@@ -43,7 +42,7 @@ class VQGAN(L.LightningModule):
         self.quantizer = quantizer
         self.decoder = decoder
         self.vocoder = vocoder
-        # self.reflow = reflow
+        self.discriminator = discriminator
         self.mel_transform = mel_transform
 
         # Freeze vocoder
@@ -51,7 +50,7 @@ class VQGAN(L.LightningModule):
             param.requires_grad = False
 
         # Loss weights
-        self.weight_reflow = weight_reflow
+        self.weight_adv = weight_adv
         self.weight_vq = weight_vq
         self.weight_mel = weight_mel
 
@@ -59,8 +58,6 @@ class VQGAN(L.LightningModule):
         self.spec_min = -12
         self.spec_max = 3
         self.sampling_rate = sampling_rate
-        self.reflow_inference_steps = reflow_inference_steps
-        self.reflow_inference_start_t = reflow_inference_start_t
 
         # Disable strict loading
         self.strict_loading = False
@@ -72,6 +69,8 @@ class VQGAN(L.LightningModule):
 
             for param in self.quantizer.parameters():
                 param.requires_grad = False
+            
+        self.automatic_optimization = False
 
     def on_save_checkpoint(self, checkpoint):
         # Do not save vocoder
@@ -81,16 +80,16 @@ class VQGAN(L.LightningModule):
                 state_dict.pop(name)
 
     def configure_optimizers(self):
-        optimizer = self.optimizer_builder(self.parameters())
-        lr_scheduler = self.lr_scheduler_builder(optimizer)
-
-        return {
-            "optimizer": optimizer,
-            "lr_scheduler": {
-                "scheduler": lr_scheduler,
-                "interval": "step",
-            },
-        }
+        # optimizer = self.optimizer_builder(self.parameters())
+        # lr_scheduler = self.lr_scheduler_builder(optimizer)
+
+        # return {
+        #     "optimizer": optimizer,
+        #     "lr_scheduler": {
+        #         "scheduler": lr_scheduler,
+        #         "interval": "step",
+        #     },
+        # }
 
     def norm_spec(self, x):
         return (x - self.spec_min) / (self.spec_max - self.spec_min) * 2 - 1

+ 37 - 0
fish_speech/models/vqgan/modules/discriminator.py

@@ -0,0 +1,37 @@
+import torch
+from torch import nn
+from torch.nn.utils.parametrizations import weight_norm
+
+class Discriminator(nn.Module):
+    def __init__(self):
+        super().__init__()
+
+        blocks = []
+        convs = [
+            (1, 64, (3, 9), 1, (1, 4)),
+            (64, 128, (3, 9), (1, 2), (1, 4)),
+            (128, 256, (3, 9), (1, 2), (1, 4)),
+            (256, 512, (3, 9), (1, 2), (1, 4)),
+            (512, 1024, (3, 3), 1, (1, 1)),
+            (1024, 1, (3, 3), 1, (1, 1)),
+        ]
+
+        for idx, (in_channels, out_channels, kernel_size, stride, padding) in enumerate(convs):
+            blocks.append(weight_norm(nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)))
+
+            if idx != len(convs) - 1:
+                blocks.append(nn.SiLU(inplace=True))
+        
+        self.blocks = nn.Sequential(*blocks)
+
+    def forward(self, x):
+        return self.blocks(x[:, None])[:, 0]
+
+
+if __name__ == "__main__":
+    model = Discriminator()
+    print(sum(p.numel() for p in model.parameters()) / 1_000_000)
+    x = torch.randn(1, 128, 1024)
+    y = model(x)
+    print(y.shape)
+    print(y)