| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105 |
- import base64
- import os
- import queue
- from dataclasses import dataclass
- from typing import Literal
- import torch
- from pydantic import BaseModel, Field, conint, model_validator
- from pydantic.functional_validators import SkipValidation
- from typing_extensions import Annotated
- from fish_speech.content_sequence import TextPart, VQPart
- class ServeVQPart(BaseModel):
- type: Literal["vq"] = "vq"
- codes: SkipValidation[list[list[int]]]
- class ServeTextPart(BaseModel):
- type: Literal["text"] = "text"
- text: str
- class ServeAudioPart(BaseModel):
- type: Literal["audio"] = "audio"
- audio: bytes
- class ServeRequest(BaseModel):
- # Raw content sequence dict that we can use with ContentSequence(**content)
- content: dict
- max_new_tokens: int = 600
- top_p: float = 0.7
- repetition_penalty: float = 1.2
- temperature: float = 0.7
- streaming: bool = False
- num_samples: int = 1
- early_stop_threshold: float = 1.0
- class ServeVQGANEncodeRequest(BaseModel):
- # The audio here should be in wav, mp3, etc
- audios: list[bytes]
- class ServeVQGANEncodeResponse(BaseModel):
- tokens: SkipValidation[list[list[list[int]]]]
- class ServeVQGANDecodeRequest(BaseModel):
- tokens: SkipValidation[list[list[list[int]]]]
- class ServeVQGANDecodeResponse(BaseModel):
- # The audio here should be in PCM float16 format
- audios: list[bytes]
- class ServeReferenceAudio(BaseModel):
- audio: bytes
- text: str
- @model_validator(mode="before")
- def decode_audio(cls, values):
- audio = values.get("audio")
- if (
- isinstance(audio, str) and len(audio) > 255
- ): # Check if audio is a string (Base64)
- try:
- values["audio"] = base64.b64decode(audio)
- except Exception as e:
- # If the audio is not a valid base64 string, we will just ignore it and let the server handle it
- pass
- return values
- def __repr__(self) -> str:
- return f"ServeReferenceAudio(text={self.text!r}, audio_size={len(self.audio)})"
- class ServeTTSRequest(BaseModel):
- text: str
- chunk_length: Annotated[int, conint(ge=100, le=300, strict=True)] = 200
- # Audio format
- format: Literal["wav", "pcm", "mp3"] = "wav"
- # References audios for in-context learning
- references: list[ServeReferenceAudio] = []
- # Reference id
- # For example, if you want use https://fish.audio/m/7f92f8afb8ec43bf81429cc1c9199cb1/
- # Just pass 7f92f8afb8ec43bf81429cc1c9199cb1
- reference_id: str | None = None
- seed: int | None = None
- use_memory_cache: Literal["on", "off"] = "off"
- # Normalize text for en & zh, this increase stability for numbers
- normalize: bool = True
- # not usually used below
- streaming: bool = False
- max_new_tokens: int = 1024
- top_p: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.8
- repetition_penalty: Annotated[float, Field(ge=0.9, le=2.0, strict=True)] = 1.1
- temperature: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.8
- class Config:
- # Allow arbitrary types for pytorch related types
- arbitrary_types_allowed = True
|