Просмотр исходного кода

API extension for reference voices. (#1099)

* Update inference documentation to clarify steps for obtaining VQ tokens and organizing reference files in the WebUI.

* Implement reference management features in the inference engine
- Added methods to list valid reference IDs and add new references with audio and text.
- Enhanced error handling for adding references, including validation for ID format and audio file existence.
- Updated API endpoints to support adding and listing references, returning structured responses.
- Added possibility for API to return JSON

* Enhance API to support multipart/form-data.
Fixed add_reference endpoint.

* Refactor API response handling and enhance error management across endpoints.

* Add delete reference functionality to API and reference_loader.

* Pre-Commit fixes.

* Added reference update endpoint.

* Revert unintendet uv.lock edits
Valentin Schröter 6 месяцев назад
Родитель
Сommit
b25daedd60

+ 5 - 2
docs/en/inference.md

@@ -14,11 +14,11 @@ huggingface-cli download fishaudio/openaudio-s1-mini --local-dir checkpoints/ope
 
 
 ## Command Line Inference
 ## Command Line Inference
 
 
+### 1. Get VQ tokens from reference audio
+
 !!! note
 !!! note
     If you plan to let the model randomly choose a voice timbre, you can skip this step.
     If you plan to let the model randomly choose a voice timbre, you can skip this step.
 
 
-### 1. Get VQ tokens from reference audio
-
 ```bash
 ```bash
 python fish_speech/models/dac/inference.py \
 python fish_speech/models/dac/inference.py \
     -i "ref_audio_name.wav" \
     -i "ref_audio_name.wav" \
@@ -36,6 +36,8 @@ python fish_speech/models/text2semantic/inference.py \
     --prompt-tokens "fake.npy" \
     --prompt-tokens "fake.npy" \
     --compile
     --compile
 ```
 ```
+with `--prompt-tokens "fake.npy"` and `--prompt-text "Your reference text"` from step 1.
+If you want to let the model randomly choose a voice timbre, skip the two parameters.
 
 
 This command will create a `codes_N` file in the working directory, where N is an integer starting from 0.
 This command will create a `codes_N` file in the working directory, where N is an integer starting from 0.
 
 
@@ -96,6 +98,7 @@ python -m tools.run_webui
 
 
 !!! note
 !!! note
     You can save the label file and reference audio file in advance to the `references` folder in the main directory (which you need to create yourself), so that you can directly call them in the WebUI.
     You can save the label file and reference audio file in advance to the `references` folder in the main directory (which you need to create yourself), so that you can directly call them in the WebUI.
+    Inside the `references` folder, put subdirectories named `<voice_id>`, and put the label file (`sample.lab`, containing the reference text) and reference audio file (`sample.wav`) in the subdirectory.
 
 
 !!! note
 !!! note
     You can use Gradio environment variables, such as `GRADIO_SHARE`, `GRADIO_SERVER_PORT`, `GRADIO_SERVER_NAME` to configure WebUI.
     You can use Gradio environment variables, such as `GRADIO_SHARE`, `GRADIO_SERVER_PORT`, `GRADIO_SERVER_NAME` to configure WebUI.

+ 143 - 4
fish_speech/inference_engine/reference_loader.py

@@ -18,7 +18,6 @@ from fish_speech.utils.schema import ServeReferenceAudio
 
 
 
 
 class ReferenceLoader:
 class ReferenceLoader:
-
     def __init__(self) -> None:
     def __init__(self) -> None:
         """
         """
         Component of the TTSInferenceEngine class.
         Component of the TTSInferenceEngine class.
@@ -43,7 +42,6 @@ class ReferenceLoader:
         id: str,
         id: str,
         use_cache: Literal["on", "off"],
         use_cache: Literal["on", "off"],
     ) -> Tuple:
     ) -> Tuple:
-
         # Load the references audio and text by id
         # Load the references audio and text by id
         ref_folder = Path("references") / id
         ref_folder = Path("references") / id
         ref_folder.mkdir(parents=True, exist_ok=True)
         ref_folder.mkdir(parents=True, exist_ok=True)
@@ -79,7 +77,6 @@ class ReferenceLoader:
         references: list[ServeReferenceAudio],
         references: list[ServeReferenceAudio],
         use_cache: Literal["on", "off"],
         use_cache: Literal["on", "off"],
     ) -> Tuple:
     ) -> Tuple:
-
         # Load the references audio and text by hash
         # Load the references audio and text by hash
         audio_hashes = [sha256(ref.audio).hexdigest() for ref in references]
         audio_hashes = [sha256(ref.audio).hexdigest() for ref in references]
 
 
