Просмотр исходного кода

Optimize pretrain & finetune receipe, apply better logger

Lengyue 1 год назад
Родитель
Сommit
1d4b7256b3

+ 3 - 2
docs/zh/index.md

@@ -21,8 +21,8 @@
 conda create -n fish-speech python=3.10
 conda activate fish-speech
 
-# 安装 pytorch nightly 版本
-pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121
+# 安装 pytorch 版本
+pip3 install torch torchvision torchaudio
 
 # 安装 fish-speech
 pip3 install -e .
@@ -30,6 +30,7 @@ pip3 install -e .
 
 ## 更新日志
 
+- 2024/04/22: 完成了 Fish-Speech 1.0 版本, 大幅修改了 VQGAN 和 LLAMA 模型.
 - 2023/12/28: 添加了 `lora` 微调支持.
 - 2023/12/27: 添加了 `gradient checkpointing`, `causual sampling` 和 `flash-attn` 支持.
 - 2023/12/19: 更新了 Webui 和 HTTP API.

+ 14 - 8
fish_speech/configs/text2semantic_finetune.yaml

@@ -1,39 +1,46 @@
 defaults:
   - base
-  - model@model.model: dual_ar_8_codebook_small
+  - model@model.model: dual_ar_2_codebook_small
   - _self_
 
-project: text2semantic_400m_finetune_spk
-max_length: 4096
-ckpt_path: checkpoints/text2semantic-400m-v0.2-4k.pth
+project: text2semantic_finetune_dual_ar
+max_length: 2048
+ckpt_path: checkpoints/text2semantic-medium-v1-2k.pth
 resume_weights_only: true
 
 # Lightning Trainer
 trainer:
-  accumulate_grad_batches: 2
+  accumulate_grad_batches: 1
   gradient_clip_val: 1.0
   gradient_clip_algorithm: 'norm'
   max_steps: 1000
   precision: bf16-true
   limit_val_batches: 10
+  val_check_interval: 100
 
 # Dataset Configuration
 tokenizer:
   _target_: transformers.AutoTokenizer.from_pretrained
-  pretrained_model_name_or_path: fishaudio/speech-lm-v1
+  pretrained_model_name_or_path: fishaudio/fish-speech-1
 
 # Dataset Configuration
 train_dataset:
   _target_: fish_speech.datasets.text.AutoAugTextDataset
+  proto_files:
+    - data/protos
   tokenizer: ${tokenizer}
   max_length: ${max_length}
   num_codebooks: ${model.model.config.num_codebooks}
+  use_speaker: false
 
 val_dataset:
   _target_: fish_speech.datasets.text.AutoAugTextDataset
+  proto_files:
+    - data/protos
   tokenizer: ${tokenizer}
   max_length: ${max_length}
   num_codebooks: ${model.model.config.num_codebooks}
+  use_speaker: false
 
 data:
   _target_: fish_speech.datasets.text.TextDataModule
@@ -65,9 +72,8 @@ model:
       _partial_: true
       num_warmup_steps: 100
       num_training_steps: ${trainer.max_steps}
-      final_lr_ratio: 0.1
 
 # Callbacks
 callbacks:
   model_checkpoint:
-    every_n_train_steps: 200
+    every_n_train_steps: 100

+ 2 - 1
fish_speech/configs/text2semantic_finetune_lora.yaml

@@ -2,7 +2,8 @@ defaults:
   - text2semantic_finetune
   - _self_
 
-project: text2semantic_400m_finetune_lora
+project: text2semantic_finetune_dual_ar_lora
+
 # Model Configuration
 model:
   save_lora_only: true

+ 1 - 1
fish_speech/configs/text2semantic_pretrain.yaml

@@ -1,6 +1,6 @@
 defaults:
   - base
-  - model@model.model: dual_ar_8_codebook_small
+  - model@model.model: dual_ar_2_codebook_small
   - _self_
 
 project: text2semantic_pretrain_dual_ar_debug

+ 135 - 0
fish_speech/configs/vqgan_finetune.yaml

