e2e_webui.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232
  1. import io
  2. import re
  3. import wave
  4. import gradio as gr
  5. import numpy as np
  6. from .fish_e2e import FishE2EAgent, FishE2EEventType
  7. from .schema import ServeMessage, ServeTextPart, ServeVQPart
  8. def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
  9. buffer = io.BytesIO()
  10. with wave.open(buffer, "wb") as wav_file:
  11. wav_file.setnchannels(channels)
  12. wav_file.setsampwidth(bit_depth // 8)
  13. wav_file.setframerate(sample_rate)
  14. wav_header_bytes = buffer.getvalue()
  15. buffer.close()
  16. return wav_header_bytes
  17. class ChatState:
  18. def __init__(self):
  19. self.conversation = []
  20. self.added_systext = False
  21. self.added_sysaudio = False
  22. def get_history(self):
  23. results = []
  24. for msg in self.conversation:
  25. results.append({"role": msg.role, "content": self.repr_message(msg)})
  26. # Process assistant messages to extract questions and update user messages
  27. for i, msg in enumerate(results):
  28. if msg["role"] == "assistant":
  29. match = re.search(r"Question: (.*?)\n\nResponse:", msg["content"])
  30. if match and i > 0 and results[i - 1]["role"] == "user":
  31. # Update previous user message with extracted question
  32. results[i - 1]["content"] += "\n" + match.group(1)
  33. # Remove the Question/Answer format from assistant message
  34. msg["content"] = msg["content"].split("\n\nResponse: ", 1)[1]
  35. return results
  36. def repr_message(self, msg: ServeMessage):
  37. response = ""
  38. for part in msg.parts:
  39. if isinstance(part, ServeTextPart):
  40. response += part.text
  41. elif isinstance(part, ServeVQPart):
  42. response += f"<audio {len(part.codes[0]) / 21:.2f}s>"
  43. return response
  44. def clear_fn():
  45. return [], ChatState(), None, None, None
  46. async def process_audio_input(
  47. sys_audio_input, sys_text_input, audio_input, state: ChatState, text_input: str
  48. ):
  49. if audio_input is None and not text_input:
  50. raise gr.Error("No input provided")
  51. agent = FishE2EAgent() # Create new agent instance for each request
  52. # Convert audio input to numpy array
  53. if isinstance(audio_input, tuple):
  54. sr, audio_data = audio_input
  55. elif text_input:
  56. sr = 44100
  57. audio_data = None
  58. else:
  59. raise gr.Error("Invalid audio format")
  60. if isinstance(sys_audio_input, tuple):
  61. sr, sys_audio_data = sys_audio_input
  62. else:
  63. sr = 44100
  64. sys_audio_data = None
  65. def append_to_chat_ctx(
  66. part: ServeTextPart | ServeVQPart, role: str = "assistant"
  67. ) -> None:
  68. if not state.conversation or state.conversation[-1].role != role:
  69. state.conversation.append(ServeMessage(role=role, parts=[part]))
  70. else:
  71. state.conversation[-1].parts.append(part)
  72. if state.added_systext is False and sys_text_input:
  73. state.added_systext = True
  74. append_to_chat_ctx(ServeTextPart(text=sys_text_input), role="system")
  75. if text_input:
  76. append_to_chat_ctx(ServeTextPart(text=text_input), role="user")
  77. audio_data = None
  78. result_audio = b""
  79. async for event in agent.stream(
  80. sys_audio_data,
  81. audio_data,
  82. sr,
  83. 1,
  84. chat_ctx={
  85. "messages": state.conversation,
  86. "added_sysaudio": state.added_sysaudio,
  87. },
  88. ):
  89. if event.type == FishE2EEventType.USER_CODES:
  90. append_to_chat_ctx(ServeVQPart(codes=event.vq_codes), role="user")
  91. elif event.type == FishE2EEventType.SPEECH_SEGMENT:
  92. append_to_chat_ctx(ServeVQPart(codes=event.vq_codes))
  93. yield state.get_history(), wav_chunk_header() + event.frame.data, None, None
  94. elif event.type == FishE2EEventType.TEXT_SEGMENT:
  95. append_to_chat_ctx(ServeTextPart(text=event.text))
  96. yield state.get_history(), None, None, None
  97. yield state.get_history(), None, None, None
  98. async def process_text_input(
  99. sys_audio_input, sys_text_input, state: ChatState, text_input: str
  100. ):
  101. async for event in process_audio_input(
  102. sys_audio_input, sys_text_input, None, state, text_input
  103. ):
  104. yield event
  105. def create_demo():
  106. with gr.Blocks() as demo:
  107. state = gr.State(ChatState())
  108. with gr.Row():
  109. # Left column (70%) for chatbot and notes
  110. with gr.Column(scale=7):
  111. chatbot = gr.Chatbot(
  112. [],
  113. elem_id="chatbot",
  114. bubble_full_width=False,
  115. height=600,
  116. type="messages",
  117. )
  118. # notes = gr.Markdown(
  119. # """
  120. # # Fish Agent
  121. # 1. 此Demo为Fish Audio自研端到端语言模型Fish Agent 3B版本.
  122. # 2. 你可以在我们的官方仓库找到代码以及权重,但是相关内容全部基于 CC BY-NC-SA 4.0 许可证发布.
  123. # 3. Demo为早期灰度测试版本,推理速度尚待优化.
  124. # # 特色
  125. # 1. 该模型自动集成ASR与TTS部分,不需要外挂其它模型,即真正的端到端,而非三段式(ASR+LLM+TTS).
  126. # 2. 模型可以使用reference audio控制说话音色.
  127. # 3. 可以生成具有较强情感与韵律的音频.
  128. # """
  129. # )
  130. notes = gr.Markdown(
  131. """
  132. # Fish Agent
  133. 1. This demo is Fish Audio's self-researh end-to-end language model, Fish Agent version 3B.
  134. 2. You can find the code and weights in our official repo in [gitub](https://github.com/fishaudio/fish-speech) and [hugging face](https://huggingface.co/fishaudio/fish-agent-v0.1-3b), but the content is released under a CC BY-NC-SA 4.0 licence.
  135. 3. The demo is an early alpha test version, the inference speed needs to be optimised.
  136. # Features
  137. 1. The model automatically integrates ASR and TTS parts, no need to plug-in other models, i.e., true end-to-end, not three-stage (ASR+LLM+TTS).
  138. 2. The model can use reference audio to control the speech timbre.
  139. 3. The model can generate speech with strong emotion.
  140. """
  141. )
  142. # Right column (30%) for controls
  143. with gr.Column(scale=3):
  144. sys_audio_input = gr.Audio(
  145. sources=["upload"],
  146. type="numpy",
  147. label="Give a timbre for your assistant",
  148. )
  149. sys_text_input = gr.Textbox(
  150. label="What is your assistant's role?",
  151. value="You are a voice assistant created by Fish Audio, offering end-to-end voice interaction for a seamless user experience. You are required to first transcribe the user's speech, then answer it in the following format: 'Question: [USER_SPEECH]\n\nAnswer: [YOUR_RESPONSE]\n'. You are required to use the following voice in this conversation.",
  152. type="text",
  153. )
  154. audio_input = gr.Audio(
  155. sources=["microphone"], type="numpy", label="Speak your message"
  156. )
  157. text_input = gr.Textbox(label="Or type your message", type="text")
  158. output_audio = gr.Audio(
  159. label="Assistant's Voice",
  160. streaming=True,
  161. autoplay=True,
  162. interactive=False,
  163. )
  164. send_button = gr.Button("Send", variant="primary")
  165. clear_button = gr.Button("Clear")
  166. # Event handlers
  167. audio_input.stop_recording(
  168. process_audio_input,
  169. inputs=[sys_audio_input, sys_text_input, audio_input, state, text_input],
  170. outputs=[chatbot, output_audio, audio_input, text_input],
  171. show_progress=True,
  172. )
  173. send_button.click(
  174. process_text_input,
  175. inputs=[sys_audio_input, sys_text_input, state, text_input],
  176. outputs=[chatbot, output_audio, audio_input, text_input],
  177. show_progress=True,
  178. )
  179. text_input.submit(
  180. process_text_input,
  181. inputs=[sys_audio_input, sys_text_input, state, text_input],
  182. outputs=[chatbot, output_audio, audio_input, text_input],
  183. show_progress=True,
  184. )
  185. clear_button.click(
  186. clear_fn,
  187. inputs=[],
  188. outputs=[chatbot, state, audio_input, output_audio, text_input],
  189. )
  190. return demo
  191. if __name__ == "__main__":
  192. demo = create_demo()
  193. demo.launch(server_name="127.0.0.1", server_port=7860, share=True)