Kaynağa Gözat

fix: tts reference and config (#713)

* fix: tts reference and config

* fix streaming

* fix bug

* fix wav header
spicysama 1 yıl önce
ebeveyn
işleme
b11bcf834a

+ 1 - 1
pyproject.toml

@@ -46,13 +46,13 @@ dependencies = [
     "ormsgpack",
     "ormsgpack",
     "tiktoken>=0.8.0",
     "tiktoken>=0.8.0",
     "pydantic==2.9.2",
     "pydantic==2.9.2",
+    "cachetools",
 ]
 ]
 
 
 [project.optional-dependencies]
 [project.optional-dependencies]
 stable = [
 stable = [
     "torch<=2.4.1",
     "torch<=2.4.1",
     "torchaudio",
     "torchaudio",
-    "cachetools",
 ]
 ]
 
 
 [build-system]
 [build-system]

+ 1 - 1
tools/api_client.py

@@ -79,7 +79,7 @@ def parse_args():
     parser.add_argument(
     parser.add_argument(
         "--max_new_tokens",
         "--max_new_tokens",
         type=int,
         type=int,
-        default=0,
+        default=1024,
         help="Maximum new tokens to generate. \n0 means no limit.",
         help="Maximum new tokens to generate. \n0 means no limit.",
     )
     )
     parser.add_argument(
     parser.add_argument(

+ 5 - 5
tools/download_models.py

@@ -22,14 +22,14 @@ def check_and_download_files(repo_id, file_list, local_dir):
 
 
 
 
 # 1st
 # 1st
-repo_id_1 = "fishaudio/fish-speech-1.4"
-local_dir_1 = "./checkpoints/fish-speech-1.4"
+repo_id_1 = "fishaudio/fish-speech-1.5"
+local_dir_1 = "./checkpoints/fish-speech-1.5"
 files_1 = [
 files_1 = [
+    "gitattributes",
     "model.pth",
     "model.pth",
     "README.md",
     "README.md",
-    "special_tokens_map.json",
-    "tokenizer_config.json",
-    "tokenizer.json",
+    "special_tokens.json",
+    "tokenizer.tiktoken",
     "config.json",
     "config.json",
     "firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
     "firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
 ]
 ]

+ 1 - 2
tools/inference_engine/__init__.py

@@ -109,8 +109,7 @@ class TTSInferenceEngine(ReferenceLoader, VQManager):
                         audio=(sample_rate, segment),
                         audio=(sample_rate, segment),
                         error=None,
                         error=None,
                     )
                     )
-                else:
-                    segments.append(segment)
+                segments.append(segment)
             else:
             else:
                 break
                 break
 
 

+ 0 - 1
tools/inference_engine/reference_loader.py

@@ -85,7 +85,6 @@ class ReferenceLoader:
                 # If the references are not already loaded, encode them
                 # If the references are not already loaded, encode them
                 prompt_tokens.append(
                 prompt_tokens.append(
                     self.encode_reference(
                     self.encode_reference(
-                        decoder_model=self.decoder_model,
                         reference_audio=ref.audio,
                         reference_audio=ref.audio,
                         enable_reference_audio=True,
                         enable_reference_audio=True,
                     )
                     )

+ 3 - 6
tools/inference_engine/utils.py

@@ -11,7 +11,7 @@ from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
 @dataclass
 @dataclass
 class InferenceResult:
 class InferenceResult:
     code: Literal["header", "segment", "error", "final"]
     code: Literal["header", "segment", "error", "final"]
-    audio: Optional[Tuple[int, np.ndarray]]
+    audio: Optional[Tuple[int, np.ndarray | bytes]]
     error: Optional[Exception]
     error: Optional[Exception]
 
 
 
 
@@ -25,7 +25,7 @@ def normalize_text(user_input: str, use_normalization: bool) -> str:
 
 
 def wav_chunk_header(
 def wav_chunk_header(
     sample_rate: int = 44100, bit_depth: int = 16, channels: int = 1
     sample_rate: int = 44100, bit_depth: int = 16, channels: int = 1
-) -> np.ndarray:
+) -> bytes:
     buffer = io.BytesIO()
     buffer = io.BytesIO()
 
 
     with wave.open(buffer, "wb") as wav_file:
     with wave.open(buffer, "wb") as wav_file:
@@ -36,7 +36,4 @@ def wav_chunk_header(
     wav_header_bytes = buffer.getvalue()
     wav_header_bytes = buffer.getvalue()
     buffer.close()
     buffer.close()
 
 
-    # Convert to numpy array
-    wav_header = np.frombuffer(wav_header_bytes, dtype=np.uint8)
-
-    return wav_header
+    return wav_header_bytes

+ 1 - 1
tools/run_webui.py

@@ -87,7 +87,7 @@ if __name__ == "__main__":
                 text="Hello world.",
                 text="Hello world.",
                 references=[],
                 references=[],
                 reference_id=None,
                 reference_id=None,
-                max_new_tokens=0,
+                max_new_tokens=1024,
                 chunk_length=200,
                 chunk_length=200,
                 top_p=0.7,
                 top_p=0.7,
                 repetition_penalty=1.5,
                 repetition_penalty=1.5,

+ 8 - 4
tools/server/inference.py

@@ -14,6 +14,7 @@ def inference_wrapper(req: ServeTTSRequest, engine: TTSInferenceEngine):
     Wrapper for the inference function.
     Wrapper for the inference function.
     Used in the API server.
     Used in the API server.
     """
     """
+    count = 0
     for result in engine.inference(req):
     for result in engine.inference(req):
         match result.code:
         match result.code:
             case "header":
             case "header":
@@ -27,15 +28,18 @@ def inference_wrapper(req: ServeTTSRequest, engine: TTSInferenceEngine):
                 )
                 )
 
 
             case "segment":
             case "segment":
+                count += 1
                 if isinstance(result.audio, tuple):
                 if isinstance(result.audio, tuple):
                     yield (result.audio[1] * AMPLITUDE).astype(np.int16).tobytes()
                     yield (result.audio[1] * AMPLITUDE).astype(np.int16).tobytes()
 
 
             case "final":
             case "final":
+                count += 1
                 if isinstance(result.audio, tuple):
                 if isinstance(result.audio, tuple):
                     yield result.audio[1]
                     yield result.audio[1]
                 return None  # Stop the generator
                 return None  # Stop the generator
 
 
-    raise HTTPException(
-        HTTPStatus.INTERNAL_SERVER_ERROR,
-        content="No audio generated, please check the input text.",
-    )
+    if count == 0:
+        raise HTTPException(
+            HTTPStatus.INTERNAL_SERVER_ERROR,
+            content="No audio generated, please check the input text.",
+        )

+ 2 - 2
tools/server/model_manager.py

@@ -113,10 +113,10 @@ class ModelManager:
             text="Hello world.",
             text="Hello world.",
             references=[],
             references=[],
             reference_id=None,
             reference_id=None,
-            max_new_tokens=0,
+            max_new_tokens=1024,
             chunk_length=200,
             chunk_length=200,
             top_p=0.7,
             top_p=0.7,
-            repetition_penalty=1.5,
+            repetition_penalty=1.2,
             temperature=0.7,
             temperature=0.7,
             format="wav",
             format="wav",
         )
         )