fish_e2e.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256
  1. import base64
  2. import io
  3. import json
  4. import os
  5. import struct
  6. from dataclasses import dataclass
  7. from enum import Enum
  8. from typing import AsyncGenerator
  9. import httpx
  10. import numpy as np
  11. import ormsgpack
  12. import soundfile as sf
  13. from livekit import rtc
  14. from livekit.agents.llm.chat_context import ChatContext
  15. from .schema import (
  16. ServeMessage,
  17. ServeRequest,
  18. ServeTextPart,
  19. ServeVQGANDecodeRequest,
  20. ServeVQGANEncodeRequest,
  21. ServeVQPart,
  22. )
  23. class FishE2EEventType(Enum):
  24. SPEECH_SEGMENT = 1
  25. TEXT_SEGMENT = 2
  26. END_OF_TEXT = 3
  27. END_OF_SPEECH = 4
  28. ASR_RESULT = 5
  29. USER_CODES = 6
  30. @dataclass
  31. class FishE2EEvent:
  32. type: FishE2EEventType
  33. frame: rtc.AudioFrame = None
  34. text: str = None
  35. vq_codes: list[list[int]] = None
  36. client = httpx.AsyncClient(
  37. timeout=None,
  38. limits=httpx.Limits(
  39. max_connections=None,
  40. max_keepalive_connections=None,
  41. keepalive_expiry=None,
  42. ),
  43. )
  44. class FishE2EAgent:
  45. def __init__(self):
  46. self.llm_url = "http://localhost:8080/v1/chat"
  47. self.vqgan_url = "http://localhost:8080"
  48. self.client = httpx.AsyncClient(timeout=None)
  49. async def get_codes(self, audio_data, sample_rate):
  50. audio_buffer = io.BytesIO()
  51. sf.write(audio_buffer, audio_data, sample_rate, format="WAV")
  52. audio_buffer.seek(0)
  53. # Step 1: Encode audio using VQGAN
  54. encode_request = ServeVQGANEncodeRequest(audios=[audio_buffer.read()])
  55. encode_request_bytes = ormsgpack.packb(
  56. encode_request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC
  57. )
  58. encode_response = await self.client.post(
  59. f"{self.vqgan_url}/v1/vqgan/encode",
  60. data=encode_request_bytes,
  61. headers={"Content-Type": "application/msgpack"},
  62. )
  63. encode_response_data = ormsgpack.unpackb(encode_response.content)
  64. codes = encode_response_data["tokens"][0]
  65. return codes
  66. async def stream(
  67. self,
  68. system_audio_data: np.ndarray | None,
  69. user_audio_data: np.ndarray | None,
  70. sample_rate: int,
  71. num_channels: int,
  72. chat_ctx: ChatContext | None = None,
  73. ) -> AsyncGenerator[bytes, None]:
  74. if system_audio_data is not None:
  75. sys_codes = await self.get_codes(system_audio_data, sample_rate)
  76. else:
  77. sys_codes = None
  78. if user_audio_data is not None:
  79. user_codes = await self.get_codes(user_audio_data, sample_rate)
  80. # Step 2: Prepare LLM request
  81. if chat_ctx is None:
  82. sys_parts = [
  83. ServeTextPart(
  84. text='您是由 Fish Audio 设计的语音助手,提供端到端的语音交互,实现无缝用户体验。首先转录用户的语音,然后使用以下格式回答:"Question: [用户语音]\n\nAnswer: [你的回答]\n"。'
  85. ),
  86. ]
  87. if system_audio_data is not None:
  88. sys_parts.append(ServeVQPart(codes=sys_codes))
  89. chat_ctx = {
  90. "messages": [
  91. ServeMessage(
  92. role="system",
  93. parts=sys_parts,
  94. ),
  95. ],
  96. }
  97. else:
  98. if chat_ctx["added_sysaudio"] is False and sys_codes:
  99. chat_ctx["added_sysaudio"] = True
  100. chat_ctx["messages"][0].parts.append(ServeVQPart(codes=sys_codes))
  101. prev_messages = chat_ctx["messages"].copy()
  102. if user_audio_data is not None:
  103. yield FishE2EEvent(
  104. type=FishE2EEventType.USER_CODES,
  105. vq_codes=user_codes,
  106. )
  107. else:
  108. user_codes = None
  109. request = ServeRequest(
  110. messages=prev_messages
  111. + (
  112. [
  113. ServeMessage(
  114. role="user",
  115. parts=[ServeVQPart(codes=user_codes)],
  116. )
  117. ]
  118. if user_codes
  119. else []
  120. ),
  121. streaming=True,
  122. num_samples=1,
  123. )
  124. # Step 3: Stream LLM response and decode audio
  125. buffer = b""
  126. vq_codes = []
  127. current_vq = False
  128. async def decode_send():
  129. nonlocal current_vq
  130. nonlocal vq_codes
  131. data = np.concatenate(vq_codes, axis=1).tolist()
  132. # Decode VQ codes to audio
  133. decode_request = ServeVQGANDecodeRequest(tokens=[data])
  134. decode_response = await self.client.post(
  135. f"{self.vqgan_url}/v1/vqgan/decode",
  136. data=ormsgpack.packb(
  137. decode_request,
  138. option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
  139. ),
  140. headers={"Content-Type": "application/msgpack"},
  141. )
  142. decode_data = ormsgpack.unpackb(decode_response.content)
  143. # Convert float16 audio data to int16
  144. audio_data = np.frombuffer(decode_data["audios"][0], dtype=np.float16)
  145. audio_data = (audio_data * 32768).astype(np.int16).tobytes()
  146. audio_frame = rtc.AudioFrame(
  147. data=audio_data,
  148. samples_per_channel=len(audio_data) // 2,
  149. sample_rate=44100,
  150. num_channels=1,
  151. )
  152. yield FishE2EEvent(
  153. type=FishE2EEventType.SPEECH_SEGMENT,
  154. frame=audio_frame,
  155. vq_codes=data,
  156. )
  157. current_vq = False
  158. vq_codes = []
  159. async with self.client.stream(
  160. "POST",
  161. self.llm_url,
  162. data=ormsgpack.packb(request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC),
  163. headers={"Content-Type": "application/msgpack"},
  164. ) as response:
  165. async for chunk in response.aiter_bytes():
  166. buffer += chunk
  167. while len(buffer) >= 4:
  168. read_length = struct.unpack("I", buffer[:4])[0]
  169. if len(buffer) < 4 + read_length:
  170. break
  171. body = buffer[4 : 4 + read_length]
  172. buffer = buffer[4 + read_length :]
  173. data = ormsgpack.unpackb(body)
  174. if data["delta"] and data["delta"]["part"]:
  175. if current_vq and data["delta"]["part"]["type"] == "text":
  176. async for event in decode_send():
  177. yield event
  178. if data["delta"]["part"]["type"] == "text":
  179. yield FishE2EEvent(
  180. type=FishE2EEventType.TEXT_SEGMENT,
  181. text=data["delta"]["part"]["text"],
  182. )
  183. elif data["delta"]["part"]["type"] == "vq":
  184. vq_codes.append(np.array(data["delta"]["part"]["codes"]))
  185. current_vq = True
  186. if current_vq and vq_codes:
  187. async for event in decode_send():
  188. yield event
  189. yield FishE2EEvent(type=FishE2EEventType.END_OF_TEXT)
  190. yield FishE2EEvent(type=FishE2EEventType.END_OF_SPEECH)
  191. # Example usage:
  192. async def main():
  193. import torchaudio
  194. agent = FishE2EAgent()
  195. # Replace this with actual audio data loading
  196. with open("uz_story_en.m4a", "rb") as f:
  197. audio_data = f.read()
  198. audio_data, sample_rate = torchaudio.load("uz_story_en.m4a")
  199. audio_data = (audio_data.numpy() * 32768).astype(np.int16)
  200. stream = agent.stream(audio_data, sample_rate, 1)
  201. if os.path.exists("audio_segment.wav"):
  202. os.remove("audio_segment.wav")
  203. async for event in stream:
  204. if event.type == FishE2EEventType.SPEECH_SEGMENT:
  205. # Handle speech segment (e.g., play audio or save to file)
  206. with open("audio_segment.wav", "ab+") as f:
  207. f.write(event.frame.data)
  208. elif event.type == FishE2EEventType.ASR_RESULT:
  209. print(event.text, flush=True)
  210. elif event.type == FishE2EEventType.TEXT_SEGMENT:
  211. print(event.text, flush=True, end="")
  212. elif event.type == FishE2EEventType.END_OF_TEXT:
  213. print("\nEnd of text reached.")
  214. elif event.type == FishE2EEventType.END_OF_SPEECH:
  215. print("End of speech reached.")
  216. if __name__ == "__main__":
  217. import asyncio
  218. asyncio.run(main())