Explorar o código

fix webui & api for fs 1.2

Lengyue hai 1 ano
pai
achega
c7d9e3fcaa
Modificáronse 4 ficheiros con 25 adicións e 205 borrados
  1. 9 28
      tools/api.py
  2. 2 6
      tools/llama/generate.py
  3. 0 153
      tools/vits_decoder/inference.py
  4. 14 18
      tools/webui.py

+ 9 - 28
tools/api.py

@@ -3,7 +3,6 @@ import io
 import json
 import queue
 import random
-import threading
 import traceback
 import wave
 from argparse import ArgumentParser
@@ -18,7 +17,6 @@ import soundfile as sf
 import torch
 from kui.asgi import (
     Body,
-    FileResponse,
     HTTPException,
     HttpView,
     JSONResponse,
@@ -29,7 +27,6 @@ from kui.asgi import (
 from kui.asgi.routing import MultimethodRoutes
 from loguru import logger
 from pydantic import BaseModel, Field
-from transformers import AutoTokenizer
 
 pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
 
@@ -99,23 +96,19 @@ def encode_reference(*, decoder_model, reference_audio, enable_reference_audio):
         # VQ Encoder
         if isinstance(decoder_model, FireflyArchitecture):
             prompt_tokens = decoder_model.encode(audios, audio_lengths)[0][0]
-            reference_embedding = None  # VQGAN does not have reference embedding
 
         logger.info(f"Encoded prompt: {prompt_tokens.shape}")
     else:
         prompt_tokens = None
-        reference_embedding = None
         logger.info("No reference audio provided")
 
-    return prompt_tokens, reference_embedding
+    return prompt_tokens
 
 
 def decode_vq_tokens(
     *,
     decoder_model,
     codes,
-    text_tokens: torch.Tensor | None = None,
-    reference_embedding: torch.Tensor | None = None,
 ):
     feature_lengths = torch.tensor([codes.shape[1]], device=decoder_model.device)
     logger.info(f"VQ features: {codes.shape}")
@@ -172,17 +165,17 @@ class InvokeRequest(BaseModel):
     text: str = "你说的对, 但是原神是一款由米哈游自主研发的开放世界手游."
     reference_text: Optional[str] = None
     reference_audio: Optional[str] = None
-    max_new_tokens: int = 0
-    chunk_length: Annotated[int, Field(ge=0, le=500, strict=True)] = 150
+    max_new_tokens: int = 1024
+    chunk_length: Annotated[int, Field(ge=0, le=500, strict=True)] = 100
     top_p: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
-    repetition_penalty: Annotated[float, Field(ge=0.9, le=2.0, strict=True)] = 1.5
+    repetition_penalty: Annotated[float, Field(ge=0.9, le=2.0, strict=True)] = 1.2
     temperature: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
-    speaker: Optional[str] = None
     emotion: Optional[str] = None
     format: Literal["wav", "mp3", "flac"] = "wav"
     streaming: bool = False
     ref_json: Optional[str] = "ref_data.json"
     ref_base: Optional[str] = "ref_data"
+    speaker: Optional[str] = None
 
 
 def get_content_type(audio_format):
@@ -217,7 +210,7 @@ def inference(req: InvokeRequest):
         logger.info("ref_text: " + ref_text)
 
     # Parse reference audio aka prompt
-    prompt_tokens, reference_embedding = encode_reference(
+    prompt_tokens = encode_reference(
         decoder_model=decoder_model,
         reference_audio=(
             io.BytesIO(base64.b64decode(req.reference_audio))
@@ -229,7 +222,6 @@ def inference(req: InvokeRequest):
 
     # LLAMA Inference
     request = dict(
-        tokenizer=llama_tokenizer,
         device=decoder_model.device,
         max_new_tokens=req.max_new_tokens,
         text=req.text,
@@ -240,7 +232,6 @@ def inference(req: InvokeRequest):
         iterative_prompt=req.chunk_length > 0,
         chunk_length=req.chunk_length,
         max_length=2048,
-        speaker=req.speaker,
         prompt_tokens=prompt_tokens,
         prompt_text=req.reference_text,
     )
@@ -267,18 +258,12 @@ def inference(req: InvokeRequest):
         if result.action == "next":
             break
 
-        text_tokens = llama_tokenizer.encode(result.text, return_tensors="pt").to(
-            decoder_model.device
-        )
-
         with torch.autocast(
             device_type=decoder_model.device.type, dtype=args.precision
         ):
             fake_audios = decode_vq_tokens(
                 decoder_model=decoder_model,
                 codes=result.codes,
-                text_tokens=text_tokens,
-                reference_embedding=reference_embedding,
             )
 
         fake_audios = fake_audios.float().cpu().numpy()
@@ -379,7 +364,6 @@ def parse_args():
         default="checkpoints/fish-speech-1.2/firefly-gan-vq-fsq-4x1024-42hz-generator.pth",
     )
     parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
-    parser.add_argument("--tokenizer", type=str, default="fishaudio/fish-speech-1")
     parser.add_argument("--device", type=str, default="cuda")
     parser.add_argument("--half", action="store_true")
     parser.add_argument("--compile", action="store_true")
@@ -422,7 +406,6 @@ if __name__ == "__main__":
         precision=args.precision,
         compile=args.compile,
     )
-    llama_tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
     logger.info("Llama model loaded, loading VQ-GAN model...")
 
     decoder_model = load_decoder_model(
@@ -437,15 +420,13 @@ if __name__ == "__main__":
     list(
         inference(
             InvokeRequest(
-                text="A warm-up sentence.",
+                text="Hello world.",
                 reference_text=None,
                 reference_audio=None,
-                max_new_tokens=0,
-                chunk_length=150,
+                max_new_tokens=1024,
                 top_p=0.7,
-                repetition_penalty=1.5,
+                repetition_penalty=1.2,
                 temperature=0.7,
-                speaker=None,
                 emotion=None,
                 format="wav",
                 ref_base=None,

+ 2 - 6
tools/llama/generate.py

@@ -1,6 +1,5 @@
 import os
 import queue
-import string
 import threading
 import time
 from dataclasses import dataclass
@@ -13,11 +12,8 @@ import numpy as np
 import torch
 import torch._dynamo.config
 import torch._inductor.config
-from hydra import compose, initialize
-from hydra.utils import instantiate
 from loguru import logger
 from tqdm import tqdm
-from transformers import AutoTokenizer
 
 from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID
 from fish_speech.text import clean_text, split_text
@@ -618,7 +614,7 @@ def launch_thread_safe_queue(
 @click.option("--num-samples", type=int, default=1)
 @click.option("--max-new-tokens", type=int, default=0)
 @click.option("--top-p", type=float, default=0.7)
-@click.option("--repetition-penalty", type=float, default=1.5)
+@click.option("--repetition-penalty", type=float, default=1.2)
 @click.option("--temperature", type=float, default=0.7)
 @click.option(
     "--checkpoint-path",
@@ -629,7 +625,7 @@ def launch_thread_safe_queue(
 @click.option("--seed", type=int, default=42)
 @click.option("--half/--no-half", default=False)
 @click.option("--iterative-prompt/--no-iterative-prompt", default=True)
-@click.option("--chunk-length", type=int, default=150)
+@click.option("--chunk-length", type=int, default=100)
 def main(
     text: str,
     prompt_text: Optional[list[str]],

+ 0 - 153
tools/vits_decoder/inference.py

@@ -1,153 +0,0 @@
-from pathlib import Path
-
-import click
-import hydra
-import librosa
-import numpy as np
-import soundfile as sf
-import torch
-from hydra import compose, initialize
-from hydra.utils import instantiate
-from lightning import LightningModule
-from loguru import logger
-from omegaconf import OmegaConf
-from transformers import AutoTokenizer
-
-from fish_speech.utils.file import AUDIO_EXTENSIONS
-
-# register eval resolver
-OmegaConf.register_new_resolver("eval", eval)
-
-
-def load_model(config_name, checkpoint_path, device="cuda"):
-    hydra.core.global_hydra.GlobalHydra.instance().clear()
-    with initialize(version_base="1.3", config_path="../../fish_speech/configs"):
-        cfg = compose(config_name=config_name)
-
-    model: LightningModule = instantiate(cfg.model)
-    state_dict = torch.load(
-        checkpoint_path,
-        map_location=model.device,
-    )
-
-    if "state_dict" in state_dict:
-        state_dict = state_dict["state_dict"]
-
-    model.load_state_dict(state_dict, strict=False)
-    model.eval()
-    model.to(device)
-    logger.info("Restored model from checkpoint")
-
-    return model
-
-
-@torch.no_grad()
-@click.command()
-@click.option(
-    "--input-path",
-    "-i",
-    default="test.npy",
-    type=click.Path(exists=True, path_type=Path),
-)
-@click.option(
-    "--reference-path",
-    "-r",
-    type=click.Path(exists=True, path_type=Path),
-    default=None,
-)
-@click.option(
-    "--text",
-    type=str,
-    default="-",
-)
-@click.option(
-    "--tokenizer",
-    type=str,
-    default="fishaudio/fish-speech-1",
-)
-@click.option(
-    "--output-path", "-o", default="fake.wav", type=click.Path(path_type=Path)
-)
-@click.option("--config-name", "-cfg", default="vits_decoder_finetune")
-@click.option(
-    "--checkpoint-path",
-    "-ckpt",
-    default="checkpoints/fish-speech-1.2/firefly-gan-vq-fsq-4x1024-42hz-generator.pth",
-)
-@click.option(
-    "--device",
-    "-d",
-    default="cuda",
-)
-def main(
-    input_path,
-    reference_path,
-    text,
-    tokenizer,
-    output_path,
-    config_name,
-    checkpoint_path,
-    device,
-):
-    model = load_model(config_name, checkpoint_path, device=device)
-
-    assert input_path.suffix == ".npy", f"Expected .npy file, got {input_path.suffix}"
-
-    logger.info(f"Processing precomputed indices from {input_path}")
-    indices = np.load(input_path)
-    indices = torch.from_numpy(indices).to(model.device).long()
-    assert indices.ndim == 2, f"Expected 2D indices, got {indices.ndim}"
-
-    # Extract reference audio
-    if reference_path is not None:
-        assert (
-            reference_path.suffix in AUDIO_EXTENSIONS
-        ), f"Expected audio file, got {reference_path.suffix}"
-        reference_audio, sr = librosa.load(reference_path, sr=model.sampling_rate)
-        reference_audio = torch.from_numpy(reference_audio).to(model.device).float()
-        reference_spec = model.spec_transform(reference_audio[None])
-        reference_embedding = model.generator.encode_ref(
-            reference_spec,
-            torch.tensor([reference_spec.shape[-1]], device=model.device),
-        )
-        logger.info(
-            f"Loaded reference audio from {reference_path}, shape: {reference_audio.shape}"
-        )
-    else:
-        reference_embedding = torch.zeros(
-            1, model.generator.gin_channels, 1, device=model.device
-        )
-        logger.info("No reference audio provided, use zero embedding")
-
-    # Extract text
-    tokenizer = AutoTokenizer.from_pretrained(tokenizer)
-    encoded_text = tokenizer(text, return_tensors="pt").input_ids.to(model.device)
-    logger.info(f"Encoded text: {encoded_text.shape}")
-
-    # Restore
-    feature_lengths = torch.tensor([indices.shape[1]], device=model.device)
-    quantized = model.generator.vq.indicies_to_vq_features(
-        indices=indices[None], feature_lengths=feature_lengths
-    )
-    logger.info(f"Restored VQ features: {quantized.shape}")
-
-    # Decode
-    fake_audios = model.generator.decode(
-        quantized,
-        torch.tensor([quantized.shape[-1]], device=model.device),
-        encoded_text,
-        torch.tensor([encoded_text.shape[-1]], device=model.device),
-        ge=reference_embedding,
-    )
-    logger.info(
-        f"Generated audio: {fake_audios.shape}, equivalent to {fake_audios.shape[-1] / model.sampling_rate:.2f} seconds"
-    )
-
-    # Save audio
-    fake_audio = fake_audios[0, 0].float().cpu().numpy()
-    sf.write(output_path, fake_audio, model.sampling_rate)
-    logger.info(f"Saved audio to {output_path}")
-
-
-if __name__ == "__main__":
-    main()

+ 14 - 18
tools/webui.py

@@ -80,7 +80,7 @@ def inference(
         )
 
     # Parse reference audio aka prompt
-    prompt_tokens, reference_embedding = encode_reference(
+    prompt_tokens = encode_reference(
         decoder_model=decoder_model,
         reference_audio=reference_audio,
         enable_reference_audio=enable_reference_audio,
@@ -125,10 +125,6 @@ def inference(
         if result.action == "next":
             break
 
-        text_tokens = llama_tokenizer.encode(result.text, return_tensors="pt").to(
-            decoder_model.device
-        )
-
         with torch.autocast(
             device_type=(
                 "cpu"
@@ -140,8 +136,6 @@ def inference(
             fake_audios = decode_vq_tokens(
                 decoder_model=decoder_model,
                 codes=result.codes,
-                text_tokens=text_tokens,
-                reference_embedding=reference_embedding,
             )
 
         fake_audios = fake_audios.float().cpu().numpy()
@@ -287,7 +281,7 @@ def build_app():
                             label=i18n("Iterative Prompt Length, 0 means off"),
                             minimum=0,
                             maximum=500,
-                            value=150,
+                            value=100,
                             step=8,
                         )
 
@@ -295,26 +289,30 @@ def build_app():
                             label=i18n("Maximum tokens per batch, 0 means no limit"),
                             minimum=0,
                             maximum=2048,
-                            value=0,  # 0 means no limit
+                            value=1024,  # 0 means no limit
                             step=8,
                         )
 
                         top_p = gr.Slider(
-                            label="Top-P", minimum=0, maximum=1, value=0.7, step=0.01
+                            label="Top-P",
+                            minimum=0.6,
+                            maximum=0.9,
+                            value=0.7,
+                            step=0.01,
                         )
 
                         repetition_penalty = gr.Slider(
                             label=i18n("Repetition Penalty"),
-                            minimum=0,
-                            maximum=2,
-                            value=1.5,
+                            minimum=1,
+                            maximum=1.5,
+                            value=1.2,
                             step=0.01,
                         )
 
                         temperature = gr.Slider(
                             label="Temperature",
-                            minimum=0,
-                            maximum=2,
+                            minimum=0.6,
+                            maximum=0.9,
                             value=0.7,
                             step=0.01,
                         )
@@ -438,7 +436,6 @@ def parse_args():
         default="checkpoints/fish-speech-1.2/firefly-gan-vq-fsq-4x1024-42hz-generator.pth",
     )
     parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
-    parser.add_argument("--tokenizer", type=str, default="fishaudio/fish-speech-1")
     parser.add_argument("--device", type=str, default="cuda")
     parser.add_argument("--half", action="store_true")
     parser.add_argument("--compile", action="store_true")
@@ -458,7 +455,6 @@ if __name__ == "__main__":
         precision=args.precision,
         compile=args.compile,
     )
-    llama_tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
     logger.info("Llama model loaded, loading VQ-GAN model...")
 
     decoder_model = load_decoder_model(
@@ -479,7 +475,7 @@ if __name__ == "__main__":
             max_new_tokens=0,
             chunk_length=100,
             top_p=0.7,
-            repetition_penalty=1.5,
+            repetition_penalty=1.2,
             temperature=0.7,
         )
     )