|
|
@@ -1,4 +1,6 @@
|
|
|
+import io
|
|
|
import re
|
|
|
+import wave
|
|
|
|
|
|
import gradio as gr
|
|
|
import numpy as np
|
|
|
@@ -7,6 +9,19 @@ from .fish_e2e import FishE2EAgent, FishE2EEventType
|
|
|
from .schema import ServeMessage, ServeTextPart, ServeVQPart
|
|
|
|
|
|
|
|
|
+def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
|
|
|
+ buffer = io.BytesIO()
|
|
|
+
|
|
|
+ with wave.open(buffer, "wb") as wav_file:
|
|
|
+ wav_file.setnchannels(channels)
|
|
|
+ wav_file.setsampwidth(bit_depth // 8)
|
|
|
+ wav_file.setframerate(sample_rate)
|
|
|
+
|
|
|
+ wav_header_bytes = buffer.getvalue()
|
|
|
+ buffer.close()
|
|
|
+ return wav_header_bytes
|
|
|
+
|
|
|
+
|
|
|
class ChatState:
|
|
|
def __init__(self):
|
|
|
self.conversation = []
|
|
|
@@ -62,7 +77,7 @@ async def process_audio_input(
|
|
|
|
|
|
if isinstance(sys_audio_input, tuple):
|
|
|
sr, sys_audio_data = sys_audio_input
|
|
|
- elif text_input:
|
|
|
+ else:
|
|
|
sr = 44100
|
|
|
sys_audio_data = None
|
|
|
|
|
|
@@ -95,21 +110,13 @@ async def process_audio_input(
|
|
|
if event.type == FishE2EEventType.USER_CODES:
|
|
|
append_to_chat_ctx(ServeVQPart(codes=event.vq_codes), role="user")
|
|
|
elif event.type == FishE2EEventType.SPEECH_SEGMENT:
|
|
|
- result_audio += event.frame.data
|
|
|
- np_audio = np.frombuffer(result_audio, dtype=np.int16)
|
|
|
append_to_chat_ctx(ServeVQPart(codes=event.vq_codes))
|
|
|
-
|
|
|
- yield state.get_history(), (44100, np_audio), None, None
|
|
|
+ yield state.get_history(), wav_chunk_header() + event.frame.data, None, None
|
|
|
elif event.type == FishE2EEventType.TEXT_SEGMENT:
|
|
|
append_to_chat_ctx(ServeTextPart(text=event.text))
|
|
|
- if result_audio:
|
|
|
- np_audio = np.frombuffer(result_audio, dtype=np.int16)
|
|
|
- yield state.get_history(), (44100, np_audio), None, None
|
|
|
- else:
|
|
|
- yield state.get_history(), None, None, None
|
|
|
+ yield state.get_history(), None, None, None
|
|
|
|
|
|
- np_audio = np.frombuffer(result_audio, dtype=np.int16)
|
|
|
- yield state.get_history(), (44100, np_audio), None, None
|
|
|
+ yield state.get_history(), None, None, None
|
|
|
|
|
|
|
|
|
async def process_text_input(
|
|
|
@@ -179,7 +186,12 @@ def create_demo():
|
|
|
|
|
|
text_input = gr.Textbox(label="Or type your message", type="text")
|
|
|
|
|
|
- output_audio = gr.Audio(label="Assistant's Voice", type="numpy")
|
|
|
+ output_audio = gr.Audio(
|
|
|
+ label="Assistant's Voice",
|
|
|
+ streaming=True,
|
|
|
+ autoplay=True,
|
|
|
+ interactive=False,
|
|
|
+ )
|
|
|
|
|
|
send_button = gr.Button("Send", variant="primary")
|
|
|
clear_button = gr.Button("Clear")
|