Преглед изворни кода

Fix backend (#627)

* Linux pyaudio dependencies

* revert generate.py

* Better bug report & feat request

* Auto-select torchaudio backend

* safety

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* feat: manual seed for restore

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Gradio > 5

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
spicysama пре 1 година
родитељ
комит
e37a445f51

+ 1 - 1
docs/en/inference.md

@@ -122,7 +122,7 @@ python -m tools.webui \
 ```
 
 !!! note
-    You can save the label file and reference audio file in advance to the examples folder in the main directory (which you need to create yourself), so that you can directly call them in the WebUI.
+    You can save the label file and reference audio file in advance to the `references` folder in the main directory (which you need to create yourself), so that you can directly call them in the WebUI.
 
 !!! note
     You can use Gradio environment variables, such as `GRADIO_SHARE`, `GRADIO_SERVER_PORT`, `GRADIO_SERVER_NAME` to configure WebUI.

+ 1 - 1
docs/ja/inference.md

@@ -152,7 +152,7 @@ python -m tools.webui \
 ```
 
 !!! note
-    ラベルファイルと参照音声ファイルをメインディレクトリの examples フォルダ(自分で作成する必要があります)に事前に保存しておくことで、WebUI で直接呼び出すことができます。
+    ラベルファイルと参照音声ファイルをメインディレクトリの `references` フォルダ(自分で作成する必要があります)に事前に保存しておくことで、WebUI で直接呼び出すことができます。
 
 !!! note
     Gradio 環境変数(`GRADIO_SHARE`、`GRADIO_SERVER_PORT`、`GRADIO_SERVER_NAME`など)を使用して WebUI を構成できます。

+ 1 - 1
docs/pt/inference.md

@@ -148,7 +148,7 @@ python -m tools.webui \
 ```
 
 !!! note
-    Você pode salvar antecipadamente o arquivo de rótulos e o arquivo de áudio de referência na pasta examples do diretório principal (que você precisa criar), para que possa chamá-los diretamente na WebUI.
+    Você pode salvar antecipadamente o arquivo de rótulos e o arquivo de áudio de referência na pasta `references` do diretório principal (que você precisa criar), para que possa chamá-los diretamente na WebUI.
     
 !!! note
     É possível usar variáveis de ambiente do Gradio, como `GRADIO_SHARE`, `GRADIO_SERVER_PORT`, `GRADIO_SERVER_NAME`, para configurar a WebUI.

+ 1 - 1
docs/zh/inference.md

@@ -132,7 +132,7 @@ python -m tools.webui \
 ```
 
 !!! note
-    你可以提前将label文件和参考音频文件保存到主目录下的examples文件夹(需要自行创建),这样你可以直接在WebUI中调用它们。
+    你可以提前将label文件和参考音频文件保存到主目录下的 `references` 文件夹(需要自行创建),这样你可以直接在WebUI中调用它们。
 
 !!! note
     你可以使用 Gradio 环境变量, 如 `GRADIO_SHARE`, `GRADIO_SERVER_PORT`, `GRADIO_SERVER_NAME` 来配置 WebUI.

+ 4 - 1
fish_speech/models/text2semantic/llama.py

@@ -369,7 +369,10 @@ class BaseTransformer(nn.Module):
                 model = simple_quantizer.convert_for_runtime()
 
             weights = torch.load(
-                Path(path) / "model.pth", map_location="cpu", mmap=True
+                Path(path) / "model.pth",
+                map_location="cpu",
+                mmap=True,
+                weights_only=True,
             )
 
             if "state_dict" in weights:

+ 2 - 1
fish_speech/utils/__init__.py

@@ -5,7 +5,7 @@ from .instantiators import instantiate_callbacks, instantiate_loggers
 from .logger import RankedLogger
 from .logging_utils import log_hyperparameters
 from .rich_utils import enforce_tags, print_config_tree
-from .utils import extras, get_metric_value, task_wrapper
+from .utils import extras, get_metric_value, set_seed, task_wrapper
 
 __all__ = [
     "enforce_tags",
@@ -20,4 +20,5 @@ __all__ = [
     "braceexpand",
     "get_latest_checkpoint",
     "autocast_exclude_mps",
+    "set_seed",
 ]

+ 22 - 0
fish_speech/utils/utils.py

@@ -1,7 +1,10 @@
+import random
 import warnings
 from importlib.util import find_spec
 from typing import Callable
 
+import numpy as np
+import torch
 from omegaconf import DictConfig
 
 from .logger import RankedLogger
@@ -112,3 +115,22 @@ def get_metric_value(metric_dict: dict, metric_name: str) -> float:
     log.info(f"Retrieved metric value! <{metric_name}={metric_value}>")
 
     return metric_value
+
+
+def set_seed(seed: int):
+    if seed < 0:
+        seed = -seed
+    if seed > (1 << 31):
+        seed = 1 << 31
+
+    random.seed(seed)
+    np.random.seed(seed)
+    torch.manual_seed(seed)
+
+    if torch.cuda.is_available():
+        torch.cuda.manual_seed(seed)
+        torch.cuda.manual_seed_all(seed)
+
+    if torch.backends.cudnn.is_available():
+        torch.backends.cudnn.deterministic = True
+        torch.backends.cudnn.benchmark = False

+ 1 - 1
fish_speech/webui/launch_utils.py

@@ -114,7 +114,7 @@ class Seafoam(Base):
             block_title_text_weight="600",
             block_border_width="3px",
             block_shadow="*shadow_drop_lg",
-            button_shadow="*shadow_drop_lg",
+            # button_shadow="*shadow_drop_lg",
             button_small_padding="0px",
             button_large_padding="3px",
         )

+ 1 - 1
fish_speech/webui/manage.py

@@ -794,7 +794,7 @@ with gr.Blocks(
                         value="VQGAN",
                     )
                 with gr.Row():
-                    with gr.Tabs():
+                    with gr.Column():
                         with gr.Tab(label=i18n("VQGAN Configuration")) as vqgan_page:
                             gr.HTML("You don't need to train this model!")
 

+ 1 - 1
pyproject.toml

@@ -23,7 +23,7 @@ dependencies = [
     "einops>=0.7.0",
     "librosa>=0.10.1",
     "rich>=13.5.3",
-    "gradio<5.0.0",
+    "gradio>5.0.0",
     "wandb>=0.15.11",
     "grpcio>=1.58.0",
     "kui>=1.6.0",

+ 14 - 5
tools/api.py

@@ -35,7 +35,7 @@ pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
 # from fish_speech.models.vqgan.lit_module import VQGAN
 from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture
 from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
-from fish_speech.utils import autocast_exclude_mps
+from fish_speech.utils import autocast_exclude_mps, set_seed
 from tools.commons import ServeTTSRequest
 from tools.file import AUDIO_EXTENSIONS, audio_to_bytes, list_files, read_ref_text
 from tools.llama.generate import (
@@ -46,6 +46,14 @@ from tools.llama.generate import (
 )
 from tools.vqgan.inference import load_model as load_decoder_model
 
+backends = torchaudio.list_audio_backends()
+if "sox" in backends:
+    backend = "sox"
+elif "ffmpeg" in backends:
+    backend = "ffmpeg"
+else:
+    backend = "soundfile"
+
 
 def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
     buffer = io.BytesIO()
@@ -88,10 +96,7 @@ def load_audio(reference_audio, sr):
         audio_data = reference_audio
         reference_audio = io.BytesIO(audio_data)
 
-    waveform, original_sr = torchaudio.load(
-        reference_audio,
-        backend="soundfile",  # not every linux release supports 'sox' or 'ffmpeg'
-    )
+    waveform, original_sr = torchaudio.load(reference_audio, backend=backend)
 
     if waveform.shape[0] > 1:
         waveform = torch.mean(waveform, dim=0, keepdim=True)
@@ -215,6 +220,10 @@ def inference(req: ServeTTSRequest):
         else:
             logger.info("Use same references")
 
+    if req.seed is not None:
+        set_seed(req.seed)
+        logger.warning(f"set seed: {req.seed}")
+
     # LLAMA Inference
     request = dict(
         device=decoder_model.device,

+ 1 - 0
tools/commons.py

@@ -20,6 +20,7 @@ class ServeTTSRequest(BaseModel):
     # For example, if you want use https://fish.audio/m/7f92f8afb8ec43bf81429cc1c9199cb1/
     # Just pass 7f92f8afb8ec43bf81429cc1c9199cb1
     reference_id: str | None = None
+    seed: int | None = None
     use_memory_cache: Literal["on-demand", "never"] = "never"
     # Normalize text for en & zh, this increase stability for numbers
     normalize: bool = True

+ 7 - 0
tools/post_api.py

@@ -109,6 +109,12 @@ def parse_args():
         default="never",
         help="Cache encoded references codes in memory",
     )
+    parser.add_argument(
+        "--seed",
+        type=int,
+        default=None,
+        help="None means randomized inference, otherwise deterministic",
+    )
 
     return parser.parse_args()
 
@@ -155,6 +161,7 @@ if __name__ == "__main__":
         "emotion": args.emotion,
         "streaming": args.streaming,
         "use_memory_cache": args.use_memory_cache,
+        "seed": args.seed,
     }
 
     pydantic_data = ServeTTSRequest(**data)

+ 8 - 1
tools/vqgan/extract_vq.py

@@ -24,6 +24,13 @@ OmegaConf.register_new_resolver("eval", eval)
 # This file is used to convert the audio files to text files using the Whisper model.
 # It's mainly used to generate the training data for the VQ model.
 
+backends = torchaudio.list_audio_backends()
+if "sox" in backends:
+    backend = "sox"
+elif "ffmpeg" in backends:
+    backend = "ffmpeg"
+else:
+    backend = "soundfile"
 
 RANK = int(os.environ.get("SLURM_PROCID", 0))
 WORLD_SIZE = int(os.environ.get("SLURM_NTASKS", 1))
@@ -81,7 +88,7 @@ def process_batch(files: list[Path], model) -> float:
     for file in files:
         try:
             wav, sr = torchaudio.load(
-                str(file), backend="sox" if sys.platform == "linux" else "soundfile"
+                str(file), backend=backend
             )  # Need to install libsox-dev
         except Exception as e:
             logger.error(f"Error reading {file}: {e}")

+ 2 - 3
tools/vqgan/inference.py

@@ -24,8 +24,7 @@ def load_model(config_name, checkpoint_path, device="cuda"):
 
     model = instantiate(cfg)
     state_dict = torch.load(
-        checkpoint_path,
-        map_location=device,
+        checkpoint_path, map_location=device, mmap=True, weights_only=True
     )
     if "state_dict" in state_dict:
         state_dict = state_dict["state_dict"]
@@ -37,7 +36,7 @@ def load_model(config_name, checkpoint_path, device="cuda"):
             if "generator." in k
         }
 
-    result = model.load_state_dict(state_dict, strict=False)
+    result = model.load_state_dict(state_dict, strict=False, assign=True)
     model.eval()
     model.to(device)
 

+ 118 - 91
tools/webui.py

@@ -21,8 +21,9 @@ pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
 
 from fish_speech.i18n import i18n
 from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
-from fish_speech.utils import autocast_exclude_mps
+from fish_speech.utils import autocast_exclude_mps, set_seed
 from tools.api import decode_vq_tokens, encode_reference
+from tools.file import AUDIO_EXTENSIONS, list_files
 from tools.llama.generate import (
     GenerateRequest,
     GenerateResponse,
@@ -70,6 +71,7 @@ def inference(
     top_p,
     repetition_penalty,
     temperature,
+    seed="0",
     streaming=False,
 ):
     if args.max_gradio_length > 0 and len(text) > args.max_gradio_length:
@@ -81,6 +83,11 @@ def inference(
             ),
         )
 
+    seed = int(seed)
+    if seed != 0:
+        set_seed(seed)
+        logger.warning(f"set seed: {seed}")
+
     # Parse reference audio aka prompt
     prompt_tokens = encode_reference(
         decoder_model=decoder_model,
@@ -177,6 +184,7 @@ def inference_wrapper(
     top_p,
     repetition_penalty,
     temperature,
+    seed,
     batch_infer_num,
 ):
     audios = []
@@ -193,6 +201,7 @@ def inference_wrapper(
             top_p,
             repetition_penalty,
             temperature,
+            seed,
         )
 
         _, audio_data, error_message = next(result)
@@ -235,7 +244,11 @@ def normalize_text(user_input, use_normalization):
         return user_input
 
 
-asr_model = None
+def update_examples():
+    examples_dir = Path("references")
+    examples_dir.mkdir(parents=True, exist_ok=True)
+    example_audios = list_files(examples_dir, AUDIO_EXTENSIONS, recursive=True)
+    return gr.Dropdown(choices=example_audios + [""])
 
 
 def build_app():
@@ -273,90 +286,100 @@ def build_app():
                     )
 
                 with gr.Row():
-                    with gr.Tab(label=i18n("Advanced Config")):
-                        chunk_length = gr.Slider(
-                            label=i18n("Iterative Prompt Length, 0 means off"),
-                            minimum=50,
-                            maximum=300,
-                            value=200,
-                            step=8,
-                        )
-
-                        max_new_tokens = gr.Slider(
-                            label=i18n("Maximum tokens per batch, 0 means no limit"),
-                            minimum=0,
-                            maximum=2048,
-                            value=0,  # 0 means no limit
-                            step=8,
-                        )
-
-                        top_p = gr.Slider(
-                            label="Top-P",
-                            minimum=0.6,
-                            maximum=0.9,
-                            value=0.7,
-                            step=0.01,
-                        )
-
-                        repetition_penalty = gr.Slider(
-                            label=i18n("Repetition Penalty"),
-                            minimum=1,
-                            maximum=1.5,
-                            value=1.2,
-                            step=0.01,
-                        )
-
-                        temperature = gr.Slider(
-                            label="Temperature",
-                            minimum=0.6,
-                            maximum=0.9,
-                            value=0.7,
-                            step=0.01,
-                        )
-
-                    with gr.Tab(label=i18n("Reference Audio")):
-                        gr.Markdown(
-                            i18n(
-                                "5 to 10 seconds of reference audio, useful for specifying speaker."
-                            )
-                        )
-
-                        enable_reference_audio = gr.Checkbox(
-                            label=i18n("Enable Reference Audio"),
-                        )
-
-                        # Add dropdown for selecting example audio files
-                        examples_dir = Path("examples")
-                        if not examples_dir.exists():
-                            examples_dir.mkdir()
-                        example_audio_files = [
-                            f.name for f in examples_dir.glob("*.wav")
-                        ] + [f.name for f in examples_dir.glob("*.mp3")]
-                        example_audio_dropdown = gr.Dropdown(
-                            label=i18n("Select Example Audio"),
-                            choices=[""] + example_audio_files,
-                            value="",
-                        )
-
-                        reference_audio = gr.Audio(
-                            label=i18n("Reference Audio"),
-                            type="filepath",
-                        )
-                        with gr.Row():
-                            reference_text = gr.Textbox(
-                                label=i18n("Reference Text"),
-                                lines=1,
-                                placeholder="在一无所知中,梦里的一天结束了,一个新的「轮回」便会开始。",
-                                value="",
-                            )
-                    with gr.Tab(label=i18n("Batch Inference")):
-                        batch_infer_num = gr.Slider(
-                            label="Batch infer nums",
-                            minimum=1,
-                            maximum=n_audios,
-                            step=1,
-                            value=1,
-                        )
+                    with gr.Column():
+                        with gr.Tab(label=i18n("Advanced Config")):
+                            with gr.Row():
+                                chunk_length = gr.Slider(
+                                    label=i18n("Iterative Prompt Length, 0 means off"),
+                                    minimum=50,
+                                    maximum=300,
+                                    value=200,
+                                    step=8,
+                                )
+
+                                max_new_tokens = gr.Slider(
+                                    label=i18n(
+                                        "Maximum tokens per batch, 0 means no limit"
+                                    ),
+                                    minimum=0,
+                                    maximum=2048,
+                                    value=0,  # 0 means no limit
+                                    step=8,
+                                )
+
+                            with gr.Row():
+                                top_p = gr.Slider(
+                                    label="Top-P",
+                                    minimum=0.6,
+                                    maximum=0.9,
+                                    value=0.7,
+                                    step=0.01,
+                                )
+
+                                repetition_penalty = gr.Slider(
+                                    label=i18n("Repetition Penalty"),
+                                    minimum=1,
+                                    maximum=1.5,
+                                    value=1.2,
+                                    step=0.01,
+                                )
+
+                            with gr.Row():
+                                temperature = gr.Slider(
+                                    label="Temperature",
+                                    minimum=0.6,
+                                    maximum=0.9,
+                                    value=0.7,
+                                    step=0.01,
+                                )
+                                seed = gr.Textbox(
+                                    label="Seed",
+                                    info="0 means randomized inference, otherwise deterministic",
+                                    placeholder="any 32-bit-integer",
+                                    value="0",
+                                )
+
+                        with gr.Tab(label=i18n("Reference Audio")):
+                            with gr.Row():
+                                gr.Markdown(
+                                    i18n(
+                                        "5 to 10 seconds of reference audio, useful for specifying speaker."
+                                    )
+                                )
+                            with gr.Row():
+                                enable_reference_audio = gr.Checkbox(
+                                    label=i18n("Enable Reference Audio"),
+                                )
+
+                            with gr.Row():
+                                example_audio_dropdown = gr.Dropdown(
+                                    label=i18n("Select Example Audio"),
+                                    choices=[""],
+                                    value="",
+                                    interactive=True,
+                                    allow_custom_value=True,
+                                )
+                            with gr.Row():
+                                reference_audio = gr.Audio(
+                                    label=i18n("Reference Audio"),
+                                    type="filepath",
+                                )
+                            with gr.Row():
+                                reference_text = gr.Textbox(
+                                    label=i18n("Reference Text"),
+                                    lines=1,
+                                    placeholder="在一无所知中,梦里的一天结束了,一个新的「轮回」便会开始。",
+                                    value="",
+                                )
+                        with gr.Tab(label=i18n("Batch Inference")):
+                            with gr.Row():
+                                batch_infer_num = gr.Slider(
+                                    label="Batch infer nums",
+                                    minimum=1,
+                                    maximum=n_audios,
+                                    step=1,
+                                    value=1,
+                                )
 
             with gr.Column(scale=3):
                 for _ in range(n_audios):
@@ -397,10 +420,10 @@ def build_app():
             fn=normalize_text, inputs=[text, if_refine_text], outputs=[refined_text]
         )
 
