|
|
@@ -1,951 +0,0 @@
|
|
|
-import io
|
|
|
-import json
|
|
|
-import os
|
|
|
-import queue
|
|
|
-import re
|
|
|
-import time
|
|
|
-import traceback
|
|
|
-import wave
|
|
|
-from argparse import ArgumentParser
|
|
|
-from http import HTTPStatus
|
|
|
-from pathlib import Path
|
|
|
-from typing import Annotated, Any
|
|
|
-
|
|
|
-import librosa
|
|
|
-import numpy as np
|
|
|
-import ormsgpack
|
|
|
-import pyrootutils
|
|
|
-import soundfile as sf
|
|
|
-import torch
|
|
|
-import torchaudio
|
|
|
-from baize.datastructures import ContentType
|
|
|
-from kui.asgi import (
|
|
|
- Body,
|
|
|
- FactoryClass,
|
|
|
- HTTPException,
|
|
|
- HttpRequest,
|
|
|
- HttpView,
|
|
|
- JSONResponse,
|
|
|
- Kui,
|
|
|
- OpenAPI,
|
|
|
- StreamResponse,
|
|
|
- request,
|
|
|
-)
|
|
|
-from kui.asgi.routing import MultimethodRoutes
|
|
|
-from loguru import logger
|
|
|
-
|
|
|
-pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
|
|
-import struct
|
|
|
-from threading import Lock
|
|
|
-
|
|
|
-import httpx
|
|
|
-from cachetools import LRUCache, cached
|
|
|
-from funasr import AutoModel
|
|
|
-from silero_vad import get_speech_timestamps, load_silero_vad
|
|
|
-
|
|
|
-from fish_speech.models.text2semantic.llama import BaseModelArgs
|
|
|
-
|
|
|
-# from fish_speech.models.vqgan.lit_module import VQGAN
|
|
|
-from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture
|
|
|
-from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
|
|
|
-
|
|
|
-# from fish_speech.conversation import IM_END_TOKEN, SEMANTIC_TOKEN
|
|
|
-from fish_speech.tokenizer import IM_END_TOKEN, FishTokenizer
|
|
|
-from fish_speech.utils import autocast_exclude_mps, set_seed
|
|
|
-from tools.file import AUDIO_EXTENSIONS, audio_to_bytes, list_files, read_ref_text
|
|
|
-from tools.llama.generate import (
|
|
|
- GenerateRequest,
|
|
|
- GenerateResponse,
|
|
|
- WrappedGenerateResponse,
|
|
|
- launch_thread_safe_queue,
|
|
|
- launch_thread_safe_queue_agent,
|
|
|
-)
|
|
|
-from tools.schema import (
|
|
|
- GLOBAL_NUM_SAMPLES,
|
|
|
- ASRPackRequest,
|
|
|
- ServeASRRequest,
|
|
|
- ServeASRResponse,
|
|
|
- ServeASRSegment,
|
|
|
- ServeAudioPart,
|
|
|
- ServeForwardMessage,
|
|
|
- ServeMessage,
|
|
|
- ServeRequest,
|
|
|
- ServeResponse,
|
|
|
- ServeStreamDelta,
|
|
|
- ServeStreamResponse,
|
|
|
- ServeTextPart,
|
|
|
- ServeTimedASRResponse,
|
|
|
- ServeTTSRequest,
|
|
|
- ServeVQGANDecodeRequest,
|
|
|
- ServeVQGANDecodeResponse,
|
|
|
- ServeVQGANEncodeRequest,
|
|
|
- ServeVQGANEncodeResponse,
|
|
|
- ServeVQPart,
|
|
|
-)
|
|
|
-from tools.vqgan.inference import load_model as load_decoder_model
|
|
|
-
|
|
|
-global_lock = Lock()
|
|
|
-
|
|
|
-# Whether to disable keepalive (which is helpful if the server is in the same cluster)
|
|
|
-DISABLE_KEEPALIVE = os.getenv("DISABLE_KEEPALIVE", "false").lower() == "true"
|
|
|
-async_client = httpx.AsyncClient(
|
|
|
- timeout=120, limits=httpx.Limits(keepalive_expiry=0 if DISABLE_KEEPALIVE else None)
|
|
|
-)
|
|
|
-backends = torchaudio.list_audio_backends()
|
|
|
-
|
|
|
-if "ffmpeg" in backends:
|
|
|
- backend = "ffmpeg"
|
|
|
-else:
|
|
|
- backend = "soundfile"
|
|
|
-
|
|
|
-
|
|
|
-def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
|
|
|
- buffer = io.BytesIO()
|
|
|
-
|
|
|
- with wave.open(buffer, "wb") as wav_file:
|
|
|
- wav_file.setnchannels(channels)
|
|
|
- wav_file.setsampwidth(bit_depth // 8)
|
|
|
- wav_file.setframerate(sample_rate)
|
|
|
-
|
|
|
- wav_header_bytes = buffer.getvalue()
|
|
|
- buffer.close()
|
|
|
- return wav_header_bytes
|
|
|
-
|
|
|
-
|
|
|
-# Define utils for web server
|
|
|
-async def http_execption_handler(exc: HTTPException):
|
|
|
- return JSONResponse(
|
|
|
- dict(
|
|
|
- statusCode=exc.status_code,
|
|
|
- message=exc.content,
|
|
|
- error=HTTPStatus(exc.status_code).phrase,
|
|
|
- ),
|
|
|
- exc.status_code,
|
|
|
- exc.headers,
|
|
|
- )
|
|
|
-
|
|
|
-
|
|
|
-async def other_exception_handler(exc: "Exception"):
|
|
|
- traceback.print_exc()
|
|
|
-
|
|
|
- status = HTTPStatus.INTERNAL_SERVER_ERROR
|
|
|
- return JSONResponse(
|
|
|
- dict(statusCode=status, message=str(exc), error=status.phrase),
|
|
|
- status,
|
|
|
- )
|
|
|
-
|
|
|
-
|
|
|
-def load_audio(reference_audio, sr):
|
|
|
- if len(reference_audio) > 255 or not Path(reference_audio).exists():
|
|
|
- audio_data = reference_audio
|
|
|
- reference_audio = io.BytesIO(audio_data)
|
|
|
-
|
|
|
- waveform, original_sr = torchaudio.load(reference_audio, backend=backend)
|
|
|
-
|
|
|
- if waveform.shape[0] > 1:
|
|
|
- waveform = torch.mean(waveform, dim=0, keepdim=True)
|
|
|
-
|
|
|
- if original_sr != sr:
|
|
|
- resampler = torchaudio.transforms.Resample(orig_freq=original_sr, new_freq=sr)
|
|
|
- waveform = resampler(waveform)
|
|
|
-
|
|
|
- audio = waveform.squeeze().numpy()
|
|
|
- return audio
|
|
|
-
|
|
|
-
|
|
|
-def encode_reference(*, decoder_model, reference_audio, enable_reference_audio):
|
|
|
- if enable_reference_audio and reference_audio is not None:
|
|
|
- # Load audios, and prepare basic info here
|
|
|
- reference_audio_content = load_audio(
|
|
|
- reference_audio, decoder_model.spec_transform.sample_rate
|
|
|
- )
|
|
|
-
|
|
|
- audios = torch.from_numpy(reference_audio_content).to(decoder_model.device)[
|
|
|
- None, None, :
|
|
|
- ]
|
|
|
- audio_lengths = torch.tensor(
|
|
|
- [audios.shape[2]], device=decoder_model.device, dtype=torch.long
|
|
|
- )
|
|
|
- logger.info(
|
|
|
- f"Loaded audio with {audios.shape[2] / decoder_model.spec_transform.sample_rate:.2f} seconds"
|
|
|
- )
|
|
|
-
|
|
|
- # VQ Encoder
|
|
|
- if isinstance(decoder_model, FireflyArchitecture):
|
|
|
- prompt_tokens = decoder_model.encode(audios, audio_lengths)[0][0]
|
|
|
-
|
|
|
- logger.info(f"Encoded prompt: {prompt_tokens.shape}")
|
|
|
- else:
|
|
|
- prompt_tokens = None
|
|
|
- logger.info("No reference audio provided")
|
|
|
-
|
|
|
- return prompt_tokens
|
|
|
-
|
|
|
-
|
|
|
-def decode_vq_tokens(
|
|
|
- *,
|
|
|
- decoder_model,
|
|
|
- codes,
|
|
|
-):
|
|
|
- feature_lengths = torch.tensor([codes.shape[1]], device=decoder_model.device)
|
|
|
- logger.info(f"VQ features: {codes.shape}")
|
|
|
-
|
|
|
- if isinstance(decoder_model, FireflyArchitecture):
|
|
|
- # VQGAN Inference
|
|
|
- return decoder_model.decode(
|
|
|
- indices=codes[None],
|
|
|
- feature_lengths=feature_lengths,
|
|
|
- )[0].squeeze()
|
|
|
-
|
|
|
- raise ValueError(f"Unknown model type: {type(decoder_model)}")
|
|
|
-
|
|
|
-
|
|
|
-routes = MultimethodRoutes(base_class=HttpView)
|
|
|
-
|
|
|
-
|
|
|
-def get_content_type(audio_format):
|
|
|
- if audio_format == "wav":
|
|
|
- return "audio/wav"
|
|
|
- elif audio_format == "flac":
|
|
|
- return "audio/flac"
|
|
|
- elif audio_format == "mp3":
|
|
|
- return "audio/mpeg"
|
|
|
- else:
|
|
|
- return "application/octet-stream"
|
|
|
-
|
|
|
-
|
|
|
-@torch.no_grad()
|
|
|
-@torch.autocast(device_type="cuda", dtype=torch.half)
|
|
|
-def batch_encode(model, audios: list[bytes | torch.Tensor]):
|
|
|
- audios = [
|
|
|
- (
|
|
|
- torch.from_numpy(
|
|
|
- librosa.load(io.BytesIO(audio), sr=model.spec_transform.sample_rate)[0]
|
|
|
- )[None]
|
|
|
- if isinstance(audio, bytes)
|
|
|
- else audio
|
|
|
- )
|
|
|
- for audio in audios
|
|
|
- ]
|
|
|
-
|
|
|
- # if any(audio.shape[-1] > model.spec_transform.sample_rate * 120 for audio in audios):
|
|
|
- # raise ValueError("Single audio length is too long (>120s)")
|
|
|
-
|
|
|
- max_length = max(audio.shape[-1] for audio in audios)
|
|
|
- print(f"Encode max length: {max_length / model.spec_transform.sample_rate:.2f}s")
|
|
|
-
|
|
|
- lengths = torch.tensor([audio.shape[-1] for audio in audios], device=model.device)
|
|
|
- max_length = lengths.max().item()
|
|
|
- padded = torch.stack(
|
|
|
- [
|
|
|
- torch.nn.functional.pad(audio, (0, max_length - audio.shape[-1]))
|
|
|
- for audio in audios
|
|
|
- ]
|
|
|
- ).to(model.device)
|
|
|
-
|
|
|
- features, feature_lengths = model.encode(padded, audio_lengths=lengths)
|
|
|
- features, feature_lengths = features.cpu(), feature_lengths.cpu()
|
|
|
-
|
|
|
- return [feature[..., :length] for feature, length in zip(features, feature_lengths)]
|
|
|
-
|
|
|
-
|
|
|
-@cached(
|
|
|
- cache=LRUCache(maxsize=10000),
|
|
|
- key=lambda model, audios: (model.device, tuple(audios)),
|
|
|
-)
|
|
|
-def cached_vqgan_batch_encode(model, audios: list[bytes]):
|
|
|
- return batch_encode(model, audios)
|
|
|
-
|
|
|
-
|
|
|
-@routes.http.post("/v1/vqgan/encode")
|
|
|
-def api_vqgan_encode(payload: Annotated[ServeVQGANEncodeRequest, Body(exclusive=True)]):
|
|
|
-
|
|
|
- start_time = time.time()
|
|
|
- tokens = cached_vqgan_batch_encode(decoder_model, payload.audios)
|
|
|
- logger.info(f"[EXEC] VQGAN encode time: {(time.time() - start_time) * 1000:.2f}ms")
|
|
|
-
|
|
|
- return ormsgpack.packb(
|
|
|
- ServeVQGANEncodeResponse(tokens=[i.tolist() for i in tokens]),
|
|
|
- option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
|
|
|
- )
|
|
|
-
|
|
|
-
|
|
|
-@torch.no_grad()
|
|
|
-@torch.autocast(device_type="cuda", dtype=torch.half)
|
|
|
-def vqgan_decode(model, features):
|
|
|
- lengths = torch.tensor(
|
|
|
- [feature.shape[-1] for feature in features], device=model.device
|
|
|
- )
|
|
|
- max_length = lengths.max().item()
|
|
|
- padded = torch.stack(
|
|
|
- [
|
|
|
- torch.nn.functional.pad(feature, (0, max_length - feature.shape[-1]))
|
|
|
- for feature in features
|
|
|
- ]
|
|
|
- ).to(model.device)
|
|
|
-
|
|
|
- # If bs too large, we do micro batch decode
|
|
|
- audios, audio_lengths = [], []
|
|
|
- for i in range(0, padded.shape[0], 8):
|
|
|
- audio, audio_length = model.decode(
|
|
|
- padded[i : i + 8], feature_lengths=lengths[i : i + 8]
|
|
|
- )
|
|
|
- audios.append(audio)
|
|
|
- audio_lengths.append(audio_length)
|
|
|
- audios = torch.cat(audios, dim=0)
|
|
|
- audio_lengths = torch.cat(audio_lengths, dim=0)
|
|
|
- audios, audio_lengths = audios.cpu(), audio_lengths.cpu()
|
|
|
-
|
|
|
- return [audio[..., :length].numpy() for audio, length in zip(audios, audio_lengths)]
|
|
|
-
|
|
|
-
|
|
|
-@routes.http.post("/v1/vqgan/decode")
|
|
|
-def api_vqgan_decode(payload: Annotated[ServeVQGANDecodeRequest, Body(exclusive=True)]):
|
|
|
- tokens = [torch.tensor(token, dtype=torch.int) for token in payload.tokens]
|
|
|
- start_time = time.time()
|
|
|
- audios = vqgan_decode(decoder_model, tokens)
|
|
|
- logger.info(f"[EXEC] VQGAN decode time: {(time.time() - start_time) * 1000:.2f}ms")
|
|
|
- audios = [audio.astype(np.float16).tobytes() for audio in audios]
|
|
|
- return ormsgpack.packb(
|
|
|
- ServeVQGANDecodeResponse(audios=audios), option=ormsgpack.OPT_SERIALIZE_PYDANTIC
|
|
|
- )
|
|
|
-
|
|
|
-
|
|
|
-@torch.no_grad()
|
|
|
-def batch_asr(model, audios, sr, language="auto"):
|
|
|
- resampled_audios = []
|
|
|
- for audio in audios:
|
|
|
- audio = torchaudio.functional.resample(audio, sr, 16000)
|
|
|
- assert audio.ndim == 1
|
|
|
- resampled_audios.append(audio)
|
|
|
-
|
|
|
- with global_lock:
|
|
|
- res = model.generate(
|
|
|
- input=resampled_audios,
|
|
|
- batch_size=len(resampled_audios),
|
|
|
- language=language,
|
|
|
- use_itn=True,
|
|
|
- )
|
|
|
-
|
|
|
- results = []
|
|
|
- for r, audio in zip(res, audios):
|
|
|
- text = r["text"]
|
|
|
- text = re.sub(r"<\|.*?\|>", "", text)
|
|
|
- duration = len(audio) / sr * 1000
|
|
|
- huge_gap = False
|
|
|
-
|
|
|
- if "timestamp" in r and len(r["timestamp"]) > 2:
|
|
|
- for timestamp_a, timestamp_b in zip(
|
|
|
- r["timestamp"][:-1], r["timestamp"][1:]
|
|
|
- ):
|
|
|
- # If there is a gap of more than 5 seconds, we consider it as a huge gap
|
|
|
- if timestamp_b[0] - timestamp_a[1] > 5000:
|
|
|
- huge_gap = True
|
|
|
- break
|
|
|
-
|
|
|
- # Doesn't make sense to have a huge gap at the end
|
|
|
- if duration - r["timestamp"][-1][1] > 3000:
|
|
|
- huge_gap = True
|
|
|
-
|
|
|
- results.append(
|
|
|
- {
|
|
|
- "text": text,
|
|
|
- "duration": duration,
|
|
|
- "huge_gap": huge_gap,
|
|
|
- }
|
|
|
- )
|
|
|
-
|
|
|
- return results
|
|
|
-
|
|
|
-
|
|
|
-@routes.http.post("/v1/asr")
|
|
|
-def api_invoke_asr(payload: Annotated[ServeASRRequest, Body(exclusive=True)]):
|
|
|
- start_time = time.time()
|
|
|
- audios = [np.frombuffer(audio, dtype=np.float16) for audio in payload.audios]
|
|
|
- audios = [torch.from_numpy(audio).float() for audio in audios]
|
|
|
-
|
|
|
- if any(audios.shape[-1] >= 30 * payload.sample_rate for audios in audios):
|
|
|
- raise HTTPException(status_code=400, detail="Audio length is too long")
|
|
|
-
|
|
|
- transcriptions = batch_asr(
|
|
|
- asr_model, audios=audios, sr=payload.sample_rate, language=payload.language
|
|
|
- )
|
|
|
- logger.info(f"[EXEC] ASR time: {(time.time() - start_time) * 1000:.2f}ms")
|
|
|
-
|
|
|
- return ormsgpack.packb(
|
|
|
- ServeASRResponse(transcriptions=transcriptions),
|
|
|
- option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
|
|
|
- )
|
|
|
-
|
|
|
-
|
|
|
-from fish_speech.conversation import Conversation, Message
|
|
|
-
|
|
|
-
|
|
|
-def execute_request(
|
|
|
- input_queue: queue.Queue,
|
|
|
- tokenizer: FishTokenizer,
|
|
|
- config: BaseModelArgs,
|
|
|
- request: ServeRequest,
|
|
|
- device: str = "cuda:0",
|
|
|
-):
|
|
|
-
|
|
|
- im_end_id = tokenizer.get_token_id(IM_END_TOKEN)
|
|
|
- messages = []
|
|
|
- for message in request.messages:
|
|
|
- messages.append(message.to_conversation_message())
|
|
|
-
|
|
|
- assert len(messages) >= 1, "At least one message is required"
|
|
|
- # assert messages[-1].role == "user", "The last message must be from the user"
|
|
|
-
|
|
|
- if messages[-1].role == "user":
|
|
|
- messages.append(
|
|
|
- Message(role="assistant", parts=[], add_im_end=False, modality="voice")
|
|
|
- )
|
|
|
- elif messages[-1].role == "raw":
|
|
|
- messages[-1].add_im_start = False
|
|
|
- messages[-1].add_im_end = False
|
|
|
- messages[-1].modality = "voice"
|
|
|
- else:
|
|
|
- assert (
|
|
|
- messages[-1].role == "assistant"
|
|
|
- ), "The last message must be from the assistant"
|
|
|
- messages[-1].add_im_end = False
|
|
|
-
|
|
|
- conv = Conversation(messages=messages)
|
|
|
-
|
|
|
- # conv.visualize(tokenizer)
|
|
|
- prompt = conv.encode_for_inference(
|
|
|
- tokenizer=tokenizer, num_codebooks=config.num_codebooks
|
|
|
- ).to(device)
|
|
|
-
|
|
|
- if request.streaming:
|
|
|
- for i in range(request.num_samples):
|
|
|
- yield ServeStreamResponse(
|
|
|
- sample_id=i,
|
|
|
- delta=ServeStreamDelta(
|
|
|
- role="assistant",
|
|
|
- ),
|
|
|
- )
|
|
|
-
|
|
|
- req = {
|
|
|
- "prompt": prompt,
|
|
|
- "max_new_tokens": request.max_new_tokens,
|
|
|
- "im_end_id": im_end_id,
|
|
|
- "temperature": request.temperature,
|
|
|
- "top_p": request.top_p,
|
|
|
- "repetition_penalty": request.repetition_penalty,
|
|
|
- "num_samples": request.num_samples,
|
|
|
- "early_stop_threshold": request.early_stop_threshold,
|
|
|
- }
|
|
|
-
|
|
|
- start = time.time()
|
|
|
- response_queue = queue.Queue()
|
|
|
- input_queue.put(GenerateRequest(req, response_queue))
|
|
|
-
|
|
|
- # Decoding
|
|
|
- decode_buffer = [[] for _ in range(request.num_samples)]
|
|
|
- parts = [[] for _ in range(request.num_samples)]
|
|
|
-
|
|
|
- def send_reset_buffer(sample_id):
|
|
|
- nonlocal decode_buffer
|
|
|
- if len(decode_buffer[sample_id]) == 0:
|
|
|
- return
|
|
|
-
|
|
|
- decoded = tokenizer.decode(decode_buffer[sample_id])
|
|
|
- part = ServeTextPart(text=decoded)
|
|
|
-
|
|
|
- if request.streaming:
|
|
|
- yield ServeStreamResponse(delta=ServeStreamDelta(part=part))
|
|
|
- else:
|
|
|
- parts[sample_id].append(part)
|
|
|
-
|
|
|
- decode_buffer[sample_id] = []
|
|
|
-
|
|
|
- # Decode process
|
|
|
- finished = [False for _ in range(request.num_samples)]
|
|
|
- stats = {}
|
|
|
- idx = 0
|
|
|
- while True:
|
|
|
- response = response_queue.get()
|
|
|
-
|
|
|
- if response in ["stop", "error"]:
|
|
|
- break
|
|
|
-
|
|
|
- for sample_id, tokens in enumerate(response):
|
|
|
- if finished[sample_id]:
|
|
|
- continue
|
|
|
-
|
|
|
- if tokens[0] == im_end_id:
|
|
|
- finished[sample_id] = True
|
|
|
- if request.streaming:
|
|
|
- yield from send_reset_buffer(sample_id)
|
|
|
- yield ServeStreamResponse(
|
|
|
- sample_id=sample_id,
|
|
|
- finish_reason="stop",
|
|
|
- stats=stats,
|
|
|
- )
|
|
|
- continue
|
|
|
-
|
|
|
- is_semantic = (
|
|
|
- tokenizer.semantic_begin_id <= tokens[0] <= tokenizer.semantic_end_id
|
|
|
- )
|
|
|
- if is_semantic and request.streaming:
|
|
|
- yield from send_reset_buffer(sample_id)
|
|
|
- # Streaming vq
|
|
|
- _tokens = tokens[1:].clone()
|
|
|
-
|
|
|
- if config.share_codebook_embeddings is False:
|
|
|
- for i in range(len(_tokens)):
|
|
|
- _tokens[i] -= config.codebook_size * i
|
|
|
-
|
|
|
- yield ServeStreamResponse(
|
|
|
- sample_id=sample_id,
|
|
|
- delta=ServeStreamDelta(part=ServeVQPart(codes=_tokens.tolist())),
|
|
|
- )
|
|
|
- continue
|
|
|
-
|
|
|
- # Not streaming vq
|
|
|
- if is_semantic:
|
|
|
- yield from send_reset_buffer(sample_id)
|
|
|
- # None streaming vq
|
|
|
- if len(parts[sample_id]) == 0 or not isinstance(
|
|
|
- parts[sample_id][-1], ServeVQPart
|
|
|
- ):
|
|
|
- _tokens = tokens[1:].clone()
|
|
|
-
|
|
|
- if config.share_codebook_embeddings is False:
|
|
|
- for i in range(len(_tokens)):
|
|
|
- _tokens[i] -= config.codebook_size * i
|
|
|
-
|
|
|
- parts[sample_id].append(ServeVQPart(codes=_tokens.tolist()))
|
|
|
- else:
|
|
|
- for codebook_id, value in enumerate(tokens[1:, :]):
|
|
|
- val = value.item()
|
|
|
- if config.share_codebook_embeddings is False:
|
|
|
- val -= config.codebook_size * codebook_id
|
|
|
-
|
|
|
- parts[sample_id][-1].codes[codebook_id].append(val)
|
|
|
- continue
|
|
|
-
|
|
|
- if not is_semantic:
|
|
|
- # Stream text decode is not supported now
|
|
|
- decode_buffer[sample_id].append(tokens[0, 0])
|
|
|
-
|
|
|
- if idx == 0:
|
|
|
- stats["time_to_first_token"] = (time.time() - start) * 1000
|
|
|
-
|
|
|
- idx += 1
|
|
|
-
|
|
|
- for sample_id in range(request.num_samples):
|
|
|
- yield from send_reset_buffer(sample_id)
|
|
|
-
|
|
|
- stats["total_time"] = (time.time() - start) * 1000
|
|
|
- stats["total_tokens"] = idx
|
|
|
-
|
|
|
- if request.streaming:
|
|
|
- for sample_id in range(request.num_samples):
|
|
|
- if finished[sample_id]:
|
|
|
- continue
|
|
|
- yield ServeStreamResponse(
|
|
|
- finish_reason=response, stats=stats, sample_id=sample_id
|
|
|
- )
|
|
|
- return
|
|
|
-
|
|
|
- yield ServeResponse(
|
|
|
- messages=[
|
|
|
- ServeMessage(role="assistant", parts=parts[i])
|
|
|
- for i in range(request.num_samples)
|
|
|
- ],
|
|
|
- finish_reason=response,
|
|
|
- stats=stats,
|
|
|
- )
|
|
|
-
|
|
|
-
|
|
|
-@routes.http.post("/v1/chat")
|
|
|
-def api_invoke_chat(
|
|
|
- req: Annotated[ServeRequest, Body(exclusive=True)],
|
|
|
-):
|
|
|
- """
|
|
|
- Invoke model and generate audio
|
|
|
- """
|
|
|
-
|
|
|
- # This makes torch compile happy
|
|
|
- assert (
|
|
|
- req.num_samples == GLOBAL_NUM_SAMPLES
|
|
|
- ), f"num_samples must be {GLOBAL_NUM_SAMPLES}"
|
|
|
-
|
|
|
- content_type = request.headers.get("Content-Type", "application/json")
|
|
|
- json_mode = "application/json" in content_type
|
|
|
-
|
|
|
- async def wrapped_generator():
|
|
|
- generator = execute_request(llama_queue, tokenizer, config, req, args.device)
|
|
|
-
|
|
|
- for i in generator:
|
|
|
- if json_mode:
|
|
|
- body = i.model_dump_json().encode("utf-8")
|
|
|
- yield b"data: " + body + b"\n\n"
|
|
|
- else:
|
|
|
- body = ormsgpack.packb(i, option=ormsgpack.OPT_SERIALIZE_PYDANTIC)
|
|
|
- yield struct.pack("I", len(body)) + body
|
|
|
-
|
|
|
- # Naive mode
|
|
|
- if req.streaming is False:
|
|
|
- result = next(execute_request(llama_queue, tokenizer, config, req, args.device))
|
|
|
-
|
|
|
- if json_mode:
|
|
|
- return JSONResponse(result.model_dump())
|
|
|
- else:
|
|
|
- return ormsgpack.packb(result, option=ormsgpack.OPT_SERIALIZE_PYDANTIC)
|
|
|
-
|
|
|
- return StreamResponse(
|
|
|
- iterable=wrapped_generator(), content_type="text/event-stream"
|
|
|
- )
|
|
|
-
|
|
|
-
|
|
|
-@torch.inference_mode()
|
|
|
-def inference(req: ServeTTSRequest):
|
|
|
-
|
|
|
- idstr: str | None = req.reference_id
|
|
|
- if idstr is not None:
|
|
|
- ref_folder = Path("references") / idstr
|
|
|
- ref_folder.mkdir(parents=True, exist_ok=True)
|
|
|
- ref_audios = list_files(
|
|
|
- ref_folder, AUDIO_EXTENSIONS, recursive=True, sort=False
|
|
|
- )
|
|
|
-
|
|
|
- if req.use_memory_cache == "never" or (
|
|
|
- req.use_memory_cache == "on-demand" and len(prompt_tokens) == 0
|
|
|
- ):
|
|
|
- prompt_tokens = [
|
|
|
- encode_reference(
|
|
|
- decoder_model=decoder_model,
|
|
|
- reference_audio=audio_to_bytes(str(ref_audio)),
|
|
|
- enable_reference_audio=True,
|
|
|
- )
|
|
|
- for ref_audio in ref_audios
|
|
|
- ]
|
|
|
- prompt_texts = [
|
|
|
- read_ref_text(str(ref_audio.with_suffix(".lab")))
|
|
|
- for ref_audio in ref_audios
|
|
|
- ]
|
|
|
- else:
|
|
|
- logger.info("Use same references")
|
|
|
-
|
|
|
- else:
|
|
|
- # Parse reference audio aka prompt
|
|
|
- refs = req.references
|
|
|
-
|
|
|
- if req.use_memory_cache == "never" or (
|
|
|
- req.use_memory_cache == "on-demand" and len(prompt_tokens) == 0
|
|
|
- ):
|
|
|
- prompt_tokens = [
|
|
|
- encode_reference(
|
|
|
- decoder_model=decoder_model,
|
|
|
- reference_audio=ref.audio,
|
|
|
- enable_reference_audio=True,
|
|
|
- )
|
|
|
- for ref in refs
|
|
|
- ]
|
|
|
- prompt_texts = [ref.text for ref in refs]
|
|
|
- else:
|
|
|
- logger.info("Use same references")
|
|
|
-
|
|
|
- if req.seed is not None:
|
|
|
- set_seed(req.seed)
|
|
|
- logger.warning(f"set seed: {req.seed}")
|
|
|
-
|
|
|
- # LLAMA Inference
|
|
|
- request = dict(
|
|
|
- device=decoder_model.device,
|
|
|
- max_new_tokens=req.max_new_tokens,
|
|
|
- text=(
|
|
|
- req.text
|
|
|
- if not req.normalize
|
|
|
- else ChnNormedText(raw_text=req.text).normalize()
|
|
|
- ),
|
|
|
- top_p=req.top_p,
|
|
|
- repetition_penalty=req.repetition_penalty,
|
|
|
- temperature=req.temperature,
|
|
|
- compile=args.compile,
|
|
|
- iterative_prompt=req.chunk_length > 0,
|
|
|
- chunk_length=req.chunk_length,
|
|
|
- max_length=4096,
|
|
|
- prompt_tokens=prompt_tokens,
|
|
|
- prompt_text=prompt_texts,
|
|
|
- )
|
|
|
-
|
|
|
- response_queue = queue.Queue()
|
|
|
- llama_queue.put(
|
|
|
- GenerateRequest(
|
|
|
- request=request,
|
|
|
- response_queue=response_queue,
|
|
|
- )
|
|
|
- )
|
|
|
-
|
|
|
- if req.streaming:
|
|
|
- yield wav_chunk_header()
|
|
|
-
|
|
|
- segments = []
|
|
|
- while True:
|
|
|
- result: WrappedGenerateResponse = response_queue.get()
|
|
|
- if result.status == "error":
|
|
|
- raise result.response
|
|
|
- break
|
|
|
-
|
|
|
- result: GenerateResponse = result.response
|
|
|
- if result.action == "next":
|
|
|
- break
|
|
|
-
|
|
|
- with autocast_exclude_mps(
|
|
|
- device_type=decoder_model.device.type, dtype=args.precision
|
|
|
- ):
|
|
|
- fake_audios = decode_vq_tokens(
|
|
|
- decoder_model=decoder_model,
|
|
|
- codes=result.codes,
|
|
|
- )
|
|
|
-
|
|
|
- fake_audios = fake_audios.float().cpu().numpy()
|
|
|
-
|
|
|
- if req.streaming:
|
|
|
- yield (fake_audios * 32768).astype(np.int16).tobytes()
|
|
|
- else:
|
|
|
- segments.append(fake_audios)
|
|
|
-
|
|
|
- if req.streaming:
|
|
|
- return
|
|
|
-
|
|
|
- if len(segments) == 0:
|
|
|
- raise HTTPException(
|
|
|
- HTTPStatus.INTERNAL_SERVER_ERROR,
|
|
|
- content="No audio generated, please check the input text.",
|
|
|
- )
|
|
|
-
|
|
|
- fake_audios = np.concatenate(segments, axis=0)
|
|
|
- yield fake_audios
|
|
|
-
|
|
|
-
|
|
|
-async def inference_async(req: ServeTTSRequest):
|
|
|
- for chunk in inference(req):
|
|
|
- yield chunk
|
|
|
-
|
|
|
-
|
|
|
-async def buffer_to_async_generator(buffer):
|
|
|
- yield buffer
|
|
|
-
|
|
|
-
|
|
|
-@routes.http.post("/v1/tts")
|
|
|
-async def api_invoke_model(
|
|
|
- req: Annotated[ServeTTSRequest, Body(exclusive=True)],
|
|
|
-):
|
|
|
- """
|
|
|
- Invoke model and generate audio
|
|
|
- """
|
|
|
-
|
|
|
- if args.max_text_length > 0 and len(req.text) > args.max_text_length:
|
|
|
- raise HTTPException(
|
|
|
- HTTPStatus.BAD_REQUEST,
|
|
|
- content=f"Text is too long, max length is {args.max_text_length}",
|
|
|
- )
|
|
|
-
|
|
|
- if req.streaming and req.format != "wav":
|
|
|
- raise HTTPException(
|
|
|
- HTTPStatus.BAD_REQUEST,
|
|
|
- content="Streaming only supports WAV format",
|
|
|
- )
|
|
|
-
|
|
|
- if req.streaming:
|
|
|
- return StreamResponse(
|
|
|
- iterable=inference_async(req),
|
|
|
- headers={
|
|
|
- "Content-Disposition": f"attachment; filename=audio.{req.format}",
|
|
|
- },
|
|
|
- content_type=get_content_type(req.format),
|
|
|
- )
|
|
|
- else:
|
|
|
- fake_audios = next(inference(req))
|
|
|
- buffer = io.BytesIO()
|
|
|
- sf.write(
|
|
|
- buffer,
|
|
|
- fake_audios,
|
|
|
- decoder_model.spec_transform.sample_rate,
|
|
|
- format=req.format,
|
|
|
- )
|
|
|
-
|
|
|
- return StreamResponse(
|
|
|
- iterable=buffer_to_async_generator(buffer.getvalue()),
|
|
|
- headers={
|
|
|
- "Content-Disposition": f"attachment; filename=audio.{req.format}",
|
|
|
- },
|
|
|
- content_type=get_content_type(req.format),
|
|
|
- )
|
|
|
-
|
|
|
-
|
|
|
-@routes.http.post("/v1/health")
|
|
|
-async def api_health():
|
|
|
- """
|
|
|
- Health check
|
|
|
- """
|
|
|
- return JSONResponse({"status": "ok"})
|
|
|
-
|
|
|
-
|
|
|
-def parse_args():
|
|
|
- parser = ArgumentParser()
|
|
|
- parser.add_argument("--mode", type=str, choices=["agent", "tts"], default="tts")
|
|
|
- parser.add_argument("--load-asr-model", action="store_true")
|
|
|
- parser.add_argument(
|
|
|
- "--llama-checkpoint-path",
|
|
|
- type=str,
|
|
|
- default="checkpoints/fish-speech-1.4",
|
|
|
- )
|
|
|
- parser.add_argument(
|
|
|
- "--decoder-checkpoint-path",
|
|
|
- type=str,
|
|
|
- default="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
|
|
|
- )
|
|
|
- parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
|
|
|
- parser.add_argument("--device", type=str, default="cuda")
|
|
|
- parser.add_argument("--half", action="store_true")
|
|
|
- parser.add_argument("--compile", action="store_true")
|
|
|
- parser.add_argument("--max-text-length", type=int, default=0)
|
|
|
- parser.add_argument("--listen", type=str, default="127.0.0.1:8080")
|
|
|
- parser.add_argument("--workers", type=int, default=1)
|
|
|
-
|
|
|
- return parser.parse_args()
|
|
|
-
|
|
|
-
|
|
|
-# Define Kui app
|
|
|
-openapi = OpenAPI(
|
|
|
- {
|
|
|
- "title": "Fish Speech API",
|
|
|
- "version": "1.4.2",
|
|
|
- },
|
|
|
-).routes
|
|
|
-
|
|
|
-
|
|
|
-class MsgPackRequest(HttpRequest):
|
|
|
- async def data(
|
|
|
- self,
|
|
|
- ) -> Annotated[
|
|
|
- Any, ContentType("application/msgpack"), ContentType("application/json")
|
|
|
- ]:
|
|
|
- if self.content_type == "application/msgpack":
|
|
|
- return ormsgpack.unpackb(await self.body)
|
|
|
-
|
|
|
- elif self.content_type == "application/json":
|
|
|
- return await self.json
|
|
|
-
|
|
|
- raise HTTPException(
|
|
|
- HTTPStatus.UNSUPPORTED_MEDIA_TYPE,
|
|
|
- headers={"Accept": "application/msgpack, application/json"},
|
|
|
- )
|
|
|
-
|
|
|
-
|
|
|
-app = Kui(
|
|
|
- routes=routes + openapi[1:], # Remove the default route
|
|
|
- exception_handlers={
|
|
|
- HTTPException: http_execption_handler,
|
|
|
- Exception: other_exception_handler,
|
|
|
- },
|
|
|
- factory_class=FactoryClass(http=MsgPackRequest),
|
|
|
- cors_config={},
|
|
|
-)
|
|
|
-
|
|
|
-
|
|
|
-def load_asr_model(*, device="cuda", hub="ms"):
|
|
|
- return AutoModel(
|
|
|
- model="iic/SenseVoiceSmall",
|
|
|
- device=device,
|
|
|
- disable_pbar=True,
|
|
|
- hub=hub,
|
|
|
- )
|
|
|
-
|
|
|
-
|
|
|
-# Each worker process created by Uvicorn has its own memory space,
|
|
|
-# meaning that models and variables are not shared between processes.
|
|
|
-# Therefore, any global variables (like `llama_queue` or `decoder_model`)
|
|
|
-# will not be shared across workers.
|
|
|
-
|
|
|
-
|
|
|
-# Multi-threading for deep learning can cause issues, such as inconsistent
|
|
|
-# outputs if multiple threads access the same buffers simultaneously.
|
|
|
-# Instead, it's better to use multiprocessing or independent models per thread.
|
|
|
-@app.on_startup
|
|
|
-def initialize_app(app: Kui):
|
|
|
-
|
|
|
- global args, llama_queue, tokenizer, config, decoder_model, vad_model, asr_model, prompt_tokens, prompt_texts
|
|
|
-
|
|
|
- prompt_tokens, prompt_texts = [], []
|
|
|
-
|
|
|
- args = parse_args() # args same as ones in other processes
|
|
|
- args.precision = torch.half if args.half else torch.bfloat16
|
|
|
-
|
|
|
- if args.load_asr_model:
|
|
|
- logger.info(f"Loading ASR model...")
|
|
|
- asr_model = load_asr_model(device=args.device)
|
|
|
-
|
|
|
- logger.info("Loading Llama model...")
|
|
|
-
|
|
|
- if args.mode == "tts":
|
|
|
- llama_queue = launch_thread_safe_queue(
|
|
|
- checkpoint_path=args.llama_checkpoint_path,
|
|
|
- device=args.device,
|
|
|
- precision=args.precision,
|
|
|
- compile=args.compile,
|
|
|
- )
|
|
|
- else:
|
|
|
- llama_queue, tokenizer, config = launch_thread_safe_queue_agent(
|
|
|
- checkpoint_path=args.llama_checkpoint_path,
|
|
|
- device=args.device,
|
|
|
- precision=args.precision,
|
|
|
- compile=args.compile,
|
|
|
- )
|
|
|
-
|
|
|
- logger.info("Llama model loaded, loading VQ-GAN model...")
|
|
|
-
|
|
|
- decoder_model = load_decoder_model(
|
|
|
- config_name=args.decoder_config_name,
|
|
|
- checkpoint_path=args.decoder_checkpoint_path,
|
|
|
- device=args.device,
|
|
|
- )
|
|
|
-
|
|
|
- logger.info("VQ-GAN model loaded, warming up...")
|
|
|
-
|
|
|
- vad_model = load_silero_vad()
|
|
|
-
|
|
|
- logger.info("VAD model loaded, warming up...")
|
|
|
-
|
|
|
- if args.mode == "tts":
|
|
|
- # Dry run to ensure models work and avoid first-time latency
|
|
|
- list(
|
|
|
- inference(
|
|
|
- ServeTTSRequest(
|
|
|
- text="Hello world.",
|
|
|
- references=[],
|
|
|
- reference_id=None,
|
|
|
- max_new_tokens=0,
|
|
|
- chunk_length=200,
|
|
|
- top_p=0.7,
|
|
|
- repetition_penalty=1.5,
|
|
|
- temperature=0.7,
|
|
|
- emotion=None,
|
|
|
- format="wav",
|
|
|
- )
|
|
|
- )
|
|
|
- )
|
|
|
-
|
|
|
- logger.info(f"Warming up done, starting server at http://{args.listen}")
|
|
|
-
|
|
|
-
|
|
|
-if __name__ == "__main__":
|
|
|
-
|
|
|
- import uvicorn
|
|
|
-
|
|
|
- args = parse_args()
|
|
|
- host, port = args.listen.split(":")
|
|
|
- uvicorn.run(
|
|
|
- "tools.api:app",
|
|
|
- host=host,
|
|
|
- port=int(port),
|
|
|
- workers=args.workers,
|
|
|
- log_level="info",
|
|
|
- )
|