| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943 |
- import io
- import os
- import queue
- import re
- import time
- import traceback
- import wave
- from argparse import ArgumentParser
- from http import HTTPStatus
- from pathlib import Path
- from typing import Annotated, Any
- import librosa
- import numpy as np
- import ormsgpack
- import pyrootutils
- import soundfile as sf
- import torch
- import torchaudio
- from baize.datastructures import ContentType
- from kui.asgi import (
- Body,
- FactoryClass,
- HTTPException,
- HttpRequest,
- HttpView,
- JSONResponse,
- Kui,
- OpenAPI,
- StreamResponse,
- request,
- )
- from kui.asgi.routing import MultimethodRoutes
- from loguru import logger
- from transformers import AutoTokenizer
- pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
- import struct
- from threading import Lock
- import httpx
- from cachetools import LRUCache, cached
- from funasr import AutoModel
- from silero_vad import get_speech_timestamps, load_silero_vad
- from fish_speech.conversation import IM_END_TOKEN, SEMANTIC_TOKEN
- from fish_speech.models.text2semantic.llama import BaseModelArgs
- # from fish_speech.models.vqgan.lit_module import VQGAN
- from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture
- from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
- from fish_speech.utils import autocast_exclude_mps, set_seed
- from tools.file import AUDIO_EXTENSIONS, audio_to_bytes, list_files, read_ref_text
- from tools.llama.generate import (
- GenerateRequest,
- GenerateResponse,
- WrappedGenerateResponse,
- launch_thread_safe_queue,
- launch_thread_safe_queue_agent,
- )
- from tools.schema import (
- GLOBAL_NUM_SAMPLES,
- ASRPackRequest,
- ServeASRRequest,
- ServeASRResponse,
- ServeASRSegment,
- ServeAudioPart,
- ServeForwardMessage,
- ServeMessage,
- ServeRequest,
- ServeResponse,
- ServeStreamDelta,
- ServeStreamResponse,
- ServeTextPart,
- ServeTimedASRResponse,
- ServeTTSRequest,
- ServeVQGANDecodeRequest,
- ServeVQGANDecodeResponse,
- ServeVQGANEncodeRequest,
- ServeVQGANEncodeResponse,
- ServeVQPart,
- )
- from tools.vqgan.inference import load_model as load_decoder_model
- global_lock = Lock()
- # Whether to disable keepalive (which is helpful if the server is in the same cluster)
- DISABLE_KEEPALIVE = os.getenv("DISABLE_KEEPALIVE", "false").lower() == "true"
- async_client = httpx.AsyncClient(
- timeout=120, limits=httpx.Limits(keepalive_expiry=0 if DISABLE_KEEPALIVE else None)
- )
- backends = torchaudio.list_audio_backends()
- if "ffmpeg" in backends:
- backend = "ffmpeg"
- else:
- backend = "soundfile"
- def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
- buffer = io.BytesIO()
- with wave.open(buffer, "wb") as wav_file:
- wav_file.setnchannels(channels)
- wav_file.setsampwidth(bit_depth // 8)
- wav_file.setframerate(sample_rate)
- wav_header_bytes = buffer.getvalue()
- buffer.close()
- return wav_header_bytes
- # Define utils for web server
- async def http_execption_handler(exc: HTTPException):
- return JSONResponse(
- dict(
- statusCode=exc.status_code,
- message=exc.content,
- error=HTTPStatus(exc.status_code).phrase,
- ),
- exc.status_code,
- exc.headers,
- )
- async def other_exception_handler(exc: "Exception"):
- traceback.print_exc()
- status = HTTPStatus.INTERNAL_SERVER_ERROR
- return JSONResponse(
- dict(statusCode=status, message=str(exc), error=status.phrase),
- status,
- )
- def load_audio(reference_audio, sr):
- if len(reference_audio) > 255 or not Path(reference_audio).exists():
- audio_data = reference_audio
- reference_audio = io.BytesIO(audio_data)
- waveform, original_sr = torchaudio.load(reference_audio, backend=backend)
- if waveform.shape[0] > 1:
- waveform = torch.mean(waveform, dim=0, keepdim=True)
- if original_sr != sr:
- resampler = torchaudio.transforms.Resample(orig_freq=original_sr, new_freq=sr)
- waveform = resampler(waveform)
- audio = waveform.squeeze().numpy()
- return audio
- def encode_reference(*, decoder_model, reference_audio, enable_reference_audio):
- if enable_reference_audio and reference_audio is not None:
- # Load audios, and prepare basic info here
- reference_audio_content = load_audio(
- reference_audio, decoder_model.spec_transform.sample_rate
- )
- audios = torch.from_numpy(reference_audio_content).to(decoder_model.device)[
- None, None, :
- ]
- audio_lengths = torch.tensor(
- [audios.shape[2]], device=decoder_model.device, dtype=torch.long
- )
- logger.info(
- f"Loaded audio with {audios.shape[2] / decoder_model.spec_transform.sample_rate:.2f} seconds"
- )
- # VQ Encoder
- if isinstance(decoder_model, FireflyArchitecture):
- prompt_tokens = decoder_model.encode(audios, audio_lengths)[0][0]
- logger.info(f"Encoded prompt: {prompt_tokens.shape}")
- else:
- prompt_tokens = None
- logger.info("No reference audio provided")
- return prompt_tokens
- def decode_vq_tokens(
- *,
- decoder_model,
- codes,
- ):
- feature_lengths = torch.tensor([codes.shape[1]], device=decoder_model.device)
- logger.info(f"VQ features: {codes.shape}")
- if isinstance(decoder_model, FireflyArchitecture):
- # VQGAN Inference
- return decoder_model.decode(
- indices=codes[None],
- feature_lengths=feature_lengths,
- )[0].squeeze()
- raise ValueError(f"Unknown model type: {type(decoder_model)}")
- routes = MultimethodRoutes(base_class=HttpView)
- def get_content_type(audio_format):
- if audio_format == "wav":
- return "audio/wav"
- elif audio_format == "flac":
- return "audio/flac"
- elif audio_format == "mp3":
- return "audio/mpeg"
- else:
- return "application/octet-stream"
- @torch.no_grad()
- @torch.autocast(device_type="cuda", dtype=torch.half)
- def batch_encode(model, audios: list[bytes | torch.Tensor]):
- audios = [
- (
- torch.from_numpy(
- librosa.load(io.BytesIO(audio), sr=model.spec_transform.sample_rate)[0]
- )[None]
- if isinstance(audio, bytes)
- else audio
- )
- for audio in audios
- ]
- # if any(audio.shape[-1] > model.spec_transform.sample_rate * 120 for audio in audios):
- # raise ValueError("Single audio length is too long (>120s)")
- max_length = max(audio.shape[-1] for audio in audios)
- print(f"Encode max length: {max_length / model.spec_transform.sample_rate:.2f}s")
- lengths = torch.tensor([audio.shape[-1] for audio in audios], device=model.device)
- max_length = lengths.max().item()
- padded = torch.stack(
- [
- torch.nn.functional.pad(audio, (0, max_length - audio.shape[-1]))
- for audio in audios
- ]
- ).to(model.device)
- features, feature_lengths = model.encode(padded, audio_lengths=lengths)
- features, feature_lengths = features.cpu(), feature_lengths.cpu()
- return [feature[..., :length] for feature, length in zip(features, feature_lengths)]
- @cached(
- cache=LRUCache(maxsize=10000),
- key=lambda model, audios: (model.device, tuple(audios)),
- )
- def cached_vqgan_batch_encode(model, audios: list[bytes]):
- return batch_encode(model, audios)
- @routes.http.post("/v1/vqgan/encode")
- def api_vqgan_encode(payload: Annotated[ServeVQGANEncodeRequest, Body(exclusive=True)]):
- start_time = time.time()
- tokens = cached_vqgan_batch_encode(decoder_model, payload.audios)
- logger.info(f"[EXEC] VQGAN encode time: {(time.time() - start_time) * 1000:.2f}ms")
- return ormsgpack.packb(
- ServeVQGANEncodeResponse(tokens=[i.tolist() for i in tokens]),
- option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
- )
- @torch.no_grad()
- @torch.autocast(device_type="cuda", dtype=torch.half)
- def vqgan_decode(model, features):
- lengths = torch.tensor(
- [feature.shape[-1] for feature in features], device=model.device
- )
- max_length = lengths.max().item()
- padded = torch.stack(
- [
- torch.nn.functional.pad(feature, (0, max_length - feature.shape[-1]))
- for feature in features
- ]
- ).to(model.device)
- # If bs too large, we do micro batch decode
- audios, audio_lengths = [], []
- for i in range(0, padded.shape[0], 8):
- audio, audio_length = model.decode(
- padded[i : i + 8], feature_lengths=lengths[i : i + 8]
- )
- audios.append(audio)
- audio_lengths.append(audio_length)
- audios = torch.cat(audios, dim=0)
- audio_lengths = torch.cat(audio_lengths, dim=0)
- audios, audio_lengths = audios.cpu(), audio_lengths.cpu()
- return [audio[..., :length].numpy() for audio, length in zip(audios, audio_lengths)]
- @routes.http.post("/v1/vqgan/decode")
- def api_vqgan_decode(payload: Annotated[ServeVQGANDecodeRequest, Body(exclusive=True)]):
- tokens = [torch.tensor(token, dtype=torch.int) for token in payload.tokens]
- start_time = time.time()
- audios = vqgan_decode(decoder_model, tokens)
- logger.info(f"[EXEC] VQGAN decode time: {(time.time() - start_time) * 1000:.2f}ms")
- audios = [audio.astype(np.float16).tobytes() for audio in audios]
- return ormsgpack.packb(
- ServeVQGANDecodeResponse(audios=audios), option=ormsgpack.OPT_SERIALIZE_PYDANTIC
- )
- @torch.no_grad()
- def batch_asr(model, audios, sr, language="auto"):
- resampled_audios = []
- for audio in audios:
- audio = torchaudio.functional.resample(audio, sr, 16000)
- assert audio.ndim == 1
- resampled_audios.append(audio)
- with global_lock:
- res = model.generate(
- input=resampled_audios,
- batch_size=len(resampled_audios),
- language=language,
- use_itn=True,
- )
- results = []
- for r, audio in zip(res, audios):
- text = r["text"]
- text = re.sub(r"<\|.*?\|>", "", text)
- duration = len(audio) / sr * 1000
- huge_gap = False
- if "timestamp" in r and len(r["timestamp"]) > 2:
- for timestamp_a, timestamp_b in zip(
- r["timestamp"][:-1], r["timestamp"][1:]
- ):
- # If there is a gap of more than 5 seconds, we consider it as a huge gap
- if timestamp_b[0] - timestamp_a[1] > 5000:
- huge_gap = True
- break
- # Doesn't make sense to have a huge gap at the end
- if duration - r["timestamp"][-1][1] > 3000:
- huge_gap = True
- results.append(
- {
- "text": text,
- "duration": duration,
- "huge_gap": huge_gap,
- }
- )
- return results
- @routes.http.post("/v1/asr")
- def api_invoke_asr(payload: Annotated[ServeASRRequest, Body(exclusive=True)]):
- start_time = time.time()
- audios = [np.frombuffer(audio, dtype=np.float16) for audio in payload.audios]
- audios = [torch.from_numpy(audio).float() for audio in audios]
- if any(audios.shape[-1] >= 30 * payload.sample_rate for audios in audios):
- raise HTTPException(status_code=400, detail="Audio length is too long")
- transcriptions = batch_asr(
- asr_model, audios=audios, sr=payload.sample_rate, language=payload.language
- )
- logger.info(f"[EXEC] ASR time: {(time.time() - start_time) * 1000:.2f}ms")
- return ormsgpack.packb(
- ServeASRResponse(transcriptions=transcriptions),
- option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
- )
- from fish_speech.conversation import Conversation, Message
- def execute_request(
- input_queue: queue.Queue,
- tokenizer: AutoTokenizer,
- config: BaseModelArgs,
- request: ServeRequest,
- device: str = "cuda:0",
- ):
- semantic_id, im_end_id = tokenizer.convert_tokens_to_ids(
- [SEMANTIC_TOKEN, IM_END_TOKEN]
- )
- messages = []
- for message in request.messages:
- messages.append(message.to_conversation_message())
- assert len(messages) >= 1, "At least one message is required"
- # assert messages[-1].role == "user", "The last message must be from the user"
- if messages[-1].role == "user":
- messages.append(Message(role="assistant", parts=[], add_im_end=False))
- else:
- assert (
- messages[-1].role == "assistant"
- ), "The last message must be from the assistant"
- messages[-1].add_im_end = False
- conv = Conversation(messages=messages)
- prompt = conv.encode_for_inference(
- tokenizer=tokenizer, num_codebooks=config.num_codebooks
- ).to(device)
- if request.streaming:
- for i in range(request.num_samples):
- yield ServeStreamResponse(
- sample_id=i,
- delta=ServeStreamDelta(
- role="assistant",
- ),
- )
- req = {
- "prompt": prompt,
- "max_new_tokens": request.max_new_tokens,
- "im_end_id": im_end_id,
- "semantic_id": semantic_id,
- "temperature": request.temperature,
- "top_p": request.top_p,
- "repetition_penalty": request.repetition_penalty,
- "num_samples": request.num_samples,
- "early_stop_threshold": request.early_stop_threshold,
- }
- start = time.time()
- response_queue = queue.Queue()
- input_queue.put(GenerateRequest(req, response_queue))
- # Decoding
- decode_buffer = [[] for _ in range(request.num_samples)]
- parts = [[] for _ in range(request.num_samples)]
- def send_reset_buffer(sample_id):
- nonlocal decode_buffer
- if len(decode_buffer[sample_id]) == 0:
- return
- decoded = tokenizer.decode(decode_buffer[sample_id])
- part = ServeTextPart(text=decoded)
- if request.streaming:
- yield ServeStreamResponse(delta=ServeStreamDelta(part=part))
- else:
- parts[sample_id].append(part)
- decode_buffer[sample_id] = []
- # Decode process
- finished = [False for _ in range(request.num_samples)]
- stats = {}
- idx = 0
- while True:
- response = response_queue.get()
- if response in ["stop", "error"]:
- break
- for sample_id, tokens in enumerate(response):
- if finished[sample_id]:
- continue
- if tokens[0] == im_end_id:
- finished[sample_id] = True
- if request.streaming:
- yield from send_reset_buffer(sample_id)
- yield ServeStreamResponse(
- sample_id=sample_id,
- finish_reason="stop",
- stats=stats,
- )
- continue
- if tokens[0] == semantic_id and request.streaming:
- yield from send_reset_buffer(sample_id)
- # Streaming vq
- _tokens = tokens[1:].clone() - 1
- if config.share_codebook_embeddings is False:
- for i in range(len(_tokens)):
- _tokens[i] -= config.codebook_size * i
- yield ServeStreamResponse(
- sample_id=sample_id,
- delta=ServeStreamDelta(part=ServeVQPart(codes=_tokens.tolist())),
- )
- continue
- # Not streaming vq
- if tokens[0] == semantic_id:
- yield from send_reset_buffer(sample_id)
- # None streaming vq
- if len(parts[sample_id]) == 0 or not isinstance(
- parts[sample_id][-1], ServeVQPart
- ):
- _tokens = tokens[1:].clone() - 1
- if config.share_codebook_embeddings is False:
- for i in range(len(_tokens)):
- _tokens[i] -= config.codebook_size * i
- parts[sample_id].append(ServeVQPart(codes=_tokens.tolist()))
- else:
- for codebook_id, value in enumerate(tokens[1:, :]):
- val = value.item() - 1
- if config.share_codebook_embeddings is False:
- val -= config.codebook_size * codebook_id
- parts[sample_id][-1].codes[codebook_id].append(val)
- continue
- if tokens[0] != semantic_id:
- # Stream text decode is not supported now
- decode_buffer[sample_id].append(tokens[0, 0])
- if idx == 0:
- stats["time_to_first_token"] = (time.time() - start) * 1000
- idx += 1
- for sample_id in range(request.num_samples):
- yield from send_reset_buffer(sample_id)
- stats["total_time"] = (time.time() - start) * 1000
- stats["total_tokens"] = idx
- if request.streaming:
- for sample_id in range(request.num_samples):
- if finished[sample_id]:
- continue
- yield ServeStreamResponse(
- finish_reason=response, stats=stats, sample_id=sample_id
- )
- return
- yield ServeResponse(
- messages=[
- ServeMessage(role="assistant", parts=parts[i])
- for i in range(request.num_samples)
- ],
- finish_reason=response,
- stats=stats,
- )
- @routes.http.post("/v1/chat")
- def api_invoke_chat(
- req: Annotated[ServeRequest, Body(exclusive=True)],
- ):
- """
- Invoke model and generate audio
- """
- # This makes torch compile happy
- assert (
- req.num_samples == GLOBAL_NUM_SAMPLES
- ), f"num_samples must be {GLOBAL_NUM_SAMPLES}"
- content_type = request.headers.get("Content-Type", "application/json")
- json_mode = "application/json" in content_type
- async def wrapped_generator():
- generator = execute_request(llama_queue, tokenizer, config, req, args.device)
- for i in generator:
- if json_mode:
- body = i.model_dump_json().encode("utf-8")
- yield b"data: " + body + b"\n\n"
- else:
- body = ormsgpack.packb(i, option=ormsgpack.OPT_SERIALIZE_PYDANTIC)
- yield struct.pack("I", len(body)) + body
- # Naive mode
- if req.streaming is False:
- result = next(execute_request(llama_queue, tokenizer, config, req, args.device))
- if json_mode:
- return JSONResponse(result.model_dump())
- else:
- return ormsgpack.packb(result, option=ormsgpack.OPT_SERIALIZE_PYDANTIC)
- return StreamResponse(
- iterable=wrapped_generator(), content_type="text/event-stream"
- )
- @torch.inference_mode()
- def inference(req: ServeTTSRequest):
- global prompt_tokens, prompt_texts
- idstr: str | None = req.reference_id
- if idstr is not None:
- ref_folder = Path("references") / idstr
- ref_folder.mkdir(parents=True, exist_ok=True)
- ref_audios = list_files(
- ref_folder, AUDIO_EXTENSIONS, recursive=True, sort=False
- )
- if req.use_memory_cache == "never" or (
- req.use_memory_cache == "on-demand" and len(prompt_tokens) == 0
- ):
- prompt_tokens = [
- encode_reference(
- decoder_model=decoder_model,
- reference_audio=audio_to_bytes(str(ref_audio)),
- enable_reference_audio=True,
- )
- for ref_audio in ref_audios
- ]
- prompt_texts = [
- read_ref_text(str(ref_audio.with_suffix(".lab")))
- for ref_audio in ref_audios
- ]
- else:
- logger.info("Use same references")
- else:
- # Parse reference audio aka prompt
- refs = req.references
- if req.use_memory_cache == "never" or (
- req.use_memory_cache == "on-demand" and len(prompt_tokens) == 0
- ):
- prompt_tokens = [
- encode_reference(
- decoder_model=decoder_model,
- reference_audio=ref.audio,
- enable_reference_audio=True,
- )
- for ref in refs
- ]
- prompt_texts = [ref.text for ref in refs]
- else:
- logger.info("Use same references")
- if req.seed is not None:
- set_seed(req.seed)
- logger.warning(f"set seed: {req.seed}")
- # LLAMA Inference
- request = dict(
- device=decoder_model.device,
- max_new_tokens=req.max_new_tokens,
- text=(
- req.text
- if not req.normalize
- else ChnNormedText(raw_text=req.text).normalize()
- ),
- top_p=req.top_p,
- repetition_penalty=req.repetition_penalty,
- temperature=req.temperature,
- compile=args.compile,
- iterative_prompt=req.chunk_length > 0,
- chunk_length=req.chunk_length,
- max_length=4096,
- prompt_tokens=prompt_tokens,
- prompt_text=prompt_texts,
- )
- response_queue = queue.Queue()
- llama_queue.put(
- GenerateRequest(
- request=request,
- response_queue=response_queue,
- )
- )
- if req.streaming:
- yield wav_chunk_header()
- segments = []
- while True:
- result: WrappedGenerateResponse = response_queue.get()
- if result.status == "error":
- raise result.response
- break
- result: GenerateResponse = result.response
- if result.action == "next":
- break
- with autocast_exclude_mps(
- device_type=decoder_model.device.type, dtype=args.precision
- ):
- fake_audios = decode_vq_tokens(
- decoder_model=decoder_model,
- codes=result.codes,
- )
- fake_audios = fake_audios.float().cpu().numpy()
- if req.streaming:
- yield (fake_audios * 32768).astype(np.int16).tobytes()
- else:
- segments.append(fake_audios)
- if req.streaming:
- return
- if len(segments) == 0:
- raise HTTPException(
- HTTPStatus.INTERNAL_SERVER_ERROR,
- content="No audio generated, please check the input text.",
- )
- fake_audios = np.concatenate(segments, axis=0)
- yield fake_audios
- async def inference_async(req: ServeTTSRequest):
- for chunk in inference(req):
- yield chunk
- async def buffer_to_async_generator(buffer):
- yield buffer
- @routes.http.post("/v1/tts")
- async def api_invoke_model(
- req: Annotated[ServeTTSRequest, Body(exclusive=True)],
- ):
- """
- Invoke model and generate audio
- """
- if args.max_text_length > 0 and len(req.text) > args.max_text_length:
- raise HTTPException(
- HTTPStatus.BAD_REQUEST,
- content=f"Text is too long, max length is {args.max_text_length}",
- )
- if req.streaming and req.format != "wav":
- raise HTTPException(
- HTTPStatus.BAD_REQUEST,
- content="Streaming only supports WAV format",
- )
- if req.streaming:
- return StreamResponse(
- iterable=inference_async(req),
- headers={
- "Content-Disposition": f"attachment; filename=audio.{req.format}",
- },
- content_type=get_content_type(req.format),
- )
- else:
- fake_audios = next(inference(req))
- buffer = io.BytesIO()
- sf.write(
- buffer,
- fake_audios,
- decoder_model.spec_transform.sample_rate,
- format=req.format,
- )
- return StreamResponse(
- iterable=buffer_to_async_generator(buffer.getvalue()),
- headers={
- "Content-Disposition": f"attachment; filename=audio.{req.format}",
- },
- content_type=get_content_type(req.format),
- )
- @routes.http.post("/v1/health")
- async def api_health():
- """
- Health check
- """
- return JSONResponse({"status": "ok"})
- def parse_args():
- parser = ArgumentParser()
- parser.add_argument("--mode", type=str, choices=["agent", "tts"], default="tts")
- parser.add_argument("--load-asr-model", action="store_true")
- parser.add_argument(
- "--llama-checkpoint-path",
- type=str,
- default="checkpoints/fish-speech-1.4",
- )
- parser.add_argument(
- "--decoder-checkpoint-path",
- type=str,
- default="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
- )
- parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
- parser.add_argument("--device", type=str, default="cuda")
- parser.add_argument("--half", action="store_true")
- parser.add_argument("--compile", action="store_true")
- parser.add_argument("--max-text-length", type=int, default=0)
- parser.add_argument("--listen", type=str, default="127.0.0.1:8080")
- parser.add_argument("--workers", type=int, default=1)
- return parser.parse_args()
- # Define Kui app
- openapi = OpenAPI(
- {
- "title": "Fish Speech API",
- "version": "1.4.2",
- },
- ).routes
- class MsgPackRequest(HttpRequest):
- async def data(
- self,
- ) -> Annotated[
- Any, ContentType("application/msgpack"), ContentType("application/json")
- ]:
- if self.content_type == "application/msgpack":
- return ormsgpack.unpackb(await self.body)
- elif self.content_type == "application/json":
- return await self.json
- raise HTTPException(
- HTTPStatus.UNSUPPORTED_MEDIA_TYPE,
- headers={"Accept": "application/msgpack, application/json"},
- )
- app = Kui(
- routes=routes + openapi[1:], # Remove the default route
- exception_handlers={
- HTTPException: http_execption_handler,
- Exception: other_exception_handler,
- },
- factory_class=FactoryClass(http=MsgPackRequest),
- cors_config={},
- )
- def load_asr_model(*, device="cuda", hub="ms"):
- return AutoModel(
- model="iic/SenseVoiceSmall",
- device=device,
- disable_pbar=True,
- hub=hub,
- )
- # Each worker process created by Uvicorn has its own memory space,
- # meaning that models and variables are not shared between processes.
- # Therefore, any global variables (like `llama_queue` or `decoder_model`)
- # will not be shared across workers.
- # Multi-threading for deep learning can cause issues, such as inconsistent
- # outputs if multiple threads access the same buffers simultaneously.
- # Instead, it's better to use multiprocessing or independent models per thread.
- @app.on_startup
- def initialize_app(app: Kui):
- global args, llama_queue, tokenizer, config, decoder_model, vad_model, asr_model, prompt_tokens, prompt_texts
- prompt_tokens, prompt_texts = [], []
- args = parse_args() # args same as ones in other processes
- args.precision = torch.half if args.half else torch.bfloat16
- if args.load_asr_model:
- logger.info(f"Loading ASR model...")
- asr_model = load_asr_model(device=args.device)
- logger.info("Loading Llama model...")
- if args.mode == "tts":
- llama_queue = launch_thread_safe_queue(
- checkpoint_path=args.llama_checkpoint_path,
- device=args.device,
- precision=args.precision,
- compile=args.compile,
- )
- else:
- llama_queue, tokenizer, config = launch_thread_safe_queue_agent(
- checkpoint_path=args.llama_checkpoint_path,
- device=args.device,
- precision=args.precision,
- compile=args.compile,
- )
- logger.info("Llama model loaded, loading VQ-GAN model...")
- decoder_model = load_decoder_model(
- config_name=args.decoder_config_name,
- checkpoint_path=args.decoder_checkpoint_path,
- device=args.device,
- )
- logger.info("VQ-GAN model loaded, warming up...")
- vad_model = load_silero_vad()
- logger.info("VAD model loaded, warming up...")
- if args.mode == "tts":
- # Dry run to ensure models work and avoid first-time latency
- list(
- inference(
- ServeTTSRequest(
- text="Hello world.",
- references=[],
- reference_id=None,
- max_new_tokens=0,
- chunk_length=200,
- top_p=0.7,
- repetition_penalty=1.2,
- temperature=0.7,
- emotion=None,
- format="wav",
- )
- )
- )
- logger.info(f"Warming up done, starting server at http://{args.listen}")
- if __name__ == "__main__":
- import uvicorn
- args = parse_args()
- host, port = args.listen.split(":")
- uvicorn.run(
- "tools.api:app",
- host=host,
- port=int(port),
- workers=args.workers,
- log_level="info",
- )
|