Parcourir la source

Update finetune document

Lengyue il y a 2 ans
Parent
commit
d3c0dee39c

+ 63 - 1
docs/zh/finetune.md

@@ -2,7 +2,63 @@
 
 
 显然, 当你打开这个页面的时候, 你已经对预训练模型 few-shot 的效果不算满意. 你想要微调一个模型, 使得它在你的数据集上表现更好.  
 显然, 当你打开这个页面的时候, 你已经对预训练模型 few-shot 的效果不算满意. 你想要微调一个模型, 使得它在你的数据集上表现更好.  
 
 
-`Fish Speech` 由两个模块组成: `VQGAN` 和 `LLAMA`. 目前, 我们只支持微调 `LLAMA` 模型.
+`Fish Speech` 由两个模块组成: `VQGAN` 和 `LLAMA`. 
+
+!!! info 
+    你应该先进行如下测试来判断你是否需要微调 `VQGAN`:
+    ```bash
+    python tools/vqgan/inference.py -i test.wav
+    ```
+    该测试会生成一个 `fake.wav` 文件, 如果该文件的音色和说话人的音色不同, 或者质量不高, 你需要微调 `VQGAN`.
+
+    相应的, 你可以参考 [推理](../inference/) 来运行 `generate.py`, 判断韵律是否满意, 如果不满意, 则需要微调 `LLAMA`.
+
+## VQGAN 微调
+### 1. 准备数据集
+
+```
+.
+├── SPK1
+│   ├── 21.15-26.44.lab
+│   ├── 21.15-26.44.mp3
+│   ├── 27.51-29.98.lab
+│   ├── 27.51-29.98.mp3
+│   ├── 30.1-32.71.lab
+│   └── 30.1-32.71.mp3
+└── SPK2
+    ├── 38.79-40.85.lab
+    └── 38.79-40.85.mp3
+```
+
+你需要将数据集转为以上格式, 并放到 `data/demo` 下, 音频后缀可以为 `.mp3`, `.wav` 或 `.flac`, 标注文件后缀可以为 `.lab` 或 `.txt`.
+
+### 2. 分割训练集和验证集
+
+```bash
+python tools/vqgan/create_train_split.py data/demo
+```
+
+该命令会在 `data/demo` 目录下创建 `data/demo/vq_train_filelist.txt` 和 `data/demo/vq_val_filelist.txt` 文件, 分别用于训练和验证.
+
+### 3. 启动训练
+
+```bash
+python fish_speech/train.py --config-name vqgan_finetune
+```
+
+!!! note
+    你可以通过修改 `fish_speech/configs/vqgan_finetune.yaml` 来修改训练参数, 但大部分情况下, 你不需要这么做.
+
+### 4. 测试音频
+    
+```bash
+python tools/vqgan/inference.py -i test.wav --checkpoint-path results/vqgan_finetune/checkpoints/step_000010000.ckpt
+```
+
+你可以查看 `fake.wav` 来判断微调效果.
+
+!!! note
+    你也可以尝试其他的 checkpoint, 我们建议你使用最早的满足你要求的 checkpoint, 他们通常在 OOD 上表现更好.
 
 
 ## LLAMA 微调
 ## LLAMA 微调
 ### 1. 准备数据集
 ### 1. 准备数据集
@@ -26,6 +82,12 @@
 !!! note
 !!! note
     你可以通过修改 `fish_speech/configs/data/finetune.yaml` 来修改数据集路径, 以及混合数据集.
     你可以通过修改 `fish_speech/configs/data/finetune.yaml` 来修改数据集路径, 以及混合数据集.
 
 
+!!! warning
+    建议先对数据集进行响度匹配, 你可以使用 [fish-audio-preprocess](https://github.com/fishaudio/audio-preprocess) 来完成这一步骤. 
+    ```bash
+    fap loudness-norm demo-raw demo --clean
+    ```
+
 ### 2. 批量提取语义 token
 ### 2. 批量提取语义 token
 
 
 确保你已经下载了 vqgan 权重, 如果没有, 请运行以下命令:
 确保你已经下载了 vqgan 权重, 如果没有, 请运行以下命令:

+ 3 - 3
fish_speech/configs/text2semantic_finetune_spk.yaml

@@ -63,7 +63,7 @@ model:
   optimizer:
   optimizer:
     _target_: torch.optim.AdamW
     _target_: torch.optim.AdamW
     _partial_: true
     _partial_: true
-    lr: 1e-4
+    lr: 1e-5
     weight_decay: 0.1
     weight_decay: 0.1
     betas: [0.9, 0.95]
     betas: [0.9, 0.95]
     eps: 1e-5
     eps: 1e-5
@@ -74,11 +74,11 @@ model:
     lr_lambda:
     lr_lambda:
       _target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda
       _target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda
       _partial_: true
       _partial_: true
-      num_warmup_steps: 1000
+      num_warmup_steps: 100
       num_training_steps: ${trainer.max_steps}
       num_training_steps: ${trainer.max_steps}
       final_lr_ratio: 0.1
       final_lr_ratio: 0.1
 
 
 # Callbacks
 # Callbacks
 callbacks:
 callbacks:
   model_checkpoint:
   model_checkpoint:
-    every_n_train_steps: 1000
+    every_n_train_steps: 200

+ 2 - 1
tools/vqgan/inference.py

@@ -14,6 +14,7 @@ from loguru import logger
 from omegaconf import OmegaConf
 from omegaconf import OmegaConf
 
 
 from fish_speech.models.vqgan.utils import sequence_mask
 from fish_speech.models.vqgan.utils import sequence_mask
+from fish_speech.utils.file import AUDIO_EXTENSIONS
 
 
 # register eval resolver
 # register eval resolver
 OmegaConf.register_new_resolver("eval", eval)
 OmegaConf.register_new_resolver("eval", eval)
@@ -51,7 +52,7 @@ def main(input_path, output_path, config_name, checkpoint_path):
     model.cuda()
     model.cuda()
     logger.info("Restored model from checkpoint")
     logger.info("Restored model from checkpoint")
 
 
-    if input_path.suffix == ".wav":
+    if input_path.suffix in AUDIO_EXTENSIONS:
         logger.info(f"Processing in-place reconstruction of {input_path}")
         logger.info(f"Processing in-place reconstruction of {input_path}")
         # Load audio
         # Load audio
         audio, _ = librosa.load(
         audio, _ = librosa.load(