Forráskód Böngészése

v1.5 (#696)

* fix e2e_webui

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Agent: Streaming audio

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix text streaming

* [feature]:add tiktoken tokenizer to fit v1.5

* v1.5 vq

* update docs

* [feature]:add agent infer

* [feature]:add decoder of api agent inference

* [fix]: use lengyue's fix to fix infer bugs

* [fix]:fix the problem of inference error with prompt audio

* [fix]:remove some used tokens

* [fix]:fix some prompt bug

* [fix]:fix the origin audio of speaking out the system prompt

* remove unused

* revert spliter

* remove unused

* remove unused ignore

* remove root conversaion

* fix llama

* disable visualization

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: whaledolphin <whaledolphin666@gmail.com>
Co-authored-by: PoTaTo <1228427403@qq.com>
Co-authored-by: Whale and Dolphin <70465000+Whale-Dolphin@users.noreply.github.com>
spicysama 1 éve
szülő
commit
b951de3b72

+ 2 - 2
.pre-commit-config.yaml

@@ -20,6 +20,6 @@ repos:
       - id: check-yaml
       - id: check-json
       - id: mixed-line-ending
-        args: ['--fix=lf']
+        args: ["--fix=lf"]
       - id: check-added-large-files
-        args: ['--maxkb=5000']
+        args: ["--maxkb=5000"]

+ 5 - 5
docs/en/finetune.md

@@ -39,7 +39,7 @@ You need to convert your dataset into the above format and place it under `data`
 Make sure you have downloaded the VQGAN weights. If not, run the following command:
 
 ```bash
-huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-speech-1.4
+huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5
 ```
 
 You can then run the following command to extract semantic tokens:
@@ -48,7 +48,7 @@ You can then run the following command to extract semantic tokens:
 python tools/vqgan/extract_vq.py data \
     --num-workers 1 --batch-size 16 \
     --config-name "firefly_gan_vq" \
-    --checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
+    --checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
 ```
 
 !!! note
@@ -92,7 +92,7 @@ After the command finishes executing, you should see the `quantized-dataset-ft.p
 Similarly, make sure you have downloaded the `LLAMA` weights. If not, run the following command:
 
 ```bash
-huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-speech-1.4
+huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5
 ```
 
 Finally, you can start the fine-tuning by running the following command:
@@ -120,9 +120,9 @@ After training, you need to convert the LoRA weights to regular weights before p
 ```bash
 python tools/llama/merge_lora.py \
 	--lora-config r_8_alpha_16 \
-	--base-weight checkpoints/fish-speech-1.4 \
+	--base-weight checkpoints/fish-speech-1.5 \
 	--lora-weight results/$project/checkpoints/step_000000010.ckpt \
-	--output checkpoints/fish-speech-1.4-yth-lora/
+	--output checkpoints/fish-speech-1.5-yth-lora/
 ```
 !!! note
     You may also try other checkpoints. We suggest using the earliest checkpoint that meets your requirements, as they often perform better on out-of-distribution (OOD) data.

+ 1 - 1
docs/en/index.md

@@ -179,7 +179,7 @@ pip install -e .[stable]
     Make sure you are in the terminal inside the docker container, then download the required `vqgan` and `llama` models from our huggingface repository.
 
     ```bash
-    huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-speech-1.4
+    huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5
     ```
 
 4. Configure environment variables and access WebUI

+ 8 - 8
docs/en/inference.md

@@ -15,7 +15,7 @@ Inference support command line, HTTP API and web UI.
 Download the required `vqgan` and `llama` models from our Hugging Face repository.
 
 ```bash
-huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-speech-1.4
+huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5
 ```
 
 ### 1. Generate prompt from voice:
@@ -26,7 +26,7 @@ huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-
 ```bash
 python tools/vqgan/inference.py \
     -i "paimon.wav" \
-    --checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
+    --checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
 ```
 
 You should get a `fake.npy` file.
@@ -38,7 +38,7 @@ python tools/llama/generate.py \
     --text "The text you want to convert" \
     --prompt-text "Your reference text" \
     --prompt-tokens "fake.npy" \
-    --checkpoint-path "checkpoints/fish-speech-1.4" \
+    --checkpoint-path "checkpoints/fish-speech-1.5" \
     --num-samples 2 \
     --compile
 ```
@@ -59,7 +59,7 @@ This command will create a `codes_N` file in the working directory, where N is a
 ```bash
 python tools/vqgan/inference.py \
     -i "codes_0.npy" \
-    --checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
+    --checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
 ```
 
 ## HTTP API Inference
@@ -69,8 +69,8 @@ We provide a HTTP API for inference. You can use the following command to start
 ```bash
 python -m tools.api \
     --listen 0.0.0.0:8080 \
-    --llama-checkpoint-path "checkpoints/fish-speech-1.4" \
-    --decoder-checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \
+    --llama-checkpoint-path "checkpoints/fish-speech-1.5" \
+    --decoder-checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \
     --decoder-config-name firefly_gan_vq
 ```
 
@@ -120,8 +120,8 @@ You can start the WebUI using the following command:
 
 ```bash
 python -m tools.webui \
-    --llama-checkpoint-path "checkpoints/fish-speech-1.4" \
-    --decoder-checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \
+    --llama-checkpoint-path "checkpoints/fish-speech-1.5" \
+    --decoder-checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \
     --decoder-config-name firefly_gan_vq
 ```
 > If you want to speed up inference, you can add the `--compile` parameter.

+ 5 - 5
docs/ja/finetune.md

@@ -39,7 +39,7 @@
 VQGANの重みをダウンロードしたことを確認してください。まだダウンロードしていない場合は、次のコマンドを実行してください。
 
 ```bash
-huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-speech-1.4
+huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5
 ```
 
 次に、次のコマンドを実行してセマンティックトークンを抽出できます。
@@ -48,7 +48,7 @@ huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-
 python tools/vqgan/extract_vq.py data \
     --num-workers 1 --batch-size 16 \
     --config-name "firefly_gan_vq" \
-    --checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
+    --checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
 ```
 
 !!! note
@@ -92,7 +92,7 @@ python tools/llama/build_dataset.py \
 同様に、`LLAMA`の重みをダウンロードしたことを確認してください。まだダウンロードしていない場合は、次のコマンドを実行してください。
 
 ```bash
-huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-speech-1.4
+huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5
 ```
 
 最後に、次のコマンドを実行して微調整を開始できます。
@@ -120,9 +120,9 @@ python fish_speech/train.py --config-name text2semantic_finetune \
 ```bash
 python tools/llama/merge_lora.py \
 	--lora-config r_8_alpha_16 \
-	--base-weight checkpoints/fish-speech-1.4 \
+	--base-weight checkpoints/fish-speech-1.5 \
 	--lora-weight results/$project/checkpoints/step_000000010.ckpt \
-	--output checkpoints/fish-speech-1.4-yth-lora/
+	--output checkpoints/fish-speech-1.5-yth-lora/
 ```
 !!! note
     他のチェックポイントを試すこともできます。要件を満たす最も早いチェックポイントを使用することをお勧めします。これらは通常、分布外(OOD)データでより良いパフォーマンスを発揮します。

+ 1 - 1
docs/ja/index.md

@@ -178,7 +178,7 @@ pip install -e .[stable]
     Docker コンテナ内のターミナルにいることを確認し、huggingface リポジトリから必要な `vqgan` と `llama` モデルをダウンロードします。
 
     ```bash
-    huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-speech-1.4
+    huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5
     ```
 
 4. 環境変数の設定と WebUI へのアクセス

+ 8 - 8
docs/ja/inference.md

@@ -15,7 +15,7 @@
 必要な`vqgan`および`llama`モデルを Hugging Face リポジトリからダウンロードします。
 
 ```bash
-huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-speech-1.4
+huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5
 ```
 
 ### 1. 音声からプロンプトを生成する:
@@ -26,7 +26,7 @@ huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-
 ```bash
 python tools/vqgan/inference.py \
     -i "paimon.wav" \
-    --checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
+    --checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
 ```
 
 `fake.npy`ファイルが生成されるはずです。
@@ -38,7 +38,7 @@ python tools/llama/generate.py \
     --text "変換したいテキスト" \
     --prompt-text "参照テキスト" \
     --prompt-tokens "fake.npy" \
-    --checkpoint-path "checkpoints/fish-speech-1.4" \
+    --checkpoint-path "checkpoints/fish-speech-1.5" \
     --num-samples 2 \
     --compile
 ```
@@ -59,7 +59,7 @@ python tools/llama/generate.py \
 ```bash
 python tools/vqgan/inference.py \
     -i "codes_0.npy" \
-    --checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
+    --checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
 ```
 
 ## HTTP API 推論
@@ -69,8 +69,8 @@ python tools/vqgan/inference.py \
 ```bash
 python -m tools.api \
     --listen 0.0.0.0:8080 \
-    --llama-checkpoint-path "checkpoints/fish-speech-1.4" \
-    --decoder-checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \
+    --llama-checkpoint-path "checkpoints/fish-speech-1.5" \
+    --decoder-checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \
     --decoder-config-name firefly_gan_vq
 ```
 
@@ -99,8 +99,8 @@ python -m tools.post_api \
 
 ```bash
 python -m tools.webui \
-    --llama-checkpoint-path "checkpoints/fish-speech-1.4" \
-    --decoder-checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \
+    --llama-checkpoint-path "checkpoints/fish-speech-1.5" \
+    --decoder-checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \
     --decoder-config-name firefly_gan_vq
 ```
 > 推論を高速化したい場合は、`--compile` パラメータを追加できます。

+ 5 - 5
docs/ko/finetune.md

@@ -38,7 +38,7 @@
 VQGAN 가중치를 다운로드했는지 확인하세요. 다운로드하지 않았다면 아래 명령어를 실행하세요:
 
 ```bash
-huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-speech-1.4
+huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5
 ```
 
 이후 시맨틱 토큰을 추출하기 위해 아래 명령어를 실행하세요:
@@ -47,7 +47,7 @@ huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-
 python tools/vqgan/extract_vq.py data \
     --num-workers 1 --batch-size 16 \
     --config-name "firefly_gan_vq" \
-    --checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
+    --checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
 ```
 
 !!! note
@@ -91,7 +91,7 @@ python tools/llama/build_dataset.py \
 마찬가지로, `LLAMA` 가중치를 다운로드했는지 확인하세요. 다운로드하지 않았다면 아래 명령어를 실행하세요:
 
 ```bash
-huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-speech-1.4
+huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5
 ```
 
 마지막으로, 아래 명령어를 실행하여 파인튜닝을 시작할 수 있습니다:
@@ -119,9 +119,9 @@ python fish_speech/train.py --config-name text2semantic_finetune \
 ```bash
 python tools/llama/merge_lora.py \
 	--lora-config r_8_alpha_16 \
-	--base-weight checkpoints/fish-speech-1.4 \
+	--base-weight checkpoints/fish-speech-1.5 \
 	--lora-weight results/$project/checkpoints/step_000000010.ckpt \
-	--output checkpoints/fish-speech-1.4-yth-lora/
+	--output checkpoints/fish-speech-1.5-yth-lora/
 ```
 
 !!! note

+ 1 - 1
docs/ko/index.md

@@ -179,7 +179,7 @@ pip install -e .[stable]
     Docker 컨테이너 내부의 터미널에서 아래 명령어를 사용하여 필요한 `vqgan` 및 `llama` 모델을 Huggingface 리포지토리에서 다운로드합니다.
 
     ```bash
-    huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-speech-1.4
+    huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5
     ```
 
 4. 환경 변수 설정 및 WebUI 접근

+ 8 - 8
docs/ko/inference.md

@@ -15,7 +15,7 @@
 필요한 `vqgan` 및 `llama` 모델을 Hugging Face 리포지토리에서 다운로드하세요.
 
 ```bash
-huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-speech-1.4
+huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5
 ```
 
 ### 1. 음성에서 프롬프트 생성:
@@ -26,7 +26,7 @@ huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-
 ```bash
 python tools/vqgan/inference.py \
     -i "paimon.wav" \
-    --checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
+    --checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
 ```
 
 이 명령을 실행하면 `fake.npy` 파일을 얻게 됩니다.
@@ -38,7 +38,7 @@ python tools/llama/generate.py \
     --text "변환할 텍스트" \
     --prompt-text "참고할 텍스트" \
     --prompt-tokens "fake.npy" \
-    --checkpoint-path "checkpoints/fish-speech-1.4" \
+    --checkpoint-path "checkpoints/fish-speech-1.5" \
     --num-samples 2 \
     --compile
 ```
@@ -59,7 +59,7 @@ python tools/llama/generate.py \
 ```bash
 python tools/vqgan/inference.py \
     -i "codes_0.npy" \
-    --checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
+    --checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
 ```
 
 ## HTTP API 추론
@@ -69,8 +69,8 @@ python tools/vqgan/inference.py \
 ```bash
 python -m tools.api \
     --listen 0.0.0.0:8080 \
-    --llama-checkpoint-path "checkpoints/fish-speech-1.4" \
-    --decoder-checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \
+    --llama-checkpoint-path "checkpoints/fish-speech-1.5" \
+    --decoder-checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \
     --decoder-config-name firefly_gan_vq
 ```
 
@@ -118,8 +118,8 @@ python -m tools.post_api \
 
 ```bash
 python -m tools.webui \
-    --llama-checkpoint-path "checkpoints/fish-speech-1.4" \
-    --decoder-checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \
+    --llama-checkpoint-path "checkpoints/fish-speech-1.5" \
+    --decoder-checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \
     --decoder-config-name firefly_gan_vq
 ```
 

+ 5 - 5
docs/pt/finetune.md

@@ -39,7 +39,7 @@ Você precisa converter seu conjunto de dados para o formato acima e colocá-lo
 Certifique-se de ter baixado os pesos do VQGAN. Se não, execute o seguinte comando:
 
 ```bash
-huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-speech-1.4
+huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5
 ```
 
 Em seguida, você pode executar o seguinte comando para extrair os tokens semânticos:
@@ -48,7 +48,7 @@ Em seguida, você pode executar o seguinte comando para extrair os tokens semân
 python tools/vqgan/extract_vq.py data \
     --num-workers 1 --batch-size 16 \
     --config-name "firefly_gan_vq" \
-    --checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
+    --checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
 ```
 
 !!! note
@@ -92,7 +92,7 @@ Após executar o comando, você deverá ver o arquivo `quantized-dataset-ft.prot
 Da mesma forma, certifique-se de ter baixado os pesos do `LLAMA`. Se não, execute o seguinte comando:
 
 ```bash
-huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-speech-1.4
+huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5
 ```
 
 E então, execute o seguinte comando para iniciar o ajuste fino:
@@ -120,9 +120,9 @@ Após o treinamento, é preciso converter os pesos do LoRA em pesos regulares an
 ```bash
 python tools/llama/merge_lora.py \
     --lora-config r_8_alpha_16 \
-    --base-weight checkpoints/fish-speech-1.4 \
+    --base-weight checkpoints/fish-speech-1.5 \
     --lora-weight results/$project/checkpoints/step_000000010.ckpt \
-    --output checkpoints/fish-speech-1.4-yth-lora/
+    --output checkpoints/fish-speech-1.5-yth-lora/
 ```
 !!! note
     É possível também tentar outros checkpoints. Sugerimos usar o checkpoint que melhor atenda aos seus requisitos, pois eles geralmente têm um desempenho melhor em dados fora da distribuição (OOD).

+ 1 - 1
docs/pt/index.md

@@ -175,7 +175,7 @@ pip install -e .[stable]
     Certifique-se de estar no terminal do contêiner Docker e, em seguida, baixe os modelos necessários `vqgan` e `llama` do nosso repositório HuggingFace.
 
     ```bash
-    huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-speech-1.4
+    huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5
     ```
 
 4. Configure as variáveis de ambiente e acesse a WebUI

+ 7 - 7
docs/pt/inference.md

@@ -15,7 +15,7 @@ Suporte para inferência por linha de comando, API HTTP e interface web (WebUI).
 Baixe os modelos `vqgan` e `llama` necessários do nosso repositório Hugging Face.
 
 ```bash
-huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-speech-1.4
+huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5
 ```
 
 ### 1. Gerar prompt a partir da voz:
@@ -26,7 +26,7 @@ huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-
 ```bash
 python tools/vqgan/inference.py \
     -i "paimon.wav" \
-    --checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
+    --checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
 ```
 
 Você deverá obter um arquivo `fake.npy`.
@@ -38,7 +38,7 @@ python tools/llama/generate.py \
     --text "O texto que você deseja converter" \
     --prompt-text "Seu texto de referência" \
     --prompt-tokens "fake.npy" \
-    --checkpoint-path "checkpoints/fish-speech-1.4" \
+    --checkpoint-path "checkpoints/fish-speech-1.5" \
     --num-samples 2 \
     --compile
 ```
@@ -59,7 +59,7 @@ Este comando criará um arquivo `codes_N` no diretório de trabalho, onde N é u
 ```bash
 python tools/vqgan/inference.py \
     -i "codes_0.npy" \
-    --checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
+    --checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
 ```
 
 ## Inferência por API HTTP
@@ -69,7 +69,7 @@ Fornecemos uma API HTTP para inferência. O seguinte comando pode ser usado para
 ```bash
 python -m tools.api \
     --listen 0.0.0.0:8080 \
-    --llama-checkpoint-path "checkpoints/fish-speech-1.4" \
+    --llama-checkpoint-path "checkpoints/fish-speech-1.5" \
     --decoder-checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \
     --decoder-config-name firefly_gan_vq
 ```
@@ -99,8 +99,8 @@ Para iniciar a WebUI de Inferência execute o seguinte comando:
 
 ```bash
 python -m tools.webui \
-    --llama-checkpoint-path "checkpoints/fish-speech-1.4" \
-    --decoder-checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \
+    --llama-checkpoint-path "checkpoints/fish-speech-1.5" \
+    --decoder-checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \
     --decoder-config-name firefly_gan_vq
 ```
 > Para acelerar a inferência, adicione o parâmetro `--compile`.

+ 7 - 7
docs/zh/finetune.md

@@ -37,13 +37,13 @@
 确保你已经下载了 vqgan 权重, 如果没有, 请运行以下命令:
 
 ```bash
-huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-speech-1.4
+huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5
 ```
 
 对于中国大陆用户, 可使用 mirror 下载.
 
 ```bash
-HF_ENDPOINT=https://hf-mirror.com huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-speech-1.4
+HF_ENDPOINT=https://hf-mirror.com huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5
 ```
 
 随后可运行以下命令来提取语义 token:
@@ -52,7 +52,7 @@ HF_ENDPOINT=https://hf-mirror.com huggingface-cli download fishaudio/fish-speech
 python tools/vqgan/extract_vq.py data \
     --num-workers 1 --batch-size 16 \
     --config-name "firefly_gan_vq" \
-    --checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
+    --checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
 ```
 
 !!! note
@@ -96,13 +96,13 @@ python tools/llama/build_dataset.py \
 同样的, 请确保你已经下载了 `LLAMA` 权重, 如果没有, 请运行以下命令:
 
 ```bash
-huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-speech-1.4
+huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5
 ```
 
 对于中国大陆用户, 可使用 mirror 下载.
 
 ```bash
-HF_ENDPOINT=https://hf-mirror.com huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-speech-1.4
+HF_ENDPOINT=https://hf-mirror.com huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5
 ```
 
 最后, 你可以运行以下命令来启动微调:
@@ -130,9 +130,9 @@ python fish_speech/train.py --config-name text2semantic_finetune \
 ```bash
 python tools/llama/merge_lora.py \
 	--lora-config r_8_alpha_16 \
-	--base-weight checkpoints/fish-speech-1.4 \
+	--base-weight checkpoints/fish-speech-1.5 \
 	--lora-weight results/$project/checkpoints/step_000000010.ckpt \
-	--output checkpoints/fish-speech-1.4-yth-lora/
+	--output checkpoints/fish-speech-1.5-yth-lora/
 ```
 
 !!! note

+ 2 - 2
docs/zh/index.md

@@ -176,13 +176,13 @@ pip install -e .[stable]
     确保您在 docker 容器内的终端,然后再从我们的 huggingface 仓库下载所需的 `vqgan` 和 `llama` 模型。
 
     ```bash
-    huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-speech-1.4
+    huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5
     ```
 
     对于中国大陆用户,可以通过镜像站下载。
 
     ```bash
-    HF_ENDPOINT=https://hf-mirror.com huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-speech-1.4
+    HF_ENDPOINT=https://hf-mirror.com huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5
     ```
 
 4. 配置环境变量,访问 WebUI

+ 9 - 9
docs/zh/inference.md

@@ -15,13 +15,13 @@
 从我们的 huggingface 仓库下载所需的 `vqgan` 和 `llama` 模型。
 
 ```bash
-huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-speech-1.4
+huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5
 ```
 
 对于中国大陆用户,可使用 mirror 下载。
 
 ```bash
-HF_ENDPOINT=https://hf-mirror.com huggingface-cli download fishaudio/fish-speech-1.4 --local-dir checkpoints/fish-speech-1.4
+HF_ENDPOINT=https://hf-mirror.com huggingface-cli download fishaudio/fish-speech-1.5 --local-dir checkpoints/fish-speech-1.5
 ```
 
 ### 1. 从语音生成 prompt:
@@ -32,7 +32,7 @@ HF_ENDPOINT=https://hf-mirror.com huggingface-cli download fishaudio/fish-speech
 ```bash
 python tools/vqgan/inference.py \
     -i "paimon.wav" \
-    --checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
+    --checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
 ```
 
 你应该能得到一个 `fake.npy` 文件.
@@ -44,7 +44,7 @@ python tools/llama/generate.py \
     --text "要转换的文本" \
     --prompt-text "你的参考文本" \
     --prompt-tokens "fake.npy" \
-    --checkpoint-path "checkpoints/fish-speech-1.4" \
+    --checkpoint-path "checkpoints/fish-speech-1.5" \
     --num-samples 2 \
     --compile
 ```
@@ -65,7 +65,7 @@ python tools/llama/generate.py \
 ```bash
 python tools/vqgan/inference.py \
     -i "codes_0.npy" \
-    --checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
+    --checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
 ```
 
 ## HTTP API 推理
@@ -75,8 +75,8 @@ python tools/vqgan/inference.py \
 ```bash
 python -m tools.api \
     --listen 0.0.0.0:8080 \
-    --llama-checkpoint-path "checkpoints/fish-speech-1.4" \
-    --decoder-checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \
+    --llama-checkpoint-path "checkpoints/fish-speech-1.5" \
+    --decoder-checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \
     --decoder-config-name firefly_gan_vq
 ```
 > 如果你想要加速推理,可以加上`--compile`参数。
@@ -128,8 +128,8 @@ python -m tools.post_api \
 
 ```bash
 python -m tools.webui \
-    --llama-checkpoint-path "checkpoints/fish-speech-1.4" \
-    --decoder-checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \
+    --llama-checkpoint-path "checkpoints/fish-speech-1.5" \
+    --decoder-checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \
     --decoder-config-name firefly_gan_vq
 ```
 > 如果你想要加速推理,可以加上`--compile`参数。

+ 94 - 83
fish_speech/conversation.py

@@ -2,41 +2,10 @@ from dataclasses import dataclass, field
 from typing import Literal
 
 import torch
-from transformers import AutoTokenizer, PretrainedConfig, PreTrainedTokenizerFast
-
-IM_START_TOKEN = "<|im_start|>"
-IM_END_TOKEN = "<|im_end|>"
-SEMANTIC_TOKEN = "<|semantic|>"
-MEL_TOKEN = "<|mel|>"
-PHONEME_START_TOKEN = "<|phoneme_start|>"
-PHONEME_END_TOKEN = "<|phoneme_end|>"
-ALL_SPECIAL_TOKENS = [
-    IM_START_TOKEN,
-    IM_END_TOKEN,
-    SEMANTIC_TOKEN,
-    MEL_TOKEN,
-    PHONEME_START_TOKEN,
-    PHONEME_END_TOKEN,
-]
-
-CODEBOOK_PAD_TOKEN_ID = 0
-
-
-class FishTokenizerConfig(PretrainedConfig):
-    share_codebook_embeddings: bool = True
-    codebook_size: int = 1024
-    num_codebooks: int = 8
 
+from .tokenizer import MODALITY_TOKENS, FishTokenizer
 
-class FishTokenizerFast(PreTrainedTokenizerFast):
-    def __init__(self, *args, **kwargs):
-        super().__init__(*args, **kwargs)
-        self.share_codebook_embeddings = kwargs.pop("share_codebook_embeddings", True)
-        self.codebook_size = kwargs.pop("codebook_size", 1024)
-        self.num_codebooks = kwargs.pop("num_codebooks", 8)
-
-
-AutoTokenizer.register(FishTokenizerConfig, fast_tokenizer_class=FishTokenizerFast)
+CODEBOOK_PAD_TOKEN_ID = 0
 
 
 @dataclass(kw_only=True)
@@ -54,77 +23,72 @@ class TextPart(BasePart):
     text: str
 
 
-@dataclass(kw_only=True)
-class MelPart(BasePart):
-    mels: torch.Tensor
-
-
 @dataclass(kw_only=True)
 class EncodedMessage:
     tokens: torch.Tensor
     labels: torch.Tensor
+    vq_mask_tokens: torch.Tensor | None = None
+    vq_mask_labels: torch.Tensor | None = None
     vq_parts: list[torch.Tensor]
-    mel_parts: list[torch.Tensor]
     vq_require_losses: torch.Tensor | None = None
 
 
 @dataclass(kw_only=True)
 class Message:
     role: Literal["system", "user", "assistant"]
-    parts: list[VQPart | TextPart | MelPart] = field(default_factory=list)
+    parts: list[VQPart | TextPart] = field(default_factory=list)
     add_im_start: bool = True
     add_im_end: bool = True
     cal_loss: bool = False
+    modality: Literal["text", "voice", "interleave"] | None = None
 
     # By default, ignore the loss of the auto-generated im_start token
     ignore_im_start_loss: bool = True
 
     def encode(
         self: "Message",
-        tokenizer: AutoTokenizer,
+        tokenizer: FishTokenizer,
     ) -> EncodedMessage:
         all_tokens = []
         all_labels = []
 
         # Multi-modal tokens
         vq_parts = []
-        mel_parts = []
-
-        semantic_id, mel_id = tokenizer.convert_tokens_to_ids(
-            [SEMANTIC_TOKEN, MEL_TOKEN]
-        )
+        vq_masks = []
 
         parts = self.parts.copy()
         if self.add_im_start:
-            parts.insert(0, TextPart(text=f"<|im_start|>{self.role}\n"))
+            modality_token = MODALITY_TOKENS[self.modality] if self.modality else ""
+            parts.insert(0, TextPart(text=f"<|im_start|>{self.role}\n{modality_token}"))
 
         if self.add_im_end:
             parts.append(TextPart(text="<|im_end|>"))
 
         for part in parts:
             if isinstance(part, TextPart):
-                tokens = tokenizer.encode(
-                    part.text,
-                    add_special_tokens=False,
-                    truncation=False,
-                    return_tensors="pt",
-                ).int()[0]
+                tokens = torch.tensor(
+                    tokenizer.encode(part.text),
+                    dtype=torch.int,
+                )
             elif isinstance(part, VQPart):
-                tokens = torch.zeros(part.codes.shape[1], dtype=torch.int) + semantic_id
-                codes = part.codes.clone() + 1
-
-                if getattr(tokenizer, "share_codebook_embeddings", True) is False:
-                    for i in range(len(codes)):
-                        codes[i] += tokenizer.codebook_size * i
-
-                vq_parts.append(codes)
-            elif isinstance(part, MelPart):
-                tokens = torch.zeros(part.mels.shape[1], dtype=torch.int) + mel_id
-                mel_parts.append(part.mels)
+                curr_codes = part.codes.clone()
+                tokens = torch.tensor(
+                    [
+                        tokenizer.semantic_id_to_token_id[i.item()]
+                        for i in curr_codes[0].int()
+                    ],
+                    dtype=torch.int,
+                )
+                vq_parts.append(curr_codes)
             else:
                 raise ValueError(f"Unsupported part type: {type(part)}")
 
             all_tokens.append(tokens)
+            if isinstance(part, VQPart):
+                vq_masks.append(torch.ones_like(tokens, dtype=torch.bool))
+            else:
+                vq_masks.append(torch.zeros_like(tokens, dtype=torch.bool))
+
             if self.cal_loss:
                 all_labels.append(tokens.clone())
             else:
@@ -132,7 +96,9 @@ class Message:
 
         tokens = torch.cat(all_tokens, dim=0)
         labels = torch.cat(all_labels, dim=0)
-        assert tokens.shape == labels.shape
+        vq_masks = torch.cat(vq_masks, dim=0)
+
+        assert tokens.shape == labels.shape == vq_masks.shape
 
         if self.ignore_im_start_loss and self.add_im_start:
             labels[: len(all_tokens[0])] = -100
@@ -141,7 +107,8 @@ class Message:
             tokens=tokens,
             labels=labels,
             vq_parts=vq_parts,
-            mel_parts=mel_parts,
+            vq_mask_tokens=vq_masks,
+            vq_mask_labels=vq_masks,
         )
 
 
