stream_service.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412
  1. import time
  2. import librosa
  3. import numpy as np
  4. import torch
  5. import torchaudio
  6. from loguru import logger
  7. from torchaudio import functional as AF
  8. from transformers import (
  9. AutoModelForSpeechSeq2Seq,
  10. AutoProcessor,
  11. AutoTokenizer,
  12. pipeline,
  13. )
  14. from fish_speech.conversation import (
  15. CODEBOOK_EOS_TOKEN_ID,
  16. Conversation,
  17. Message,
  18. TokensPart,
  19. encode_conversation,
  20. )
  21. from fish_speech.models.text2semantic.llama import DualARTransformer
  22. from tools.api import decode_vq_tokens, encode_reference
  23. from tools.llama.generate_test import convert_string
  24. from tools.llama.generate_test import generate as llama_generate
  25. from tools.llama.generate_test import load_model as load_llama_model
  26. from tools.vqgan.inference import load_model as load_decoder_model
  27. class FishStreamVAD:
  28. def __init__(self) -> None:
  29. # Args
  30. self.sample_rate = 16000
  31. self.threshold = 0.5
  32. self.neg_threshold = self.threshold - 0.15
  33. self.min_speech_duration_ms = 100
  34. self.min_silence_ms = 500
  35. self.speech_pad_ms = 30
  36. self.chunk_size = 512
  37. # Convert to samples
  38. self.min_speech_duration_samples = (
  39. self.min_speech_duration_ms * self.sample_rate // 1000
  40. )
  41. self.min_silence_samples = self.min_silence_ms * self.sample_rate // 1000
  42. self.speech_pad_samples = self.speech_pad_ms * self.sample_rate // 1000
  43. # Core buffers
  44. self.reset()
  45. # Load models
  46. logger.info("Loading VAD model")
  47. vad_model, vad_utils = torch.hub.load(
  48. repo_or_dir="snakers4/silero-vad",
  49. model="silero_vad",
  50. force_reload=True,
  51. onnx=True,
  52. )
  53. self.vad_model = vad_model
  54. self.get_speech_timestamps = vad_utils[0]
  55. logger.info("VAD model loaded")
  56. def reset(self):
  57. self.audio_chunks = None
  58. self.vad_pointer = 0
  59. self.speech_probs = []
  60. self.triggered = False
  61. self.start = self.end = self.temp_end = 0
  62. self.last_seen_end = 0
  63. self.speech_segments = []
  64. def add_chunk(self, chunk, sr=None):
  65. """
  66. Add a chunk to the buffer
  67. """
  68. if isinstance(chunk, np.ndarray):
  69. chunk = torch.from_numpy(chunk)
  70. if sr is not None and sr != self.sample_rate:
  71. chunk = AF.resample(chunk, sr, self.sample_rate)
  72. # self.audio_chunks.append(chunk)
  73. if self.audio_chunks is None:
  74. self.audio_chunks = chunk
  75. else:
  76. self.audio_chunks = torch.cat([self.audio_chunks, chunk])
  77. # Trigger VAD
  78. yield from self.detect_speech()
  79. def detect_speech(self):
  80. """
  81. Run the VAD model on the current buffer
  82. """
  83. speech_prob_start_idx = len(self.speech_probs)
  84. while len(self.audio_chunks) - self.vad_pointer >= self.chunk_size:
  85. chunk = self.audio_chunks[
  86. self.vad_pointer : self.vad_pointer + self.chunk_size
  87. ]
  88. speech_prob = self.vad_model(chunk, self.sample_rate)
  89. self.speech_probs.append(speech_prob)
  90. self.vad_pointer += self.chunk_size
  91. # Process speech probs
  92. for i in range(speech_prob_start_idx, len(self.speech_probs)):
  93. speech_prob = self.speech_probs[i]
  94. if speech_prob >= self.threshold and self.temp_end:
  95. self.temp_end = 0
  96. if speech_prob >= self.threshold and self.triggered is False:
  97. self.triggered = True
  98. self.start = i * self.chunk_size
  99. continue
  100. if speech_prob < self.neg_threshold and self.triggered is True:
  101. if self.temp_end == 0:
  102. self.temp_end = i * self.chunk_size
  103. if i * self.chunk_size - self.temp_end < self.min_silence_samples:
  104. continue
  105. self.end = self.temp_end
  106. if self.end - self.start > self.min_speech_duration_samples:
  107. yield self.audio_chunks[
  108. self.start : self.end + self.speech_pad_samples
  109. ]
  110. self.triggered = False
  111. self.start = self.end = self.temp_end = 0
  112. class FishASR:
  113. def __init__(self) -> None:
  114. self.audio_chunks = None
  115. self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
  116. torch_dtype = torch.bfloat16
  117. model_id = "openai/whisper-medium.en"
  118. logger.info("Loading ASR model")
  119. model = AutoModelForSpeechSeq2Seq.from_pretrained(
  120. model_id, torch_dtype=torch_dtype, use_safetensors=True
  121. ).to(self.device)
  122. processor = AutoProcessor.from_pretrained(model_id)
  123. self.pipe = pipeline(
  124. "automatic-speech-recognition",
  125. model=model,
  126. tokenizer=processor.tokenizer,
  127. feature_extractor=processor.feature_extractor,
  128. max_new_tokens=256,
  129. torch_dtype=torch_dtype,
  130. device=self.device,
  131. )
  132. logger.info("ASR model loaded")
  133. @torch.inference_mode()
  134. def run(self, chunk):
  135. return self.pipe(chunk.numpy())
  136. class FishE2EAgent:
  137. def __init__(self) -> None:
  138. self.device = device = "cuda" if torch.cuda.is_available() else "cpu"
  139. logger.info(f"Using device: {device}")
  140. decoder_model = load_decoder_model(
  141. config_name="firefly_gan_vq",
  142. checkpoint_path="checkpoints/fish-speech-1.2/firefly-gan-vq-fsq-4x1024-42hz-generator.pth",
  143. device=device,
  144. )
  145. self.decoder_model = decoder_model
  146. logger.info("Decoder model loaded")
  147. llama_model, decode_one_token = load_llama_model(
  148. config_name="dual_ar_2_codebook_1.3b",
  149. checkpoint_path="checkpoints/step_000206000.ckpt",
  150. device=device,
  151. precision=torch.bfloat16,
  152. max_length=2048,
  153. compile=True,
  154. )
  155. self.llama_model: DualARTransformer = llama_model
  156. self.decode_one_token = decode_one_token
  157. logger.info("LLAMA model loaded")
  158. self.tokenizer = AutoTokenizer.from_pretrained(
  159. "checkpoints/fish-speech-agent-1"
  160. )
  161. self.semantic_id = self.tokenizer.convert_tokens_to_ids("<|semantic|>")
  162. self.im_end_id = self.tokenizer.convert_tokens_to_ids("<|im_end|>")
  163. self.decoder_tokenizer = AutoTokenizer.from_pretrained(
  164. "fishaudio/fish-speech-1"
  165. )
  166. # Control params
  167. self.temperature = torch.tensor(0.7, device=device, dtype=torch.float)
  168. self.top_p = torch.tensor(0.7, device=device, dtype=torch.float)
  169. self.repetition_penalty = torch.tensor(1.2, device=device, dtype=torch.float)
  170. # This is used to control the timbre of the generated audio
  171. self.base_messages = [
  172. # Message(
  173. # role="user",
  174. # parts=[np.load("example/q0.npy")],
  175. # ),
  176. # Message(
  177. # role="assistant",
  178. # parts=[
  179. # "Transcribed: Hi, can you briefly describe what is machine learning?\nResponse: Sure! Machine learning is the process of automating tasks that humans are capable of doing with a computer. It involves training computers to make decisions based on data.",
  180. # np.load("example/a0.npy"),
  181. # ],
  182. # ),
  183. ]
  184. self.reference = encode_reference(
  185. decoder_model=self.decoder_model,
  186. reference_audio="example/a0.wav",
  187. enable_reference_audio=True,
  188. )
  189. self.messages = self.base_messages.copy()
  190. def reset(self):
  191. self.messages = self.base_messages.copy()
  192. @torch.inference_mode()
  193. def vq_encode(self, audios, sr=None):
  194. if isinstance(audios, np.ndarray):
  195. audios = torch.from_numpy(audios)
  196. if audios.ndim == 1:
  197. audios = audios[None, None, :]
  198. audios = audios.to(self.decoder_model.device)
  199. if sr is not None and sr != self.decoder_model.sampling_rate:
  200. audios = AF.resample(audios, sr, self.decoder_model.sampling_rate)
  201. audio_lengths = torch.tensor(
  202. [audios.shape[2]], device=self.decoder_model.device, dtype=torch.long
  203. )
  204. return self.decoder_model.encode(audios, audio_lengths)[0][0]
  205. @torch.inference_mode()
  206. def generate(self, audio_chunk, sr=None, text=None):
  207. vq_output = self.vq_encode(audio_chunk, sr)
  208. logger.info(f"VQ output: {vq_output.shape}")
  209. # Encode conversation
  210. self.messages.append(
  211. Message(
  212. role="user",
  213. parts=[vq_output],
  214. )
  215. )
  216. parts = []
  217. if text is not None:
  218. parts.append(f"Transcribed: {text}\nResponse:")
  219. self.messages.append(
  220. Message(
  221. role="assistant",
  222. parts=parts,
  223. )
  224. )
  225. conversation = Conversation(self.messages)
  226. # Encode the conversation
  227. prompt, _ = encode_conversation(
  228. conversation, self.tokenizer, self.llama_model.config.num_codebooks
  229. )
  230. prompt = prompt[:, :-1].to(dtype=torch.int, device=self.device)
  231. prompt_length = prompt.shape[1]
  232. # Generate
  233. y = llama_generate(
  234. model=self.llama_model,
  235. prompt=prompt,
  236. max_new_tokens=0,
  237. eos_token_id=self.tokenizer.eos_token_id,
  238. im_end_id=self.im_end_id,
  239. decode_one_token=self.decode_one_token,
  240. temperature=self.temperature,
  241. top_p=self.top_p,
  242. repetition_penalty=self.repetition_penalty,
  243. )
  244. tokens = self.tokenizer.decode(
  245. y[0, prompt_length:].tolist(), skip_special_tokens=False
  246. )
  247. logger.info(f"Generated: {convert_string(tokens)}")
  248. # Put the generated tokens
  249. # since there is <im_end> and <eos> tokens, we remove last 2 tokens
  250. code_mask = y[0, prompt_length:-2] == self.semantic_id
  251. codes = y[1:, prompt_length:-2][:, code_mask].clone()
  252. codes = codes - 2
  253. assert (codes >= 0).all(), f"Negative code found"
  254. decoded = y[:, prompt_length:-1].clone()
  255. if decoded[0, -1] != self.im_end_id: # <im_end>
  256. val = [[self.im_end_id]] + [[CODEBOOK_EOS_TOKEN_ID]] * (decoded.size(0) - 1)
  257. decoded = torch.cat(
  258. (decoded, torch.tensor(val, device=self.device, dtype=torch.int)), dim=1
  259. )
  260. decoded = decoded.cpu()
  261. self.messages[-1].parts.append(
  262. TokensPart(
  263. tokens=decoded[:1],
  264. codes=decoded[1:],
  265. )
  266. )
  267. # Less than 5 * 20 = 100ms
  268. if codes.shape[1] <= 5:
  269. return
  270. # Generate audio
  271. main_tokens = decoded[0]
  272. text_tokens = main_tokens[main_tokens != self.semantic_id]
  273. text = self.tokenizer.decode(text_tokens.tolist(), skip_special_tokens=True)
  274. text_tokens = self.decoder_tokenizer.encode(text, return_tensors="pt").to(
  275. self.device
  276. )
  277. audio = decode_vq_tokens(
  278. decoder_model=self.decoder_model,
  279. codes=codes,
  280. text_tokens=text_tokens,
  281. reference_embedding=self.reference,
  282. )
  283. if sr is not None and sr != self.decoder_model.sampling_rate:
  284. audio = AF.resample(audio, self.decoder_model.sampling_rate, sr)
  285. return audio.float()
  286. class FishAgentPipeline:
  287. def __init__(self) -> None:
  288. self.vad = FishStreamVAD()
  289. # Currently use ASR model as intermediate
  290. self.asr = FishASR()
  291. self.agent = FishE2EAgent()
  292. self.vad_segments = []
  293. self.text_segments = []
  294. def add_chunk(self, chunk, sr=None):
  295. use_np = isinstance(chunk, np.ndarray)
  296. if use_np:
  297. chunk = torch.from_numpy(chunk)
  298. if sr is not None and sr != 16000:
  299. chunk = AF.resample(chunk, sr, 16000)
  300. for vad_audio in self.vad.add_chunk(chunk, 16000):
  301. self.vad_segments.append(vad_audio)
  302. asr_text = self.asr.run(vad_audio)
  303. self.text_segments.append(asr_text)
  304. logger.info(f"ASR: {asr_text}")
  305. # Actually should detect if intent is finished here
  306. result = self.agent.generate(vad_audio, 16000, text=asr_text)
  307. if result is None:
  308. continue
  309. if sr is not None and sr != 16000:
  310. result = AF.resample(result, 16000, sr)
  311. if use_np:
  312. result = result.cpu().numpy()
  313. yield result
  314. def reset(self):
  315. self.vad.reset()
  316. self.agent.reset()
  317. self.vad_segments = []
  318. self.text_segments = []
  319. def warmup(self):
  320. logger.info("Warming up the pipeline")
  321. audio, sr = librosa.load("example/q0.mp3", sr=16000)
  322. for i in range(0, len(audio), 882):
  323. for audio in self.add_chunk(audio[i : i + 882], sr):
  324. pass
  325. logger.info("Pipeline warmed up")
  326. self.reset()
  327. if __name__ == "__main__":
  328. import soundfile as sf
  329. service = FishAgentPipeline()
  330. service.warmup()
  331. logger.info("Stream service started")
  332. audio, sr = librosa.load("example/q1.mp3", sr=16000)
  333. seg = []
  334. for i in range(0, len(audio), 882):
  335. for audio in service.add_chunk(audio[i : i + 882], sr):
  336. seg.append(audio)
  337. audio = np.concatenate(seg)
  338. sf.write("output.wav", audio, 16000)