schema.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. import base64
  2. import os
  3. import queue
  4. from dataclasses import dataclass
  5. from typing import Literal
  6. import torch
  7. from pydantic import BaseModel, Field, conint, model_validator
  8. from pydantic.functional_validators import SkipValidation
  9. from typing_extensions import Annotated
  10. from fish_speech.content_sequence import TextPart, VQPart
  11. class ServeVQPart(BaseModel):
  12. type: Literal["vq"] = "vq"
  13. codes: SkipValidation[list[list[int]]]
  14. class ServeTextPart(BaseModel):
  15. type: Literal["text"] = "text"
  16. text: str
  17. class ServeAudioPart(BaseModel):
  18. type: Literal["audio"] = "audio"
  19. audio: bytes
  20. class ServeASRRequest(BaseModel):
  21. # The audio should be an uncompressed PCM float16 audio
  22. audios: list[bytes]
  23. sample_rate: int = 44100
  24. language: Literal["zh", "en", "ja", "auto"] = "auto"
  25. class ServeASRTranscription(BaseModel):
  26. text: str
  27. duration: float
  28. huge_gap: bool
  29. class ServeASRSegment(BaseModel):
  30. text: str
  31. start: float
  32. end: float
  33. class ServeTimedASRResponse(BaseModel):
  34. text: str
  35. segments: list[ServeASRSegment]
  36. duration: float
  37. class ServeASRResponse(BaseModel):
  38. transcriptions: list[ServeASRTranscription]
  39. class ServeRequest(BaseModel):
  40. # Raw content sequence dict that we can use with ContentSequence(**content)
  41. content: dict
  42. max_new_tokens: int = 600
  43. top_p: float = 0.7
  44. repetition_penalty: float = 1.2
  45. temperature: float = 0.7
  46. streaming: bool = False
  47. num_samples: int = 1
  48. early_stop_threshold: float = 1.0
  49. class ServeVQGANEncodeRequest(BaseModel):
  50. # The audio here should be in wav, mp3, etc
  51. audios: list[bytes]
  52. class ServeVQGANEncodeResponse(BaseModel):
  53. tokens: SkipValidation[list[list[list[int]]]]
  54. class ServeVQGANDecodeRequest(BaseModel):
  55. tokens: SkipValidation[list[list[list[int]]]]
  56. class ServeVQGANDecodeResponse(BaseModel):
  57. # The audio here should be in PCM float16 format
  58. audios: list[bytes]
  59. class ServeStreamDelta(BaseModel):
  60. role: Literal["system", "assistant", "user"] | None = None
  61. part: ServeVQPart | ServeTextPart | None = None
  62. class ServeStreamResponse(BaseModel):
  63. sample_id: int = 0
  64. delta: ServeStreamDelta | None = None
  65. finish_reason: Literal["stop", "error"] | None = None
  66. stats: dict[str, int | float | str] | None = None
  67. class ServeReferenceAudio(BaseModel):
  68. audio: bytes
  69. text: str
  70. @model_validator(mode="before")
  71. def decode_audio(cls, values):
  72. audio = values.get("audio")
  73. if (
  74. isinstance(audio, str) and len(audio) > 255
  75. ): # Check if audio is a string (Base64)
  76. try:
  77. values["audio"] = base64.b64decode(audio)
  78. except Exception as e:
  79. # If the audio is not a valid base64 string, we will just ignore it and let the server handle it
  80. pass
  81. return values
  82. def __repr__(self) -> str:
  83. return f"ServeReferenceAudio(text={self.text!r}, audio_size={len(self.audio)})"
  84. class ServeTTSRequest(BaseModel):
  85. text: str
  86. chunk_length: Annotated[int, conint(ge=100, le=300, strict=True)] = 200
  87. # Audio format
  88. format: Literal["wav", "pcm", "mp3"] = "wav"
  89. # References audios for in-context learning
  90. references: list[ServeReferenceAudio] = []
  91. # Reference id
  92. # For example, if you want use https://fish.audio/m/7f92f8afb8ec43bf81429cc1c9199cb1/
  93. # Just pass 7f92f8afb8ec43bf81429cc1c9199cb1
  94. reference_id: str | None = None
  95. seed: int | None = None
  96. use_memory_cache: Literal["on", "off"] = "off"
  97. # Normalize text for en & zh, this increase stability for numbers
  98. normalize: bool = True
  99. # not usually used below
  100. streaming: bool = False
  101. max_new_tokens: int = 1024
  102. top_p: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.8
  103. repetition_penalty: Annotated[float, Field(ge=0.9, le=2.0, strict=True)] = 1.1
  104. temperature: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.8
  105. class Config:
  106. # Allow arbitrary types for pytorch related types
  107. arbitrary_types_allowed = True