Quellcode durchsuchen

Fix docs images etc. (#385)

* Add quick start ipynb

* Remove redundant output

* Fix docs

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [Feature] Add Fast Whisper

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Different suffix

* Different audio format

* Fix README.md for ja docs

* Fix ZH docs

* Fix docs images & WebUI

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix spelling

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Leng Yue <lengyue@lengyue.me>
spicysama vor 1 Jahr
Ursprung
Commit
cee143d213
6 geänderte Dateien mit 111 neuen und 66 gelöschten Zeilen
  1. 2 2
      docs/en/index.md
  2. 2 2
      docs/ja/index.md
  3. 2 1
      docs/zh/index.md
  4. 88 57
      fish_speech/webui/manage.py
  5. 2 2
      mkdocs.yml
  6. 15 2
      tools/whisper_asr.py

+ 2 - 2
docs/en/index.md

@@ -18,7 +18,7 @@ We assume no responsibility for any illegal use of the codebase. Please refer to
 This codebase is released under the `BSD-3-Clause` license, and all models are released under the CC-BY-NC-SA-4.0 license.
 This codebase is released under the `BSD-3-Clause` license, and all models are released under the CC-BY-NC-SA-4.0 license.
 
 
 <p align="center">
 <p align="center">
-   <img src="/docs/assets/figs/diagram.png" width="75%">
+   <img src="../assets/figs/diagram.png" width="75%">
 </p>
 </p>
 
 
 ## Requirements
 ## Requirements
@@ -63,7 +63,7 @@ Non-professional Windows users can consider the following methods to run the cod
                   <li>After installing Visual Studio Installer, download Visual Studio Community 2022.</li>
                   <li>After installing Visual Studio Installer, download Visual Studio Community 2022.</li>
                   <li>Click the <code>Modify</code> button as shown below, find the <code>Desktop development with C++</code> option, and check it for download.</li>
                   <li>Click the <code>Modify</code> button as shown below, find the <code>Desktop development with C++</code> option, and check it for download.</li>
                   <p align="center">
                   <p align="center">
-                     <img src="/docs/assets/figs/VS_1.jpg" width="75%">
+                     <img src="../assets/figs/VS_1.jpg" width="75%">
                   </p>
                   </p>
                </ul>
                </ul>
             </li>
             </li>

+ 2 - 2
docs/ja/index.md

@@ -18,7 +18,7 @@
 このコードベースは `BSD-3-Clause` ライセンスの下でリリースされており、すべてのモデルは CC-BY-NC-SA-4.0 ライセンスの下でリリースされています。
 このコードベースは `BSD-3-Clause` ライセンスの下でリリースされており、すべてのモデルは CC-BY-NC-SA-4.0 ライセンスの下でリリースされています。
 
 
 <p align="center">
 <p align="center">
-   <img src="/docs/assets/figs/diagram.png" width="75%">
+   <img src="../assets/figs/diagram.png" width="75%">
 </p>
 </p>
 
 
 ## 要件
 ## 要件
@@ -63,7 +63,7 @@ Windows のプロユーザーは、コードベースを実行するために WS
                   <li>Visual Studio Installerをインストールした後、Visual Studio Community 2022をダウンロードします。</li>
                   <li>Visual Studio Installerをインストールした後、Visual Studio Community 2022をダウンロードします。</li>
                   <li>以下の図のように<code>Modify</code>ボタンをクリックし、<code>Desktop development with C++</code>オプションを見つけてチェックしてダウンロードします。</li>
                   <li>以下の図のように<code>Modify</code>ボタンをクリックし、<code>Desktop development with C++</code>オプションを見つけてチェックしてダウンロードします。</li>
                   <p align="center">
                   <p align="center">
-                     <img src="/docs/assets/figs/VS_1.jpg" width="75%">
+                     <img src="../assets/figs/VS_1.jpg" width="75%">
                   </p>
                   </p>
                </ul>
                </ul>
             </li>
             </li>

+ 2 - 1
docs/zh/index.md

@@ -18,7 +18,7 @@
 此代码库根据 `BSD-3-Clause` 许可证发布, 所有模型根据 CC-BY-NC-SA-4.0 许可证发布.
 此代码库根据 `BSD-3-Clause` 许可证发布, 所有模型根据 CC-BY-NC-SA-4.0 许可证发布.
 
 
 <p align="center">
 <p align="center">
-  <img src="https://s2.loli.net/2024/05/11/h9qSpRboTs5dGMQ.png" width="75%">
+   <img src="../assets/figs/diagram.png" width="75%">
 </p>
 </p>
 
 
 ## 要求
 ## 要求
@@ -32,6 +32,7 @@ Windows 专业用户可以考虑 WSL2 或 docker 来运行代码库。
 
 
 Windows 非专业用户可考虑以下为免 Linux 环境的基础运行方法(附带模型编译功能,即 `torch.compile`):
 Windows 非专业用户可考虑以下为免 Linux 环境的基础运行方法(附带模型编译功能,即 `torch.compile`):
 
 
+
 1. 解压项目压缩包。
 1. 解压项目压缩包。
 2. 点击 `install_env.bat` 安装环境。
 2. 点击 `install_env.bat` 安装环境。
     - 可以通过编辑 `install_env.bat` 的 `USE_MIRROR` 项来决定是否使用镜像站下载。
     - 可以通过编辑 `install_env.bat` 的 `USE_MIRROR` 项来决定是否使用镜像站下载。

+ 88 - 57
fish_speech/webui/manage.py

@@ -251,7 +251,13 @@ def new_explorer(data_path, max_depth):
     )
     )
 
 
 
 
