| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495 |
- import base64
- import io
- import json
- import queue
- import random
- import sys
- import traceback
- import wave
- from argparse import ArgumentParser
- from http import HTTPStatus
- from pathlib import Path
- from typing import Annotated, Literal, Optional
- import numpy as np
- import pyrootutils
- import soundfile as sf
- import torch
- import torchaudio
- from kui.asgi import (
- Body,
- HTTPException,
- HttpView,
- JSONResponse,
- Kui,
- OpenAPI,
- StreamResponse,
- )
- from kui.asgi.routing import MultimethodRoutes
- from loguru import logger
- from pydantic import BaseModel, Field
- pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
- # from fish_speech.models.vqgan.lit_module import VQGAN
- from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture
- from fish_speech.utils import autocast_exclude_mps
- from tools.auto_rerank import batch_asr, calculate_wer, is_chinese, load_model
- from tools.llama.generate import (
- GenerateRequest,
- GenerateResponse,
- WrappedGenerateResponse,
- launch_thread_safe_queue,
- )
- from tools.vqgan.inference import load_model as load_decoder_model
- def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
- buffer = io.BytesIO()
- with wave.open(buffer, "wb") as wav_file:
- wav_file.setnchannels(channels)
- wav_file.setsampwidth(bit_depth // 8)
- wav_file.setframerate(sample_rate)
- wav_header_bytes = buffer.getvalue()
- buffer.close()
- return wav_header_bytes
- # Define utils for web server
- async def http_execption_handler(exc: HTTPException):
- return JSONResponse(
- dict(
- statusCode=exc.status_code,
- message=exc.content,
- error=HTTPStatus(exc.status_code).phrase,
- ),
- exc.status_code,
- exc.headers,
- )
- async def other_exception_handler(exc: "Exception"):
- traceback.print_exc()
- status = HTTPStatus.INTERNAL_SERVER_ERROR
- return JSONResponse(
- dict(statusCode=status, message=str(exc), error=status.phrase),
- status,
- )
- def load_audio(reference_audio, sr):
- if len(reference_audio) > 255 or not Path(reference_audio).exists():
- try:
- audio_data = base64.b64decode(reference_audio)
- reference_audio = io.BytesIO(audio_data)
- except base64.binascii.Error:
- raise ValueError("Invalid path or base64 string")
- waveform, original_sr = torchaudio.load(
- reference_audio, backend="sox" if sys.platform == "linux" else "soundfile"
- )
- if waveform.shape[0] > 1:
- waveform = torch.mean(waveform, dim=0, keepdim=True)
- if original_sr != sr:
- resampler = torchaudio.transforms.Resample(orig_freq=original_sr, new_freq=sr)
- waveform = resampler(waveform)
- audio = waveform.squeeze().numpy()
- return audio
- def encode_reference(*, decoder_model, reference_audio, enable_reference_audio):
- if enable_reference_audio and reference_audio is not None:
- # Load audios, and prepare basic info here
- reference_audio_content = load_audio(
- reference_audio, decoder_model.spec_transform.sample_rate
- )
- audios = torch.from_numpy(reference_audio_content).to(decoder_model.device)[
- None, None, :
- ]
- audio_lengths = torch.tensor(
- [audios.shape[2]], device=decoder_model.device, dtype=torch.long
- )
- logger.info(
- f"Loaded audio with {audios.shape[2] / decoder_model.spec_transform.sample_rate:.2f} seconds"
- )
- # VQ Encoder
- if isinstance(decoder_model, FireflyArchitecture):
- prompt_tokens = decoder_model.encode(audios, audio_lengths)[0][0]
- logger.info(f"Encoded prompt: {prompt_tokens.shape}")
- else:
- prompt_tokens = None
- logger.info("No reference audio provided")
- return prompt_tokens
- def decode_vq_tokens(
- *,
- decoder_model,
- codes,
- ):
- feature_lengths = torch.tensor([codes.shape[1]], device=decoder_model.device)
- logger.info(f"VQ features: {codes.shape}")
- if isinstance(decoder_model, FireflyArchitecture):
- # VQGAN Inference
- return decoder_model.decode(
- indices=codes[None],
- feature_lengths=feature_lengths,
- ).squeeze()
- raise ValueError(f"Unknown model type: {type(decoder_model)}")
- routes = MultimethodRoutes(base_class=HttpView)
- def get_random_paths(base_path, data, speaker, emotion):
- if base_path and data and speaker and emotion and (Path(base_path).exists()):
- if speaker in data and emotion in data[speaker]:
- files = data[speaker][emotion]
- lab_files = [f for f in files if f.endswith(".lab")]
- wav_files = [f for f in files if f.endswith(".wav")]
- if lab_files and wav_files:
- selected_lab = random.choice(lab_files)
- selected_wav = random.choice(wav_files)
- lab_path = Path(base_path) / speaker / emotion / selected_lab
- wav_path = Path(base_path) / speaker / emotion / selected_wav
- if lab_path.exists() and wav_path.exists():
- return lab_path, wav_path
- return None, None
- def load_json(json_file):
- if not json_file:
- logger.info("Not using a json file")
- return None
- try:
- with open(json_file, "r", encoding="utf-8") as file:
- data = json.load(file)
- except FileNotFoundError:
- logger.warning(f"ref json not found: {json_file}")
- data = None
- except Exception as e:
- logger.warning(f"Loading json failed: {e}")
- data = None
- return data
- class InvokeRequest(BaseModel):
- text: str = "你说的对, 但是原神是一款由米哈游自主研发的开放世界手游."
- reference_text: Optional[str] = None
- reference_audio: Optional[str] = None
- max_new_tokens: int = 1024
- chunk_length: Annotated[int, Field(ge=0, le=500, strict=True)] = 100
- top_p: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
- repetition_penalty: Annotated[float, Field(ge=0.9, le=2.0, strict=True)] = 1.2
- temperature: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
- emotion: Optional[str] = None
- format: Literal["wav", "mp3", "flac"] = "wav"
- streaming: bool = False
- ref_json: Optional[str] = "ref_data.json"
- ref_base: Optional[str] = "ref_data"
- speaker: Optional[str] = None
- def get_content_type(audio_format):
- if audio_format == "wav":
- return "audio/wav"
- elif audio_format == "flac":
- return "audio/flac"
- elif audio_format == "mp3":
- return "audio/mpeg"
- else:
- return "application/octet-stream"
- @torch.inference_mode()
- def inference(req: InvokeRequest):
- # Parse reference audio aka prompt
- prompt_tokens = None
- ref_data = load_json(req.ref_json)
- ref_base = req.ref_base
- lab_path, wav_path = get_random_paths(ref_base, ref_data, req.speaker, req.emotion)
- if lab_path and wav_path:
- with open(lab_path, "r", encoding="utf-8") as lab_file:
- ref_text = lab_file.read()
- req.reference_audio = wav_path
- req.reference_text = ref_text
- logger.info("ref_path: " + str(wav_path))
- logger.info("ref_text: " + ref_text)
- # Parse reference audio aka prompt
- prompt_tokens = encode_reference(
- decoder_model=decoder_model,
- reference_audio=req.reference_audio,
- enable_reference_audio=req.reference_audio is not None,
- )
- logger.info(f"ref_text: {req.reference_text}")
- # LLAMA Inference
- request = dict(
- device=decoder_model.device,
- max_new_tokens=req.max_new_tokens,
- text=req.text,
- top_p=req.top_p,
- repetition_penalty=req.repetition_penalty,
- temperature=req.temperature,
- compile=args.compile,
- iterative_prompt=req.chunk_length > 0,
- chunk_length=req.chunk_length,
- max_length=2048,
- prompt_tokens=prompt_tokens,
- prompt_text=req.reference_text,
- )
- response_queue = queue.Queue()
- llama_queue.put(
- GenerateRequest(
- request=request,
- response_queue=response_queue,
- )
- )
- if req.streaming:
- yield wav_chunk_header()
- segments = []
- while True:
- result: WrappedGenerateResponse = response_queue.get()
- if result.status == "error":
- raise result.response
- break
- result: GenerateResponse = result.response
- if result.action == "next":
- break
- with autocast_exclude_mps(
- device_type=decoder_model.device.type, dtype=args.precision
- ):
- fake_audios = decode_vq_tokens(
- decoder_model=decoder_model,
- codes=result.codes,
- )
- fake_audios = fake_audios.float().cpu().numpy()
- if req.streaming:
- yield (fake_audios * 32768).astype(np.int16).tobytes()
- else:
- segments.append(fake_audios)
- if req.streaming:
- return
- if len(segments) == 0:
- raise HTTPException(
- HTTPStatus.INTERNAL_SERVER_ERROR,
- content="No audio generated, please check the input text.",
- )
- fake_audios = np.concatenate(segments, axis=0)
- yield fake_audios
- def auto_rerank_inference(req: InvokeRequest, use_auto_rerank: bool = True):
- if not use_auto_rerank:
- # 如果不使用 auto_rerank,直接调用原始的 inference 函数
- return inference(req)
- zh_model, en_model = load_model()
- max_attempts = 5
- best_wer = float("inf")
- best_audio = None
- for attempt in range(max_attempts):
- # 调用原始的 inference 函数
- audio_generator = inference(req)
- fake_audios = next(audio_generator)
- asr_result = batch_asr(
- zh_model if is_chinese(req.text) else en_model, [fake_audios], 44100
- )[0]
- wer = calculate_wer(req.text, asr_result["text"])
- if wer <= 0.1 and not asr_result["huge_gap"]:
- return fake_audios
- if wer < best_wer:
- best_wer = wer
- best_audio = fake_audios
- if attempt == max_attempts - 1:
- break
- return best_audio
- async def inference_async(req: InvokeRequest):
- for chunk in inference(req):
- yield chunk
- async def buffer_to_async_generator(buffer):
- yield buffer
- @routes.http.post("/v1/invoke")
- async def api_invoke_model(
- req: Annotated[InvokeRequest, Body(exclusive=True)],
- ):
- """
- Invoke model and generate audio
- """
- if args.max_text_length > 0 and len(req.text) > args.max_text_length:
- raise HTTPException(
- HTTPStatus.BAD_REQUEST,
- content=f"Text is too long, max length is {args.max_text_length}",
- )
- if req.streaming and req.format != "wav":
- raise HTTPException(
- HTTPStatus.BAD_REQUEST,
- content="Streaming only supports WAV format",
- )
- if req.streaming:
- return StreamResponse(
- iterable=inference_async(req),
- headers={
- "Content-Disposition": f"attachment; filename=audio.{req.format}",
- },
- content_type=get_content_type(req.format),
- )
- else:
- fake_audios = next(inference(req))
- buffer = io.BytesIO()
- sf.write(
- buffer,
- fake_audios,
- decoder_model.spec_transform.sample_rate,
- format=req.format,
- )
- return StreamResponse(
- iterable=buffer_to_async_generator(buffer.getvalue()),
- headers={
- "Content-Disposition": f"attachment; filename=audio.{req.format}",
- },
- content_type=get_content_type(req.format),
- )
- @routes.http.post("/v1/health")
- async def api_health():
- """
- Health check
- """
- return JSONResponse({"status": "ok"})
- def parse_args():
- parser = ArgumentParser()
- parser.add_argument(
- "--llama-checkpoint-path",
- type=str,
- default="checkpoints/fish-speech-1.2-sft",
- )
- parser.add_argument(
- "--decoder-checkpoint-path",
- type=str,
- default="checkpoints/fish-speech-1.2-sft/firefly-gan-vq-fsq-4x1024-42hz-generator.pth",
- )
- parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
- parser.add_argument("--device", type=str, default="cuda")
- parser.add_argument("--half", action="store_true")
- parser.add_argument("--compile", action="store_true")
- parser.add_argument("--max-text-length", type=int, default=0)
- parser.add_argument("--listen", type=str, default="127.0.0.1:8000")
- parser.add_argument("--workers", type=int, default=1)
- parser.add_argument("--use-auto-rerank", type=bool, default=True)
- return parser.parse_args()
- # Define Kui app
- openapi = OpenAPI(
- {
- "title": "Fish Speech API",
- },
- ).routes
- app = Kui(
- routes=routes + openapi[1:], # Remove the default route
- exception_handlers={
- HTTPException: http_execption_handler,
- Exception: other_exception_handler,
- },
- cors_config={},
- )
- if __name__ == "__main__":
- import threading
- import uvicorn
- args = parse_args()
- args.precision = torch.half if args.half else torch.bfloat16
- logger.info("Loading Llama model...")
- llama_queue = launch_thread_safe_queue(
- checkpoint_path=args.llama_checkpoint_path,
- device=args.device,
- precision=args.precision,
- compile=args.compile,
- )
- logger.info("Llama model loaded, loading VQ-GAN model...")
- decoder_model = load_decoder_model(
- config_name=args.decoder_config_name,
- checkpoint_path=args.decoder_checkpoint_path,
- device=args.device,
- )
- logger.info("VQ-GAN model loaded, warming up...")
- # Dry run to check if the model is loaded correctly and avoid the first-time latency
- list(
- inference(
- InvokeRequest(
- text="Hello world.",
- reference_text=None,
- reference_audio=None,
- max_new_tokens=0,
- top_p=0.7,
- repetition_penalty=1.2,
- temperature=0.7,
- emotion=None,
- format="wav",
- ref_base=None,
- ref_json=None,
- )
- )
- )
- logger.info(f"Warming up done, starting server at http://{args.listen}")
- host, port = args.listen.split(":")
- uvicorn.run(app, host=host, port=int(port), workers=args.workers, log_level="info")
|