| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462 |
- import io
- import os
- import re
- import shutil
- import tempfile
- import time
- from http import HTTPStatus
- from pathlib import Path
- import numpy as np
- import ormsgpack
- import soundfile as sf
- import torch
- from kui.asgi import (
- Body,
- HTTPException,
- HttpView,
- JSONResponse,
- Routes,
- StreamResponse,
- UploadFile,
- request,
- )
- from loguru import logger
- from typing_extensions import Annotated
- from fish_speech.utils.schema import (
- AddReferenceRequest,
- AddReferenceResponse,
- DeleteReferenceResponse,
- ListReferencesResponse,
- ServeTTSRequest,
- ServeVQGANDecodeRequest,
- ServeVQGANDecodeResponse,
- ServeVQGANEncodeRequest,
- ServeVQGANEncodeResponse,
- UpdateReferenceResponse,
- )
- from tools.server.api_utils import (
- buffer_to_async_generator,
- format_response,
- get_content_type,
- inference_async,
- )
- from tools.server.inference import inference_wrapper as inference
- from tools.server.model_manager import ModelManager
- from tools.server.model_utils import (
- batch_vqgan_decode,
- cached_vqgan_batch_encode,
- )
- MAX_NUM_SAMPLES = int(os.getenv("NUM_SAMPLES", 1))
- routes = Routes()
- @routes.http("/v1/health")
- class Health(HttpView):
- @classmethod
- async def get(cls):
- return JSONResponse({"status": "ok"})
- @classmethod
- async def post(cls):
- return JSONResponse({"status": "ok"})
- @routes.http.post("/v1/vqgan/encode")
- async def vqgan_encode(req: Annotated[ServeVQGANEncodeRequest, Body(exclusive=True)]):
- """
- Encode audio using VQGAN model.
- """
- try:
- # Get the model from the app
- model_manager: ModelManager = request.app.state.model_manager
- decoder_model = model_manager.decoder_model
- # Encode the audio
- start_time = time.time()
- tokens = cached_vqgan_batch_encode(decoder_model, req.audios)
- logger.info(
- f"[EXEC] VQGAN encode time: {(time.time() - start_time) * 1000:.2f}ms"
- )
- # Return the response
- return ormsgpack.packb(
- ServeVQGANEncodeResponse(tokens=[i.tolist() for i in tokens]),
- option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
- )
- except Exception as e:
- logger.error(f"Error in VQGAN encode: {e}", exc_info=True)
- raise HTTPException(
- HTTPStatus.INTERNAL_SERVER_ERROR, content="Failed to encode audio"
- )
- @routes.http.post("/v1/vqgan/decode")
- async def vqgan_decode(req: Annotated[ServeVQGANDecodeRequest, Body(exclusive=True)]):
- """
- Decode tokens to audio using VQGAN model.
- """
- try:
- # Get the model from the app
- model_manager: ModelManager = request.app.state.model_manager
- decoder_model = model_manager.decoder_model
- # Decode the audio
- tokens = [torch.tensor(token, dtype=torch.int) for token in req.tokens]
- start_time = time.time()
- audios = batch_vqgan_decode(decoder_model, tokens)
- logger.info(
- f"[EXEC] VQGAN decode time: {(time.time() - start_time) * 1000:.2f}ms"
- )
- audios = [audio.astype(np.float16).tobytes() for audio in audios]
- # Return the response
- return ormsgpack.packb(
- ServeVQGANDecodeResponse(audios=audios),
- option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
- )
- except Exception as e:
- logger.error(f"Error in VQGAN decode: {e}", exc_info=True)
- raise HTTPException(
- HTTPStatus.INTERNAL_SERVER_ERROR, content="Failed to decode tokens to audio"
- )
- @routes.http.post("/v1/tts")
- async def tts(req: Annotated[ServeTTSRequest, Body(exclusive=True)]):
- """
- Generate speech from text using TTS model.
- """
- try:
- # Get the model from the app
- app_state = request.app.state
- model_manager: ModelManager = app_state.model_manager
- engine = model_manager.tts_inference_engine
- sample_rate = engine.decoder_model.sample_rate
- # Check if the text is too long
- if app_state.max_text_length > 0 and len(req.text) > app_state.max_text_length:
- raise HTTPException(
- HTTPStatus.BAD_REQUEST,
- content=f"Text is too long, max length is {app_state.max_text_length}",
- )
- # Check if streaming is enabled
- if req.streaming and req.format != "wav":
- raise HTTPException(
- HTTPStatus.BAD_REQUEST,
- content="Streaming only supports WAV format",
- )
- # Perform TTS
- if req.streaming:
- return StreamResponse(
- iterable=inference_async(req, engine),
- headers={
- "Content-Disposition": f"attachment; filename=audio.{req.format}",
- },
- content_type=get_content_type(req.format),
- )
- else:
- fake_audios = next(inference(req, engine))
- buffer = io.BytesIO()
- sf.write(
- buffer,
- fake_audios,
- sample_rate,
- format=req.format,
- )
- return StreamResponse(
- iterable=buffer_to_async_generator(buffer.getvalue()),
- headers={
- "Content-Disposition": f"attachment; filename=audio.{req.format}",
- },
- content_type=get_content_type(req.format),
- )
- except HTTPException:
- # Re-raise HTTP exceptions as they are already properly formatted
- raise
- except Exception as e:
- logger.error(f"Error in TTS generation: {e}", exc_info=True)
- raise HTTPException(
- HTTPStatus.INTERNAL_SERVER_ERROR, content="Failed to generate speech"
- )
- @routes.http.post("/v1/references/add")
- async def add_reference(
- id: str = Body(...), audio: UploadFile = Body(...), text: str = Body(...)
- ):
- """
- Add a new reference voice with audio file and text.
- """
- temp_file_path = None
- try:
- # Validate input parameters
- if not id or not id.strip():
- raise ValueError("Reference ID cannot be empty")
- if not text or not text.strip():
- raise ValueError("Reference text cannot be empty")
- # Get the model manager to access the reference loader
- app_state = request.app.state
- model_manager: ModelManager = app_state.model_manager
- engine = model_manager.tts_inference_engine
- # Read the uploaded audio file
- audio_content = audio.read()
- if not audio_content:
- raise ValueError("Audio file is empty or could not be read")
- # Create a temporary file for the audio data
- with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_file:
- temp_file.write(audio_content)
- temp_file_path = temp_file.name
- # Add the reference using the engine's reference loader
- engine.add_reference(id, temp_file_path, text)
- response = AddReferenceResponse(
- success=True,
- message=f"Reference voice '{id}' added successfully",
- reference_id=id,
- )
- return format_response(response)
- except FileExistsError as e:
- logger.warning(f"Reference ID '{id}' already exists: {e}")
- response = AddReferenceResponse(
- success=False,
- message=f"Reference ID '{id}' already exists",
- reference_id=id,
- )
- return format_response(response, status_code=409) # Conflict
- except ValueError as e:
- logger.warning(f"Invalid input for reference '{id}': {e}")
- response = AddReferenceResponse(success=False, message=str(e), reference_id=id)
- return format_response(response, status_code=400)
- except (FileNotFoundError, OSError) as e:
- logger.error(f"File system error for reference '{id}': {e}")
- response = AddReferenceResponse(
- success=False, message="File system error occurred", reference_id=id
- )
- return format_response(response, status_code=500)
- except Exception as e:
- logger.error(f"Unexpected error adding reference '{id}': {e}", exc_info=True)
- response = AddReferenceResponse(
- success=False, message="Internal server error occurred", reference_id=id
- )
- return format_response(response, status_code=500)
- finally:
- # Clean up temporary file
- if temp_file_path and os.path.exists(temp_file_path):
- try:
- os.unlink(temp_file_path)
- except OSError as e:
- logger.warning(
- f"Failed to clean up temporary file {temp_file_path}: {e}"
- )
- @routes.http.get("/v1/references/list")
- async def list_references():
- """
- Get a list of all available reference voice IDs.
- """
- try:
- # Get the model manager to access the reference loader
- app_state = request.app.state
- model_manager: ModelManager = app_state.model_manager
- engine = model_manager.tts_inference_engine
- # Get the list of reference IDs
- reference_ids = engine.list_reference_ids()
- response = ListReferencesResponse(
- success=True,
- reference_ids=reference_ids,
- message=f"Found {len(reference_ids)} reference voices",
- )
- return format_response(response)
- except Exception as e:
- logger.error(f"Unexpected error listing references: {e}", exc_info=True)
- response = ListReferencesResponse(
- success=False, reference_ids=[], message="Internal server error occurred"
- )
- return format_response(response, status_code=500)
- @routes.http.delete("/v1/references/delete")
- async def delete_reference(reference_id: str = Body(...)):
- """
- Delete a reference voice by ID.
- """
- try:
- # Validate input parameters
- if not reference_id or not reference_id.strip():
- raise ValueError("Reference ID cannot be empty")
- # Get the model manager to access the reference loader
- app_state = request.app.state
- model_manager: ModelManager = app_state.model_manager
- engine = model_manager.tts_inference_engine
- # Delete the reference using the engine's reference loader
- engine.delete_reference(reference_id)
- response = DeleteReferenceResponse(
- success=True,
- message=f"Reference voice '{reference_id}' deleted successfully",
- reference_id=reference_id,
- )
- return format_response(response)
- except FileNotFoundError as e:
- logger.warning(f"Reference ID '{reference_id}' not found: {e}")
- response = DeleteReferenceResponse(
- success=False,
- message=f"Reference ID '{reference_id}' not found",
- reference_id=reference_id,
- )
- return format_response(response, status_code=404) # Not Found
- except ValueError as e:
- logger.warning(f"Invalid input for reference '{reference_id}': {e}")
- response = DeleteReferenceResponse(
- success=False, message=str(e), reference_id=reference_id
- )
- return format_response(response, status_code=400)
- except OSError as e:
- logger.error(f"File system error deleting reference '{reference_id}': {e}")
- response = DeleteReferenceResponse(
- success=False,
- message="File system error occurred",
- reference_id=reference_id,
- )
- return format_response(response, status_code=500)
- except Exception as e:
- logger.error(
- f"Unexpected error deleting reference '{reference_id}': {e}", exc_info=True
- )
- response = DeleteReferenceResponse(
- success=False,
- message="Internal server error occurred",
- reference_id=reference_id,
- )
- return format_response(response, status_code=500)
- @routes.http.post("/v1/references/update")
- async def update_reference(
- old_reference_id: str = Body(...), new_reference_id: str = Body(...)
- ):
- """
- Rename a reference voice directory from old_reference_id to new_reference_id.
- """
- try:
- # Validate input parameters
- if not old_reference_id or not old_reference_id.strip():
- raise ValueError("Old reference ID cannot be empty")
- if not new_reference_id or not new_reference_id.strip():
- raise ValueError("New reference ID cannot be empty")
- if old_reference_id == new_reference_id:
- raise ValueError("New reference ID must be different from old reference ID")
- # Validate ID format per ReferenceLoader rules
- id_pattern = r"^[a-zA-Z0-9\-_ ]+$"
- if not re.match(id_pattern, new_reference_id) or len(new_reference_id) > 255:
- raise ValueError(
- "New reference ID contains invalid characters or is too long"
- )
- # Access engine to update caches after renaming
- app_state = request.app.state
- model_manager: ModelManager = app_state.model_manager
- engine = model_manager.tts_inference_engine
- refs_base = Path("references")
- old_dir = refs_base / old_reference_id
- new_dir = refs_base / new_reference_id
- # Existence checks
- if not old_dir.exists() or not old_dir.is_dir():
- raise FileNotFoundError(f"Reference ID '{old_reference_id}' not found")
- if new_dir.exists():
- # Conflict: destination already exists
- response = UpdateReferenceResponse(
- success=False,
- message=f"Reference ID '{new_reference_id}' already exists",
- old_reference_id=old_reference_id,
- new_reference_id=new_reference_id,
- )
- return format_response(response, status_code=409)
- # Perform rename
- old_dir.rename(new_dir)
- # Update in-memory cache key if present
- if old_reference_id in engine.ref_by_id:
- engine.ref_by_id[new_reference_id] = engine.ref_by_id.pop(old_reference_id)
- response = UpdateReferenceResponse(
- success=True,
- message=(
- f"Reference voice renamed from '{old_reference_id}' to '{new_reference_id}' successfully"
- ),
- old_reference_id=old_reference_id,
- new_reference_id=new_reference_id,
- )
- return format_response(response)
- except FileNotFoundError as e:
- logger.warning(str(e))
- response = UpdateReferenceResponse(
- success=False,
- message=str(e),
- old_reference_id=old_reference_id,
- new_reference_id=new_reference_id,
- )
- return format_response(response, status_code=404)
- except ValueError as e:
- logger.warning(f"Invalid input for update reference: {e}")
- response = UpdateReferenceResponse(
- success=False,
- message=str(e),
- old_reference_id=old_reference_id if "old_reference_id" in locals() else "",
- new_reference_id=new_reference_id if "new_reference_id" in locals() else "",
- )
- return format_response(response, status_code=400)
- except OSError as e:
- logger.error(f"File system error renaming reference: {e}")
- response = UpdateReferenceResponse(
- success=False,
- message="File system error occurred",
- old_reference_id=old_reference_id,
- new_reference_id=new_reference_id,
- )
- return format_response(response, status_code=500)
- except Exception as e:
- logger.error(f"Unexpected error updating reference: {e}", exc_info=True)
- response = UpdateReferenceResponse(
- success=False,
- message="Internal server error occurred",
- old_reference_id=old_reference_id if "old_reference_id" in locals() else "",
- new_reference_id=new_reference_id if "new_reference_id" in locals() else "",
- )
- return format_response(response, status_code=500)
|