@@ -149,17 +116,23 @@ class Message:
 class Conversation:
     messages: list[Message]
 
+    def __init__(self: "Conversation", messages: list[Message] | None = None):
+        self.messages = messages or []
+
     def encode(
         self: "Conversation",
-        tokenizer: AutoTokenizer,
+        tokenizer: FishTokenizer,
         add_shift: bool = True,
+        ignore_loss_tokens: list[str] = [],
     ) -> EncodedMessage:
         # Build the input_ids and labels
         tokens = []
         labels = []
         vq_parts = []
-        mel_parts = []
+        vq_mask_tokens = []
+        vq_mask_labels = []
         vq_require_losses = []
+        ignore_loss_token_ids = [tokenizer.get_token_id(i) for i in ignore_loss_tokens]
 
         for message in self.messages:
             encoded = message.encode(
@@ -168,16 +141,25 @@ class Conversation:
             tokens.append(encoded.tokens)
             labels.append(encoded.labels)
             vq_parts.extend(encoded.vq_parts)
-            mel_parts.extend(encoded.mel_parts)
+            vq_mask_tokens.append(encoded.vq_mask_tokens)
+            vq_mask_labels.append(encoded.vq_mask_labels)
             vq_require_losses.extend([message.cal_loss] * len(encoded.vq_parts))
 
         tokens = torch.cat(tokens, dim=0)
         labels = torch.cat(labels, dim=0)
+        vq_mask_tokens = torch.cat(vq_mask_tokens, dim=0)
+        vq_mask_labels = torch.cat(vq_mask_labels, dim=0)
         vq_require_losses = torch.tensor(vq_require_losses, dtype=torch.bool)
 
         if add_shift:
             tokens = tokens[:-1]
             labels = labels[1:]
+            vq_mask_tokens = vq_mask_tokens[:-1]
+            vq_mask_labels = vq_mask_labels[1:]
+
+        for i in ignore_loss_token_ids:
+            assert i != -100 and i is not None
+            labels[labels == i] = -100
 
         assert tokens.dtype in [
             torch.int,
@@ -188,15 +170,18 @@ class Conversation:
             tokens=tokens,
             labels=labels,
             vq_parts=vq_parts,
-            mel_parts=mel_parts,
+            vq_mask_tokens=vq_mask_tokens,
+            vq_mask_labels=vq_mask_labels,
             vq_require_losses=vq_require_losses,
         )
 
     def encode_for_inference(
         self: "Conversation",
-        tokenizer: AutoTokenizer,
+        tokenizer: FishTokenizer,
         num_codebooks: int,
     ) -> EncodedMessage:
+        # self.visualize(tokenizer)
+
         encoded = self.encode(tokenizer, add_shift=False)
         tokens = encoded.tokens
         values = torch.zeros((num_codebooks + 1, len(tokens)), dtype=torch.int)
@@ -205,24 +190,47 @@ class Conversation:
         if encoded.vq_parts is None or len(encoded.vq_parts) == 0:
             return values
 
-        semantic_id, mel_id = tokenizer.convert_tokens_to_ids(
-            [SEMANTIC_TOKEN, MEL_TOKEN]
-        )
         vq_parts = encoded.vq_parts
+        vq_parts = [part.to(values.device) for part in vq_parts]
         vq_parts = torch.cat(vq_parts, dim=1)
-        values[1:, tokens == semantic_id] = vq_parts
+        values[0, encoded.vq_mask_tokens] = vq_parts[0] + tokenizer.semantic_begin_id
+        values[1:, encoded.vq_mask_tokens] = vq_parts
+
         return values
 
-    def visualize(self: "Conversation", tokenizer: AutoTokenizer):
-        encoded = self.encode(tokenizer, add_shift=False)
+    def visualize(
+        self: "Conversation",
+        tokenizer: FishTokenizer,
+        ignore_loss_tokens: list[str] = [],
+    ):
+        encoded = self.encode(
+            tokenizer, add_shift=False, ignore_loss_tokens=ignore_loss_tokens
+        )
 
-        print_in_blue = lambda x: print("\033[94m" + x + "\033[0m", end="")
-        print_in_green = lambda x: print("\033[92m" + x + "\033[0m", end="")
+        # Colors for alternating tokens
+        colors = {
+            "blue": "\033[94m",  # Light blue
+            "cyan": "\033[96m",  # Cyan
+            "green": "\033[92m",  # Light green
+            "dark_green": "\033[32m",  # Dark green
+        }
+        blue_idx = 0
+        green_idx = 0
+
+        def print_in_blue(x):
+            nonlocal blue_idx
+            color = colors["blue"] if blue_idx % 2 == 0 else colors["cyan"]
+            print(f"{color}{x}\033[0m", end="")
+            blue_idx += 1
+
+        def print_in_green(x):
+            nonlocal green_idx
+            color = colors["green"] if green_idx % 2 == 0 else colors["dark_green"]
+            print(f"{color}{x}\033[0m", end="")
+            green_idx += 1
 
         for tok, lab in zip(encoded.tokens, encoded.labels):
-            val = tokenizer.decode(tok, skip_special_tokens=False)
-            if val == "\n":
-                val = "\\n\n"
+            val = tokenizer.decode([tok])
 
             if lab == -100:
                 print_in_green(val)
@@ -231,6 +239,9 @@ class Conversation:
 
         print()
 
+    def append(self: "Conversation", message: Message):
+        self.messages.append(message)
+
 
 if __name__ == "__main__":
     message0 = Message(
@@ -248,7 +259,7 @@ if __name__ == "__main__":
         cal_loss=True,
     )
     conversation = Conversation([message0, message1])
-    tokenizer = AutoTokenizer.from_pretrained("checkpoints/Qwen2-1.5B-Instruct")
+    tokenizer = FishTokenizer.from_pretrained("checkpoints/Qwen2-1.5B-Instruct")
     conversation.visualize(tokenizer)
 
     encoded = conversation.encode(tokenizer)

+ 63 - 20
fish_speech/models/text2semantic/llama.py

@@ -16,7 +16,7 @@ from torch.nn.attention import SDPBackend, sdpa_kernel
 from torch.utils.checkpoint import checkpoint
 from transformers import AutoTokenizer
 
-from fish_speech.conversation import SEMANTIC_TOKEN
+from fish_speech.tokenizer import SEMANTIC_TOKENS, FishTokenizer
 from fish_speech.utils import RankedLogger
 
 from .lora import LoraConfig, setup_lora
@@ -61,6 +61,7 @@ class BaseModelArgs:
     # Dummy vars
     is_reward_model: bool = False
     share_codebook_embeddings: bool = True
+    scale_codebook_embeddings: bool = False
 
     def __post_init__(self):
         if self.n_local_heads == -1:
@@ -164,13 +165,17 @@ class BaseTransformerForwardResult:
 
 class BaseTransformer(nn.Module):
     def __init__(
-        self, config: BaseModelArgs, tokenizer: AutoTokenizer, init_weights: bool = True
+        self,
+        config: BaseModelArgs,
+        tokenizer: FishTokenizer | AutoTokenizer,
+        init_weights: bool = True,
     ) -> None:
         super().__init__()
         self.config = config
         self.tokenizer = tokenizer
-
-        self.semantic_token_id = tokenizer.convert_tokens_to_ids(SEMANTIC_TOKEN)
+        self.semantic_token_ids = [
+            tokenizer.get_token_id(SEMANTIC_TOKEN) for SEMANTIC_TOKEN in SEMANTIC_TOKENS
+        ]
 
         # Slow transformer
         self.embeddings = nn.Embedding(
@@ -245,8 +250,10 @@ class BaseTransformer(nn.Module):
         vocab_embeds = [self.embeddings(x[:, 0])]
         for i in range(self.config.num_codebooks):
             emb = self.codebook_embeddings(x[:, i + 1] + i * self.config.codebook_size)
-            emb[x[:, 0] != self.semantic_token_id] = 0
-            vocab_embeds.append(emb)
+            semantic_token_ids_tensor = torch.tensor(
+                self.semantic_token_ids, device=x.device
+            )
+            emb[~torch.isin(x[:, 0], semantic_token_ids_tensor)] = 0
 
         x = torch.stack(vocab_embeds, dim=3)
         x = x.sum(dim=3)
@@ -294,20 +301,45 @@ class BaseTransformer(nn.Module):
 
     def forward_generate(
         self,
-        x: Tensor,
+        inp: Tensor,
         input_pos: Optional[Tensor] = None,
+        vq_masks: Optional[Tensor] = None,  # this is not used in fact
         return_all: bool = False,
     ) -> BaseTransformerForwardResult:
         # This is used for generation, optimized for torch compile
-        assert (
-            self.max_seq_len != -1 and self.max_batch_size != -1
-        ), "Please call setup_caches before forward_generate"
+        # assert (
+        #     self.max_seq_len != -1 and self.max_batch_size != -1
+        # ), "Please call setup_caches before forward_generate"
 
-        x = self.embed(x)
+        embeds = []
+        for i in range(self.config.num_codebooks):
+            if self.config.share_codebook_embeddings:
+                _tokens = inp[:, i + 1] + i * self.config.codebook_size
+            else:
+                _tokens = inp[:, i + 1]
 
-        mask = self.causal_mask[
-            None, None, input_pos, : self.max_seq_len
-        ]  # (B, N, Q, K)
+            emb = self.codebook_embeddings(_tokens)
+            embeds.append(emb)
+
+        vq_embeds_sum = torch.stack(embeds, dim=1).sum(dim=1)
+        # if self.config.use_codebook_mlp:
+        #     vq_embeds_sum = vq_embeds_sum / self.config.num_codebooks
+        #     vq_embeds_sum = self.codebook_mlp(vq_embeds_sum)
+
+        vq_masks = (inp[:, 0] >= self.tokenizer.semantic_begin_id) & (
+            inp[:, 0] <= self.tokenizer.semantic_end_id
+        )
+
+        vq_embeds_sum[~vq_masks] = 0
+        x = self.embeddings(inp[:, 0]) + vq_embeds_sum
+
+        if input_pos is None:
+            input_pos = torch.arange(inp.shape[-1], device=x.device)
+            max_seq_len = inp.shape[-1]
+        else:
+            max_seq_len = self.max_seq_len
+
+        mask = self.causal_mask[None, None, input_pos, :max_seq_len]  # (B, N, Q, K)
         freqs_cis = self.freqs_cis[input_pos]
 
         for layer in self.layers:
@@ -320,7 +352,9 @@ class BaseTransformer(nn.Module):
         # We got slow_out here
         slow_out = self.norm(x)
 
-        if self.config.tie_word_embeddings:
+        if self.config.is_reward_model:
+            token_logits = self.score_output(slow_out)
+        elif self.config.tie_word_embeddings:
             token_logits = F.linear(slow_out, self.embeddings.weight)
         else:
             token_logits = self.output(slow_out)
@@ -348,6 +382,7 @@ class BaseTransformer(nn.Module):
         max_length: int | None = None,
         lora_config: LoraConfig | None = None,
         rope_base: int | None = None,
+        is_agent: bool = False,
     ) -> "BaseTransformer":
         config = BaseModelArgs.from_pretrained(str(path))
         if max_length is not None:
@@ -366,7 +401,12 @@ class BaseTransformer(nn.Module):
             case _:
                 raise ValueError(f"Unknown model type: {config.model_type}")
 
-        tokenizer = AutoTokenizer.from_pretrained(str(path))
+        if is_agent:
+            tokenizer = AutoTokenizer.from_pretrained(str(path))
+        else:
+            tokenizer_path = str(path) + "/tokenizer.tiktoken"
+            tokenizer = FishTokenizer(tokenizer_path)
+
         log.info(f"Loading model from {path}, config: {config}")
         model = model_cls(config, tokenizer=tokenizer)
 
@@ -452,7 +492,7 @@ class BaseTransformer(nn.Module):
 
 
 class NaiveTransformer(BaseTransformer):
-    def __init__(self, config: NaiveModelArgs, tokenizer: AutoTokenizer) -> None:
+    def __init__(self, config: NaiveModelArgs, tokenizer: FishTokenizer) -> None:
         super().__init__(config, init_weights=False, tokenizer=tokenizer)
 
         self.codebook_norm = RMSNorm(config.dim, eps=config.norm_eps)
@@ -498,7 +538,7 @@ class NaiveTransformer(BaseTransformer):
 
 
 class DualARTransformer(BaseTransformer):
-    def __init__(self, config: NaiveModelArgs, tokenizer: AutoTokenizer) -> None:
+    def __init__(self, config: NaiveModelArgs, tokenizer: FishTokenizer) -> None:
         super().__init__(config, init_weights=False, tokenizer=tokenizer)
 
         # Project to fast dim if needed
@@ -654,9 +694,12 @@ class DualARTransformer(BaseTransformer):
         return codebook_logits
 
     def forward_generate(
-        self, x: Tensor, input_pos: Optional[Tensor] = None
+        self,
+        x: Tensor,
+        input_pos: Optional[Tensor] = None,
+        vq_masks: Optional[Tensor] = None,
     ) -> TransformerForwardResult:
-        x = super().forward_generate(x, input_pos)
+        x = super().forward_generate(x, input_pos, vq_masks)
         x.hidden_states = self.fast_project_in(x.hidden_states)
         return x
 

+ 1 - 26
fish_speech/text/clean.py

@@ -1,33 +1,8 @@
 import re
 
 SYMBOLS_MAPPING = {
-    "\n": "",
-    "…": ".",
-    "“": "'",
-    "”": "'",
     "‘": "'",
     "’": "'",
-    "【": "",
-    "】": "",
-    "[": "",
-    "]": "",
-    "(": "",
-    ")": "",
-    "(": "",
-    ")": "",
-    "・": "",
-    "·": "",
-    "「": "'",
-    "」": "'",
-    "《": "'",
-    "》": "'",
-    "—": "",
-    "~": "",
-    "~": "",
-    ":": ",",
-    ";": ",",
-    ";": ",",
-    ":": ",",
 }
 
 REPLACE_SYMBOL_REGEX = re.compile(
@@ -57,6 +32,6 @@ def clean_text(text):
     text = EMOJI_REGEX.sub(r"", text)
 
     # Remove continuous periods (...) and commas (,,,)
-    text = re.sub(r"[.,]{2,}", lambda m: m.group()[0], text)
+    text = re.sub(r"[,]{2,}", lambda m: m.group()[0], text)
 
     return text

+ 1 - 1
fish_speech/text/spliter.py

@@ -4,7 +4,7 @@ import string
 from fish_speech.text.clean import clean_text
 
 
-def utf_8_len(text):
+def utf_8_len(text: str):
     return len(text.encode("utf-8"))
 
 

+ 152 - 0
fish_speech/tokenizer.py

@@ -0,0 +1,152 @@
+import base64
+import json
+import logging
+from pathlib import Path
+
+import tiktoken
+
+logger = logging.getLogger(__name__)
+
+# This is a modified version of the default pattern from GPT-4o, that better handles punctuations.
+FISH_TIKTOKEN_PATTERN = "|".join(
+    [
+        r"(?i:'s|'t|'re|'ve|'m|'ll|'d)",
+        r"\p{P}",
+        r"[^\r\n\p{L}\p{N}]?\p{L}+",
+        r"\p{N}",
+        r" ?[^\s\p{L}\p{N}]+[\r\n]*",
+        r"\s*[\r\n]+",
+        r"\s+(\?!\S)",
+        r"\s+",
+    ]
+)
+TIKTOKEN_MAX_ENCODE_CHARS = 400_000
+
+BOS_TOKEN = "<|begin_of_text|>"
+EOS_TOKEN = "<|end_of_text|>"
+PAD_TOKEN = "<|pad|>"
+IM_START_TOKEN = "<|im_start|>"
+IM_END_TOKEN = "<|im_end|>"
+
+MODALITY_TEXT_TOKEN = "<|text|>"
+MODALITY_VOICE_TOKEN = "<|voice|>"
+MODALITY_INTERLEAVE_TOKEN = "<|interleave|>"
+MODALITY_TOKENS = {
+    "text": MODALITY_TEXT_TOKEN,
+    "voice": MODALITY_VOICE_TOKEN,
+    "interleave": MODALITY_INTERLEAVE_TOKEN,
+}
+
+PLACEHOLDER_TOKEN = [""] * 4
+for i in range(4):
+    PLACEHOLDER_TOKEN[i] = f"<|placeholder:{i}|>"
+
+SEMANTIC_TOKEN_TEMPLATE = "<|semantic:{i}|>"
+SEMANTIC_TOKENS = [SEMANTIC_TOKEN_TEMPLATE.format(i=i) for i in range(1024)]
+
+# Warning: when you add a new special token, you should only add it to the end of the list.
+ALL_SPECIAL_TOKENS = [
+    BOS_TOKEN,
+    EOS_TOKEN,
+    PAD_TOKEN,
+    IM_START_TOKEN,
+    IM_END_TOKEN,
+    PLACEHOLDER_TOKEN[0],
+    PLACEHOLDER_TOKEN[1],
+    PLACEHOLDER_TOKEN[2],
+    PLACEHOLDER_TOKEN[3],
+    MODALITY_TEXT_TOKEN,
+    MODALITY_VOICE_TOKEN,
+    MODALITY_INTERLEAVE_TOKEN,
+    *SEMANTIC_TOKENS,
+]
+
+
+class FishTokenizer:
+    def __init__(self, model_path: str) -> None:
+        mergeable_ranks = self.load_tiktoken_bpe(model_path)
+        special_token_begin = len(mergeable_ranks)
+        self.all_special_tokens_with_ids = {
+            token: special_token_begin + i for i, token in enumerate(ALL_SPECIAL_TOKENS)
+        }
+        self.semantic_id_to_token_id = {
+            i: self.all_special_tokens_with_ids[token]
+            for i, token in enumerate(SEMANTIC_TOKENS)
+        }
+        self.semantic_begin_id = self.all_special_tokens_with_ids[SEMANTIC_TOKENS[0]]
+        self.semantic_end_id = self.all_special_tokens_with_ids[SEMANTIC_TOKENS[-1]]
+
+        self.tkt_model = tiktoken.core.Encoding(
+            name=Path(model_path).stem,
+            pat_str=FISH_TIKTOKEN_PATTERN,
+            mergeable_ranks=mergeable_ranks,
+            special_tokens=self.all_special_tokens_with_ids,
+        )
+
+    @staticmethod
+    def load_tiktoken_bpe(tiktoken_bpe_file: str) -> dict[bytes, int]:
+        data = {}
+        for line in open(tiktoken_bpe_file).read().splitlines():
+            if not line:
+                continue
+            token, rank = line.split()
+            data[base64.b64decode(token)] = int(rank)
+        return data
+
+    def get_token_id(self, token: str) -> int:
+        return self.all_special_tokens_with_ids[token]
+
+    def encode(self, s: str, allowed_special: bool | set[str] = True) -> list[int]:
+        assert isinstance(s, str)
+
+        subs = []
+        for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS):
+            subs.append(s[i : i + TIKTOKEN_MAX_ENCODE_CHARS])
+
+        if allowed_special is True:
+            allowed_special = self.tkt_model.special_tokens_set
+        elif allowed_special is False:
+            allowed_special = set()
+
+        return sum(
+            self.tkt_model.encode_batch(
+                subs, allowed_special=allowed_special, disallowed_special=set()
+            ),
+            start=[],
+        )
+
+    def decode(self, tokens: list[int]) -> str:
+        return self.tkt_model.decode(tokens)
+
+    def save_pretrained(self, path: str):
+        path = Path(path)
+        path.mkdir(parents=True, exist_ok=True)
+
+        with open(path / "tokenizer.tiktoken", "w") as f:
+            for token, rank in self.tkt_model._mergeable_ranks.items():
+                f.write(f"{base64.b64encode(token).decode()} {rank}\n")
+
+        with open(path / "special_tokens.json", "w") as f:
+            json.dump(
+                self.all_special_tokens_with_ids,
+                f,
+                indent=2,
+                ensure_ascii=False,
+            )
+
+    @staticmethod
+    def from_pretrained(path: str):
+        return FishTokenizer(Path(path) / "tokenizer.tiktoken")
+
+
+if __name__ == "__main__":
+    tokenizer = FishTokenizer("data/mpacks/v1.4-pretrain/tokenizer.all.tiktoken")
+    tokenizer.save_pretrained("checkpoints/fish-speech-0.5B")
+    tokenizer = FishTokenizer.from_pretrained("checkpoints/fish-speech-0.5B")
+
+    print(
+        [
+            tokenizer.decode([i])
+            for i in tokenizer.encode(f"{BOS_TOKEN}你好,世界!{EOS_TOKEN}")
+        ]
+    )

+ 2 - 0
pyproject.toml

@@ -44,6 +44,8 @@ dependencies = [
     "opencc-python-reimplemented==0.1.7",
     "silero-vad",
     "ormsgpack",
+    "tiktoken>=0.8.0",
+    "pydantic==2.9.2",
 ]
 
 [project.optional-dependencies]

+ 26 - 21
tools/api.py

@@ -1,4 +1,5 @@
 import io
+import json
 import os
 import queue
 import re
@@ -32,7 +33,6 @@ from kui.asgi import (
 )
 from kui.asgi.routing import MultimethodRoutes
 from loguru import logger
-from transformers import AutoTokenizer
 
 pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
 import struct
@@ -43,12 +43,14 @@ from cachetools import LRUCache, cached
 from funasr import AutoModel
 from silero_vad import get_speech_timestamps, load_silero_vad
 
-from fish_speech.conversation import IM_END_TOKEN, SEMANTIC_TOKEN
 from fish_speech.models.text2semantic.llama import BaseModelArgs
 
 # from fish_speech.models.vqgan.lit_module import VQGAN
 from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture
 from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
+
+# from fish_speech.conversation import IM_END_TOKEN, SEMANTIC_TOKEN
+from fish_speech.tokenizer import IM_END_TOKEN, FishTokenizer
 from fish_speech.utils import autocast_exclude_mps, set_seed
 from tools.file import AUDIO_EXTENSIONS, audio_to_bytes, list_files, read_ref_text
 from tools.llama.generate import (
@@ -381,14 +383,13 @@ from fish_speech.conversation import Conversation, Message
 
 def execute_request(
     input_queue: queue.Queue,
-    tokenizer: AutoTokenizer,
+    tokenizer: FishTokenizer,
     config: BaseModelArgs,
     request: ServeRequest,
     device: str = "cuda:0",
 ):
-    semantic_id, im_end_id = tokenizer.convert_tokens_to_ids(
-        [SEMANTIC_TOKEN, IM_END_TOKEN]
-    )
+
+    im_end_id = tokenizer.get_token_id(IM_END_TOKEN)
     messages = []
     for message in request.messages:
         messages.append(message.to_conversation_message())
@@ -397,7 +398,13 @@ def execute_request(
     # assert messages[-1].role == "user", "The last message must be from the user"
 
     if messages[-1].role == "user":
-        messages.append(Message(role="assistant", parts=[], add_im_end=False))
+        messages.append(
+            Message(role="assistant", parts=[], add_im_end=False, modality="voice")
+        )
+    elif messages[-1].role == "raw":
+        messages[-1].add_im_start = False
+        messages[-1].add_im_end = False
+        messages[-1].modality = "voice"
     else:
         assert (
             messages[-1].role == "assistant"
@@ -405,6 +412,8 @@ def execute_request(
         messages[-1].add_im_end = False
 
     conv = Conversation(messages=messages)
+
+    # conv.visualize(tokenizer)
     prompt = conv.encode_for_inference(
         tokenizer=tokenizer, num_codebooks=config.num_codebooks
     ).to(device)
@@ -422,7 +431,6 @@ def execute_request(
         "prompt": prompt,
         "max_new_tokens": request.max_new_tokens,
         "im_end_id": im_end_id,
-        "semantic_id": semantic_id,
         "temperature": request.temperature,
         "top_p": request.top_p,
         "repetition_penalty": request.repetition_penalty,
@@ -478,10 +486,13 @@ def execute_request(
                     )
                 continue
 
-            if tokens[0] == semantic_id and request.streaming:
+            is_semantic = (
+                tokenizer.semantic_begin_id <= tokens[0] <= tokenizer.semantic_end_id
+            )
+            if is_semantic and request.streaming:
                 yield from send_reset_buffer(sample_id)
                 # Streaming vq
-                _tokens = tokens[1:].clone() - 1
+                _tokens = tokens[1:].clone()
 
                 if config.share_codebook_embeddings is False:
                     for i in range(len(_tokens)):
@@ -494,13 +505,13 @@ def execute_request(
                 continue
 
             # Not streaming vq
-            if tokens[0] == semantic_id:
+            if is_semantic:
                 yield from send_reset_buffer(sample_id)
                 # None streaming vq
                 if len(parts[sample_id]) == 0 or not isinstance(
                     parts[sample_id][-1], ServeVQPart
                 ):
-                    _tokens = tokens[1:].clone() - 1
+                    _tokens = tokens[1:].clone()
 
                     if config.share_codebook_embeddings is False:
                         for i in range(len(_tokens)):
@@ -509,14 +520,14 @@ def execute_request(
                     parts[sample_id].append(ServeVQPart(codes=_tokens.tolist()))
                 else:
                     for codebook_id, value in enumerate(tokens[1:, :]):
-                        val = value.item() - 1
+                        val = value.item()
                         if config.share_codebook_embeddings is False:
                             val -= config.codebook_size * codebook_id
 
                         parts[sample_id][-1].codes[codebook_id].append(val)
                 continue
 
-            if tokens[0] != semantic_id:
+            if not is_semantic:
                 # Stream text decode is not supported now
                 decode_buffer[sample_id].append(tokens[0, 0])
 
@@ -776,7 +787,6 @@ async def api_health():
     """
     Health check
     """
-
     return JSONResponse({"status": "ok"})
 
 
@@ -871,11 +881,6 @@ def initialize_app(app: Kui):
     args = parse_args()  # args same as ones in other processes
     args.precision = torch.half if args.half else torch.bfloat16
 
-    # Check if CUDA is available
-    if not torch.cuda.is_available():
-        logger.info("CUDA is not available, running on CPU.")
-        args.device = "cpu"
-
     if args.load_asr_model:
         logger.info(f"Loading ASR model...")
         asr_model = load_asr_model(device=args.device)
@@ -922,7 +927,7 @@ def initialize_app(app: Kui):
                     max_new_tokens=0,
                     chunk_length=200,
                     top_p=0.7,
-                    repetition_penalty=1.2,
+                    repetition_penalty=1.5,
                     temperature=0.7,
                     emotion=None,
                     format="wav",

+ 117 - 89
tools/llama/generate.py

@@ -17,9 +17,16 @@ from loguru import logger
 from tqdm import tqdm
 from transformers import AutoTokenizer
 
-from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID
+from fish_speech.conversation import (
+    CODEBOOK_PAD_TOKEN_ID,
+    Conversation,
+    Message,
+    TextPart,
+    VQPart,
+)
 from fish_speech.models.text2semantic.llama import BaseModelArgs
 from fish_speech.text import clean_text, split_text
+from fish_speech.tokenizer import IM_END_TOKEN, FishTokenizer
 
 os.environ["TOKENIZERS_PARALLELISM"] = "false"
 torch._inductor.config.coordinate_descent_tuning = True
@@ -145,8 +152,8 @@ def decode_one_token_ar_agent(
     model: DualARTransformer,
     x: torch.Tensor,
     input_pos: torch.Tensor,
+    semantic_ids: list,
     previous_tokens: torch.Tensor = None,
-    semantic_id: int = 32003,
     **sampling_kwargs,
 ) -> torch.Tensor:
     # print(x, input_pos)
@@ -190,19 +197,13 @@ def decode_one_token_ar_agent(
         codebooks.append(a)
 
     codebooks = torch.stack(codebooks, dim=1)
+    semantic_ids_tensor = torch.tensor(semantic_ids, device=codebooks.device)
     codebooks[:, 1:, :] = torch.masked_fill(
-        codebooks[:, 1:, :], codebooks[:, :1, :] != semantic_id, CODEBOOK_PAD_TOKEN_ID
+        codebooks[:, 1:, :],
+        ~torch.isin(codebooks[:, :1, :], semantic_ids_tensor),
+        CODEBOOK_PAD_TOKEN_ID,
     )
 
-    # for i in range(codebooks.size(1) - 1):
-    #     codebooks[:, i + 1, :] = torch.masked_fill(
-    #         codebooks[:, i + 1, :],
-    #         codebooks[:, :1, :] != semantic_id,
-    #         CODEBOOK_PAD_TOKEN_ID + i * 1024,
-    #     )
-
-    # print(codebooks)
-
     return codebooks
 
 
@@ -210,8 +211,8 @@ def decode_one_token_naive_agent(
     model: NaiveTransformer,
     x: torch.Tensor,
     input_pos: torch.Tensor,
+    semantic_ids: list,
     previous_tokens: torch.Tensor = None,
-    semantic_id: int = 32003,
     **sampling_kwargs,
 ) -> torch.Tensor:
     x = model.forward_generate(x, input_pos)
@@ -236,8 +237,11 @@ def decode_one_token_naive_agent(
         )
 
     codebooks = torch.stack(codebooks, dim=1)
+    semantic_ids_tensor = torch.tensor(semantic_ids, device=codebooks.device)
     codebooks[:, 1:, :] = torch.masked_fill(
-        codebooks[:, 1:, :], codebooks[:, :1, :] != semantic_id, CODEBOOK_PAD_TOKEN_ID
+        codebooks[:, 1:, :],
+        ~torch.isin(codebooks[:, :1, :], semantic_ids_tensor),
+        CODEBOOK_PAD_TOKEN_ID,
     )
 
     return codebooks
@@ -247,8 +251,8 @@ def decode_one_token_ar(
     model: DualARTransformer,
     x: torch.Tensor,
     input_pos: torch.Tensor,
+    semantic_ids: list,
     previous_tokens: torch.Tensor = None,
-    semantic_id: int = 0,
     **sampling_kwargs,
 ) -> torch.Tensor:
     x = model.forward_generate(x, input_pos)
@@ -261,21 +265,32 @@ def decode_one_token_ar(
     codebooks = [
         sample(
             x.logits,
-            previous_tokens=None,  # Disable repetition penalty for the token codebook
+            previous_tokens=(
+                previous_tokens[0] if previous_tokens is not None else None
+            ),  # Disable repetition penalty for the token codebook
             **sampling_kwargs_main,
         )[0]
     ]
 
-    x = x.hidden_states
+    hidden_states = x.hidden_states
 
     # Cleanup the cache
     for layer in model.fast_layers:
         layer.attention.kv_cache.k_cache.fill_(0)
         layer.attention.kv_cache.v_cache.fill_(0)
 
-    for codebook_idx in range(model.config.num_codebooks):
-        input_pos = torch.tensor([codebook_idx], device=x.device, dtype=torch.long)
-        logits = model.forward_generate_fast(x, input_pos)
+    input_pos = torch.tensor([0], device=hidden_states.device, dtype=torch.long)
+    model.forward_generate_fast(hidden_states, input_pos)
+    a = codebooks[0] - model.tokenizer.semantic_begin_id
+    a[a < 0] = 0
+    hidden_states = model.fast_embeddings(a)
+    codebooks.append(a)
+
+    for codebook_idx in range(1, model.config.num_codebooks):
+        input_pos = torch.tensor(
+            [codebook_idx], device=hidden_states.device, dtype=torch.long
+        )
+        logits = model.forward_generate_fast(hidden_states, input_pos)
         a = sample(
             logits,
             previous_tokens=(
@@ -285,14 +300,16 @@ def decode_one_token_ar(
             ),
             **sampling_kwargs,
         )[0]
-        x = model.fast_embeddings(a)
+        hidden_states = model.fast_embeddings(a)
         codebooks.append(a)
 
     codebooks = torch.stack(codebooks, dim=0)
-    codebooks[1:, :] = torch.masked_fill(
-        codebooks[1:, :], codebooks[:1, :] != semantic_id, CODEBOOK_PAD_TOKEN_ID
-    )
+    # semantic_ids_tensor = torch.tensor(semantic_ids, device=codebooks.device)
+    # codebooks[1:, :] = torch.masked_fill(
+    #     codebooks[1:, :], ~torch.isin(codebooks[:1, :], semantic_ids_tensor), CODEBOOK_PAD_TOKEN_ID
+    # )
 
+    # print(codebooks)
     return codebooks
 
 
@@ -337,9 +354,8 @@ def decode_n_tokens(
     cur_token: torch.Tensor,
     input_pos: torch.Tensor,
     num_new_tokens: int,
-    im_end_id: int = 4,
+    semantic_ids: list,
     decode_one_token=decode_one_token_naive,
-    semantic_id: int = 0,
     **sampling_kwargs,
 ):
     previous_tokens = torch.zeros(
@@ -368,7 +384,7 @@ def decode_n_tokens(
                 x=cur_token,
                 input_pos=input_pos,
                 previous_tokens=window,
-                semantic_id=semantic_id,
+                semantic_ids=semantic_ids,
                 **sampling_kwargs,
             )
 
@@ -378,7 +394,7 @@ def decode_n_tokens(
             model.config.num_codebooks + 1, -1
         )
 
-        if cur_token[0, 0, -1] == im_end_id:
+        if cur_token[0, 0, -1] == model.tokenizer.get_token_id(IM_END_TOKEN):
             break
 
     return previous_tokens[:, : i + 1]
@@ -391,7 +407,6 @@ def generate(
     model: NaiveTransformer,
     prompt: torch.Tensor,
     max_new_tokens: int,
-    im_end_id: int = 4,
     decode_one_token=decode_one_token_naive,
     **sampling_kwargs,
 ) -> torch.Tensor:
@@ -401,7 +416,10 @@ def generate(
 
     # create an empty tensor of the expected final shape and fill in the current tokens
     T = prompt.size(1)
-    semantic_id = model.tokenizer.convert_tokens_to_ids("<|semantic|>")
+    # semantic_id = model.tokenizer.convert_tokens_to_ids("<|semantic|>")
+    semantic_ids = [
+        model.tokenizer.get_token_id(f"<|semantic:{i}|>") for i in range(1024)
+    ]
 
     if max_new_tokens:
         if T + max_new_tokens > model.config.max_seq_len:
@@ -435,7 +453,7 @@ def generate(
         model,
         prompt.view(1, codebook_dim, -1),
         input_pos,
-        semantic_id=semantic_id,
+        semantic_ids=semantic_ids,
         **sampling_kwargs,
     )
     seq[:, T : T + 1] = next_token
@@ -446,9 +464,8 @@ def generate(
         next_token.view(1, codebook_dim, -1),
         input_pos,
         max_new_tokens - 1,
-        im_end_id=im_end_id,
         decode_one_token=decode_one_token,
-        semantic_id=semantic_id,
+        semantic_ids=semantic_ids,
         **sampling_kwargs,
     )
     # x = torch.cat(generated_tokens, dim=1)
@@ -463,8 +480,8 @@ def decode_n_tokens_agent(
     cur_token: torch.Tensor,
     input_pos: torch.Tensor,
     num_new_tokens: int,
+    semantic_ids: list,
     im_end_id: int = 4,
-    semantic_id: int = 32003,
     decode_one_token=decode_one_token_naive_agent,
     early_stop_threshold: float = 0.6,
     **sampling_kwargs,
@@ -495,7 +512,7 @@ def decode_n_tokens_agent(
                 x=cur_token,
                 input_pos=input_pos,
                 previous_tokens=window,
-                semantic_id=semantic_id,
+                semantic_ids=semantic_ids,
                 **sampling_kwargs,
             )
 
@@ -529,8 +546,8 @@ def generate_agent(
     model: BaseTransformer,
     prompt: torch.Tensor,
     max_new_tokens: int,
+    semantic_ids: list,
     im_end_id: int = 4,
-    semantic_id: int = 32003,
     decode_one_token=decode_one_token_naive_agent,
     num_samples: int = 1,
     early_stop_threshold: float = 0.6,
@@ -574,7 +591,7 @@ def generate_agent(
         model,
         prompt,
         input_pos,
-        semantic_id=semantic_id,
+        semantic_ids=semantic_ids,
         **sampling_kwargs,
     ).view(num_samples, codebook_dim, -1)
     yield next_token.cpu()
@@ -587,7 +604,7 @@ def generate_agent(
         input_pos,
         max_new_tokens - 1,
         im_end_id=im_end_id,
-        semantic_id=semantic_id,
+        semantic_ids=semantic_ids,
         decode_one_token=decode_one_token,
         early_stop_threshold=early_stop_threshold,
         **sampling_kwargs,
@@ -602,65 +619,63 @@ def encode_tokens(
     num_codebooks=4,
 ):
     string = clean_text(string)
-    string = f"<|im_start|>user\n{string}<|im_end|><|im_start|>assistant\n"
 
-    new_tokens = tokenizer.encode(
-        string,
-        add_special_tokens=False,
-        max_length=10**6,
-        truncation=False,
+    messages = []
+    messages.append(
+        Message(
+            role="user",
+            parts=[TextPart(text=string)],
+            cal_loss=False,
+        )
     )
-    tokens = torch.tensor([new_tokens], dtype=torch.int, device=device)
 
-    # Codebooks
-    zeros = (
-        torch.ones((num_codebooks, tokens.size(1)), dtype=torch.int, device=device)
-        * CODEBOOK_PAD_TOKEN_ID
-    )
-    prompt = torch.cat((tokens, zeros), dim=0)
+    if prompt_tokens is not None:
+        if prompt_tokens.ndim == 3:
+            assert (
+                prompt_tokens.shape[0] == 1
+            ), "3D prompt tokens should have shape (1, num_codebooks, seq_len)"
+            prompt_tokens = prompt_tokens[0]
 
-    if prompt_tokens is None:
-        return prompt
+        assert prompt_tokens.ndim == 2, "Prompt tokens should be 2D tensor"
 
-    # Get prompt tokens
-    if prompt_tokens.ndim == 3:
-        assert (
-            prompt_tokens.shape[0] == 1
-        ), f"3 dim prompt tokens should have shape (1, num_codebooks, seq_len)"
-        prompt_tokens = prompt_tokens[0]
+        if prompt_tokens.shape[0] > num_codebooks:
+            logger.warning(
+                f"Prompt tokens shape {prompt_tokens.shape} is larger than num_codebooks {num_codebooks}, getting first {num_codebooks} codebooks"
+            )
+            prompt_tokens = prompt_tokens[:num_codebooks]
 
-    assert prompt_tokens.ndim == 2
-    data = prompt_tokens + 1
+        vq_part = VQPart(codes=prompt_tokens.to(device))
 
-    if prompt_tokens.shape[0] > num_codebooks:
-        logger.warning(
-            f"Prompt tokens shape {prompt_tokens.shape} is larger than num_codebooks {num_codebooks}, getting first {num_codebooks} codebooks"
+        messages.append(
+            Message(
+                role="assistant",
+                parts=[TextPart(text="<|voice|>"), vq_part],
+                cal_loss=False,
+            )
+        )
+    else:
+        messages.append(
+            Message(
+                role="assistant",
+                parts=[TextPart(text="<|voice|>")],
+                cal_loss=False,
+                add_im_end=False,
+            )
         )
-        data = data[:num_codebooks]
-
-    # Add pad token for each codebook
-    data = torch.cat(
-        (data, torch.zeros((data.size(0), 1), dtype=torch.int, device=device)),
-        dim=1,
-    )
 
-    # Since 1.0, we use <|semantic|>
-    s0_token_id = tokenizer.convert_tokens_to_ids("<|semantic|>")
-    end_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
-    main_token_ids = (
-        torch.ones((1, data.size(1)), dtype=torch.int, device=device) * s0_token_id
+    conversation = Conversation(messages=messages)
+    # conversation.visualize(tokenizer)
+    encoded = conversation.encode_for_inference(
+        tokenizer=tokenizer,
+        num_codebooks=num_codebooks,
     )
-    main_token_ids[0, -1] = end_token_id
-
-    data = torch.cat((main_token_ids, data), dim=0)
-    prompt = torch.cat((prompt, data), dim=1)
 
-    return prompt
+    return encoded.to(device)
 
 
 def load_model(checkpoint_path, device, precision, compile=False, is_agent=False):
     model: Union[NaiveTransformer, DualARTransformer] = BaseTransformer.from_pretrained(
-        checkpoint_path, load_weights=True
+        checkpoint_path, load_weights=True, is_agent=is_agent
     )
 
     model = model.to(device=device, dtype=precision)
@@ -729,11 +744,26 @@ def generate_long(
 
     model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
     tokenizer = model.tokenizer
-    im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
+    im_end_id = tokenizer.get_token_id("<|im_end|>")
 
     encoded = []
     texts = split_text(text, chunk_length) if iterative_prompt else [text]
-    encoded_prompts = []
+    encoded_prompts = [
+        Conversation(
+            messages=[
+                Message(
+                    role="system",
+                    parts=[TextPart(text="Speak out the provided text.")],
+                    cal_loss=False,
+                )
+            ]
+        )
+        .encode_for_inference(
+            tokenizer=tokenizer,
+            num_codebooks=model.config.num_codebooks,
+        )
+        .to(device)
+    ]
 
     if use_prompt:
         for idx, (t, c) in enumerate(zip(prompt_text, prompt_tokens)):
@@ -812,7 +842,6 @@ def generate_long(
                 model=model,
                 prompt=cat_encoded,
                 max_new_tokens=max_new_tokens,
-                im_end_id=im_end_id,
                 decode_one_token=decode_one_token,
                 temperature=temperature,
                 top_p=top_p,
@@ -842,12 +871,11 @@ def generate_long(
                 )
 
             # Put the generated tokens
-            # since there is <im_end> and <eos> tokens, we remove last 2 tokens
-            codes = y[1:, prompt_length:-1].clone()
-            codes = codes - 1
+            # since there is <im_end>, we remove last token
+            codes = y[1:, prompt_length + 1 :].clone()
             assert (codes >= 0).all(), f"Negative code found"
 
-            decoded = y[:, prompt_length:-1].clone()
+            decoded = y[:, prompt_length:].clone()
             # But for global encoding, we should keep the <im_end> token
 
             global_encoded.append(decoded)

+ 4 - 1
tools/schema.py

@@ -64,11 +64,14 @@ class ServeASRResponse(BaseModel):
 
 
 class ServeMessage(BaseModel):
-    role: Literal["system", "assistant", "user"]
+    role: Literal["system", "assistant", "user", "raw"]
     parts: list[ServeVQPart | ServeTextPart]
 
     def to_conversation_message(self):
         new_message = Message(role=self.role, parts=[])
+        if self.role == "assistant":
+            new_message.modality = "voice"
+
         for part in self.parts:
             if isinstance(part, ServeTextPart):
                 new_message.parts.append(TextPart(text=part.text))