Selaa lähdekoodia

Fix freeze codebook & update finetune guide

Lengyue 2 vuotta sitten
vanhempi
commit
9d34988cec

+ 120 - 0
docs/zh/finetune.md

@@ -0,0 +1,120 @@
+# 微调
+
+显然, 当你打开这个页面的时候, 你已经对预训练模型 few-shot 的效果不算满意. 你想要微调一个模型, 使得它在你的数据集上表现更好.  
+
+`Fish Speech` 由两个模块组成: `VQGAN` 和 `LLAMA`. 目前, 我们只支持微调 `LLAMA` 模型.
+
+## LLAMA 微调
+### 1. 准备数据集
+
+```
+.
+├── SPK1
+│   ├── 21.15-26.44.lab
+│   ├── 21.15-26.44.mp3
+│   ├── 27.51-29.98.lab
+│   ├── 27.51-29.98.mp3
+│   ├── 30.1-32.71.lab
+│   └── 30.1-32.71.mp3
+└── SPK2
+    ├── 38.79-40.85.lab
+    └── 38.79-40.85.mp3
+```
+
+你需要将数据集转为以上格式, 并放到 `data/demo` 下, 音频后缀可以为 `.mp3`, `.wav` 或 `.flac`, 标注文件后缀可以为 `.lab` 或 `.txt`.
+
+!!! note
+    你可以通过修改 `fish_speech/configs/data/finetune.yaml` 来修改数据集路径, 以及混合数据集.
+
+### 2. 批量提取语义 token
+
+确保你已经下载了 vqgan 权重, 如果没有, 请运行以下命令:
+
+```bash
+huggingface-cli download fishaudio/speech-lm-v1 vqgan-v1.pth --local-dir checkpoints
+```
+
+随后可运行以下命令来提取语义 token:
+
+```bash
+python tools/vqgan/extract_vq.py data/demo \
+    --num-workers 1 --batch-size 16 \
+    --config-name "vqgan_pretrain" \
+    --checkpoint-path "checkpoints/vqgan-v1.pth"
+```
+
+!!! note
+    你可以调整 `--num-workers` 和 `--batch-size` 来提高提取速度, 但是请注意不要超过你的显存限制.
+
+该命令会在 `data/demo` 目录下创建 `.npy` 文件, 如下所示:
+
+```
+.
+├── SPK1
+│   ├── 21.15-26.44.lab
+│   ├── 21.15-26.44.mp3
+│   ├── 21.15-26.44.npy
+│   ├── 27.51-29.98.lab
+│   ├── 27.51-29.98.mp3
+│   ├── 27.51-29.98.npy
+│   ├── 30.1-32.71.lab
+│   ├── 30.1-32.71.mp3
+│   └── 30.1-32.71.npy
+└── SPK2
+    ├── 38.79-40.85.lab
+    ├── 38.79-40.85.mp3
+    └── 38.79-40.85.npy
+```
+
+### 3. 打包数据集为 protobuf
+
+```bash
+python tools/llama/build_dataset.py \
+    --config "fish_speech/configs/data/finetune.yaml" \
+    --output "data/quantized-dataset-ft.protos"
+```
+
+命令执行完毕后, 你应该能在 `data` 目录下看到 `quantized-dataset-ft.protos` 文件.
+
+### 4. 启动 Rust 数据服务器
+
+由于加载和打乱数据集非常缓慢且占用内存, 因此我们使用 rust 服务器来加载和打乱数据. 该服务器基于 GRPC, 可以通过以下方式安装:
+
+```bash
+cd data_server
+cargo build --release
+```
+
+编译完成后你可以使用以下命令来启动服务器:
+
+```bash
+export RUST_LOG=info # 可选, 用于调试
+data_server/target/release/data_server \
+    --files "data/quantized-dataset-ft.protos" 
+```
+
+!!! note
+    你可以指定多个 `--files` 参数来加载多个数据集.
+
+### 5. 最后, 启动微调
+
+同样的, 请确保你已经下载了 `LLAMA` 权重, 如果没有, 请运行以下命令:
+
+```bash
+huggingface-cli download fishaudio/speech-lm-v1 text2semantic-400m-v0.2-4k.pth --local-dir checkpoints
+```
+
+最后, 你可以运行以下命令来启动微调:
+```bash
+python fish_speech/train.py --config-name text2semantic_finetune_spk
+```
+
+!!! note
+    你可以通过修改 `fish_speech/configs/text2semantic_finetune_spk.yaml` 来修改训练参数如 `batch_size`, `gradient_accumulation_steps` 等, 来适应你的显存.
+
+训练结束后, 你可以参考推理部分来生成语音. 
+
+
+!!! info
+    默认配置下, 基本只会学到说话人的发音方式, 而不包含音色, 你依然需要使用 prompt 来保证音色的稳定性.  
+    如果你想要学到音色, 请将训练步数调大, 但这有可能会导致过拟合.

