Sfoglia il codice sorgente

Use SFT medium as base model

Lengyue 1 anno fa
parent
commit
e787d0182f

+ 2 - 2
docs/en/finetune.md

@@ -148,7 +148,7 @@ After the command finishes executing, you should see the `quantized-dataset-ft.p
 Similarly, make sure you have downloaded the `LLAMA` weights. If not, run the following command:
 
 ```bash
-huggingface-cli download fishaudio/fish-speech-1 text2semantic-sft-large-v1-4k.pth --local-dir checkpoints
+huggingface-cli download fishaudio/fish-speech-1 text2semantic-sft-medium-v1-4k.pth --local-dir checkpoints
 ```
 
 Finally, you can start the fine-tuning by running the following command:
@@ -182,7 +182,7 @@ After training, you need to convert the LoRA weights to regular weights before p
 python tools/llama/merge_lora.py \
     --llama-config dual_ar_2_codebook_large \
     --lora-config r_8_alpha_16 \
-    --llama-weight checkpoints/text2semantic-sft-large-v1-4k.pth \
+    --llama-weight checkpoints/text2semantic-sft-medium-v1-4k.pth \
     --lora-weight results/text2semantic-finetune-medium-lora/checkpoints/step_000000200.ckpt \
     --output checkpoints/merged.ckpt
 ```

+ 4 - 4
docs/en/inference.md

@@ -16,7 +16,7 @@ Download the required `vqgan` and `text2semantic` models from our Hugging Face r
     
 ```bash
 huggingface-cli download fishaudio/fish-speech-1 vq-gan-group-fsq-2x1024.pth --local-dir checkpoints
-huggingface-cli download fishaudio/fish-speech-1 text2semantic-sft-large-v1-4k.pth --local-dir checkpoints
+huggingface-cli download fishaudio/fish-speech-1 text2semantic-sft-medium-v1-4k.pth --local-dir checkpoints
 ```
 
 ### 1. Generate prompt from voice:
@@ -38,7 +38,7 @@ python tools/llama/generate.py \
     --prompt-text "Your reference text" \
     --prompt-tokens "fake.npy" \
     --config-name dual_ar_2_codebook_large \
-    --checkpoint-path "checkpoints/text2semantic-sft-large-v1-4k.pth" \
+    --checkpoint-path "checkpoints/text2semantic-sft-medium-v1-4k.pth" \
     --num-samples 2 \
     --compile
 ```
@@ -69,7 +69,7 @@ We provide a HTTP API for inference. You can use the following command to start
 ```bash
 python -m tools.api \
     --listen 0.0.0.0:8000 \
-    --llama-checkpoint-path "checkpoints/text2semantic-sft-large-v1-4k.pth" \
+    --llama-checkpoint-path "checkpoints/text2semantic-sft-medium-v1-4k.pth" \
     --llama-config-name dual_ar_2_codebook_large \
     --vqgan-checkpoint-path "checkpoints/vq-gan-group-fsq-2x1024.pth"
 ```
@@ -82,7 +82,7 @@ You can start the WebUI using the following command:
 
 ```bash
 python -m tools.webui \
-    --llama-checkpoint-path "checkpoints/text2semantic-sft-large-v1-4k.pth" \
+    --llama-checkpoint-path "checkpoints/text2semantic-sft-medium-v1-4k.pth" \
     --llama-config-name dual_ar_2_codebook_large \
     --vqgan-checkpoint-path "checkpoints/vq-gan-group-fsq-2x1024.pth"
 ```

+ 3 - 3
docs/zh/finetune.md

@@ -152,13 +152,13 @@ python tools/llama/build_dataset.py \
 同样的, 请确保你已经下载了 `LLAMA` 权重, 如果没有, 请运行以下命令:
 
 ```bash
