|
|
@@ -1,7 +1,11 @@
|
|
|
import io
|
|
|
import os
|
|
|
+import re
|
|
|
+import shutil
|
|
|
+import tempfile
|
|
|
import time
|
|
|
from http import HTTPStatus
|
|
|
+from pathlib import Path
|
|
|
|
|
|
import numpy as np
|
|
|
import ormsgpack
|
|
|
@@ -14,20 +18,27 @@ from kui.asgi import (
|
|
|
JSONResponse,
|
|
|
Routes,
|
|
|
StreamResponse,
|
|
|
+ UploadFile,
|
|
|
request,
|
|
|
)
|
|
|
from loguru import logger
|
|
|
from typing_extensions import Annotated
|
|
|
|
|
|
from fish_speech.utils.schema import (
|
|
|
+ AddReferenceRequest,
|
|
|
+ AddReferenceResponse,
|
|
|
+ DeleteReferenceResponse,
|
|
|
+ ListReferencesResponse,
|
|
|
ServeTTSRequest,
|
|
|
ServeVQGANDecodeRequest,
|
|
|
ServeVQGANDecodeResponse,
|
|
|
ServeVQGANEncodeRequest,
|
|
|
ServeVQGANEncodeResponse,
|
|
|
+ UpdateReferenceResponse,
|
|
|
)
|
|
|
from tools.server.api_utils import (
|
|
|
buffer_to_async_generator,
|
|
|
+ format_response,
|
|
|
get_content_type,
|
|
|
inference_async,
|
|
|
)
|
|
|
@@ -56,87 +67,396 @@ class Health(HttpView):
|
|
|
|
|
|
@routes.http.post("/v1/vqgan/encode")
|
|
|
async def vqgan_encode(req: Annotated[ServeVQGANEncodeRequest, Body(exclusive=True)]):
|
|
|
- # Get the model from the app
|
|
|
- model_manager: ModelManager = request.app.state.model_manager
|
|
|
- decoder_model = model_manager.decoder_model
|
|
|
+ """
|
|
|
+ Encode audio using VQGAN model.
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ # Get the model from the app
|
|
|
+ model_manager: ModelManager = request.app.state.model_manager
|
|
|
+ decoder_model = model_manager.decoder_model
|
|
|
|
|
|
- # Encode the audio
|
|
|
- start_time = time.time()
|
|
|
- tokens = cached_vqgan_batch_encode(decoder_model, req.audios)
|
|
|
- logger.info(f"[EXEC] VQGAN encode time: {(time.time() - start_time) * 1000:.2f}ms")
|
|
|
+ # Encode the audio
|
|
|
+ start_time = time.time()
|
|
|
+ tokens = cached_vqgan_batch_encode(decoder_model, req.audios)
|
|
|
+ logger.info(
|
|
|
+ f"[EXEC] VQGAN encode time: {(time.time() - start_time) * 1000:.2f}ms"
|
|
|
+ )
|
|
|
|
|
|
- # Return the response
|
|
|
- return ormsgpack.packb(
|
|
|
- ServeVQGANEncodeResponse(tokens=[i.tolist() for i in tokens]),
|
|
|
- option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
|
|
|
- )
|
|
|
+ # Return the response
|
|
|
+ return ormsgpack.packb(
|
|
|
+ ServeVQGANEncodeResponse(tokens=[i.tolist() for i in tokens]),
|
|
|
+ option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
|
|
|
+ )
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"Error in VQGAN encode: {e}", exc_info=True)
|
|
|
+ raise HTTPException(
|
|
|
+ HTTPStatus.INTERNAL_SERVER_ERROR, content="Failed to encode audio"
|
|
|
+ )
|
|
|
|
|
|
|
|
|
@routes.http.post("/v1/vqgan/decode")
|
|
|
async def vqgan_decode(req: Annotated[ServeVQGANDecodeRequest, Body(exclusive=True)]):
|
|
|
- # Get the model from the app
|
|
|
- model_manager: ModelManager = request.app.state.model_manager
|
|
|
- decoder_model = model_manager.decoder_model
|
|
|
+ """
|
|
|
+ Decode tokens to audio using VQGAN model.
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ # Get the model from the app
|
|
|
+ model_manager: ModelManager = request.app.state.model_manager
|
|
|
+ decoder_model = model_manager.decoder_model
|
|
|
|
|
|
- # Decode the audio
|
|
|
- tokens = [torch.tensor(token, dtype=torch.int) for token in req.tokens]
|
|
|
- start_time = time.time()
|
|
|
- audios = batch_vqgan_decode(decoder_model, tokens)
|
|
|
- logger.info(f"[EXEC] VQGAN decode time: {(time.time() - start_time) * 1000:.2f}ms")
|
|
|
- audios = [audio.astype(np.float16).tobytes() for audio in audios]
|
|
|
+ # Decode the audio
|
|
|
+ tokens = [torch.tensor(token, dtype=torch.int) for token in req.tokens]
|
|
|
+ start_time = time.time()
|
|
|
+ audios = batch_vqgan_decode(decoder_model, tokens)
|
|
|
+ logger.info(
|
|
|
+ f"[EXEC] VQGAN decode time: {(time.time() - start_time) * 1000:.2f}ms"
|
|
|
+ )
|
|
|
+ audios = [audio.astype(np.float16).tobytes() for audio in audios]
|
|
|
|
|
|
- # Return the response
|
|
|
- return ormsgpack.packb(
|
|
|
- ServeVQGANDecodeResponse(audios=audios),
|
|
|
- option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
|
|
|
- )
|
|
|
+ # Return the response
|
|
|
+ return ormsgpack.packb(
|
|
|
+ ServeVQGANDecodeResponse(audios=audios),
|
|
|
+ option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
|
|
|
+ )
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"Error in VQGAN decode: {e}", exc_info=True)
|
|
|
+ raise HTTPException(
|
|
|
+ HTTPStatus.INTERNAL_SERVER_ERROR, content="Failed to decode tokens to audio"
|
|
|
+ )
|
|
|
|
|
|
|
|
|
@routes.http.post("/v1/tts")
|
|
|
async def tts(req: Annotated[ServeTTSRequest, Body(exclusive=True)]):
|
|
|
- # Get the model from the app
|
|
|
- app_state = request.app.state
|
|
|
- model_manager: ModelManager = app_state.model_manager
|
|
|
- engine = model_manager.tts_inference_engine
|
|
|
- sample_rate = engine.decoder_model.sample_rate
|
|
|
-
|
|
|
- # Check if the text is too long
|
|
|
- if app_state.max_text_length > 0 and len(req.text) > app_state.max_text_length:
|
|
|
+ """
|
|
|
+ Generate speech from text using TTS model.
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ # Get the model from the app
|
|
|
+ app_state = request.app.state
|
|
|
+ model_manager: ModelManager = app_state.model_manager
|
|
|
+ engine = model_manager.tts_inference_engine
|
|
|
+ sample_rate = engine.decoder_model.sample_rate
|
|
|
+
|
|
|
+ # Check if the text is too long
|
|
|
+ if app_state.max_text_length > 0 and len(req.text) > app_state.max_text_length:
|
|
|
+ raise HTTPException(
|
|
|
+ HTTPStatus.BAD_REQUEST,
|
|
|
+ content=f"Text is too long, max length is {app_state.max_text_length}",
|
|
|
+ )
|
|
|
+
|
|
|
+ # Check if streaming is enabled
|
|
|
+ if req.streaming and req.format != "wav":
|
|
|
+ raise HTTPException(
|
|
|
+ HTTPStatus.BAD_REQUEST,
|
|
|
+ content="Streaming only supports WAV format",
|
|
|
+ )
|
|
|
+
|
|
|
+ # Perform TTS
|
|
|
+ if req.streaming:
|
|
|
+ return StreamResponse(
|
|
|
+ iterable=inference_async(req, engine),
|
|
|
+ headers={
|
|
|
+ "Content-Disposition": f"attachment; filename=audio.{req.format}",
|
|
|
+ },
|
|
|
+ content_type=get_content_type(req.format),
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ fake_audios = next(inference(req, engine))
|
|
|
+ buffer = io.BytesIO()
|
|
|
+ sf.write(
|
|
|
+ buffer,
|
|
|
+ fake_audios,
|
|
|
+ sample_rate,
|
|
|
+ format=req.format,
|
|
|
+ )
|
|
|
+
|
|
|
+ return StreamResponse(
|
|
|
+ iterable=buffer_to_async_generator(buffer.getvalue()),
|
|
|
+ headers={
|
|
|
+ "Content-Disposition": f"attachment; filename=audio.{req.format}",
|
|
|
+ },
|
|
|
+ content_type=get_content_type(req.format),
|
|
|
+ )
|
|
|
+ except HTTPException:
|
|
|
+ # Re-raise HTTP exceptions as they are already properly formatted
|
|
|
+ raise
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"Error in TTS generation: {e}", exc_info=True)
|
|
|
raise HTTPException(
|
|
|
- HTTPStatus.BAD_REQUEST,
|
|
|
- content=f"Text is too long, max length is {app_state.max_text_length}",
|
|
|
+ HTTPStatus.INTERNAL_SERVER_ERROR, content="Failed to generate speech"
|
|
|
)
|
|
|
|
|
|
- # Check if streaming is enabled
|
|
|
- if req.streaming and req.format != "wav":
|
|
|
- raise HTTPException(
|
|
|
- HTTPStatus.BAD_REQUEST,
|
|
|
- content="Streaming only supports WAV format",
|
|
|
- )
|
|
|
-
|
|
|
- # Perform TTS
|
|
|
- if req.streaming:
|
|
|
- return StreamResponse(
|
|
|
- iterable=inference_async(req, engine),
|
|
|
- headers={
|
|
|
- "Content-Disposition": f"attachment; filename=audio.{req.format}",
|
|
|
- },
|
|
|
- content_type=get_content_type(req.format),
|
|
|
- )
|
|
|
- else:
|
|
|
- fake_audios = next(inference(req, engine))
|
|
|
- buffer = io.BytesIO()
|
|
|
- sf.write(
|
|
|
- buffer,
|
|
|
- fake_audios,
|
|
|
- sample_rate,
|
|
|
- format=req.format,
|
|
|
- )
|
|
|
-
|
|
|
- return StreamResponse(
|
|
|
- iterable=buffer_to_async_generator(buffer.getvalue()),
|
|
|
- headers={
|
|
|
- "Content-Disposition": f"attachment; filename=audio.{req.format}",
|
|
|
- },
|
|
|
- content_type=get_content_type(req.format),
|
|
|
+
|
|
|
+@routes.http.post("/v1/references/add")
|
|
|
+async def add_reference(
|
|
|
+ id: str = Body(...), audio: UploadFile = Body(...), text: str = Body(...)
|
|
|
+):
|
|
|
+ """
|
|
|
+ Add a new reference voice with audio file and text.
|
|
|
+ """
|
|
|
+ temp_file_path = None
|
|
|
+
|
|
|
+ try:
|
|
|
+ # Validate input parameters
|
|
|
+ if not id or not id.strip():
|
|
|
+ raise ValueError("Reference ID cannot be empty")
|
|
|
+
|
|
|
+ if not text or not text.strip():
|
|
|
+ raise ValueError("Reference text cannot be empty")
|
|
|
+
|
|
|
+ # Get the model manager to access the reference loader
|
|
|
+ app_state = request.app.state
|
|
|
+ model_manager: ModelManager = app_state.model_manager
|
|
|
+ engine = model_manager.tts_inference_engine
|
|
|
+
|
|
|
+ # Read the uploaded audio file
|
|
|
+ audio_content = audio.read()
|
|
|
+ if not audio_content:
|
|
|
+ raise ValueError("Audio file is empty or could not be read")
|
|
|
+
|
|
|
+ # Create a temporary file for the audio data
|
|
|
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_file:
|
|
|
+ temp_file.write(audio_content)
|
|
|
+ temp_file_path = temp_file.name
|
|
|
+
|
|
|
+ # Add the reference using the engine's reference loader
|
|
|
+ engine.add_reference(id, temp_file_path, text)
|
|
|
+
|
|
|
+ response = AddReferenceResponse(
|
|
|
+ success=True,
|
|
|
+ message=f"Reference voice '{id}' added successfully",
|
|
|
+ reference_id=id,
|
|
|
+ )
|
|
|
+ return format_response(response)
|
|
|
+
|
|
|
+ except FileExistsError as e:
|
|
|
+ logger.warning(f"Reference ID '{id}' already exists: {e}")
|
|
|
+ response = AddReferenceResponse(
|
|
|
+ success=False,
|
|
|
+ message=f"Reference ID '{id}' already exists",
|
|
|
+ reference_id=id,
|
|
|
+ )
|
|
|
+ return format_response(response, status_code=409) # Conflict
|
|
|
+
|
|
|
+ except ValueError as e:
|
|
|
+ logger.warning(f"Invalid input for reference '{id}': {e}")
|
|
|
+ response = AddReferenceResponse(success=False, message=str(e), reference_id=id)
|
|
|
+ return format_response(response, status_code=400)
|
|
|
+
|
|
|
+ except (FileNotFoundError, OSError) as e:
|
|
|
+ logger.error(f"File system error for reference '{id}': {e}")
|
|
|
+ response = AddReferenceResponse(
|
|
|
+ success=False, message="File system error occurred", reference_id=id
|
|
|
+ )
|
|
|
+ return format_response(response, status_code=500)
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"Unexpected error adding reference '{id}': {e}", exc_info=True)
|
|
|
+ response = AddReferenceResponse(
|
|
|
+ success=False, message="Internal server error occurred", reference_id=id
|
|
|
+ )
|
|
|
+ return format_response(response, status_code=500)
|
|
|
+
|
|
|
+ finally:
|
|
|
+ # Clean up temporary file
|
|
|
+ if temp_file_path and os.path.exists(temp_file_path):
|
|
|
+ try:
|
|
|
+ os.unlink(temp_file_path)
|
|
|
+ except OSError as e:
|
|
|
+ logger.warning(
|
|
|
+ f"Failed to clean up temporary file {temp_file_path}: {e}"
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+@routes.http.get("/v1/references/list")
|
|
|
+async def list_references():
|
|
|
+ """
|
|
|
+ Get a list of all available reference voice IDs.
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ # Get the model manager to access the reference loader
|
|
|
+ app_state = request.app.state
|
|
|
+ model_manager: ModelManager = app_state.model_manager
|
|
|
+ engine = model_manager.tts_inference_engine
|
|
|
+
|
|
|
+ # Get the list of reference IDs
|
|
|
+ reference_ids = engine.list_reference_ids()
|
|
|
+
|
|
|
+ response = ListReferencesResponse(
|
|
|
+ success=True,
|
|
|
+ reference_ids=reference_ids,
|
|
|
+ message=f"Found {len(reference_ids)} reference voices",
|
|
|
+ )
|
|
|
+ return format_response(response)
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"Unexpected error listing references: {e}", exc_info=True)
|
|
|
+ response = ListReferencesResponse(
|
|
|
+ success=False, reference_ids=[], message="Internal server error occurred"
|
|
|
+ )
|
|
|
+ return format_response(response, status_code=500)
|
|
|
+
|
|
|
+
|
|
|
+@routes.http.delete("/v1/references/delete")
|
|
|
+async def delete_reference(reference_id: str = Body(...)):
|
|
|
+ """
|
|
|
+ Delete a reference voice by ID.
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ # Validate input parameters
|
|
|
+ if not reference_id or not reference_id.strip():
|
|
|
+ raise ValueError("Reference ID cannot be empty")
|
|
|
+
|
|
|
+ # Get the model manager to access the reference loader
|
|
|
+ app_state = request.app.state
|
|
|
+ model_manager: ModelManager = app_state.model_manager
|
|
|
+ engine = model_manager.tts_inference_engine
|
|
|
+
|
|
|
+ # Delete the reference using the engine's reference loader
|
|
|
+ engine.delete_reference(reference_id)
|
|
|
+
|
|
|
+ response = DeleteReferenceResponse(
|
|
|
+ success=True,
|
|
|
+ message=f"Reference voice '{reference_id}' deleted successfully",
|
|
|
+ reference_id=reference_id,
|
|
|
+ )
|
|
|
+ return format_response(response)
|
|
|
+
|
|
|
+ except FileNotFoundError as e:
|
|
|
+ logger.warning(f"Reference ID '{reference_id}' not found: {e}")
|
|
|
+ response = DeleteReferenceResponse(
|
|
|
+ success=False,
|
|
|
+ message=f"Reference ID '{reference_id}' not found",
|
|
|
+ reference_id=reference_id,
|
|
|
+ )
|
|
|
+ return format_response(response, status_code=404) # Not Found
|
|
|
+
|
|
|
+ except ValueError as e:
|
|
|
+ logger.warning(f"Invalid input for reference '{reference_id}': {e}")
|
|
|
+ response = DeleteReferenceResponse(
|
|
|
+ success=False, message=str(e), reference_id=reference_id
|
|
|
+ )
|
|
|
+ return format_response(response, status_code=400)
|
|
|
+
|
|
|
+ except OSError as e:
|
|
|
+ logger.error(f"File system error deleting reference '{reference_id}': {e}")
|
|
|
+ response = DeleteReferenceResponse(
|
|
|
+ success=False,
|
|
|
+ message="File system error occurred",
|
|
|
+ reference_id=reference_id,
|
|
|
+ )
|
|
|
+ return format_response(response, status_code=500)
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(
|
|
|
+ f"Unexpected error deleting reference '{reference_id}': {e}", exc_info=True
|
|
|
+ )
|
|
|
+ response = DeleteReferenceResponse(
|
|
|
+ success=False,
|
|
|
+ message="Internal server error occurred",
|
|
|
+ reference_id=reference_id,
|
|
|
+ )
|
|
|
+ return format_response(response, status_code=500)
|
|
|
+
|
|
|
+
|
|
|
+@routes.http.post("/v1/references/update")
|
|
|
+async def update_reference(
|
|
|
+ old_reference_id: str = Body(...), new_reference_id: str = Body(...)
|
|
|
+):
|
|
|
+ """
|
|
|
+ Rename a reference voice directory from old_reference_id to new_reference_id.
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ # Validate input parameters
|
|
|
+ if not old_reference_id or not old_reference_id.strip():
|
|
|
+ raise ValueError("Old reference ID cannot be empty")
|
|
|
+ if not new_reference_id or not new_reference_id.strip():
|
|
|
+ raise ValueError("New reference ID cannot be empty")
|
|
|
+ if old_reference_id == new_reference_id:
|
|
|
+ raise ValueError("New reference ID must be different from old reference ID")
|
|
|
+
|
|
|
+ # Validate ID format per ReferenceLoader rules
|
|
|
+ id_pattern = r"^[a-zA-Z0-9\-_ ]+$"
|
|
|
+ if not re.match(id_pattern, new_reference_id) or len(new_reference_id) > 255:
|
|
|
+ raise ValueError(
|
|
|
+ "New reference ID contains invalid characters or is too long"
|
|
|
+ )
|
|
|
+
|
|
|
+ # Access engine to update caches after renaming
|
|
|
+ app_state = request.app.state
|
|
|
+ model_manager: ModelManager = app_state.model_manager
|
|
|
+ engine = model_manager.tts_inference_engine
|
|
|
+
|
|
|
+ refs_base = Path("references")
|
|
|
+ old_dir = refs_base / old_reference_id
|
|
|
+ new_dir = refs_base / new_reference_id
|
|
|
+
|
|
|
+ # Existence checks
|
|
|
+ if not old_dir.exists() or not old_dir.is_dir():
|
|
|
+ raise FileNotFoundError(f"Reference ID '{old_reference_id}' not found")
|
|
|
+ if new_dir.exists():
|
|
|
+ # Conflict: destination already exists
|
|
|
+ response = UpdateReferenceResponse(
|
|
|
+ success=False,
|
|
|
+ message=f"Reference ID '{new_reference_id}' already exists",
|
|
|
+ old_reference_id=old_reference_id,
|
|
|
+ new_reference_id=new_reference_id,
|
|
|
+ )
|
|
|
+ return format_response(response, status_code=409)
|
|
|
+
|
|
|
+ # Perform rename
|
|
|
+ old_dir.rename(new_dir)
|
|
|
+
|
|
|
+ # Update in-memory cache key if present
|
|
|
+ if old_reference_id in engine.ref_by_id:
|
|
|
+ engine.ref_by_id[new_reference_id] = engine.ref_by_id.pop(old_reference_id)
|
|
|
+
|
|
|
+ response = UpdateReferenceResponse(
|
|
|
+ success=True,
|
|
|
+ message=(
|
|
|
+ f"Reference voice renamed from '{old_reference_id}' to '{new_reference_id}' successfully"
|
|
|
+ ),
|
|
|
+ old_reference_id=old_reference_id,
|
|
|
+ new_reference_id=new_reference_id,
|
|
|
+ )
|
|
|
+ return format_response(response)
|
|
|
+
|
|
|
+ except FileNotFoundError as e:
|
|
|
+ logger.warning(str(e))
|
|
|
+ response = UpdateReferenceResponse(
|
|
|
+ success=False,
|
|
|
+ message=str(e),
|
|
|
+ old_reference_id=old_reference_id,
|
|
|
+ new_reference_id=new_reference_id,
|
|
|
+ )
|
|
|
+ return format_response(response, status_code=404)
|
|
|
+
|
|
|
+ except ValueError as e:
|
|
|
+ logger.warning(f"Invalid input for update reference: {e}")
|
|
|
+ response = UpdateReferenceResponse(
|
|
|
+ success=False,
|
|
|
+ message=str(e),
|
|
|
+ old_reference_id=old_reference_id if "old_reference_id" in locals() else "",
|
|
|
+ new_reference_id=new_reference_id if "new_reference_id" in locals() else "",
|
|
|
+ )
|
|
|
+ return format_response(response, status_code=400)
|
|
|
+
|
|
|
+ except OSError as e:
|
|
|
+ logger.error(f"File system error renaming reference: {e}")
|
|
|
+ response = UpdateReferenceResponse(
|
|
|
+ success=False,
|
|
|
+ message="File system error occurred",
|
|
|
+ old_reference_id=old_reference_id,
|
|
|
+ new_reference_id=new_reference_id,
|
|
|
+ )
|
|
|
+ return format_response(response, status_code=500)
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"Unexpected error updating reference: {e}", exc_info=True)
|
|
|
+ response = UpdateReferenceResponse(
|
|
|
+ success=False,
|
|
|
+ message="Internal server error occurred",
|
|
|
+ old_reference_id=old_reference_id if "old_reference_id" in locals() else "",
|
|
|
+ new_reference_id=new_reference_id if "new_reference_id" in locals() else "",
|
|
|
)
|
|
|
+ return format_response(response, status_code=500)
|