-def add_item(folder: str, method: str, label_lang: str):
+def add_item(
+    folder: str,
+    method: str,
+    label_lang: str,
+    if_initial_prompt: bool,
+    initial_prompt: str | None,
+):
     folder = folder.strip(" ").strip('"')
     folder = folder.strip(" ").strip('"')
 
 
     folder_path = Path(folder)
     folder_path = Path(folder)
@@ -260,7 +266,10 @@ def add_item(folder: str, method: str, label_lang: str):
         if folder_path.is_dir():
         if folder_path.is_dir():
             items.append(folder)
             items.append(folder)
             dict_items[folder] = dict(
             dict_items[folder] = dict(
-                type="folder", method=method, label_lang=label_lang
+                type="folder",
+                method=method,
+                label_lang=label_lang,
+                initial_prompt=initial_prompt if if_initial_prompt else None,
             )
             )
         elif folder:
         elif folder:
             err = folder
             err = folder
@@ -269,7 +278,8 @@ def add_item(folder: str, method: str, label_lang: str):
             )
             )
 
 
     formatted_data = json.dumps(dict_items, ensure_ascii=False, indent=4)
     formatted_data = json.dumps(dict_items, ensure_ascii=False, indent=4)
-    logger.info(formatted_data)
+    logger.info("After Adding: " + formatted_data)
+    gr.Info(formatted_data)
     return gr.Checkboxgroup(choices=items), build_html_ok_message(
     return gr.Checkboxgroup(choices=items), build_html_ok_message(
         i18n("Added path successfully!")
         i18n("Added path successfully!")
     )
     )
@@ -283,6 +293,7 @@ def remove_items(selected_items):
     items = [item for item in items if item in dict_items.keys()]
     items = [item for item in items if item in dict_items.keys()]
     formatted_data = json.dumps(dict_items, ensure_ascii=False, indent=4)
     formatted_data = json.dumps(dict_items, ensure_ascii=False, indent=4)
     logger.info(formatted_data)
     logger.info(formatted_data)
+    gr.Warning("After Removing: " + formatted_data)
     return gr.Checkboxgroup(choices=items, value=[]), build_html_ok_message(
     return gr.Checkboxgroup(choices=items, value=[]), build_html_ok_message(
         i18n("Removed path successfully!")
         i18n("Removed path successfully!")
     )
     )
@@ -351,6 +362,7 @@ def list_copy(list_file_path, method):
 def check_files(data_path: str, max_depth: int, label_model: str, label_device: str):
 def check_files(data_path: str, max_depth: int, label_model: str, label_device: str):
     global dict_items
     global dict_items
     data_path = Path(data_path)
     data_path = Path(data_path)
+    gr.Warning("Pre-processing begins...")
     for item, content in dict_items.items():
     for item, content in dict_items.items():
         item_path = Path(item)
         item_path = Path(item)
         tar_path = data_path / item_path.name
         tar_path = data_path / item_path.name
