Pārlūkot izejas kodu

Fix ServeReferenceAudio to allow base64 reference data in json (#777)

* Update schema.py to fix ServeStreamResponse

* Update schema.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
MithrilMan 1 gadu atpakaļ
vecāks
revīzija
b8bdcd454c
1 mainītis faili ar 15 papildinājumiem un 1 dzēšanām
  1. 15 1
      tools/schema.py

+ 15 - 1
tools/schema.py

@@ -1,10 +1,11 @@
+import base64
 import os
 import queue
 from dataclasses import dataclass
 from typing import Literal
 
 import torch
-from pydantic import BaseModel, Field, conint, conlist
+from pydantic import BaseModel, Field, conint, conlist, model_validator
 from pydantic.functional_validators import SkipValidation
 from typing_extensions import Annotated
 
@@ -140,6 +141,19 @@ class ServeReferenceAudio(BaseModel):
     audio: bytes
     text: str
 
+    @model_validator(mode="before")
+    def decode_audio(cls, values):
+        audio = values.get("audio")
+        if (
+            isinstance(audio, str) and len(audio) > 255
+        ):  # Check if audio is a string (Base64)
+            try:
+                values["audio"] = base64.b64decode(audio)
+            except Exception as e:
+                # If the audio is not a valid base64 string, we will just ignore it and let the server handle it
+                pass
+        return values
+
     def __repr__(self) -> str:
         return f"ServeReferenceAudio(text={self.text!r}, audio_size={len(self.audio)})"