-huggingface-cli download fishaudio/fish-speech-1 text2semantic-sft-large-v1-4k.pth --local-dir checkpoints
+huggingface-cli download fishaudio/fish-speech-1 text2semantic-sft-medium-v1-4k.pth --local-dir checkpoints
 ```
 
 对于中国大陆用户, 可使用 mirror 下载.
 
 ```bash
-HF_ENDPOINT=https://hf-mirror.com huggingface-cli download fishaudio/fish-speech-1 text2semantic-sft-large-v1-4k.pth --local-dir checkpoints
+HF_ENDPOINT=https://hf-mirror.com huggingface-cli download fishaudio/fish-speech-1 text2semantic-sft-medium-v1-4k.pth --local-dir checkpoints
 ```
 
 最后, 你可以运行以下命令来启动微调:
@@ -192,7 +192,7 @@ python fish_speech/train.py --config-name text2semantic_finetune \
 python tools/llama/merge_lora.py \
     --llama-config dual_ar_2_codebook_large \
     --lora-config r_8_alpha_16 \
-    --llama-weight checkpoints/text2semantic-sft-large-v1-4k.pth \
+    --llama-weight checkpoints/text2semantic-sft-medium-v1-4k.pth \
     --lora-weight results/text2semantic-finetune-medium-lora/checkpoints/step_000000200.ckpt \
     --output checkpoints/merged.ckpt
 ```

+ 5 - 5
docs/zh/inference.md

@@ -16,12 +16,12 @@
     
 ```bash
 huggingface-cli download fishaudio/fish-speech-1 vq-gan-group-fsq-2x1024.pth --local-dir checkpoints
-huggingface-cli download fishaudio/fish-speech-1 text2semantic-sft-large-v1-4k.pth --local-dir checkpoints
+huggingface-cli download fishaudio/fish-speech-1 text2semantic-sft-medium-v1-4k.pth --local-dir checkpoints
 ```
 对于中国大陆用户,可使用mirror下载。
 ```bash
 HF_ENDPOINT=https://hf-mirror.com huggingface-cli download fishaudio/fish-speech-1 vq-gan-group-fsq-2x1024.pth --local-dir checkpoints
-HF_ENDPOINT=https://hf-mirror.com huggingface-cli download fishaudio/fish-speech-1 text2semantic-sft-large-v1-4k.pth --local-dir checkpoints
+HF_ENDPOINT=https://hf-mirror.com huggingface-cli download fishaudio/fish-speech-1 text2semantic-sft-medium-v1-4k.pth --local-dir checkpoints
 ```
 
 ### 1. 从语音生成 prompt: 
@@ -43,7 +43,7 @@ python tools/llama/generate.py \
     --prompt-text "你的参考文本" \
     --prompt-tokens "fake.npy" \
     --config-name dual_ar_2_codebook_large \
-    --checkpoint-path "checkpoints/text2semantic-sft-large-v1-4k.pth" \
+    --checkpoint-path "checkpoints/text2semantic-sft-medium-v1-4k.pth" \
     --num-samples 2 \
     --compile
 ```
@@ -74,7 +74,7 @@ python tools/vqgan/inference.py \
 ```bash
 python -m tools.api \
     --listen 0.0.0.0:8000 \
-    --llama-checkpoint-path "checkpoints/text2semantic-sft-large-v1-4k.pth" \
+    --llama-checkpoint-path "checkpoints/text2semantic-sft-medium-v1-4k.pth" \
     --llama-config-name dual_ar_2_codebook_large \
     --vqgan-checkpoint-path "checkpoints/vq-gan-group-fsq-2x1024.pth"
 
@@ -90,7 +90,7 @@ HF_ENDPOINT=https://hf-mirror.com python -m ...
 
 ```bash
 python -m tools.webui \
-    --llama-checkpoint-path "checkpoints/text2semantic-sft-large-v1-4k.pth" \
+    --llama-checkpoint-path "checkpoints/text2semantic-sft-medium-v1-4k.pth" \
     --llama-config-name dual_ar_2_codebook_large \
     --vqgan-checkpoint-path "checkpoints/vq-gan-group-fsq-2x1024.pth"
 ```