@@ -109,7 +106,7 @@ class ReferenceLoader:
 
 
         return prompt_tokens, prompt_texts
         return prompt_tokens, prompt_texts
 
 
-    def load_audio(self, reference_audio, sr):
+    def load_audio(self, reference_audio: bytes | str, sr: int):
         """
         """
         Load the audio data from a file or bytes.
         Load the audio data from a file or bytes.
         """
         """
@@ -130,3 +127,145 @@ class ReferenceLoader:
 
 
         audio = waveform.squeeze().numpy()
         audio = waveform.squeeze().numpy()
         return audio
         return audio
+
+    def list_reference_ids(self) -> list[str]:
+        """
+        List all valid reference IDs (subdirectory names containing valid audio and .lab files).
+
+        Returns:
+            list[str]: List of valid reference IDs
+        """
+        ref_base_path = Path("references")
+        if not ref_base_path.exists():
+            return []
+
+        valid_ids = []
+        for ref_dir in ref_base_path.iterdir():
+            if not ref_dir.is_dir():
+                continue
+
+            # Check if directory contains at least one audio file and corresponding .lab file
+            audio_files = list_files(
+                ref_dir, AUDIO_EXTENSIONS, recursive=False, sort=False
+            )
+            if not audio_files:
+                continue
+
+            # Check if corresponding .lab file exists for at least one audio file
+            has_valid_pair = False
+            for audio_file in audio_files:
+                lab_file = audio_file.with_suffix(".lab")
+                if lab_file.exists():
+                    has_valid_pair = True
+                    break
+
+            if has_valid_pair:
+                valid_ids.append(ref_dir.name)
+
+        return sorted(valid_ids)
+
+    def add_reference(self, id: str, wav_file_path: str, reference_text: str) -> None:
+        """
+        Add a new reference voice by creating a new directory and copying files.
+
+        Args:
+            id: Reference ID (directory name)
+            wav_file_path: Path to the audio file to copy
+            reference_text: Text content for the .lab file
+
+        Raises:
+            FileExistsError: If the reference ID already exists
+            FileNotFoundError: If the audio file doesn't exist
+            OSError: If file operations fail
+        """
+        # Validate ID format
+        import re
+
+        if not re.match(r"^[a-zA-Z0-9\-_ ]+$", id):
+            raise ValueError(
+                "Reference ID contains invalid characters. Only alphanumeric, hyphens, underscores, and spaces are allowed."
+            )
+
+        if len(id) > 255:
+            raise ValueError(
+                "Reference ID is too long. Maximum length is 255 characters."
+            )
+
+        # Check if reference already exists
+        ref_dir = Path("references") / id
+        if ref_dir.exists():
+            raise FileExistsError(f"Reference ID '{id}' already exists")
+
+        # Check if audio file exists
+        audio_path = Path(wav_file_path)
+        if not audio_path.exists():
+            raise FileNotFoundError(f"Audio file not found: {wav_file_path}")
+
+        # Validate audio file extension
+        if audio_path.suffix.lower() not in AUDIO_EXTENSIONS:
+            raise ValueError(
+                f"Unsupported audio format: {audio_path.suffix}. Supported formats: {', '.join(AUDIO_EXTENSIONS)}"
+            )
+
+        try:
+            # Create reference directory
+            ref_dir.mkdir(parents=True, exist_ok=False)
+
+            # Determine the target audio filename with original extension
+            target_audio_path = ref_dir / f"sample{audio_path.suffix}"
+
+            # Copy audio file
+            import shutil
+
+            shutil.copy2(audio_path, target_audio_path)
+
+            # Create .lab file
+            lab_path = ref_dir / "sample.lab"
+            with open(lab_path, "w", encoding="utf-8") as f:
+                f.write(reference_text)
+
+            # Clear cache for this ID if it exists
+            if id in self.ref_by_id:
+                del self.ref_by_id[id]
+
+            logger.info(f"Successfully added reference voice with ID: {id}")
+
+        except Exception as e:
+            # Clean up on failure
+            if ref_dir.exists():
+                import shutil
+
+                shutil.rmtree(ref_dir)
+            raise e
+
+    def delete_reference(self, id: str) -> None:
+        """
+        Delete a reference voice by removing its directory and files.
+
+        Args:
+            id: Reference ID (directory name) to delete
+
+        Raises:
+            FileNotFoundError: If the reference ID doesn't exist
+            OSError: If file operations fail
+        """
+        # Check if reference exists
+        ref_dir = Path("references") / id
+        if not ref_dir.exists():
+            raise FileNotFoundError(f"Reference ID '{id}' does not exist")
+
+        try:
+            # Remove the entire reference directory
+            import shutil
+
+            shutil.rmtree(ref_dir)
+
+            # Clear cache for this ID if it exists
+            if id in self.ref_by_id:
+                del self.ref_by_id[id]
+
+            logger.info(f"Successfully deleted reference voice with ID: {id}")
+
+        except Exception as e:
+            logger.error(f"Failed to delete reference '{id}': {e}")
+            raise OSError(f"Failed to delete reference '{id}': {e}")

