e2e_webui.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. import re
  2. import gradio as gr
  3. import numpy as np
  4. from .fish_e2e import FishE2EAgent, FishE2EEventType
  5. from .schema import ServeMessage, ServeTextPart, ServeVQPart
  6. class ChatState:
  7. def __init__(self):
  8. self.conversation = []
  9. self.added_systext = False
  10. self.added_sysaudio = False
  11. def get_history(self):
  12. results = []
  13. for msg in self.conversation:
  14. results.append({"role": msg.role, "content": self.repr_message(msg)})
  15. # Process assistant messages to extract questions and update user messages
  16. for i, msg in enumerate(results):
  17. if msg["role"] == "assistant":
  18. match = re.search(r"Question: (.*?)\n\nResponse:", msg["content"])
  19. if match and i > 0 and results[i - 1]["role"] == "user":
  20. # Update previous user message with extracted question
  21. results[i - 1]["content"] += "\n" + match.group(1)
  22. # Remove the Question/Answer format from assistant message
  23. msg["content"] = msg["content"].split("\n\nResponse: ", 1)[1]
  24. return results
  25. def repr_message(self, msg: ServeMessage):
  26. response = ""
  27. for part in msg.parts:
  28. if isinstance(part, ServeTextPart):
  29. response += part.text
  30. elif isinstance(part, ServeVQPart):
  31. response += f"<audio {len(part.codes[0]) / 21:.2f}s>"
  32. return response
  33. def clear_fn():
  34. return [], ChatState(), None, None, None
  35. async def process_audio_input(
  36. sys_audio_input, sys_text_input, audio_input, state: ChatState, text_input: str
  37. ):
  38. if audio_input is None and not text_input:
  39. raise gr.Error("No input provided")
  40. agent = FishE2EAgent() # Create new agent instance for each request
  41. # Convert audio input to numpy array
  42. if isinstance(audio_input, tuple):
  43. sr, audio_data = audio_input
  44. elif text_input:
  45. sr = 44100
  46. audio_data = None
  47. else:
  48. raise gr.Error("Invalid audio format")
  49. if isinstance(sys_audio_input, tuple):
  50. sr, sys_audio_data = sys_audio_input
  51. elif text_input:
  52. sr = 44100
  53. sys_audio_data = None
  54. else:
  55. raise gr.Error("Invalid audio format")
  56. def append_to_chat_ctx(
  57. part: ServeTextPart | ServeVQPart, role: str = "assistant"
  58. ) -> None:
  59. if not state.conversation or state.conversation[-1].role != role:
  60. state.conversation.append(ServeMessage(role=role, parts=[part]))
  61. else:
  62. state.conversation[-1].parts.append(part)
  63. if state.added_systext is False and sys_text_input:
  64. state.added_systext = True
  65. append_to_chat_ctx(ServeTextPart(text=sys_text_input), role="system")
  66. if text_input:
  67. append_to_chat_ctx(ServeTextPart(text=text_input), role="user")
  68. audio_data = None
  69. result_audio = b""
  70. async for event in agent.stream(
  71. sys_audio_data,
  72. audio_data,
  73. sr,
  74. 1,
  75. chat_ctx={
  76. "messages": state.conversation,
  77. "added_sysaudio": state.added_sysaudio,
  78. },
  79. ):
  80. if event.type == FishE2EEventType.USER_CODES:
  81. append_to_chat_ctx(ServeVQPart(codes=event.vq_codes), role="user")
  82. elif event.type == FishE2EEventType.SPEECH_SEGMENT:
  83. result_audio += event.frame.data
  84. np_audio = np.frombuffer(result_audio, dtype=np.int16)
  85. append_to_chat_ctx(ServeVQPart(codes=event.vq_codes))
  86. yield state.get_history(), (44100, np_audio), None, None
  87. elif event.type == FishE2EEventType.TEXT_SEGMENT:
  88. append_to_chat_ctx(ServeTextPart(text=event.text))
  89. if result_audio:
  90. np_audio = np.frombuffer(result_audio, dtype=np.int16)
  91. yield state.get_history(), (44100, np_audio), None, None
  92. else:
  93. yield state.get_history(), None, None, None
  94. np_audio = np.frombuffer(result_audio, dtype=np.int16)
  95. yield state.get_history(), (44100, np_audio), None, None
  96. async def process_text_input(
  97. sys_audio_input, sys_text_input, state: ChatState, text_input: str
  98. ):
  99. async for event in process_audio_input(
  100. sys_audio_input, sys_text_input, None, state, text_input
  101. ):
  102. yield event
  103. def create_demo():
  104. with gr.Blocks() as demo:
  105. state = gr.State(ChatState())
  106. with gr.Row():
  107. # Left column (70%) for chatbot and notes
  108. with gr.Column(scale=7):
  109. chatbot = gr.Chatbot(
  110. [],
  111. elem_id="chatbot",
  112. bubble_full_width=False,
  113. height=600,
  114. type="messages",
  115. )
  116. notes = gr.Markdown(
  117. """
  118. # Fish Agent
  119. 1. 此Demo为Fish Audio自研端到端语言模型Fish Agent 3B版本.
  120. 2. 你可以在我们的官方仓库找到代码以及权重,但是相关内容全部基于 CC BY-NC-SA 4.0 许可证发布.
  121. 3. Demo为早期灰度测试版本,推理速度尚待优化.
  122. # 特色
  123. 1. 该模型自动集成ASR与TTS部分,不需要外挂其它模型,即真正的端到端,而非三段式(ASR+LLM+TTS).
  124. 2. 模型可以使用reference audio控制说话音色.
  125. 3. 可以生成具有较强情感与韵律的音频.
  126. """
  127. )
  128. # Right column (30%) for controls
  129. with gr.Column(scale=3):
  130. sys_audio_input = gr.Audio(
  131. sources=["upload"],
  132. type="numpy",
  133. label="Give a timbre for your assistant",
  134. )
  135. sys_text_input = gr.Textbox(
  136. label="What is your assistant's role?",
  137. value='您是由 Fish Audio 设计的语音助手,提供端到端的语音交互,实现无缝用户体验。首先转录用户的语音,然后使用以下格式回答:"Question: [用户语音]\n\nResponse: [你的回答]\n"。',
  138. type="text",
  139. )
  140. audio_input = gr.Audio(
  141. sources=["microphone"], type="numpy", label="Speak your message"
  142. )
  143. text_input = gr.Textbox(label="Or type your message", type="text")
  144. output_audio = gr.Audio(label="Assistant's Voice", type="numpy")
  145. send_button = gr.Button("Send", variant="primary")
  146. clear_button = gr.Button("Clear")
  147. # Event handlers
  148. audio_input.stop_recording(
  149. process_audio_input,
  150. inputs=[sys_audio_input, sys_text_input, audio_input, state, text_input],
  151. outputs=[chatbot, output_audio, audio_input, text_input],
  152. show_progress=True,
  153. )
  154. send_button.click(
  155. process_text_input,
  156. inputs=[sys_audio_input, sys_text_input, state, text_input],
  157. outputs=[chatbot, output_audio, audio_input, text_input],
  158. show_progress=True,
  159. )
  160. text_input.submit(
  161. process_text_input,
  162. inputs=[sys_audio_input, sys_text_input, state, text_input],
  163. outputs=[chatbot, output_audio, audio_input, text_input],
  164. show_progress=True,
  165. )
  166. clear_button.click(
  167. clear_fn,
  168. inputs=[],
  169. outputs=[chatbot, state, audio_input, output_audio, text_input],
  170. )
  171. return demo
  172. if __name__ == "__main__":
  173. demo = create_demo()
  174. demo.launch(server_name="127.0.0.1", server_port=7860, share=True)