| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298 |
- import base64
- import ctypes
- import io
- import json
- import os
- import struct
- from dataclasses import dataclass
- from enum import Enum
- from typing import AsyncGenerator, Union
- import httpx
- import numpy as np
- import ormsgpack
- import soundfile as sf
- from .schema import (
- ServeMessage,
- ServeRequest,
- ServeTextPart,
- ServeVQGANDecodeRequest,
- ServeVQGANEncodeRequest,
- ServeVQPart,
- )
- class CustomAudioFrame:
- def __init__(self, data, sample_rate, num_channels, samples_per_channel):
- if len(data) < num_channels * samples_per_channel * ctypes.sizeof(
- ctypes.c_int16
- ):
- raise ValueError(
- "data length must be >= num_channels * samples_per_channel * sizeof(int16)"
- )
- self._data = bytearray(data)
- self._sample_rate = sample_rate
- self._num_channels = num_channels
- self._samples_per_channel = samples_per_channel
- @property
- def data(self):
- return memoryview(self._data).cast("h")
- @property
- def sample_rate(self):
- return self._sample_rate
- @property
- def num_channels(self):
- return self._num_channels
- @property
- def samples_per_channel(self):
- return self._samples_per_channel
- @property
- def duration(self):
- return self.samples_per_channel / self.sample_rate
- def __repr__(self):
- return (
- f"CustomAudioFrame(sample_rate={self.sample_rate}, "
- f"num_channels={self.num_channels}, "
- f"samples_per_channel={self.samples_per_channel}, "
- f"duration={self.duration:.3f})"
- )
- class FishE2EEventType(Enum):
- SPEECH_SEGMENT = 1
- TEXT_SEGMENT = 2
- END_OF_TEXT = 3
- END_OF_SPEECH = 4
- ASR_RESULT = 5
- USER_CODES = 6
- @dataclass
- class FishE2EEvent:
- type: FishE2EEventType
- frame: np.ndarray = None
- text: str = None
- vq_codes: list[list[int]] = None
- client = httpx.AsyncClient(
- timeout=None,
- limits=httpx.Limits(
- max_connections=None,
- max_keepalive_connections=None,
- keepalive_expiry=None,
- ),
- )
- class FishE2EAgent:
- def __init__(self):
- self.llm_url = "http://localhost:8080/v1/chat"
- self.vqgan_url = "http://localhost:8080"
- self.client = httpx.AsyncClient(timeout=None)
- async def get_codes(self, audio_data, sample_rate):
- audio_buffer = io.BytesIO()
- sf.write(audio_buffer, audio_data, sample_rate, format="WAV")
- audio_buffer.seek(0)
- # Step 1: Encode audio using VQGAN
- encode_request = ServeVQGANEncodeRequest(audios=[audio_buffer.read()])
- encode_request_bytes = ormsgpack.packb(
- encode_request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC
- )
- encode_response = await self.client.post(
- f"{self.vqgan_url}/v1/vqgan/encode",
- data=encode_request_bytes,
- headers={"Content-Type": "application/msgpack"},
- )
- encode_response_data = ormsgpack.unpackb(encode_response.content)
- codes = encode_response_data["tokens"][0]
- return codes
- async def stream(
- self,
- system_audio_data: np.ndarray | None,
- user_audio_data: np.ndarray | None,
- sample_rate: int,
- num_channels: int,
- chat_ctx: dict | None = None,
- ) -> AsyncGenerator[bytes, None]:
- if system_audio_data is not None:
- sys_codes = await self.get_codes(system_audio_data, sample_rate)
- else:
- sys_codes = None
- if user_audio_data is not None:
- user_codes = await self.get_codes(user_audio_data, sample_rate)
- # Step 2: Prepare LLM request
- if chat_ctx is None:
- sys_parts = [
- ServeTextPart(
- text='您是由 Fish Audio 设计的语音助手,提供端到端的语音交互,实现无缝用户体验。首先转录用户的语音,然后使用以下格式回答:"Question: [用户语音]\n\nAnswer: [你的回答]\n"。'
- ),
- ]
- if system_audio_data is not None:
- sys_parts.append(ServeVQPart(codes=sys_codes))
- chat_ctx = {
- "messages": [
- ServeMessage(
- role="system",
- parts=sys_parts,
- ),
- ],
- }
- else:
- if chat_ctx["added_sysaudio"] is False and sys_codes:
- chat_ctx["added_sysaudio"] = True
- chat_ctx["messages"][0].parts.append(ServeVQPart(codes=sys_codes))
- prev_messages = chat_ctx["messages"].copy()
- if user_audio_data is not None:
- yield FishE2EEvent(
- type=FishE2EEventType.USER_CODES,
- vq_codes=user_codes,
- )
- else:
- user_codes = None
- request = ServeRequest(
- messages=prev_messages
- + (
- [
- ServeMessage(
- role="user",
- parts=[ServeVQPart(codes=user_codes)],
- )
- ]
- if user_codes
- else []
- ),
- streaming=True,
- num_samples=1,
- )
- # Step 3: Stream LLM response and decode audio
- buffer = b""
- vq_codes = []
- current_vq = False
- async def decode_send():
- nonlocal current_vq
- nonlocal vq_codes
- data = np.concatenate(vq_codes, axis=1).tolist()
- # Decode VQ codes to audio
- decode_request = ServeVQGANDecodeRequest(tokens=[data])
- decode_response = await self.client.post(
- f"{self.vqgan_url}/v1/vqgan/decode",
- data=ormsgpack.packb(
- decode_request,
- option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
- ),
- headers={"Content-Type": "application/msgpack"},
- )
- decode_data = ormsgpack.unpackb(decode_response.content)
- # Convert float16 audio data to int16
- audio_data = np.frombuffer(decode_data["audios"][0], dtype=np.float16)
- audio_data = (audio_data * 32768).astype(np.int16).tobytes()
- audio_frame = CustomAudioFrame(
- data=audio_data,
- samples_per_channel=len(audio_data) // 2,
- sample_rate=44100,
- num_channels=1,
- )
- yield FishE2EEvent(
- type=FishE2EEventType.SPEECH_SEGMENT,
- frame=audio_frame,
- vq_codes=data,
- )
- current_vq = False
- vq_codes = []
- async with self.client.stream(
- "POST",
- self.llm_url,
- data=ormsgpack.packb(request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC),
- headers={"Content-Type": "application/msgpack"},
- ) as response:
- async for chunk in response.aiter_bytes():
- buffer += chunk
- while len(buffer) >= 4:
- read_length = struct.unpack("I", buffer[:4])[0]
- if len(buffer) < 4 + read_length:
- break
- body = buffer[4 : 4 + read_length]
- buffer = buffer[4 + read_length :]
- data = ormsgpack.unpackb(body)
- if data["delta"] and data["delta"]["part"]:
- if current_vq and data["delta"]["part"]["type"] == "text":
- async for event in decode_send():
- yield event
- if data["delta"]["part"]["type"] == "text":
- yield FishE2EEvent(
- type=FishE2EEventType.TEXT_SEGMENT,
- text=data["delta"]["part"]["text"],
- )
- elif data["delta"]["part"]["type"] == "vq":
- vq_codes.append(np.array(data["delta"]["part"]["codes"]))
- current_vq = True
- if current_vq and vq_codes:
- async for event in decode_send():
- yield event
- yield FishE2EEvent(type=FishE2EEventType.END_OF_TEXT)
- yield FishE2EEvent(type=FishE2EEventType.END_OF_SPEECH)
- # Example usage:
- async def main():
- import torchaudio
- agent = FishE2EAgent()
- # Replace this with actual audio data loading
- with open("uz_story_en.m4a", "rb") as f:
- audio_data = f.read()
- audio_data, sample_rate = torchaudio.load("uz_story_en.m4a")
- audio_data = (audio_data.numpy() * 32768).astype(np.int16)
- stream = agent.stream(audio_data, sample_rate, 1)
- if os.path.exists("audio_segment.wav"):
- os.remove("audio_segment.wav")
- async for event in stream:
- if event.type == FishE2EEventType.SPEECH_SEGMENT:
- # Handle speech segment (e.g., play audio or save to file)
- with open("audio_segment.wav", "ab+") as f:
- f.write(event.frame.data)
- elif event.type == FishE2EEventType.ASR_RESULT:
- print(event.text, flush=True)
- elif event.type == FishE2EEventType.TEXT_SEGMENT:
- print(event.text, flush=True, end="")
- elif event.type == FishE2EEventType.END_OF_TEXT:
- print("\nEnd of text reached.")
- elif event.type == FishE2EEventType.END_OF_SPEECH:
- print("End of speech reached.")
- if __name__ == "__main__":
- import asyncio
- asyncio.run(main())
|