|
@@ -3,7 +3,6 @@ import io
|
|
|
import json
|
|
import json
|
|
|
import queue
|
|
import queue
|
|
|
import random
|
|
import random
|
|
|
-import threading
|
|
|
|
|
import traceback
|
|
import traceback
|
|
|
import wave
|
|
import wave
|
|
|
from argparse import ArgumentParser
|
|
from argparse import ArgumentParser
|
|
@@ -18,7 +17,6 @@ import soundfile as sf
|
|
|
import torch
|
|
import torch
|
|
|
from kui.asgi import (
|
|
from kui.asgi import (
|
|
|
Body,
|
|
Body,
|
|
|
- FileResponse,
|
|
|
|
|
HTTPException,
|
|
HTTPException,
|
|
|
HttpView,
|
|
HttpView,
|
|
|
JSONResponse,
|
|
JSONResponse,
|
|
@@ -29,7 +27,6 @@ from kui.asgi import (
|
|
|
from kui.asgi.routing import MultimethodRoutes
|
|
from kui.asgi.routing import MultimethodRoutes
|
|
|
from loguru import logger
|
|
from loguru import logger
|
|
|
from pydantic import BaseModel, Field
|
|
from pydantic import BaseModel, Field
|
|
|
-from transformers import AutoTokenizer
|
|
|
|
|
|
|
|
|
|
pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
|
pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
|
|
|
|
|
|
@@ -99,23 +96,19 @@ def encode_reference(*, decoder_model, reference_audio, enable_reference_audio):
|
|
|
# VQ Encoder
|
|
# VQ Encoder
|
|
|
if isinstance(decoder_model, FireflyArchitecture):
|
|
if isinstance(decoder_model, FireflyArchitecture):
|
|
|
prompt_tokens = decoder_model.encode(audios, audio_lengths)[0][0]
|
|
prompt_tokens = decoder_model.encode(audios, audio_lengths)[0][0]
|
|
|
- reference_embedding = None # VQGAN does not have reference embedding
|
|
|
|
|
|
|
|
|
|
logger.info(f"Encoded prompt: {prompt_tokens.shape}")
|
|
logger.info(f"Encoded prompt: {prompt_tokens.shape}")
|
|
|
else:
|
|
else:
|
|
|
prompt_tokens = None
|
|
prompt_tokens = None
|
|
|
- reference_embedding = None
|
|
|
|
|
logger.info("No reference audio provided")
|
|
logger.info("No reference audio provided")
|
|
|
|
|
|
|
|
- return prompt_tokens, reference_embedding
|
|
|
|
|
|
|
+ return prompt_tokens
|
|
|
|
|
|
|
|
|
|
|
|
|
def decode_vq_tokens(
|
|
def decode_vq_tokens(
|
|
|
*,
|
|
*,
|
|
|
decoder_model,
|
|
decoder_model,
|
|
|
codes,
|
|
codes,
|
|
|
- text_tokens: torch.Tensor | None = None,
|
|
|
|
|
- reference_embedding: torch.Tensor | None = None,
|
|
|
|
|
):
|
|
):
|
|
|
feature_lengths = torch.tensor([codes.shape[1]], device=decoder_model.device)
|
|
feature_lengths = torch.tensor([codes.shape[1]], device=decoder_model.device)
|
|
|
logger.info(f"VQ features: {codes.shape}")
|
|
logger.info(f"VQ features: {codes.shape}")
|
|
@@ -172,17 +165,17 @@ class InvokeRequest(BaseModel):
|
|
|
text: str = "你说的对, 但是原神是一款由米哈游自主研发的开放世界手游."
|
|
text: str = "你说的对, 但是原神是一款由米哈游自主研发的开放世界手游."
|
|
|
reference_text: Optional[str] = None
|
|
reference_text: Optional[str] = None
|
|
|
reference_audio: Optional[str] = None
|
|
reference_audio: Optional[str] = None
|
|
|
- max_new_tokens: int = 0
|
|
|
|
|
- chunk_length: Annotated[int, Field(ge=0, le=500, strict=True)] = 150
|
|
|
|
|
|
|
+ max_new_tokens: int = 1024
|
|
|
|
|
+ chunk_length: Annotated[int, Field(ge=0, le=500, strict=True)] = 100
|
|
|
top_p: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
|
|
top_p: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
|
|
|
- repetition_penalty: Annotated[float, Field(ge=0.9, le=2.0, strict=True)] = 1.5
|
|
|
|
|
|
|
+ repetition_penalty: Annotated[float, Field(ge=0.9, le=2.0, strict=True)] = 1.2
|
|
|
temperature: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
|
|
temperature: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
|
|
|
- speaker: Optional[str] = None
|
|
|
|
|
emotion: Optional[str] = None
|
|
emotion: Optional[str] = None
|
|
|
format: Literal["wav", "mp3", "flac"] = "wav"
|
|
format: Literal["wav", "mp3", "flac"] = "wav"
|
|
|
streaming: bool = False
|
|
streaming: bool = False
|
|
|
ref_json: Optional[str] = "ref_data.json"
|
|
ref_json: Optional[str] = "ref_data.json"
|
|
|
ref_base: Optional[str] = "ref_data"
|
|
ref_base: Optional[str] = "ref_data"
|
|
|
|
|
+ speaker: Optional[str] = None
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_content_type(audio_format):
|
|
def get_content_type(audio_format):
|
|
@@ -217,7 +210,7 @@ def inference(req: InvokeRequest):
|
|
|
logger.info("ref_text: " + ref_text)
|
|
logger.info("ref_text: " + ref_text)
|
|
|
|
|
|
|
|
# Parse reference audio aka prompt
|
|
# Parse reference audio aka prompt
|
|
|
- prompt_tokens, reference_embedding = encode_reference(
|
|
|
|
|
|
|
+ prompt_tokens = encode_reference(
|
|
|
decoder_model=decoder_model,
|
|
decoder_model=decoder_model,
|
|
|
reference_audio=(
|
|
reference_audio=(
|
|
|
io.BytesIO(base64.b64decode(req.reference_audio))
|
|
io.BytesIO(base64.b64decode(req.reference_audio))
|
|
@@ -229,7 +222,6 @@ def inference(req: InvokeRequest):
|
|
|
|
|
|
|
|
# LLAMA Inference
|
|
# LLAMA Inference
|
|
|
request = dict(
|
|
request = dict(
|
|
|
- tokenizer=llama_tokenizer,
|
|
|
|
|
device=decoder_model.device,
|
|
device=decoder_model.device,
|
|
|
max_new_tokens=req.max_new_tokens,
|
|
max_new_tokens=req.max_new_tokens,
|
|
|
text=req.text,
|
|
text=req.text,
|
|
@@ -240,7 +232,6 @@ def inference(req: InvokeRequest):
|
|
|
iterative_prompt=req.chunk_length > 0,
|
|
iterative_prompt=req.chunk_length > 0,
|
|
|
chunk_length=req.chunk_length,
|
|
chunk_length=req.chunk_length,
|
|
|
max_length=2048,
|
|
max_length=2048,
|
|
|
- speaker=req.speaker,
|
|
|
|
|
prompt_tokens=prompt_tokens,
|
|
prompt_tokens=prompt_tokens,
|
|
|
prompt_text=req.reference_text,
|
|
prompt_text=req.reference_text,
|
|
|
)
|
|
)
|
|
@@ -267,18 +258,12 @@ def inference(req: InvokeRequest):
|
|
|
if result.action == "next":
|
|
if result.action == "next":
|
|
|
break
|
|
break
|
|
|
|
|
|
|
|
- text_tokens = llama_tokenizer.encode(result.text, return_tensors="pt").to(
|
|
|
|
|
- decoder_model.device
|
|
|
|
|
- )
|
|
|
|
|
-
|
|
|
|
|
with torch.autocast(
|
|
with torch.autocast(
|
|
|
device_type=decoder_model.device.type, dtype=args.precision
|
|
device_type=decoder_model.device.type, dtype=args.precision
|
|
|
):
|
|
):
|
|
|
fake_audios = decode_vq_tokens(
|
|
fake_audios = decode_vq_tokens(
|
|
|
decoder_model=decoder_model,
|
|
decoder_model=decoder_model,
|
|
|
codes=result.codes,
|
|
codes=result.codes,
|
|
|
- text_tokens=text_tokens,
|
|
|
|
|
- reference_embedding=reference_embedding,
|
|
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
fake_audios = fake_audios.float().cpu().numpy()
|
|
fake_audios = fake_audios.float().cpu().numpy()
|
|
@@ -379,7 +364,6 @@ def parse_args():
|
|
|
default="checkpoints/fish-speech-1.2/firefly-gan-vq-fsq-4x1024-42hz-generator.pth",
|
|
default="checkpoints/fish-speech-1.2/firefly-gan-vq-fsq-4x1024-42hz-generator.pth",
|
|
|
)
|
|
)
|
|
|
parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
|
|
parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
|
|
|
- parser.add_argument("--tokenizer", type=str, default="fishaudio/fish-speech-1")
|
|
|
|
|
parser.add_argument("--device", type=str, default="cuda")
|
|
parser.add_argument("--device", type=str, default="cuda")
|
|
|
parser.add_argument("--half", action="store_true")
|
|
parser.add_argument("--half", action="store_true")
|
|
|
parser.add_argument("--compile", action="store_true")
|
|
parser.add_argument("--compile", action="store_true")
|
|
@@ -422,7 +406,6 @@ if __name__ == "__main__":
|
|
|
precision=args.precision,
|
|
precision=args.precision,
|
|
|
compile=args.compile,
|
|
compile=args.compile,
|
|
|
)
|
|
)
|
|
|
- llama_tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
|
|
|
|
|
logger.info("Llama model loaded, loading VQ-GAN model...")
|
|
logger.info("Llama model loaded, loading VQ-GAN model...")
|
|
|
|
|
|
|
|
decoder_model = load_decoder_model(
|
|
decoder_model = load_decoder_model(
|
|
@@ -437,15 +420,13 @@ if __name__ == "__main__":
|
|
|
list(
|
|
list(
|
|
|
inference(
|
|
inference(
|
|
|
InvokeRequest(
|
|
InvokeRequest(
|
|
|
- text="A warm-up sentence.",
|
|
|
|
|
|
|
+ text="Hello world.",
|
|
|
reference_text=None,
|
|
reference_text=None,
|
|
|
reference_audio=None,
|
|
reference_audio=None,
|
|
|
- max_new_tokens=0,
|
|
|
|
|
- chunk_length=150,
|
|
|
|
|
|
|
+ max_new_tokens=1024,
|
|
|
top_p=0.7,
|
|
top_p=0.7,
|
|
|
- repetition_penalty=1.5,
|
|
|
|
|
|
|
+ repetition_penalty=1.2,
|
|
|
temperature=0.7,
|
|
temperature=0.7,
|
|
|
- speaker=None,
|
|
|
|
|
emotion=None,
|
|
emotion=None,
|
|
|
format="wav",
|
|
format="wav",
|
|
|
ref_base=None,
|
|
ref_base=None,
|