Browse Source

Support lora tuning & update document

Lengyue 1 year ago
parent
commit
467b7beb56

+ 18 - 3
docs/en/finetune.md

@@ -157,9 +157,6 @@ python fish_speech/train.py --config-name text2semantic_finetune \
     model@model.model=dual_ar_2_codebook_large
 ```
 
-!!! info
-    If you want to use lora, please use `--config-name text2semantic_finetune_lora` to start fine-tuning (still under development).
-
 !!! note
     You can modify the training parameters such as `batch_size`, `gradient_accumulation_steps`, etc. to fit your GPU memory by modifying `fish_speech/configs/text2semantic_finetune.yaml`.
 
@@ -171,3 +168,21 @@ After training is complete, you can refer to the [inference](inference.md) secti
 !!! info
     By default, the model will only learn the speaker's speech patterns and not the timbre. You still need to use prompts to ensure timbre stability.
     If you want to learn the timbre, you can increase the number of training steps, but this may lead to overfitting.
+
+#### Fine-tuning with LoRA
+
+!!! note
+    LoRA can reduce the risk of overfitting in models, but it may also lead to underfitting on large datasets. 
+
+If you want to use LoRA, please add the following parameter: `+lora@model.lora_config=r_8_alpha_16`. 
+
+After training, you need to convert the LoRA weights to regular weights before performing inference.
+
+```bash
+python tools/llama/merge_lora.py \
+    --llama-config dual_ar_2_codebook_large \
+    --lora-config r_8_alpha_16 \
+    --llama-weight checkpoints/text2semantic-large-v1-4k.pth \
+    --lora-weight results/text2semantic-finetune-medium-lora/checkpoints/step_000000200.ckpt \
+    --output checkpoints/merged.ckpt
+```

+ 1 - 2
docs/en/inference.md

@@ -53,8 +53,7 @@ This command will create a `codes_N` file in the working directory, where N is a
     For GPUs that do not support bf16, you may need to use the `--half` parameter.
 
 !!! warning
-    If you are using your own fine-tuned model, please be sure to carry the `--speaker` parameter to ensure the stability of pronunciation.  
-    If you are using lora, please use `--config-name text2semantic_finetune_lora` to load the model.
+    If you are using your own fine-tuned model, please be sure to carry the `--speaker` parameter to ensure the stability of pronunciation.
 
 ### 3. Generate vocals from semantic tokens:
 ```bash

+ 17 - 3
docs/zh/finetune.md

@@ -168,9 +168,6 @@ python fish_speech/train.py --config-name text2semantic_finetune \
     model@model.model=dual_ar_2_codebook_large
 ```
 
-!!! note
-    如果你想使用 lora, 请使用 `--config-name text2semantic_finetune_lora` 来启动微调 (仍在开发中).
-
 !!! note
     你可以通过修改 `fish_speech/configs/text2semantic_finetune.yaml` 来修改训练参数如 `batch_size`, `gradient_accumulation_steps` 等, 来适应你的显存.
 
@@ -182,3 +179,20 @@ python fish_speech/train.py --config-name text2semantic_finetune \
 !!! info
     默认配置下, 基本只会学到说话人的发音方式, 而不包含音色, 你依然需要使用 prompt 来保证音色的稳定性.  
     如果你想要学到音色, 请将训练步数调大, 但这有可能会导致过拟合.
+
+#### 使用 lora 进行微调
+!!! note
+    lora 可以减少模型过拟合的风险, 但是相应的会导致在大数据集上欠拟合.   
+
+如果你想使用 lora, 请添加以下参数 `+lora@model.lora_config=r_8_alpha_16`.  
+
+训练完成后, 你需要先将 lora 的权重转为普通权重, 然后再进行推理.
+
+```bash
+python tools/llama/merge_lora.py \
+    --llama-config dual_ar_2_codebook_large \
+    --lora-config r_8_alpha_16 \
+    --llama-weight checkpoints/text2semantic-large-v1-4k.pth \
+    --lora-weight results/text2semantic-finetune-medium-lora/checkpoints/step_000000200.ckpt \
+    --output checkpoints/merged.ckpt
+```

+ 1 - 2
docs/zh/inference.md

@@ -58,8 +58,7 @@ python tools/llama/generate.py \
     对于不支持 bf16 的 GPU, 你可能需要使用 `--half` 参数.
 
 !!! warning
-    如果你在使用自己微调的模型, 请务必携带 `--speaker` 参数来保证发音的稳定性.  
-    如果你使用了 lora, 请使用 `--config-name text2semantic_finetune_lora` 来加载模型.
+    如果你在使用自己微调的模型, 请务必携带 `--speaker` 参数来保证发音的稳定性.
 
 ### 3. 从语义 token 生成人声: 
 ```bash

