schema.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. import base64
  2. import os
  3. import queue
  4. from dataclasses import dataclass
  5. from typing import Literal
  6. import torch
  7. from pydantic import BaseModel, Field, conint, model_validator
  8. from pydantic.functional_validators import SkipValidation
  9. from typing_extensions import Annotated
  10. from fish_speech.content_sequence import TextPart, VQPart
  11. class ServeVQPart(BaseModel):
  12. type: Literal["vq"] = "vq"
  13. codes: SkipValidation[list[list[int]]]
  14. class ServeTextPart(BaseModel):
  15. type: Literal["text"] = "text"
  16. text: str
  17. class ServeAudioPart(BaseModel):
  18. type: Literal["audio"] = "audio"
  19. audio: bytes
  20. @dataclass
  21. class ASRPackRequest:
  22. audio: torch.Tensor
  23. result_queue: queue.Queue
  24. language: str
  25. class ServeASRRequest(BaseModel):
  26. # The audio should be an uncompressed PCM float16 audio
  27. audios: list[bytes]
  28. sample_rate: int = 44100
  29. language: Literal["zh", "en", "ja", "auto"] = "auto"
  30. class ServeASRTranscription(BaseModel):
  31. text: str
  32. duration: float
  33. huge_gap: bool
  34. class ServeASRSegment(BaseModel):
  35. text: str
  36. start: float
  37. end: float
  38. class ServeTimedASRResponse(BaseModel):
  39. text: str
  40. segments: list[ServeASRSegment]
  41. duration: float
  42. class ServeASRResponse(BaseModel):
  43. transcriptions: list[ServeASRTranscription]
  44. class ServeRequest(BaseModel):
  45. # Raw content sequence dict that we can use with ContentSequence(**content)
  46. content: dict
  47. max_new_tokens: int = 600
  48. top_p: float = 0.7
  49. repetition_penalty: float = 1.2
  50. temperature: float = 0.7
  51. streaming: bool = False
  52. num_samples: int = 1
  53. early_stop_threshold: float = 1.0
  54. class ServeVQGANEncodeRequest(BaseModel):
  55. # The audio here should be in wav, mp3, etc
  56. audios: list[bytes]
  57. class ServeVQGANEncodeResponse(BaseModel):
  58. tokens: SkipValidation[list[list[list[int]]]]
  59. class ServeVQGANDecodeRequest(BaseModel):
  60. tokens: SkipValidation[list[list[list[int]]]]
  61. class ServeVQGANDecodeResponse(BaseModel):
  62. # The audio here should be in PCM float16 format
  63. audios: list[bytes]
  64. class ServeContentSequenceParts(BaseModel):
  65. parts: list[VQPart | TextPart]
  66. class ServeResponse(BaseModel):
  67. content_sequences: list[ServeContentSequenceParts]
  68. finish_reason: Literal["stop", "error"] | None = None
  69. stats: dict[str, int | float | str] = {}
  70. finished: list[bool] | None = None
  71. class ServeStreamDelta(BaseModel):
  72. role: Literal["system", "assistant", "user"] | None = None
  73. part: ServeVQPart | ServeTextPart | None = None
  74. class ServeStreamResponse(BaseModel):
  75. sample_id: int = 0
  76. delta: ServeStreamDelta | None = None
  77. finish_reason: Literal["stop", "error"] | None = None
  78. stats: dict[str, int | float | str] | None = None
  79. class ServeReferenceAudio(BaseModel):
  80. audio: bytes
  81. text: str
  82. @model_validator(mode="before")
  83. def decode_audio(cls, values):
  84. audio = values.get("audio")
  85. if (
  86. isinstance(audio, str) and len(audio) > 255
  87. ): # Check if audio is a string (Base64)
  88. try:
  89. values["audio"] = base64.b64decode(audio)
  90. except Exception as e:
  91. # If the audio is not a valid base64 string, we will just ignore it and let the server handle it
  92. pass
  93. return values
  94. def __repr__(self) -> str:
  95. return f"ServeReferenceAudio(text={self.text!r}, audio_size={len(self.audio)})"
  96. class ServeTTSRequest(BaseModel):
  97. text: str
  98. chunk_length: Annotated[int, conint(ge=100, le=300, strict=True)] = 200
  99. # Audio format
  100. format: Literal["wav", "pcm", "mp3"] = "wav"
  101. # References audios for in-context learning
  102. references: list[ServeReferenceAudio] = []
  103. # Reference id
  104. # For example, if you want use https://fish.audio/m/7f92f8afb8ec43bf81429cc1c9199cb1/
  105. # Just pass 7f92f8afb8ec43bf81429cc1c9199cb1
  106. reference_id: str | None = None
  107. seed: int | None = None
  108. use_memory_cache: Literal["on", "off"] = "off"
  109. # Normalize text for en & zh, this increase stability for numbers
  110. normalize: bool = True
  111. # not usually used below
  112. streaming: bool = False
  113. max_new_tokens: int = 1024
  114. top_p: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
  115. repetition_penalty: Annotated[float, Field(ge=0.9, le=2.0, strict=True)] = 1.2
  116. temperature: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
  117. class Config:
  118. # Allow arbitrary types for pytorch related types
  119. arbitrary_types_allowed = True