Просмотр исходного кода

Add some corner cases (#158)

* Fix button height

* Streaming support

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Convert to 1 channel

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix Conversion bug

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix target path

* Add checkpoint selection

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix gpup decorator

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add link for labeler

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Localize labeler

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add LoRA llama config

* Allow download stream audio

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* asr

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add cache auto recycling

* 多打了一个字母

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Check 'compile' avaliable

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
spicysama 1 год назад
Родитель
Сommit
12915468c8
5 измененных файлов с 86 добавлено и 16 удалено
  1. 1 1
      .gitignore
  2. 1 0
      fish_speech/i18n/locale/zh_CN.json
  3. 6 0
      fish_speech/webui/launch_utils.py
  4. 68 14
      fish_speech/webui/manage.py
  5. 10 1
      tools/webui.py

+ 1 - 1
.gitignore

@@ -19,7 +19,7 @@ filelists
 /.idea
 ffmpeg.exe
 ffprobe.exe
-asr-label-win-x64.exe
+asr-label*
 /.cache
 /fishenv
 /.locale

+ 1 - 0
fish_speech/i18n/locale/zh_CN.json

@@ -56,6 +56,7 @@
     "Open Tensorboard": "打开 Tensorboard",
     "Opened labeler in browser": "在浏览器中打开标注工具",
     "Optional Label Language": "[可选] 标注语言",
+    "Optional online ver": "[可选] 使用在线版",
     "Output Path": "输出路径",
     "Path error, please check the model file exists in the corresponding path": "路径错误,请检查模型文件是否存在于相应路径",
     "Precision": "精度",

+ 6 - 0
fish_speech/webui/launch_utils.py

@@ -1,3 +1,4 @@
+import importlib.util
 import os
 import subprocess
 import sys
@@ -17,6 +18,11 @@ GIT = (
 GIT = str(GIT)
 
 
+def is_module_installed(module_name: str) -> bool:
+    spec = importlib.util.find_spec(module_name)
+    return spec is not None
+
+
 @lru_cache()
 def commit_hash():
     try:

+ 68 - 14
fish_speech/webui/manage.py

@@ -8,7 +8,6 @@ import shutil
 import signal
 import subprocess
 import sys
-import webbrowser
 from pathlib import Path
 
 import gradio as gr
@@ -18,7 +17,7 @@ from loguru import logger
 from tqdm import tqdm
 
 from fish_speech.i18n import i18n
-from fish_speech.webui.launch_utils import Seafoam, versions_html
+from fish_speech.webui.launch_utils import Seafoam, is_module_installed, versions_html
 
 PYTHON = os.path.join(os.environ.get("PYTHON_FOLDERPATH", ""), "python")
 sys.path.insert(0, "")
@@ -51,6 +50,15 @@ def build_html_ok_message(msg):
     """
 
 
+def build_html_href(link, desc, msg):
+    return f"""
+    <span style="color: green; font-weight: bold; display: inline-block">
+        {html.escape(msg)}
+        <a href="{link}">{desc}</a>
+    </span>
+    """
+
+
 def load_data_in_raw(path):
     with open(path, "r", encoding="utf-8") as file:
         data = file.read()
@@ -94,14 +102,42 @@ def kill_process(pid):
 
 def change_label(if_label):
     global p_label
-    if if_label == True:
-        # 设置要访问的URL
-        url = "https://text-labeler.pages.dev/"
-        webbrowser.open(url)
-        yield i18n("Opened labeler in browser")
-    elif if_label == False:
+    if if_label == True and p_label is None:
+        url = "http://localhost:3000"
+        remote_url = "https://text-labeler.pages.dev/"
+        p_label = subprocess.Popen(
+            [
+                "asr-label-linux-x64"
+                if sys.platform == "linux"
+                else "asr-label-win-x64.exe"
+            ]
+        )
+        yield build_html_href(
+            link=remote_url,
+            desc=i18n("Optional online ver"),
+            msg=i18n("Opened labeler in browser"),
+        )
+
+    elif if_label == False and p_label is not None:
+        kill_process(p_label.pid)
         p_label = None
-        yield "Nothing"
+        yield build_html_ok_message("Nothing")
+
+
+def clean_infer_cache():
+    import tempfile
+
+    temp_dir = Path(tempfile.gettempdir())
+    gradio_dir = str(temp_dir / "gradio")
+    try:
+        shutil.rmtree(gradio_dir)
+        logger.info(f"Deleted cached audios: {gradio_dir}")
+    except PermissionError:
+        logger.info(f"Permission denied: Unable to delete {gradio_dir}")
+    except FileNotFoundError:
+        logger.info(f"{gradio_dir} was not found")
+    except Exception as e:
+        logger.info(f"An error occurred: {e}")
 
 
 def change_infer(
@@ -124,6 +160,9 @@ def change_infer(
         yield build_html_ok_message(
             i18n("Inferring interface is launched at {}").format(url)
         )
+
+        clean_infer_cache()
+
         p_infer = subprocess.Popen(
             [
                 PYTHON,
@@ -141,7 +180,7 @@ def change_infer(
             env=env,
         )
 
-    elif if_infer == False and p_infer != None:
+    elif if_infer == False and p_infer is not None:
         kill_process(p_infer.pid)
         p_infer = None
         yield build_html_error_message(i18n("Infer interface is closed"))
@@ -585,7 +624,7 @@ def fresh_llama_model():
     )
 
 
-def llama_lora_merge(llama_weight, lora_weight, llama_lora_output):
+def llama_lora_merge(llama_weight, lora_llama_config, lora_weight, llama_lora_output):
     if (
         lora_weight is None
         or not Path(lora_weight).exists()
@@ -601,7 +640,7 @@ def llama_lora_merge(llama_weight, lora_weight, llama_lora_output):
         PYTHON,
         "tools/llama/merge_lora.py",
         "--llama-config",
-        "dual_ar_2_codebook_large",
+        lora_llama_config,
         "--lora-config",
         "r_8_alpha_16",
         "--llama-weight",
@@ -902,6 +941,15 @@ with gr.Blocks(
                                 allow_custom_value=True,
                                 interactive=True,
                             )
+                            lora_llama_config = gr.Dropdown(
+                                label=i18n("LLAMA Model Config"),
+                                choices=[
+                                    "dual_ar_2_codebook_large",
+                                    "dual_ar_2_codebook_medium",
+                                ],
+                                value="dual_ar_2_codebook_large",
+                                allow_custom_value=True,
+                            )
                         with gr.Row(equal_height=False):
                             llama_lora_output = gr.Dropdown(
                                 label=i18n("Output Path"),
@@ -994,7 +1042,13 @@ with gr.Blocks(
                                         "Compile the model can significantly reduce the inference time, but will increase cold start time"
                                     ),
                                     choices=["Yes", "No"],
-                                    value="Yes",
+                                    value="Yes"
+                                    if (
+                                        sys.platform == "linux"
+                                        or is_module_installed("triton")
+                                    )
+                                    else "No",
+                                    interactive=is_module_installed("triton"),
                                 )
                                 infer_llama_config = gr.Dropdown(
                                     label=i18n("LLAMA Model Config"),
@@ -1134,7 +1188,7 @@ with gr.Blocks(
     llama_ckpt.change(fn=fresh_llama_ckpt, inputs=[], outputs=[llama_ckpt])
     llama_lora_merge_btn.click(
         fn=llama_lora_merge,
-        inputs=[llama_weight, lora_weight, llama_lora_output],
+        inputs=[llama_weight, lora_llama_config, lora_weight, llama_lora_output],
         outputs=[train_error],
     )
     infer_checkbox.change(

+ 10 - 1
tools/webui.py

@@ -39,6 +39,7 @@ HEADER_MD = f"""# Fish Speech
 
 TEXTBOX_PLACEHOLDER = i18n("Put your text here.")
 SPACE_IMPORTED = False
+cached_audio = np.zeros((1,))
 
 
 def build_html_error_message(error):
@@ -122,6 +123,8 @@ def inference(
         yield wav_chunk_header(), None
 
     segments = []
+    global cached_audio
+    cached_audio = np.zeros((1,))
     while True:
         result = payload["response_queue"].get()
         if result == "next":
@@ -141,6 +144,7 @@ def inference(
         fake_audios = fake_audios.float().cpu().numpy()
 
         if streaming:
+            cached_audio = np.concatenate([cached_audio, fake_audios], axis=0)
             yield (fake_audios * 32768).astype(np.int16).tobytes(), None
         else:
             segments.append(fake_audios)
@@ -296,6 +300,11 @@ def build_app():
             [audio, error],
             concurrency_limit=1,
         )
+
+        def transfer_audio():
+            global cached_audio
+            return (vqgan_model.sampling_rate, cached_audio)
+
         generate_stream.click(
             inference_stream,
             [
@@ -312,7 +321,7 @@ def build_app():
             ],
             [stream_audio, error],
             concurrency_limit=10,
-        )
+        ).then(transfer_audio, None, audio)
     return app