msgpack_api.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. from typing import Annotated, AsyncGenerator, Literal, Optional
  2. import httpx
  3. import ormsgpack
  4. from pydantic import AfterValidator, BaseModel, Field, conint
  5. class ServeReferenceAudio(BaseModel):
  6. audio: bytes
  7. text: str
  8. class ServeTTSRequest(BaseModel):
  9. text: str
  10. chunk_length: Annotated[int, conint(ge=100, le=300, strict=True)] = 200
  11. # Audio format
  12. format: Literal["wav", "pcm", "mp3"] = "wav"
  13. mp3_bitrate: Literal[64, 128, 192] = 128
  14. # References audios for in-context learning
  15. references: list[ServeReferenceAudio] = []
  16. # Reference id
  17. # For example, if you want use https://fish.audio/m/7f92f8afb8ec43bf81429cc1c9199cb1/
  18. # Just pass 7f92f8afb8ec43bf81429cc1c9199cb1
  19. reference_id: str | None = None
  20. # Normalize text for en & zh, this increase stability for numbers
  21. normalize: bool = True
  22. mp3_bitrate: Optional[int] = 64
  23. opus_bitrate: Optional[int] = -1000
  24. # Balance mode will reduce latency to 300ms, but may decrease stability
  25. latency: Literal["normal", "balanced"] = "normal"
  26. # not usually used below
  27. streaming: bool = False
  28. emotion: Optional[str] = None
  29. max_new_tokens: int = 1024
  30. top_p: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
  31. repetition_penalty: Annotated[float, Field(ge=0.9, le=2.0, strict=True)] = 1.2
  32. temperature: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
  33. # priority: ref_id > references
  34. request = ServeTTSRequest(
  35. text="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.",
  36. # reference_id="114514",
  37. references=[
  38. ServeReferenceAudio(
  39. audio=open("lengyue.wav", "rb").read(),
  40. text=open("lengyue.lab", "r", encoding="utf-8").read(),
  41. )
  42. ],
  43. streaming=True,
  44. )
  45. with (
  46. httpx.Client() as client,
  47. open("hello.wav", "wb") as f,
  48. ):
  49. with client.stream(
  50. "POST",
  51. "http://127.0.0.1:8080/v1/tts",
  52. content=ormsgpack.packb(request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC),
  53. headers={
  54. "authorization": "Bearer YOUR_API_KEY",
  55. "content-type": "application/msgpack",
  56. },
  57. timeout=None,
  58. ) as response:
  59. for chunk in response.iter_bytes():
  60. f.write(chunk)