@@ -369,23 +381,31 @@ def check_files(data_path: str, max_depth: int, label_model: str, label_device:
                     convert_to_mono_in_place(audio_path)
                     convert_to_mono_in_place(audio_path)
 
 
             cur_lang = content["label_lang"]
             cur_lang = content["label_lang"]
+            initial_prompt = content["initial_prompt"]
+
+            transcribe_cmd = [
+                PYTHON,
+                "tools/whisper_asr.py",
+                "--model-size",
+                label_model,
+                "--device",
+                label_device,
+                "--audio-dir",
+                tar_path,
+                "--save-dir",
+                tar_path,
+                "--language",
+                cur_lang,
+            ]
+
+            if initial_prompt is not None:
+                transcribe_cmd += ["--initial-prompt", initial_prompt]
+
             if cur_lang != "IGNORE":
             if cur_lang != "IGNORE":
                 try:
                 try:
+                    gr.Warning("Begin To Transcribe")
                     subprocess.run(
                     subprocess.run(
-                        [
-                            PYTHON,
-                            "tools/whisper_asr.py",
-                            "--model-size",
-                            label_model,
-                            "--device",
-                            label_device,
-                            "--audio-dir",
-                            tar_path,
-                            "--save-dir",
-                            tar_path,
-                            "--language",
-                            cur_lang,
-                        ],
+                        transcribe_cmd,
                         env=env,
                         env=env,
                     )
                     )
                 except Exception:
                 except Exception:
@@ -408,8 +428,6 @@ def generate_folder_name():
 def train_process(
 def train_process(
     data_path: str,
     data_path: str,
     option: str,
     option: str,
-    min_duration: float,
-    max_duration: float,
     # llama config
     # llama config
     llama_ckpt,
     llama_ckpt,
     llama_base_config,
     llama_base_config,
@@ -428,13 +446,17 @@ def train_process(
     backend = "nccl" if sys.platform == "linux" else "gloo"
     backend = "nccl" if sys.platform == "linux" else "gloo"
 
 
     new_project = generate_folder_name()
     new_project = generate_folder_name()
-
     print("New Project Name: ", new_project)
     print("New Project Name: ", new_project)
 
 
-    if min_duration > max_duration:
-        min_duration, max_duration = max_duration, min_duration
+    if option == "VQGAN":
+        msg = "Skipped VQGAN Training."
+        gr.Warning(msg)
+        logger.info(msg)
 
 
     if option == "LLAMA":
     if option == "LLAMA":
+        msg = "LLAMA Training begins..."
+        gr.Warning(msg)
+        logger.info(msg)
         subprocess.run(
         subprocess.run(
             [
             [
                 PYTHON,
                 PYTHON,
@@ -565,13 +587,16 @@ def list_llama_models():
     choices = [str(p.parent) for p in Path("checkpoints").glob("merged*/*model*.pth")]
     choices = [str(p.parent) for p in Path("checkpoints").glob("merged*/*model*.pth")]
     choices += [str(p.parent) for p in Path("checkpoints").glob("fish*/*model*.pth")]
     choices += [str(p.parent) for p in Path("checkpoints").glob("fish*/*model*.pth")]
     choices += [str(p.parent) for p in Path("checkpoints").glob("fs*/*model*.pth")]
     choices += [str(p.parent) for p in Path("checkpoints").glob("fs*/*model*.pth")]
+    choices = sorted(choices, reverse=True)
     if not choices:
     if not choices:
         logger.warning("No LLaMA model found")
         logger.warning("No LLaMA model found")
     return choices
     return choices
 
 
 
 
 def list_lora_llama_models():
 def list_lora_llama_models():
-    choices = [str(p) for p in Path("results").glob("lora*/**/*.ckpt")]
+    choices = sorted(
+        [str(p) for p in Path("results").glob("lora*/**/*.ckpt")], reverse=True
+    )
     if not choices:
     if not choices:
         logger.warning("No LoRA LLaMA model found")
         logger.warning("No LoRA LLaMA model found")
     return choices
     return choices
@@ -607,7 +632,7 @@ def llama_lora_merge(llama_weight, lora_llama_config, lora_weight, llama_lora_ou
                 "Path error, please check the model file exists in the corresponding path"
                 "Path error, please check the model file exists in the corresponding path"
             )
             )
         )
         )
-
+    gr.Warning("Merging begins...")
     merge_cmd = [
     merge_cmd = [
         PYTHON,
         PYTHON,
         "tools/llama/merge_lora.py",
         "tools/llama/merge_lora.py",
@@ -630,6 +655,9 @@ def llama_quantify(llama_weight, quantify_mode):
                 "Path error, please check the model file exists in the corresponding path"
                 "Path error, please check the model file exists in the corresponding path"
             )
             )
         )
         )
+
+    gr.Warning("Quantifying begins...")
+
     now = generate_folder_name()
     now = generate_folder_name()
     quantify_cmd = [
     quantify_cmd = [
         PYTHON,
         PYTHON,
@@ -690,30 +718,6 @@ with gr.Blocks(
                         if_label = gr.Checkbox(
                         if_label = gr.Checkbox(
                             label=i18n("Open Labeler WebUI"), scale=0, show_label=True
                             label=i18n("Open Labeler WebUI"), scale=0, show_label=True
                         )
                         )
-                with gr.Row():
-                    min_duration = gr.Slider(
-                        label=i18n("Minimum Audio Duration"),
-                        value=1.5,
-                        step=0.1,
-                        minimum=0.4,
-                        maximum=30,
-                    )
-                    max_duration = gr.Slider(
-                        label=i18n("Maximum Audio Duration"),
-                        value=30,
-                        step=0.1,
-                        minimum=0.4,
-                        maximum=30,
-                    )
-
-                with gr.Row():
-                    add_button = gr.Button(
-                        "\U000027A1 " + i18n("Add to Processing Area"),
-                        variant="primary",
-                    )
-                    remove_button = gr.Button(
-                        "\U000026D4 " + i18n("Remove Selected Data")
-                    )
 
 
                 with gr.Row():
                 with gr.Row():
                     label_device = gr.Dropdown(
                     label_device = gr.Dropdown(
@@ -728,9 +732,9 @@ with gr.Blocks(
                     label_model = gr.Dropdown(
                     label_model = gr.Dropdown(
                         label=i18n("Whisper Model"),
                         label=i18n("Whisper Model"),
                         info=i18n("Faster Whisper, Up to 5g GPU memory usage"),
                         info=i18n("Faster Whisper, Up to 5g GPU memory usage"),
-                        choices=["large-v3"],
+                        choices=["large-v3", "medium"],
                         value="large-v3",
                         value="large-v3",
-                        interactive=False,
+                        interactive=True,
                     )
                     )
                     label_radio = gr.Dropdown(
                     label_radio = gr.Dropdown(
                         label=i18n("Optional Label Language"),
                         label=i18n("Optional Label Language"),
@@ -738,9 +742,9 @@ with gr.Blocks(
                             "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format"
                             "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format"
                         ),
                         ),
                         choices=[
                         choices=[
-                            (i18n("Chinese"), "ZH"),
-                            (i18n("English"), "EN"),
-                            (i18n("Japanese"), "JA"),
+                            (i18n("Chinese"), "zh"),
+                            (i18n("English"), "en"),
+                            (i18n("Japanese"), "ja"),
                             (i18n("Disabled"), "IGNORE"),
                             (i18n("Disabled"), "IGNORE"),
                             (i18n("auto"), "auto"),
                             (i18n("auto"), "auto"),
                         ],
                         ],
@@ -748,6 +752,31 @@ with gr.Blocks(
                         interactive=True,
                         interactive=True,
                     )
                     )
 
 
+                with gr.Row():
+                    if_initial_prompt = gr.Checkbox(
+                        value=False,
+                        label=i18n("Enable Initial Prompt"),
+                        min_width=120,
+                        scale=0,
+                    )
+                    initial_prompt = gr.Textbox(
+                        label=i18n("Initial Prompt"),
+                        info=i18n(
+                            "Initial prompt can provide contextual or vocabulary-specific guidance to the model."
+                        ),
+                        placeholder="This audio introduces the basic concepts and applications of artificial intelligence and machine learning.",
+                        interactive=False,
+                    )
+
+                with gr.Row():
+                    add_button = gr.Button(
+                        "\U000027A1 " + i18n("Add to Processing Area"),
+                        variant="primary",
+                    )
+                    remove_button = gr.Button(
+                        "\U000026D4 " + i18n("Remove Selected Data")
+                    )
+
             with gr.Tab("\U0001F6E0 " + i18n("Training Configuration")):
             with gr.Tab("\U0001F6E0 " + i18n("Training Configuration")):
                 with gr.Row():
                 with gr.Row():
                     model_type_radio = gr.Radio(
                     model_type_radio = gr.Radio(
@@ -1103,7 +1132,7 @@ with gr.Blocks(
     llama_page.select(lambda: "LLAMA", None, model_type_radio)
     llama_page.select(lambda: "LLAMA", None, model_type_radio)
     add_button.click(
     add_button.click(
         fn=add_item,
         fn=add_item,
-        inputs=[textbox, output_radio, label_radio],
+        inputs=[textbox, output_radio, label_radio, if_initial_prompt, initial_prompt],
         outputs=[checkbox_group, error],
         outputs=[checkbox_group, error],
     )
     )
     remove_button.click(
     remove_button.click(
@@ -1116,14 +1145,16 @@ with gr.Blocks(
         'toolbar=no, menubar=no, scrollbars=no, resizable=no, location=no, status=no")}',
         'toolbar=no, menubar=no, scrollbars=no, resizable=no, location=no, status=no")}',
     )
     )
     if_label.change(fn=change_label, inputs=[if_label], outputs=[error])
     if_label.change(fn=change_label, inputs=[if_label], outputs=[error])
-
+    if_initial_prompt.change(
+        fn=lambda x: gr.Textbox(value="", interactive=x),
+        inputs=[if_initial_prompt],
+        outputs=[initial_prompt],
+    )
     train_btn.click(
     train_btn.click(
         fn=train_process,
         fn=train_process,
         inputs=[
         inputs=[
             train_box,
             train_box,
             model_type_radio,
             model_type_radio,
-            min_duration,
-            max_duration,
             # llama config
             # llama config
             llama_ckpt,
             llama_ckpt,
             llama_base_config,
             llama_base_config,

+ 2 - 2
mkdocs.yml

@@ -33,12 +33,12 @@ theme:
       toggle:
       toggle:
         icon: material/brightness-auto
         icon: material/brightness-auto
         name: Switch to light mode
         name: Switch to light mode
-  
+
     # Palette toggle for light mode
     # Palette toggle for light mode
     - media: "(prefers-color-scheme: light)"
     - media: "(prefers-color-scheme: light)"
       scheme: default
       scheme: default
       toggle:
       toggle:
-        icon: material/brightness-7 
+        icon: material/brightness-7
         name: Switch to dark mode
         name: Switch to dark mode
       primary: black
       primary: black
       font:
       font:

+ 15 - 2
tools/whisper_asr.py

@@ -54,7 +54,17 @@ from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files
 )
 )
 @click.option("--device", default="cuda", help="Device to use [cuda / cpu]")
 @click.option("--device", default="cuda", help="Device to use [cuda / cpu]")
 @click.option("--language", default="auto", help="Language of the transcription")
 @click.option("--language", default="auto", help="Language of the transcription")
-def main(model_size, compute_type, audio_dir, save_dir, sample_rate, device, language):
+@click.option("--initial-prompt", default=None, help="Initial prompt for transcribing")
+def main(
+    model_size,
+    compute_type,
+    audio_dir,
+    save_dir,
+    sample_rate,
+    device,
+    language,
+    initial_prompt,
+):
     logger.info("Loading / Downloading Faster Whisper model...")
     logger.info("Loading / Downloading Faster Whisper model...")
 
 
     model = WhisperModel(
     model = WhisperModel(
@@ -97,7 +107,10 @@ def main(model_size, compute_type, audio_dir, save_dir, sample_rate, device, lan
         audio = AudioSegment.from_file(file_path)
         audio = AudioSegment.from_file(file_path)
 
 
         segments, info = model.transcribe(
         segments, info = model.transcribe(
-            file_path, beam_size=5, language=None if language == "auto" else language
+            file_path,
+            beam_size=5,
+            language=None if language == "auto" else language,
+            initial_prompt=initial_prompt,
         )
         )
 
 
         print(
         print(