|
|
@@ -1,10 +1,11 @@
|
|
|
+import base64
|
|
|
import os
|
|
|
import queue
|
|
|
from dataclasses import dataclass
|
|
|
from typing import Literal
|
|
|
|
|
|
import torch
|
|
|
-from pydantic import BaseModel, Field, conint, conlist
|
|
|
+from pydantic import BaseModel, Field, conint, conlist, model_validator
|
|
|
from pydantic.functional_validators import SkipValidation
|
|
|
from typing_extensions import Annotated
|
|
|
|
|
|
@@ -140,6 +141,19 @@ class ServeReferenceAudio(BaseModel):
|
|
|
audio: bytes
|
|
|
text: str
|
|
|
|
|
|
+ @model_validator(mode="before")
|
|
|
+ def decode_audio(cls, values):
|
|
|
+ audio = values.get("audio")
|
|
|
+ if (
|
|
|
+ isinstance(audio, str) and len(audio) > 255
|
|
|
+ ): # Check if audio is a string (Base64)
|
|
|
+ try:
|
|
|
+ values["audio"] = base64.b64decode(audio)
|
|
|
+ except Exception as e:
|
|
|
+ # If the audio is not a valid base64 string, we will just ignore it and let the server handle it
|
|
|
+ pass
|
|
|
+ return values
|
|
|
+
|
|
|
def __repr__(self) -> str:
|
|
|
return f"ServeReferenceAudio(text={self.text!r}, audio_size={len(self.audio)})"
|
|
|
|