Просмотр исходного кода

feat(demo): Add Colab badge and ensure output directory is created (#844)

* feat(demo): Add Colab Demo link and ensure output directory creation

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update README.md

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Deep Chavda 1 год назад
Родитель
Сommit
91bf150031
2 измененных файлов с 8 добавлено и 5 удалено
  1. 2 2
      README.md
  2. 6 3
      fish_speech/models/text2semantic/inference.py

+ 2 - 2
README.md

@@ -78,9 +78,9 @@ We do not hold any responsibility for any illegal usage of the codebase. Please
 
 [Fish Agent](https://fish.audio/demo/live)
 
-## Quick Start for Local Inference
+## Quick Start for Local Inference 
 
-[inference.ipynb](/inference.ipynb)
+[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://github.com/fishaudio/fish-speech/blob/main/inference.ipynb)
 
 ## Videos
 

+ 6 - 3
fish_speech/models/text2semantic/inference.py

@@ -1026,6 +1026,7 @@ def launch_thread_safe_queue_agent(
 @click.option("--half/--no-half", default=False)
 @click.option("--iterative-prompt/--no-iterative-prompt", default=True)
 @click.option("--chunk-length", type=int, default=100)
+@click.option("--output-dir", type=Path, default="temp")
 def main(
     text: str,
     prompt_text: Optional[list[str]],
@@ -1042,8 +1043,9 @@ def main(
     half: bool,
     iterative_prompt: bool,
     chunk_length: int,
+    output_dir: Path,
 ) -> None:
-
+    os.makedirs(output_dir, exist_ok=True)
     precision = torch.half if half else torch.bfloat16
 
     if prompt_text is not None and len(prompt_text) != len(prompt_tokens):
@@ -1101,8 +1103,9 @@ def main(
             logger.info(f"Sampled text: {response.text}")
         elif response.action == "next":
             if codes:
-                np.save(f"codes_{idx}.npy", torch.cat(codes, dim=1).cpu().numpy())
-                logger.info(f"Saved codes to codes_{idx}.npy")
+                codes_npy_path = os.path.join(output_dir, f"codes_{idx}.npy")
+                np.save(codes_npy_path, torch.cat(codes, dim=1).cpu().numpy())
+                logger.info(f"Saved codes to {codes_npy_path}")
             logger.info(f"Next sample")
             codes = []
             idx += 1