|
|
@@ -27,13 +27,6 @@ class ServeAudioPart(BaseModel):
|
|
|
audio: bytes
|
|
|
|
|
|
|
|
|
-@dataclass
|
|
|
-class ASRPackRequest:
|
|
|
- audio: torch.Tensor
|
|
|
- result_queue: queue.Queue
|
|
|
- language: str
|
|
|
-
|
|
|
-
|
|
|
class ServeASRRequest(BaseModel):
|
|
|
# The audio should be an uncompressed PCM float16 audio
|
|
|
audios: list[bytes]
|
|
|
@@ -93,17 +86,6 @@ class ServeVQGANDecodeResponse(BaseModel):
|
|
|
audios: list[bytes]
|
|
|
|
|
|
|
|
|
-class ServeContentSequenceParts(BaseModel):
|
|
|
- parts: list[VQPart | TextPart]
|
|
|
-
|
|
|
-
|
|
|
-class ServeResponse(BaseModel):
|
|
|
- content_sequences: list[ServeContentSequenceParts]
|
|
|
- finish_reason: Literal["stop", "error"] | None = None
|
|
|
- stats: dict[str, int | float | str] = {}
|
|
|
- finished: list[bool] | None = None
|
|
|
-
|
|
|
-
|
|
|
class ServeStreamDelta(BaseModel):
|
|
|
role: Literal["system", "assistant", "user"] | None = None
|
|
|
part: ServeVQPart | ServeTextPart | None = None
|
|
|
@@ -155,9 +137,9 @@ class ServeTTSRequest(BaseModel):
|
|
|
# not usually used below
|
|
|
streaming: bool = False
|
|
|
max_new_tokens: int = 1024
|
|
|
- top_p: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
|
|
|
- repetition_penalty: Annotated[float, Field(ge=0.9, le=2.0, strict=True)] = 1.2
|
|
|
- temperature: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
|
|
|
+ top_p: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.8
|
|
|
+ repetition_penalty: Annotated[float, Field(ge=0.9, le=2.0, strict=True)] = 1.1
|
|
|
+ temperature: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.8
|
|
|
|
|
|
class Config:
|
|
|
# Allow arbitrary types for pytorch related types
|