+ 32 - 1
fish_speech/utils/schema.py

@@ -69,7 +69,7 @@ class ServeReferenceAudio(BaseModel):
         ):  # Check if audio is a string (Base64)
         ):  # Check if audio is a string (Base64)
             try:
             try:
                 values["audio"] = base64.b64decode(audio)
                 values["audio"] = base64.b64decode(audio)
-            except Exception as e:
+            except Exception:
                 # If the audio is not a valid base64 string, we will just ignore it and let the server handle it
                 # If the audio is not a valid base64 string, we will just ignore it and let the server handle it
                 pass
                 pass
         return values
         return values
@@ -103,3 +103,34 @@ class ServeTTSRequest(BaseModel):
     class Config:
     class Config:
         # Allow arbitrary types for pytorch related types
         # Allow arbitrary types for pytorch related types
         arbitrary_types_allowed = True
         arbitrary_types_allowed = True
+
+
+class AddReferenceRequest(BaseModel):
+    id: str = Field(..., min_length=1, max_length=255, pattern=r"^[a-zA-Z0-9\-_ ]+$")
+    audio: bytes
+    text: str = Field(..., min_length=1)
+
+
+class AddReferenceResponse(BaseModel):
+    success: bool
+    message: str
+    reference_id: str
+
+
+class ListReferencesResponse(BaseModel):
+    success: bool
+    reference_ids: list[str]
+    message: str = "Success"
+
+
+class DeleteReferenceResponse(BaseModel):
+    success: bool
+    message: str
+    reference_id: str
+
+
+class UpdateReferenceResponse(BaseModel):
+    success: bool
+    message: str
+    old_reference_id: str
+    new_reference_id: str

+ 9 - 5
tools/api_client.py

@@ -1,5 +1,6 @@
 import argparse
 import argparse
 import base64
 import base64
+import time
 import wave
 import wave
 
 
 import ormsgpack
 import ormsgpack
@@ -13,7 +14,6 @@ from fish_speech.utils.schema import ServeReferenceAudio, ServeTTSRequest
 
 
 
 
 def parse_args():
 def parse_args():
-
     parser = argparse.ArgumentParser(
     parser = argparse.ArgumentParser(
         description="Send a WAV file and text to a server and receive synthesized audio.",
         description="Send a WAV file and text to a server and receive synthesized audio.",
         formatter_class=argparse.RawTextHelpFormatter,
         formatter_class=argparse.RawTextHelpFormatter,
@@ -97,8 +97,9 @@ def parse_args():
         "--temperature", type=float, default=0.8, help="Temperature for sampling"
         "--temperature", type=float, default=0.8, help="Temperature for sampling"
     )
     )
 
 