+ 3 - 0
fish_speech/configs/lora/r_8_alpha_16.yaml

@@ -0,0 +1,3 @@
+_target_: fish_speech.models.text2semantic.lora_utils.LoraConfig
+r: 8
+lora_alpha: 16

+ 4 - 2
fish_speech/configs/text2semantic_finetune.yaml

@@ -31,7 +31,8 @@ train_dataset:
   tokenizer: ${tokenizer}
   max_length: ${max_length}
   num_codebooks: ${model.model.config.num_codebooks}
-  use_speaker: false
+  use_speaker: 0.5
+  interactive_prob: 0.7
 
 val_dataset:
   _target_: fish_speech.datasets.text.AutoAugTextDataset
@@ -40,7 +41,8 @@ val_dataset:
   tokenizer: ${tokenizer}
   max_length: ${max_length}
   num_codebooks: ${model.model.config.num_codebooks}
-  use_speaker: false
+  use_speaker: 0.5
+  interactive_prob: 0.7
 
 data:
   _target_: fish_speech.datasets.text.TextDataModule

+ 0 - 13
fish_speech/configs/text2semantic_finetune_lora.yaml

@@ -1,13 +0,0 @@
-defaults:
-  - text2semantic_finetune
-  - _self_
-
-project: text2semantic_finetune_dual_ar_lora
-
-# Model Configuration
-model:
-  save_lora_only: true
-  lora_config:
-    _target_: fish_speech.models.text2semantic.lit_module.LoraConfig
-    r: 8
-    lora_alpha: 16

+ 0 - 3
fish_speech/configs/text2semantic_sft.yaml

@@ -17,9 +17,6 @@ trainer:
   precision: bf16-true
   limit_val_batches: 10
   val_check_interval: 500
-  strategy:
-    _target_: lightning.pytorch.strategies.DDPStrategy
-    process_group_backend: nccl  # This should be override when training on windows
 
 # Dataset Configuration
 tokenizer:

+ 8 - 65
fish_speech/models/text2semantic/lit_module.py

@@ -1,25 +1,17 @@
-from dataclasses import dataclass
 from typing import Any, Optional
 
 import lightning as L
-import loralib as lora
 import torch
 import torch.nn.functional as F
 from lightning.pytorch.utilities.types import OptimizerLRScheduler
 
 import fish_speech.utils as utils
 from fish_speech.models.text2semantic.llama import NaiveTransformer
+from fish_speech.models.text2semantic.lora_utils import LoraConfig, setup_lora
 
 log = utils.RankedLogger(__name__, rank_zero_only=True)
 
 
-@dataclass
-class LoraConfig:
-    r: int
-    lora_alpha: float
-    lora_dropout: float = 0.0
-
-
 class TextToSemantic(L.LightningModule):
     def __init__(
         self,
@@ -27,7 +19,6 @@ class TextToSemantic(L.LightningModule):
         optimizer: Any,
         lr_scheduler: Any,
         lora_config: Optional[LoraConfig] = None,
-        save_lora_only: bool = False,
         use_dpo: bool = False,
         dpo_beta: float = 0.2,
     ):
@@ -37,70 +28,17 @@ class TextToSemantic(L.LightningModule):
         self.optimizer_builder = optimizer
         self.lr_scheduler_builder = lr_scheduler
         self.lora_config = lora_config
-        self.save_lora_only = save_lora_only
         self.use_dpo = use_dpo  # We don't support reference model yet
         self.dpo_beta = dpo_beta
 
         if self.lora_config is not None:
-            self.setup_lora()
-
-    def setup_lora(self):
-        # Replace the embedding layer with a LoRA layer
-        self.model.embeddings = lora.Embedding(
-            num_embeddings=self.model.embeddings.num_embeddings,
-            embedding_dim=self.model.embeddings.embedding_dim,
-            padding_idx=self.model.embeddings.padding_idx,
-            r=self.lora_config.r,
-            lora_alpha=self.lora_config.lora_alpha,
-        )
-
-        # Replace output layer with a LoRA layer
-        linears = [(self.model, "output")]
-
-        # Replace all linear layers with LoRA layers
-        for layer in self.model.layers:
-            linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
-            linears.extend(
-                [
-                    (layer.feed_forward, "w1"),
-                    (layer.feed_forward, "w2"),
-                    (layer.feed_forward, "w3"),
-                ]
-            )
-
-        if hasattr(self.model, "fast_layers"):
-            # Dual-AR model
-            linears.extend([(self.model, "fast_output")])
-
-            for layer in self.model.fast_layers:
-                linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
-                linears.extend(
-                    [
-                        (layer.feed_forward, "w1"),
-                        (layer.feed_forward, "w2"),
-                        (layer.feed_forward, "w3"),
-                    ]
-                )
-
-        for module, layer in linears:
-            updated_linear = lora.Linear(
-                in_features=getattr(module, layer).in_features,
-                out_features=getattr(module, layer).out_features,
-                bias=getattr(module, layer).bias,
-                r=self.lora_config.r,
-                lora_alpha=self.lora_config.lora_alpha,
-                lora_dropout=self.lora_config.lora_dropout,
-            )
-            setattr(module, layer, updated_linear)
-
-        # Mark only the LoRA layers as trainable
-        lora.mark_only_lora_as_trainable(self.model, bias="lora_only")
+            setup_lora(self.model, self.lora_config)
 
     def forward(self, x):
         return self.model(x)
 
     def on_save_checkpoint(self, checkpoint):
-        if self.lora_config is None or self.save_lora_only is False:
+        if self.lora_config is None:
             return
 
         # Save only LoRA parameters
@@ -178,6 +116,11 @@ class TextToSemantic(L.LightningModule):
     def _step(self, batch, batch_idx, stage: str):
         is_train = stage == "train"
 
