Explorar el Código

Fix docs images etc. (#385)

* Add quick start ipynb

* Remove redundant output

* Fix docs

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [Feature] Add Fast Whisper

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Different suffix

* Different audio format

* Fix README.md for ja docs

* Fix ZH docs

* Fix docs images & WebUI

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix spelling

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Leng Yue <lengyue@lengyue.me>
spicysama hace 1 año
padre
commit
cee143d213
Se han modificado 6 ficheros con 111 adiciones y 66 borrados
  1. 2 2
      docs/en/index.md
  2. 2 2
      docs/ja/index.md
  3. 2 1
      docs/zh/index.md
  4. 88 57
      fish_speech/webui/manage.py
  5. 2 2
      mkdocs.yml
  6. 15 2
      tools/whisper_asr.py

+ 2 - 2
docs/en/index.md

@@ -18,7 +18,7 @@ We assume no responsibility for any illegal use of the codebase. Please refer to
 This codebase is released under the `BSD-3-Clause` license, and all models are released under the CC-BY-NC-SA-4.0 license.
 
 <p align="center">
-   <img src="/docs/assets/figs/diagram.png" width="75%">
+   <img src="../assets/figs/diagram.png" width="75%">
 </p>
 
 ## Requirements
@@ -63,7 +63,7 @@ Non-professional Windows users can consider the following methods to run the cod
                   <li>After installing Visual Studio Installer, download Visual Studio Community 2022.</li>
                   <li>Click the <code>Modify</code> button as shown below, find the <code>Desktop development with C++</code> option, and check it for download.</li>
                   <p align="center">
-                     <img src="/docs/assets/figs/VS_1.jpg" width="75%">
+                     <img src="../assets/figs/VS_1.jpg" width="75%">
                   </p>
                </ul>
             </li>

+ 2 - 2
docs/ja/index.md

@@ -18,7 +18,7 @@
 このコードベースは `BSD-3-Clause` ライセンスの下でリリースされており、すべてのモデルは CC-BY-NC-SA-4.0 ライセンスの下でリリースされています。
 
 <p align="center">
-   <img src="/docs/assets/figs/diagram.png" width="75%">
+   <img src="../assets/figs/diagram.png" width="75%">
 </p>
 
 ## 要件
@@ -63,7 +63,7 @@ Windows のプロユーザーは、コードベースを実行するために WS
                   <li>Visual Studio Installerをインストールした後、Visual Studio Community 2022をダウンロードします。</li>
                   <li>以下の図のように<code>Modify</code>ボタンをクリックし、<code>Desktop development with C++</code>オプションを見つけてチェックしてダウンロードします。</li>
                   <p align="center">
-                     <img src="/docs/assets/figs/VS_1.jpg" width="75%">
+                     <img src="../assets/figs/VS_1.jpg" width="75%">
                   </p>
                </ul>
             </li>

+ 2 - 1
docs/zh/index.md

@@ -18,7 +18,7 @@
 此代码库根据 `BSD-3-Clause` 许可证发布, 所有模型根据 CC-BY-NC-SA-4.0 许可证发布.
 
 <p align="center">
-  <img src="https://s2.loli.net/2024/05/11/h9qSpRboTs5dGMQ.png" width="75%">
+   <img src="../assets/figs/diagram.png" width="75%">
 </p>
 
 ## 要求
@@ -32,6 +32,7 @@ Windows 专业用户可以考虑 WSL2 或 docker 来运行代码库。
 
 Windows 非专业用户可考虑以下为免 Linux 环境的基础运行方法(附带模型编译功能,即 `torch.compile`):
 
+
 1. 解压项目压缩包。
 2. 点击 `install_env.bat` 安装环境。
     - 可以通过编辑 `install_env.bat` 的 `USE_MIRROR` 项来决定是否使用镜像站下载。

+ 88 - 57
fish_speech/webui/manage.py

@@ -251,7 +251,13 @@ def new_explorer(data_path, max_depth):
     )
 
 
-def add_item(folder: str, method: str, label_lang: str):
+def add_item(
+    folder: str,
+    method: str,
+    label_lang: str,
+    if_initial_prompt: bool,
+    initial_prompt: str | None,
+):
     folder = folder.strip(" ").strip('"')
 
     folder_path = Path(folder)
@@ -260,7 +266,10 @@ def add_item(folder: str, method: str, label_lang: str):
         if folder_path.is_dir():
             items.append(folder)
             dict_items[folder] = dict(
-                type="folder", method=method, label_lang=label_lang
+                type="folder",
+                method=method,
+                label_lang=label_lang,
+                initial_prompt=initial_prompt if if_initial_prompt else None,
             )
         elif folder:
             err = folder
@@ -269,7 +278,8 @@ def add_item(folder: str, method: str, label_lang: str):
             )
 
     formatted_data = json.dumps(dict_items, ensure_ascii=False, indent=4)
-    logger.info(formatted_data)
+    logger.info("After Adding: " + formatted_data)
+    gr.Info(formatted_data)
     return gr.Checkboxgroup(choices=items), build_html_ok_message(
         i18n("Added path successfully!")
     )
