schema.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. import base64
  2. import os
  3. import queue
  4. from dataclasses import dataclass
  5. from typing import Literal
  6. import torch
  7. from pydantic import BaseModel, Field, conint, model_validator
  8. from pydantic.functional_validators import SkipValidation
  9. from typing_extensions import Annotated
  10. from fish_speech.content_sequence import TextPart, VQPart
  11. class ServeVQPart(BaseModel):
  12. type: Literal["vq"] = "vq"
  13. codes: SkipValidation[list[list[int]]]
  14. class ServeTextPart(BaseModel):
  15. type: Literal["text"] = "text"
  16. text: str
  17. class ServeAudioPart(BaseModel):
  18. type: Literal["audio"] = "audio"
  19. audio: bytes
  20. class ServeRequest(BaseModel):
  21. # Raw content sequence dict that we can use with ContentSequence(**content)
  22. content: dict
  23. max_new_tokens: int = 600
  24. top_p: float = 0.7
  25. repetition_penalty: float = 1.2
  26. temperature: float = 0.7
  27. streaming: bool = False
  28. num_samples: int = 1
  29. early_stop_threshold: float = 1.0
  30. class ServeVQGANEncodeRequest(BaseModel):
  31. # The audio here should be in wav, mp3, etc
  32. audios: list[bytes]
  33. class ServeVQGANEncodeResponse(BaseModel):
  34. tokens: SkipValidation[list[list[list[int]]]]
  35. class ServeVQGANDecodeRequest(BaseModel):
  36. tokens: SkipValidation[list[list[list[int]]]]
  37. class ServeVQGANDecodeResponse(BaseModel):
  38. # The audio here should be in PCM float16 format
  39. audios: list[bytes]
  40. class ServeReferenceAudio(BaseModel):
  41. audio: bytes
  42. text: str
  43. @model_validator(mode="before")
  44. def decode_audio(cls, values):
  45. audio = values.get("audio")
  46. if (
  47. isinstance(audio, str) and len(audio) > 255
  48. ): # Check if audio is a string (Base64)
  49. try:
  50. values["audio"] = base64.b64decode(audio)
  51. except Exception:
  52. # If the audio is not a valid base64 string, we will just ignore it and let the server handle it
  53. pass
  54. return values
  55. def __repr__(self) -> str:
  56. return f"ServeReferenceAudio(text={self.text!r}, audio_size={len(self.audio)})"
  57. class ServeTTSRequest(BaseModel):
  58. text: str
  59. chunk_length: Annotated[int, conint(ge=100, le=300, strict=True)] = 200
  60. # Audio format
  61. format: Literal["wav", "pcm", "mp3"] = "wav"
  62. # References audios for in-context learning
  63. references: list[ServeReferenceAudio] = []
  64. # Reference id
  65. # For example, if you want use https://fish.audio/m/7f92f8afb8ec43bf81429cc1c9199cb1/
  66. # Just pass 7f92f8afb8ec43bf81429cc1c9199cb1
  67. reference_id: str | None = None
  68. seed: int | None = None
  69. use_memory_cache: Literal["on", "off"] = "off"
  70. # Normalize text for en & zh, this increase stability for numbers
  71. normalize: bool = True
  72. # not usually used below
  73. streaming: bool = False
  74. max_new_tokens: int = 1024
  75. top_p: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.8
  76. repetition_penalty: Annotated[float, Field(ge=0.9, le=2.0, strict=True)] = 1.1
  77. temperature: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.8
  78. class Config:
  79. # Allow arbitrary types for pytorch related types
  80. arbitrary_types_allowed = True
  81. class AddReferenceRequest(BaseModel):
  82. id: str = Field(..., min_length=1, max_length=255, pattern=r"^[a-zA-Z0-9\-_ ]+$")
  83. audio: bytes
  84. text: str = Field(..., min_length=1)
  85. class AddReferenceResponse(BaseModel):
  86. success: bool
  87. message: str
  88. reference_id: str
  89. class ListReferencesResponse(BaseModel):
  90. success: bool
  91. reference_ids: list[str]
  92. message: str = "Success"
  93. class DeleteReferenceResponse(BaseModel):
  94. success: bool
  95. message: str
  96. reference_id: str
  97. class UpdateReferenceResponse(BaseModel):
  98. success: bool
  99. message: str
  100. old_reference_id: str
  101. new_reference_id: str