+        if is_train:
+            # Key part to make lora work
+            # Otherwise the parameters are merged, which lead to incorrect gradients
+            self.model.train()
+
         # Do positive and negative samples in the same batch to speed up training
         labels = batch["labels"]
         outputs = self.model(

+ 84 - 0
fish_speech/models/text2semantic/lora_utils.py

@@ -0,0 +1,84 @@
+from dataclasses import dataclass
+
+import loralib as lora
+
+
+@dataclass
+class LoraConfig:
+    r: int
+    lora_alpha: float
+    lora_dropout: float = 0.0
+
+
+def setup_lora(model, lora_config):
+    # Replace the embedding layer with a LoRA layer
+    model.embeddings = lora.Embedding(
+        num_embeddings=model.embeddings.num_embeddings,
+        embedding_dim=model.embeddings.embedding_dim,
+        padding_idx=model.embeddings.padding_idx,
+        r=lora_config.r,
+        lora_alpha=lora_config.lora_alpha,
+    )
+
+    # Replace output layer with a LoRA layer
+    linears = [(model, "output")]
+
+    # Replace all linear layers with LoRA layers
+    for layer in model.layers:
+        linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
+        linears.extend(
+            [
+                (layer.feed_forward, "w1"),
+                (layer.feed_forward, "w2"),
+                (layer.feed_forward, "w3"),
+            ]
+        )
+
+    if hasattr(model, "fast_layers"):
+        model.fast_embeddings = lora.Embedding(
+            num_embeddings=model.fast_embeddings.num_embeddings,
+            embedding_dim=model.fast_embeddings.embedding_dim,
+            padding_idx=model.fast_embeddings.padding_idx,
+            r=lora_config.r,
+            lora_alpha=lora_config.lora_alpha,
+        )
+
+        # Dual-AR model
+        linears.append((model, "fast_output"))
+
+        for layer in model.fast_layers:
+            linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
+            linears.extend(
+                [
+                    (layer.feed_forward, "w1"),
+                    (layer.feed_forward, "w2"),
+                    (layer.feed_forward, "w3"),
+                ]
+            )
+
+    for module, layer in linears:
+        updated_linear = lora.Linear(
+            in_features=getattr(module, layer).in_features,
+            out_features=getattr(module, layer).out_features,
+            bias=getattr(module, layer).bias,
+            r=lora_config.r,
+            lora_alpha=lora_config.lora_alpha,
+            lora_dropout=lora_config.lora_dropout,
+        )
+        setattr(module, layer, updated_linear)
+
+    # Mark only the LoRA layers as trainable
+    lora.mark_only_lora_as_trainable(model, bias="none")
+
+
+def get_merged_state_dict(model):
+    # This line will merge the state dict of the model and the LoRA parameters
+    model.eval()
+
+    # Then we need to remove the LoRA parameters from the state dict
+    state_dict = model.state_dict()
+    for name in list(state_dict.keys()):
+        if "lora" in name:
+            state_dict.pop(name)
+
+    return state_dict

+ 79 - 0
tools/llama/merge_lora.py

@@ -0,0 +1,79 @@
+import click
+import hydra
+import torch
+from hydra import compose, initialize
+from hydra.utils import instantiate
+from loguru import logger
+
+from fish_speech.models.text2semantic.lora_utils import (
+    get_merged_state_dict,
+    setup_lora,
+)
+
+
+@click.command()
+@click.option("--llama-config", type=str, default="dual_ar_2_codebook_large")
+@click.option("--lora-config", type=str, default="r_8_alpha_16")
+@click.option(
+    "--llama-weight", type=str, default="checkpoints/text2semantic-large-v1-4k.pth"
+)
+@click.option("--lora-weight", type=str, required=True)
+@click.option("--output", type=str, required=True)
+def merge(llama_config, lora_config, llama_weight, lora_weight, output):
+    logger.info(
+        f"Merging {llama_weight} and {lora_weight} into {output} with configs {llama_config} and {lora_config}"
+    )
+
+    hydra.core.global_hydra.GlobalHydra.instance().clear()
+    with initialize(version_base="1.3", config_path="../../fish_speech/configs/model"):
+        # The max_seq_len here doesn't matter.
+        cfg = compose(config_name=llama_config, overrides=[f"config.max_seq_len=2048"])
+
+    llama_model = instantiate(cfg)
+    logger.info(f"Loaded llama model with config {llama_config}")
+
+    hydra.core.global_hydra.GlobalHydra.instance().clear()
+    with initialize(version_base="1.3", config_path="../../fish_speech/configs/lora"):
+        cfg = compose(config_name=lora_config)
+
+    lora_config = instantiate(cfg)
+    logger.info(f"Loaded lora model with config {lora_config}")
+
+    setup_lora(llama_model, lora_config)
+    logger.info(f"Merged model setup complete")
+
+    llama_state_dict = torch.load(llama_weight, map_location="cpu")
+    lora_state_dict = torch.load(lora_weight, map_location="cpu")
+
+    if "state_dict" in llama_state_dict:
+        llama_state_dict = llama_state_dict["state_dict"]
+
+    if "state_dict" in lora_state_dict:
+        lora_state_dict = lora_state_dict["state_dict"]
+
+    # remove prefix model.
+    llama_state_dict = {
+        k.replace("model.", ""): v
+        for k, v in llama_state_dict.items()
+        if k.startswith("model.")
+    }
+    lora_state_dict = {
+        k.replace("model.", ""): v
+        for k, v in lora_state_dict.items()
+        if k.startswith("model.")
+    }
+
+    logger.info(f"Found {len(llama_state_dict)} keys in llama model")
+    logger.info(f"Found {len(lora_state_dict)} keys in lora model")
+
+    merged_state_dict = llama_state_dict | lora_state_dict
+    llama_model.load_state_dict(merged_state_dict, strict=True)
+    logger.info(f"Merged model loaded")
+
+    state_dict = get_merged_state_dict(llama_model)
+    torch.save(state_dict, output)
+    logger.info(f"Merged model saved to {output}")
+
+
+if __name__ == "__main__":
+    merge()