|
|
@@ -29,9 +29,15 @@ from transformers import AutoTokenizer
|
|
|
|
|
|
pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
|
|
|
|
|
-from tools.llama.generate import launch_thread_safe_queue
|
|
|
-from tools.vqgan.inference import load_model as load_vqgan_model
|
|
|
-from tools.webui import inference
|
|
|
+from fish_speech.models.vits_decoder.lit_module import VITSDecoder
|
|
|
+from fish_speech.models.vqgan.lit_module import VQGAN
|
|
|
+from tools.llama.generate import (
|
|
|
+ GenerateRequest,
|
|
|
+ GenerateResponse,
|
|
|
+ WrappedGenerateResponse,
|
|
|
+ launch_thread_safe_queue,
|
|
|
+)
|
|
|
+from tools.vqgan.inference import load_model as load_decoder_model
|
|
|
|
|
|
|
|
|
def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
|
|
|
@@ -70,6 +76,94 @@ def other_exception_handler(exc: "Exception"):
|
|
|
)
|
|
|
|
|
|
|
|
|
+def encode_reference(*, decoder_model, reference_audio, enable_reference_audio):
|
|
|
+ if enable_reference_audio and reference_audio is not None:
|
|
|
+ # Load audios, and prepare basic info here
|
|
|
+ reference_audio_content, _ = librosa.load(
|
|
|
+ reference_audio, sr=decoder_model.sampling_rate, mono=True
|
|
|
+ )
|
|
|
+ audios = torch.from_numpy(reference_audio_content).to(decoder_model.device)[
|
|
|
+ None, None, :
|
|
|
+ ]
|
|
|
+ audio_lengths = torch.tensor(
|
|
|
+ [audios.shape[2]], device=decoder_model.device, dtype=torch.long
|
|
|
+ )
|
|
|
+ logger.info(
|
|
|
+ f"Loaded audio with {audios.shape[2] / decoder_model.sampling_rate:.2f} seconds"
|
|
|
+ )
|
|
|
+
|
|
|
+ # VQ Encoder
|
|
|
+ if isinstance(decoder_model, VQGAN):
|
|
|
+ prompt_tokens = decoder_model.encode(audios, audio_lengths)[0][0]
|
|
|
+ reference_embedding = None # VQGAN does not have reference embedding
|
|
|
+ elif isinstance(decoder_model, VITSDecoder):
|
|
|
+ reference_spec = decoder_model.spec_transform(audios[0])
|
|
|
+ reference_embedding = decoder_model.generator.encode_ref(
|
|
|
+ reference_spec,
|
|
|
+ torch.tensor([reference_spec.shape[-1]], device=decoder_model.device),
|
|
|
+ )
|
|
|
+ logger.info(f"Loaded reference audio from {reference_audio}")
|
|
|
+
|
|
|
+ audio_lengths = torch.tensor(
|
|
|
+ [audios.shape[-1]], device=decoder_model.device, dtype=torch.long
|
|
|
+ )
|
|
|
+ prompt_tokens = decoder_model.generator.vq.encode(audios, audio_lengths)[0][
|
|
|
+ 0
|
|
|
+ ]
|
|
|
+ else:
|
|
|
+ raise ValueError(f"Unknown model type: {type(decoder_model)}")
|
|
|
+
|
|
|
+ logger.info(f"Encoded prompt: {prompt_tokens.shape}")
|
|
|
+ elif isinstance(decoder_model, VITSDecoder):
|
|
|
+ prompt_tokens = None
|
|
|
+ reference_embedding = torch.zeros(
|
|
|
+ 1, decoder_model.generator.gin_channels, 1, device=decoder_model.device
|
|
|
+ )
|
|
|
+ logger.info("No reference audio provided, use zero embedding")
|
|
|
+ else:
|
|
|
+ prompt_tokens = None
|
|
|
+ reference_embedding = None
|
|
|
+ logger.info("No reference audio provided")
|
|
|
+
|
|
|
+ return prompt_tokens, reference_embedding
|
|
|
+
|
|
|
+
|
|
|
+def decode_vq_tokens(
|
|
|
+ *,
|
|
|
+ decoder_model,
|
|
|
+ codes,
|
|
|
+ text_tokens: torch.Tensor | None = None,
|
|
|
+ reference_embedding: torch.Tensor | None = None,
|
|
|
+):
|
|
|
+ feature_lengths = torch.tensor([codes.shape[1]], device=decoder_model.device)
|
|
|
+ logger.info(f"VQ features: {codes.shape}")
|
|
|
+
|
|
|
+ if isinstance(decoder_model, VQGAN):
|
|
|
+ # VQGAN Inference
|
|
|
+ return decoder_model.decode(
|
|
|
+ indices=codes[None],
|
|
|
+ feature_lengths=feature_lengths,
|
|
|
+ return_audios=True,
|
|
|
+ ).squeeze()
|
|
|
+
|
|
|
+ if isinstance(decoder_model, VITSDecoder):
|
|
|
+ # VITS Inference
|
|
|
+ quantized = decoder_model.generator.vq.indicies_to_vq_features(
|
|
|
+ indices=codes[None], feature_lengths=feature_lengths
|
|
|
+ )
|
|
|
+ logger.info(f"Restored VQ features: {quantized.shape}")
|
|
|
+
|
|
|
+ return decoder_model.generator.decode(
|
|
|
+ quantized,
|
|
|
+ torch.tensor([quantized.shape[-1]], device=decoder_model.device),
|
|
|
+ text_tokens,
|
|
|
+ torch.tensor([text_tokens.shape[-1]], device=decoder_model.device),
|
|
|
+ ge=reference_embedding,
|
|
|
+ ).squeeze()
|
|
|
+
|
|
|
+ raise ValueError(f"Unknown model type: {type(decoder_model)}")
|
|
|
+
|
|
|
+
|
|
|
routes = MultimethodRoutes(base_class=HttpView)
|
|
|
|
|
|
|
|
|
@@ -91,29 +185,22 @@ class InvokeRequest(BaseModel):
|
|
|
def inference(req: InvokeRequest):
|
|
|
# Parse reference audio aka prompt
|
|
|
prompt_tokens = None
|
|
|
- if req.reference_audio is not None:
|
|
|
- buffer = io.BytesIO(base64.b64decode(req.reference_audio))
|
|
|
- reference_audio_content, _ = librosa.load(
|
|
|
- buffer, sr=vqgan_model.sampling_rate, mono=True
|
|
|
- )
|
|
|
- audios = torch.from_numpy(reference_audio_content).to(vqgan_model.device)[
|
|
|
- None, None, :
|
|
|
- ]
|
|
|
|
|
|
- logger.info(
|
|
|
- f"Loaded audio with {audios.shape[2] / vqgan_model.sampling_rate:.2f} seconds"
|
|
|
- )
|
|
|
-
|
|
|
- # VQ Encoder
|
|
|
- audio_lengths = torch.tensor(
|
|
|
- [audios.shape[2]], device=vqgan_model.device, dtype=torch.long
|
|
|
- )
|
|
|
- prompt_tokens = vqgan_model.encode(audios, audio_lengths)[0][0]
|
|
|
+ # Parse reference audio aka prompt
|
|
|
+ prompt_tokens, reference_embedding = encode_reference(
|
|
|
+ decoder_model=decoder_model,
|
|
|
+ reference_audio=(
|
|
|
+ io.BytesIO(base64.b64decode(req.reference_audio))
|
|
|
+ if req.reference_audio is not None
|
|
|
+ else None
|
|
|
+ ),
|
|
|
+ enable_reference_audio=req.reference_audio is not None,
|
|
|
+ )
|
|
|
|
|
|
# LLAMA Inference
|
|
|
request = dict(
|
|
|
tokenizer=llama_tokenizer,
|
|
|
- device=vqgan_model.device,
|
|
|
+ device=decoder_model.device,
|
|
|
max_new_tokens=req.max_new_tokens,
|
|
|
text=req.text,
|
|
|
top_p=req.top_p,
|
|
|
@@ -126,35 +213,44 @@ def inference(req: InvokeRequest):
|
|
|
speaker=req.speaker,
|
|
|
prompt_tokens=prompt_tokens,
|
|
|
prompt_text=req.reference_text,
|
|
|
- is_streaming=True,
|
|
|
)
|
|
|
|
|
|
- payload = dict(
|
|
|
- response_queue=queue.Queue(),
|
|
|
- request=request,
|
|
|
+ response_queue = queue.Queue()
|
|
|
+ llama_queue.put(
|
|
|
+ GenerateRequest(
|
|
|
+ request=request,
|
|
|
+ response_queue=response_queue,
|
|
|
+ )
|
|
|
)
|
|
|
- llama_queue.put(payload)
|
|
|
|
|
|
if req.streaming:
|
|
|
yield wav_chunk_header()
|
|
|
|
|
|
segments = []
|
|
|
while True:
|
|
|
- result = payload["response_queue"].get()
|
|
|
- if result == "next":
|
|
|
- # TODO: handle next sentence
|
|
|
- continue
|
|
|
-
|
|
|
- if result == "done":
|
|
|
- if payload["success"] is False:
|
|
|
- raise payload["response"]
|
|
|
+ result: WrappedGenerateResponse = response_queue.get()
|
|
|
+ if result.status == "error":
|
|
|
+ raise result.response
|
|
|
break
|
|
|
|
|
|
- # VQGAN Inference
|
|
|
- feature_lengths = torch.tensor([result.shape[1]], device=vqgan_model.device)
|
|
|
- fake_audios = vqgan_model.decode(
|
|
|
- indices=result[None], feature_lengths=feature_lengths, return_audios=True
|
|
|
- )[0, 0]
|
|
|
+ result: GenerateResponse = result.response
|
|
|
+ if result.action == "next":
|
|
|
+ break
|
|
|
+
|
|
|
+ text_tokens = llama_tokenizer.encode(result.text, return_tensors="pt").to(
|
|
|
+ decoder_model.device
|
|
|
+ )
|
|
|
+
|
|
|
+ with torch.autocast(
|
|
|
+ device_type=decoder_model.device.type, dtype=args.precision
|
|
|
+ ):
|
|
|
+ fake_audios = decode_vq_tokens(
|
|
|
+ decoder_model=decoder_model,
|
|
|
+ codes=result.codes,
|
|
|
+ text_tokens=text_tokens,
|
|
|
+ reference_embedding=reference_embedding,
|
|
|
+ )
|
|
|
+
|
|
|
fake_audios = fake_audios.float().cpu().numpy()
|
|
|
|
|
|
if req.streaming:
|
|
|
@@ -162,14 +258,17 @@ def inference(req: InvokeRequest):
|
|
|
else:
|
|
|
segments.append(fake_audios)
|
|
|
|
|
|
+ if req.streaming:
|
|
|
+ return
|
|
|
+
|
|
|
if len(segments) == 0:
|
|
|
raise HTTPException(
|
|
|
HTTPStatus.INTERNAL_SERVER_ERROR,
|
|
|
content="No audio generated, please check the input text.",
|
|
|
)
|
|
|
- elif req.streaming is False:
|
|
|
- fake_audios = np.concatenate(segments, axis=0)
|
|
|
- yield fake_audios
|
|
|
+
|
|
|
+ fake_audios = np.concatenate(segments, axis=0)
|
|
|
+ yield fake_audios
|
|
|
|
|
|
|
|
|
@routes.http.post("/v1/invoke")
|
|
|
@@ -204,7 +303,7 @@ def api_invoke_model(
|
|
|
else:
|
|
|
fake_audios = next(generator)
|
|
|
buffer = io.BytesIO()
|
|
|
- sf.write(buffer, fake_audios, vqgan_model.sampling_rate, format=req.format)
|
|
|
+ sf.write(buffer, fake_audios, decoder_model.sampling_rate, format=req.format)
|
|
|
|
|
|
return StreamResponse(
|
|
|
iterable=[buffer.getvalue()],
|
|
|
@@ -235,11 +334,11 @@ def parse_args():
|
|
|
"--llama-config-name", type=str, default="dual_ar_2_codebook_medium"
|
|
|
)
|
|
|
parser.add_argument(
|
|
|
- "--vqgan-checkpoint-path",
|
|
|
+ "--decoder-checkpoint-path",
|
|
|
type=str,
|
|
|
default="checkpoints/vq-gan-group-fsq-2x1024.pth",
|
|
|
)
|
|
|
- parser.add_argument("--vqgan-config-name", type=str, default="vqgan_pretrain")
|
|
|
+ parser.add_argument("--decoder-config-name", type=str, default="vqgan_pretrain")
|
|
|
parser.add_argument("--tokenizer", type=str, default="fishaudio/fish-speech-1")
|
|
|
parser.add_argument("--device", type=str, default="cuda")
|
|
|
parser.add_argument("--half", action="store_true")
|
|
|
@@ -288,9 +387,9 @@ if __name__ == "__main__":
|
|
|
llama_tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
|
|
|
logger.info("Llama model loaded, loading VQ-GAN model...")
|
|
|
|
|
|
- vqgan_model = load_vqgan_model(
|
|
|
- config_name=args.vqgan_config_name,
|
|
|
- checkpoint_path=args.vqgan_checkpoint_path,
|
|
|
+ decoder_model = load_decoder_model(
|
|
|
+ config_name=args.decoder_config_name,
|
|
|
+ checkpoint_path=args.decoder_checkpoint_path,
|
|
|
device=args.device,
|
|
|
)
|
|
|
|