schema.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. import os
  2. import queue
  3. from dataclasses import dataclass
  4. from typing import Annotated, Literal, Optional
  5. import torch
  6. from pydantic import AfterValidator, BaseModel, Field, confloat, conint, conlist
  7. from pydantic.functional_validators import SkipValidation
  8. from fish_speech.conversation import Message, TextPart, VQPart
  9. GLOBAL_NUM_SAMPLES = int(os.getenv("GLOBAL_NUM_SAMPLES", 1))
  10. class ServeVQPart(BaseModel):
  11. type: Literal["vq"] = "vq"
  12. codes: SkipValidation[list[list[int]]]
  13. class ServeTextPart(BaseModel):
  14. type: Literal["text"] = "text"
  15. text: str
  16. class ServeAudioPart(BaseModel):
  17. type: Literal["audio"] = "audio"
  18. audio: bytes
  19. @dataclass
  20. class ASRPackRequest:
  21. audio: torch.Tensor
  22. result_queue: queue.Queue
  23. language: str
  24. class ServeASRRequest(BaseModel):
  25. # The audio should be an uncompressed PCM float16 audio
  26. audios: list[bytes]
  27. sample_rate: int = 44100
  28. language: Literal["zh", "en", "ja", "auto"] = "auto"
  29. class ServeASRTranscription(BaseModel):
  30. text: str
  31. duration: float
  32. huge_gap: bool
  33. class ServeASRSegment(BaseModel):
  34. text: str
  35. start: float
  36. end: float
  37. class ServeTimedASRResponse(BaseModel):
  38. text: str
  39. segments: list[ServeASRSegment]
  40. duration: float
  41. class ServeASRResponse(BaseModel):
  42. transcriptions: list[ServeASRTranscription]
  43. class ServeMessage(BaseModel):
  44. role: Literal["system", "assistant", "user", "raw"]
  45. parts: list[ServeVQPart | ServeTextPart]
  46. def to_conversation_message(self):
  47. new_message = Message(role=self.role, parts=[])
  48. if self.role == "assistant":
  49. new_message.modality = "voice"
  50. for part in self.parts:
  51. if isinstance(part, ServeTextPart):
  52. new_message.parts.append(TextPart(text=part.text))
  53. elif isinstance(part, ServeVQPart):
  54. new_message.parts.append(
  55. VQPart(codes=torch.tensor(part.codes, dtype=torch.int))
  56. )
  57. else:
  58. raise ValueError(f"Unsupported part type: {part}")
  59. return new_message
  60. class ServeRequest(BaseModel):
  61. messages: Annotated[list[ServeMessage], conlist(ServeMessage, min_length=1)]
  62. max_new_tokens: int = 1024
  63. top_p: float = 0.7
  64. repetition_penalty: float = 1.2
  65. temperature: float = 0.7
  66. streaming: bool = False
  67. num_samples: int = 1
  68. early_stop_threshold: float = 1.0
  69. class ServeVQGANEncodeRequest(BaseModel):
  70. # The audio here should be in wav, mp3, etc
  71. audios: list[bytes]
  72. class ServeVQGANEncodeResponse(BaseModel):
  73. tokens: SkipValidation[list[list[list[int]]]]
  74. class ServeVQGANDecodeRequest(BaseModel):
  75. tokens: SkipValidation[list[list[list[int]]]]
  76. class ServeVQGANDecodeResponse(BaseModel):
  77. # The audio here should be in PCM float16 format
  78. audios: list[bytes]
  79. class ServeReferenceAudio(BaseModel):
  80. audio: bytes
  81. text: str
  82. class ServeForwardMessage(BaseModel):
  83. role: str
  84. content: str
  85. class ServeResponse(BaseModel):
  86. messages: list[ServeMessage]
  87. finish_reason: Literal["stop", "error"] | None = None
  88. stats: dict[str, int | float | str] = {}
  89. class ServeStreamDelta(BaseModel):
  90. role: Literal["system", "assistant", "user"] | None = None
  91. part: ServeVQPart | ServeTextPart | None = None
  92. class ServeStreamResponse(BaseModel):
  93. sample_id: int = 0
  94. delta: ServeStreamDelta | None = None
  95. finish_reason: Literal["stop", "error"] | None = None
  96. stats: dict[str, int | float | str] | None = None
  97. class ServeReferenceAudio(BaseModel):
  98. audio: bytes
  99. text: str
  100. def __repr__(self) -> str:
  101. return f"ServeReferenceAudio(text={self.text!r}, audio_size={len(self.audio)})"
  102. class ServeChatRequestV1(BaseModel):
  103. model: str = "llama3-8b"
  104. messages: list[ServeForwardMessage] = []
  105. audio: bytes | None = None
  106. temperature: float = 1.0
  107. top_p: float = 1.0
  108. max_tokens: int = 256
  109. voice: str = "jessica"
  110. tts_audio_format: Literal["mp3", "pcm", "opus"] = "mp3"
  111. tts_audio_bitrate: Literal[16, 24, 32, 48, 64, 96, 128, 192] = 128
  112. class ServeTTSRequest(BaseModel):
  113. text: str
  114. chunk_length: Annotated[int, conint(ge=100, le=300, strict=True)] = 200
  115. # Audio format
  116. format: Literal["wav", "pcm", "mp3"] = "wav"
  117. mp3_bitrate: Literal[64, 128, 192] = 128
  118. # References audios for in-context learning
  119. references: list[ServeReferenceAudio] = []
  120. # Reference id
  121. # For example, if you want use https://fish.audio/m/7f92f8afb8ec43bf81429cc1c9199cb1/
  122. # Just pass 7f92f8afb8ec43bf81429cc1c9199cb1
  123. reference_id: str | None = None
  124. seed: int | None = None
  125. use_memory_cache: Literal["on-demand", "never"] = "never"
  126. # Normalize text for en & zh, this increase stability for numbers
  127. normalize: bool = True
  128. mp3_bitrate: Optional[int] = 64
  129. opus_bitrate: Optional[int] = -1000
  130. # Balance mode will reduce latency to 300ms, but may decrease stability
  131. latency: Literal["normal", "balanced"] = "normal"
  132. # not usually used below
  133. streaming: bool = False
  134. max_new_tokens: int = 1024
  135. top_p: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
  136. repetition_penalty: Annotated[float, Field(ge=0.9, le=2.0, strict=True)] = 1.2
  137. temperature: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7