@@ -0,0 +1,135 @@
+defaults:
+  - base
+  - _self_
+
+project: vq-gan-finetune
+ckpt_path: checkpoints/vq-gan-group-fsq-2x1024.pth
+resume_weights_only: true
+
+# Lightning Trainer
+trainer:
+  accelerator: gpu
+  devices: auto
+  precision: bf16-mixed
+  max_steps: 100_000
+  val_check_interval: 5000
+  strategy: ddp_find_unused_parameters_true
+
+sample_rate: 44100
+hop_length: 512
+num_mels: 128
+n_fft: 2048
+win_length: 2048
+freeze_encoder: true
+
+# Dataset Configuration
+train_dataset:
+  _target_: fish_speech.datasets.vqgan.VQGANDataset
+  filelist: data/filelist.train.txt
+  sample_rate: ${sample_rate}
+  hop_length: ${hop_length}
+  slice_frames: 512
+
+val_dataset:
+  _target_: fish_speech.datasets.vqgan.VQGANDataset
+  filelist: data/filelist.val.txt
+  sample_rate: ${sample_rate}
+  hop_length: ${hop_length}
+
+data:
+  _target_: fish_speech.datasets.vqgan.VQGANDataModule
+  train_dataset: ${train_dataset}
+  val_dataset: ${val_dataset}
+  num_workers: 4
+  batch_size: 16
+  val_batch_size: 16
+
+# Model Configuration
+model:
+  _target_: fish_speech.models.vqgan.VQGAN
+
+  sampling_rate: ${sample_rate}
+  weight_adv: 0.2
+  weight_vq: 1.0
+  weight_mel: 1.0
+  freeze_encoder: false
+
+  encoder:
+    _target_: fish_speech.models.vqgan.modules.wavenet.WaveNet
+    input_channels: ${num_mels}
+    residual_channels: 768
+    residual_layers: 20
+    dilation_cycle: 4
+  
+  quantizer:
+    _target_: fish_speech.models.vqgan.modules.fsq.DownsampleFiniteScalarQuantize
+    input_dim: 768
+    n_codebooks: 1
+    n_groups: 2
+    levels: [8, 5, 5, 5]
+
+  decoder:
+    _target_: fish_speech.models.vqgan.modules.wavenet.WaveNet
+    output_channels: ${num_mels}
+    residual_channels: 768
+    residual_layers: 20
+    dilation_cycle: 4
+    condition_channels: 768
+  
+  discriminator:
+    _target_: fish_speech.models.vqgan.modules.discriminator.Discriminator
+
+  vocoder:
+    _target_: fish_speech.models.vqgan.modules.firefly.FireflyBase
+    ckpt_path: null # You may download the pretrained vocoder and set the path here
+
+  encode_mel_transform:
+    _target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram
+    sample_rate: ${sample_rate}
+    n_fft: ${n_fft}
+    hop_length: ${hop_length}
+    win_length: ${win_length}
+    n_mels: ${num_mels}
+    f_min: 0.0
+    f_max: 8000.0
+
+  gt_mel_transform:
+    _target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram
+    sample_rate: ${sample_rate}
+    n_fft: ${n_fft}
+    hop_length: ${hop_length}
+    win_length: ${win_length}
+    n_mels: ${num_mels}
+
+  optimizer:
+    _target_: torch.optim.AdamW
+    _partial_: true
+    lr: 4e-5
+    betas: [0.8, 0.99]
+    eps: 1e-5
+    weight_decay: 0.01
+
+  lr_scheduler:
+    _target_: torch.optim.lr_scheduler.LambdaLR
+    _partial_: true
+    lr_lambda:
+      _target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda
+      _partial_: true
+      num_warmup_steps: 100
+      num_training_steps: ${trainer.max_steps}
+      final_lr_ratio: 0
+
+callbacks:
+  model_summary:
+    _target_: lightning.pytorch.callbacks.ModelSummary
+    max_depth: 1
+
+  model_checkpoint:
+    every_n_train_steps: ${trainer.val_check_interval}
+
+  grad_norm_monitor:
+    sub_module: 
+      - encoder
+      - decoder
+      - quantizer
+      - discriminator

+ 1 - 1
fish_speech/configs/vqgan_pretrain.yaml

@@ -2,7 +2,7 @@ defaults:
   - base
   - _self_
 
-project: vq-gan-group-fsq-8x1024-wn-20x768-cond
+project: vq-gan-pretrain
 
 # Lightning Trainer
 trainer:

+ 3 - 0
fish_speech/datasets/vqgan.py

@@ -68,6 +68,9 @@ class VQGANDataset(Dataset):
         try:
             return self.get_item(idx)
         except Exception as e:
+            import traceback
+
+            traceback.print_exc()
             logger.error(f"Error loading {self.files[idx]}: {e}")
             return None
 

+ 34 - 18
fish_speech/models/text2semantic/lit_module.py

