Explorar o código

Streaming Agent (#659)

* fix e2e_webui

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Agent: Streaming audio

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix text streaming

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
spicysama hai 1 ano
pai
achega
672000d594
Modificáronse 1 ficheiros con 25 adicións e 13 borrados
  1. 25 13
      tools/e2e_webui.py

+ 25 - 13
tools/e2e_webui.py

@@ -1,4 +1,6 @@
+import io
 import re
+import wave
 
 import gradio as gr
 import numpy as np
@@ -7,6 +9,19 @@ from .fish_e2e import FishE2EAgent, FishE2EEventType
 from .schema import ServeMessage, ServeTextPart, ServeVQPart
 
 
+def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
+    buffer = io.BytesIO()
+
+    with wave.open(buffer, "wb") as wav_file:
+        wav_file.setnchannels(channels)
+        wav_file.setsampwidth(bit_depth // 8)
+        wav_file.setframerate(sample_rate)
+
+    wav_header_bytes = buffer.getvalue()
+    buffer.close()
+    return wav_header_bytes
+
+
 class ChatState:
     def __init__(self):
         self.conversation = []
@@ -62,7 +77,7 @@ async def process_audio_input(
 
     if isinstance(sys_audio_input, tuple):
         sr, sys_audio_data = sys_audio_input
-    elif text_input:
+    else:
         sr = 44100
         sys_audio_data = None
 
@@ -95,21 +110,13 @@ async def process_audio_input(
         if event.type == FishE2EEventType.USER_CODES:
             append_to_chat_ctx(ServeVQPart(codes=event.vq_codes), role="user")
         elif event.type == FishE2EEventType.SPEECH_SEGMENT:
-            result_audio += event.frame.data
-            np_audio = np.frombuffer(result_audio, dtype=np.int16)
             append_to_chat_ctx(ServeVQPart(codes=event.vq_codes))
-
-            yield state.get_history(), (44100, np_audio), None, None
+            yield state.get_history(), wav_chunk_header() + event.frame.data, None, None
         elif event.type == FishE2EEventType.TEXT_SEGMENT:
             append_to_chat_ctx(ServeTextPart(text=event.text))
-            if result_audio:
-                np_audio = np.frombuffer(result_audio, dtype=np.int16)
-                yield state.get_history(), (44100, np_audio), None, None
-            else:
-                yield state.get_history(), None, None, None
+            yield state.get_history(), None, None, None
 
-    np_audio = np.frombuffer(result_audio, dtype=np.int16)
-    yield state.get_history(), (44100, np_audio), None, None
+    yield state.get_history(), None, None, None
 
 
 async def process_text_input(
@@ -179,7 +186,12 @@ def create_demo():
 
                 text_input = gr.Textbox(label="Or type your message", type="text")
 
-                output_audio = gr.Audio(label="Assistant's Voice", type="numpy")
+                output_audio = gr.Audio(
+                    label="Assistant's Voice",
+                    streaming=True,
+                    autoplay=True,
+                    interactive=False,
+                )
 
                 send_button = gr.Button("Send", variant="primary")
                 clear_button = gr.Button("Clear")