+ 1 - 1
fish_speech/configs/text2semantic_finetune.yaml

@@ -5,7 +5,7 @@ defaults:
 
 project: text2semantic_finetune_dual_ar
 max_length: 2048
-ckpt_path: checkpoints/text2semantic-sft-large-v1-4k.pth
+ckpt_path: checkpoints/text2semantic-sft-medium-v1-4k.pth
 resume_weights_only: true
 
 # Lightning Trainer

+ 1 - 1
fish_speech/webui/manage.py

@@ -470,7 +470,7 @@ def train_process(
         ckpt_path = (
             "text2semantic-pretrain-medium-2k-v1.pth"
             if llama_base_config == "dual_ar_2_codebook_medium"
-            else "text2semantic-sft-large-v1-4k.pth"
+            else "text2semantic-sft-medium-v1-4k.pth"
         )
 
         latest = list(

+ 0 - 1
pyproject.toml

@@ -34,7 +34,6 @@ dependencies = [
     "vector_quantize_pytorch>=1.14.7",
     "samplerate>=0.2.1",
     "resampy>=0.4.3",
-    "spaces>=0.26.1",
     "einx[torch]==0.2.2"
 ]
 

+ 1 - 1
tools/api.py

@@ -225,7 +225,7 @@ def parse_args():
     parser.add_argument(
         "--llama-checkpoint-path",
         type=str,
-        default="checkpoints/text2semantic-sft-large-v1-4k.pth",
+        default="checkpoints/text2semantic-sft-medium-v1-4k.pth",
     )
     parser.add_argument(
         "--llama-config-name", type=str, default="dual_ar_2_codebook_large"

+ 1 - 1
tools/llama/merge_lora.py

@@ -15,7 +15,7 @@ from fish_speech.models.text2semantic.lora_utils import (
 @click.option("--llama-config", type=str, default="dual_ar_2_codebook_large")
 @click.option("--lora-config", type=str, default="r_8_alpha_16")
 @click.option(
-    "--llama-weight", type=str, default="checkpoints/text2semantic-sft-large-v1-4k.pth"
+    "--llama-weight", type=str, default="checkpoints/text2semantic-sft-medium-v1-4k.pth"
 )
 @click.option("--lora-weight", type=str, required=True)
 @click.option("--output", type=str, required=True)

+ 1 - 22
tools/webui.py

@@ -40,21 +40,6 @@ HEADER_MD = f"""# Fish Speech
 TEXTBOX_PLACEHOLDER = i18n("Put your text here.")
 SPACE_IMPORTED = False
 
-try:
-    import spaces
-
-    GPU_DECORATOR = spaces.GPU
-    SPACE_IMPORTED = True
-except ImportError:
-
-    def GPU_DECORATOR(func):
-        @wraps(func)
-        def wrapper(*args, **kwargs):
-            return func(*args, **kwargs)
-
-        wrapper.original = func  # ref
-        return wrapper
-
 
 def build_html_error_message(error):
     return f"""
@@ -65,7 +50,6 @@ def build_html_error_message(error):
     """
 
 
-@GPU_DECORATOR
 @torch.inference_mode()
 def inference(
     text,
@@ -173,11 +157,6 @@ def inference(
 
 inference_stream = partial(inference, streaming=True)
 
-if not SPACE_IMPORTED:
-    logger.info("‘spaces’ not imported, use original")
-    inference = inference.original
-    inference_stream = partial(inference, streaming=True)
-
 
 def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
     buffer = io.BytesIO()
@@ -343,7 +322,7 @@ def parse_args():
     parser.add_argument(
         "--llama-checkpoint-path",
         type=Path,
-        default="checkpoints/text2semantic-sft-large-v1-4k.pth",
+        default="checkpoints/text2semantic-sft-medium-v1-4k.pth",
     )
     parser.add_argument(
         "--llama-config-name", type=str, default="dual_ar_2_codebook_large"