| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187 |
- import os
- import queue
- from dataclasses import dataclass
- from typing import Annotated, Literal, Optional
- import torch
- from pydantic import AfterValidator, BaseModel, Field, confloat, conint, conlist
- from pydantic.functional_validators import SkipValidation
- from fish_speech.conversation import Message, TextPart, VQPart
- GLOBAL_NUM_SAMPLES = int(os.getenv("GLOBAL_NUM_SAMPLES", 1))
- class ServeVQPart(BaseModel):
- type: Literal["vq"] = "vq"
- codes: SkipValidation[list[list[int]]]
- class ServeTextPart(BaseModel):
- type: Literal["text"] = "text"
- text: str
- class ServeAudioPart(BaseModel):
- type: Literal["audio"] = "audio"
- audio: bytes
- @dataclass
- class ASRPackRequest:
- audio: torch.Tensor
- result_queue: queue.Queue
- language: str
- class ServeASRRequest(BaseModel):
- # The audio should be an uncompressed PCM float16 audio
- audios: list[bytes]
- sample_rate: int = 44100
- language: Literal["zh", "en", "ja", "auto"] = "auto"
- class ServeASRTranscription(BaseModel):
- text: str
- duration: float
- huge_gap: bool
- class ServeASRSegment(BaseModel):
- text: str
- start: float
- end: float
- class ServeTimedASRResponse(BaseModel):
- text: str
- segments: list[ServeASRSegment]
- duration: float
- class ServeASRResponse(BaseModel):
- transcriptions: list[ServeASRTranscription]
- class ServeMessage(BaseModel):
- role: Literal["system", "assistant", "user"]
- parts: list[ServeVQPart | ServeTextPart]
- def to_conversation_message(self):
- new_message = Message(role=self.role, parts=[])
- for part in self.parts:
- if isinstance(part, ServeTextPart):
- new_message.parts.append(TextPart(text=part.text))
- elif isinstance(part, ServeVQPart):
- new_message.parts.append(
- VQPart(codes=torch.tensor(part.codes, dtype=torch.int))
- )
- else:
- raise ValueError(f"Unsupported part type: {part}")
- return new_message
- class ServeRequest(BaseModel):
- messages: Annotated[list[ServeMessage], conlist(ServeMessage, min_length=1)]
- max_new_tokens: int = 1024
- top_p: float = 0.7
- repetition_penalty: float = 1.2
- temperature: float = 0.7
- streaming: bool = False
- num_samples: int = 1
- early_stop_threshold: float = 1.0
- class ServeVQGANEncodeRequest(BaseModel):
- # The audio here should be in wav, mp3, etc
- audios: list[bytes]
- class ServeVQGANEncodeResponse(BaseModel):
- tokens: SkipValidation[list[list[list[int]]]]
- class ServeVQGANDecodeRequest(BaseModel):
- tokens: SkipValidation[list[list[list[int]]]]
- class ServeVQGANDecodeResponse(BaseModel):
- # The audio here should be in PCM float16 format
- audios: list[bytes]
- class ServeReferenceAudio(BaseModel):
- audio: bytes
- text: str
- class ServeForwardMessage(BaseModel):
- role: str
- content: str
- class ServeResponse(BaseModel):
- messages: list[ServeMessage]
- finish_reason: Literal["stop", "error"] | None = None
- stats: dict[str, int | float | str] = {}
- class ServeStreamDelta(BaseModel):
- role: Literal["system", "assistant", "user"] | None = None
- part: ServeVQPart | ServeTextPart | None = None
- class ServeStreamResponse(BaseModel):
- sample_id: int = 0
- delta: ServeStreamDelta | None = None
- finish_reason: Literal["stop", "error"] | None = None
- stats: dict[str, int | float | str] | None = None
- class ServeReferenceAudio(BaseModel):
- audio: bytes
- text: str
- def __repr__(self) -> str:
- return f"ServeReferenceAudio(text={self.text!r}, audio_size={len(self.audio)})"
- class ServeChatRequestV1(BaseModel):
- model: str = "llama3-8b"
- messages: list[ServeForwardMessage] = []
- audio: bytes | None = None
- temperature: float = 1.0
- top_p: float = 1.0
- max_tokens: int = 256
- voice: str = "jessica"
- tts_audio_format: Literal["mp3", "pcm", "opus"] = "mp3"
- tts_audio_bitrate: Literal[16, 24, 32, 48, 64, 96, 128, 192] = 128
- class ServeTTSRequest(BaseModel):
- text: str
- chunk_length: Annotated[int, conint(ge=100, le=300, strict=True)] = 200
- # Audio format
- format: Literal["wav", "pcm", "mp3"] = "wav"
- mp3_bitrate: Literal[64, 128, 192] = 128
- # References audios for in-context learning
- references: list[ServeReferenceAudio] = []
- # Reference id
- # For example, if you want use https://fish.audio/m/7f92f8afb8ec43bf81429cc1c9199cb1/
- # Just pass 7f92f8afb8ec43bf81429cc1c9199cb1
- reference_id: str | None = None
- seed: int | None = None
- use_memory_cache: Literal["on-demand", "never"] = "never"
- # Normalize text for en & zh, this increase stability for numbers
- normalize: bool = True
- mp3_bitrate: Optional[int] = 64
- opus_bitrate: Optional[int] = -1000
- # Balance mode will reduce latency to 300ms, but may decrease stability
- latency: Literal["normal", "balanced"] = "normal"
- # not usually used below
- streaming: bool = False
- max_new_tokens: int = 1024
- top_p: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
- repetition_penalty: Annotated[float, Field(ge=0.9, le=2.0, strict=True)] = 1.2
- temperature: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
|