post_api.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. import argparse
  2. import base64
  3. import json
  4. import pyaudio
  5. import requests
  6. def play_audio(audio_content, format, channels, rate):
  7. p = pyaudio.PyAudio()
  8. stream = p.open(format=format, channels=channels, rate=rate, output=True)
  9. stream.write(audio_content)
  10. stream.stop_stream()
  11. stream.close()
  12. p.terminate()
  13. if __name__ == "__main__":
  14. parser = argparse.ArgumentParser(
  15. description="Send a WAV file and text to a server and receive synthesized audio."
  16. )
  17. parser.add_argument(
  18. "--url",
  19. "-u",
  20. type=str,
  21. default="http://127.0.0.1:8000/v1/invoke",
  22. help="URL of the server",
  23. )
  24. parser.add_argument(
  25. "--text", "-t", type=str, required=True, help="Text to be synthesized"
  26. )
  27. parser.add_argument(
  28. "--reference_audio",
  29. "-ra",
  30. type=str,
  31. required=False,
  32. help="Path to the WAV file",
  33. )
  34. parser.add_argument(
  35. "--reference_text",
  36. "-rt",
  37. type=str,
  38. required=False,
  39. help="Reference text for voice synthesis",
  40. )
  41. parser.add_argument(
  42. "--max_new_tokens",
  43. type=int,
  44. default=1024,
  45. help="Maximum new tokens to generate",
  46. )
  47. parser.add_argument(
  48. "--chunk_length", type=int, default=100, help="Chunk length for synthesis"
  49. )
  50. parser.add_argument(
  51. "--top_p", type=float, default=0.7, help="Top-p sampling for synthesis"
  52. )
  53. parser.add_argument(
  54. "--repetition_penalty",
  55. type=float,
  56. default=1.2,
  57. help="Repetition penalty for synthesis",
  58. )
  59. parser.add_argument(
  60. "--temperature", type=float, default=0.7, help="Temperature for sampling"
  61. )
  62. parser.add_argument(
  63. "--speaker", type=str, default=None, help="Speaker ID for voice synthesis"
  64. )
  65. parser.add_argument("--emotion", type=str, default=None, help="Speaker's Emotion")
  66. parser.add_argument("--format", type=str, default="wav", help="Audio format")
  67. parser.add_argument(
  68. "--streaming", type=bool, default=False, help="Enable streaming response"
  69. )
  70. parser.add_argument(
  71. "--channels", type=int, default=1, help="Number of audio channels"
  72. )
  73. parser.add_argument("--rate", type=int, default=44100, help="Sample rate for audio")
  74. args = parser.parse_args()
  75. data = {
  76. "text": args.text,
  77. "reference_text": args.reference_text,
  78. "reference_audio": args.reference_audio,
  79. "max_new_tokens": args.max_new_tokens,
  80. "chunk_length": args.chunk_length,
  81. "top_p": args.top_p,
  82. "repetition_penalty": args.repetition_penalty,
  83. "temperature": args.temperature,
  84. "speaker": args.speaker,
  85. "emotion": args.emotion,
  86. "format": args.format,
  87. "streaming": args.streaming,
  88. }
  89. response = requests.post(args.url, json=data, stream=args.streaming)
  90. audio_format = pyaudio.paInt16 # Assuming 16-bit PCM format
  91. if response.status_code == 200:
  92. if args.streaming:
  93. p = pyaudio.PyAudio()
  94. stream = p.open(
  95. format=audio_format, channels=args.channels, rate=args.rate, output=True
  96. )
  97. for chunk in response.iter_content(chunk_size=1024):
  98. if chunk:
  99. stream.write(chunk)
  100. stream.stop_stream()
  101. stream.close()
  102. p.terminate()
  103. else:
  104. audio_content = response.content
  105. with open("generated_audio.wav", "wb") as audio_file:
  106. audio_file.write(audio_content)
  107. play_audio(audio_content, audio_format, args.channels, args.rate)
  108. print("Audio has been saved to 'generated_audio.wav'.")
  109. else:
  110. print(f"Request failed with status code {response.status_code}")
  111. print(response.json())