schema.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. import base64
  2. import os
  3. import queue
  4. from dataclasses import dataclass
  5. from typing import Literal
  6. import torch
  7. from pydantic import BaseModel, Field, conint, conlist, model_validator
  8. from pydantic.functional_validators import SkipValidation
  9. from typing_extensions import Annotated
  10. from fish_speech.conversation import Message, TextPart, VQPart
  11. class ServeVQPart(BaseModel):
  12. type: Literal["vq"] = "vq"
  13. codes: SkipValidation[list[list[int]]]
  14. class ServeTextPart(BaseModel):
  15. type: Literal["text"] = "text"
  16. text: str
  17. class ServeAudioPart(BaseModel):
  18. type: Literal["audio"] = "audio"
  19. audio: bytes
  20. @dataclass
  21. class ASRPackRequest:
  22. audio: torch.Tensor
  23. result_queue: queue.Queue
  24. language: str
  25. class ServeASRRequest(BaseModel):
  26. # The audio should be an uncompressed PCM float16 audio
  27. audios: list[bytes]
  28. sample_rate: int = 44100
  29. language: Literal["zh", "en", "ja", "auto"] = "auto"
  30. class ServeASRTranscription(BaseModel):
  31. text: str
  32. duration: float
  33. huge_gap: bool
  34. class ServeASRSegment(BaseModel):
  35. text: str
  36. start: float
  37. end: float
  38. class ServeTimedASRResponse(BaseModel):
  39. text: str
  40. segments: list[ServeASRSegment]
  41. duration: float
  42. class ServeASRResponse(BaseModel):
  43. transcriptions: list[ServeASRTranscription]
  44. class ServeMessage(BaseModel):
  45. role: Literal["system", "assistant", "user"]
  46. parts: list[ServeVQPart | ServeTextPart]
  47. def to_conversation_message(self):
  48. new_message = Message(role=self.role, parts=[])
  49. if self.role == "assistant":
  50. new_message.modality = "voice"
  51. for part in self.parts:
  52. if isinstance(part, ServeTextPart):
  53. new_message.parts.append(TextPart(text=part.text))
  54. elif isinstance(part, ServeVQPart):
  55. new_message.parts.append(
  56. VQPart(codes=torch.tensor(part.codes, dtype=torch.int))
  57. )
  58. else:
  59. raise ValueError(f"Unsupported part type: {part}")
  60. return new_message
  61. class ServeChatRequest(BaseModel):
  62. messages: Annotated[list[ServeMessage], conlist(ServeMessage, min_length=1)]
  63. max_new_tokens: int = 1024
  64. top_p: float = 0.7
  65. repetition_penalty: float = 1.2
  66. temperature: float = 0.7
  67. streaming: bool = False
  68. num_samples: int = 1
  69. early_stop_threshold: float = 1.0
  70. class ServeVQGANEncodeRequest(BaseModel):
  71. # The audio here should be in wav, mp3, etc
  72. audios: list[bytes]
  73. class ServeVQGANEncodeResponse(BaseModel):
  74. tokens: SkipValidation[list[list[list[int]]]]
  75. class ServeVQGANDecodeRequest(BaseModel):
  76. tokens: SkipValidation[list[list[list[int]]]]
  77. class ServeVQGANDecodeResponse(BaseModel):
  78. # The audio here should be in PCM float16 format
  79. audios: list[bytes]
  80. class ServeForwardMessage(BaseModel):
  81. role: str
  82. content: str
  83. class ServeResponse(BaseModel):
  84. messages: list[ServeMessage]
  85. finish_reason: Literal["stop", "error"] | None = None
  86. stats: dict[str, int | float | str] = {}
  87. class ServeStreamDelta(BaseModel):
  88. role: Literal["system", "assistant", "user"] | None = None
  89. part: ServeVQPart | ServeTextPart | None = None
  90. class ServeStreamResponse(BaseModel):
  91. sample_id: int = 0
  92. delta: ServeStreamDelta | None = None
  93. finish_reason: Literal["stop", "error"] | None = None
  94. stats: dict[str, int | float | str] | None = None
  95. class ServeReferenceAudio(BaseModel):
  96. audio: bytes
  97. text: str
  98. @model_validator(mode="before")
  99. def decode_audio(cls, values):
  100. audio = values.get("audio")
  101. if (
  102. isinstance(audio, str) and len(audio) > 255
  103. ): # Check if audio is a string (Base64)
  104. try:
  105. values["audio"] = base64.b64decode(audio)
  106. except Exception as e:
  107. # If the audio is not a valid base64 string, we will just ignore it and let the server handle it
  108. pass
  109. return values
  110. def __repr__(self) -> str:
  111. return f"ServeReferenceAudio(text={self.text!r}, audio_size={len(self.audio)})"
  112. class ServeTTSRequest(BaseModel):
  113. text: str
  114. chunk_length: Annotated[int, conint(ge=100, le=300, strict=True)] = 200
  115. # Audio format
  116. format: Literal["wav", "pcm", "mp3"] = "wav"
  117. # References audios for in-context learning
  118. references: list[ServeReferenceAudio] = []
  119. # Reference id
  120. # For example, if you want use https://fish.audio/m/7f92f8afb8ec43bf81429cc1c9199cb1/
  121. # Just pass 7f92f8afb8ec43bf81429cc1c9199cb1
  122. reference_id: str | None = None
  123. seed: int | None = None
  124. use_memory_cache: Literal["on", "off"] = "off"
  125. # Normalize text for en & zh, this increase stability for numbers
  126. normalize: bool = True
  127. # not usually used below
  128. streaming: bool = False
  129. max_new_tokens: int = 1024
  130. top_p: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
  131. repetition_penalty: Annotated[float, Field(ge=0.9, le=2.0, strict=True)] = 1.2
  132. temperature: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
  133. class Config:
  134. # Allow arbitrary types for pytorch related types
  135. arbitrary_types_allowed = True