+ 0 - 8
docs/zh/index.md

@@ -29,14 +29,6 @@ pip3 install ninja && MAX_JOBS=4 pip3 install flash-attn --no-build-isolation
 pip3 install -e .
 ```
 
-## Rust 数据服务器
-由于加载和打乱数据集非常缓慢且占用内存, 因此我们使用 rust 服务器来加载和打乱数据. 该服务器基于 GRPC, 可以通过以下方式安装:
-
-```bash
-cd data_server
-cargo build --release
-```
-
 ## 更新日志
 
 - 2023/12/17: 更新了 `text2semantic` 模型, 支持无音素模式.

+ 2 - 4
docs/zh/inference.md

@@ -15,10 +15,8 @@
 从我们的 huggingface 仓库下载所需的 `vqgan` 和 `text2semantic` 模型。
     
 ```bash
-wget https://huggingface.co/fishaudio/speech-lm-v1/raw/main/vqgan-v1.pth \
-    -O "checkpoints/vqgan-v1.pth"
-wget https://huggingface.co/fishaudio/speech-lm-v1/blob/main/text2semantic-400m-v0.2-4k.pth \
-    -O "checkpoints/text2semantic-400m-v0.2-4k.pth"
+huggingface-cli download fishaudio/speech-lm-v1 vqgan-v1.pth --local-dir checkpoints
+huggingface-cli download fishaudio/speech-lm-v1 text2semantic-400m-v0.2-4k.pth --local-dir checkpoints
 ```
 
 ### 1. 从语音生成 prompt: 

+ 8 - 0
fish_speech/configs/data/finetune.yaml

@@ -0,0 +1,8 @@
+datasets:
+  - root: data/demo
+    source: MyFinetune
+    languages: [ZH, EN]
+    extension: .lab
+    # This controls the grouping of the dataset (i.e. speaker)
+    # 1 means we use the parent folder of the file as the group name
+    group_parent_level: 1

+ 43 - 0
fish_speech/configs/data/pretrain.yaml

@@ -0,0 +1,43 @@
+datasets:
+  - root: data/StarRail/Chinese
+    source: StarRail
+    languages: [ZH, EN]
+    extension: .lab
+    # This controls the grouping of the dataset (i.e. speaker)
+    # 1 means we use the parent folder of the file as the group name
+    group_parent_level: 1 
+  - root: data/StarRail/English
+    source: StarRail
+    languages: [EN]
+    extension: .lab
+    group_parent_level: 1
+  - root: data/StarRail/Japanese
+    source: StarRail
+    languages: [JP, EN]
+    extension: .lab
+    group_parent_level: 1
+  - root: data/Genshin/Chinese
+    source: Genshin
+    languages: [ZH, EN]
+    extension: .lab
+    group_parent_level: 1
+  - root: data/Genshin/English
+    source: Genshin
+    languages: [EN]
+    extension: .lab
+    group_parent_level: 1
+  - root: data/Genshin/Japanese
+    source: Genshin
+    languages: [JP, EN]
+    extension: .lab
+    group_parent_level: 1
+  - root: data/LibriTTS_R
+    source: LibriTTS_R
+    languages: [EN]
+    extension: .normalized.txt
+    group_parent_level: 2
+  - root: data/WenetSpeech
+    source: WenetSpeech
+    languages: [ZH, EN]
+    extension: .txt
+    group_parent_level: 1

+ 2 - 2
fish_speech/configs/text2semantic_finetune_spk.yaml

@@ -4,7 +4,7 @@ defaults:
 
 project: text2semantic_400m_finetune_spk
 max_length: 4096
-ckpt_path: results/text2semantic_400m_finetune/checkpoints/step_000010000.ckpt
+ckpt_path: checkpoints/text2semantic-400m-v0.2-4k.pth
 resume_weights_only: true
 
 # Lightning Trainer
@@ -74,7 +74,7 @@ model:
     lr_lambda:
       _target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda
       _partial_: true
-      num_warmup_steps: 100
+      num_warmup_steps: 1000
       num_training_steps: ${trainer.max_steps}
       final_lr_ratio: 0.1
 

+ 1 - 1
fish_speech/models/vqgan/modules/encoders.py

@@ -317,7 +317,7 @@ class VQEncoder(nn.Module):
             x_mask = F.pad(x_mask, (0, self.downsample - x_len % self.downsample))
 
         x = self.conv_in(x)
-        q, indices, loss = self.vq(x.mT, freeze_codebook=freeze_codebook)
+        q, indices, loss = self.vq(x.mT)
         q = q.mT
 
         if self.codebook_groups > 1:

+ 25 - 19
tools/llama/build_dataset.py

@@ -2,7 +2,9 @@ import re
 from collections import defaultdict
 from multiprocessing import Pool
 
+import click
 import numpy as np
+import yaml
 from loguru import logger
 from tqdm import tqdm
 
@@ -11,22 +13,20 @@ from fish_speech.datasets.protos.text_data_stream import pack_pb_stream
 from fish_speech.text import g2p
 from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files
 
-# Define datasets
-DATASETS = [
-    # (root, name, languages, extension, group parent level)
-    ("data/StarRail/Chinese", "StarRail", ["ZH", "EN"], ".lab", 1),
-    ("data/StarRail/English", "StarRail", ["EN"], ".lab", 1),
-    ("data/StarRail/Japanese", "StarRail", ["JP", "EN"], ".lab", 1),
-    ("data/Genshin/Chinese", "Genshin", ["ZH", "EN"], ".lab", 1),
-    ("data/Genshin/English", "Genshin", ["EN"], ".lab", 1),
-    ("data/Genshin/Japanese", "Genshin", ["JP", "EN"], ".lab", 1),
-    ("data/LibriTTS_R", "LibriTTS_R", ["EN"], ".normalized.txt", 2),
-    ("data/WenetSpeech", "WenetSpeech", ["ZH", "EN"], ".txt", 1),
-]
-
-
-def task_generator():
-    for root, source, languages, extension, parent_level in DATASETS:
+
+def task_generator(config):
+    with open(config, "r") as f:
+        config = yaml.load(f, Loader=yaml.FullLoader)
+
+    for row in config["datasets"]:
+        root, source, languages, extension, parent_level = (
+            row["root"],
+            row["source"],
+            row["languages"],
+            row["extension"],
+            row["group_parent_level"],
+        )
+
         # Load the files
         files = list_files(root, AUDIO_EXTENSIONS, recursive=True)
 
@@ -55,6 +55,7 @@ def run_task(task):
         np_file = file.with_suffix(".npy")
         txt_file = file.with_suffix(extension)
         if np_file.exists() is False or txt_file.exists() is False:
+            logger.warning(f"Can't find {np_file} or {txt_file}")
             continue
 
         with open(txt_file, "r") as f:
@@ -94,10 +95,15 @@ def run_task(task):
     )
 
 
-def main():
-    dataset_fp = open("data/quantized-dataset-1208.protos", "wb")
+@click.command()
+@click.option(
+    "--config", type=click.Path(), default="fish_speech/configs/data/finetune.yaml"
+)
+@click.option("--output", type=click.Path(), default="data/quantized-dataset-ft.protos")
+def main(config, output):
+    dataset_fp = open(output, "wb")
     with Pool(16) as p:
-        for result in tqdm(p.imap_unordered(run_task, task_generator())):
+        for result in tqdm(p.imap_unordered(run_task, task_generator(config))):
             dataset_fp.write(result)
 
     dataset_fp.close()

+ 6 - 3
tools/vqgan/extract_vq.py

@@ -53,7 +53,10 @@ def get_model(
     state_dict = torch.load(
         checkpoint_path,
         map_location=model.device,
-    )["state_dict"]
+    )
+    if "state_dict" in state_dict:
+        state_dict = state_dict["state_dict"]
+
     model.load_state_dict(state_dict, strict=True)
     model.eval()
     model.cuda()
@@ -136,10 +139,10 @@ def process_batch(files: list[Path], model) -> float:
 @click.command()
 @click.argument("folder")
 @click.option("--num-workers", default=1)
-@click.option("--config-name", default="vqgan")
+@click.option("--config-name", default="vqgan_pretrain")
 @click.option(
     "--checkpoint-path",
-    default="checkpoints/vqgan/step_000380000.ckpt",
+    default="checkpoints/vqgan-v1.pth",
 )
 @click.option("--batch-size", default=64)
 def main(