schema.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. import os
  2. import queue
  3. from dataclasses import dataclass
  4. from typing import Literal
  5. import torch
  6. from pydantic import BaseModel, Field, conint, conlist
  7. from pydantic.functional_validators import SkipValidation
  8. from typing_extensions import Annotated
  9. from fish_speech.conversation import Message, TextPart, VQPart
  10. class ServeVQPart(BaseModel):
  11. type: Literal["vq"] = "vq"
  12. codes: SkipValidation[list[list[int]]]
  13. class ServeTextPart(BaseModel):
  14. type: Literal["text"] = "text"
  15. text: str
  16. class ServeAudioPart(BaseModel):
  17. type: Literal["audio"] = "audio"
  18. audio: bytes
  19. @dataclass
  20. class ASRPackRequest:
  21. audio: torch.Tensor
  22. result_queue: queue.Queue
  23. language: str
  24. class ServeASRRequest(BaseModel):
  25. # The audio should be an uncompressed PCM float16 audio
  26. audios: list[bytes]
  27. sample_rate: int = 44100
  28. language: Literal["zh", "en", "ja", "auto"] = "auto"
  29. class ServeASRTranscription(BaseModel):
  30. text: str
  31. duration: float
  32. huge_gap: bool
  33. class ServeASRSegment(BaseModel):
  34. text: str
  35. start: float
  36. end: float
  37. class ServeTimedASRResponse(BaseModel):
  38. text: str
  39. segments: list[ServeASRSegment]
  40. duration: float
  41. class ServeASRResponse(BaseModel):
  42. transcriptions: list[ServeASRTranscription]
  43. class ServeMessage(BaseModel):
  44. role: Literal["system", "assistant", "user"]
  45. parts: list[ServeVQPart | ServeTextPart]
  46. def to_conversation_message(self):
  47. new_message = Message(role=self.role, parts=[])
  48. if self.role == "assistant":
  49. new_message.modality = "voice"
  50. for part in self.parts:
  51. if isinstance(part, ServeTextPart):
  52. new_message.parts.append(TextPart(text=part.text))
  53. elif isinstance(part, ServeVQPart):
  54. new_message.parts.append(
  55. VQPart(codes=torch.tensor(part.codes, dtype=torch.int))
  56. )
  57. else:
  58. raise ValueError(f"Unsupported part type: {part}")
  59. return new_message
  60. class ServeChatRequest(BaseModel):
  61. messages: Annotated[list[ServeMessage], conlist(ServeMessage, min_length=1)]
  62. max_new_tokens: int = 1024
  63. top_p: float = 0.7
  64. repetition_penalty: float = 1.2
  65. temperature: float = 0.7
  66. streaming: bool = False
  67. num_samples: int = 1
  68. early_stop_threshold: float = 1.0
  69. class ServeVQGANEncodeRequest(BaseModel):
  70. # The audio here should be in wav, mp3, etc
  71. audios: list[bytes]
  72. class ServeVQGANEncodeResponse(BaseModel):
  73. tokens: SkipValidation[list[list[list[int]]]]
  74. class ServeVQGANDecodeRequest(BaseModel):
  75. tokens: SkipValidation[list[list[list[int]]]]
  76. class ServeVQGANDecodeResponse(BaseModel):
  77. # The audio here should be in PCM float16 format
  78. audios: list[bytes]
  79. class ServeForwardMessage(BaseModel):
  80. role: str
  81. content: str
  82. class ServeResponse(BaseModel):
  83. messages: list[ServeMessage]
  84. finish_reason: Literal["stop", "error"] | None = None
  85. stats: dict[str, int | float | str] = {}
  86. class ServeStreamDelta(BaseModel):
  87. role: Literal["system", "assistant", "user"] | None = None
  88. part: ServeVQPart | ServeTextPart | None = None
  89. class ServeStreamResponse(BaseModel):
  90. sample_id: int = 0
  91. delta: ServeStreamDelta | None = None
  92. finish_reason: Literal["stop", "error"] | None = None
  93. stats: dict[str, int | float | str] | None = None
  94. class ServeReferenceAudio(BaseModel):
  95. audio: bytes
  96. text: str
  97. def __repr__(self) -> str:
  98. return f"ServeReferenceAudio(text={self.text!r}, audio_size={len(self.audio)})"
  99. class ServeTTSRequest(BaseModel):
  100. text: str
  101. chunk_length: Annotated[int, conint(ge=100, le=300, strict=True)] = 200
  102. # Audio format
  103. format: Literal["wav", "pcm", "mp3"] = "wav"
  104. # References audios for in-context learning
  105. references: list[ServeReferenceAudio] = []
  106. # Reference id
  107. # For example, if you want use https://fish.audio/m/7f92f8afb8ec43bf81429cc1c9199cb1/
  108. # Just pass 7f92f8afb8ec43bf81429cc1c9199cb1
  109. reference_id: str | None = None
  110. seed: int | None = None
  111. use_memory_cache: Literal["on", "off"] = "off"
  112. # Normalize text for en & zh, this increase stability for numbers
  113. normalize: bool = True
  114. # not usually used below
  115. streaming: bool = False
  116. max_new_tokens: int = 1024
  117. top_p: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
  118. repetition_penalty: Annotated[float, Field(ge=0.9, le=2.0, strict=True)] = 1.2
  119. temperature: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
  120. class Config:
  121. # Allow arbitrary types for pytorch related types
  122. arbitrary_types_allowed = True