Browse Source

feat: enable more workers in `api.py` (#621)

* Readmes, deps, api workers

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix speed loss after compiling

* revert log

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
spicysama 1 year ago
parent
commit
f15d9f23a9
13 changed files with 113 additions and 70 deletions
  1. 9 13
      README.ja.md
  2. 4 3
      README.md
  3. 3 5
      README.pt-BR.md
  4. 4 5
      README.zh.md
  5. 3 3
      docs/en/index.md
  6. 3 3
      docs/ja/index.md
  7. 3 3
      docs/pt/index.md
  8. 3 3
      docs/zh/index.md
  9. 1 1
      install_env.bat
  10. 3 2
      pyproject.toml
  11. 69 29
      tools/api.py
  12. 1 0
      tools/commons.py
  13. 7 0
      tools/post_api.py

+ 9 - 13
README.ja.md

@@ -1,4 +1,3 @@
-
 <div align="center">
 <h1>Fish Speech</h1>
 
@@ -15,7 +14,7 @@
 <br>
 
 <div align="center">
-    <img src="https://counter.seku.su/cmoe?name=fish-speech&theme=asoul" /><br>
+    <img src="https://count.getloli.com/get/@fish-speech?theme=asoul" /><br>
 </div>
 <br>
 
@@ -31,28 +30,25 @@
     </a>
 </div>
 
-このコードベースとすべてのモデルは、CC-BY-NC-SA-4.0ライセンスの下でリリースされています。詳細については、[LICENSE](LICENSE)を参照してください。
+このコードベースとすべてのモデルは、CC-BY-NC-SA-4.0 ライセンスの下でリリースされています。詳細については、[LICENSE](LICENSE)を参照してください。
 
 ---
 
 ## 機能
 
-1. **ゼロショット & フューショット TTS**:10〜30秒の音声サンプルを入力して、高品質のTTS出力を生成します。**詳細は [音声クローンのベストプラクティス](https://docs.fish.audio/text-to-speech/voice-clone-best-practices) を参照してください。**
+1. **ゼロショット & フューショット TTS**:10〜30 秒の音声サンプルを入力して、高品質の TTS 出力を生成します。**詳細は [音声クローンのベストプラクティス](https://docs.fish.audio/text-to-speech/voice-clone-best-practices) を参照してください。**
 2. **多言語 & クロスリンガル対応**:多言語テキストを入力ボックスにコピーペーストするだけで、言語を気にする必要はありません。現在、英語、日本語、韓国語、中国語、フランス語、ドイツ語、アラビア語、スペイン語に対応しています。
-3. **音素依存なし**:このモデルは強力な汎化能力を持ち、TTSに音素を必要としません。あらゆる言語スクリプトに対応可能です。
-4. **高精度**:5分間の英語テキストに対し、CER(文字誤り率)とWER(単語誤り率)は約2%の精度を達成します。
-5. **高速**:fish-techアクセラレーションにより、Nvidia RTX 4060ラップトップではリアルタイムファクターが約1:5、Nvidia RTX 4090では約1:15です。
-6. **WebUI 推論**:使いやすいGradioベースのWebユーザーインターフェースを搭載し、Chrome、Firefox、Edgeなどのブラウザに対応しています。
-7. **GUI 推論**:PyQt6のグラフィカルインターフェースを提供し、APIサーバーとシームレスに連携します。Linux、Windows、macOSに対応しています。[GUIを見る](https://github.com/AnyaCoder/fish-speech-gui)。
-8. **デプロイしやすい**:Linux、Windows、macOSにネイティブ対応した推論サーバーを簡単にセットアップでき、速度の低下を最小限に抑えます。
-
-
+3. **音素依存なし**:このモデルは強力な汎化能力を持ち、TTS に音素を必要としません。あらゆる言語スクリプトに対応可能です。
+4. **高精度**:5 分間の英語テキストに対し、CER(文字誤り率)と WER(単語誤り率)は約 2%の精度を達成します。
+5. **高速**:fish-tech アクセラレーションにより、Nvidia RTX 4060 ラップトップではリアルタイムファクターが約 1:5、Nvidia RTX 4090 では約 1:15 です。
+6. **WebUI 推論**:使いやすい Gradio ベースの Web ユーザーインターフェースを搭載し、Chrome、Firefox、Edge などのブラウザに対応しています。
+7. **GUI 推論**:PyQt6 のグラフィカルインターフェースを提供し、API サーバーとシームレスに連携します。Linux、Windows、macOS に対応しています。[GUI を見る](https://github.com/AnyaCoder/fish-speech-gui)。
+8. **デプロイしやすい**:Linux、Windows、macOS にネイティブ対応した推論サーバーを簡単にセットアップでき、速度の低下を最小限に抑えます。
 
 ## 免責事項
 
 コードベースの違法な使用については一切責任を負いません。DMCA(デジタルミレニアム著作権法)およびその他の関連法については、地域の法律を参照してください。
 
-
 ## オンラインデモ
 
 [Fish Audio](https://fish.audio)

+ 4 - 3
README.md

@@ -1,4 +1,3 @@
-
 <div align="center">
 <h1>Fish Speech</h1>
 
@@ -15,8 +14,9 @@
 <br>
 
 <div align="center">
-    <img src="https://counter.seku.su/cmoe?name=fish-speech&theme=asoul" /><br>
+    <img src="https://count.getloli.com/get/@fish-speech?theme=asoul" /><br>
 </div>
+
 <br>
 
 <div align="center">
@@ -31,7 +31,7 @@
     </a>
 </div>
 
-This codebase and all models are released under CC-BY-NC-SA-4.0 License. Please refer to [LICENSE](LICENSE) for more details. 
+This codebase and all models are released under CC-BY-NC-SA-4.0 License. Please refer to [LICENSE](LICENSE) for more details.
 
 ---
 
@@ -54,6 +54,7 @@ This codebase and all models are released under CC-BY-NC-SA-4.0 License. Please
 8. **Deploy-Friendly:** Easily set up an inference server with native support for Linux, Windows and MacOS, minimizing speed loss.
 
 ## Disclaimer
+
 We do not hold any responsibility for any illegal usage of the codebase. Please refer to your local laws about DMCA and other related laws.
 
 ## Online Demo

+ 3 - 5
README.pt-BR.md

@@ -1,4 +1,3 @@
-
 <div align="center">
 <h1>Fish Speech</h1>
 
@@ -15,8 +14,9 @@
 <br>
 
 <div align="center">
-    <img src="https://counter.seku.su/cmoe?name=fish-speech&theme=asoul" /><br>
+    <img src="https://count.getloli.com/get/@fish-speech?theme=asoul" /><br>
 </div>
+
 <br>
 
 <div align="center">
@@ -34,6 +34,7 @@
 Este código-fonte e os modelos são publicados sob a licença CC-BY-NC-SA-4.0. Consulte [LICENSE](LICENSE) para mais detalhes.
 
 ---
+
 ## Funcionalidades
 
 1. **TTS Zero-shot & Few-shot**: Insira uma amostra vocal de 10 a 30 segundos para gerar saída de TTS de alta qualidade. **Para diretrizes detalhadas, veja [Melhores Práticas para Clonagem de Voz](https://docs.fish.audio/text-to-speech/voice-clone-best-practices).**
@@ -52,13 +53,10 @@ Este código-fonte e os modelos são publicados sob a licença CC-BY-NC-SA-4.0.
 
 8. **Fácil de Implantar**: Configura facilmente um servidor de inferência com suporte nativo para Linux, Windows e macOS, minimizando a perda de velocidade.
 
-   
-
 ## Isenção de Responsabilidade
 
 Não nos responsabilizamos por qualquer uso ilegal do código-fonte. Consulte as leis locais sobre DMCA (Digital Millennium Copyright Act) e outras leis relevantes em sua região.
 
-
 ## Demonstração Online
 
 [Fish Audio](https://fish.audio)

+ 4 - 5
README.zh.md

@@ -1,4 +1,3 @@
-
 <div align="center">
 <h1>Fish Speech</h1>
 
@@ -15,8 +14,9 @@
 <br>
 
 <div align="center">
-    <img src="https://counter.seku.su/cmoe?name=fish-speech&theme=asoul" /><br>
+    <img src="https://count.getloli.com/get/@fish-speech?theme=asoul" /><br>
 </div>
+
 <br>
 
 <div align="center">
@@ -30,13 +30,14 @@
         <img alt="Huggingface" src="https://img.shields.io/badge/🤗%20-space%20demo-yellow"/>
     </a>
     <br>
-    
+
 
 </div>
 
 此代码库及模型根据 CC-BY-NC-SA-4.0 许可证发布。请参阅 [LICENSE](LICENSE) 了解更多细节.
 
 ---
+
 ## 特性
 
 1. **零样本 & 小样本 TTS**:输入 10 到 30 秒的声音样本即可生成高质量的 TTS 输出。**详见 [语音克隆最佳实践指南](https://docs.fish.audio/text-to-speech/voice-clone-best-practices)。**
@@ -48,12 +49,10 @@
 7. **GUI 推理**:提供 PyQt6 图形界面,与 API 服务器无缝协作。支持 Linux、Windows 和 macOS。[查看 GUI](https://github.com/AnyaCoder/fish-speech-gui)。
 8. **易于部署**:轻松设置推理服务器,原生支持 Linux、Windows 和 macOS,最大程度减少速度损失。
 
-
 ## 免责声明
 
 我们不对代码库的任何非法使用承担任何责任. 请参阅您当地关于 DMCA (数字千年法案) 和其他相关法律法规.
 
-
 ## 在线 DEMO
 
 [Fish Audio](https://fish.audio)

+ 3 - 3
docs/en/index.md

@@ -35,7 +35,7 @@ conda create -n fish-speech python=3.10
 conda activate fish-speech
 
 # Install pytorch
-pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
+pip3 install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1 --index-url https://download.pytorch.org/whl/cu121
 
 # Install fish-speech
 pip3 install -e .
@@ -100,7 +100,7 @@ conda create -n fish-speech python=3.10
 conda activate fish-speech
 
 # Install pytorch
-pip3 install torch torchvision torchaudio
+pip3 install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1
 
 # Install fish-speech
 pip3 install -e .[stable]
@@ -122,7 +122,7 @@ Please refer to [this PR](https://github.com/fishaudio/fish-speech/pull/461#issu
 conda create -n fish-speech python=3.10
 conda activate fish-speech
 # install pytorch
-pip install torch torchvision torchaudio
+pip install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1
 # install fish-speech
 pip install -e .[stable]
 ```

+ 3 - 3
docs/ja/index.md

@@ -35,7 +35,7 @@ conda create -n fish-speech python=3.10
 conda activate fish-speech
 
 # PyTorchをインストール
-pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
+pip3 install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1 --index-url https://download.pytorch.org/whl/cu121
 
 # fish-speechをインストール
 pip3 install -e .
@@ -98,7 +98,7 @@ conda create -n fish-speech python=3.10
 conda activate fish-speech
 
 # pytorchをインストールします。
-pip3 install torch torchvision torchaudio
+pip3 install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1
 
 # fish-speechをインストールします。
 pip3 install -e .[stable]
@@ -120,7 +120,7 @@ apt install libsox-dev ffmpeg
 conda create -n fish-speech python=3.10
 conda activate fish-speech
 # install pytorch
-pip install torch torchvision torchaudio
+pip install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1
 # install fish-speech
 pip install -e .[stable]
 ```

+ 3 - 3
docs/pt/index.md

@@ -35,7 +35,7 @@ conda create -n fish-speech python=3.10
 conda activate fish-speech
 
 # Instale o pytorch
-pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
+pip3 install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1 --index-url https://download.pytorch.org/whl/cu121
 
 # Instale o fish-speech
 pip3 install -e .
@@ -96,7 +96,7 @@ conda create -n fish-speech python=3.10
 conda activate fish-speech
 
 # Instale o pytorch
-pip3 install torch torchvision torchaudio
+pip3 install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1
 
 # Instale o fish-speech
 pip3 install -e .[stable]
@@ -118,7 +118,7 @@ Para uma comparação das velocidades de inferência, consulte [este PR](https:/
 conda create -n fish-speech python=3.10
 conda activate fish-speech
 # install pytorch
-pip install torch torchvision torchaudio
+pip install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1
 # install fish-speech
 pip install -e .[stable]
 ```

+ 3 - 3
docs/zh/index.md

@@ -35,7 +35,7 @@ conda create -n fish-speech python=3.10
 conda activate fish-speech
 
 # 安装 pytorch
-pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
+pip3 install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1 --index-url https://download.pytorch.org/whl/cu121
 
 # 安装 fish-speech
 pip3 install -e .
@@ -95,7 +95,7 @@ conda create -n fish-speech python=3.10
 conda activate fish-speech
 
 # 安装 pytorch
-pip3 install torch torchvision torchaudio
+pip3 install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1
 
 # 安装 fish-speech
 pip3 install -e .[stable]
@@ -117,7 +117,7 @@ apt install libsox-dev ffmpeg
 conda create -n fish-speech python=3.10
 conda activate fish-speech
 # install pytorch
-pip install torch torchvision torchaudio
+pip install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1
 # install fish-speech
 pip install -e .[stable]
 ```

+ 1 - 1
install_env.bat

@@ -133,7 +133,7 @@ if "%USE_MIRROR%"=="true" (
 echo "HF_ENDPOINT: !HF_ENDPOINT!"
 echo "NO_PROXY: !no_proxy!"
 
-%PIP_CMD% install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
+%PIP_CMD% install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1 --index-url https://download.pytorch.org/whl/cu121
 
 %PIP_CMD% install -e . --upgrade-strategy only-if-needed
 

+ 3 - 2
pyproject.toml

@@ -23,7 +23,7 @@ dependencies = [
     "einops>=0.7.0",
     "librosa>=0.10.1",
     "rich>=13.5.3",
-    "gradio>=4.0.0",
+    "gradio<5.0.0",
     "wandb>=0.15.11",
     "grpcio>=1.58.0",
     "kui>=1.6.0",
@@ -37,6 +37,7 @@ dependencies = [
     "einx[torch]==0.2.2",
     "zstandard>=0.22.0",
     "pydub",
+    "pyaudio",
     "faster_whisper",
     "modelscope==1.17.1",
     "funasr==1.1.5",
@@ -47,7 +48,7 @@ dependencies = [
 
 [project.optional-dependencies]
 stable = [
-    "torch>=2.3.1",
+    "torch<=2.4.1",
     "torchaudio",
 ]
 

+ 69 - 29
tools/api.py

@@ -1,4 +1,5 @@
 import io
+import os
 import queue
 import sys
 import traceback
@@ -88,7 +89,8 @@ def load_audio(reference_audio, sr):
         reference_audio = io.BytesIO(audio_data)
 
     waveform, original_sr = torchaudio.load(
-        reference_audio, backend="ffmpeg" if sys.platform == "linux" else "soundfile"
+        reference_audio,
+        backend="soundfile",  # not every linux release supports 'sox' or 'ffmpeg'
     )
 
     if waveform.shape[0] > 1:
@@ -166,6 +168,8 @@ def get_content_type(audio_format):
 @torch.inference_mode()
 def inference(req: ServeTTSRequest):
 
+    global prompt_tokens, prompt_texts
+
     idstr: str | None = req.reference_id
     if idstr is not None:
         ref_folder = Path("references") / idstr
@@ -173,33 +177,43 @@ def inference(req: ServeTTSRequest):
         ref_audios = list_files(
             ref_folder, AUDIO_EXTENSIONS, recursive=True, sort=False
         )
-        prompt_tokens = [
-            encode_reference(
-                decoder_model=decoder_model,
-                reference_audio=audio_to_bytes(str(ref_audio)),
-                enable_reference_audio=True,
-            )
-            for ref_audio in ref_audios
-        ]
-        prompt_texts = [
-            read_ref_text(str(ref_audio.with_suffix(".lab")))
-            for ref_audio in ref_audios
-        ]
+
+        if req.use_memory_cache == "never" or (
+            req.use_memory_cache == "on-demand" and len(prompt_tokens) == 0
+        ):
+            prompt_tokens = [
+                encode_reference(
+                    decoder_model=decoder_model,
+                    reference_audio=audio_to_bytes(str(ref_audio)),
+                    enable_reference_audio=True,
+                )
+                for ref_audio in ref_audios
+            ]
+            prompt_texts = [
+                read_ref_text(str(ref_audio.with_suffix(".lab")))
+                for ref_audio in ref_audios
+            ]
+        else:
+            logger.info("Use same references")
 
     else:
         # Parse reference audio aka prompt
         refs = req.references
-        if refs is None:
-            refs = []
-        prompt_tokens = [
-            encode_reference(
-                decoder_model=decoder_model,
-                reference_audio=ref.audio,
-                enable_reference_audio=True,
-            )
-            for ref in refs
-        ]
-        prompt_texts = [ref.text for ref in refs]
+
+        if req.use_memory_cache == "never" or (
+            req.use_memory_cache == "on-demand" and len(prompt_tokens) == 0
+        ):
+            prompt_tokens = [
+                encode_reference(
+                    decoder_model=decoder_model,
+                    reference_audio=ref.audio,
+                    enable_reference_audio=True,
+                )
+                for ref in refs
+            ]
+            prompt_texts = [ref.text for ref in refs]
+        else:
+            logger.info("Use same references")
 
     # LLAMA Inference
     request = dict(
@@ -397,11 +411,23 @@ app = Kui(
 )
 
 
-if __name__ == "__main__":
+# Each worker process created by Uvicorn has its own memory space,
+# meaning that models and variables are not shared between processes.
+# Therefore, any global variables (like `llama_queue` or `decoder_model`)
+# will not be shared across workers.
 
-    import uvicorn
 
-    args = parse_args()
+# Multi-threading for deep learning can cause issues, such as inconsistent
+# outputs if multiple threads access the same buffers simultaneously.
+# Instead, it's better to use multiprocessing or independent models per thread.
+@app.on_startup
+def initialize_app(app: Kui):
+
+    global args, llama_queue, decoder_model, prompt_tokens, prompt_texts
+
+    prompt_tokens, prompt_texts = [], []
+
+    args = parse_args()  # args same as ones in other processes
     args.precision = torch.half if args.half else torch.bfloat16
 
     logger.info("Loading Llama model...")
@@ -411,6 +437,7 @@ if __name__ == "__main__":
         precision=args.precision,
         compile=args.compile,
     )
+
     logger.info("Llama model loaded, loading VQ-GAN model...")
 
     decoder_model = load_decoder_model(
@@ -421,7 +448,7 @@ if __name__ == "__main__":
 
     logger.info("VQ-GAN model loaded, warming up...")
 
-    # Dry run to check if the model is loaded correctly and avoid the first-time latency
+    # Dry run to ensure models work and avoid first-time latency
     list(
         inference(
             ServeTTSRequest(
@@ -440,5 +467,18 @@ if __name__ == "__main__":
     )
 
     logger.info(f"Warming up done, starting server at http://{args.listen}")
+
+
+if __name__ == "__main__":
+
+    import uvicorn
+
+    args = parse_args()
     host, port = args.listen.split(":")
-    uvicorn.run(app, host=host, port=int(port), workers=args.workers, log_level="info")
+    uvicorn.run(
+        "tools.api:app",
+        host=host,
+        port=int(port),
+        workers=args.workers,
+        log_level="info",
+    )

+ 1 - 0
tools/commons.py

@@ -20,6 +20,7 @@ class ServeTTSRequest(BaseModel):
     # For example, if you want use https://fish.audio/m/7f92f8afb8ec43bf81429cc1c9199cb1/
     # Just pass 7f92f8afb8ec43bf81429cc1c9199cb1
     reference_id: str | None = None
+    use_memory_cache: Literal["on-demand", "never"] = "never"
     # Normalize text for en & zh, this increase stability for numbers
     normalize: bool = True
     mp3_bitrate: Optional[int] = 64

+ 7 - 0
tools/post_api.py

@@ -103,6 +103,12 @@ def parse_args():
         "--channels", type=int, default=1, help="Number of audio channels"
     )
     parser.add_argument("--rate", type=int, default=44100, help="Sample rate for audio")
+    parser.add_argument(
+        "--use_memory_cache",
+        type=str,
+        default="never",
+        help="Cache encoded references codes in memory",
+    )
 
     return parser.parse_args()
 
@@ -148,6 +154,7 @@ if __name__ == "__main__":
         "speaker": args.speaker,
         "emotion": args.emotion,
         "streaming": args.streaming,
+        "use_memory_cache": args.use_memory_cache,
     }
 
     pydantic_data = ServeTTSRequest(**data)