schema.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. import base64
  2. import os
  3. import queue
  4. from dataclasses import dataclass
  5. from typing import Literal
  6. import torch
  7. from pydantic import BaseModel, Field, conint, model_validator
  8. from pydantic.functional_validators import SkipValidation
  9. from typing_extensions import Annotated
  10. from fish_speech.content_sequence import TextPart, VQPart
  11. class ServeVQPart(BaseModel):
  12. type: Literal["vq"] = "vq"
  13. codes: SkipValidation[list[list[int]]]
  14. class ServeTextPart(BaseModel):
  15. type: Literal["text"] = "text"
  16. text: str
  17. class ServeAudioPart(BaseModel):
  18. type: Literal["audio"] = "audio"
  19. audio: bytes
  20. class ServeRequest(BaseModel):
  21. # Raw content sequence dict that we can use with ContentSequence(**content)
  22. content: dict
  23. max_new_tokens: int = 600
  24. top_p: float = 0.7
  25. repetition_penalty: float = 1.2
  26. temperature: float = 0.7
  27. streaming: bool = False
  28. num_samples: int = 1
  29. early_stop_threshold: float = 1.0
  30. class ServeVQGANEncodeRequest(BaseModel):
  31. # The audio here should be in wav, mp3, etc
  32. audios: list[bytes]
  33. class ServeVQGANEncodeResponse(BaseModel):
  34. tokens: SkipValidation[list[list[list[int]]]]
  35. class ServeVQGANDecodeRequest(BaseModel):
  36. tokens: SkipValidation[list[list[list[int]]]]
  37. class ServeVQGANDecodeResponse(BaseModel):
  38. # The audio here should be in PCM float16 format
  39. audios: list[bytes]
  40. class ServeReferenceAudio(BaseModel):
  41. audio: bytes
  42. text: str
  43. @model_validator(mode="before")
  44. def decode_audio(cls, values):
  45. audio = values.get("audio")
  46. if (
  47. isinstance(audio, str) and len(audio) > 255
  48. ): # Check if audio is a string (Base64)
  49. try:
  50. values["audio"] = base64.b64decode(audio)
  51. except Exception:
  52. # If the audio is not a valid base64 string, we will just ignore it and let the server handle it
  53. pass
  54. return values
  55. def __repr__(self) -> str:
  56. return f"ServeReferenceAudio(text={self.text!r}, audio_size={len(self.audio)})"
  57. class ServeTTSRequest(BaseModel):
  58. text: str
  59. chunk_length: Annotated[int, conint(ge=100, le=1000, strict=True)] = 200
  60. # Audio format
  61. format: Literal["wav", "pcm", "mp3", "opus"] = "wav"
  62. # Latency mode (used by api.fish.audio; "normal" or "balanced")
  63. latency: Literal["normal", "balanced"] = "normal"
  64. # References audios for in-context learning
  65. references: list[ServeReferenceAudio] = []
  66. # Reference id
  67. # For example, if you want use https://fish.audio/m/7f92f8afb8ec43bf81429cc1c9199cb1/
  68. # Just pass 7f92f8afb8ec43bf81429cc1c9199cb1
  69. reference_id: str | None = None
  70. seed: int | None = None
  71. use_memory_cache: Literal["on", "off"] = "off"
  72. # Normalize text for en & zh, this increase stability for numbers
  73. normalize: bool = True
  74. # not usually used below
  75. streaming: bool = False
  76. max_new_tokens: int = 1024
  77. top_p: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.8
  78. repetition_penalty: Annotated[float, Field(ge=0.9, le=2.0, strict=True)] = 1.1
  79. temperature: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.8
  80. class Config:
  81. # Allow arbitrary types for pytorch related types
  82. arbitrary_types_allowed = True
  83. class AddReferenceRequest(BaseModel):
  84. id: str = Field(..., min_length=1, max_length=255, pattern=r"^[a-zA-Z0-9\-_ ]+$")
  85. audio: bytes
  86. text: str = Field(..., min_length=1)
  87. class AddReferenceResponse(BaseModel):
  88. success: bool
  89. message: str
  90. reference_id: str
  91. class ListReferencesResponse(BaseModel):
  92. success: bool
  93. reference_ids: list[str]
  94. message: str = "Success"
  95. class DeleteReferenceResponse(BaseModel):
  96. success: bool
  97. message: str
  98. reference_id: str
  99. class UpdateReferenceResponse(BaseModel):
  100. success: bool
  101. message: str
  102. old_reference_id: str
  103. new_reference_id: str