瀏覽代碼

fix webui

Lengyue 10 月之前
父節點
當前提交
ef1af20791

+ 4 - 2
fish_speech/inference_engine/__init__.py

@@ -65,7 +65,10 @@ class TTSInferenceEngine(ReferenceLoader, VQManager):
         response_queue = self.send_Llama_request(req, prompt_tokens, prompt_texts)
 
         # Get the sample rate from the decoder model
-        sample_rate = self.decoder_model.spec_transform.sample_rate
+        if hasattr(self.decoder_model, "spec_transform"):
+            sample_rate = self.decoder_model.spec_transform.sample_rate
+        else:
+            sample_rate = self.decoder_model.sample_rate
 
         # If streaming, send the header
         if req.streaming:
@@ -156,7 +159,6 @@ class TTSInferenceEngine(ReferenceLoader, VQManager):
             compile=self.compile,
             iterative_prompt=req.chunk_length > 0,
             chunk_length=req.chunk_length,
-            max_length=4096,
             prompt_tokens=prompt_tokens,
             prompt_text=prompt_texts,
         )

+ 6 - 4
fish_speech/inference_engine/vq_manager.py

@@ -30,9 +30,11 @@ class VQManager:
     def encode_reference(self, reference_audio, enable_reference_audio):
         if enable_reference_audio and reference_audio is not None:
             # Load audios, and prepare basic info here
-            reference_audio_content = self.load_audio(
-                reference_audio, self.decoder_model.spec_transform.sample_rate
-            )
+            if hasattr(self.decoder_model, "spec_transform"):
+                sample_rate = self.decoder_model.spec_transform.sample_rate
+            else:
+                sample_rate = self.decoder_model.sample_rate
+            reference_audio_content = self.load_audio(reference_audio, sample_rate)
 
             audios = torch.from_numpy(reference_audio_content).to(
                 self.decoder_model.device
@@ -41,7 +43,7 @@ class VQManager:
                 [audios.shape[2]], device=self.decoder_model.device, dtype=torch.long
             )
             logger.info(
-                f"Loaded audio with {audios.shape[2] / self.decoder_model.spec_transform.sample_rate:.2f} seconds"
+                f"Loaded audio with {audios.shape[2] / sample_rate:.2f} seconds"
             )
 
             # VQ Encoder

+ 12 - 4
fish_speech/models/text2semantic/inference.py

