|
|
@@ -1,18 +1,17 @@
|
|
|
import base64
|
|
|
+import ctypes
|
|
|
import io
|
|
|
import json
|
|
|
import os
|
|
|
import struct
|
|
|
from dataclasses import dataclass
|
|
|
from enum import Enum
|
|
|
-from typing import AsyncGenerator
|
|
|
+from typing import AsyncGenerator, Union
|
|
|
|
|
|
import httpx
|
|
|
import numpy as np
|
|
|
import ormsgpack
|
|
|
import soundfile as sf
|
|
|
-from livekit import rtc
|
|
|
-from livekit.agents.llm.chat_context import ChatContext
|
|
|
|
|
|
from .schema import (
|
|
|
ServeMessage,
|
|
|
@@ -24,6 +23,49 @@ from .schema import (
|
|
|
)
|
|
|
|
|
|
|
|
|
+class CustomAudioFrame:
|
|
|
+ def __init__(self, data, sample_rate, num_channels, samples_per_channel):
|
|
|
+ if len(data) < num_channels * samples_per_channel * ctypes.sizeof(
|
|
|
+ ctypes.c_int16
|
|
|
+ ):
|
|
|
+ raise ValueError(
|
|
|
+ "data length must be >= num_channels * samples_per_channel * sizeof(int16)"
|
|
|
+ )
|
|
|
+
|
|
|
+ self._data = bytearray(data)
|
|
|
+ self._sample_rate = sample_rate
|
|
|
+ self._num_channels = num_channels
|
|
|
+ self._samples_per_channel = samples_per_channel
|
|
|
+
|
|
|
+ @property
|
|
|
+ def data(self):
|
|
|
+ return memoryview(self._data).cast("h")
|
|
|
+
|
|
|
+ @property
|
|
|
+ def sample_rate(self):
|
|
|
+ return self._sample_rate
|
|
|
+
|
|
|
+ @property
|
|
|
+ def num_channels(self):
|
|
|
+ return self._num_channels
|
|
|
+
|
|
|
+ @property
|
|
|
+ def samples_per_channel(self):
|
|
|
+ return self._samples_per_channel
|
|
|
+
|
|
|
+ @property
|
|
|
+ def duration(self):
|
|
|
+ return self.samples_per_channel / self.sample_rate
|
|
|
+
|
|
|
+ def __repr__(self):
|
|
|
+ return (
|
|
|
+ f"CustomAudioFrame(sample_rate={self.sample_rate}, "
|
|
|
+ f"num_channels={self.num_channels}, "
|
|
|
+ f"samples_per_channel={self.samples_per_channel}, "
|
|
|
+ f"duration={self.duration:.3f})"
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
class FishE2EEventType(Enum):
|
|
|
SPEECH_SEGMENT = 1
|
|
|
TEXT_SEGMENT = 2
|
|
|
@@ -36,7 +78,7 @@ class FishE2EEventType(Enum):
|
|
|
@dataclass
|
|
|
class FishE2EEvent:
|
|
|
type: FishE2EEventType
|
|
|
- frame: rtc.AudioFrame = None
|
|
|
+ frame: np.ndarray = None
|
|
|
text: str = None
|
|
|
vq_codes: list[list[int]] = None
|
|
|
|
|
|
@@ -81,7 +123,7 @@ class FishE2EAgent:
|
|
|
user_audio_data: np.ndarray | None,
|
|
|
sample_rate: int,
|
|
|
num_channels: int,
|
|
|
- chat_ctx: ChatContext | None = None,
|
|
|
+ chat_ctx: dict | None = None,
|
|
|
) -> AsyncGenerator[bytes, None]:
|
|
|
|
|
|
if system_audio_data is not None:
|
|
|
@@ -163,7 +205,7 @@ class FishE2EAgent:
|
|
|
audio_data = np.frombuffer(decode_data["audios"][0], dtype=np.float16)
|
|
|
audio_data = (audio_data * 32768).astype(np.int16).tobytes()
|
|
|
|
|
|
- audio_frame = rtc.AudioFrame(
|
|
|
+ audio_frame = CustomAudioFrame(
|
|
|
data=audio_data,
|
|
|
samples_per_channel=len(audio_data) // 2,
|
|
|
sample_rate=44100,
|