schema.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. import os
  2. import queue
  3. from dataclasses import dataclass
  4. from typing import Annotated, Literal
  5. import torch
  6. from pydantic import BaseModel, Field, conint, conlist
  7. from pydantic.functional_validators import SkipValidation
  8. from fish_speech.conversation import Message, TextPart, VQPart
  9. class ServeVQPart(BaseModel):
  10. type: Literal["vq"] = "vq"
  11. codes: SkipValidation[list[list[int]]]
  12. class ServeTextPart(BaseModel):
  13. type: Literal["text"] = "text"
  14. text: str
  15. class ServeAudioPart(BaseModel):
  16. type: Literal["audio"] = "audio"
  17. audio: bytes
  18. @dataclass
  19. class ASRPackRequest:
  20. audio: torch.Tensor
  21. result_queue: queue.Queue
  22. language: str
  23. class ServeASRRequest(BaseModel):
  24. # The audio should be an uncompressed PCM float16 audio
  25. audios: list[bytes]
  26. sample_rate: int = 44100
  27. language: Literal["zh", "en", "ja", "auto"] = "auto"
  28. class ServeASRTranscription(BaseModel):
  29. text: str
  30. duration: float
  31. huge_gap: bool
  32. class ServeASRSegment(BaseModel):
  33. text: str
  34. start: float
  35. end: float
  36. class ServeTimedASRResponse(BaseModel):
  37. text: str
  38. segments: list[ServeASRSegment]
  39. duration: float
  40. class ServeASRResponse(BaseModel):
  41. transcriptions: list[ServeASRTranscription]
  42. class ServeMessage(BaseModel):
  43. role: Literal["system", "assistant", "user"]
  44. parts: list[ServeVQPart | ServeTextPart]
  45. def to_conversation_message(self):
  46. new_message = Message(role=self.role, parts=[])
  47. if self.role == "assistant":
  48. new_message.modality = "voice"
  49. for part in self.parts:
  50. if isinstance(part, ServeTextPart):
  51. new_message.parts.append(TextPart(text=part.text))
  52. elif isinstance(part, ServeVQPart):
  53. new_message.parts.append(
  54. VQPart(codes=torch.tensor(part.codes, dtype=torch.int))
  55. )
  56. else:
  57. raise ValueError(f"Unsupported part type: {part}")
  58. return new_message
  59. class ServeChatRequest(BaseModel):
  60. messages: Annotated[list[ServeMessage], conlist(ServeMessage, min_length=1)]
  61. max_new_tokens: int = 1024
  62. top_p: float = 0.7
  63. repetition_penalty: float = 1.2
  64. temperature: float = 0.7
  65. streaming: bool = False
  66. num_samples: int = 1
  67. early_stop_threshold: float = 1.0
  68. class ServeVQGANEncodeRequest(BaseModel):
  69. # The audio here should be in wav, mp3, etc
  70. audios: list[bytes]
  71. class ServeVQGANEncodeResponse(BaseModel):
  72. tokens: SkipValidation[list[list[list[int]]]]
  73. class ServeVQGANDecodeRequest(BaseModel):
  74. tokens: SkipValidation[list[list[list[int]]]]
  75. class ServeVQGANDecodeResponse(BaseModel):
  76. # The audio here should be in PCM float16 format
  77. audios: list[bytes]
  78. class ServeForwardMessage(BaseModel):
  79. role: str
  80. content: str
  81. class ServeResponse(BaseModel):
  82. messages: list[ServeMessage]
  83. finish_reason: Literal["stop", "error"] | None = None
  84. stats: dict[str, int | float | str] = {}
  85. class ServeStreamDelta(BaseModel):
  86. role: Literal["system", "assistant", "user"] | None = None
  87. part: ServeVQPart | ServeTextPart | None = None
  88. class ServeStreamResponse(BaseModel):
  89. sample_id: int = 0
  90. delta: ServeStreamDelta | None = None
  91. finish_reason: Literal["stop", "error"] | None = None
  92. stats: dict[str, int | float | str] | None = None
  93. class ServeReferenceAudio(BaseModel):
  94. audio: bytes
  95. text: str
  96. def __repr__(self) -> str:
  97. return f"ServeReferenceAudio(text={self.text!r}, audio_size={len(self.audio)})"
  98. class ServeTTSRequest(BaseModel):
  99. text: str
  100. chunk_length: Annotated[int, conint(ge=100, le=300, strict=True)] = 200
  101. # Audio format
  102. format: Literal["wav", "pcm", "mp3"] = "wav"
  103. # References audios for in-context learning
  104. references: list[ServeReferenceAudio] = []
  105. # Reference id
  106. # For example, if you want use https://fish.audio/m/7f92f8afb8ec43bf81429cc1c9199cb1/
  107. # Just pass 7f92f8afb8ec43bf81429cc1c9199cb1
  108. reference_id: str | None = None
  109. seed: int | None = None
  110. use_memory_cache: Literal["on", "off"] = "off"
  111. # Normalize text for en & zh, this increase stability for numbers
  112. normalize: bool = True
  113. # not usually used below
  114. streaming: bool = False
  115. max_new_tokens: int = 1024
  116. top_p: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
  117. repetition_penalty: Annotated[float, Field(ge=0.9, le=2.0, strict=True)] = 1.2
  118. temperature: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
  119. class Config:
  120. # Allow arbitrary types for pytorch related types
  121. arbitrary_types_allowed = True