فهرست منبع

Update docs etc. (#524)

* fully support ormsgpack

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* dependency

* torch==2.4.1 windows compilable

* Update docs

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove unused code

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove autorerank

* api usage

* back slash

* fix docs

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
spicysama 1 سال پیش
والد
کامیت
9cb84a6da5
11فایلهای تغییر یافته به همراه125 افزوده شده و 549 حذف شده
  1. 65 60
      docs/en/index.md
  2. 11 40
      docs/en/inference.md
  3. 18 6
      docs/zh/index.md
  4. 10 40
      docs/zh/inference.md
  5. 2 0
      fish_speech/train.py
  6. 5 3
      fish_speech/webui/manage.py
  7. 6 88
      install_env.bat
  8. 0 1
      tools/api.py
  9. 0 159
      tools/auto_rerank.py
  10. 5 15
      tools/llama/generate.py
  11. 3 137
      tools/webui.py

+ 65 - 60
docs/en/index.md

@@ -27,66 +27,70 @@
 
 ## Windows Setup
 
-Windows professional users may consider WSL2 or Docker to run the codebase.
-
-Non-professional Windows users can consider the following methods to run the codebase without a Linux environment (with model compilation capabilities aka `torch.compile`):
-
-<ol>
-   <li>Unzip the project package.</li>
-   <li>Click <code>install_env.bat</code> to install the environment.
-      <ul>
-            <li>You can decide whether to use a mirror site for downloads by editing the <code>USE_MIRROR</code> item in <code>install_env.bat</code>.</li>
-            <li><code>USE_MIRROR=false</code> downloads the latest stable version of <code>torch</code> from the original site. <code>USE_MIRROR=true</code> downloads the latest version of <code>torch</code> from a mirror site. The default is <code>true</code>.</li>
-            <li>You can decide whether to enable the compiled environment download by editing the <code>INSTALL_TYPE</code> item in <code>install_env.bat</code>.</li>
-            <li><code>INSTALL_TYPE=preview</code> downloads the preview version with the compiled environment. <code>INSTALL_TYPE=stable</code> downloads the stable version without the compiled environment.</li>
-      </ul>
-   </li>
-   <li>If step 2 has <code>USE_MIRROR=preview</code>, execute this step (optional, for activating the compiled model environment):
-      <ol>
-            <li>Download the LLVM compiler using the following links:
-               <ul>
-                  <li><a href="https://huggingface.co/fishaudio/fish-speech-1/resolve/main/LLVM-17.0.6-win64.exe?download=true">LLVM-17.0.6 (original site download)</a></li>
-                  <li><a href="https://hf-mirror.com/fishaudio/fish-speech-1/resolve/main/LLVM-17.0.6-win64.exe?download=true">LLVM-17.0.6 (mirror site download)</a></li>
-                  <li>After downloading <code>LLVM-17.0.6-win64.exe</code>, double-click to install it, choose an appropriate installation location, and most importantly, check <code>Add Path to Current User</code> to add to the environment variables.</li>
-                  <li>Confirm the installation is complete.</li>
-               </ul>
-            </li>
-            <li>Download and install the Microsoft Visual C++ Redistributable package to resolve potential .dll missing issues.
-               <ul>
-                  <li><a href="https://aka.ms/vs/17/release/vc_redist.x64.exe">MSVC++ 14.40.33810.0 Download</a></li>
-               </ul>
-            </li>
-            <li>Download and install Visual Studio Community Edition to obtain MSVC++ build tools, resolving LLVM header file dependencies.
-               <ul>
-                  <li><a href="https://visualstudio.microsoft.com/zh-hans/downloads/">Visual Studio Download</a></li>
-                  <li>After installing Visual Studio Installer, download Visual Studio Community 2022.</li>
-                  <li>Click the <code>Modify</code> button as shown below, find the <code>Desktop development with C++</code> option, and check it for download.</li>
-                  <p align="center">
-                     <img src="../assets/figs/VS_1.jpg" width="75%">
-                  </p>
-               </ul>
-            </li>
-            <li>Install <a href="https://developer.nvidia.com/cuda-12-1-0-download-archive?target_os=Windows&target_arch=x86_64">CUDA Toolkit 12</a></li>
-      </ol>
-   </li>
-   <li>Double-click <code>start.bat</code> to enter the Fish-Speech training inference configuration WebUI page.
-      <ul>
-            <li>(Optional) Want to go directly to the inference page? Edit the <code>API_FLAGS.txt</code> in the project root directory and modify the first three lines as follows:
-               <pre><code>--infer
-# --api
-# --listen ...
-...</code></pre>
-            </li>
-            <li>(Optional) Want to start the API server? Edit the <code>API_FLAGS.txt</code> in the project root directory and modify the first three lines as follows:
-               <pre><code># --infer
---api
---listen ...
-...</code></pre>
-            </li>
-      </ul>
-   </li>
-   <li>(Optional) Double-click <code>run_cmd.bat</code> to enter the conda/python command line environment of this project.</li>
-</ol>
+Professional Windows users may consider using WSL2 or Docker to run the codebase.
+
+```bash
+# Create a python 3.10 virtual environment, you can also use virtualenv
+conda create -n fish-speech python=3.10
+conda activate fish-speech
+
+# Install pytorch
+pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
+
+# Install fish-speech
+pip3 install -e .
+
+# (Enable acceleration) Install triton-windows
+pip install https://github.com/AnyaCoder/fish-speech/releases/download/v0.1.0/triton_windows-0.1.0-py3-none-any.whl
+```
+
+Non-professional Windows users can consider the following basic methods to run the project without a Linux environment (with model compilation capabilities, i.e., `torch.compile`):
+
+1. Extract the project package.
+2. Click `install_env.bat` to install the environment.
+3. If you want to enable compilation acceleration, follow this step:
+    1. Download the LLVM compiler from the following links:
+        - [LLVM-17.0.6 (Official Site Download)](https://huggingface.co/fishaudio/fish-speech-1/resolve/main/LLVM-17.0.6-win64.exe?download=true)
+        - [LLVM-17.0.6 (Mirror Site Download)](https://hf-mirror.com/fishaudio/fish-speech-1/resolve/main/LLVM-17.0.6-win64.exe?download=true)
+        - After downloading `LLVM-17.0.6-win64.exe`, double-click to install, select an appropriate installation location, and most importantly, check the `Add Path to Current User` option to add the environment variable.
+        - Confirm that the installation is complete.
+    2. Download and install the Microsoft Visual C++ Redistributable to solve potential .dll missing issues:
+        - [MSVC++ 14.40.33810.0 Download](https://aka.ms/vs/17/release/vc_redist.x64.exe)
+    3. Download and install Visual Studio Community Edition to get MSVC++ build tools and resolve LLVM's header file dependencies:
+        - [Visual Studio Download](https://visualstudio.microsoft.com/zh-hans/downloads/)
+        - After installing Visual Studio Installer, download Visual Studio Community 2022.
+        - As shown below, click the `Modify` button and find the `Desktop development with C++` option to select and download.
+    4. Download and install [CUDA Toolkit 12.x](https://developer.nvidia.com/cuda-12-1-0-download-archive?target_os=Windows&target_arch=x86_64)
+4. Double-click `start.bat` to open the training inference WebUI management interface. If needed, you can modify the `API_FLAGS` as prompted below.
+
+!!! info "Optional"
+
+	Want to start the inference WebUI? 
+
+    Edit the `API_FLAGS.txt` file in the project root directory and modify the first three lines as follows: 
+    ```
+     --infer 
+     # --api 
+     # --listen ...
+     ...
+    ```
+
+!!! info "Optional"
+
+	Want to start the API server? 
+
+    Edit the `API_FLAGS.txt` file in the project root directory and modify the first three lines as follows:
+
+    ``` 
+    # --infer
+    --api
+    --listen ...
+    ...
+    ```
+
+!!! info "Optional"
+
+	Double-click `run_cmd.bat` to enter the conda/python command line environment of this project.
 
 ## Linux Setup
 
@@ -107,6 +111,7 @@ apt install libsox-dev
 
 ## Changelog
 
+- 2024/09/10: Updated Fish-Speech to 1.4 version, with an increase in dataset size and a change in the quantizer's n_groups from 4 to 8.
 - 2024/07/02: Updated Fish-Speech to 1.2 version, remove VITS Decoder, and greatly enhanced zero-shot ability.
 - 2024/05/10: Updated Fish-Speech to 1.1 version, implement VITS decoder to reduce WER and improve timbre similarity.
 - 2024/04/22: Finished Fish-Speech 1.0 version, significantly modified VQGAN and LLAMA models.

+ 11 - 40
docs/en/inference.md

@@ -90,51 +90,22 @@ python -m tools.post_api \
 
 The above command indicates synthesizing the desired audio according to the reference audio information and returning it in a streaming manner.
 
-If you need to randomly select reference audio based on `{SPEAKER}` and `{EMOTION}`, configure it according to the following steps:
-
-### 1. Create a `ref_data` folder in the root directory of the project.
-
-### 2. Create a directory structure similar to the following within the `ref_data` folder.
-
-```
-.
-├── SPEAKER1
-│    ├──EMOTION1
-│    │    ├── 21.15-26.44.lab
-│    │    ├── 21.15-26.44.wav
-│    │    ├── 27.51-29.98.lab
-│    │    ├── 27.51-29.98.wav
-│    │    ├── 30.1-32.71.lab
-│    │    └── 30.1-32.71.flac
-│    └──EMOTION2
-│         ├── 30.1-32.71.lab
-│         └── 30.1-32.71.mp3
-└── SPEAKER2
-    └─── EMOTION3
-          ├── 30.1-32.71.lab
-          └── 30.1-32.71.mp3
-```
-
-That is, first place `{SPEAKER}` folders in `ref_data`, then place `{EMOTION}` folders under each speaker, and place any number of `audio-text pairs` under each emotion folder.
-
-### 3. Enter the following command in the virtual environment
-
-```bash
-python tools/gen_ref.py
-
-```
-
-### 4. Call the API.
+The following example demonstrates that you can use **multiple** reference audio paths and reference audio texts at once. Separate them with spaces in the command.
 
 ```bash
 python -m tools.post_api \
-    --text "Text to be input" \
-    --speaker "${SPEAKER1}" \
-    --emotion "${EMOTION1}" \
-    --streaming True
+    --text "Text to input" \
+    --reference_audio "reference audio path1" "reference audio path2" \
+    --reference_text "reference audio text1" "reference audio text2"\
+    --streaming False \
+    --output "generated" \
+    --format "mp3"
 ```
 
-The above example is for testing purposes only.
+The above command synthesizes the desired `MP3` format audio based on the information from multiple reference audios and saves it as `generated.mp3` in the current directory.
+
+## GUI Inference 
+[Download client](https://github.com/AnyaCoder/fish-speech-gui/releases/tag/v0.1.0)
 
 ## WebUI Inference
 

+ 18 - 6
docs/zh/index.md

@@ -29,15 +29,26 @@
 
 Windows 专业用户可以考虑 WSL2 或 docker 来运行代码库。
 
+```bash
+# 创建一个 python 3.10 虚拟环境, 你也可以用 virtualenv
+conda create -n fish-speech python=3.10
+conda activate fish-speech
+
+# 安装 pytorch
+pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
+
+# 安装 fish-speech
+pip3 install -e .
+
+# (开启编译加速) 安装 triton-windows
+pip install https://github.com/AnyaCoder/fish-speech/releases/download/v0.1.0/triton_windows-0.1.0-py3-none-any.whl
+```
+
 Windows 非专业用户可考虑以下为免 Linux 环境的基础运行方法(附带模型编译功能,即 `torch.compile`):
 
 1. 解压项目压缩包。
 2. 点击 `install_env.bat` 安装环境。
-    - 可以通过编辑 `install_env.bat` 的 `USE_MIRROR` 项来决定是否使用镜像站下载。
-    - `USE_MIRROR=false` 使用原始站下载最新稳定版 `torch` 环境。`USE_MIRROR=true` 为从镜像站下载最新 `torch` 环境。默认为 `true`。
-    - 可以通过编辑 `install_env.bat` 的 `INSTALL_TYPE` 项来决定是否启用可编译环境下载。
-    - `INSTALL_TYPE=preview` 下载开发版编译环境。`INSTALL_TYPE=stable` 下载稳定版不带编译环境。
-3. 若第 2 步 `INSTALL_TYPE=preview` 则执行这一步(可跳过,此步为激活编译模型环境)
+3. 若需要开启编译加速则执行这一步:
     1. 使用如下链接下载 LLVM 编译器。
         - [LLVM-17.0.6(原站站点下载)](https://huggingface.co/fishaudio/fish-speech-1/resolve/main/LLVM-17.0.6-win64.exe?download=true)
         - [LLVM-17.0.6(镜像站点下载)](https://hf-mirror.com/fishaudio/fish-speech-1/resolve/main/LLVM-17.0.6-win64.exe?download=true)
@@ -49,7 +60,7 @@ Windows 非专业用户可考虑以下为免 Linux 环境的基础运行方法
         - [Visual Studio 下载](https://visualstudio.microsoft.com/zh-hans/downloads/)
         - 安装好 Visual Studio Installer 之后,下载 Visual Studio Community 2022
         - 如下图点击`修改`按钮,找到`使用C++的桌面开发`项,勾选下载
-    4. 下载安装 [CUDA Toolkit 12](https://developer.nvidia.com/cuda-12-1-0-download-archive?target_os=Windows&target_arch=x86_64)
+    4. 下载安装 [CUDA Toolkit 12.x](https://developer.nvidia.com/cuda-12-1-0-download-archive?target_os=Windows&target_arch=x86_64)
 4. 双击 `start.bat` 打开训练推理 WebUI 管理界面. 如有需要,可照下列提示修改`API_FLAGS`.
 
 !!! info "可选"
@@ -158,6 +169,7 @@ apt install libsox-dev
 
 ## 更新日志
 
+- 2024/09/10: 更新了 Fish-Speech 到 1.4, 增加了数据集大小, quantizer n_groups 4 -> 8.
 - 2024/07/02: 更新了 Fish-Speech 到 1.2 版本,移除 VITS Decoder,同时极大幅度提升 zero-shot 能力.
 - 2024/05/10: 更新了 Fish-Speech 到 1.1 版本,引入了 VITS Decoder 来降低口胡和提高音色相似度.
 - 2024/04/22: 完成了 Fish-Speech 1.0 版本, 大幅修改了 VQGAN 和 LLAMA 模型.

+ 10 - 40
docs/zh/inference.md

@@ -100,52 +100,22 @@ python -m tools.post_api \
 
 上面的命令表示按照参考音频的信息,合成所需的音频并流式返回.
 
-如果需要通过`{说话人}`和`{情绪}`随机选择参考音频,那么就根据下列步骤配置:
-
-### 1. 在项目根目录创建`ref_data`文件夹.
-
-### 2. 在`ref_data`文件夹内创建类似如下结构的目录.
-
-```
-.
-├── SPEAKER1
-│    ├──EMOTION1
-│    │    ├── 21.15-26.44.lab
-│    │    ├── 21.15-26.44.wav
-│    │    ├── 27.51-29.98.lab
-│    │    ├── 27.51-29.98.wav
-│    │    ├── 30.1-32.71.lab
-│    │    └── 30.1-32.71.flac
-│    └──EMOTION2
-│         ├── 30.1-32.71.lab
-│         └── 30.1-32.71.mp3
-└── SPEAKER2
-    └─── EMOTION3
-          ├── 30.1-32.71.lab
-          └── 30.1-32.71.mp3
-```
-
-也就是`ref_data`里先放`{说话人}`文件夹, 每个说话人下再放`{情绪}`文件夹, 每个情绪文件夹下放任意个`音频-文本对`。
-
-### 3. 在虚拟环境里输入
-
-```bash
-python tools/gen_ref.py
-```
-
-生成参考目录.
-
-### 4. 调用 api.
+下面的示例展示了, 可以一次使用**多个** `参考音频路径` 和 `参考音频的文本内容`。在命令里用空格隔开即可。 
 
 ```bash
 python -m tools.post_api \
     --text "要输入的文本" \
-    --speaker "说话人1" \
-    --emotion "情绪1" \
-    --streaming True
+    --reference_audio "参考音频路径1" "参考音频路径2" \
+    --reference_text "参考音频的文本内容1" "参考音频的文本内容2"\
+    --streaming False \
+    --output "generated" \
+    --format "mp3"
 ```
 
-以上示例仅供测试.
+上面的命令表示按照多个参考音频的信息,合成所需的`MP3`格式音频,并保存为当前目录的`generated.mp3`文件。
+
+## GUI 推理 
+[下载客户端](https://github.com/AnyaCoder/fish-speech-gui/releases/tag/v0.1.0)
 
 ## WebUI 推理
 

+ 2 - 0
fish_speech/train.py

@@ -1,4 +1,6 @@
 import os
+
+os.environ["USE_LIBUV"] = "0"
 import sys
 from typing import Optional
 

+ 5 - 3
fish_speech/webui/manage.py

@@ -1,9 +1,11 @@
 from __future__ import annotations
 
+import os
+
+os.environ["USE_LIBUV"] = "0"
 import datetime
 import html
 import json
-import os
 import platform
 import shutil
 import signal
@@ -862,7 +864,7 @@ with gr.Blocks(
                                     minimum=1,
                                     maximum=32,
                                     step=1,
-                                    value=4,
+                                    value=2,
                                 )
                                 llama_data_max_length_slider = gr.Slider(
                                     label=i18n("Maximum Length per Sample"),
@@ -870,7 +872,7 @@ with gr.Blocks(
                                     minimum=1024,
                                     maximum=4096,
                                     step=128,
-                                    value=1024,
+                                    value=2048,
                                 )
                             with gr.Row(equal_height=False):
                                 llama_precision_dropdown = gr.Dropdown(

+ 6 - 88
install_env.bat

@@ -2,9 +2,7 @@
 chcp 65001
 
 set USE_MIRROR=true
-set INSTALL_TYPE=preview
 echo "USE_MIRROR: %USE_MIRROR%"
-echo "INSTALL_TYPE: %INSTALL_TYPE%"
 setlocal enabledelayedexpansion
 
 cd /D "%~dp0"
@@ -125,12 +123,6 @@ if errorlevel 1 (
     echo successfully create env.
 )
 
-set "packages=torch torchvision torchaudio fish-speech"
-
-if "%INSTALL_TYPE%"=="preview" (
-    set "packages=!packages! triton_windows"
-)
-
 set "HF_ENDPOINT=https://huggingface.co"
 set "no_proxy="
 if "%USE_MIRROR%"=="true" (
@@ -141,27 +133,14 @@ if "%USE_MIRROR%"=="true" (
 echo "HF_ENDPOINT: !HF_ENDPOINT!"
 echo "NO_PROXY: !no_proxy!"
 
-set "install_packages="
+%PIP_CMD% install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
 
-for %%p in (%packages%) do (
-    %PIP_CMD% show %%p >nul 2>&1
-    if errorlevel 1 (
-        set "install_packages=!install_packages! %%p"
-    )
-)
+%PIP_CMD% install -e . --upgrade-strategy only-if-needed
+
+call :download_and_install "triton_windows-0.1.0-py3-none-any.whl" ^
+        "%HF_ENDPOINT%/datasets/SpicyqSama007/windows_compile/resolve/main/triton_windows-0.1.0-py3-none-any.whl?download=true" ^
+        "2cc998638180f37cf5025ab65e48c7f629aa5a369176cfa32177d2bd9aa26a0a"
 
-if not "%install_packages%"=="" (
-    echo.
-    echo Installing: %install_packages%
-    
-    for %%p in (%install_packages%) do (
-        if "%INSTALL_TYPE%"=="preview" (
-            call :install_preview %%p
-        ) else (
-            call :install_stable %%p
-        )
-    )
-)
 
 endlocal
 echo "Environment Check: Success."
@@ -169,67 +148,6 @@ pause
 
 goto :EOF
 
-:install_preview
-setlocal
-
-if "%1"=="torch" (
-    call :download_and_install "torch-2.4.0.dev20240427+cu121-cp310-cp310-win_amd64.whl" ^
-        "%HF_ENDPOINT%/datasets/SpicyqSama007/windows_compile/resolve/main/torch-2.4.0.dev20240427_cu121-cp310-cp310-win_amd64.whl?download=true" ^
-        "b091308f4cb74e63d0323afd67c92f2279d9e488d8cbf467bcc7b939bcd74e0b"
-
-) else if "%1"=="torchvision" (
-    call :download_and_install "torchvision-0.19.0.dev20240428+cu121-cp310-cp310-win_amd64.whl" ^
-        "%HF_ENDPOINT%/datasets/SpicyqSama007/windows_compile/resolve/main/torchvision-0.19.0.dev20240428_cu121-cp310-cp310-win_amd64.whl?download=true" ^
-        "7e46d0a89534013f001563d15e80f9eb431089571720c51f2cc595feeb01d785"
-
-) else if "%1"=="torchaudio" (
-    call :download_and_install "torchaudio-2.2.0.dev20240427+cu121-cp310-cp310-win_amd64.whl" ^
-        "%HF_ENDPOINT%/datasets/SpicyqSama007/windows_compile/resolve/main/torchaudio-2.2.0.dev20240427_cu121-cp310-cp310-win_amd64.whl?download=true" ^
-        "abafb4bc82cbc6f58f18e1b95191bc1884c28e404781082db2eb540b4fae8a5d"
-
-) else if "%1"=="fish-speech" (
-    %PIP_CMD% install -e . --upgrade-strategy only-if-needed
-
-) else if "%1"=="triton_windows" (
-    call :download_and_install "triton_windows-0.1.0-py3-none-any.whl" ^
-        "%HF_ENDPOINT%/datasets/SpicyqSama007/windows_compile/resolve/main/triton_windows-0.1.0-py3-none-any.whl?download=true" ^
-        "2cc998638180f37cf5025ab65e48c7f629aa5a369176cfa32177d2bd9aa26a0a"
-)
-
-endlocal
-goto :EOF
-
-:install_stable
-if "%USE_MIRROR%"=="true" (
-    if "%1"=="torch" (
-        %PIP_CMD% install torch==2.3.1 --index-url https://mirror.sjtu.edu.cn/pytorch-wheels/cu121 --no-warn-script-location
-
-    ) else if "%1"=="torchvision" (
-        %PIP_CMD% install torchvision==0.18.1 --index-url https://mirror.sjtu.edu.cn/pytorch-wheels/cu121 --no-warn-script-location
-
-    ) else if "%1"=="torchaudio" (
-        %PIP_CMD% install torchaudio==2.3.1 --index-url https://mirror.sjtu.edu.cn/pytorch-wheels/cu121 --no-warn-script-location
-
-    ) else if "%1"=="fish-speech" (
-        %PIP_CMD% install -e . -i https://pypi.tuna.tsinghua.edu.cn/simple
-    )
-
-) else (
-    if "%1"=="torch" (
-        %PIP_CMD% install torch==2.3.1 --index-url https://download.pytorch.org/whl/cu121 --no-warn-script-location
-
-    ) else if "%1"=="torchvision" (
-        %PIP_CMD% install torchvision==0.18.1 --index-url https://download.pytorch.org/whl/cu121 --no-warn-script-location
-
-    ) else if "%1"=="torchaudio" (
-        %PIP_CMD% install torchaudio==2.3.1 --index-url https://download.pytorch.org/whl/cu121 --no-warn-script-location
-
-    ) else if "%1"=="fish-speech" (
-        %PIP_CMD% install -e .
-    )
-)
-
-goto :EOF
 
 :download_and_install
 setlocal

+ 0 - 1
tools/api.py

@@ -359,7 +359,6 @@ def parse_args():
     parser.add_argument("--max-text-length", type=int, default=0)
     parser.add_argument("--listen", type=str, default="127.0.0.1:8080")
     parser.add_argument("--workers", type=int, default=1)
-    parser.add_argument("--use-auto-rerank", type=bool, default=True)
 
     return parser.parse_args()
 

+ 0 - 159
tools/auto_rerank.py

@@ -1,159 +0,0 @@
-import os
-
-os.environ["MODELSCOPE_CACHE"] = ".cache/"
-
-import string
-import time
-from threading import Lock
-
-import librosa
-import numpy as np
-import opencc
-import torch
-from faster_whisper import WhisperModel
-
-t2s_converter = opencc.OpenCC("t2s")
-
-
-def load_model(*, device="cuda"):
-    model = WhisperModel(
-        "medium",
-        device=device,
-        compute_type="float16",
-        download_root="faster_whisper",
-    )
-    print("faster_whisper loaded!")
-    return model
-
-
-@torch.no_grad()
-def batch_asr_internal(model: WhisperModel, audios, sr):
-    resampled_audios = []
-    for audio in audios:
-
-        if isinstance(audio, np.ndarray):
-            audio = torch.from_numpy(audio).float()
-
-        if audio.dim() > 1:
-            audio = audio.squeeze()
-
-        assert audio.dim() == 1
-        audio_np = audio.numpy()
-        resampled_audio = librosa.resample(audio_np, orig_sr=sr, target_sr=16000)
-        resampled_audios.append(resampled_audio)
-
-    trans_results = []
-
-    for resampled_audio in resampled_audios:
-        segments, info = model.transcribe(
-            resampled_audio,
-            language=None,
-            beam_size=5,
-            initial_prompt="Punctuation is needed in any language.",
-        )
-        trans_results.append(list(segments))
-
-    results = []
-    for trans_res, audio in zip(trans_results, audios):
-
-        duration = len(audio) / sr * 1000
-        huge_gap = False
-        max_gap = 0.0
-
-        text = None
-        last_tr = None
-
-        for tr in trans_res:
-            delta = tr.text.strip()
-            if tr.id > 1:
-                max_gap = max(tr.start - last_tr.end, max_gap)
-                text += delta
-            else:
-                text = delta
-
-            last_tr = tr
-            if max_gap > 3.0:
-                huge_gap = True
-                break
-
-        sim_text = t2s_converter.convert(text)
-        results.append(
-            {
-                "text": sim_text,
-                "duration": duration,
-                "huge_gap": huge_gap,
-            }
-        )
-
-    return results
-
-
-global_lock = Lock()
-
-
-def batch_asr(model, audios, sr):
-    return batch_asr_internal(model, audios, sr)
-
-
-def is_chinese(text):
-    return True
-
-
-def calculate_wer(text1, text2, debug=False):
-    chars1 = remove_punctuation(text1)
-    chars2 = remove_punctuation(text2)
-
-    m, n = len(chars1), len(chars2)
-
-    if m > n:
-        chars1, chars2 = chars2, chars1
-        m, n = n, m
-
-    prev = list(range(m + 1))  # row 0 distance: [0, 1, 2, ...]
-    curr = [0] * (m + 1)
-
-    for j in range(1, n + 1):
-        curr[0] = j
-        for i in range(1, m + 1):
-            if chars1[i - 1] == chars2[j - 1]:
-                curr[i] = prev[i - 1]
-            else:
-                curr[i] = min(prev[i], curr[i - 1], prev[i - 1]) + 1
-        prev, curr = curr, prev
-
-    edits = prev[m]
-    tot = max(len(chars1), len(chars2))
-    wer = edits / tot
-
-    if debug:
-        print("            gt:   ", chars1)
-        print("          pred:   ", chars2)
-        print(" edits/tot = wer: ", edits, "/", tot, "=", wer)
-
-    return wer
-
-
-def remove_punctuation(text):
-    chinese_punctuation = (
-        " \n\t”“!?。。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃《》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—"
-        '‛""„‟…‧﹏'
-    )
-    all_punctuation = string.punctuation + chinese_punctuation
-    translator = str.maketrans("", "", all_punctuation)
-    text_without_punctuation = text.translate(translator)
-    return text_without_punctuation
-
-
-if __name__ == "__main__":
-    model = load_model()
-    audios = [
-        librosa.load("44100.wav", sr=44100)[0],
-        librosa.load("lengyue.wav", sr=44100)[0],
-    ]
-    print(np.array(audios[0]))
-    print(batch_asr(model, audios, 44100))
-
-    start_time = time.time()
-    for _ in range(10):
-        print(batch_asr(model, audios, 44100))
-    print("Time taken:", time.time() - start_time)

+ 5 - 15
tools/llama/generate.py

@@ -237,25 +237,11 @@ def generate(
     # create an empty tensor of the expected final shape and fill in the current tokens
     T = prompt.size(1)
 
-    if max_new_tokens:
-        if T + max_new_tokens > model.config.max_seq_len:
-            max_new_tokens = model.config.max_seq_len - T
-            logger.info(f"Truncating max_new_tokens to {max_new_tokens}")
-
-        T_new = T + max_new_tokens
-    else:
-        T_new = model.config.max_seq_len
-        max_new_tokens = T_new - T
-
     device, dtype = prompt.device, prompt.dtype
-    with torch.device(device):
-        model.setup_caches(
-            max_batch_size=1, max_seq_len=T_new, dtype=next(model.parameters()).dtype
-        )
 
     codebook_dim = 1 + model.config.num_codebooks
     # create an empty tensor of the expected final shape and fill in the current tokens
-    empty = torch.empty((codebook_dim, T_new), dtype=dtype, device=device)
+    empty = torch.empty((codebook_dim, max_new_tokens), dtype=dtype, device=device)
     empty[:, :T] = prompt
     seq = empty
     input_pos = torch.arange(0, T, device=device)
@@ -575,6 +561,10 @@ def launch_thread_safe_queue(
         model, decode_one_token = load_model(
             checkpoint_path, device, precision, compile=compile
         )
+        with torch.device(device):
+            model.setup_caches(
+                max_batch_size=1, max_seq_len=2048, dtype=next(model.parameters()).dtype
+            )
         init_event.set()
 
         while True:

+ 3 - 137
tools/webui.py

@@ -23,7 +23,6 @@ from fish_speech.i18n import i18n
 from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
 from fish_speech.utils import autocast_exclude_mps
 from tools.api import decode_vq_tokens, encode_reference
-from tools.auto_rerank import batch_asr, calculate_wer, is_chinese, load_model
 from tools.llama.generate import (
     GenerateRequest,
     GenerateResponse,
@@ -160,66 +159,6 @@ def inference(
         gc.collect()
 
 
-def inference_with_auto_rerank(
-    text,
-    enable_reference_audio,
-    reference_audio,
-    reference_text,
-    max_new_tokens,
-    chunk_length,
-    top_p,
-    repetition_penalty,
-    temperature,
-    use_auto_rerank,
-    streaming=False,
-):
-
-    max_attempts = 2 if use_auto_rerank else 1
-    best_wer = float("inf")
-    best_audio = None
-    best_sample_rate = None
-
-    for attempt in range(max_attempts):
-        audio_generator = inference(
-            text,
-            enable_reference_audio,
-            reference_audio,
-            reference_text,
-            max_new_tokens,
-            chunk_length,
-            top_p,
-            repetition_penalty,
-            temperature,
-            streaming=False,
-        )
-
-        # 获取音频数据
-        for _ in audio_generator:
-            pass
-        _, (sample_rate, audio), message = _
-
-        if audio is None:
-            return None, None, message
-
-        if not use_auto_rerank:
-            return None, (sample_rate, audio), None
-
-        asr_result = batch_asr(asr_model, [audio], sample_rate)[0]
-        wer = calculate_wer(text, asr_result["text"])
-        if wer <= 0.3 and not asr_result["huge_gap"]:
-            return None, (sample_rate, audio), None
-
-        if wer < best_wer:
-            best_wer = wer
-            best_audio = audio
-            best_sample_rate = sample_rate
-
-        if attempt == max_attempts - 1:
-            break
-
-    return None, (best_sample_rate, best_audio), None
-
-
 inference_stream = partial(inference, streaming=True)
 
 n_audios = 4
@@ -239,13 +178,12 @@ def inference_wrapper(
     repetition_penalty,
     temperature,
     batch_infer_num,
-    if_load_asr_model,
 ):
     audios = []
     errors = []
 
     for _ in range(batch_infer_num):
-        result = inference_with_auto_rerank(
+        result = inference(
             text,
             enable_reference_audio,
             reference_audio,
@@ -255,10 +193,9 @@ def inference_wrapper(
             top_p,
             repetition_penalty,
             temperature,
-            if_load_asr_model,
         )
 
-        _, audio_data, error_message = result
+        _, audio_data, error_message = next(result)
 
         audios.append(
             gr.Audio(value=audio_data if audio_data else None, visible=True),
@@ -301,42 +238,6 @@ def normalize_text(user_input, use_normalization):
 asr_model = None
 
 
-def change_if_load_asr_model(if_load):
-    global asr_model
-
-    if if_load:
-        gr.Warning("Loading faster whisper model...")
-        if asr_model is None:
-            asr_model = load_model()
-        return gr.Checkbox(label="Unload faster whisper model", value=if_load)
-
-    if if_load is False:
-        gr.Warning("Unloading faster whisper model...")
-        del asr_model
-        asr_model = None
-        if torch.cuda.is_available():
-            torch.cuda.empty_cache()
-            gc.collect()
-        return gr.Checkbox(label="Load faster whisper model", value=if_load)
-
-
-def change_if_auto_label(if_load, if_auto_label, enable_ref, ref_audio, ref_text):
-    if if_load and asr_model is not None:
-        if (
-            if_auto_label
-            and enable_ref
-            and ref_audio is not None
-            and ref_text.strip() == ""
-        ):
-            data, sample_rate = librosa.load(ref_audio)
-            res = batch_asr(asr_model, [data], sample_rate)[0]
-            ref_text = res["text"]
-    else:
-        gr.Warning("Whisper model not loaded!")
-
-    return gr.Textbox(value=ref_text)
-
-
 def build_app():
     with gr.Blocks(theme=gr.themes.Base()) as app:
         gr.Markdown(HEADER_MD)
@@ -371,12 +272,6 @@ def build_app():
                         scale=1,
                     )
 
-                    if_load_asr_model = gr.Checkbox(
-                        label=i18n("Load / Unload ASR model for auto-reranking"),
-                        value=False,
-                        scale=3,
-                    )
-
                 with gr.Row():
                     with gr.Tab(label=i18n("Advanced Config")):
                         chunk_length = gr.Slider(
@@ -434,12 +329,6 @@ def build_app():
                             type="filepath",
                         )
                         with gr.Row():
-                            if_auto_label = gr.Checkbox(
-                                label=i18n("Auto Labeling"),
-                                min_width=100,
-                                scale=0,
-                                value=False,
-                            )
                             reference_text = gr.Textbox(
                                 label=i18n("Reference Text"),
                                 lines=1,
@@ -494,28 +383,6 @@ def build_app():
             fn=normalize_text, inputs=[text, if_refine_text], outputs=[refined_text]
         )
 
-        if_load_asr_model.change(
-            fn=change_if_load_asr_model,
-            inputs=[if_load_asr_model],
-            outputs=[if_load_asr_model],
-        )
-
-        if_auto_label.change(
-            fn=lambda: gr.Textbox(value=""),
-            inputs=[],
-            outputs=[reference_text],
-        ).then(
-            fn=change_if_auto_label,
-            inputs=[
-                if_load_asr_model,
-                if_auto_label,
-                enable_reference_audio,
-                reference_audio,
-                reference_text,
-            ],
-            outputs=[reference_text],
-        )
-
         # # Submit
         generate.click(
             inference_wrapper,
@@ -530,7 +397,6 @@ def build_app():
                 repetition_penalty,
                 temperature,
                 batch_infer_num,
-                if_load_asr_model,
             ],
             [stream_audio, *global_audio_list, *global_error_list],
             concurrency_limit=1,
@@ -605,7 +471,7 @@ if __name__ == "__main__":
             enable_reference_audio=False,
             reference_audio=None,
             reference_text="",
-            max_new_tokens=0,
+            max_new_tokens=2048,
             chunk_length=100,
             top_p=0.7,
             repetition_penalty=1.2,