-        def select_example_audio(audio_file):
-            if audio_file:
-                audio_path = examples_dir / audio_file
-                lab_file = audio_path.with_suffix(".lab")
+        def select_example_audio(audio_path):
+            audio_path = Path(audio_path)
+            if audio_path.is_file():
+                lab_file = Path(audio_path.with_suffix(".lab"))
 
                 if lab_file.exists():
                     lab_content = lab_file.read_text(encoding="utf-8").strip()
@@ -412,6 +435,8 @@ def build_app():
 
         # Connect the dropdown to update reference audio and text
         example_audio_dropdown.change(
+            fn=update_examples, inputs=[], outputs=[example_audio_dropdown]
+        ).then(
             fn=select_example_audio,
             inputs=[example_audio_dropdown],
             outputs=[reference_audio, reference_text, enable_reference_audio],
@@ -430,6 +455,7 @@ def build_app():
                 top_p,
                 repetition_penalty,
                 temperature,
+                seed,
                 batch_infer_num,
             ],
             [stream_audio, *global_audio_list, *global_error_list],
@@ -448,9 +474,10 @@ def build_app():
                 top_p,
                 repetition_penalty,
                 temperature,
+                seed,
             ],
             [stream_audio, global_audio_list[0], global_error_list[0]],
-            concurrency_limit=10,
+            concurrency_limit=1,
         )
     return app