fun_asr.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332
  1. import gc
  2. import os
  3. import re
  4. from audio_separator.separator import Separator
  5. os.environ["MODELSCOPE_CACHE"] = "./.cache/funasr"
  6. os.environ["UVR5_CACHE"] = "./.cache/uvr5-models"
  7. import json
  8. import subprocess
  9. from pathlib import Path
  10. import click
  11. import torch
  12. from loguru import logger
  13. from pydub import AudioSegment
  14. from silero_vad import get_speech_timestamps, load_silero_vad, read_audio
  15. from tqdm import tqdm
  16. from tools.file import AUDIO_EXTENSIONS, VIDEO_EXTENSIONS, list_files
  17. from tools.sensevoice.auto_model import AutoModel
  18. def uvr5_cli(
  19. audio_dir: Path,
  20. output_folder: Path,
  21. audio_files: list[Path] | None = None,
  22. output_format: str = "flac",
  23. model: str = "BS-Roformer-Viperx-1297.ckpt",
  24. ):
  25. # ["BS-Roformer-Viperx-1297.ckpt", "BS-Roformer-Viperx-1296.ckpt", "BS-Roformer-Viperx-1053.ckpt", "Mel-Roformer-Viperx-1143.ckpt"]
  26. sepr = Separator(
  27. model_file_dir=os.environ["UVR5_CACHE"],
  28. output_dir=output_folder,
  29. output_format=output_format,
  30. )
  31. dictmodel = {
  32. "BS-Roformer-Viperx-1297.ckpt": "model_bs_roformer_ep_317_sdr_12.9755.ckpt",
  33. "BS-Roformer-Viperx-1296.ckpt": "model_bs_roformer_ep_368_sdr_12.9628.ckpt",
  34. "BS-Roformer-Viperx-1053.ckpt": "model_bs_roformer_ep_937_sdr_10.5309.ckpt",
  35. "Mel-Roformer-Viperx-1143.ckpt": "model_mel_band_roformer_ep_3005_sdr_11.4360.ckpt",
  36. }
  37. roformer_model = dictmodel[model]
  38. sepr.load_model(roformer_model)
  39. if audio_files is None:
  40. audio_files = list_files(
  41. path=audio_dir, extensions=AUDIO_EXTENSIONS, recursive=True
  42. )
  43. total_files = len(audio_files)
  44. print(f"{total_files} audio files found")
  45. res = []
  46. for audio in tqdm(audio_files, desc="Denoising: "):
  47. file_path = str(audio_dir / audio)
  48. sep_out = sepr.separate(file_path)
  49. if isinstance(sep_out, str):
  50. res.append(sep_out)
  51. elif isinstance(sep_out, list):
  52. res.extend(sep_out)
  53. del sepr
  54. gc.collect()
  55. if torch.cuda.is_available():
  56. torch.cuda.empty_cache()
  57. return res, roformer_model
  58. def get_sample_rate(media_path: Path):
  59. result = subprocess.run(
  60. [
  61. "ffprobe",
  62. "-v",
  63. "quiet",
  64. "-print_format",
  65. "json",
  66. "-show_streams",
  67. str(media_path),
  68. ],
  69. capture_output=True,
  70. text=True,
  71. check=True,
  72. )
  73. media_info = json.loads(result.stdout)
  74. for stream in media_info.get("streams", []):
  75. if stream.get("codec_type") == "audio":
  76. return stream.get("sample_rate")
  77. return "44100" # Default sample rate if not found
  78. def convert_to_mono(src_path: Path, out_path: Path, out_fmt: str = "wav"):
  79. sr = get_sample_rate(src_path)
  80. out_path.parent.mkdir(parents=True, exist_ok=True)
  81. if src_path.resolve() == out_path.resolve():
  82. output = str(out_path.with_stem(out_path.stem + f"_{sr}"))
  83. else:
  84. output = str(out_path)
  85. subprocess.run(
  86. [
  87. "ffmpeg",
  88. "-loglevel",
  89. "error",
  90. "-i",
  91. str(src_path),
  92. "-acodec",
  93. "pcm_s16le" if out_fmt == "wav" else "flac",
  94. "-ar",
  95. sr,
  96. "-ac",
  97. "1",
  98. "-y",
  99. output,
  100. ],
  101. check=True,
  102. )
  103. return out_path
  104. def convert_video_to_audio(video_path: Path, audio_dir: Path):
  105. cur_dir = audio_dir / video_path.relative_to(audio_dir).parent
  106. vocals = [
  107. p
  108. for p in cur_dir.glob(f"{video_path.stem}_(Vocals)*.*")
  109. if p.suffix in AUDIO_EXTENSIONS
  110. ]
  111. if len(vocals) > 0:
  112. return vocals[0]
  113. audio_path = cur_dir / f"{video_path.stem}.wav"
  114. convert_to_mono(video_path, audio_path)
  115. return audio_path
  116. @click.command()
  117. @click.option("--audio-dir", required=True, help="Directory containing audio files")
  118. @click.option(
  119. "--save-dir", required=True, help="Directory to save processed audio files"
  120. )
  121. @click.option("--device", default="cuda", help="Device to use [cuda / cpu]")
  122. @click.option("--language", default="auto", help="Language of the transcription")
  123. @click.option(
  124. "--max_single_segment_time",
  125. default=20000,
  126. type=int,
  127. help="Maximum of Output single audio duration(ms)",
  128. )
  129. @click.option("--fsmn-vad/--silero-vad", default=False)
  130. @click.option("--punc/--no-punc", default=False)
  131. @click.option("--denoise/--no-denoise", default=False)
  132. @click.option("--save_emo/--no_save_emo", default=False)
  133. def main(
  134. audio_dir: str,
  135. save_dir: str,
  136. device: str,
  137. language: str,
  138. max_single_segment_time: int,
  139. fsmn_vad: bool,
  140. punc: bool,
  141. denoise: bool,
  142. save_emo: bool,
  143. ):
  144. audios_path = Path(audio_dir)
  145. save_path = Path(save_dir)
  146. save_path.mkdir(parents=True, exist_ok=True)
  147. video_files = list_files(
  148. path=audio_dir, extensions=VIDEO_EXTENSIONS, recursive=True
  149. )
  150. v2a_files = [convert_video_to_audio(p, audio_dir) for p in video_files]
  151. if denoise:
  152. VOCAL = "_(Vocals)"
  153. original_files = [
  154. p
  155. for p in audios_path.glob("**/*")
  156. if p.suffix in AUDIO_EXTENSIONS and VOCAL not in p.stem
  157. ]
  158. _, cur_model = uvr5_cli(
  159. audio_dir=audio_dir, output_folder=audio_dir, audio_files=original_files
  160. )
  161. need_remove = [p for p in audios_path.glob("**/*(Instrumental)*")]
  162. need_remove.extend(original_files)
  163. for _ in need_remove:
  164. _.unlink()
  165. vocal_files = [
  166. p
  167. for p in audios_path.glob("**/*")
  168. if p.suffix in AUDIO_EXTENSIONS and VOCAL in p.stem
  169. ]
  170. for f in vocal_files:
  171. fn, ext = f.stem, f.suffix
  172. v_pos = fn.find(VOCAL + "_" + cur_model.split(".")[0])
  173. if v_pos != -1:
  174. new_fn = fn[: v_pos + len(VOCAL)]
  175. new_f = f.with_name(new_fn + ext)
  176. f = f.rename(new_f)
  177. convert_to_mono(f, f, "flac")
  178. f.unlink()
  179. audio_files = list_files(
  180. path=audio_dir, extensions=AUDIO_EXTENSIONS, recursive=True
  181. )
  182. logger.info("Loading / Downloading Funasr model...")
  183. model_dir = "iic/SenseVoiceSmall"
  184. vad_model = "fsmn-vad" if fsmn_vad else None
  185. vad_kwargs = {"max_single_segment_time": max_single_segment_time}
  186. punc_model = "ct-punc" if punc else None
  187. manager = AutoModel(
  188. model=model_dir,
  189. trust_remote_code=False,
  190. vad_model=vad_model,
  191. vad_kwargs=vad_kwargs,
  192. punc_model=punc_model,
  193. device=device,
  194. )
  195. if not fsmn_vad and vad_model is None:
  196. vad_model = load_silero_vad()
  197. logger.info("Model loaded.")
  198. pattern = re.compile(r"_\d{3}\.")
  199. for file_path in tqdm(audio_files, desc="Processing audio file"):
  200. if pattern.search(file_path.name):
  201. # logger.info(f"Skipping {file_path} as it has already been processed.")
  202. continue
  203. file_stem = file_path.stem
  204. file_suffix = file_path.suffix
  205. rel_path = Path(file_path).relative_to(audio_dir)
  206. (save_path / rel_path.parent).mkdir(parents=True, exist_ok=True)
  207. audio = AudioSegment.from_file(file_path)
  208. cfg = dict(
  209. cache={},
  210. language=language, # "zh", "en", "yue", "ja", "ko", "nospeech"
  211. use_itn=False,
  212. batch_size_s=60,
  213. )
  214. if fsmn_vad:
  215. elapsed, vad_res = manager.vad(input=str(file_path), **cfg)
  216. else:
  217. wav = read_audio(
  218. str(file_path)
  219. ) # backend (sox, soundfile, or ffmpeg) required!
  220. audio_key = file_path.stem
  221. audio_val = []
  222. speech_timestamps = get_speech_timestamps(
  223. wav,
  224. vad_model,
  225. max_speech_duration_s=max_single_segment_time // 1000,
  226. return_seconds=True,
  227. )
  228. audio_val = [
  229. [int(timestamp["start"] * 1000), int(timestamp["end"] * 1000)]
  230. for timestamp in speech_timestamps
  231. ]
  232. vad_res = []
  233. vad_res.append(dict(key=audio_key, value=audio_val))
  234. res = manager.inference_with_vadres(
  235. input=str(file_path), vad_res=vad_res, **cfg
  236. )
  237. for i, info in enumerate(res):
  238. [start_ms, end_ms] = info["interval"]
  239. text = info["text"]
  240. emo = info["emo"]
  241. sliced_audio = audio[start_ms:end_ms]
  242. audio_save_path = (
  243. save_path / rel_path.parent / f"{file_stem}_{i:03d}{file_suffix}"
  244. )
  245. sliced_audio.export(audio_save_path, format=file_suffix[1:])
  246. print(f"Exported {audio_save_path}: {text}")
  247. transcript_save_path = (
  248. save_path / rel_path.parent / f"{file_stem}_{i:03d}.lab"
  249. )
  250. with open(
  251. transcript_save_path,
  252. "w",
  253. encoding="utf-8",
  254. ) as f:
  255. f.write(text)
  256. if save_emo:
  257. emo_save_path = save_path / rel_path.parent / f"{file_stem}_{i:03d}.emo"
  258. with open(
  259. emo_save_path,
  260. "w",
  261. encoding="utf-8",
  262. ) as f:
  263. f.write(emo)
  264. if audios_path.resolve() == save_path.resolve():
  265. file_path.unlink()
  266. if __name__ == "__main__":
  267. main()
  268. exit(0)
  269. from funasr.utils.postprocess_utils import rich_transcription_postprocess
  270. # Load the audio file
  271. audio_path = Path(r"D:\PythonProject\ok\1_output_(Vocals).wav")
  272. model_dir = "iic/SenseVoiceSmall"
  273. m, kwargs = SenseVoiceSmall.from_pretrained(model=model_dir, device="cuda:0")
  274. m.eval()
  275. res = m.inference(
  276. data_in=f"{kwargs['model_path']}/example/zh.mp3",
  277. language="auto", # "zh", "en", "yue", "ja", "ko", "nospeech"
  278. use_itn=False,
  279. ban_emo_unk=False,
  280. **kwargs,
  281. )
  282. print(res)
  283. text = rich_transcription_postprocess(res[0][0]["text"])
  284. print(text)