@@ -322,9 +322,9 @@ def generate_long(
     text: str,
     num_samples: int = 1,
     max_new_tokens: int = 0,
-    top_p: int = 0.7,
-    repetition_penalty: float = 1.5,
-    temperature: float = 0.7,
+    top_p: int = 0.8,
+    repetition_penalty: float = 1.1,
+    temperature: float = 0.8,
     compile: bool = False,
     iterative_prompt: bool = True,
     chunk_length: int = 150,
@@ -344,6 +344,8 @@ def generate_long(
         prompt_tokens
     ), "Prompt text and tokens must have the same length"
 
+    prompt_tokens = [i.cpu() for i in prompt_tokens]
+
     model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
     tokenizer = model.tokenizer
     base_content_sequence = ContentSequence(modality="interleave")
@@ -423,7 +425,13 @@ def generate_long(
             #     partial_encoded = global_encoded
 
             # cat_encoded = torch.cat([encoded_prompts, *partial_encoded], dim=1)
-            cat_encoded = torch.cat([encoded_prompts, seg], dim=1)
+            if len(base_content_sequence.parts) <= 1 and len(global_encoded) >= 2:
+                cat_encoded = torch.cat(
+                    [encoded_prompts, global_encoded[0], global_encoded[1], seg], dim=1
+                )
+            else:
+                cat_encoded = torch.cat([encoded_prompts, seg], dim=1)
+
             cat_encoded = cat_encoded.to(device=device)
             prompt_length = cat_encoded.size(1)
 

+ 3 - 21
fish_speech/utils/schema.py

@@ -27,13 +27,6 @@ class ServeAudioPart(BaseModel):
     audio: bytes
 
 
-@dataclass
-class ASRPackRequest:
-    audio: torch.Tensor
-    result_queue: queue.Queue
-    language: str
-
-
 class ServeASRRequest(BaseModel):
     # The audio should be an uncompressed PCM float16 audio
     audios: list[bytes]
@@ -93,17 +86,6 @@ class ServeVQGANDecodeResponse(BaseModel):
     audios: list[bytes]
 
 
-class ServeContentSequenceParts(BaseModel):
-    parts: list[VQPart | TextPart]
-
-
-class ServeResponse(BaseModel):
-    content_sequences: list[ServeContentSequenceParts]
-    finish_reason: Literal["stop", "error"] | None = None
-    stats: dict[str, int | float | str] = {}
-    finished: list[bool] | None = None
-
-
 class ServeStreamDelta(BaseModel):
     role: Literal["system", "assistant", "user"] | None = None
     part: ServeVQPart | ServeTextPart | None = None
@@ -155,9 +137,9 @@ class ServeTTSRequest(BaseModel):
     # not usually used below
     streaming: bool = False
     max_new_tokens: int = 1024
-    top_p: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
-    repetition_penalty: Annotated[float, Field(ge=0.9, le=2.0, strict=True)] = 1.2
-    temperature: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
+    top_p: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.8
+    repetition_penalty: Annotated[float, Field(ge=0.9, le=2.0, strict=True)] = 1.1
+    temperature: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.8
 
     class Config:
         # Allow arbitrary types for pytorch related types

+ 4 - 4
tools/api_client.py

@@ -82,19 +82,19 @@ def parse_args():
         help="Maximum new tokens to generate. \n0 means no limit.",
     )
     parser.add_argument(
-        "--chunk_length", type=int, default=200, help="Chunk length for synthesis"
+        "--chunk_length", type=int, default=300, help="Chunk length for synthesis"
     )
     parser.add_argument(
-        "--top_p", type=float, default=0.7, help="Top-p sampling for synthesis"
+        "--top_p", type=float, default=0.8, help="Top-p sampling for synthesis"
     )
     parser.add_argument(
         "--repetition_penalty",
         type=float,
-        default=1.2,
+        default=1.1,
         help="Repetition penalty for synthesis",
     )
     parser.add_argument(
-        "--temperature", type=float, default=0.7, help="Temperature for sampling"
+        "--temperature", type=float, default=0.8, help="Temperature for sampling"
     )
 
     parser.add_argument(

+ 2 - 2
tools/run_webui.py

@@ -29,9 +29,9 @@ def parse_args():
     parser.add_argument(
         "--decoder-checkpoint-path",
         type=Path,
-        default="checkpoints/openaudio-s1-mini/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
+        default="checkpoints/openaudio-s1-mini/codec.pth",
     )
-    parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
+    parser.add_argument("--decoder-config-name", type=str, default="modded_dac_vq")
     parser.add_argument("--device", type=str, default="cuda")
     parser.add_argument("--half", action="store_true")
     parser.add_argument("--compile", action="store_true")

+ 1 - 1
tools/server/api_utils.py

@@ -25,7 +25,7 @@ def parse_args():
         type=str,
         default="checkpoints/openaudio-s1-mini/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
     )
-    parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
+    parser.add_argument("--decoder-config-name", type=str, default="modded_dac_vq")
     parser.add_argument("--device", type=str, default="cuda")
     parser.add_argument("--half", action="store_true")
     parser.add_argument("--compile", action="store_true")

+ 2 - 2
tools/vqgan/extract_vq.py

@@ -46,7 +46,7 @@ logger.add(sys.stderr, format=logger_format)
 
 @lru_cache(maxsize=1)
 def get_model(
-    config_name: str = "firefly_gan_vq",
+    config_name: str = "modded_dac_vq",
     checkpoint_path: str = "checkpoints/openaudio-s1-mini/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
     device: str | torch.device = "cuda",
 ):
@@ -135,7 +135,7 @@ def process_batch(files: list[Path], model) -> float:
 @click.command()
 @click.argument("folder")
 @click.option("--num-workers", default=1)
-@click.option("--config-name", default="firefly_gan_vq")
+@click.option("--config-name", default="modded_dac_vq")
 @click.option(
     "--checkpoint-path",
     default="checkpoints/openaudio-s1-mini/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",

+ 11 - 11
tools/webui/__init__.py

@@ -31,9 +31,9 @@ def build_app(inference_fct: Callable, theme: str = "light") -> gr.Blocks:
                             with gr.Row():
                                 chunk_length = gr.Slider(
                                     label=i18n("Iterative Prompt Length, 0 means off"),
-                                    minimum=0,
-                                    maximum=300,
-                                    value=200,
+                                    minimum=100,
+                                    maximum=400,
+                                    value=300,
                                     step=8,
                                 )
 
@@ -50,26 +50,26 @@ def build_app(inference_fct: Callable, theme: str = "light") -> gr.Blocks:
                             with gr.Row():
                                 top_p = gr.Slider(
                                     label="Top-P",
-                                    minimum=0.6,
-                                    maximum=0.9,
-                                    value=0.7,
+                                    minimum=0.7,
+                                    maximum=0.95,
+                                    value=0.8,
                                     step=0.01,
                                 )
 
                                 repetition_penalty = gr.Slider(
                                     label=i18n("Repetition Penalty"),
                                     minimum=1,
-                                    maximum=1.5,
-                                    value=1.2,
+                                    maximum=1.2,
+                                    value=1.1,
                                     step=0.01,
                                 )
 
                             with gr.Row():
                                 temperature = gr.Slider(
                                     label="Temperature",
-                                    minimum=0.6,
-                                    maximum=0.9,
-                                    value=0.7,
+                                    minimum=0.7,
+                                    maximum=1.0,
+                                    value=0.8,
                                     step=0.01,
                                 )
                                 seed = gr.Number(