فهرست منبع

Make WebUI and API code cleaner (+ 1.5 fixes) (#703)

* rename webui.py to run_webui.py

* remove unused imports

* remove unsued code

* move inference code and fix all warnings

* move web app code

* make code easier to read

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove unused function

* remove msgpack_api.py

* rename API files

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* finish updating the doc with the new file names

* finish updating the doc with the new file names

* fix CPU use in the API

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* refactor WebUIinference in a class with submodules

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* re-enable streaming in webui inference code

* generalize inference code in webui

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* minor fix

* make a unique inference engine class

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* minor fix

* cleaning code

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* implement new structure of the API (not working)

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* refactor API

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* minor fixes

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* reimplement chat endpoint

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Picus303 1 سال پیش
والد
کامیت
62eae262c2
45فایلهای تغییر یافته به همراه1959 افزوده شده و 1697 حذف شده
  1. 1 1
      .github/ISSUE_TEMPLATE/bug_report.yml
  2. 1 1
      docs/en/index.md
  3. 5 5
      docs/en/inference.md
  4. 1 1
      docs/en/start_agent.md
  5. 1 1
      docs/ja/index.md
  6. 4 4
      docs/ja/inference.md
  7. 1 1
      docs/ja/start_agent.md
  8. 1 1
      docs/ko/index.md
  9. 5 5
      docs/ko/inference.md
  10. 1 1
      docs/ko/start_agent.md
  11. 1 1
      docs/pt/index.md
  12. 4 4
      docs/pt/inference.md
  13. 1 1
      docs/pt/start_agent.md
  14. 1 1
      docs/zh/index.md
  15. 5 5
      docs/zh/inference.md
  16. 1 1
      docs/zh/start_agent.md
  17. 1 1
      entrypoint.sh
  18. 1 1
      fish_speech/webui/manage.py
  19. 1 1
      inference.ipynb
  20. 1 1
      start.bat
  21. 0 951
      tools/api.py
  22. 6 12
      tools/api_client.py
  23. 98 0
      tools/api_server.py
  24. 2 2
      tools/fish_e2e.py
  25. 193 0
      tools/inference_engine/__init__.py
  26. 128 0
      tools/inference_engine/reference_loader.py
  27. 42 0
      tools/inference_engine/utils.py
  28. 57 0
      tools/inference_engine/vq_manager.py
  29. 0 95
      tools/msgpack_api.py
  30. 101 0
      tools/run_webui.py
  31. 9 29
      tools/schema.py
  32. 57 0
      tools/server/agent/__init__.py
  33. 119 0
      tools/server/agent/generate.py
  34. 122 0
      tools/server/agent/generation_utils.py
  35. 72 0
      tools/server/agent/pre_generation_utils.py
  36. 75 0
      tools/server/api_utils.py
  37. 27 0
      tools/server/exception_handler.py
  38. 41 0
      tools/server/inference.py
  39. 119 0
      tools/server/model_manager.py
  40. 129 0
      tools/server/model_utils.py
  41. 246 0
      tools/server/views.py
  42. 0 570
      tools/webui.py
  43. 173 0
      tools/webui/__init__.py
  44. 91 0
      tools/webui/inference.py
  45. 14 0
      tools/webui/variables.py

+ 1 - 1
.github/ISSUE_TEMPLATE/bug_report.yml

@@ -45,7 +45,7 @@ body:
       description: |
         Include detailed steps, screenshots, and logs. Use the correct markdown syntax for code blocks.
       placeholder: |
-        1. Run the command `python -m tools.post_api -t "xxxxx"`
+        1. Run the command `python -m tools.api_client -t "xxxxx"`
         2. Observe the console output error: `ModuleNotFoundError: No module named 'pyaudio'` (with screenshots or logs will be better)
     validations:
       required: true

+ 1 - 1
docs/en/index.md

@@ -185,7 +185,7 @@ pip install -e .[stable]
 4. Configure environment variables and access WebUI
 
     In the terminal inside the docker container, enter `export GRADIO_SERVER_NAME="0.0.0.0"` to allow external access to the gradio service inside docker.
-    Then in the terminal inside the docker container, enter `python tools/webui.py` to start the WebUI service.
+    Then in the terminal inside the docker container, enter `python tools/run_webui.py` to start the WebUI service.
 
     If you're using WSL or MacOS, visit [http://localhost:7860](http://localhost:7860) to open the WebUI interface.
 

+ 5 - 5
docs/en/inference.md

@@ -67,7 +67,7 @@ python tools/vqgan/inference.py \
 We provide a HTTP API for inference. You can use the following command to start the server:
 
 ```bash
-python -m tools.api \
+python -m tools.api_server \
     --listen 0.0.0.0:8080 \
     --llama-checkpoint-path "checkpoints/fish-speech-1.5" \
     --decoder-checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \
@@ -78,10 +78,10 @@ python -m tools.api \
 
 After that, you can view and test the API at http://127.0.0.1:8080/.
 
-Below is an example of sending a request using `tools/post_api.py`.
+Below is an example of sending a request using `tools/api_client.py`.
 
 ```bash
-python -m tools.post_api \
+python -m tools.api_client \
     --text "Text to be input" \
     --reference_audio "Path to reference audio" \
     --reference_text "Text content of the reference audio" \
@@ -93,7 +93,7 @@ The above command indicates synthesizing the desired audio according to the refe
 The following example demonstrates that you can use **multiple** reference audio paths and reference audio texts at once. Separate them with spaces in the command.
 
 ```bash
-python -m tools.post_api \
+python -m tools.api_client \
     --text "Text to input" \
     --reference_audio "reference audio path1" "reference audio path2" \
     --reference_text "reference audio text1" "reference audio text2"\
@@ -109,7 +109,7 @@ The currently supported reference audio has a maximum total duration of 90 secon
 
 
 !!! info 
-    To learn more about available parameters, you can use the command `python -m tools.post_api -h`
+    To learn more about available parameters, you can use the command `python -m tools.api_client -h`
 
 ## GUI Inference 
 [Download client](https://github.com/AnyaCoder/fish-speech-gui/releases)

+ 1 - 1
docs/en/start_agent.md

@@ -44,7 +44,7 @@ pip install -e .[stable]
 To build fish-agent, please use the command below under the main folder:
 
 ```bash
-python -m tools.api --llama-checkpoint-path checkpoints/fish-agent-v0.1-3b/ --mode agent --compile
+python -m tools.api_server --llama-checkpoint-path checkpoints/fish-agent-v0.1-3b/ --mode agent --compile
 ```
 
 The `--compile` args only support Python < 3.12 , which will greatly speed up the token generation.

+ 1 - 1
docs/ja/index.md

@@ -184,7 +184,7 @@ pip install -e .[stable]
 4. 環境変数の設定と WebUI へのアクセス
 
     Docker コンテナ内のターミナルで、`export GRADIO_SERVER_NAME="0.0.0.0"` と入力して、外部から Docker 内の gradio サービスにアクセスできるようにします。
-    次に、Docker コンテナ内のターミナルで `python tools/webui.py` と入力して WebUI サービスを起動します。
+    次に、Docker コンテナ内のターミナルで `python tools/run_webui.py` と入力して WebUI サービスを起動します。
 
     WSL または MacOS の場合は、[http://localhost:7860](http://localhost:7860) にアクセスして WebUI インターフェースを開くことができます。
 

+ 4 - 4
docs/ja/inference.md

@@ -67,7 +67,7 @@ python tools/vqgan/inference.py \
 推論のための HTTP API を提供しています。次のコマンドを使用してサーバーを起動できます:
 
 ```bash
-python -m tools.api \
+python -m tools.api_server \
     --listen 0.0.0.0:8080 \
     --llama-checkpoint-path "checkpoints/fish-speech-1.5" \
     --decoder-checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \
@@ -78,10 +78,10 @@ python -m tools.api \
 
 その後、`http://127.0.0.1:8080/`で API を表示およびテストできます。
 
-以下は、`tools/post_api.py` を使用してリクエストを送信する例です。
+以下は、`tools/api_client.py` を使用してリクエストを送信する例です。
 
 ```bash
-python -m tools.post_api \
+python -m tools.api_client \
     --text "入力するテキスト" \
     --reference_audio "参照音声へのパス" \
     --reference_text "参照音声テキスト" \
@@ -91,7 +91,7 @@ python -m tools.post_api \
 上記のコマンドは、参照音声の情報に基づいて必要な音声を合成し、ストリーミング方式で返すことを示しています。
 
 !!! info
-    使用可能なパラメータの詳細については、コマンド` python -m tools.post_api -h `を使用してください
+    使用可能なパラメータの詳細については、コマンド` python -m tools.api_client -h `を使用してください
 
 ## WebUI 推論
 

+ 1 - 1
docs/ja/start_agent.md

@@ -47,7 +47,7 @@ pip install -e .[stable]
 fish-agentを構築するには、メインフォルダで以下のコマンドを使用してください:
 
 ```bash
-python -m tools.api --llama-checkpoint-path checkpoints/fish-agent-v0.1-3b/ --mode agent --compile
+python -m tools.api_server --llama-checkpoint-path checkpoints/fish-agent-v0.1-3b/ --mode agent --compile
 ```
 
 `--compile`引数はPython < 3.12でのみサポートされており、トークン生成を大幅に高速化します。

+ 1 - 1
docs/ko/index.md

@@ -185,7 +185,7 @@ pip install -e .[stable]
 4. 환경 변수 설정 및 WebUI 접근
 
     Docker 컨테이너 내부의 터미널에서 `export GRADIO_SERVER_NAME="0.0.0.0"`를 입력하여 Docker 내부에서 Gradio 서비스에 외부 접근을 허용합니다.
-    이후, 터미널에서 `python tools/webui.py` 명령어를 입력하여 WebUI 서비스를 시작합니다.
+    이후, 터미널에서 `python tools/run_webui.py` 명령어를 입력하여 WebUI 서비스를 시작합니다.
 
     WSL 또는 macOS를 사용하는 경우 [http://localhost:7860](http://localhost:7860)에서 WebUI 인터페이스를 열 수 있습니다.
 

+ 5 - 5
docs/ko/inference.md

@@ -67,7 +67,7 @@ python tools/vqgan/inference.py \
 추론을 위한 HTTP API를 제공하고 있습니다. 아래의 명령어로 서버를 시작할 수 있습니다:
 
 ```bash
-python -m tools.api \
+python -m tools.api_server \
     --listen 0.0.0.0:8080 \
     --llama-checkpoint-path "checkpoints/fish-speech-1.5" \
     --decoder-checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \
@@ -78,10 +78,10 @@ python -m tools.api \
 
 이후, http://127.0.0.1:8080/ 에서 API를 확인하고 테스트할 수 있습니다.
 
-아래는 `tools/post_api.py`를 사용하여 요청을 보내는 예시입니다.
+아래는 `tools/api_client.py`를 사용하여 요청을 보내는 예시입니다.
 
 ```bash
-python -m tools.post_api \
+python -m tools.api_client \
     --text "입력할 텍스트" \
     --reference_audio "참고 음성 경로" \
     --reference_text "참고 음성의 텍스트 내용" \
@@ -93,7 +93,7 @@ python -m tools.post_api \
 다음 예시는 여러 개의 참고 음성 경로와 텍스트를 한꺼번에 사용할 수 있음을 보여줍니다. 명령에서 공백으로 구분하여 입력합니다.
 
 ```bash
-python -m tools.post_api \
+python -m tools.api_client \
     --text "입력할 텍스트" \
     --reference_audio "참고 음성 경로1" "참고 음성 경로2" \
     --reference_text "참고 음성 텍스트1" "참고 음성 텍스트2"\
@@ -107,7 +107,7 @@ python -m tools.post_api \
 `--reference_audio`와 `--reference_text` 대신에 `--reference_id`(하나만 사용 가능)를 사용할 수 있습니다. 프로젝트 루트 디렉토리에 `references/<your reference_id>` 폴더를 만들어 해당 음성과 주석 텍스트를 넣어야 합니다. 참고 음성은 최대 90초까지 지원됩니다.
 
 !!! info 
-    제공되는 파라미터는 `python -m tools.post_api -h`를 사용하여 확인할 수 있습니다.
+    제공되는 파라미터는 `python -m tools.api_client -h`를 사용하여 확인할 수 있습니다.
 
 ## GUI 추론 
 [클라이언트 다운로드](https://github.com/AnyaCoder/fish-speech-gui/releases)

+ 1 - 1
docs/ko/start_agent.md

@@ -47,7 +47,7 @@ pip install -e .[stable]
 fish-agent를 구축하려면 메인 폴더에서 아래 명령어를 사용하세요:
 
 ```bash
-python -m tools.api --llama-checkpoint-path checkpoints/fish-agent-v0.1-3b/ --mode agent --compile
+python -m tools.api_server --llama-checkpoint-path checkpoints/fish-agent-v0.1-3b/ --mode agent --compile
 ```
 
 `--compile` 인자는 Python < 3.12에서만 지원되며, 토큰 생성 속도를 크게 향상시킵니다.

+ 1 - 1
docs/pt/index.md

@@ -181,7 +181,7 @@ pip install -e .[stable]
 4. Configure as variáveis de ambiente e acesse a WebUI
 
     No terminal do contêiner Docker, digite `export GRADIO_SERVER_NAME="0.0.0.0"` para permitir o acesso externo ao serviço gradio dentro do Docker.
-    Em seguida, no terminal do contêiner Docker, digite `python tools/webui.py` para iniciar o serviço WebUI.
+    Em seguida, no terminal do contêiner Docker, digite `python tools/run_webui.py` para iniciar o serviço WebUI.
 
     Se estiver usando WSL ou MacOS, acesse [http://localhost:7860](http://localhost:7860) para abrir a interface WebUI.
 

+ 4 - 4
docs/pt/inference.md

@@ -67,7 +67,7 @@ python tools/vqgan/inference.py \
 Fornecemos uma API HTTP para inferência. O seguinte comando pode ser usado para iniciar o servidor:
 
 ```bash
-python -m tools.api \
+python -m tools.api_server \
     --listen 0.0.0.0:8080 \
     --llama-checkpoint-path "checkpoints/fish-speech-1.5" \
     --decoder-checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \
@@ -78,10 +78,10 @@ python -m tools.api \
 
 Depois disso, é possível visualizar e testar a API em http://127.0.0.1:8080/.
 
-Abaixo está um exemplo de envio de uma solicitação usando `tools/post_api.py`.
+Abaixo está um exemplo de envio de uma solicitação usando `tools/api_client.py`.
 
 ```bash
-python -m tools.post_api \
+python -m tools.api_client \
     --text "Texto a ser inserido" \
     --reference_audio "Caminho para o áudio de referência" \
     --reference_text "Conteúdo de texto do áudio de referência" \
@@ -91,7 +91,7 @@ python -m tools.post_api \
 O comando acima indica a síntese do áudio desejada de acordo com as informações do áudio de referência e a retorna em modo de streaming.
 
 !!! info
-    Para aprender mais sobre parâmetros disponíveis, você pode usar o comando `python -m tools.post_api -h`
+    Para aprender mais sobre parâmetros disponíveis, você pode usar o comando `python -m tools.api_client -h`
 
 ## Inferência por WebUI
 

+ 1 - 1
docs/pt/start_agent.md

@@ -47,7 +47,7 @@ pip install -e .[stable]
 Para construir o fish-agent, use o comando abaixo na pasta principal:
 
 ```bash
-python -m tools.api --llama-checkpoint-path checkpoints/fish-agent-v0.1-3b/ --mode agent --compile
+python -m tools.api_server --llama-checkpoint-path checkpoints/fish-agent-v0.1-3b/ --mode agent --compile
 ```
 
 O argumento `--compile` só suporta Python < 3.12, o que aumentará muito a velocidade de geração de tokens.

+ 1 - 1
docs/zh/index.md

@@ -188,7 +188,7 @@ pip install -e .[stable]
 4. 配置环境变量,访问 WebUI
 
     在 docker 容器内的终端,输入 `export GRADIO_SERVER_NAME="0.0.0.0"` ,从而让外部可以访问 docker 内的 gradio 服务。
-    接着在 docker 容器内的终端,输入 `python tools/webui.py` 即可开启 WebUI 服务。
+    接着在 docker 容器内的终端,输入 `python tools/run_webui.py` 即可开启 WebUI 服务。
 
     如果是 WSL 或者是 MacOS ,访问 [http://localhost:7860](http://localhost:7860) 即可打开 WebUI 界面。
 

+ 5 - 5
docs/zh/inference.md

@@ -73,7 +73,7 @@ python tools/vqgan/inference.py \
 运行以下命令来启动 HTTP 服务:
 
 ```bash
-python -m tools.api \
+python -m tools.api_server \
     --listen 0.0.0.0:8080 \
     --llama-checkpoint-path "checkpoints/fish-speech-1.5" \
     --decoder-checkpoint-path "checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \
@@ -88,10 +88,10 @@ HF_ENDPOINT=https://hf-mirror.com python -m ...(同上)
 
 随后, 你可以在 `http://127.0.0.1:8080/` 中查看并测试 API.
 
-下面是使用`tools/post_api.py`发送请求的示例。
+下面是使用`tools/api_client.py`发送请求的示例。
 
 ```bash
-python -m tools.post_api \
+python -m tools.api_client \
     --text "要输入的文本" \
     --reference_audio "参考音频路径" \
     --reference_text "参考音频的文本内容" \
@@ -102,7 +102,7 @@ python -m tools.post_api \
 
 下面的示例展示了, 可以一次使用**多个** `参考音频路径` 和 `参考音频的文本内容`。在命令里用空格隔开即可。
 ```bash
-python -m tools.post_api \
+python -m tools.api_client \
     --text "要输入的文本" \
     --reference_audio "参考音频路径1" "参考音频路径2" \
     --reference_text "参考音频的文本内容1" "参考音频的文本内容2"\
@@ -117,7 +117,7 @@ python -m tools.post_api \
 里面放上任意对音频与标注文本。 目前支持的参考音频最多加起来总时长90s。
 
 !!! info
-    要了解有关可用参数的更多信息,可以使用命令`python -m tools.post_api -h`
+    要了解有关可用参数的更多信息,可以使用命令`python -m tools.api_client -h`
 
 ## GUI 推理 
 [下载客户端](https://github.com/AnyaCoder/fish-speech-gui/releases)

+ 1 - 1
docs/zh/start_agent.md

@@ -49,7 +49,7 @@ pip install -e .[stable]
 你需要使用以下指令来构建 fish-agent
 
 ```bash
-python -m tools.api --llama-checkpoint-path checkpoints/fish-agent-v0.1-3b/ --mode agent --compile
+python -m tools.api_server --llama-checkpoint-path checkpoints/fish-agent-v0.1-3b/ --mode agent --compile
 ```
 
 `--compile`只能在小于 3.12 版本的 Python 使用,这个功能可以极大程度上提高生成速度。

+ 1 - 1
entrypoint.sh

@@ -7,4 +7,4 @@ if [ "${CUDA_ENABLED}" != "true" ]; then
     DEVICE="--device cpu"
 fi
 
-exec python tools/webui.py ${DEVICE}
+exec python tools/run_webui.py ${DEVICE}

+ 1 - 1
fish_speech/webui/manage.py

@@ -176,7 +176,7 @@ def change_infer(
         p_infer = subprocess.Popen(
             [
                 PYTHON,
-                "tools/webui.py",
+                "tools/run_webui.py",
                 "--decoder-checkpoint-path",
                 infer_decoder_model,
                 "--decoder-config-name",

+ 1 - 1
inference.ipynb

@@ -83,7 +83,7 @@
    },
    "outputs": [],
    "source": [
-    "!python tools/webui.py \\\n",
+    "!python tools/run_webui.py \\\n",
     "    --llama-checkpoint-path checkpoints/fish-speech-1.4 \\\n",
     "    --decoder-checkpoint-path checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth \\\n",
     "    # --compile"

+ 1 - 1
start.bat

@@ -82,7 +82,7 @@ if not "!flags!"=="" set "flags=!flags:~1!"
 echo Debug: flags = !flags!
 
 if "!mode!"=="api" (
-    %PYTHON_CMD% -m tools.api !flags!
+    %PYTHON_CMD% -m tools.api_server !flags!
 ) else if "!mode!"=="infer" (
     %PYTHON_CMD% -m tools.webui !flags!
 )

+ 0 - 951
tools/api.py

@@ -1,951 +0,0 @@
-import io
-import json
-import os
-import queue
-import re
-import time
-import traceback
-import wave
-from argparse import ArgumentParser
-from http import HTTPStatus
-from pathlib import Path
-from typing import Annotated, Any
-
-import librosa
-import numpy as np
-import ormsgpack
-import pyrootutils
-import soundfile as sf
-import torch
-import torchaudio
-from baize.datastructures import ContentType
-from kui.asgi import (
-    Body,
-    FactoryClass,
-    HTTPException,
-    HttpRequest,
-    HttpView,
-    JSONResponse,
-    Kui,
-    OpenAPI,
-    StreamResponse,
-    request,
-)
-from kui.asgi.routing import MultimethodRoutes
-from loguru import logger
-
-pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
-import struct
-from threading import Lock
-
-import httpx
-from cachetools import LRUCache, cached
-from funasr import AutoModel
-from silero_vad import get_speech_timestamps, load_silero_vad
-
-from fish_speech.models.text2semantic.llama import BaseModelArgs
-
-# from fish_speech.models.vqgan.lit_module import VQGAN
-from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture
-from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
-
-# from fish_speech.conversation import IM_END_TOKEN, SEMANTIC_TOKEN
-from fish_speech.tokenizer import IM_END_TOKEN, FishTokenizer
-from fish_speech.utils import autocast_exclude_mps, set_seed
-from tools.file import AUDIO_EXTENSIONS, audio_to_bytes, list_files, read_ref_text
-from tools.llama.generate import (
-    GenerateRequest,
-    GenerateResponse,
-    WrappedGenerateResponse,
-    launch_thread_safe_queue,
-    launch_thread_safe_queue_agent,
-)
-from tools.schema import (
-    GLOBAL_NUM_SAMPLES,
-    ASRPackRequest,
-    ServeASRRequest,
-    ServeASRResponse,
-    ServeASRSegment,
-    ServeAudioPart,
-    ServeForwardMessage,
-    ServeMessage,
-    ServeRequest,
-    ServeResponse,
-    ServeStreamDelta,
-    ServeStreamResponse,
-    ServeTextPart,
-    ServeTimedASRResponse,
-    ServeTTSRequest,
-    ServeVQGANDecodeRequest,
-    ServeVQGANDecodeResponse,
-    ServeVQGANEncodeRequest,
-    ServeVQGANEncodeResponse,
-    ServeVQPart,
-)
-from tools.vqgan.inference import load_model as load_decoder_model
-
-global_lock = Lock()
-
-# Whether to disable keepalive (which is helpful if the server is in the same cluster)
-DISABLE_KEEPALIVE = os.getenv("DISABLE_KEEPALIVE", "false").lower() == "true"
-async_client = httpx.AsyncClient(
-    timeout=120, limits=httpx.Limits(keepalive_expiry=0 if DISABLE_KEEPALIVE else None)
-)
-backends = torchaudio.list_audio_backends()
-
-if "ffmpeg" in backends:
-    backend = "ffmpeg"
-else:
-    backend = "soundfile"
-
-
-def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
-    buffer = io.BytesIO()
-
-    with wave.open(buffer, "wb") as wav_file:
-        wav_file.setnchannels(channels)
-        wav_file.setsampwidth(bit_depth // 8)
-        wav_file.setframerate(sample_rate)
-
-    wav_header_bytes = buffer.getvalue()
-    buffer.close()
-    return wav_header_bytes
-
-
-# Define utils for web server
-async def http_execption_handler(exc: HTTPException):
-    return JSONResponse(
-        dict(
-            statusCode=exc.status_code,
-            message=exc.content,
-            error=HTTPStatus(exc.status_code).phrase,
-        ),
-        exc.status_code,
-        exc.headers,
-    )
-
-
-async def other_exception_handler(exc: "Exception"):
-    traceback.print_exc()
-
-    status = HTTPStatus.INTERNAL_SERVER_ERROR
-    return JSONResponse(
-        dict(statusCode=status, message=str(exc), error=status.phrase),
-        status,
-    )
-
-
-def load_audio(reference_audio, sr):
-    if len(reference_audio) > 255 or not Path(reference_audio).exists():
-        audio_data = reference_audio
-        reference_audio = io.BytesIO(audio_data)
-
-    waveform, original_sr = torchaudio.load(reference_audio, backend=backend)
-
-    if waveform.shape[0] > 1:
-        waveform = torch.mean(waveform, dim=0, keepdim=True)
-
-    if original_sr != sr:
-        resampler = torchaudio.transforms.Resample(orig_freq=original_sr, new_freq=sr)
-        waveform = resampler(waveform)
-
-    audio = waveform.squeeze().numpy()
-    return audio
-
-
-def encode_reference(*, decoder_model, reference_audio, enable_reference_audio):
-    if enable_reference_audio and reference_audio is not None:
-        # Load audios, and prepare basic info here
-        reference_audio_content = load_audio(
-            reference_audio, decoder_model.spec_transform.sample_rate
-        )
-
-        audios = torch.from_numpy(reference_audio_content).to(decoder_model.device)[
-            None, None, :
-        ]
-        audio_lengths = torch.tensor(
-            [audios.shape[2]], device=decoder_model.device, dtype=torch.long
-        )
-        logger.info(
-            f"Loaded audio with {audios.shape[2] / decoder_model.spec_transform.sample_rate:.2f} seconds"
-        )
-
-        # VQ Encoder
-        if isinstance(decoder_model, FireflyArchitecture):
-            prompt_tokens = decoder_model.encode(audios, audio_lengths)[0][0]
-
-        logger.info(f"Encoded prompt: {prompt_tokens.shape}")
-    else:
-        prompt_tokens = None
-        logger.info("No reference audio provided")
-
-    return prompt_tokens
-
-
-def decode_vq_tokens(
-    *,
-    decoder_model,
-    codes,
-):
-    feature_lengths = torch.tensor([codes.shape[1]], device=decoder_model.device)
-    logger.info(f"VQ features: {codes.shape}")
-
-    if isinstance(decoder_model, FireflyArchitecture):
-        # VQGAN Inference
-        return decoder_model.decode(
-            indices=codes[None],
-            feature_lengths=feature_lengths,
-        )[0].squeeze()
-
-    raise ValueError(f"Unknown model type: {type(decoder_model)}")
-
-
-routes = MultimethodRoutes(base_class=HttpView)
-
-
-def get_content_type(audio_format):
-    if audio_format == "wav":
-        return "audio/wav"
-    elif audio_format == "flac":
-        return "audio/flac"
-    elif audio_format == "mp3":
-        return "audio/mpeg"
-    else:
-        return "application/octet-stream"
-
-
-@torch.no_grad()
-@torch.autocast(device_type="cuda", dtype=torch.half)
-def batch_encode(model, audios: list[bytes | torch.Tensor]):
-    audios = [
-        (
-            torch.from_numpy(
-                librosa.load(io.BytesIO(audio), sr=model.spec_transform.sample_rate)[0]
-            )[None]
-            if isinstance(audio, bytes)
-            else audio
-        )
-        for audio in audios
-    ]
-
-    # if any(audio.shape[-1] > model.spec_transform.sample_rate * 120 for audio in audios):
-    #     raise ValueError("Single audio length is too long (>120s)")
-
-    max_length = max(audio.shape[-1] for audio in audios)
-    print(f"Encode max length: {max_length / model.spec_transform.sample_rate:.2f}s")
-
-    lengths = torch.tensor([audio.shape[-1] for audio in audios], device=model.device)
-    max_length = lengths.max().item()
-    padded = torch.stack(
-        [
-            torch.nn.functional.pad(audio, (0, max_length - audio.shape[-1]))
-            for audio in audios
-        ]
-    ).to(model.device)
-
-    features, feature_lengths = model.encode(padded, audio_lengths=lengths)
-    features, feature_lengths = features.cpu(), feature_lengths.cpu()
-
-    return [feature[..., :length] for feature, length in zip(features, feature_lengths)]
-
-
-@cached(
-    cache=LRUCache(maxsize=10000),
-    key=lambda model, audios: (model.device, tuple(audios)),
-)
-def cached_vqgan_batch_encode(model, audios: list[bytes]):
-    return batch_encode(model, audios)
-
-
-@routes.http.post("/v1/vqgan/encode")
-def api_vqgan_encode(payload: Annotated[ServeVQGANEncodeRequest, Body(exclusive=True)]):
-
-    start_time = time.time()
-    tokens = cached_vqgan_batch_encode(decoder_model, payload.audios)
-    logger.info(f"[EXEC] VQGAN encode time: {(time.time() - start_time) * 1000:.2f}ms")
-
-    return ormsgpack.packb(
-        ServeVQGANEncodeResponse(tokens=[i.tolist() for i in tokens]),
-        option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
-    )
-
-
-@torch.no_grad()
-@torch.autocast(device_type="cuda", dtype=torch.half)
-def vqgan_decode(model, features):
-    lengths = torch.tensor(
-        [feature.shape[-1] for feature in features], device=model.device
-    )
-    max_length = lengths.max().item()
-    padded = torch.stack(
-        [
-            torch.nn.functional.pad(feature, (0, max_length - feature.shape[-1]))
-            for feature in features
-        ]
-    ).to(model.device)
-
-    # If bs too large, we do micro batch decode
-    audios, audio_lengths = [], []
-    for i in range(0, padded.shape[0], 8):
-        audio, audio_length = model.decode(
-            padded[i : i + 8], feature_lengths=lengths[i : i + 8]
-        )
-        audios.append(audio)
-        audio_lengths.append(audio_length)
-    audios = torch.cat(audios, dim=0)
-    audio_lengths = torch.cat(audio_lengths, dim=0)
-    audios, audio_lengths = audios.cpu(), audio_lengths.cpu()
-
-    return [audio[..., :length].numpy() for audio, length in zip(audios, audio_lengths)]
-
-
-@routes.http.post("/v1/vqgan/decode")
-def api_vqgan_decode(payload: Annotated[ServeVQGANDecodeRequest, Body(exclusive=True)]):
-    tokens = [torch.tensor(token, dtype=torch.int) for token in payload.tokens]
-    start_time = time.time()
-    audios = vqgan_decode(decoder_model, tokens)
-    logger.info(f"[EXEC] VQGAN decode time: {(time.time() - start_time) * 1000:.2f}ms")
-    audios = [audio.astype(np.float16).tobytes() for audio in audios]
-    return ormsgpack.packb(
-        ServeVQGANDecodeResponse(audios=audios), option=ormsgpack.OPT_SERIALIZE_PYDANTIC
-    )
-
-
-@torch.no_grad()
-def batch_asr(model, audios, sr, language="auto"):
-    resampled_audios = []
-    for audio in audios:
-        audio = torchaudio.functional.resample(audio, sr, 16000)
-        assert audio.ndim == 1
-        resampled_audios.append(audio)
-
-    with global_lock:
-        res = model.generate(
-            input=resampled_audios,
-            batch_size=len(resampled_audios),
-            language=language,
-            use_itn=True,
-        )
-
-    results = []
-    for r, audio in zip(res, audios):
-        text = r["text"]
-        text = re.sub(r"<\|.*?\|>", "", text)
-        duration = len(audio) / sr * 1000
-        huge_gap = False
-
-        if "timestamp" in r and len(r["timestamp"]) > 2:
-            for timestamp_a, timestamp_b in zip(
-                r["timestamp"][:-1], r["timestamp"][1:]
-            ):
-                # If there is a gap of more than 5 seconds, we consider it as a huge gap
-                if timestamp_b[0] - timestamp_a[1] > 5000:
-                    huge_gap = True
-                    break
-
-            # Doesn't make sense to have a huge gap at the end
-            if duration - r["timestamp"][-1][1] > 3000:
-                huge_gap = True
-
-        results.append(
-            {
-                "text": text,
-                "duration": duration,
-                "huge_gap": huge_gap,
-            }
-        )
-
-    return results
-
-
-@routes.http.post("/v1/asr")
-def api_invoke_asr(payload: Annotated[ServeASRRequest, Body(exclusive=True)]):
-    start_time = time.time()
-    audios = [np.frombuffer(audio, dtype=np.float16) for audio in payload.audios]
-    audios = [torch.from_numpy(audio).float() for audio in audios]
-
-    if any(audios.shape[-1] >= 30 * payload.sample_rate for audios in audios):
-        raise HTTPException(status_code=400, detail="Audio length is too long")
-
-    transcriptions = batch_asr(
-        asr_model, audios=audios, sr=payload.sample_rate, language=payload.language
-    )
-    logger.info(f"[EXEC] ASR time: {(time.time() - start_time) * 1000:.2f}ms")
-
-    return ormsgpack.packb(
-        ServeASRResponse(transcriptions=transcriptions),
-        option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
-    )
-
-
-from fish_speech.conversation import Conversation, Message
-
-
-def execute_request(
-    input_queue: queue.Queue,
-    tokenizer: FishTokenizer,
-    config: BaseModelArgs,
-    request: ServeRequest,
-    device: str = "cuda:0",
-):
-
-    im_end_id = tokenizer.get_token_id(IM_END_TOKEN)
-    messages = []
-    for message in request.messages:
-        messages.append(message.to_conversation_message())
-
-    assert len(messages) >= 1, "At least one message is required"
-    # assert messages[-1].role == "user", "The last message must be from the user"
-
-    if messages[-1].role == "user":
-        messages.append(
-            Message(role="assistant", parts=[], add_im_end=False, modality="voice")
-        )
-    elif messages[-1].role == "raw":
-        messages[-1].add_im_start = False
-        messages[-1].add_im_end = False
-        messages[-1].modality = "voice"
-    else:
-        assert (
-            messages[-1].role == "assistant"
-        ), "The last message must be from the assistant"
-        messages[-1].add_im_end = False
-
-    conv = Conversation(messages=messages)
-
-    # conv.visualize(tokenizer)
-    prompt = conv.encode_for_inference(
-        tokenizer=tokenizer, num_codebooks=config.num_codebooks
-    ).to(device)
-
-    if request.streaming:
-        for i in range(request.num_samples):
-            yield ServeStreamResponse(
-                sample_id=i,
-                delta=ServeStreamDelta(
-                    role="assistant",
-                ),
-            )
-
-    req = {
-        "prompt": prompt,
-        "max_new_tokens": request.max_new_tokens,
-        "im_end_id": im_end_id,
-        "temperature": request.temperature,
-        "top_p": request.top_p,
-        "repetition_penalty": request.repetition_penalty,
-        "num_samples": request.num_samples,
-        "early_stop_threshold": request.early_stop_threshold,
-    }
-
-    start = time.time()
-    response_queue = queue.Queue()
-    input_queue.put(GenerateRequest(req, response_queue))
-
-    # Decoding
-    decode_buffer = [[] for _ in range(request.num_samples)]
-    parts = [[] for _ in range(request.num_samples)]
-
-    def send_reset_buffer(sample_id):
-        nonlocal decode_buffer
-        if len(decode_buffer[sample_id]) == 0:
-            return
-
-        decoded = tokenizer.decode(decode_buffer[sample_id])
-        part = ServeTextPart(text=decoded)
-
-        if request.streaming:
-            yield ServeStreamResponse(delta=ServeStreamDelta(part=part))
-        else:
-            parts[sample_id].append(part)
-
-        decode_buffer[sample_id] = []
-
-    # Decode process
-    finished = [False for _ in range(request.num_samples)]
-    stats = {}
-    idx = 0
-    while True:
-        response = response_queue.get()
-
-        if response in ["stop", "error"]:
-            break
-
-        for sample_id, tokens in enumerate(response):
-            if finished[sample_id]:
-                continue
-
-            if tokens[0] == im_end_id:
-                finished[sample_id] = True
-                if request.streaming:
-                    yield from send_reset_buffer(sample_id)
-                    yield ServeStreamResponse(
-                        sample_id=sample_id,
-                        finish_reason="stop",
-                        stats=stats,
-                    )
-                continue
-
-            is_semantic = (
-                tokenizer.semantic_begin_id <= tokens[0] <= tokenizer.semantic_end_id
-            )
-            if is_semantic and request.streaming:
-                yield from send_reset_buffer(sample_id)
-                # Streaming vq
-                _tokens = tokens[1:].clone()
-
-                if config.share_codebook_embeddings is False:
-                    for i in range(len(_tokens)):
-                        _tokens[i] -= config.codebook_size * i
-
-                yield ServeStreamResponse(
-                    sample_id=sample_id,
-                    delta=ServeStreamDelta(part=ServeVQPart(codes=_tokens.tolist())),
-                )
-                continue
-
-            # Not streaming vq
-            if is_semantic:
-                yield from send_reset_buffer(sample_id)
-                # None streaming vq
-                if len(parts[sample_id]) == 0 or not isinstance(
-                    parts[sample_id][-1], ServeVQPart
-                ):
-                    _tokens = tokens[1:].clone()
-
-                    if config.share_codebook_embeddings is False:
-                        for i in range(len(_tokens)):
-                            _tokens[i] -= config.codebook_size * i
-
-                    parts[sample_id].append(ServeVQPart(codes=_tokens.tolist()))
-                else:
-                    for codebook_id, value in enumerate(tokens[1:, :]):
-                        val = value.item()
-                        if config.share_codebook_embeddings is False:
-                            val -= config.codebook_size * codebook_id
-
-                        parts[sample_id][-1].codes[codebook_id].append(val)
-                continue
-
-            if not is_semantic:
-                # Stream text decode is not supported now
-                decode_buffer[sample_id].append(tokens[0, 0])
-
-        if idx == 0:
-            stats["time_to_first_token"] = (time.time() - start) * 1000
-
-        idx += 1
-
-    for sample_id in range(request.num_samples):
-        yield from send_reset_buffer(sample_id)
-
-    stats["total_time"] = (time.time() - start) * 1000
-    stats["total_tokens"] = idx
-
-    if request.streaming:
-        for sample_id in range(request.num_samples):
-            if finished[sample_id]:
-                continue
-            yield ServeStreamResponse(
-                finish_reason=response, stats=stats, sample_id=sample_id
-            )
-        return
-
-    yield ServeResponse(
-        messages=[
-            ServeMessage(role="assistant", parts=parts[i])
-            for i in range(request.num_samples)
-        ],
-        finish_reason=response,
-        stats=stats,
-    )
-
-
-@routes.http.post("/v1/chat")
-def api_invoke_chat(
-    req: Annotated[ServeRequest, Body(exclusive=True)],
-):
-    """
-    Invoke model and generate audio
-    """
-
-    # This makes torch compile happy
-    assert (
-        req.num_samples == GLOBAL_NUM_SAMPLES
-    ), f"num_samples must be {GLOBAL_NUM_SAMPLES}"
-
-    content_type = request.headers.get("Content-Type", "application/json")
-    json_mode = "application/json" in content_type
-
-    async def wrapped_generator():
-        generator = execute_request(llama_queue, tokenizer, config, req, args.device)
-
-        for i in generator:
-            if json_mode:
-                body = i.model_dump_json().encode("utf-8")
-                yield b"data: " + body + b"\n\n"
-            else:
-                body = ormsgpack.packb(i, option=ormsgpack.OPT_SERIALIZE_PYDANTIC)
-                yield struct.pack("I", len(body)) + body
-
-    # Naive mode
-    if req.streaming is False:
-        result = next(execute_request(llama_queue, tokenizer, config, req, args.device))
-
-        if json_mode:
-            return JSONResponse(result.model_dump())
-        else:
-            return ormsgpack.packb(result, option=ormsgpack.OPT_SERIALIZE_PYDANTIC)
-
-    return StreamResponse(
-        iterable=wrapped_generator(), content_type="text/event-stream"
-    )
-
-
-@torch.inference_mode()
-def inference(req: ServeTTSRequest):
-
-    idstr: str | None = req.reference_id
-    if idstr is not None:
-        ref_folder = Path("references") / idstr
-        ref_folder.mkdir(parents=True, exist_ok=True)
-        ref_audios = list_files(
-            ref_folder, AUDIO_EXTENSIONS, recursive=True, sort=False
-        )
-
-        if req.use_memory_cache == "never" or (
-            req.use_memory_cache == "on-demand" and len(prompt_tokens) == 0
-        ):
-            prompt_tokens = [
-                encode_reference(
-                    decoder_model=decoder_model,
-                    reference_audio=audio_to_bytes(str(ref_audio)),
-                    enable_reference_audio=True,
-                )
-                for ref_audio in ref_audios
-            ]
-            prompt_texts = [
-                read_ref_text(str(ref_audio.with_suffix(".lab")))
-                for ref_audio in ref_audios
-            ]
-        else:
-            logger.info("Use same references")
-
-    else:
-        # Parse reference audio aka prompt
-        refs = req.references
-
-        if req.use_memory_cache == "never" or (
-            req.use_memory_cache == "on-demand" and len(prompt_tokens) == 0
-        ):
-            prompt_tokens = [
-                encode_reference(
-                    decoder_model=decoder_model,
-                    reference_audio=ref.audio,
-                    enable_reference_audio=True,
-                )
-                for ref in refs
-            ]
-            prompt_texts = [ref.text for ref in refs]
-        else:
-            logger.info("Use same references")
-
-    if req.seed is not None:
-        set_seed(req.seed)
-        logger.warning(f"set seed: {req.seed}")
-
-    # LLAMA Inference
-    request = dict(
-        device=decoder_model.device,
-        max_new_tokens=req.max_new_tokens,
-        text=(
-            req.text
-            if not req.normalize
-            else ChnNormedText(raw_text=req.text).normalize()
-        ),
-        top_p=req.top_p,
-        repetition_penalty=req.repetition_penalty,
-        temperature=req.temperature,
-        compile=args.compile,
-        iterative_prompt=req.chunk_length > 0,
-        chunk_length=req.chunk_length,
-        max_length=4096,
-        prompt_tokens=prompt_tokens,
-        prompt_text=prompt_texts,
-    )
-
-    response_queue = queue.Queue()
-    llama_queue.put(
-        GenerateRequest(
-            request=request,
-            response_queue=response_queue,
-        )
-    )
-
-    if req.streaming:
-        yield wav_chunk_header()
-
-    segments = []
-    while True:
-        result: WrappedGenerateResponse = response_queue.get()
-        if result.status == "error":
-            raise result.response
-            break
-
-        result: GenerateResponse = result.response
-        if result.action == "next":
-            break
-
-        with autocast_exclude_mps(
-            device_type=decoder_model.device.type, dtype=args.precision
-        ):
-            fake_audios = decode_vq_tokens(
-                decoder_model=decoder_model,
-                codes=result.codes,
-            )
-
-        fake_audios = fake_audios.float().cpu().numpy()
-
-        if req.streaming:
-            yield (fake_audios * 32768).astype(np.int16).tobytes()
-        else:
-            segments.append(fake_audios)
-
-    if req.streaming:
-        return
-
-    if len(segments) == 0:
-        raise HTTPException(
-            HTTPStatus.INTERNAL_SERVER_ERROR,
-            content="No audio generated, please check the input text.",
-        )
-
-    fake_audios = np.concatenate(segments, axis=0)
-    yield fake_audios
-
-
-async def inference_async(req: ServeTTSRequest):
-    for chunk in inference(req):
-        yield chunk
-
-
-async def buffer_to_async_generator(buffer):
-    yield buffer
-
-
-@routes.http.post("/v1/tts")
-async def api_invoke_model(
-    req: Annotated[ServeTTSRequest, Body(exclusive=True)],
-):
-    """
-    Invoke model and generate audio
-    """
-
-    if args.max_text_length > 0 and len(req.text) > args.max_text_length:
-        raise HTTPException(
-            HTTPStatus.BAD_REQUEST,
-            content=f"Text is too long, max length is {args.max_text_length}",
-        )
-
-    if req.streaming and req.format != "wav":
-        raise HTTPException(
-            HTTPStatus.BAD_REQUEST,
-            content="Streaming only supports WAV format",
-        )
-
-    if req.streaming:
-        return StreamResponse(
-            iterable=inference_async(req),
-            headers={
-                "Content-Disposition": f"attachment; filename=audio.{req.format}",
-            },
-            content_type=get_content_type(req.format),
-        )
-    else:
-        fake_audios = next(inference(req))
-        buffer = io.BytesIO()
-        sf.write(
-            buffer,
-            fake_audios,
-            decoder_model.spec_transform.sample_rate,
-            format=req.format,
-        )
-
-        return StreamResponse(
-            iterable=buffer_to_async_generator(buffer.getvalue()),
-            headers={
-                "Content-Disposition": f"attachment; filename=audio.{req.format}",
-            },
-            content_type=get_content_type(req.format),
-        )
-
-
-@routes.http.post("/v1/health")
-async def api_health():
-    """
-    Health check
-    """
-    return JSONResponse({"status": "ok"})
-
-
-def parse_args():
-    parser = ArgumentParser()
-    parser.add_argument("--mode", type=str, choices=["agent", "tts"], default="tts")
-    parser.add_argument("--load-asr-model", action="store_true")
-    parser.add_argument(
-        "--llama-checkpoint-path",
-        type=str,
-        default="checkpoints/fish-speech-1.4",
-    )
-    parser.add_argument(
-        "--decoder-checkpoint-path",
-        type=str,
-        default="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
-    )
-    parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
-    parser.add_argument("--device", type=str, default="cuda")
-    parser.add_argument("--half", action="store_true")
-    parser.add_argument("--compile", action="store_true")
-    parser.add_argument("--max-text-length", type=int, default=0)
-    parser.add_argument("--listen", type=str, default="127.0.0.1:8080")
-    parser.add_argument("--workers", type=int, default=1)
-
-    return parser.parse_args()
-
-
-# Define Kui app
-openapi = OpenAPI(
-    {
-        "title": "Fish Speech API",
-        "version": "1.4.2",
-    },
-).routes
-
-
-class MsgPackRequest(HttpRequest):
-    async def data(
-        self,
-    ) -> Annotated[
-        Any, ContentType("application/msgpack"), ContentType("application/json")
-    ]:
-        if self.content_type == "application/msgpack":
-            return ormsgpack.unpackb(await self.body)
-
-        elif self.content_type == "application/json":
-            return await self.json
-
-        raise HTTPException(
-            HTTPStatus.UNSUPPORTED_MEDIA_TYPE,
-            headers={"Accept": "application/msgpack, application/json"},
-        )
-
-
-app = Kui(
-    routes=routes + openapi[1:],  # Remove the default route
-    exception_handlers={
-        HTTPException: http_execption_handler,
-        Exception: other_exception_handler,
-    },
-    factory_class=FactoryClass(http=MsgPackRequest),
-    cors_config={},
-)
-
-
-def load_asr_model(*, device="cuda", hub="ms"):
-    return AutoModel(
-        model="iic/SenseVoiceSmall",
-        device=device,
-        disable_pbar=True,
-        hub=hub,
-    )
-
-
-# Each worker process created by Uvicorn has its own memory space,
-# meaning that models and variables are not shared between processes.
-# Therefore, any global variables (like `llama_queue` or `decoder_model`)
-# will not be shared across workers.
-
-
-# Multi-threading for deep learning can cause issues, such as inconsistent
-# outputs if multiple threads access the same buffers simultaneously.
-# Instead, it's better to use multiprocessing or independent models per thread.
-@app.on_startup
-def initialize_app(app: Kui):
-
-    global args, llama_queue, tokenizer, config, decoder_model, vad_model, asr_model, prompt_tokens, prompt_texts
-
-    prompt_tokens, prompt_texts = [], []
-
-    args = parse_args()  # args same as ones in other processes
-    args.precision = torch.half if args.half else torch.bfloat16
-
-    if args.load_asr_model:
-        logger.info(f"Loading ASR model...")
-        asr_model = load_asr_model(device=args.device)
-
-    logger.info("Loading Llama model...")
-
-    if args.mode == "tts":
-        llama_queue = launch_thread_safe_queue(
-            checkpoint_path=args.llama_checkpoint_path,
-            device=args.device,
-            precision=args.precision,
-            compile=args.compile,
-        )
-    else:
-        llama_queue, tokenizer, config = launch_thread_safe_queue_agent(
-            checkpoint_path=args.llama_checkpoint_path,
-            device=args.device,
-            precision=args.precision,
-            compile=args.compile,
-        )
-
-    logger.info("Llama model loaded, loading VQ-GAN model...")
-
-    decoder_model = load_decoder_model(
-        config_name=args.decoder_config_name,
-        checkpoint_path=args.decoder_checkpoint_path,
-        device=args.device,
-    )
-
-    logger.info("VQ-GAN model loaded, warming up...")
-
-    vad_model = load_silero_vad()
-
-    logger.info("VAD model loaded, warming up...")
-
-    if args.mode == "tts":
-        # Dry run to ensure models work and avoid first-time latency
-        list(
-            inference(
-                ServeTTSRequest(
-                    text="Hello world.",
-                    references=[],
-                    reference_id=None,
-                    max_new_tokens=0,
-                    chunk_length=200,
-                    top_p=0.7,
-                    repetition_penalty=1.5,
-                    temperature=0.7,
-                    emotion=None,
-                    format="wav",
-                )
-            )
-        )
-
-    logger.info(f"Warming up done, starting server at http://{args.listen}")
-
-
-if __name__ == "__main__":
-
-    import uvicorn
-
-    args = parse_args()
-    host, port = args.listen.split(":")
-    uvicorn.run(
-        "tools.api:app",
-        host=host,
-        port=int(port),
-        workers=args.workers,
-        log_level="info",
-    )

+ 6 - 12
tools/post_api.py → tools/api_client.py

@@ -69,10 +69,6 @@ def parse_args():
     parser.add_argument(
         "--format", type=str, choices=["wav", "mp3", "flac"], default="wav"
     )
-    parser.add_argument(
-        "--mp3_bitrate", type=int, choices=[64, 128, 192], default=64, help="kHz"
-    )
-    parser.add_argument("--opus_bitrate", type=int, default=-1000)
     parser.add_argument(
         "--latency",
         type=str,
@@ -112,11 +108,9 @@ def parse_args():
     parser.add_argument(
         "--use_memory_cache",
         type=str,
-        default="never",
-        choices=["on-demand", "never"],
-        help="Cache encoded references codes in memory.\n"
-        "If `on-demand`, the server will use cached encodings\n "
-        "instead of encoding reference audio again.",
+        default="off",
+        choices=["on", "off"],
+        help="Cache encoded references codes in memory.\n",
     )
     parser.add_argument(
         "--seed",
@@ -154,14 +148,14 @@ if __name__ == "__main__":
     data = {
         "text": args.text,
         "references": [
-            ServeReferenceAudio(audio=ref_audio, text=ref_text)
+            ServeReferenceAudio(
+                audio=ref_audio if ref_audio is not None else b"", text=ref_text
+            )
             for ref_text, ref_audio in zip(ref_texts, byte_audios)
         ],
         "reference_id": idstr,
         "normalize": args.normalize,
         "format": args.format,
-        "mp3_bitrate": args.mp3_bitrate,
-        "opus_bitrate": args.opus_bitrate,
         "max_new_tokens": args.max_new_tokens,
         "chunk_length": args.chunk_length,
         "top_p": args.top_p,

+ 98 - 0
tools/api_server.py

@@ -0,0 +1,98 @@
+from threading import Lock
+
+import pyrootutils
+import uvicorn
+from kui.asgi import FactoryClass, HTTPException, HttpRoute, Kui, OpenAPI, Routes
+from loguru import logger
+
+pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
+
+from tools.server.api_utils import MsgPackRequest, parse_args
+from tools.server.exception_handler import ExceptionHandler
+from tools.server.model_manager import ModelManager
+from tools.server.views import (
+    ASRView,
+    ChatView,
+    HealthView,
+    TTSView,
+    VQGANDecodeView,
+    VQGANEncodeView,
+)
+
+
+class API(ExceptionHandler):
+    def __init__(self):
+        self.args = parse_args()
+        self.routes = [
+            ("/v1/health", HealthView),
+            ("/v1/vqgan/encode", VQGANEncodeView),
+            ("/v1/vqgan/decode", VQGANDecodeView),
+            ("/v1/asr", ASRView),
+            ("/v1/tts", TTSView),
+            ("/v1/chat", ChatView),
+        ]
+        self.routes = Routes([HttpRoute(path, view) for path, view in self.routes])
+
+        self.openapi = OpenAPI(
+            {
+                "title": "Fish Speech API",
+                "version": "1.5.0",
+            },
+        ).routes
+
+        # Initialize the app
+        self.app = Kui(
+            routes=self.routes + self.openapi[1:],  # Remove the default route
+            exception_handlers={
+                HTTPException: self.http_exception_handler,
+                Exception: self.other_exception_handler,
+            },
+            factory_class=FactoryClass(http=MsgPackRequest),
+            cors_config={},
+        )
+
+        # Add the state variables
+        self.app.state.lock = Lock()
+        self.app.state.device = self.args.device
+        self.app.state.max_text_length = self.args.max_text_length
+
+        # Associate the app with the model manager
+        self.app.on_startup(self.initialize_app)
+
+    async def initialize_app(self, app: Kui):
+        # Make the ModelManager available to the views
+        app.state.model_manager = ModelManager(
+            mode=self.args.mode,
+            device=self.args.device,
+            half=self.args.half,
+            compile=self.args.compile,
+            asr_enabled=self.args.load_asr_model,
+            llama_checkpoint_path=self.args.llama_checkpoint_path,
+            decoder_checkpoint_path=self.args.decoder_checkpoint_path,
+            decoder_config_name=self.args.decoder_config_name,
+        )
+
+        logger.info(f"Startup done, listening server at http://{self.args.listen}")
+
+
+# Each worker process created by Uvicorn has its own memory space,
+# meaning that models and variables are not shared between processes.
+# Therefore, any variables (like `llama_queue` or `decoder_model`)
+# will not be shared across workers.
+
+# Multi-threading for deep learning can cause issues, such as inconsistent
+# outputs if multiple threads access the same buffers simultaneously.
+# Instead, it's better to use multiprocessing or independent models per thread.
+
+if __name__ == "__main__":
+
+    api = API()
+    host, port = api.args.listen.split(":")
+
+    uvicorn.run(
+        api.app,
+        host=host,
+        port=int(port),
+        workers=api.args.workers,
+        log_level="info",
+    )

+ 2 - 2
tools/fish_e2e.py

@@ -14,8 +14,8 @@ import ormsgpack
 import soundfile as sf
 
 from .schema import (
+    ServeChatRequest,
     ServeMessage,
-    ServeRequest,
     ServeTextPart,
     ServeVQGANDecodeRequest,
     ServeVQGANEncodeRequest,
@@ -163,7 +163,7 @@ class FishE2EAgent:
         else:
             user_codes = None
 
-        request = ServeRequest(
+        request = ServeChatRequest(
             messages=prev_messages
             + (
                 [

+ 193 - 0
tools/inference_engine/__init__.py

@@ -0,0 +1,193 @@
+import gc
+import queue
+from typing import Generator
+
+import numpy as np
+import torch
+from loguru import logger
+
+from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture
+from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
+from fish_speech.utils import autocast_exclude_mps, set_seed
+from tools.inference_engine.reference_loader import ReferenceLoader
+from tools.inference_engine.utils import InferenceResult, wav_chunk_header
+from tools.inference_engine.vq_manager import VQManager
+from tools.llama.generate import (
+    GenerateRequest,
+    GenerateResponse,
+    WrappedGenerateResponse,
+)
+from tools.schema import ServeTTSRequest
+
+
+class TTSInferenceEngine(ReferenceLoader, VQManager):
+
+    def __init__(
+        self,
+        llama_queue: queue.Queue,
+        decoder_model: FireflyArchitecture,
+        precision: torch.dtype,
+        compile: bool,
+    ) -> None:
+
+        super().__init__()
+
+        self.llama_queue = llama_queue
+        self.decoder_model = decoder_model
+        self.precision = precision
+        self.compile = compile
+
+    @torch.inference_mode()
+    def inference(self, req: ServeTTSRequest) -> Generator[InferenceResult, None, None]:
+        """
+        Main inference function:
+        - Loads the reference audio and text.
+        - Calls the LLAMA model for inference.
+        - Decodes the VQ tokens to audio.
+        """
+
+        ref_id: str | None = req.reference_id
+        prompt_tokens, prompt_texts = [], []
+        # Load the reference audio and text based on id or hash
+        if ref_id is not None:
+            prompt_tokens, prompt_texts = self.load_by_id(ref_id, req.use_memory_cache)
+
+        elif req.references:
+            prompt_tokens, prompt_texts = self.load_by_hash(
+                req.references, req.use_memory_cache
+            )
+
+        # Set the random seed if provided
+        if req.seed is not None:
+            set_seed(req.seed)
+            logger.warning(f"set seed: {req.seed}")
+
+        # Get the symbolic tokens from the LLAMA model
+        response_queue = self.send_Llama_request(req, prompt_tokens, prompt_texts)
+
+        # Get the sample rate from the decoder model
+        sample_rate = self.decoder_model.spec_transform.sample_rate
+
+        # If streaming, send the header
+        if req.streaming:
+            yield InferenceResult(
+                code="header",
+                audio=(sample_rate, wav_chunk_header(sample_rate=sample_rate)),
+                error=None,
+            )
+
+        segments = []
+
+        while True:
+            # Get the response from the LLAMA model
+            wrapped_result: WrappedGenerateResponse = response_queue.get()
+            if wrapped_result.status == "error":
+                yield InferenceResult(
+                    code="error",
+                    audio=None,
+                    error=(
+                        wrapped_result.response
+                        if isinstance(wrapped_result.response, Exception)
+                        else Exception("Unknown error")
+                    ),
+                )
+                break
+
+            # Check the response type
+            if not isinstance(wrapped_result.response, GenerateResponse):
+                raise TypeError(
+                    "Expected GenerateResponse, got {type(wrapped_result.response).__name__}"
+                )
+
+            result: GenerateResponse = wrapped_result.response
+            if result.action != "next":
+                segment = self.get_audio_segment(result)
+
+                if req.streaming:  # Used only by the API server
+                    yield InferenceResult(
+                        code="segment",
+                        audio=(sample_rate, segment),
+                        error=None,
+                    )
+                else:
+                    segments.append(segment)
+            else:
+                break
+
+        # Clean up the memory
+        if torch.cuda.is_available():
+            torch.cuda.empty_cache()
+            gc.collect()
+
+        # Edge case: no audio generated
+        if len(segments) == 0:
+            yield InferenceResult(
+                code="error",
+                audio=None,
+                error=RuntimeError("No audio generated, please check the input text."),
+            )
+        else:
+            # Streaming or not, return the final audio
+            audio = np.concatenate(segments, axis=0)
+            yield InferenceResult(
+                code="final",
+                audio=(sample_rate, audio),
+                error=None,
+            )
+
+        return None
+
+    def send_Llama_request(
+        self, req: ServeTTSRequest, prompt_tokens: list, prompt_texts: list
+    ) -> queue.Queue:
+        """
+        Send a request to the LLAMA model to generate the symbolic tokens.
+        """
+
+        # Prepare the request
+        request = dict(
+            device=self.decoder_model.device,
+            max_new_tokens=req.max_new_tokens,
+            text=(
+                req.text
+                if not req.normalize
+                else ChnNormedText(raw_text=req.text).normalize()
+            ),
+            top_p=req.top_p,
+            repetition_penalty=req.repetition_penalty,
+            temperature=req.temperature,
+            compile=self.compile,
+            iterative_prompt=req.chunk_length > 0,
+            chunk_length=req.chunk_length,
+            max_length=4096,
+            prompt_tokens=prompt_tokens,
+            prompt_text=prompt_texts,
+        )
+
+        # Create a queue to get the response
+        response_queue = queue.Queue()
+
+        # Send the request to the LLAMA model
+        self.llama_queue.put(
+            GenerateRequest(
+                request=request,
+                response_queue=response_queue,
+            )
+        )
+
+        return response_queue
+
+    def get_audio_segment(self, result: GenerateResponse) -> np.ndarray:
+        """
+        Decode the VQ tokens to audio.
+        """
+
+        # Don't use autocast on MPS devices
+        with autocast_exclude_mps(
+            device_type=self.decoder_model.device.type, dtype=self.precision
+        ):
+            # Decode the symbolic tokens to audio
+            segment = self.decode_vq_tokens(codes=result.codes)
+
+        # Convert the audio to numpy
+        return segment.float().cpu().numpy()

+ 128 - 0
tools/inference_engine/reference_loader.py

@@ -0,0 +1,128 @@
+import io
+from hashlib import sha256
+from pathlib import Path
+from typing import Callable, Literal, Tuple
+
+import torch
+import torchaudio
+from loguru import logger
+
+from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture
+from tools.file import AUDIO_EXTENSIONS, audio_to_bytes, list_files, read_ref_text
+from tools.schema import ServeReferenceAudio
+
+
+class ReferenceLoader:
+
+    def __init__(self) -> None:
+        """
+        Component of the TTSInferenceEngine class.
+        Loads and manages the cache for the reference audio and text.
+        """
+        self.ref_by_id: dict = {}
+        self.ref_by_hash: dict = {}
+
+        # Make Pylance happy (attribut/method not defined...)
+        self.decoder_model: FireflyArchitecture
+        self.encode_reference: Callable
+
+        # Define the torchaudio backend
+        backends = torchaudio.list_audio_backends()
+        if "ffmpeg" in backends:
+            self.backend = "ffmpeg"
+        else:
+            self.backend = "soundfile"
+
+    def load_by_id(
+        self,
+        id: str,
+        use_cache: Literal["on", "off"],
+    ) -> Tuple:
+
+        # Load the references audio and text by id
+        ref_folder = Path("references") / id
+        ref_folder.mkdir(parents=True, exist_ok=True)
+        ref_audios = list_files(
+            ref_folder, AUDIO_EXTENSIONS, recursive=True, sort=False
+        )
+
+        if use_cache == "off" or id not in self.ref_by_id:
+            # If the references are not already loaded, encode them
+            prompt_tokens = [
+                self.encode_reference(
+                    decoder_model=self.decoder_model,
+                    reference_audio=audio_to_bytes(str(ref_audio)),
+                    enable_reference_audio=True,
+                )
+                for ref_audio in ref_audios
+            ]
+            prompt_texts = [
+                read_ref_text(str(ref_audio.with_suffix(".lab")))
+                for ref_audio in ref_audios
+            ]
+            self.ref_by_id[id] = (prompt_tokens, prompt_texts)
+
+        else:
+            # Reuse already encoded references
+            logger.info("Use same references")
+            prompt_tokens, prompt_texts = self.ref_by_id[id]
+
+        return prompt_tokens, prompt_texts
+
+    def load_by_hash(
+        self,
+        references: list[ServeReferenceAudio],
+        use_cache: Literal["on", "off"],
+    ) -> Tuple:
+
+        # Load the references audio and text by hash
+        audio_hashes = [sha256(ref.audio).hexdigest() for ref in references]
+
+        cache_used = False
+        prompt_tokens, prompt_texts = [], []
+        for i, ref in enumerate(references):
+            if use_cache == "off" or audio_hashes[i] not in self.ref_by_hash:
+                # If the references are not already loaded, encode them
+                prompt_tokens.append(
+                    self.encode_reference(
+                        decoder_model=self.decoder_model,
+                        reference_audio=ref.audio,
+                        enable_reference_audio=True,
+                    )
+                )
+                prompt_texts.append(ref.text)
+                self.ref_by_hash[audio_hashes[i]] = (prompt_tokens, prompt_texts)
+
+            else:
+                # Reuse already encoded references
+                prompt_text, prompt_token = self.ref_by_hash[audio_hashes[i]]
+                prompt_texts.append(prompt_text)
+                prompt_tokens.append(prompt_token)
+                cache_used = True
+
+        if cache_used:
+            logger.info("Use same references")
+
+        return prompt_tokens, prompt_texts
+
+    def load_audio(self, reference_audio, sr):
+        """
+        Load the audio data from a file or bytes.
+        """
+        if len(reference_audio) > 255 or not Path(reference_audio).exists():
+            audio_data = reference_audio
+            reference_audio = io.BytesIO(audio_data)
+
+        waveform, original_sr = torchaudio.load(reference_audio, backend=self.backend)
+
+        if waveform.shape[0] > 1:
+            waveform = torch.mean(waveform, dim=0, keepdim=True)
+
+        if original_sr != sr:
+            resampler = torchaudio.transforms.Resample(
+                orig_freq=original_sr, new_freq=sr
+            )
+            waveform = resampler(waveform)
+
+        audio = waveform.squeeze().numpy()
+        return audio

+ 42 - 0
tools/inference_engine/utils.py

@@ -0,0 +1,42 @@
+import io
+import wave
+from dataclasses import dataclass
+from typing import Literal, Optional, Tuple
+
+import numpy as np
+
+from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
+
+
+@dataclass
+class InferenceResult:
+    code: Literal["header", "segment", "error", "final"]
+    audio: Optional[Tuple[int, np.ndarray]]
+    error: Optional[Exception]
+
+
+def normalize_text(user_input: str, use_normalization: bool) -> str:
+    """Normalize user input text if needed."""
+    if use_normalization:
+        return ChnNormedText(raw_text=user_input).normalize()
+    else:
+        return user_input
+
+
+def wav_chunk_header(
+    sample_rate: int = 44100, bit_depth: int = 16, channels: int = 1
+) -> np.ndarray:
+    buffer = io.BytesIO()
+
+    with wave.open(buffer, "wb") as wav_file:
+        wav_file.setnchannels(channels)
+        wav_file.setsampwidth(bit_depth // 8)
+        wav_file.setframerate(sample_rate)
+
+    wav_header_bytes = buffer.getvalue()
+    buffer.close()
+
+    # Convert to numpy array
+    wav_header = np.frombuffer(wav_header_bytes, dtype=np.uint8)
+
+    return wav_header

+ 57 - 0
tools/inference_engine/vq_manager.py

@@ -0,0 +1,57 @@
+from typing import Callable
+
+import torch
+from loguru import logger
+
+from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture
+
+
+class VQManager:
+
+    def __init__(self):
+        # Make Pylance happy (attribut/method not defined...)
+        self.decoder_model: FireflyArchitecture
+        self.load_audio: Callable
+
+    def decode_vq_tokens(self, codes):
+        feature_lengths = torch.tensor(
+            [codes.shape[1]], device=self.decoder_model.device
+        )
+        logger.info(f"VQ features: {codes.shape}")
+
+        if isinstance(self.decoder_model, FireflyArchitecture):
+            return self.decoder_model.decode(
+                indices=codes[None],
+                feature_lengths=feature_lengths,
+            )[0].squeeze()
+
+        raise ValueError(f"Unknown model type: {type(self.decoder_model)}")
+
+    def encode_reference(self, reference_audio, enable_reference_audio):
+        if enable_reference_audio and reference_audio is not None:
+            # Load audios, and prepare basic info here
+            reference_audio_content = self.load_audio(
+                reference_audio, self.decoder_model.spec_transform.sample_rate
+            )
+
+            audios = torch.from_numpy(reference_audio_content).to(
+                self.decoder_model.device
+            )[None, None, :]
+            audio_lengths = torch.tensor(
+                [audios.shape[2]], device=self.decoder_model.device, dtype=torch.long
+            )
+            logger.info(
+                f"Loaded audio with {audios.shape[2] / self.decoder_model.spec_transform.sample_rate:.2f} seconds"
+            )
+
+            # VQ Encoder
+            if isinstance(self.decoder_model, FireflyArchitecture):
+                prompt_tokens = self.decoder_model.encode(audios, audio_lengths)[0][0]
+                logger.info(f"Encoded prompt: {prompt_tokens.shape}")
+            else:
+                raise ValueError(f"Unknown model type: {type(self.decoder_model)}")
+        else:
+            prompt_tokens = None
+            logger.info("No reference audio provided")
+
+        return prompt_tokens

+ 0 - 95
tools/msgpack_api.py

@@ -1,95 +0,0 @@
-import os
-from argparse import ArgumentParser
-from pathlib import Path
-
-import httpx
-import ormsgpack
-
-from tools.schema import ServeReferenceAudio, ServeTTSRequest
-
-api_key = os.environ.get("FISH_API_KEY", "YOUR_API_KEY")
-
-
-def audio_request():
-    # priority: ref_id > references
-    request = ServeTTSRequest(
-        text="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.",
-        # reference_id="114514",
-        references=[
-            ServeReferenceAudio(
-                audio=open("lengyue.wav", "rb").read(),
-                text=open("lengyue.lab", "r", encoding="utf-8").read(),
-            )
-        ],
-        streaming=True,
-    )
-
-    api_key = os.environ.get("FISH_API_KEY", "YOUR_API_KEY")
-
-    with (
-        httpx.Client() as client,
-        open("hello.wav", "wb") as f,
-    ):
-        with client.stream(
-            "POST",
-            "http://127.0.0.1:8080/v1/tts",
-            content=ormsgpack.packb(request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC),
-            headers={
-                "authorization": f"Bearer {api_key}",
-                "content-type": "application/msgpack",
-            },
-            timeout=None,
-        ) as response:
-            for chunk in response.iter_bytes():
-                f.write(chunk)
-
-
-def asr_request(audio_path: Path):
-
-    # Read the audio file
-    with open(
-        str(audio_path),
-        "rb",
-    ) as audio_file:
-        audio_data = audio_file.read()
-
-    # Prepare the request data
-    request_data = {
-        "audio": audio_data,
-        "language": "en",  # Optional: specify the language
-        "ignore_timestamps": False,  # Optional: set to True to ignore precise timestamps
-    }
-
-    # Send the request
-    with httpx.Client() as client:
-        response = client.post(
-            "https://api.fish.audio/v1/asr",
-            headers={
-                "Authorization": f"Bearer {api_key}",
-                "Content-Type": "application/msgpack",
-            },
-            content=ormsgpack.packb(request_data),
-        )
-
-    # Parse the response
-    result = response.json()
-
-    print(f"Transcribed text: {result['text']}")
-    print(f"Audio duration: {result['duration']} seconds")
-
-    for segment in result["segments"]:
-        print(f"Segment: {segment['text']}")
-        print(f"Start time: {segment['start']}, End time: {segment['end']}")
-
-
-def parse_args():
-    parser = ArgumentParser()
-    parser.add_argument("--audio_path", type=Path, default="audio/ref/trump.mp3")
-
-    return parser.parse_args()
-
-
-if __name__ == "__main__":
-    args = parse_args()
-
-    asr_request(args.audio_path)

+ 101 - 0
tools/run_webui.py

@@ -0,0 +1,101 @@
+import os
+from argparse import ArgumentParser
+from pathlib import Path
+
+import pyrootutils
+import torch
+from loguru import logger
+
+pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
+
+from tools.inference_engine import TTSInferenceEngine
+from tools.llama.generate import launch_thread_safe_queue
+from tools.schema import ServeTTSRequest
+from tools.vqgan.inference import load_model as load_decoder_model
+from tools.webui import build_app
+from tools.webui.inference import get_inference_wrapper
+
+# Make einx happy
+os.environ["EINX_FILTER_TRACEBACK"] = "false"
+
+
+def parse_args():
+    parser = ArgumentParser()
+    parser.add_argument(
+        "--llama-checkpoint-path",
+        type=Path,
+        default="checkpoints/fish-speech-1.5",
+    )
+    parser.add_argument(
+        "--decoder-checkpoint-path",
+        type=Path,
+        default="checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
+    )
+    parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
+    parser.add_argument("--device", type=str, default="cuda")
+    parser.add_argument("--half", action="store_true")
+    parser.add_argument("--compile", action="store_true")
+    parser.add_argument("--max-gradio-length", type=int, default=0)
+    parser.add_argument("--theme", type=str, default="light")
+
+    return parser.parse_args()
+
+
+if __name__ == "__main__":
+    args = parse_args()
+    args.precision = torch.half if args.half else torch.bfloat16
+
+    # Check if CUDA is available
+    if not torch.cuda.is_available():
+        logger.info("CUDA is not available, running on CPU.")
+        args.device = "cpu"
+
+    logger.info("Loading Llama model...")
+    llama_queue = launch_thread_safe_queue(
+        checkpoint_path=args.llama_checkpoint_path,
+        device=args.device,
+        precision=args.precision,
+        compile=args.compile,
+    )
+
+    logger.info("Loading VQ-GAN model...")
+    decoder_model = load_decoder_model(
+        config_name=args.decoder_config_name,
+        checkpoint_path=args.decoder_checkpoint_path,
+        device=args.device,
+    )
+
+    logger.info("Decoder model loaded, warming up...")
+
+    # Create the inference engine
+    inference_engine = TTSInferenceEngine(
+        llama_queue=llama_queue,
+        decoder_model=decoder_model,
+        compile=args.compile,
+        precision=args.precision,
+    )
+
+    # Dry run to check if the model is loaded correctly and avoid the first-time latency
+    list(
+        inference_engine.inference(
+            ServeTTSRequest(
+                text="Hello world.",
+                references=[],
+                reference_id=None,
+                max_new_tokens=0,
+                chunk_length=200,
+                top_p=0.7,
+                repetition_penalty=1.5,
+                temperature=0.7,
+                format="wav",
+            )
+        )
+    )
+
+    logger.info("Warming up done, launching the web UI...")
+
+    # Get the inference function with the immutable arguments
+    inference_fct = get_inference_wrapper(inference_engine)
+
+    app = build_app(inference_fct, args.theme)
+    app.launch(show_api=True)

+ 9 - 29
tools/schema.py

@@ -1,16 +1,14 @@
 import os
 import queue
 from dataclasses import dataclass
-from typing import Annotated, Literal, Optional
+from typing import Annotated, Literal
 
 import torch
-from pydantic import AfterValidator, BaseModel, Field, confloat, conint, conlist
+from pydantic import BaseModel, Field, conint, conlist
 from pydantic.functional_validators import SkipValidation
 
 from fish_speech.conversation import Message, TextPart, VQPart
 
-GLOBAL_NUM_SAMPLES = int(os.getenv("GLOBAL_NUM_SAMPLES", 1))
-
 
 class ServeVQPart(BaseModel):
     type: Literal["vq"] = "vq"
@@ -64,7 +62,7 @@ class ServeASRResponse(BaseModel):
 
 
 class ServeMessage(BaseModel):
-    role: Literal["system", "assistant", "user", "raw"]
+    role: Literal["system", "assistant", "user"]
     parts: list[ServeVQPart | ServeTextPart]
 
     def to_conversation_message(self):
@@ -85,7 +83,7 @@ class ServeMessage(BaseModel):
         return new_message
 
 
-class ServeRequest(BaseModel):
+class ServeChatRequest(BaseModel):
     messages: Annotated[list[ServeMessage], conlist(ServeMessage, min_length=1)]
     max_new_tokens: int = 1024
     top_p: float = 0.7
@@ -114,11 +112,6 @@ class ServeVQGANDecodeResponse(BaseModel):
     audios: list[bytes]
 
 
-class ServeReferenceAudio(BaseModel):
-    audio: bytes
-    text: str
-
-
 class ServeForwardMessage(BaseModel):
     role: str
     content: str
@@ -150,24 +143,11 @@ class ServeReferenceAudio(BaseModel):
         return f"ServeReferenceAudio(text={self.text!r}, audio_size={len(self.audio)})"
 
 
-class ServeChatRequestV1(BaseModel):
-    model: str = "llama3-8b"
-    messages: list[ServeForwardMessage] = []
-    audio: bytes | None = None
-    temperature: float = 1.0
-    top_p: float = 1.0
-    max_tokens: int = 256
-    voice: str = "jessica"
-    tts_audio_format: Literal["mp3", "pcm", "opus"] = "mp3"
-    tts_audio_bitrate: Literal[16, 24, 32, 48, 64, 96, 128, 192] = 128
-
-
 class ServeTTSRequest(BaseModel):
     text: str
     chunk_length: Annotated[int, conint(ge=100, le=300, strict=True)] = 200
     # Audio format
     format: Literal["wav", "pcm", "mp3"] = "wav"
-    mp3_bitrate: Literal[64, 128, 192] = 128
     # References audios for in-context learning
     references: list[ServeReferenceAudio] = []
     # Reference id
@@ -175,16 +155,16 @@ class ServeTTSRequest(BaseModel):
     # Just pass 7f92f8afb8ec43bf81429cc1c9199cb1
     reference_id: str | None = None
     seed: int | None = None
-    use_memory_cache: Literal["on-demand", "never"] = "never"
+    use_memory_cache: Literal["on", "off"] = "off"
     # Normalize text for en & zh, this increase stability for numbers
     normalize: bool = True
-    mp3_bitrate: Optional[int] = 64
-    opus_bitrate: Optional[int] = -1000
-    # Balance mode will reduce latency to 300ms, but may decrease stability
-    latency: Literal["normal", "balanced"] = "normal"
     # not usually used below
     streaming: bool = False
     max_new_tokens: int = 1024
     top_p: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
     repetition_penalty: Annotated[float, Field(ge=0.9, le=2.0, strict=True)] = 1.2
     temperature: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
+
+    class Config:
+        # Allow arbitrary types for pytorch related types
+        arbitrary_types_allowed = True

+ 57 - 0
tools/server/agent/__init__.py

@@ -0,0 +1,57 @@
+import struct
+from functools import partial
+
+import ormsgpack
+
+from tools.server.agent.generate import generate_responses
+from tools.server.agent.pre_generation_utils import prepare_messages
+
+
+def execute_request(input_queue, tokenizer, config, request, device):
+    """
+    This function prepares the conversation, encodes the request,
+    sends the generation request, and handles decoding/streaming.
+    It returns a response generator (ServeResponse or ServeStreamResponse).
+    """
+    prompt, im_end_id = prepare_messages(request, tokenizer, config)
+    yield from generate_responses(
+        input_queue, tokenizer, config, request, prompt, im_end_id, device
+    )
+
+
+def response_generator(req, llama_queue, tokenizer, config, device):
+    """
+    Non-streaming response wrapper for the chat endpoint.
+    Only returns the final result.
+    """
+    generator = execute_request(llama_queue, tokenizer, config, req, device)
+    return next(generator)
+
+
+async def streaming_generator(req, llama_queue, tokenizer, config, device, json_mode):
+    """
+    Streaming response wrapper for the chat endpoint.
+    Returns the response in chunks.
+    """
+    generator = execute_request(llama_queue, tokenizer, config, req, device)
+    for i in generator:
+        if json_mode:
+            body = i.model_dump_json().encode("utf-8")
+            yield b"data: " + body + b"\n\n"
+        else:
+            body = ormsgpack.packb(i, option=ormsgpack.OPT_SERIALIZE_PYDANTIC)
+            yield struct.pack("I", len(body)) + body
+
+
+def get_response_generator(
+    llama_queue, tokenizer, config, req, device, json_mode
+) -> partial:
+    """
+    Get the correct response generator based on the request.
+    """
+    if not req.streaming:
+        return partial(response_generator, req, llama_queue, tokenizer, config, device)
+    else:
+        return partial(
+            streaming_generator, req, llama_queue, tokenizer, config, device, json_mode
+        )

+ 119 - 0
tools/server/agent/generate.py

@@ -0,0 +1,119 @@
+import time
+
+from tools.schema import ServeMessage, ServeResponse, ServeStreamResponse
+from tools.server.agent.generation_utils import (
+    initialize_decode_buffers,
+    process_response_tokens,
+    send_reset_buffer,
+)
+from tools.server.agent.pre_generation_utils import (
+    create_generation_request,
+    send_generation_request,
+)
+
+
+def generate_responses(
+    input_queue, tokenizer, config, request, prompt, im_end_id, device
+):
+    """
+    Main generation function that handles the conversation, encodes the request,
+    sends the generation request, and handles decoding/streaming.
+    It returns a response generator (ServeResponse or ServeStreamResponse).
+    """
+    stats = {}
+    start = time.time()
+    stats["start_time"] = start
+    stats["tokens_count"] = 0
+
+    # Prepare and send the generation request
+    req = create_generation_request(prompt, request, im_end_id, device)
+    response_queue = send_generation_request(input_queue, req)
+    decode_buffer, parts, finished = initialize_decode_buffers(request.num_samples)
+
+    while True:
+        response = response_queue.get()
+
+        # Handle abnormal finish or error
+        if response in ["stop", "error"]:
+            finish_reason = response
+            break
+
+        # Process the response tokens
+        is_first_token = stats["tokens_count"] == 0
+        responses = process_response_tokens(
+            response,
+            tokenizer,
+            config,
+            request,
+            decode_buffer,
+            parts,
+            finished,
+            im_end_id,
+            stats,
+            start,
+            is_first_token,
+        )
+
+        # Yield the responses if streaming
+        if request.streaming and responses:
+            for r in responses:
+                yield r
+
+        stats["tokens_count"] += 1
+
+        # Check if all samples are finished
+        if all(finished):
+            finish_reason = "stop"
+            break
+
+    # Finalize the response
+    final_responses = finalize_response(
+        request, finished, decode_buffer, tokenizer, parts, stats, finish_reason
+    )
+    for fr in final_responses:
+        yield fr
+
+
+def finalize_response(
+    request, finished, decode_buffer, tokenizer, parts, stats, finish_reason
+):
+    """
+    Finalize the response by sending the remaining text buffers.
+    """
+    responses = []
+
+    # Send the remaining text buffers
+    for sample_id in range(request.num_samples):
+        responses.extend(
+            send_reset_buffer(sample_id, decode_buffer, tokenizer, parts, request)
+        )
+
+    # Calculate the final stats
+    stats["total_time"] = (time.time() - stats["start_time"]) * 1000
+    stats["total_tokens"] = stats["tokens_count"]
+
+    # If streaming, send the final chunks for each sample
+    if request.streaming:
+        for sample_id in range(request.num_samples):
+            if finished[sample_id]:
+                continue
+            responses.append(
+                ServeStreamResponse(
+                    finish_reason=finish_reason, stats=stats, sample_id=sample_id
+                )
+            )
+    else:
+        # If not streaming, send the full messages for each sample
+        full_messages = [
+            ServeMessage(role="assistant", parts=parts[i])
+            for i in range(request.num_samples)
+        ]
+        responses.append(
+            ServeResponse(
+                messages=full_messages,
+                finish_reason=finish_reason,
+                stats=stats,
+            )
+        )
+
+    return responses

+ 122 - 0
tools/server/agent/generation_utils.py

@@ -0,0 +1,122 @@
+import time
+
+from tools.schema import (
+    ServeStreamDelta,
+    ServeStreamResponse,
+    ServeTextPart,
+    ServeVQPart,
+)
+
+
+def initialize_decode_buffers(num_samples):
+    """Initialise the decode buffers for each sample."""
+    decode_buffer = [[] for _ in range(num_samples)]
+    parts = [[] for _ in range(num_samples)]
+    finished = [False for _ in range(num_samples)]
+    return decode_buffer, parts, finished
+
+
+def send_reset_buffer(sample_id, decode_buffer, tokenizer, parts, request):
+    """Send the remaining text buffer for a sample."""
+    if len(decode_buffer[sample_id]) == 0:
+        return []
+
+    decoded = tokenizer.decode(decode_buffer[sample_id])
+    part = ServeTextPart(text=decoded)
+
+    responses = []
+    if request.streaming:
+        responses.append(ServeStreamResponse(delta=ServeStreamDelta(part=part)))
+    else:
+        parts[sample_id].append(part)
+
+    decode_buffer[sample_id] = []
+    return responses
+
+
+def handle_semantic_tokens(tokens, config, sample_id, parts, request):
+    """Handle the semantic tokens returned by the model."""
+    responses = []
+    _tokens = tokens[1:].clone()
+
+    if not config.share_codebook_embeddings:
+        for i in range(len(_tokens)):
+            _tokens[i] -= config.codebook_size * i
+
+    # If streaming, send the VQ parts directly
+    if request.streaming:
+        responses.append(
+            ServeStreamResponse(
+                sample_id=sample_id,
+                delta=ServeStreamDelta(part=ServeVQPart(codes=_tokens.tolist())),
+            )
+        )
+    else:
+        # If not streaming, accumulate the VQ parts
+        if not parts[sample_id] or not isinstance(parts[sample_id][-1], ServeVQPart):
+            parts[sample_id].append(ServeVQPart(codes=_tokens.tolist()))
+        else:
+            # Accumulate the codes
+            for codebook_id, value in enumerate(_tokens):
+                parts[sample_id][-1].codes[codebook_id].append(value.item())
+
+    return responses
+
+
+def process_response_tokens(
+    response,
+    tokenizer,
+    config,
+    request,
+    decode_buffer,
+    parts,
+    finished,
+    im_end_id,
+    stats,
+    start,
+    is_first_token,
+):
+    """Process the response tokens returned by the model."""
+    responses = []
+    for sample_id, tokens in enumerate(response):
+        if finished[sample_id]:
+            continue
+
+        # End of the conversation
+        if tokens[0] == im_end_id:
+            finished[sample_id] = True
+            # Send the remaining text buffer
+            responses.extend(
+                send_reset_buffer(sample_id, decode_buffer, tokenizer, parts, request)
+            )
+            if request.streaming:
+                responses.append(
+                    ServeStreamResponse(
+                        sample_id=sample_id,
+                        finish_reason="stop",
+                        stats=stats,
+                    )
+                )
+            continue
+
+        # Check if the token is semantic
+        is_semantic = (
+            tokenizer.semantic_begin_id <= tokens[0] <= tokenizer.semantic_end_id
+        )
+
+        if is_semantic:
+            # Before the semantic tokens, send the remaining text buffer
+            responses.extend(
+                send_reset_buffer(sample_id, decode_buffer, tokenizer, parts, request)
+            )
+            responses.extend(
+                handle_semantic_tokens(tokens, config, sample_id, parts, request)
+            )
+        else:
+            # Accumulate the text tokens (not implemented?)
+            decode_buffer[sample_id].append(tokens[0, 0])
+
+    if is_first_token:
+        stats["time_to_first_token"] = (time.time() - start) * 1000
+
+    return responses

+ 72 - 0
tools/server/agent/pre_generation_utils.py

@@ -0,0 +1,72 @@
+import queue
+
+from fish_speech.conversation import Conversation, Message
+from fish_speech.tokenizer import IM_END_TOKEN
+from tools.llama.generate import GenerateRequest
+
+
+def prepare_messages(request, tokenizer, config):
+    """
+    Reorganise the provided list of messages into a conversation.
+    Encode the conversation for inference.
+    """
+    # Convert the messages to ConversationMessage objects
+    messages = [msg.to_conversation_message() for msg in request.messages]
+
+    if len(messages) < 1:
+        raise ValueError("At least one message is required")
+
+    # Check the last message to determine the next step
+    last_role = messages[-1].role
+    match last_role:
+        case "user":
+            # The last message is from the user, ask the assistant to respond with a new message
+            messages.append(
+                Message(role="assistant", parts=[], add_im_end=False, modality="voice")
+            )
+        case "raw":
+            # The last message is raw text, ask the assistant to complete it
+            messages[-1].add_im_start = False
+            messages[-1].add_im_end = False
+            messages[-1].modality = "voice"
+        case "assistant":
+            # The last message is from the assistant, ask the assistant to continue
+            messages[-1].add_im_end = False
+        case _:
+            # We expect it to be assistant if not user or raw
+            raise ValueError("The last message must be from the assistant, user or raw")
+
+    # Create a conversation object and encode it for inference
+    conv = Conversation(messages=messages)
+    prompt = conv.encode_for_inference(
+        tokenizer=tokenizer, num_codebooks=config.num_codebooks
+    )
+    im_end_id = tokenizer.get_token_id(IM_END_TOKEN)
+
+    return prompt, im_end_id
+
+
+def create_generation_request(prompt, request, im_end_id, device):
+    """
+    Convert the request into a dictionary that can be sent to the model for generation.
+    """
+    req = {
+        "prompt": prompt.to(device),
+        "max_new_tokens": request.max_new_tokens,
+        "im_end_id": im_end_id,
+        "temperature": request.temperature,
+        "top_p": request.top_p,
+        "repetition_penalty": request.repetition_penalty,
+        "num_samples": request.num_samples,
+        "early_stop_threshold": request.early_stop_threshold,
+    }
+    return req
+
+
+def send_generation_request(input_queue, req):
+    """
+    Send the generation request to the model and return a queue to get the response.
+    """
+    response_queue = queue.Queue()
+    input_queue.put(GenerateRequest(req, response_queue))
+    return response_queue

+ 75 - 0
tools/server/api_utils.py

@@ -0,0 +1,75 @@
+from argparse import ArgumentParser
+from http import HTTPStatus
+from typing import Annotated, Any
+
+import ormsgpack
+from baize.datastructures import ContentType
+from kui.asgi import HTTPException, HttpRequest
+
+from tools.inference_engine import TTSInferenceEngine
+from tools.schema import ServeTTSRequest
+from tools.server.inference import inference_wrapper as inference
+
+
+def parse_args():
+    parser = ArgumentParser()
+    parser.add_argument("--mode", type=str, choices=["agent", "tts"], default="tts")
+    parser.add_argument("--load-asr-model", action="store_true")
+    parser.add_argument(
+        "--llama-checkpoint-path",
+        type=str,
+        default="checkpoints/fish-speech-1.5",
+    )
+    parser.add_argument(
+        "--decoder-checkpoint-path",
+        type=str,
+        default="checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
+    )
+    parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
+    parser.add_argument("--device", type=str, default="cuda")
+    parser.add_argument("--half", action="store_true")
+    parser.add_argument("--compile", action="store_true")
+    parser.add_argument("--max-text-length", type=int, default=0)
+    parser.add_argument("--listen", type=str, default="127.0.0.1:8080")
+    parser.add_argument("--workers", type=int, default=1)
+
+    return parser.parse_args()
+
+
+class MsgPackRequest(HttpRequest):
+    async def data(
+        self,
+    ) -> Annotated[
+        Any, ContentType("application/msgpack"), ContentType("application/json")
+    ]:
+        if self.content_type == "application/msgpack":
+            return ormsgpack.unpackb(await self.body)
+
+        elif self.content_type == "application/json":
+            return await self.json
+
+        raise HTTPException(
+            HTTPStatus.UNSUPPORTED_MEDIA_TYPE,
+            headers={"Accept": "application/msgpack, application/json"},
+        )
+
+
+async def inference_async(req: ServeTTSRequest, engine: TTSInferenceEngine):
+    for chunk in inference(req, engine):
+        if isinstance(chunk, bytes):
+            yield chunk
+
+
+async def buffer_to_async_generator(buffer):
+    yield buffer
+
+
+def get_content_type(audio_format):
+    if audio_format == "wav":
+        return "audio/wav"
+    elif audio_format == "flac":
+        return "audio/flac"
+    elif audio_format == "mp3":
+        return "audio/mpeg"
+    else:
+        return "application/octet-stream"

+ 27 - 0
tools/server/exception_handler.py

@@ -0,0 +1,27 @@
+import traceback
+from http import HTTPStatus
+
+from kui.asgi import HTTPException, JSONResponse
+
+
+class ExceptionHandler:
+
+    async def http_exception_handler(self, exc: HTTPException):
+        return JSONResponse(
+            dict(
+                statusCode=exc.status_code,
+                message=exc.content,
+                error=HTTPStatus(exc.status_code).phrase,
+            ),
+            exc.status_code,
+            exc.headers,
+        )
+
+    async def other_exception_handler(self, exc: Exception):
+        traceback.print_exc()
+
+        status = HTTPStatus.INTERNAL_SERVER_ERROR
+        return JSONResponse(
+            dict(statusCode=status, message=str(exc), error=status.phrase),
+            status,
+        )

+ 41 - 0
tools/server/inference.py

@@ -0,0 +1,41 @@
+from http import HTTPStatus
+
+import numpy as np
+from kui.asgi import HTTPException
+
+from tools.inference_engine import TTSInferenceEngine
+from tools.schema import ServeTTSRequest
+
+AMPLITUDE = 32768  # Needs an explaination
+
+
+def inference_wrapper(req: ServeTTSRequest, engine: TTSInferenceEngine):
+    """
+    Wrapper for the inference function.
+    Used in the API server.
+    """
+    for result in engine.inference(req):
+        match result.code:
+            case "header":
+                if isinstance(result.audio, tuple):
+                    yield result.audio[1]
+
+            case "error":
+                raise HTTPException(
+                    HTTPStatus.INTERNAL_SERVER_ERROR,
+                    content=str(result.error),
+                )
+
+            case "segment":
+                if isinstance(result.audio, tuple):
+                    yield (result.audio[1] * AMPLITUDE).astype(np.int16).tobytes()
+
+            case "final":
+                if isinstance(result.audio, tuple):
+                    yield result.audio[1]
+                return None  # Stop the generator
+
+    raise HTTPException(
+        HTTPStatus.INTERNAL_SERVER_ERROR,
+        content="No audio generated, please check the input text.",
+    )

+ 119 - 0
tools/server/model_manager.py

@@ -0,0 +1,119 @@
+import torch
+from funasr import AutoModel
+from loguru import logger
+
+from tools.inference_engine import TTSInferenceEngine
+from tools.llama.generate import (
+    launch_thread_safe_queue,
+    launch_thread_safe_queue_agent,
+)
+from tools.schema import ServeTTSRequest
+from tools.server.inference import inference_wrapper as inference
+from tools.vqgan.inference import load_model as load_decoder_model
+
+ASR_MODEL_NAME = "iic/SenseVoiceSmall"
+
+
+class ModelManager:
+    def __init__(
+        self,
+        mode: str,
+        device: str,
+        half: bool,
+        compile: bool,
+        asr_enabled: bool,
+        llama_checkpoint_path: str,
+        decoder_checkpoint_path: str,
+        decoder_config_name: str,
+    ) -> None:
+
+        self.mode = mode
+        self.device = device
+        self.half = half
+        self.compile = compile
+
+        self.precision = torch.half if half else torch.bfloat16
+
+        # Check if CUDA is available
+        if not torch.cuda.is_available():
+            self.device = "cpu"
+            logger.info("CUDA is not available, running on CPU.")
+
+        # Load the ASR model if enabled
+        if asr_enabled:
+            self.load_asr_model(self.device)
+
+        # Load the TTS models
+        self.load_llama_model(
+            llama_checkpoint_path, self.device, self.precision, self.compile, self.mode
+        )
+        self.load_decoder_model(
+            decoder_config_name, decoder_checkpoint_path, self.device
+        )
+        self.tts_inference_engine = TTSInferenceEngine(
+            llama_queue=self.llama_queue,
+            decoder_model=self.decoder_model,
+            precision=self.precision,
+            compile=self.compile,
+        )
+
+        # Warm up the models
+        if self.mode == "tts":
+            self.warm_up(self.tts_inference_engine)
+
+    def load_asr_model(self, device, hub="ms") -> None:
+        self.asr_model = AutoModel(
+            model=ASR_MODEL_NAME,
+            device=device,
+            disable_pbar=True,
+            hub=hub,
+        )
+        logger.info("ASR model loaded.")
+
+    def load_llama_model(
+        self, checkpoint_path, device, precision, compile, mode
+    ) -> None:
+
+        if mode == "tts":
+            self.llama_queue = launch_thread_safe_queue(
+                checkpoint_path=checkpoint_path,
+                device=device,
+                precision=precision,
+                compile=compile,
+            )
+        elif mode == "agent":
+            self.llama_queue, self.tokenizer, self.config = (
+                launch_thread_safe_queue_agent(
+                    checkpoint_path=checkpoint_path,
+                    device=device,
+                    precision=precision,
+                    compile=compile,
+                )
+            )
+        else:
+            raise ValueError(f"Invalid mode: {mode}")
+
+        logger.info("LLAMA model loaded.")
+
+    def load_decoder_model(self, config_name, checkpoint_path, device) -> None:
+        self.decoder_model = load_decoder_model(
+            config_name=config_name,
+            checkpoint_path=checkpoint_path,
+            device=device,
+        )
+        logger.info("Decoder model loaded.")
+
+    def warm_up(self, tts_inference_engine) -> None:
+        request = ServeTTSRequest(
+            text="Hello world.",
+            references=[],
+            reference_id=None,
+            max_new_tokens=0,
+            chunk_length=200,
+            top_p=0.7,
+            repetition_penalty=1.5,
+            temperature=0.7,
+            format="wav",
+        )
+        list(inference(request, tts_inference_engine))
+        logger.info("Models warmed up.")

+ 129 - 0
tools/server/model_utils.py

@@ -0,0 +1,129 @@
+import io
+import re
+
+import librosa
+import torch
+import torchaudio
+from cachetools import LRUCache, cached
+
+CACHE_MAXSIZE = 10000
+MICRO_BATCH_SIZE = 8
+ASR_SAMPLE_RATE = 16000
+HUGE_GAP_THRESHOLD = 4000
+
+
+@torch.no_grad()
+@torch.autocast(device_type="cuda", dtype=torch.half)
+def batch_encode(model, audios_list: list[bytes]):
+    audios: list[torch.Tensor] = [
+        (
+            torch.from_numpy(
+                librosa.load(io.BytesIO(audio), sr=model.spec_transform.sample_rate)[0]
+            )[None]
+            if isinstance(audio, bytes)
+            else audio
+        )
+        for audio in audios_list
+    ]
+
+    lengths = torch.tensor([audio.shape[-1] for audio in audios], device=model.device)
+    max_length = lengths.max().item()
+
+    print(f"Encode max length: {max_length / model.spec_transform.sample_rate:.2f}s")
+
+    padded = torch.stack(
+        [
+            torch.nn.functional.pad(audio, (0, int(max_length - audio.shape[-1])))
+            for audio in audios
+        ]
+    ).to(model.device)
+
+    features, feature_lengths = model.encode(padded, audio_lengths=lengths)
+    features, feature_lengths = features.cpu(), feature_lengths.cpu()
+
+    return [feature[..., :length] for feature, length in zip(features, feature_lengths)]
+
+
+@cached(
+    cache=LRUCache(maxsize=CACHE_MAXSIZE),
+    key=lambda model, audios: (model.device, tuple(audios)),
+)
+def cached_vqgan_batch_encode(model, audios: list[bytes]):
+    return batch_encode(model, audios)
+
+
+@torch.no_grad()
+@torch.autocast(device_type="cuda", dtype=torch.half)
+def vqgan_decode(model, features):
+    lengths = torch.tensor(
+        [feature.shape[-1] for feature in features], device=model.device
+    )
+    max_length = lengths.max().item()
+    padded = torch.stack(
+        [
+            torch.nn.functional.pad(feature, (0, max_length - feature.shape[-1]))
+            for feature in features
+        ]
+    ).to(model.device)
+
+    # If bs too large, we do micro batch decode
+    audios, audio_lengths = [], []
+    for i in range(0, padded.shape[0], MICRO_BATCH_SIZE):
+        audio, audio_length = model.decode(
+            padded[i : i + MICRO_BATCH_SIZE],
+            feature_lengths=lengths[i : i + MICRO_BATCH_SIZE],
+        )
+        audios.append(audio)
+        audio_lengths.append(audio_length)
+    audios = torch.cat(audios, dim=0)
+    audio_lengths = torch.cat(audio_lengths, dim=0)
+    audios, audio_lengths = audios.cpu(), audio_lengths.cpu()
+
+    return [audio[..., :length].numpy() for audio, length in zip(audios, audio_lengths)]
+
+
+@torch.no_grad()
+def batch_asr(model, lock, audios, sr, language="auto"):
+    resampled_audios = []
+    for audio in audios:
+        audio = torchaudio.functional.resample(audio, sr, ASR_SAMPLE_RATE)
+        assert audio.ndim == 1
+        resampled_audios.append(audio)
+
+    with lock:
+        res = model.generate(
+            input=resampled_audios,
+            batch_size=len(resampled_audios),
+            language=language,
+            use_itn=True,
+        )
+
+    results = []
+    for r, audio in zip(res, audios):
+        text = r["text"]
+        text = re.sub(r"<\|.*?\|>", "", text)
+        duration = len(audio) / sr * 1000
+        huge_gap = False
+
+        if "timestamp" in r and len(r["timestamp"]) > 2:
+            for timestamp_a, timestamp_b in zip(
+                r["timestamp"][:-1], r["timestamp"][1:]
+            ):
+                # If there is a gap of more than 4 seconds, we consider it as a huge gap
+                if timestamp_b[0] - timestamp_a[1] > HUGE_GAP_THRESHOLD:
+                    huge_gap = True
+                    break
+
+            # Doesn't make sense to have a huge gap at the end
+            if duration - r["timestamp"][-1][1] > HUGE_GAP_THRESHOLD:
+                huge_gap = True
+
+        results.append(
+            {
+                "text": text,
+                "duration": duration,
+                "huge_gap": huge_gap,
+            }
+        )
+
+    return results

+ 246 - 0
tools/server/views.py

@@ -0,0 +1,246 @@
+import io
+import os
+import time
+from http import HTTPStatus
+
+import numpy as np
+import ormsgpack
+import soundfile as sf
+import torch
+from kui.asgi import HTTPException, HttpView, JSONResponse, StreamResponse, request
+from loguru import logger
+
+from tools.schema import (
+    ServeASRRequest,
+    ServeASRResponse,
+    ServeChatRequest,
+    ServeTTSRequest,
+    ServeVQGANDecodeRequest,
+    ServeVQGANDecodeResponse,
+    ServeVQGANEncodeRequest,
+    ServeVQGANEncodeResponse,
+)
+from tools.server.agent import get_response_generator
+from tools.server.api_utils import (
+    buffer_to_async_generator,
+    get_content_type,
+    inference_async,
+)
+from tools.server.inference import inference_wrapper as inference
+from tools.server.model_manager import ModelManager
+from tools.server.model_utils import batch_asr, cached_vqgan_batch_encode, vqgan_decode
+
+MAX_NUM_SAMPLES = int(os.getenv("NUM_SAMPLES", 1))
+
+
+class HealthView(HttpView):
+    """
+    Return the health status of the server.
+    """
+
+    @classmethod
+    async def post(cls):
+        return JSONResponse({"status": "ok"})
+
+
+class VQGANEncodeView(HttpView):
+    """
+    Encode the audio into symbolic tokens.
+    """
+
+    @classmethod
+    async def post(cls):
+        # Decode the request
+        payload = await request.data()
+        req = ServeVQGANEncodeRequest(**payload)
+
+        # Get the model from the app
+        model_manager: ModelManager = request.app.state.model_manager
+        decoder_model = model_manager.decoder_model
+
+        # Encode the audio
+        start_time = time.time()
+        tokens = cached_vqgan_batch_encode(decoder_model, req.audios)
+        logger.info(
+            f"[EXEC] VQGAN encode time: {(time.time() - start_time) * 1000:.2f}ms"
+        )
+
+        # Return the response
+        return ormsgpack.packb(
+            ServeVQGANEncodeResponse(tokens=[i.tolist() for i in tokens]),
+            option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
+        )
+
+
+class VQGANDecodeView(HttpView):
+    """
+    Decode the symbolic tokens into audio.
+    """
+
+    @classmethod
+    async def post(cls):
+        # Decode the request
+        payload = await request.data()
+        req = ServeVQGANDecodeRequest(**payload)
+
+        # Get the model from the app
+        model_manager: ModelManager = request.app.state.model_manager
+        decoder_model = model_manager.decoder_model
+
+        # Decode the audio
+        tokens = [torch.tensor(token, dtype=torch.int) for token in req.tokens]
+        start_time = time.time()
+        audios = vqgan_decode(decoder_model, tokens)
+        logger.info(
+            f"[EXEC] VQGAN decode time: {(time.time() - start_time) * 1000:.2f}ms"
+        )
+        audios = [audio.astype(np.float16).tobytes() for audio in audios]
+
+        # Return the response
+        return ormsgpack.packb(
+            ServeVQGANDecodeResponse(audios=audios),
+            option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
+        )
+
+
+class ASRView(HttpView):
+    """
+    Perform automatic speech recognition on the audio.
+    """
+
+    @classmethod
+    async def post(cls):
+        # Decode the request
+        payload = await request.data()
+        req = ServeASRRequest(**payload)
+
+        # Get the model from the app
+        model_manager: ModelManager = request.app.state.model_manager
+        asr_model = model_manager.asr_model
+        lock = request.app.state.lock
+
+        # Perform ASR
+        start_time = time.time()
+        audios = [np.frombuffer(audio, dtype=np.float16) for audio in req.audios]
+        audios = [torch.from_numpy(audio).float() for audio in audios]
+
+        if any(audios.shape[-1] >= 30 * req.sample_rate for audios in audios):
+            raise HTTPException(status_code=400, content="Audio length is too long")
+
+        transcriptions = batch_asr(
+            asr_model, lock, audios=audios, sr=req.sample_rate, language=req.language
+        )
+        logger.info(f"[EXEC] ASR time: {(time.time() - start_time) * 1000:.2f}ms")
+
+        # Return the response
+        return ormsgpack.packb(
+            ServeASRResponse(transcriptions=transcriptions),
+            option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
+        )
+
+
+class TTSView(HttpView):
+    """
+    Perform text-to-speech on the input text.
+    """
+
+    @classmethod
+    async def post(cls):
+        # Decode the request
+        payload = await request.data()
+        req = ServeTTSRequest(**payload)
+
+        # Get the model from the app
+        app_state = request.app.state
+        model_manager: ModelManager = app_state.model_manager
+        engine = model_manager.tts_inference_engine
+        sample_rate = engine.decoder_model.spec_transform.sample_rate
+
+        # Check if the text is too long
+        if app_state.max_text_length > 0 and len(req.text) > app_state.max_text_length:
+            raise HTTPException(
+                HTTPStatus.BAD_REQUEST,
+                content=f"Text is too long, max length is {app_state.max_text_length}",
+            )
+
+        # Check if streaming is enabled
+        if req.streaming and req.format != "wav":
+            raise HTTPException(
+                HTTPStatus.BAD_REQUEST,
+                content="Streaming only supports WAV format",
+            )
+
+        # Perform TTS
+        if req.streaming:
+            return StreamResponse(
+                iterable=inference_async(req, engine),
+                headers={
+                    "Content-Disposition": f"attachment; filename=audio.{req.format}",
+                },
+                content_type=get_content_type(req.format),
+            )
+        else:
+            fake_audios = next(inference(req, engine))
+            buffer = io.BytesIO()
+            sf.write(
+                buffer,
+                fake_audios,
+                sample_rate,
+                format=req.format,
+            )
+
+            return StreamResponse(
+                iterable=buffer_to_async_generator(buffer.getvalue()),
+                headers={
+                    "Content-Disposition": f"attachment; filename=audio.{req.format}",
+                },
+                content_type=get_content_type(req.format),
+            )
+
+
+class ChatView(HttpView):
+    """
+    Perform chatbot inference on the input text.
+    """
+
+    @classmethod
+    async def post(cls):
+        # Decode the request
+        payload = await request.data()
+        req = ServeChatRequest(**payload)
+
+        # Check that the number of samples requested is correct
+        if req.num_samples < 1 or req.num_samples > MAX_NUM_SAMPLES:
+            raise HTTPException(
+                HTTPStatus.BAD_REQUEST,
+                content=f"Number of samples must be between 1 and {MAX_NUM_SAMPLES}",
+            )
+
+        # Get the type of content provided
+        content_type = request.headers.get("Content-Type", "application/json")
+        json_mode = "application/json" in content_type
+
+        # Get the models from the app
+        model_manager: ModelManager = request.app.state.model_manager
+        llama_queue = model_manager.llama_queue
+        tokenizer = model_manager.tokenizer
+        config = model_manager.config
+
+        device = request.app.state.device
+
+        # Get the response generators
+        response_generator = get_response_generator(
+            llama_queue, tokenizer, config, req, device, json_mode
+        )
+
+        # Return the response in the correct format
+        if req.streaming is False:
+            result = response_generator()
+            if json_mode:
+                return JSONResponse(result.model_dump())
+            else:
+                return ormsgpack.packb(result, option=ormsgpack.OPT_SERIALIZE_PYDANTIC)
+
+        return StreamResponse(
+            iterable=response_generator(), content_type="text/event-stream"
+        )

+ 0 - 570
tools/webui.py

@@ -1,570 +0,0 @@
-import gc
-import html
-import io
-import os
-import queue
-import wave
-from argparse import ArgumentParser
-from functools import partial
-from pathlib import Path
-
-import gradio as gr
-import librosa
-import numpy as np
-import pyrootutils
-import torch
-from loguru import logger
-from transformers import AutoTokenizer
-
-pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
-
-
-from fish_speech.i18n import i18n
-from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
-from fish_speech.utils import autocast_exclude_mps, set_seed
-from tools.api import decode_vq_tokens, encode_reference
-from tools.file import AUDIO_EXTENSIONS, audio_to_bytes, list_files, read_ref_text
-from tools.llama.generate import (
-    GenerateRequest,
-    GenerateResponse,
-    WrappedGenerateResponse,
-    launch_thread_safe_queue,
-)
-from tools.schema import (
-    GLOBAL_NUM_SAMPLES,
-    ASRPackRequest,
-    ServeASRRequest,
-    ServeASRResponse,
-    ServeASRSegment,
-    ServeAudioPart,
-    ServeForwardMessage,
-    ServeMessage,
-    ServeReferenceAudio,
-    ServeRequest,
-    ServeResponse,
-    ServeStreamDelta,
-    ServeStreamResponse,
-    ServeTextPart,
-    ServeTimedASRResponse,
-    ServeTTSRequest,
-    ServeVQGANDecodeRequest,
-    ServeVQGANDecodeResponse,
-    ServeVQGANEncodeRequest,
-    ServeVQGANEncodeResponse,
-    ServeVQPart,
-)
-from tools.vqgan.inference import load_model as load_decoder_model
-
-# Make einx happy
-os.environ["EINX_FILTER_TRACEBACK"] = "false"
-
-
-HEADER_MD = f"""# Fish Speech
-
-{i18n("A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).")}  
-
-{i18n("You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1.5).")}  
-
-{i18n("Related code and weights are released under CC BY-NC-SA 4.0 License.")}  
-
-{i18n("We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.")}  
-"""
-
-TEXTBOX_PLACEHOLDER = i18n("Put your text here.")
-SPACE_IMPORTED = False
-
-
-def build_html_error_message(error):
-    return f"""
-    <div style="color: red; 
-    font-weight: bold;">
-        {html.escape(str(error))}
-    </div>
-    """
-
-
-@torch.inference_mode()
-def inference(req: ServeTTSRequest):
-
-    idstr: str | None = req.reference_id
-    prompt_tokens, prompt_texts = [], []
-    if idstr is not None:
-        ref_folder = Path("references") / idstr
-        ref_folder.mkdir(parents=True, exist_ok=True)
-        ref_audios = list_files(
-            ref_folder, AUDIO_EXTENSIONS, recursive=True, sort=False
-        )
-
-        if req.use_memory_cache == "never" or (
-            req.use_memory_cache == "on-demand" and len(prompt_tokens) == 0
-        ):
-            prompt_tokens = [
-                encode_reference(
-                    decoder_model=decoder_model,
-                    reference_audio=audio_to_bytes(str(ref_audio)),
-                    enable_reference_audio=True,
-                )
-                for ref_audio in ref_audios
-            ]
-            prompt_texts = [
-                read_ref_text(str(ref_audio.with_suffix(".lab")))
-                for ref_audio in ref_audios
-            ]
-        else:
-            logger.info("Use same references")
-
-    else:
-        # Parse reference audio aka prompt
-        refs = req.references
-
-        if req.use_memory_cache == "never" or (
-            req.use_memory_cache == "on-demand" and len(prompt_tokens) == 0
-        ):
-            prompt_tokens = [
-                encode_reference(
-                    decoder_model=decoder_model,
-                    reference_audio=ref.audio,
-                    enable_reference_audio=True,
-                )
-                for ref in refs
-            ]
-            prompt_texts = [ref.text for ref in refs]
-        else:
-            logger.info("Use same references")
-
-    if req.seed is not None:
-        set_seed(req.seed)
-        logger.warning(f"set seed: {req.seed}")
-
-    # LLAMA Inference
-    request = dict(
-        device=decoder_model.device,
-        max_new_tokens=req.max_new_tokens,
-        text=(
-            req.text
-            if not req.normalize
-            else ChnNormedText(raw_text=req.text).normalize()
-        ),
-        top_p=req.top_p,
-        repetition_penalty=req.repetition_penalty,
-        temperature=req.temperature,
-        compile=args.compile,
-        iterative_prompt=req.chunk_length > 0,
-        chunk_length=req.chunk_length,
-        max_length=4096,
-        prompt_tokens=prompt_tokens,
-        prompt_text=prompt_texts,
-    )
-
-    response_queue = queue.Queue()
-    llama_queue.put(
-        GenerateRequest(
-            request=request,
-            response_queue=response_queue,
-        )
-    )
-
-    segments = []
-
-    while True:
-        result: WrappedGenerateResponse = response_queue.get()
-        if result.status == "error":
-            yield None, None, build_html_error_message(result.response)
-            break
-
-        result: GenerateResponse = result.response
-        if result.action == "next":
-            break
-
-        with autocast_exclude_mps(
-            device_type=decoder_model.device.type, dtype=args.precision
-        ):
-            fake_audios = decode_vq_tokens(
-                decoder_model=decoder_model,
-                codes=result.codes,
-            )
-
-        fake_audios = fake_audios.float().cpu().numpy()
-        segments.append(fake_audios)
-
-    if len(segments) == 0:
-        return (
-            None,
-            None,
-            build_html_error_message(
-                i18n("No audio generated, please check the input text.")
-            ),
-        )
-
-    # No matter streaming or not, we need to return the final audio
-    audio = np.concatenate(segments, axis=0)
-    yield None, (decoder_model.spec_transform.sample_rate, audio), None
-
-    if torch.cuda.is_available():
-        torch.cuda.empty_cache()
-        gc.collect()
-
-
-n_audios = 4
-
-global_audio_list = []
-global_error_list = []
-
-
-def inference_wrapper(
-    text,
-    enable_reference_audio,
-    reference_audio,
-    reference_text,
-    max_new_tokens,
-    chunk_length,
-    top_p,
-    repetition_penalty,
-    temperature,
-    seed,
-    batch_infer_num,
-):
-    audios = []
-    errors = []
-
-    for _ in range(batch_infer_num):
-        result = inference(
-            text,
-            enable_reference_audio,
-            reference_audio,
-            reference_text,
-            max_new_tokens,
-            chunk_length,
-            top_p,
-            repetition_penalty,
-            temperature,
-            seed,
-        )
-
-        _, audio_data, error_message = next(result)
-
-        audios.append(
-            gr.Audio(value=audio_data if audio_data else None, visible=True),
-        )
-        errors.append(
-            gr.HTML(value=error_message if error_message else None, visible=True),
-        )
-
-    for _ in range(batch_infer_num, n_audios):
-        audios.append(
-            gr.Audio(value=None, visible=False),
-        )
-        errors.append(
-            gr.HTML(value=None, visible=False),
-        )
-
-    return None, *audios, *errors
-
-
-def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
-    buffer = io.BytesIO()
-
-    with wave.open(buffer, "wb") as wav_file:
-        wav_file.setnchannels(channels)
-        wav_file.setsampwidth(bit_depth // 8)
-        wav_file.setframerate(sample_rate)
-
-    wav_header_bytes = buffer.getvalue()
-    buffer.close()
-    return wav_header_bytes
-
-
-def normalize_text(user_input, use_normalization):
-    if use_normalization:
-        return ChnNormedText(raw_text=user_input).normalize()
-    else:
-        return user_input
-
-
-def update_examples():
-    examples_dir = Path("references")
-    examples_dir.mkdir(parents=True, exist_ok=True)
-    example_audios = list_files(examples_dir, AUDIO_EXTENSIONS, recursive=True)
-    return gr.Dropdown(choices=example_audios + [""])
-
-
-def build_app():
-    with gr.Blocks(theme=gr.themes.Base()) as app:
-        gr.Markdown(HEADER_MD)
-
-        # Use light theme by default
-        app.load(
-            None,
-            None,
-            js="() => {const params = new URLSearchParams(window.location.search);if (!params.has('__theme')) {params.set('__theme', '%s');window.location.search = params.toString();}}"
-            % args.theme,
-        )
-
-        # Inference
-        with gr.Row():
-            with gr.Column(scale=3):
-                text = gr.Textbox(
-                    label=i18n("Input Text"), placeholder=TEXTBOX_PLACEHOLDER, lines=10
-                )
-                refined_text = gr.Textbox(
-                    label=i18n("Realtime Transform Text"),
-                    placeholder=i18n(
-                        "Normalization Result Preview (Currently Only Chinese)"
-                    ),
-                    lines=5,
-                    interactive=False,
-                )
-
-                with gr.Row():
-                    normalize = gr.Checkbox(
-                        label=i18n("Text Normalization"),
-                        value=False,
-                    )
-
-                with gr.Row():
-                    with gr.Column():
-                        with gr.Tab(label=i18n("Advanced Config")):
-                            with gr.Row():
-                                chunk_length = gr.Slider(
-                                    label=i18n("Iterative Prompt Length, 0 means off"),
-                                    minimum=0,
-                                    maximum=300,
-                                    value=200,
-                                    step=8,
-                                )
-
-                                max_new_tokens = gr.Slider(
-                                    label=i18n(
-                                        "Maximum tokens per batch, 0 means no limit"
-                                    ),
-                                    minimum=0,
-                                    maximum=2048,
-                                    value=0,
-                                    step=8,
-                                )
-
-                            with gr.Row():
-                                top_p = gr.Slider(
-                                    label="Top-P",
-                                    minimum=0.6,
-                                    maximum=0.9,
-                                    value=0.7,
-                                    step=0.01,
-                                )
-
-                                repetition_penalty = gr.Slider(
-                                    label=i18n("Repetition Penalty"),
-                                    minimum=1,
-                                    maximum=1.5,
-                                    value=1.2,
-                                    step=0.01,
-                                )
-
-                            with gr.Row():
-                                temperature = gr.Slider(
-                                    label="Temperature",
-                                    minimum=0.6,
-                                    maximum=0.9,
-                                    value=0.7,
-                                    step=0.01,
-                                )
-                                seed = gr.Number(
-                                    label="Seed",
-                                    info="0 means randomized inference, otherwise deterministic",
-                                    value=0,
-                                )
-
-                        with gr.Tab(label=i18n("Reference Audio")):
-                            with gr.Row():
-                                gr.Markdown(
-                                    i18n(
-                                        "5 to 10 seconds of reference audio, useful for specifying speaker."
-                                    )
-                                )
-                            with gr.Row():
-                                reference_id = gr.Textbox(
-                                    label=i18n("Reference ID"),
-                                    placeholder="Leave empty to use uploaded references",
-                                )
-
-                            with gr.Row():
-                                use_memory_cache = gr.Radio(
-                                    label=i18n("Use Memory Cache"),
-                                    choices=["never", "on-demand", "always"],
-                                    value="on-demand",
-                                )
-
-                            with gr.Row():
-                                reference_audio = gr.Audio(
-                                    label=i18n("Reference Audio"),
-                                    type="filepath",
-                                )
-                            with gr.Row():
-                                reference_text = gr.Textbox(
-                                    label=i18n("Reference Text"),
-                                    lines=1,
-                                    placeholder="在一无所知中,梦里的一天结束了,一个新的「轮回」便会开始。",
-                                    value="",
-                                )
-
-            with gr.Column(scale=3):
-                with gr.Row():
-                    error = gr.HTML(
-                        label=i18n("Error Message"),
-                        visible=True,
-                    )
-                with gr.Row():
-                    audio = gr.Audio(
-                        label=i18n("Generated Audio"),
-                        type="numpy",
-                        interactive=False,
-                        visible=True,
-                    )
-
-                with gr.Row():
-                    with gr.Column(scale=3):
-                        generate = gr.Button(
-                            value="\U0001F3A7 " + i18n("Generate"), variant="primary"
-                        )
-
-        text.input(fn=normalize_text, inputs=[text, normalize], outputs=[refined_text])
-
-        def inference_wrapper(
-            text,
-            normalize,
-            reference_id,
-            reference_audio,
-            reference_text,
-            max_new_tokens,
-            chunk_length,
-            top_p,
-            repetition_penalty,
-            temperature,
-            seed,
-            use_memory_cache,
-        ):
-            references = []
-            if reference_audio:
-                # 将文件路径转换为字节
-                with open(reference_audio, "rb") as audio_file:
-                    audio_bytes = audio_file.read()
-                references = [
-                    ServeReferenceAudio(audio=audio_bytes, text=reference_text)
-                ]
-
-            req = ServeTTSRequest(
-                text=text,
-                normalize=normalize,
-                reference_id=reference_id if reference_id else None,
-                references=references,
-                max_new_tokens=max_new_tokens,
-                chunk_length=chunk_length,
-                top_p=top_p,
-                repetition_penalty=repetition_penalty,
-                temperature=temperature,
-                seed=int(seed) if seed else None,
-                use_memory_cache=use_memory_cache,
-            )
-
-            for result in inference(req):
-                if result[2]:  # Error message
-                    return None, result[2]
-                elif result[1]:  # Audio data
-                    return result[1], None
-
-            return None, i18n("No audio generated")
-
-        # Submit
-        generate.click(
-            inference_wrapper,
-            [
-                refined_text,
-                normalize,
-                reference_id,
-                reference_audio,
-                reference_text,
-                max_new_tokens,
-                chunk_length,
-                top_p,
-                repetition_penalty,
-                temperature,
-                seed,
-                use_memory_cache,
-            ],
-            [audio, error],
-            concurrency_limit=1,
-        )
-
-    return app
-
-
-def parse_args():
-    parser = ArgumentParser()
-    parser.add_argument(
-        "--llama-checkpoint-path",
-        type=Path,
-        default="checkpoints/fish-speech-1.5",
-    )
-    parser.add_argument(
-        "--decoder-checkpoint-path",
-        type=Path,
-        default="checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
-    )
-    parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
-    parser.add_argument("--device", type=str, default="cuda")
-    parser.add_argument("--half", action="store_true")
-    parser.add_argument("--compile", action="store_true")
-    parser.add_argument("--max-gradio-length", type=int, default=0)
-    parser.add_argument("--theme", type=str, default="light")
-
-    return parser.parse_args()
-
-
-if __name__ == "__main__":
-    args = parse_args()
-    args.precision = torch.half if args.half else torch.bfloat16
-
-    # Check if CUDA is available
-    if not torch.cuda.is_available():
-        logger.info("CUDA is not available, running on CPU.")
-        args.device = "cpu"
-
-    logger.info("Loading Llama model...")
-    llama_queue = launch_thread_safe_queue(
-        checkpoint_path=args.llama_checkpoint_path,
-        device=args.device,
-        precision=args.precision,
-        compile=args.compile,
-    )
-    logger.info("Llama model loaded, loading VQ-GAN model...")
-
-    decoder_model = load_decoder_model(
-        config_name=args.decoder_config_name,
-        checkpoint_path=args.decoder_checkpoint_path,
-        device=args.device,
-    )
-
-    logger.info("Decoder model loaded, warming up...")
-
-    # Dry run to check if the model is loaded correctly and avoid the first-time latency
-    list(
-        inference(
-            ServeTTSRequest(
-                text="Hello world.",
-                references=[],
-                reference_id=None,
-                max_new_tokens=0,
-                chunk_length=200,
-                top_p=0.7,
-                repetition_penalty=1.5,
-                temperature=0.7,
-                emotion=None,
-                format="wav",
-            )
-        )
-    )
-
-    logger.info("Warming up done, launching the web UI...")
-
-    app = build_app()
-    app.launch(show_api=True)

+ 173 - 0
tools/webui/__init__.py

@@ -0,0 +1,173 @@
+from typing import Callable
+
+import gradio as gr
+
+from fish_speech.i18n import i18n
+from tools.inference_engine.utils import normalize_text
+from tools.webui.variables import HEADER_MD, TEXTBOX_PLACEHOLDER
+
+
+def build_app(inference_fct: Callable, theme: str = "light") -> gr.Blocks:
+    with gr.Blocks(theme=gr.themes.Base()) as app:
+        gr.Markdown(HEADER_MD)
+
+        # Use light theme by default
+        app.load(
+            None,
+            None,
+            js="() => {const params = new URLSearchParams(window.location.search);if (!params.has('__theme')) {params.set('__theme', '%s');window.location.search = params.toString();}}"
+            % theme,
+        )
+
+        # Inference
+        with gr.Row():
+            with gr.Column(scale=3):
+                text = gr.Textbox(
+                    label=i18n("Input Text"), placeholder=TEXTBOX_PLACEHOLDER, lines=10
+                )
+                refined_text = gr.Textbox(
+                    label=i18n("Realtime Transform Text"),
+                    placeholder=i18n(
+                        "Normalization Result Preview (Currently Only Chinese)"
+                    ),
+                    lines=5,
+                    interactive=False,
+                )
+
+                with gr.Row():
+                    normalize = gr.Checkbox(
+                        label=i18n("Text Normalization"),
+                        value=False,
+                    )
+
+                with gr.Row():
+                    with gr.Column():
+                        with gr.Tab(label=i18n("Advanced Config")):
+                            with gr.Row():
+                                chunk_length = gr.Slider(
+                                    label=i18n("Iterative Prompt Length, 0 means off"),
+                                    minimum=0,
+                                    maximum=300,
+                                    value=200,
+                                    step=8,
+                                )
+
+                                max_new_tokens = gr.Slider(
+                                    label=i18n(
+                                        "Maximum tokens per batch, 0 means no limit"
+                                    ),
+                                    minimum=0,
+                                    maximum=2048,
+                                    value=0,
+                                    step=8,
+                                )
+
+                            with gr.Row():
+                                top_p = gr.Slider(
+                                    label="Top-P",
+                                    minimum=0.6,
+                                    maximum=0.9,
+                                    value=0.7,
+                                    step=0.01,
+                                )
+
+                                repetition_penalty = gr.Slider(
+                                    label=i18n("Repetition Penalty"),
+                                    minimum=1,
+                                    maximum=1.5,
+                                    value=1.2,
+                                    step=0.01,
+                                )
+
+                            with gr.Row():
+                                temperature = gr.Slider(
+                                    label="Temperature",
+                                    minimum=0.6,
+                                    maximum=0.9,
+                                    value=0.7,
+                                    step=0.01,
+                                )
+                                seed = gr.Number(
+                                    label="Seed",
+                                    info="0 means randomized inference, otherwise deterministic",
+                                    value=0,
+                                )
+
+                        with gr.Tab(label=i18n("Reference Audio")):
+                            with gr.Row():
+                                gr.Markdown(
+                                    i18n(
+                                        "5 to 10 seconds of reference audio, useful for specifying speaker."
+                                    )
+                                )
+                            with gr.Row():
+                                reference_id = gr.Textbox(
+                                    label=i18n("Reference ID"),
+                                    placeholder="Leave empty to use uploaded references",
+                                )
+
+                            with gr.Row():
+                                use_memory_cache = gr.Radio(
+                                    label=i18n("Use Memory Cache"),
+                                    choices=["on", "off"],
+                                    value="on",
+                                )
+
+                            with gr.Row():
+                                reference_audio = gr.Audio(
+                                    label=i18n("Reference Audio"),
+                                    type="filepath",
+                                )
+                            with gr.Row():
+                                reference_text = gr.Textbox(
+                                    label=i18n("Reference Text"),
+                                    lines=1,
+                                    placeholder="在一无所知中,梦里的一天结束了,一个新的「轮回」便会开始。",
+                                    value="",
+                                )
+
+            with gr.Column(scale=3):
+                with gr.Row():
+                    error = gr.HTML(
+                        label=i18n("Error Message"),
+                        visible=True,
+                    )
+                with gr.Row():
+                    audio = gr.Audio(
+                        label=i18n("Generated Audio"),
+                        type="numpy",
+                        interactive=False,
+                        visible=True,
+                    )
+
+                with gr.Row():
+                    with gr.Column(scale=3):
+                        generate = gr.Button(
+                            value="\U0001F3A7 " + i18n("Generate"),
+                            variant="primary",
+                        )
+
+        text.input(fn=normalize_text, inputs=[text, normalize], outputs=[refined_text])
+
+        # Submit
+        generate.click(
+            inference_fct,
+            [
+                refined_text,
+                normalize,
+                reference_id,
+                reference_audio,
+                reference_text,
+                max_new_tokens,
+                chunk_length,
+                top_p,
+                repetition_penalty,
+                temperature,
+                seed,
+                use_memory_cache,
+            ],
+            [audio, error],
+            concurrency_limit=1,
+        )
+
+    return app

+ 91 - 0
tools/webui/inference.py

@@ -0,0 +1,91 @@
+import html
+from functools import partial
+from typing import Any, Callable
+
+from fish_speech.i18n import i18n
+from tools.schema import ServeReferenceAudio, ServeTTSRequest
+
+
+def inference_wrapper(
+    text,
+    normalize,
+    reference_id,
+    reference_audio,
+    reference_text,
+    max_new_tokens,
+    chunk_length,
+    top_p,
+    repetition_penalty,
+    temperature,
+    seed,
+    use_memory_cache,
+    engine,
+):
+    """
+    Wrapper for the inference function.
+    Used in the Gradio interface.
+    """
+
+    if reference_audio:
+        references = get_reference_audio(reference_audio, reference_text)
+    else:
+        references = []
+
+    req = ServeTTSRequest(
+        text=text,
+        normalize=normalize,
+        reference_id=reference_id if reference_id else None,
+        references=references,
+        max_new_tokens=max_new_tokens,
+        chunk_length=chunk_length,
+        top_p=top_p,
+        repetition_penalty=repetition_penalty,
+        temperature=temperature,
+        seed=int(seed) if seed else None,
+        use_memory_cache=use_memory_cache,
+    )
+
+    for result in engine.inference(req):
+        match result.code:
+            case "final":
+                return result.audio, None
+            case "error":
+                return None, build_html_error_message(i18n(result.error))
+            case _:
+                pass
+
+    return None, i18n("No audio generated")
+
+
+def get_reference_audio(reference_audio: str, reference_text: str) -> list:
+    """
+    Get the reference audio bytes.
+    """
+
+    with open(reference_audio, "rb") as audio_file:
+        audio_bytes = audio_file.read()
+
+    return [ServeReferenceAudio(audio=audio_bytes, text=reference_text)]
+
+
+def build_html_error_message(error: Any) -> str:
+
+    error = error if isinstance(error, Exception) else Exception("Unknown error")
+
+    return f"""
+    <div style="color: red; 
+    font-weight: bold;">
+        {html.escape(str(error))}
+    </div>
+    """
+
+
+def get_inference_wrapper(engine) -> Callable:
+    """
+    Get the inference function with the immutable arguments.
+    """
+
+    return partial(
+        inference_wrapper,
+        engine=engine,
+    )

+ 14 - 0
tools/webui/variables.py

@@ -0,0 +1,14 @@
+from fish_speech.i18n import i18n
+
+HEADER_MD = f"""# Fish Speech
+
+{i18n("A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).")}  
+
+{i18n("You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1.5).")}  
+
+{i18n("Related code and weights are released under CC BY-NC-SA 4.0 License.")}  
+
+{i18n("We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.")}  
+"""
+
+TEXTBOX_PLACEHOLDER = i18n("Put your text here.")