@@ -283,6 +293,7 @@ def remove_items(selected_items):
     items = [item for item in items if item in dict_items.keys()]
     formatted_data = json.dumps(dict_items, ensure_ascii=False, indent=4)
     logger.info(formatted_data)
+    gr.Warning("After Removing: " + formatted_data)
     return gr.Checkboxgroup(choices=items, value=[]), build_html_ok_message(
         i18n("Removed path successfully!")
     )
@@ -351,6 +362,7 @@ def list_copy(list_file_path, method):
 def check_files(data_path: str, max_depth: int, label_model: str, label_device: str):
     global dict_items
     data_path = Path(data_path)
+    gr.Warning("Pre-processing begins...")
     for item, content in dict_items.items():
         item_path = Path(item)
         tar_path = data_path / item_path.name
@@ -369,23 +381,31 @@ def check_files(data_path: str, max_depth: int, label_model: str, label_device:
                     convert_to_mono_in_place(audio_path)
 
             cur_lang = content["label_lang"]
+            initial_prompt = content["initial_prompt"]
+
+            transcribe_cmd = [
+                PYTHON,
+                "tools/whisper_asr.py",
+                "--model-size",
+                label_model,
+                "--device",
+                label_device,
+                "--audio-dir",
+                tar_path,
+                "--save-dir",
+                tar_path,
+                "--language",
+                cur_lang,
+            ]
+
+            if initial_prompt is not None:
+                transcribe_cmd += ["--initial-prompt", initial_prompt]
+
             if cur_lang != "IGNORE":
                 try:
+                    gr.Warning("Begin To Transcribe")
                     subprocess.run(
-                        [
-                            PYTHON,
-                            "tools/whisper_asr.py",
-                            "--model-size",
-                            label_model,
-                            "--device",
-                            label_device,
-                            "--audio-dir",
-                            tar_path,
-                            "--save-dir",
-                            tar_path,
-                            "--language",
-                            cur_lang,
-                        ],
+                        transcribe_cmd,
                         env=env,
                     )
                 except Exception:
@@ -408,8 +428,6 @@ def generate_folder_name():
 def train_process(
     data_path: str,
     option: str,
-    min_duration: float,
-    max_duration: float,
     # llama config
     llama_ckpt,
     llama_base_config,
@@ -428,13 +446,17 @@ def train_process(
     backend = "nccl" if sys.platform == "linux" else "gloo"
 
     new_project = generate_folder_name()
-
     print("New Project Name: ", new_project)
 
-    if min_duration > max_duration:
-        min_duration, max_duration = max_duration, min_duration
+    if option == "VQGAN":
+        msg = "Skipped VQGAN Training."
+        gr.Warning(msg)
+        logger.info(msg)
 
     if option == "LLAMA":
+        msg = "LLAMA Training begins..."
+        gr.Warning(msg)
+        logger.info(msg)
         subprocess.run(
             [
                 PYTHON,
@@ -565,13 +587,16 @@ def list_llama_models():
     choices = [str(p.parent) for p in Path("checkpoints").glob("merged*/*model*.pth")]
     choices += [str(p.parent) for p in Path("checkpoints").glob("fish*/*model*.pth")]
     choices += [str(p.parent) for p in Path("checkpoints").glob("fs*/*model*.pth")]
+    choices = sorted(choices, reverse=True)
     if not choices:
         logger.warning("No LLaMA model found")
     return choices
 
 
 def list_lora_llama_models():
-    choices = [str(p) for p in Path("results").glob("lora*/**/*.ckpt")]
+    choices = sorted(
+        [str(p) for p in Path("results").glob("lora*/**/*.ckpt")], reverse=True
+    )
     if not choices:
         logger.warning("No LoRA LLaMA model found")
     return choices
@@ -607,7 +632,7 @@ def llama_lora_merge(llama_weight, lora_llama_config, lora_weight, llama_lora_ou
                 "Path error, please check the model file exists in the corresponding path"
             )
         )
-
+    gr.Warning("Merging begins...")
     merge_cmd = [
         PYTHON,
         "tools/llama/merge_lora.py",
@@ -630,6 +655,9 @@ def llama_quantify(llama_weight, quantify_mode):
                 "Path error, please check the model file exists in the corresponding path"
             )
         )
+
+    gr.Warning("Quantifying begins...")
+
     now = generate_folder_name()
     quantify_cmd = [
         PYTHON,
@@ -690,30 +718,6 @@ with gr.Blocks(
                         if_label = gr.Checkbox(
                             label=i18n("Open Labeler WebUI"), scale=0, show_label=True
                         )
-                with gr.Row():
-                    min_duration = gr.Slider(
-                        label=i18n("Minimum Audio Duration"),
-                        value=1.5,
-                        step=0.1,
-                        minimum=0.4,
-                        maximum=30,
-                    )
-                    max_duration = gr.Slider(
-                        label=i18n("Maximum Audio Duration"),
-                        value=30,
-                        step=0.1,
-                        minimum=0.4,
-                        maximum=30,
-                    )
-
-                with gr.Row():
-                    add_button = gr.Button(
-                        "\U000027A1 " + i18n("Add to Processing Area"),
-                        variant="primary",
-                    )
-                    remove_button = gr.Button(
-                        "\U000026D4 " + i18n("Remove Selected Data")
-                    )
 
                 with gr.Row():
                     label_device = gr.Dropdown(
@@ -728,9 +732,9 @@ with gr.Blocks(
                     label_model = gr.Dropdown(
                         label=i18n("Whisper Model"),
                         info=i18n("Faster Whisper, Up to 5g GPU memory usage"),
-                        choices=["large-v3"],
+                        choices=["large-v3", "medium"],
                         value="large-v3",
-                        interactive=False,
+                        interactive=True,
                     )
                     label_radio = gr.Dropdown(
                         label=i18n("Optional Label Language"),
@@ -738,9 +742,9 @@ with gr.Blocks(
                             "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format"
                         ),
                         choices=[
-                            (i18n("Chinese"), "ZH"),
-                            (i18n("English"), "EN"),
-                            (i18n("Japanese"), "JA"),
+                            (i18n("Chinese"), "zh"),
+                            (i18n("English"), "en"),
+                            (i18n("Japanese"), "ja"),
                             (i18n("Disabled"), "IGNORE"),
                             (i18n("auto"), "auto"),
                         ],
@@ -748,6 +752,31 @@ with gr.Blocks(
                         interactive=True,
                     )
 
+                with gr.Row():
+                    if_initial_prompt = gr.Checkbox(
+                        value=False,
+                        label=i18n("Enable Initial Prompt"),
+                        min_width=120,
+                        scale=0,
+                    )
+                    initial_prompt = gr.Textbox(
+                        label=i18n("Initial Prompt"),
+                        info=i18n(
+                            "Initial prompt can provide contextual or vocabulary-specific guidance to the model."
+                        ),
+                        placeholder="This audio introduces the basic concepts and applications of artificial intelligence and machine learning.",
+                        interactive=False,
+                    )
+
+                with gr.Row():
+                    add_button = gr.Button(
+                        "\U000027A1 " + i18n("Add to Processing Area"),
+                        variant="primary",
+                    )
+                    remove_button = gr.Button(
+                        "\U000026D4 " + i18n("Remove Selected Data")
+                    )
+
             with gr.Tab("\U0001F6E0 " + i18n("Training Configuration")):
                 with gr.Row():
                     model_type_radio = gr.Radio(
@@ -1103,7 +1132,7 @@ with gr.Blocks(
     llama_page.select(lambda: "LLAMA", None, model_type_radio)
     add_button.click(
         fn=add_item,
-        inputs=[textbox, output_radio, label_radio],
+        inputs=[textbox, output_radio, label_radio, if_initial_prompt, initial_prompt],
         outputs=[checkbox_group, error],
     )
     remove_button.click(
@@ -1116,14 +1145,16 @@ with gr.Blocks(
         'toolbar=no, menubar=no, scrollbars=no, resizable=no, location=no, status=no")}',
     )
     if_label.change(fn=change_label, inputs=[if_label], outputs=[error])
-
+    if_initial_prompt.change(
+        fn=lambda x: gr.Textbox(value="", interactive=x),
+        inputs=[if_initial_prompt],
+        outputs=[initial_prompt],
+    )
     train_btn.click(
         fn=train_process,
         inputs=[
             train_box,
             model_type_radio,
-            min_duration,
-            max_duration,
             # llama config
             llama_ckpt,
             llama_base_config,

+ 2 - 2
mkdocs.yml

@@ -33,12 +33,12 @@ theme:
       toggle:
         icon: material/brightness-auto
         name: Switch to light mode
-  
+
     # Palette toggle for light mode
     - media: "(prefers-color-scheme: light)"
       scheme: default
       toggle:
-        icon: material/brightness-7 
+        icon: material/brightness-7
         name: Switch to dark mode
       primary: black
       font:

+ 15 - 2
tools/whisper_asr.py

@@ -54,7 +54,17 @@ from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files
 )
 @click.option("--device", default="cuda", help="Device to use [cuda / cpu]")
 @click.option("--language", default="auto", help="Language of the transcription")
-def main(model_size, compute_type, audio_dir, save_dir, sample_rate, device, language):
+@click.option("--initial-prompt", default=None, help="Initial prompt for transcribing")
+def main(
+    model_size,
+    compute_type,
+    audio_dir,
+    save_dir,
+    sample_rate,
+    device,
+    language,
+    initial_prompt,
+):
     logger.info("Loading / Downloading Faster Whisper model...")
 
     model = WhisperModel(
@@ -97,7 +107,10 @@ def main(model_size, compute_type, audio_dir, save_dir, sample_rate, device, lan
         audio = AudioSegment.from_file(file_path)
 
         segments, info = model.transcribe(
-            file_path, beam_size=5, language=None if language == "auto" else language
+            file_path,
+            beam_size=5,
+            language=None if language == "auto" else language,
+            initial_prompt=initial_prompt,
         )
 
         print(