schema.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. import os
  2. import queue
  3. from dataclasses import dataclass
  4. from typing import Annotated, Literal, Optional
  5. import torch
  6. from pydantic import AfterValidator, BaseModel, Field, confloat, conint, conlist
  7. from pydantic.functional_validators import SkipValidation
  8. from fish_speech.conversation import Message, TextPart, VQPart
  9. GLOBAL_NUM_SAMPLES = int(os.getenv("GLOBAL_NUM_SAMPLES", 1))
  10. class ServeVQPart(BaseModel):
  11. type: Literal["vq"] = "vq"
  12. codes: SkipValidation[list[list[int]]]
  13. class ServeTextPart(BaseModel):
  14. type: Literal["text"] = "text"
  15. text: str
  16. class ServeAudioPart(BaseModel):
  17. type: Literal["audio"] = "audio"
  18. audio: bytes
  19. @dataclass
  20. class ASRPackRequest:
  21. audio: torch.Tensor
  22. result_queue: queue.Queue
  23. language: str
  24. class ServeASRRequest(BaseModel):
  25. # The audio should be an uncompressed PCM float16 audio
  26. audios: list[bytes]
  27. sample_rate: int = 44100
  28. language: Literal["zh", "en", "ja", "auto"] = "auto"
  29. class ServeASRTranscription(BaseModel):
  30. text: str
  31. duration: float
  32. huge_gap: bool
  33. class ServeASRSegment(BaseModel):
  34. text: str
  35. start: float
  36. end: float
  37. class ServeTimedASRResponse(BaseModel):
  38. text: str
  39. segments: list[ServeASRSegment]
  40. duration: float
  41. class ServeASRResponse(BaseModel):
  42. transcriptions: list[ServeASRTranscription]
  43. class ServeMessage(BaseModel):
  44. role: Literal["system", "assistant", "user"]
  45. parts: list[ServeVQPart | ServeTextPart]
  46. def to_conversation_message(self):
  47. new_message = Message(role=self.role, parts=[])
  48. for part in self.parts:
  49. if isinstance(part, ServeTextPart):
  50. new_message.parts.append(TextPart(text=part.text))
  51. elif isinstance(part, ServeVQPart):
  52. new_message.parts.append(
  53. VQPart(codes=torch.tensor(part.codes, dtype=torch.int))
  54. )
  55. else:
  56. raise ValueError(f"Unsupported part type: {part}")
  57. return new_message
  58. class ServeRequest(BaseModel):
  59. messages: Annotated[list[ServeMessage], conlist(ServeMessage, min_length=1)]
  60. max_new_tokens: int = 1024
  61. top_p: float = 0.7
  62. repetition_penalty: float = 1.2
  63. temperature: float = 0.7
  64. streaming: bool = False
  65. num_samples: int = 1
  66. early_stop_threshold: float = 1.0
  67. class ServeVQGANEncodeRequest(BaseModel):
  68. # The audio here should be in wav, mp3, etc
  69. audios: list[bytes]
  70. class ServeVQGANEncodeResponse(BaseModel):
  71. tokens: SkipValidation[list[list[list[int]]]]
  72. class ServeVQGANDecodeRequest(BaseModel):
  73. tokens: SkipValidation[list[list[list[int]]]]
  74. class ServeVQGANDecodeResponse(BaseModel):
  75. # The audio here should be in PCM float16 format
  76. audios: list[bytes]
  77. class ServeReferenceAudio(BaseModel):
  78. audio: bytes
  79. text: str
  80. class ServeForwardMessage(BaseModel):
  81. role: str
  82. content: str
  83. class ServeResponse(BaseModel):
  84. messages: list[ServeMessage]
  85. finish_reason: Literal["stop", "error"] | None = None
  86. stats: dict[str, int | float | str] = {}
  87. class ServeStreamDelta(BaseModel):
  88. role: Literal["system", "assistant", "user"] | None = None
  89. part: ServeVQPart | ServeTextPart | None = None
  90. class ServeStreamResponse(BaseModel):
  91. sample_id: int = 0
  92. delta: ServeStreamDelta | None = None
  93. finish_reason: Literal["stop", "error"] | None = None
  94. stats: dict[str, int | float | str] | None = None
  95. class ServeReferenceAudio(BaseModel):
  96. audio: bytes
  97. text: str
  98. def __repr__(self) -> str:
  99. return f"ServeReferenceAudio(text={self.text!r}, audio_size={len(self.audio)})"
  100. class ServeChatRequestV1(BaseModel):
  101. model: str = "llama3-8b"
  102. messages: list[ServeForwardMessage] = []
  103. audio: bytes | None = None
  104. temperature: float = 1.0
  105. top_p: float = 1.0
  106. max_tokens: int = 256
  107. voice: str = "jessica"
  108. tts_audio_format: Literal["mp3", "pcm", "opus"] = "mp3"
  109. tts_audio_bitrate: Literal[16, 24, 32, 48, 64, 96, 128, 192] = 128
  110. class ServeTTSRequest(BaseModel):
  111. text: str
  112. chunk_length: Annotated[int, conint(ge=100, le=300, strict=True)] = 200
  113. # Audio format
  114. format: Literal["wav", "pcm", "mp3"] = "wav"
  115. mp3_bitrate: Literal[64, 128, 192] = 128
  116. # References audios for in-context learning
  117. references: list[ServeReferenceAudio] = []
  118. # Reference id
  119. # For example, if you want use https://fish.audio/m/7f92f8afb8ec43bf81429cc1c9199cb1/
  120. # Just pass 7f92f8afb8ec43bf81429cc1c9199cb1
  121. reference_id: str | None = None
  122. seed: int | None = None
  123. use_memory_cache: Literal["on-demand", "never"] = "never"
  124. # Normalize text for en & zh, this increase stability for numbers
  125. normalize: bool = True
  126. mp3_bitrate: Optional[int] = 64
  127. opus_bitrate: Optional[int] = -1000
  128. # Balance mode will reduce latency to 300ms, but may decrease stability
  129. latency: Literal["normal", "balanced"] = "normal"
  130. # not usually used below
  131. streaming: bool = False
  132. max_new_tokens: int = 1024
  133. top_p: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
  134. repetition_penalty: Annotated[float, Field(ge=0.9, le=2.0, strict=True)] = 1.2
  135. temperature: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7