import io import os import re import shutil import tempfile import time from http import HTTPStatus from pathlib import Path import numpy as np import ormsgpack import soundfile as sf import torch from kui.asgi import ( Body, HTTPException, HttpView, JSONResponse, Routes, StreamResponse, UploadFile, request, ) from loguru import logger from typing_extensions import Annotated from fish_speech.utils.schema import ( AddReferenceRequest, AddReferenceResponse, DeleteReferenceResponse, ListReferencesResponse, ServeTTSRequest, ServeVQGANDecodeRequest, ServeVQGANDecodeResponse, ServeVQGANEncodeRequest, ServeVQGANEncodeResponse, UpdateReferenceResponse, ) from tools.server.api_utils import ( buffer_to_async_generator, format_response, get_content_type, inference_async, ) from tools.server.inference import inference_wrapper as inference from tools.server.model_manager import ModelManager from tools.server.model_utils import ( batch_vqgan_decode, cached_vqgan_batch_encode, ) MAX_NUM_SAMPLES = int(os.getenv("NUM_SAMPLES", 1)) routes = Routes() @routes.http("/v1/health") class Health(HttpView): @classmethod async def get(cls): return JSONResponse({"status": "ok"}) @classmethod async def post(cls): return JSONResponse({"status": "ok"}) @routes.http.post("/v1/vqgan/encode") async def vqgan_encode(req: Annotated[ServeVQGANEncodeRequest, Body(exclusive=True)]): """ Encode audio using VQGAN model. """ try: # Get the model from the app model_manager: ModelManager = request.app.state.model_manager decoder_model = model_manager.decoder_model # Encode the audio start_time = time.time() tokens = cached_vqgan_batch_encode(decoder_model, req.audios) logger.info( f"[EXEC] VQGAN encode time: {(time.time() - start_time) * 1000:.2f}ms" ) # Return the response return ormsgpack.packb( ServeVQGANEncodeResponse(tokens=[i.tolist() for i in tokens]), option=ormsgpack.OPT_SERIALIZE_PYDANTIC, ) except Exception as e: logger.error(f"Error in VQGAN encode: {e}", exc_info=True) raise HTTPException( HTTPStatus.INTERNAL_SERVER_ERROR, content="Failed to encode audio" ) @routes.http.post("/v1/vqgan/decode") async def vqgan_decode(req: Annotated[ServeVQGANDecodeRequest, Body(exclusive=True)]): """ Decode tokens to audio using VQGAN model. """ try: # Get the model from the app model_manager: ModelManager = request.app.state.model_manager decoder_model = model_manager.decoder_model # Decode the audio tokens = [torch.tensor(token, dtype=torch.int) for token in req.tokens] start_time = time.time() audios = batch_vqgan_decode(decoder_model, tokens) logger.info( f"[EXEC] VQGAN decode time: {(time.time() - start_time) * 1000:.2f}ms" ) audios = [audio.astype(np.float16).tobytes() for audio in audios] # Return the response return ormsgpack.packb( ServeVQGANDecodeResponse(audios=audios), option=ormsgpack.OPT_SERIALIZE_PYDANTIC, ) except Exception as e: logger.error(f"Error in VQGAN decode: {e}", exc_info=True) raise HTTPException( HTTPStatus.INTERNAL_SERVER_ERROR, content="Failed to decode tokens to audio" ) @routes.http.post("/v1/tts") async def tts(req: Annotated[ServeTTSRequest, Body(exclusive=True)]): """ Generate speech from text using TTS model. """ try: # Get the model from the app app_state = request.app.state model_manager: ModelManager = app_state.model_manager engine = model_manager.tts_inference_engine sample_rate = engine.decoder_model.sample_rate # Check if the text is too long if app_state.max_text_length > 0 and len(req.text) > app_state.max_text_length: raise HTTPException( HTTPStatus.BAD_REQUEST, content=f"Text is too long, max length is {app_state.max_text_length}", ) # Check if streaming is enabled if req.streaming and req.format != "wav": raise HTTPException( HTTPStatus.BAD_REQUEST, content="Streaming only supports WAV format", ) # Perform TTS if req.streaming: return StreamResponse( iterable=inference_async(req, engine), headers={ "Content-Disposition": f"attachment; filename=audio.{req.format}", }, content_type=get_content_type(req.format), ) else: fake_audios = next(inference(req, engine)) buffer = io.BytesIO() sf.write( buffer, fake_audios, sample_rate, format=req.format, ) return StreamResponse( iterable=buffer_to_async_generator(buffer.getvalue()), headers={ "Content-Disposition": f"attachment; filename=audio.{req.format}", }, content_type=get_content_type(req.format), ) except HTTPException: # Re-raise HTTP exceptions as they are already properly formatted raise except Exception as e: logger.error(f"Error in TTS generation: {e}", exc_info=True) raise HTTPException( HTTPStatus.INTERNAL_SERVER_ERROR, content="Failed to generate speech" ) @routes.http.post("/v1/references/add") async def add_reference( id: str = Body(...), audio: UploadFile = Body(...), text: str = Body(...) ): """ Add a new reference voice with audio file and text. """ temp_file_path = None try: # Validate input parameters if not id or not id.strip(): raise ValueError("Reference ID cannot be empty") if not text or not text.strip(): raise ValueError("Reference text cannot be empty") # Get the model manager to access the reference loader app_state = request.app.state model_manager: ModelManager = app_state.model_manager engine = model_manager.tts_inference_engine # Read the uploaded audio file audio_content = audio.read() if not audio_content: raise ValueError("Audio file is empty or could not be read") # Create a temporary file for the audio data with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_file: temp_file.write(audio_content) temp_file_path = temp_file.name # Add the reference using the engine's reference loader engine.add_reference(id, temp_file_path, text) response = AddReferenceResponse( success=True, message=f"Reference voice '{id}' added successfully", reference_id=id, ) return format_response(response) except FileExistsError as e: logger.warning(f"Reference ID '{id}' already exists: {e}") response = AddReferenceResponse( success=False, message=f"Reference ID '{id}' already exists", reference_id=id, ) return format_response(response, status_code=409) # Conflict except ValueError as e: logger.warning(f"Invalid input for reference '{id}': {e}") response = AddReferenceResponse(success=False, message=str(e), reference_id=id) return format_response(response, status_code=400) except (FileNotFoundError, OSError) as e: logger.error(f"File system error for reference '{id}': {e}") response = AddReferenceResponse( success=False, message="File system error occurred", reference_id=id ) return format_response(response, status_code=500) except Exception as e: logger.error(f"Unexpected error adding reference '{id}': {e}", exc_info=True) response = AddReferenceResponse( success=False, message="Internal server error occurred", reference_id=id ) return format_response(response, status_code=500) finally: # Clean up temporary file if temp_file_path and os.path.exists(temp_file_path): try: os.unlink(temp_file_path) except OSError as e: logger.warning( f"Failed to clean up temporary file {temp_file_path}: {e}" ) @routes.http.get("/v1/references/list") async def list_references(): """ Get a list of all available reference voice IDs. """ try: # Get the model manager to access the reference loader app_state = request.app.state model_manager: ModelManager = app_state.model_manager engine = model_manager.tts_inference_engine # Get the list of reference IDs reference_ids = engine.list_reference_ids() response = ListReferencesResponse( success=True, reference_ids=reference_ids, message=f"Found {len(reference_ids)} reference voices", ) return format_response(response) except Exception as e: logger.error(f"Unexpected error listing references: {e}", exc_info=True) response = ListReferencesResponse( success=False, reference_ids=[], message="Internal server error occurred" ) return format_response(response, status_code=500) @routes.http.delete("/v1/references/delete") async def delete_reference(reference_id: str = Body(...)): """ Delete a reference voice by ID. """ try: # Validate input parameters if not reference_id or not reference_id.strip(): raise ValueError("Reference ID cannot be empty") # Get the model manager to access the reference loader app_state = request.app.state model_manager: ModelManager = app_state.model_manager engine = model_manager.tts_inference_engine # Delete the reference using the engine's reference loader engine.delete_reference(reference_id) response = DeleteReferenceResponse( success=True, message=f"Reference voice '{reference_id}' deleted successfully", reference_id=reference_id, ) return format_response(response) except FileNotFoundError as e: logger.warning(f"Reference ID '{reference_id}' not found: {e}") response = DeleteReferenceResponse( success=False, message=f"Reference ID '{reference_id}' not found", reference_id=reference_id, ) return format_response(response, status_code=404) # Not Found except ValueError as e: logger.warning(f"Invalid input for reference '{reference_id}': {e}") response = DeleteReferenceResponse( success=False, message=str(e), reference_id=reference_id ) return format_response(response, status_code=400) except OSError as e: logger.error(f"File system error deleting reference '{reference_id}': {e}") response = DeleteReferenceResponse( success=False, message="File system error occurred", reference_id=reference_id, ) return format_response(response, status_code=500) except Exception as e: logger.error( f"Unexpected error deleting reference '{reference_id}': {e}", exc_info=True ) response = DeleteReferenceResponse( success=False, message="Internal server error occurred", reference_id=reference_id, ) return format_response(response, status_code=500) @routes.http.post("/v1/references/update") async def update_reference( old_reference_id: str = Body(...), new_reference_id: str = Body(...) ): """ Rename a reference voice directory from old_reference_id to new_reference_id. """ try: # Validate input parameters if not old_reference_id or not old_reference_id.strip(): raise ValueError("Old reference ID cannot be empty") if not new_reference_id or not new_reference_id.strip(): raise ValueError("New reference ID cannot be empty") if old_reference_id == new_reference_id: raise ValueError("New reference ID must be different from old reference ID") # Validate ID format per ReferenceLoader rules id_pattern = r"^[a-zA-Z0-9\-_ ]+$" if not re.match(id_pattern, new_reference_id) or len(new_reference_id) > 255: raise ValueError( "New reference ID contains invalid characters or is too long" ) # Access engine to update caches after renaming app_state = request.app.state model_manager: ModelManager = app_state.model_manager engine = model_manager.tts_inference_engine refs_base = Path("references") old_dir = refs_base / old_reference_id new_dir = refs_base / new_reference_id # Existence checks if not old_dir.exists() or not old_dir.is_dir(): raise FileNotFoundError(f"Reference ID '{old_reference_id}' not found") if new_dir.exists(): # Conflict: destination already exists response = UpdateReferenceResponse( success=False, message=f"Reference ID '{new_reference_id}' already exists", old_reference_id=old_reference_id, new_reference_id=new_reference_id, ) return format_response(response, status_code=409) # Perform rename old_dir.rename(new_dir) # Update in-memory cache key if present if old_reference_id in engine.ref_by_id: engine.ref_by_id[new_reference_id] = engine.ref_by_id.pop(old_reference_id) response = UpdateReferenceResponse( success=True, message=( f"Reference voice renamed from '{old_reference_id}' to '{new_reference_id}' successfully" ), old_reference_id=old_reference_id, new_reference_id=new_reference_id, ) return format_response(response) except FileNotFoundError as e: logger.warning(str(e)) response = UpdateReferenceResponse( success=False, message=str(e), old_reference_id=old_reference_id, new_reference_id=new_reference_id, ) return format_response(response, status_code=404) except ValueError as e: logger.warning(f"Invalid input for update reference: {e}") response = UpdateReferenceResponse( success=False, message=str(e), old_reference_id=old_reference_id if "old_reference_id" in locals() else "", new_reference_id=new_reference_id if "new_reference_id" in locals() else "", ) return format_response(response, status_code=400) except OSError as e: logger.error(f"File system error renaming reference: {e}") response = UpdateReferenceResponse( success=False, message="File system error occurred", old_reference_id=old_reference_id, new_reference_id=new_reference_id, ) return format_response(response, status_code=500) except Exception as e: logger.error(f"Unexpected error updating reference: {e}", exc_info=True) response = UpdateReferenceResponse( success=False, message="Internal server error occurred", old_reference_id=old_reference_id if "old_reference_id" in locals() else "", new_reference_id=new_reference_id if "new_reference_id" in locals() else "", ) return format_response(response, status_code=500)