+    # parser.add_argument("--streaming", type=bool, default=False, help="Enable streaming response")
     parser.add_argument(
     parser.add_argument(
-        "--streaming", type=bool, default=False, help="Enable streaming response"
+        "--streaming", action="store_true", help="Enable streaming response"
     )
     )
     parser.add_argument(
     parser.add_argument(
         "--channels", type=int, default=1, help="Number of audio channels"
         "--channels", type=int, default=1, help="Number of audio channels"
@@ -115,8 +116,7 @@ def parse_args():
         "--seed",
         "--seed",
         type=int,
         type=int,
         default=None,
         default=None,
-        help="`None` means randomized inference, otherwise deterministic.\n"
-        "It can't be used for fixing a timbre.",
+        help="`None` means randomized inference, otherwise deterministic.\nIt can't be used for fixing a timbre.",
     )
     )
     parser.add_argument(
     parser.add_argument(
         "--api_key",
         "--api_key",
@@ -129,7 +129,6 @@ def parse_args():
 
 
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
-
     args = parse_args()
     args = parse_args()
 
 
     idstr: str | None = args.reference_id
     idstr: str | None = args.reference_id
@@ -172,8 +171,11 @@ if __name__ == "__main__":
 
 
     pydantic_data = ServeTTSRequest(**data)
     pydantic_data = ServeTTSRequest(**data)
 
 
+    print("Sending request")
+    start_time = time.time()
     response = requests.post(
     response = requests.post(
         args.url,
         args.url,
+        params={"format": "msgpack"},
         data=ormsgpack.packb(pydantic_data, option=ormsgpack.OPT_SERIALIZE_PYDANTIC),
         data=ormsgpack.packb(pydantic_data, option=ormsgpack.OPT_SERIALIZE_PYDANTIC),
         stream=args.streaming,
         stream=args.streaming,
         headers={
         headers={
@@ -181,6 +183,8 @@ if __name__ == "__main__":
             "content-type": "application/msgpack",
             "content-type": "application/msgpack",
         },
         },
     )
     )
+    end_time = time.time()
+    print(f"Request took {end_time - start_time} seconds")
 
 
     if response.status_code == 200:
     if response.status_code == 200:
         if args.streaming:
         if args.streaming:

+ 77 - 3
tools/server/api_utils.py

@@ -4,7 +4,14 @@ from typing import Annotated, Any
 
 
 import ormsgpack
 import ormsgpack
 from baize.datastructures import ContentType
 from baize.datastructures import ContentType
-from kui.asgi import HTTPException, HttpRequest
+from kui.asgi import (
+    HTTPException,
+    HttpRequest,
+    JSONResponse,
+    request,
+)
+from loguru import logger
+from pydantic import BaseModel
 
 
 from fish_speech.inference_engine import TTSInferenceEngine
 from fish_speech.inference_engine import TTSInferenceEngine
 from fish_speech.utils.schema import ServeTTSRequest
 from fish_speech.utils.schema import ServeTTSRequest
@@ -40,7 +47,10 @@ class MsgPackRequest(HttpRequest):
     async def data(
     async def data(
         self,
         self,
     ) -> Annotated[
     ) -> Annotated[
-        Any, ContentType("application/msgpack"), ContentType("application/json")
+        Any,
+        ContentType("application/msgpack"),
+        ContentType("application/json"),
+        ContentType("multipart/form-data"),
     ]:
     ]:
         if self.content_type == "application/msgpack":
         if self.content_type == "application/msgpack":
             return ormsgpack.unpackb(await self.body)
             return ormsgpack.unpackb(await self.body)
@@ -48,14 +58,20 @@ class MsgPackRequest(HttpRequest):
         elif self.content_type == "application/json":
         elif self.content_type == "application/json":
             return await self.json
             return await self.json
 
 
+        elif self.content_type == "multipart/form-data":
+            return await self.form
+
         raise HTTPException(
         raise HTTPException(
             HTTPStatus.UNSUPPORTED_MEDIA_TYPE,
             HTTPStatus.UNSUPPORTED_MEDIA_TYPE,
-            headers={"Accept": "application/msgpack, application/json"},
+            headers={
+                "Accept": "application/msgpack, application/json, multipart/form-data"
+            },
         )
         )
 
 
 
 
 async def inference_async(req: ServeTTSRequest, engine: TTSInferenceEngine):
 async def inference_async(req: ServeTTSRequest, engine: TTSInferenceEngine):
     for chunk in inference(req, engine):
     for chunk in inference(req, engine):
+        print("Got chunk")
         if isinstance(chunk, bytes):
         if isinstance(chunk, bytes):
             yield chunk
             yield chunk
 
 
@@ -73,3 +89,61 @@ def get_content_type(audio_format):
         return "audio/mpeg"
         return "audio/mpeg"
     else:
     else:
         return "application/octet-stream"
         return "application/octet-stream"
+
+
+def wants_json(req):
+    """Helper method to determine if the client wants a JSON response
+
+    Parameters
+    ----------
+    req : Request
+        The request object
+
+    Returns
+    -------
+    bool
+        True if the client wants a JSON response, False otherwise
+    """
+    q = req.query_params.get("format", "").strip().lower()
+    if q in {"json", "application/json", "msgpack", "application/msgpack"}:
+        return q == "json"
+    accept = req.headers.get("Accept", "").strip().lower()
+    return "application/json" in accept and "application/msgpack" not in accept
+
+
+def format_response(response: BaseModel, status_code=200):
+    """
+    Helper function to format responses consistently based on client preference.
+
+    Parameters
+    ----------
+    response : BaseModel
+        The response object to format
+    status_code : int
+        HTTP status code (default: 200)
+
+    Returns
+    -------
+    Response
+        Formatted response in the client's preferred format
+    """
+    try:
+        if wants_json(request):
+            return JSONResponse(
+                response.model_dump(mode="json"), status_code=status_code
+            )
+
+        return (
+            ormsgpack.packb(
+                response,
+                option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
+            ),
+            status_code,
+            {"Content-Type": "application/msgpack"},
+        )
+    except Exception as e:
+        logger.error(f"Error formatting response: {e}", exc_info=True)
+        # Fallback to JSON response if formatting fails
+        return JSONResponse(
+            {"error": "Response formatting failed", "details": str(e)}, status_code=500
+        )

+ 388 - 68
tools/server/views.py

@@ -1,7 +1,11 @@
 import io
 import io
 import os
 import os
+import re
+import shutil
+import tempfile
 import time
 import time
 from http import HTTPStatus
 from http import HTTPStatus
+from pathlib import Path
 
 
 import numpy as np
 import numpy as np
 import ormsgpack
 import ormsgpack
@@ -14,20 +18,27 @@ from kui.asgi import (
     JSONResponse,
     JSONResponse,
     Routes,
     Routes,
     StreamResponse,
     StreamResponse,
+    UploadFile,
     request,
     request,
 )
 )
 from loguru import logger
 from loguru import logger
 from typing_extensions import Annotated
 from typing_extensions import Annotated
 
 
 from fish_speech.utils.schema import (
 from fish_speech.utils.schema import (
+    AddReferenceRequest,
+    AddReferenceResponse,
+    DeleteReferenceResponse,
+    ListReferencesResponse,
     ServeTTSRequest,
     ServeTTSRequest,
     ServeVQGANDecodeRequest,
     ServeVQGANDecodeRequest,
     ServeVQGANDecodeResponse,
     ServeVQGANDecodeResponse,
     ServeVQGANEncodeRequest,
     ServeVQGANEncodeRequest,
     ServeVQGANEncodeResponse,
     ServeVQGANEncodeResponse,
+    UpdateReferenceResponse,
 )
 )
 from tools.server.api_utils import (
 from tools.server.api_utils import (
     buffer_to_async_generator,
     buffer_to_async_generator,
+    format_response,
     get_content_type,
     get_content_type,
     inference_async,
     inference_async,
 )
 )
@@ -56,87 +67,396 @@ class Health(HttpView):
 
 
 @routes.http.post("/v1/vqgan/encode")
 @routes.http.post("/v1/vqgan/encode")
 async def vqgan_encode(req: Annotated[ServeVQGANEncodeRequest, Body(exclusive=True)]):
 async def vqgan_encode(req: Annotated[ServeVQGANEncodeRequest, Body(exclusive=True)]):
-    # Get the model from the app
-    model_manager: ModelManager = request.app.state.model_manager
-    decoder_model = model_manager.decoder_model
+    """
+    Encode audio using VQGAN model.
+    """
+    try:
+        # Get the model from the app
+        model_manager: ModelManager = request.app.state.model_manager
+        decoder_model = model_manager.decoder_model
 
 
-    # Encode the audio
-    start_time = time.time()
-    tokens = cached_vqgan_batch_encode(decoder_model, req.audios)
-    logger.info(f"[EXEC] VQGAN encode time: {(time.time() - start_time) * 1000:.2f}ms")
+        # Encode the audio
+        start_time = time.time()
+        tokens = cached_vqgan_batch_encode(decoder_model, req.audios)
+        logger.info(
+            f"[EXEC] VQGAN encode time: {(time.time() - start_time) * 1000:.2f}ms"
+        )
 
 
-    # Return the response
-    return ormsgpack.packb(
-        ServeVQGANEncodeResponse(tokens=[i.tolist() for i in tokens]),
-        option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
-    )
+        # Return the response
+        return ormsgpack.packb(
+            ServeVQGANEncodeResponse(tokens=[i.tolist() for i in tokens]),
+            option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
+        )
+    except Exception as e:
+        logger.error(f"Error in VQGAN encode: {e}", exc_info=True)
+        raise HTTPException(
+            HTTPStatus.INTERNAL_SERVER_ERROR, content="Failed to encode audio"
+        )
 
 
 
 
 @routes.http.post("/v1/vqgan/decode")
 @routes.http.post("/v1/vqgan/decode")
 async def vqgan_decode(req: Annotated[ServeVQGANDecodeRequest, Body(exclusive=True)]):
 async def vqgan_decode(req: Annotated[ServeVQGANDecodeRequest, Body(exclusive=True)]):
-    # Get the model from the app
-    model_manager: ModelManager = request.app.state.model_manager
-    decoder_model = model_manager.decoder_model
+    """
+    Decode tokens to audio using VQGAN model.
+    """
+    try:
+        # Get the model from the app
+        model_manager: ModelManager = request.app.state.model_manager
+        decoder_model = model_manager.decoder_model
 
 
-    # Decode the audio
-    tokens = [torch.tensor(token, dtype=torch.int) for token in req.tokens]
-    start_time = time.time()
-    audios = batch_vqgan_decode(decoder_model, tokens)
-    logger.info(f"[EXEC] VQGAN decode time: {(time.time() - start_time) * 1000:.2f}ms")
-    audios = [audio.astype(np.float16).tobytes() for audio in audios]
+        # Decode the audio
+        tokens = [torch.tensor(token, dtype=torch.int) for token in req.tokens]
+        start_time = time.time()
+        audios = batch_vqgan_decode(decoder_model, tokens)
+        logger.info(
+            f"[EXEC] VQGAN decode time: {(time.time() - start_time) * 1000:.2f}ms"
+        )
+        audios = [audio.astype(np.float16).tobytes() for audio in audios]
 
 
-    # Return the response
-    return ormsgpack.packb(
-        ServeVQGANDecodeResponse(audios=audios),
-        option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
-    )
+        # Return the response
+        return ormsgpack.packb(
+            ServeVQGANDecodeResponse(audios=audios),
+            option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
+        )
+    except Exception as e:
+        logger.error(f"Error in VQGAN decode: {e}", exc_info=True)
+        raise HTTPException(
+            HTTPStatus.INTERNAL_SERVER_ERROR, content="Failed to decode tokens to audio"
+        )
 
 
 
 
 @routes.http.post("/v1/tts")
 @routes.http.post("/v1/tts")
 async def tts(req: Annotated[ServeTTSRequest, Body(exclusive=True)]):
 async def tts(req: Annotated[ServeTTSRequest, Body(exclusive=True)]):
-    # Get the model from the app
-    app_state = request.app.state
-    model_manager: ModelManager = app_state.model_manager
-    engine = model_manager.tts_inference_engine
-    sample_rate = engine.decoder_model.sample_rate
-
-    # Check if the text is too long
-    if app_state.max_text_length > 0 and len(req.text) > app_state.max_text_length:
+    """
+    Generate speech from text using TTS model.
+    """
+    try:
+        # Get the model from the app
+        app_state = request.app.state
+        model_manager: ModelManager = app_state.model_manager
+        engine = model_manager.tts_inference_engine
+        sample_rate = engine.decoder_model.sample_rate
+
+        # Check if the text is too long
+        if app_state.max_text_length > 0 and len(req.text) > app_state.max_text_length:
+            raise HTTPException(
+                HTTPStatus.BAD_REQUEST,
+                content=f"Text is too long, max length is {app_state.max_text_length}",
+            )
+
+        # Check if streaming is enabled
+        if req.streaming and req.format != "wav":
+            raise HTTPException(
+                HTTPStatus.BAD_REQUEST,
+                content="Streaming only supports WAV format",
+            )
+
+        # Perform TTS
+        if req.streaming:
+            return StreamResponse(
+                iterable=inference_async(req, engine),
+                headers={
+                    "Content-Disposition": f"attachment; filename=audio.{req.format}",
+                },
+                content_type=get_content_type(req.format),
+            )
+        else:
+            fake_audios = next(inference(req, engine))
+            buffer = io.BytesIO()
+            sf.write(
+                buffer,
+                fake_audios,
+                sample_rate,
+                format=req.format,
+            )
+
+            return StreamResponse(
+                iterable=buffer_to_async_generator(buffer.getvalue()),
+                headers={
+                    "Content-Disposition": f"attachment; filename=audio.{req.format}",
+                },
+                content_type=get_content_type(req.format),
+            )
+    except HTTPException:
+        # Re-raise HTTP exceptions as they are already properly formatted
+        raise
+    except Exception as e:
+        logger.error(f"Error in TTS generation: {e}", exc_info=True)
         raise HTTPException(
         raise HTTPException(
-            HTTPStatus.BAD_REQUEST,
-            content=f"Text is too long, max length is {app_state.max_text_length}",
+            HTTPStatus.INTERNAL_SERVER_ERROR, content="Failed to generate speech"
         )
         )
 
 
-    # Check if streaming is enabled
-    if req.streaming and req.format != "wav":
-        raise HTTPException(
-            HTTPStatus.BAD_REQUEST,
-            content="Streaming only supports WAV format",
-        )
-
-    # Perform TTS
-    if req.streaming:
-        return StreamResponse(
-            iterable=inference_async(req, engine),
-            headers={
-                "Content-Disposition": f"attachment; filename=audio.{req.format}",
-            },
-            content_type=get_content_type(req.format),
-        )
-    else:
-        fake_audios = next(inference(req, engine))
-        buffer = io.BytesIO()
-        sf.write(
-            buffer,
-            fake_audios,
-            sample_rate,
-            format=req.format,
-        )
-
-        return StreamResponse(
-            iterable=buffer_to_async_generator(buffer.getvalue()),
-            headers={
-                "Content-Disposition": f"attachment; filename=audio.{req.format}",
-            },
-            content_type=get_content_type(req.format),
+
+@routes.http.post("/v1/references/add")
+async def add_reference(
+    id: str = Body(...), audio: UploadFile = Body(...), text: str = Body(...)
+):
+    """
+    Add a new reference voice with audio file and text.
+    """
+    temp_file_path = None
+
+    try:
+        # Validate input parameters
+        if not id or not id.strip():
+            raise ValueError("Reference ID cannot be empty")
+
+        if not text or not text.strip():
+            raise ValueError("Reference text cannot be empty")
+
+        # Get the model manager to access the reference loader
+        app_state = request.app.state
+        model_manager: ModelManager = app_state.model_manager
+        engine = model_manager.tts_inference_engine
+
+        # Read the uploaded audio file
+        audio_content = audio.read()
+        if not audio_content:
+            raise ValueError("Audio file is empty or could not be read")
+
+        # Create a temporary file for the audio data
+        with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_file:
+            temp_file.write(audio_content)
+            temp_file_path = temp_file.name
+
+        # Add the reference using the engine's reference loader
+        engine.add_reference(id, temp_file_path, text)
+
+        response = AddReferenceResponse(
+            success=True,
+            message=f"Reference voice '{id}' added successfully",
+            reference_id=id,
+        )
+        return format_response(response)
+
+    except FileExistsError as e:
+        logger.warning(f"Reference ID '{id}' already exists: {e}")
+        response = AddReferenceResponse(
+            success=False,
+            message=f"Reference ID '{id}' already exists",
+            reference_id=id,
+        )
+        return format_response(response, status_code=409)  # Conflict
+
+    except ValueError as e:
+        logger.warning(f"Invalid input for reference '{id}': {e}")
+        response = AddReferenceResponse(success=False, message=str(e), reference_id=id)
+        return format_response(response, status_code=400)
+
+    except (FileNotFoundError, OSError) as e:
+        logger.error(f"File system error for reference '{id}': {e}")
+        response = AddReferenceResponse(
+            success=False, message="File system error occurred", reference_id=id
+        )
+        return format_response(response, status_code=500)
+
+    except Exception as e:
+        logger.error(f"Unexpected error adding reference '{id}': {e}", exc_info=True)
+        response = AddReferenceResponse(
+            success=False, message="Internal server error occurred", reference_id=id
+        )
+        return format_response(response, status_code=500)
+
+    finally:
+        # Clean up temporary file
+        if temp_file_path and os.path.exists(temp_file_path):
+            try:
+                os.unlink(temp_file_path)
+            except OSError as e:
+                logger.warning(
+                    f"Failed to clean up temporary file {temp_file_path}: {e}"
+                )
+
+
+@routes.http.get("/v1/references/list")
+async def list_references():
+    """
+    Get a list of all available reference voice IDs.
+    """
+    try:
+        # Get the model manager to access the reference loader
+        app_state = request.app.state
+        model_manager: ModelManager = app_state.model_manager
+        engine = model_manager.tts_inference_engine
+
+        # Get the list of reference IDs
+        reference_ids = engine.list_reference_ids()
+
+        response = ListReferencesResponse(
+            success=True,
+            reference_ids=reference_ids,
+            message=f"Found {len(reference_ids)} reference voices",
+        )
+        return format_response(response)
+
+    except Exception as e:
+        logger.error(f"Unexpected error listing references: {e}", exc_info=True)
+        response = ListReferencesResponse(
+            success=False, reference_ids=[], message="Internal server error occurred"
+        )
+        return format_response(response, status_code=500)
+
+
+@routes.http.delete("/v1/references/delete")
+async def delete_reference(reference_id: str = Body(...)):
+    """
+    Delete a reference voice by ID.
+    """
+    try:
+        # Validate input parameters
+        if not reference_id or not reference_id.strip():
+            raise ValueError("Reference ID cannot be empty")
+
+        # Get the model manager to access the reference loader
+        app_state = request.app.state
+        model_manager: ModelManager = app_state.model_manager
+        engine = model_manager.tts_inference_engine
+
+        # Delete the reference using the engine's reference loader
+        engine.delete_reference(reference_id)
+
+        response = DeleteReferenceResponse(
+            success=True,
+            message=f"Reference voice '{reference_id}' deleted successfully",
+            reference_id=reference_id,
+        )
+        return format_response(response)
+
+    except FileNotFoundError as e:
+        logger.warning(f"Reference ID '{reference_id}' not found: {e}")
+        response = DeleteReferenceResponse(
+            success=False,
+            message=f"Reference ID '{reference_id}' not found",
+            reference_id=reference_id,
+        )
+        return format_response(response, status_code=404)  # Not Found
+
+    except ValueError as e:
+        logger.warning(f"Invalid input for reference '{reference_id}': {e}")
+        response = DeleteReferenceResponse(
+            success=False, message=str(e), reference_id=reference_id
+        )
+        return format_response(response, status_code=400)
+
+    except OSError as e:
+        logger.error(f"File system error deleting reference '{reference_id}': {e}")
+        response = DeleteReferenceResponse(
+            success=False,
+            message="File system error occurred",
+            reference_id=reference_id,
+        )
+        return format_response(response, status_code=500)
+
+    except Exception as e:
+        logger.error(
+            f"Unexpected error deleting reference '{reference_id}': {e}", exc_info=True
+        )
+        response = DeleteReferenceResponse(
+            success=False,
+            message="Internal server error occurred",
+            reference_id=reference_id,
+        )
+        return format_response(response, status_code=500)
+
+
+@routes.http.post("/v1/references/update")
+async def update_reference(
+    old_reference_id: str = Body(...), new_reference_id: str = Body(...)
+):
+    """
+    Rename a reference voice directory from old_reference_id to new_reference_id.
+    """
+    try:
+        # Validate input parameters
+        if not old_reference_id or not old_reference_id.strip():
+            raise ValueError("Old reference ID cannot be empty")
+        if not new_reference_id or not new_reference_id.strip():
+            raise ValueError("New reference ID cannot be empty")
+        if old_reference_id == new_reference_id:
+            raise ValueError("New reference ID must be different from old reference ID")
+
+        # Validate ID format per ReferenceLoader rules
+        id_pattern = r"^[a-zA-Z0-9\-_ ]+$"
+        if not re.match(id_pattern, new_reference_id) or len(new_reference_id) > 255:
+            raise ValueError(
+                "New reference ID contains invalid characters or is too long"
+            )
+
+        # Access engine to update caches after renaming
+        app_state = request.app.state
+        model_manager: ModelManager = app_state.model_manager
+        engine = model_manager.tts_inference_engine
+
+        refs_base = Path("references")
+        old_dir = refs_base / old_reference_id
+        new_dir = refs_base / new_reference_id
+
+        # Existence checks
+        if not old_dir.exists() or not old_dir.is_dir():
+            raise FileNotFoundError(f"Reference ID '{old_reference_id}' not found")
+        if new_dir.exists():
+            # Conflict: destination already exists
+            response = UpdateReferenceResponse(
+                success=False,
+                message=f"Reference ID '{new_reference_id}' already exists",
+                old_reference_id=old_reference_id,
+                new_reference_id=new_reference_id,
+            )
+            return format_response(response, status_code=409)
+
+        # Perform rename
+        old_dir.rename(new_dir)
+
+        # Update in-memory cache key if present
+        if old_reference_id in engine.ref_by_id:
+            engine.ref_by_id[new_reference_id] = engine.ref_by_id.pop(old_reference_id)
+
+        response = UpdateReferenceResponse(
+            success=True,
+            message=(
+                f"Reference voice renamed from '{old_reference_id}' to '{new_reference_id}' successfully"
+            ),
+            old_reference_id=old_reference_id,
+            new_reference_id=new_reference_id,
+        )
+        return format_response(response)
+
+    except FileNotFoundError as e:
+        logger.warning(str(e))
+        response = UpdateReferenceResponse(
+            success=False,
+            message=str(e),
+            old_reference_id=old_reference_id,
+            new_reference_id=new_reference_id,
+        )
+        return format_response(response, status_code=404)
+
+    except ValueError as e:
+        logger.warning(f"Invalid input for update reference: {e}")
+        response = UpdateReferenceResponse(
+            success=False,
+            message=str(e),
+            old_reference_id=old_reference_id if "old_reference_id" in locals() else "",
+            new_reference_id=new_reference_id if "new_reference_id" in locals() else "",
+        )
+        return format_response(response, status_code=400)
+
+    except OSError as e:
+        logger.error(f"File system error renaming reference: {e}")
+        response = UpdateReferenceResponse(
+            success=False,
+            message="File system error occurred",
+            old_reference_id=old_reference_id,
+            new_reference_id=new_reference_id,
+        )
+        return format_response(response, status_code=500)
+
+    except Exception as e:
+        logger.error(f"Unexpected error updating reference: {e}", exc_info=True)
+        response = UpdateReferenceResponse(
+            success=False,
+            message="Internal server error occurred",
+            old_reference_id=old_reference_id if "old_reference_id" in locals() else "",
+            new_reference_id=new_reference_id if "new_reference_id" in locals() else "",
         )
         )
+        return format_response(response, status_code=500)