@@ -68,6 +68,20 @@ class TextToSemantic(L.LightningModule):
                 ]
             )
 
+        if hasattr(self.model, "fast_layers"):
+            # Dual-AR model
+            linears.extend([(self.model, "fast_output")])
+
+            for layer in self.model.fast_layers:
+                linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
+                linears.extend(
+                    [
+                        (layer.feed_forward, "w1"),
+                        (layer.feed_forward, "w2"),
+                        (layer.feed_forward, "w3"),
+                    ]
+                )
+
         for module, layer in linears:
             updated_linear = lora.Linear(
                 in_features=getattr(module, layer).in_features,
@@ -162,6 +176,8 @@ class TextToSemantic(L.LightningModule):
             return (per_token_logps * loss_mask).sum(-1)
 
     def _step(self, batch, batch_idx, stage: str):
+        is_train = stage == "train"
+
         # Do positive and negative samples in the same batch to speed up training
         labels = batch["labels"]
         outputs = self.model(
@@ -224,8 +240,8 @@ class TextToSemantic(L.LightningModule):
             self.log(
                 f"{stage}/dpo_loss",
                 dpo_loss,
-                on_step=True,
-                on_epoch=False,
+                on_step=is_train,
+                on_epoch=not is_train,
                 prog_bar=False,
                 logger=True,
             )
@@ -233,8 +249,8 @@ class TextToSemantic(L.LightningModule):
             self.log(
                 f"{stage}/chosen_rewards",
                 chosen_rewards,
-                on_step=True,
-                on_epoch=False,
+                on_step=is_train,
+                on_epoch=not is_train,
                 prog_bar=False,
                 logger=True,
             )
@@ -242,8 +258,8 @@ class TextToSemantic(L.LightningModule):
             self.log(
                 f"{stage}/rejected_rewards",
                 rejected_rewards,
-                on_step=True,
-                on_epoch=False,
+                on_step=is_train,
+                on_epoch=not is_train,
                 prog_bar=False,
                 logger=True,
             )
@@ -251,8 +267,8 @@ class TextToSemantic(L.LightningModule):
             self.log(
                 f"{stage}/reward_accuracy",
                 reward_accuracy,
-                on_step=True,
-                on_epoch=False,
+                on_step=is_train,
+                on_epoch=not is_train,
                 prog_bar=False,
                 logger=True,
             )
@@ -260,8 +276,8 @@ class TextToSemantic(L.LightningModule):
         self.log(
             f"{stage}/loss",
             loss,
-            on_step=True,
-            on_epoch=False,
+            on_step=is_train,
+            on_epoch=not is_train,
             prog_bar=True,
             logger=True,
         )
@@ -269,8 +285,8 @@ class TextToSemantic(L.LightningModule):
         self.log(
             f"{stage}/base_loss",
             base_loss,
-            on_step=True,
-            on_epoch=False,
+            on_step=is_train,
+            on_epoch=not is_train,
             prog_bar=False,
             logger=True,
         )
@@ -278,8 +294,8 @@ class TextToSemantic(L.LightningModule):
         self.log(
             f"{stage}/semantic_loss",
             semantic_loss,
-            on_step=True,
-            on_epoch=False,
+            on_step=is_train,
+            on_epoch=not is_train,
             prog_bar=False,
             logger=True,
         )
@@ -289,8 +305,8 @@ class TextToSemantic(L.LightningModule):
         self.log(
             f"{stage}/top_5_accuracy",
             accuracy,
-            on_step=True,
-            on_epoch=False,
+            on_step=is_train,
+            on_epoch=not is_train,
             prog_bar=True,
             logger=True,
         )
@@ -304,8 +320,8 @@ class TextToSemantic(L.LightningModule):
             self.log(
                 f"{stage}/top_5_accuracy_in",
                 accuracy,
-                on_step=True,
-                on_epoch=False,
+                on_step=is_train,
+                on_epoch=not is_train,
                 prog_bar=True,
                 logger=True,
             )

+ 4 - 1
pyproject.toml

@@ -30,7 +30,10 @@ dependencies = [
     "loguru>=0.6.0",
     "loralib>=0.1.2",
     "natsort>=8.4.0",
-    "pyrootutils>=1.0.4"
+    "pyrootutils>=1.0.4",
+    "vector_quantize_pytorch>=1.14.7",
+    "samplerate>=0.2.1",
+    "resampy>=0.4.3",
 ]
 
 [project.optional-dependencies]