ソースを参照

Add some corner cases (#158)

* Fix button height

* Streaming support

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Convert to 1 channel

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix Conversion bug

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix target path

* Add checkpoint selection

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix gpup decorator

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add link for labeler

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Localize labeler

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add LoRA llama config

* Allow download stream audio

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* asr

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add cache auto recycling

* 多打了一个字母

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Check 'compile' avaliable

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
spicysama 1 年間 前
コミット
12915468c8

+ 1 - 1
.gitignore

@@ -19,7 +19,7 @@ filelists
 /.idea
 /.idea
 ffmpeg.exe
 ffmpeg.exe
 ffprobe.exe
 ffprobe.exe
-asr-label-win-x64.exe
+asr-label*
 /.cache
 /.cache
 /fishenv
 /fishenv
 /.locale
 /.locale

+ 1 - 0
fish_speech/i18n/locale/zh_CN.json

@@ -56,6 +56,7 @@
     "Open Tensorboard": "打开 Tensorboard",
     "Open Tensorboard": "打开 Tensorboard",
     "Opened labeler in browser": "在浏览器中打开标注工具",
     "Opened labeler in browser": "在浏览器中打开标注工具",
     "Optional Label Language": "[可选] 标注语言",
     "Optional Label Language": "[可选] 标注语言",
+    "Optional online ver": "[可选] 使用在线版",
     "Output Path": "输出路径",
     "Output Path": "输出路径",
     "Path error, please check the model file exists in the corresponding path": "路径错误,请检查模型文件是否存在于相应路径",
     "Path error, please check the model file exists in the corresponding path": "路径错误,请检查模型文件是否存在于相应路径",
     "Precision": "精度",
     "Precision": "精度",

+ 6 - 0
fish_speech/webui/launch_utils.py

@@ -1,3 +1,4 @@
+import importlib.util
 import os
 import os
 import subprocess
 import subprocess
 import sys
 import sys
@@ -17,6 +18,11 @@ GIT = (
 GIT = str(GIT)
 GIT = str(GIT)
 
 
 
 
+def is_module_installed(module_name: str) -> bool:
+    spec = importlib.util.find_spec(module_name)
+    return spec is not None
+
+
 @lru_cache()
 @lru_cache()
 def commit_hash():
 def commit_hash():
     try:
     try:

+ 68 - 14
fish_speech/webui/manage.py

@@ -8,7 +8,6 @@ import shutil
 import signal
 import signal
 import subprocess
 import subprocess
 import sys
 import sys
-import webbrowser
 from pathlib import Path
 from pathlib import Path
 
 
 import gradio as gr
 import gradio as gr
@@ -18,7 +17,7 @@ from loguru import logger
 from tqdm import tqdm
 from tqdm import tqdm
 
 
 from fish_speech.i18n import i18n
 from fish_speech.i18n import i18n
-from fish_speech.webui.launch_utils import Seafoam, versions_html
+from fish_speech.webui.launch_utils import Seafoam, is_module_installed, versions_html
 
 
 PYTHON = os.path.join(os.environ.get("PYTHON_FOLDERPATH", ""), "python")
 PYTHON = os.path.join(os.environ.get("PYTHON_FOLDERPATH", ""), "python")
 sys.path.insert(0, "")
 sys.path.insert(0, "")
@@ -51,6 +50,15 @@ def build_html_ok_message(msg):
     """
     """
 
 
 
 
+def build_html_href(link, desc, msg):
+    return f"""
+    <span style="color: green; font-weight: bold; display: inline-block">
+        {html.escape(msg)}
+        <a href="{link}">{desc}</a>
+    </span>
+    """
+
+
 def load_data_in_raw(path):
 def load_data_in_raw(path):
     with open(path, "r", encoding="utf-8") as file:
     with open(path, "r", encoding="utf-8") as file:
         data = file.read()
         data = file.read()
@@ -94,14 +102,42 @@ def kill_process(pid):
 
 
 def change_label(if_label):
 def change_label(if_label):
     global p_label
     global p_label
-    if if_label == True:
-        # 设置要访问的URL
-        url = "https://text-labeler.pages.dev/"
-        webbrowser.open(url)
-        yield i18n("Opened labeler in browser")
-    elif if_label == False:
+    if if_label == True and p_label is None:
+        url = "http://localhost:3000"
+        remote_url = "https://text-labeler.pages.dev/"
+        p_label = subprocess.Popen(
+            [
+                "asr-label-linux-x64"
+                if sys.platform == "linux"
+                else "asr-label-win-x64.exe"
+            ]
+        )
+        yield build_html_href(
+            link=remote_url,
+            desc=i18n("Optional online ver"),
+            msg=i18n("Opened labeler in browser"),
+        )
+
+    elif if_label == False and p_label is not None:
+        kill_process(p_label.pid)
         p_label = None
         p_label = None
-        yield "Nothing"
+        yield build_html_ok_message("Nothing")
+
+
+def clean_infer_cache():
+    import tempfile
+
+    temp_dir = Path(tempfile.gettempdir())
+    gradio_dir = str(temp_dir / "gradio")
+    try:
+        shutil.rmtree(gradio_dir)
+        logger.info(f"Deleted cached audios: {gradio_dir}")
+    except PermissionError:
+        logger.info(f"Permission denied: Unable to delete {gradio_dir}")
+    except FileNotFoundError:
+        logger.info(f"{gradio_dir} was not found")
+    except Exception as e:
+        logger.info(f"An error occurred: {e}")
 
 
 
 
 def change_infer(
 def change_infer(
@@ -124,6 +160,9 @@ def change_infer(
         yield build_html_ok_message(
         yield build_html_ok_message(
             i18n("Inferring interface is launched at {}").format(url)
             i18n("Inferring interface is launched at {}").format(url)
         )
         )
+
+        clean_infer_cache()
+
         p_infer = subprocess.Popen(
         p_infer = subprocess.Popen(
             [
             [
                 PYTHON,
                 PYTHON,
@@ -141,7 +180,7 @@ def change_infer(
             env=env,
             env=env,
         )
         )
 
 
-    elif if_infer == False and p_infer != None:
+    elif if_infer == False and p_infer is not None:
         kill_process(p_infer.pid)
         kill_process(p_infer.pid)
         p_infer = None
         p_infer = None
         yield build_html_error_message(i18n("Infer interface is closed"))
         yield build_html_error_message(i18n("Infer interface is closed"))
@@ -585,7 +624,7 @@ def fresh_llama_model():
     )
     )
 
 
 
 
-def llama_lora_merge(llama_weight, lora_weight, llama_lora_output):
+def llama_lora_merge(llama_weight, lora_llama_config, lora_weight, llama_lora_output):
     if (
     if (
         lora_weight is None
         lora_weight is None
         or not Path(lora_weight).exists()
         or not Path(lora_weight).exists()
@@ -601,7 +640,7 @@ def llama_lora_merge(llama_weight, lora_weight, llama_lora_output):
         PYTHON,
         PYTHON,
         "tools/llama/merge_lora.py",
         "tools/llama/merge_lora.py",
         "--llama-config",
         "--llama-config",
-        "dual_ar_2_codebook_large",
+        lora_llama_config,
         "--lora-config",
         "--lora-config",
         "r_8_alpha_16",
         "r_8_alpha_16",
         "--llama-weight",
         "--llama-weight",
@@ -902,6 +941,15 @@ with gr.Blocks(
                                 allow_custom_value=True,
                                 allow_custom_value=True,
                                 interactive=True,
                                 interactive=True,
                             )
                             )
+                            lora_llama_config = gr.Dropdown(
+                                label=i18n("LLAMA Model Config"),
+                                choices=[
+                                    "dual_ar_2_codebook_large",
+                                    "dual_ar_2_codebook_medium",
+                                ],
+                                value="dual_ar_2_codebook_large",
+                                allow_custom_value=True,
+                            )
                         with gr.Row(equal_height=False):
                         with gr.Row(equal_height=False):
                             llama_lora_output = gr.Dropdown(
                             llama_lora_output = gr.Dropdown(
                                 label=i18n("Output Path"),
                                 label=i18n("Output Path"),
@@ -994,7 +1042,13 @@ with gr.Blocks(
                                         "Compile the model can significantly reduce the inference time, but will increase cold start time"
                                         "Compile the model can significantly reduce the inference time, but will increase cold start time"
                                     ),
                                     ),
                                     choices=["Yes", "No"],
                                     choices=["Yes", "No"],
-                                    value="Yes",
+                                    value="Yes"
+                                    if (
+                                        sys.platform == "linux"
+                                        or is_module_installed("triton")
+                                    )
+                                    else "No",
+                                    interactive=is_module_installed("triton"),
                                 )
                                 )
                                 infer_llama_config = gr.Dropdown(
                                 infer_llama_config = gr.Dropdown(
                                     label=i18n("LLAMA Model Config"),
                                     label=i18n("LLAMA Model Config"),
@@ -1134,7 +1188,7 @@ with gr.Blocks(
     llama_ckpt.change(fn=fresh_llama_ckpt, inputs=[], outputs=[llama_ckpt])
     llama_ckpt.change(fn=fresh_llama_ckpt, inputs=[], outputs=[llama_ckpt])
     llama_lora_merge_btn.click(
     llama_lora_merge_btn.click(
         fn=llama_lora_merge,
         fn=llama_lora_merge,
-        inputs=[llama_weight, lora_weight, llama_lora_output],
+        inputs=[llama_weight, lora_llama_config, lora_weight, llama_lora_output],
         outputs=[train_error],
         outputs=[train_error],
     )
     )
     infer_checkbox.change(
     infer_checkbox.change(

+ 10 - 1
tools/webui.py

@@ -39,6 +39,7 @@ HEADER_MD = f"""# Fish Speech
 
 
 TEXTBOX_PLACEHOLDER = i18n("Put your text here.")
 TEXTBOX_PLACEHOLDER = i18n("Put your text here.")
 SPACE_IMPORTED = False
 SPACE_IMPORTED = False
+cached_audio = np.zeros((1,))
 
 
 
 
 def build_html_error_message(error):
 def build_html_error_message(error):
@@ -122,6 +123,8 @@ def inference(
         yield wav_chunk_header(), None
         yield wav_chunk_header(), None
 
 
     segments = []
     segments = []
+    global cached_audio
+    cached_audio = np.zeros((1,))
     while True:
     while True:
         result = payload["response_queue"].get()
         result = payload["response_queue"].get()
         if result == "next":
         if result == "next":
@@ -141,6 +144,7 @@ def inference(
         fake_audios = fake_audios.float().cpu().numpy()
         fake_audios = fake_audios.float().cpu().numpy()
 
 
         if streaming:
         if streaming:
+            cached_audio = np.concatenate([cached_audio, fake_audios], axis=0)
             yield (fake_audios * 32768).astype(np.int16).tobytes(), None
             yield (fake_audios * 32768).astype(np.int16).tobytes(), None
         else:
         else:
             segments.append(fake_audios)
             segments.append(fake_audios)
@@ -296,6 +300,11 @@ def build_app():
             [audio, error],
             [audio, error],
             concurrency_limit=1,
             concurrency_limit=1,
         )
         )
+
+        def transfer_audio():
+            global cached_audio
+            return (vqgan_model.sampling_rate, cached_audio)
+
         generate_stream.click(
         generate_stream.click(
             inference_stream,
             inference_stream,
             [
             [
@@ -312,7 +321,7 @@ def build_app():
             ],
             ],
             [stream_audio, error],
             [stream_audio, error],
             concurrency_limit=10,
             concurrency_limit=10,
-        )
+        ).then(transfer_audio, None, audio)
     return app
     return app