auto_model.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573
  1. #!/usr/bin/env python3
  2. # -*- encoding: utf-8 -*-
  3. # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
  4. # MIT License (https://opensource.org/licenses/MIT)
  5. import copy
  6. import json
  7. import logging
  8. import os.path
  9. import random
  10. import re
  11. import string
  12. import time
  13. import numpy as np
  14. import torch
  15. from funasr.download.download_model_from_hub import download_model
  16. from funasr.download.file import download_from_url
  17. from funasr.register import tables
  18. from funasr.train_utils.load_pretrained_model import load_pretrained_model
  19. from funasr.train_utils.set_all_random_seed import set_all_random_seed
  20. from funasr.utils import export_utils, misc
  21. from funasr.utils.load_utils import load_audio_text_image_video, load_bytes
  22. from funasr.utils.misc import deep_update
  23. from funasr.utils.timestamp_tools import timestamp_sentence, timestamp_sentence_en
  24. from tqdm import tqdm
  25. from .vad_utils import merge_vad, slice_padding_audio_samples
  26. try:
  27. from funasr.models.campplus.cluster_backend import ClusterBackend
  28. from funasr.models.campplus.utils import distribute_spk, postprocess, sv_chunk
  29. except:
  30. pass
  31. def prepare_data_iterator(data_in, input_len=None, data_type=None, key=None):
  32. """ """
  33. data_list = []
  34. key_list = []
  35. filelist = [".scp", ".txt", ".json", ".jsonl", ".text"]
  36. chars = string.ascii_letters + string.digits
  37. if isinstance(data_in, str):
  38. if data_in.startswith("http://") or data_in.startswith("https://"): # url
  39. data_in = download_from_url(data_in)
  40. if isinstance(data_in, str) and os.path.exists(
  41. data_in
  42. ): # wav_path; filelist: wav.scp, file.jsonl;text.txt;
  43. _, file_extension = os.path.splitext(data_in)
  44. file_extension = file_extension.lower()
  45. if file_extension in filelist: # filelist: wav.scp, file.jsonl;text.txt;
  46. with open(data_in, encoding="utf-8") as fin:
  47. for line in fin:
  48. key = "rand_key_" + "".join(random.choice(chars) for _ in range(13))
  49. if data_in.endswith(
  50. ".jsonl"
  51. ): # file.jsonl: json.dumps({"source": data})
  52. lines = json.loads(line.strip())
  53. data = lines["source"]
  54. key = data["key"] if "key" in data else key
  55. else: # filelist, wav.scp, text.txt: id \t data or data
  56. lines = line.strip().split(maxsplit=1)
  57. data = lines[1] if len(lines) > 1 else lines[0]
  58. key = lines[0] if len(lines) > 1 else key
  59. data_list.append(data)
  60. key_list.append(key)
  61. else:
  62. if key is None:
  63. # key = "rand_key_" + "".join(random.choice(chars) for _ in range(13))
  64. key = misc.extract_filename_without_extension(data_in)
  65. data_list = [data_in]
  66. key_list = [key]
  67. elif isinstance(data_in, (list, tuple)):
  68. if data_type is not None and isinstance(
  69. data_type, (list, tuple)
  70. ): # mutiple inputs
  71. data_list_tmp = []
  72. for data_in_i, data_type_i in zip(data_in, data_type):
  73. key_list, data_list_i = prepare_data_iterator(
  74. data_in=data_in_i, data_type=data_type_i
  75. )
  76. data_list_tmp.append(data_list_i)
  77. data_list = []
  78. for item in zip(*data_list_tmp):
  79. data_list.append(item)
  80. else:
  81. # [audio sample point, fbank, text]
  82. data_list = data_in
  83. key_list = []
  84. for data_i in data_in:
  85. if isinstance(data_i, str) and os.path.exists(data_i):
  86. key = misc.extract_filename_without_extension(data_i)
  87. else:
  88. if key is None:
  89. key = "rand_key_" + "".join(
  90. random.choice(chars) for _ in range(13)
  91. )
  92. key_list.append(key)
  93. else: # raw text; audio sample point, fbank; bytes
  94. if isinstance(data_in, bytes): # audio bytes
  95. data_in = load_bytes(data_in)
  96. if key is None:
  97. key = "rand_key_" + "".join(random.choice(chars) for _ in range(13))
  98. data_list = [data_in]
  99. key_list = [key]
  100. return key_list, data_list
  101. class AutoModel:
  102. def __init__(self, **kwargs):
  103. try:
  104. from funasr.utils.version_checker import check_for_update
  105. print(
  106. "Check update of funasr, and it would cost few times. You may disable it by set `disable_update=True` in AutoModel"
  107. )
  108. check_for_update(disable=kwargs.get("disable_update", False))
  109. except:
  110. pass
  111. log_level = getattr(logging, kwargs.get("log_level", "INFO").upper())
  112. logging.basicConfig(level=log_level)
  113. model, kwargs = self.build_model(**kwargs)
  114. # if vad_model is not None, build vad model else None
  115. vad_model = kwargs.get("vad_model", None)
  116. vad_kwargs = (
  117. {} if kwargs.get("vad_kwargs", {}) is None else kwargs.get("vad_kwargs", {})
  118. )
  119. if vad_model is not None:
  120. logging.info("Building VAD model.")
  121. vad_kwargs["model"] = vad_model
  122. vad_kwargs["model_revision"] = kwargs.get("vad_model_revision", "master")
  123. vad_kwargs["device"] = kwargs["device"]
  124. vad_model, vad_kwargs = self.build_model(**vad_kwargs)
  125. # if punc_model is not None, build punc model else None
  126. punc_model = kwargs.get("punc_model", None)
  127. punc_kwargs = (
  128. {}
  129. if kwargs.get("punc_kwargs", {}) is None
  130. else kwargs.get("punc_kwargs", {})
  131. )
  132. if punc_model is not None:
  133. logging.info("Building punc model.")
  134. punc_kwargs["model"] = punc_model
  135. punc_kwargs["model_revision"] = kwargs.get("punc_model_revision", "master")
  136. punc_kwargs["device"] = kwargs["device"]
  137. punc_model, punc_kwargs = self.build_model(**punc_kwargs)
  138. # if spk_model is not None, build spk model else None
  139. spk_model = kwargs.get("spk_model", None)
  140. spk_kwargs = (
  141. {} if kwargs.get("spk_kwargs", {}) is None else kwargs.get("spk_kwargs", {})
  142. )
  143. if spk_model is not None:
  144. logging.info("Building SPK model.")
  145. spk_kwargs["model"] = spk_model
  146. spk_kwargs["model_revision"] = kwargs.get("spk_model_revision", "master")
  147. spk_kwargs["device"] = kwargs["device"]
  148. spk_model, spk_kwargs = self.build_model(**spk_kwargs)
  149. self.cb_model = ClusterBackend().to(kwargs["device"])
  150. spk_mode = kwargs.get("spk_mode", "punc_segment")
  151. if spk_mode not in ["default", "vad_segment", "punc_segment"]:
  152. logging.error(
  153. "spk_mode should be one of default, vad_segment and punc_segment."
  154. )
  155. self.spk_mode = spk_mode
  156. self.kwargs = kwargs
  157. self.model = model
  158. self.vad_model = vad_model
  159. self.vad_kwargs = vad_kwargs
  160. self.punc_model = punc_model
  161. self.punc_kwargs = punc_kwargs
  162. self.spk_model = spk_model
  163. self.spk_kwargs = spk_kwargs
  164. self.model_path = kwargs.get("model_path")
  165. @staticmethod
  166. def build_model(**kwargs):
  167. assert "model" in kwargs
  168. if "model_conf" not in kwargs:
  169. logging.info(
  170. "download models from model hub: {}".format(kwargs.get("hub", "ms"))
  171. )
  172. kwargs = download_model(**kwargs)
  173. set_all_random_seed(kwargs.get("seed", 0))
  174. device = kwargs.get("device", "cuda")
  175. if not torch.cuda.is_available() or kwargs.get("ngpu", 1) == 0:
  176. device = "cpu"
  177. kwargs["batch_size"] = 1
  178. kwargs["device"] = device
  179. torch.set_num_threads(kwargs.get("ncpu", 4))
  180. # build tokenizer
  181. tokenizer = kwargs.get("tokenizer", None)
  182. if tokenizer is not None:
  183. tokenizer_class = tables.tokenizer_classes.get(tokenizer)
  184. tokenizer = tokenizer_class(**kwargs.get("tokenizer_conf", {}))
  185. kwargs["token_list"] = (
  186. tokenizer.token_list if hasattr(tokenizer, "token_list") else None
  187. )
  188. kwargs["token_list"] = (
  189. tokenizer.get_vocab()
  190. if hasattr(tokenizer, "get_vocab")
  191. else kwargs["token_list"]
  192. )
  193. vocab_size = (
  194. len(kwargs["token_list"]) if kwargs["token_list"] is not None else -1
  195. )
  196. if vocab_size == -1 and hasattr(tokenizer, "get_vocab_size"):
  197. vocab_size = tokenizer.get_vocab_size()
  198. else:
  199. vocab_size = -1
  200. kwargs["tokenizer"] = tokenizer
  201. # build frontend
  202. frontend = kwargs.get("frontend", None)
  203. kwargs["input_size"] = None
  204. if frontend is not None:
  205. frontend_class = tables.frontend_classes.get(frontend)
  206. frontend = frontend_class(**kwargs.get("frontend_conf", {}))
  207. kwargs["input_size"] = (
  208. frontend.output_size() if hasattr(frontend, "output_size") else None
  209. )
  210. kwargs["frontend"] = frontend
  211. # build model
  212. model_class = tables.model_classes.get(kwargs["model"])
  213. assert model_class is not None, f'{kwargs["model"]} is not registered'
  214. model_conf = {}
  215. deep_update(model_conf, kwargs.get("model_conf", {}))
  216. deep_update(model_conf, kwargs)
  217. model = model_class(**model_conf, vocab_size=vocab_size)
  218. # init_param
  219. init_param = kwargs.get("init_param", None)
  220. if init_param is not None:
  221. if os.path.exists(init_param):
  222. logging.info(f"Loading pretrained params from {init_param}")
  223. load_pretrained_model(
  224. model=model,
  225. path=init_param,
  226. ignore_init_mismatch=kwargs.get("ignore_init_mismatch", True),
  227. oss_bucket=kwargs.get("oss_bucket", None),
  228. scope_map=kwargs.get("scope_map", []),
  229. excludes=kwargs.get("excludes", None),
  230. )
  231. else:
  232. print(f"error, init_param does not exist!: {init_param}")
  233. # fp16
  234. if kwargs.get("fp16", False):
  235. model.to(torch.float16)
  236. elif kwargs.get("bf16", False):
  237. model.to(torch.bfloat16)
  238. model.to(device)
  239. if not kwargs.get("disable_log", True):
  240. tables.print()
  241. return model, kwargs
  242. def __call__(self, *args, **cfg):
  243. kwargs = self.kwargs
  244. deep_update(kwargs, cfg)
  245. res = self.model(*args, kwargs)
  246. return res
  247. def generate(self, input, input_len=None, **cfg):
  248. if self.vad_model is None:
  249. return self.inference(input, input_len=input_len, **cfg)
  250. else:
  251. return self.inference_with_vad(input, input_len=input_len, **cfg)
  252. def inference(
  253. self, input, input_len=None, model=None, kwargs=None, key=None, **cfg
  254. ):
  255. kwargs = self.kwargs if kwargs is None else kwargs
  256. if "cache" in kwargs:
  257. kwargs.pop("cache")
  258. deep_update(kwargs, cfg)
  259. model = self.model if model is None else model
  260. model.eval()
  261. batch_size = kwargs.get("batch_size", 1)
  262. # if kwargs.get("device", "cpu") == "cpu":
  263. # batch_size = 1
  264. key_list, data_list = prepare_data_iterator(
  265. input, input_len=input_len, data_type=kwargs.get("data_type", None), key=key
  266. )
  267. speed_stats = {}
  268. asr_result_list = []
  269. num_samples = len(data_list)
  270. disable_pbar = self.kwargs.get("disable_pbar", False)
  271. pbar = (
  272. tqdm(colour="blue", total=num_samples, dynamic_ncols=True)
  273. if not disable_pbar
  274. else None
  275. )
  276. time_speech_total = 0.0
  277. time_escape_total = 0.0
  278. for beg_idx in range(0, num_samples, batch_size):
  279. end_idx = min(num_samples, beg_idx + batch_size)
  280. data_batch = data_list[beg_idx:end_idx]
  281. key_batch = key_list[beg_idx:end_idx]
  282. batch = {"data_in": data_batch, "key": key_batch}
  283. if (end_idx - beg_idx) == 1 and kwargs.get(
  284. "data_type", None
  285. ) == "fbank": # fbank
  286. batch["data_in"] = data_batch[0]
  287. batch["data_lengths"] = input_len
  288. time1 = time.perf_counter()
  289. with torch.no_grad():
  290. res = model.inference(**batch, **kwargs)
  291. if isinstance(res, (list, tuple)):
  292. results = res[0] if len(res) > 0 else [{"text": ""}]
  293. meta_data = res[1] if len(res) > 1 else {}
  294. time2 = time.perf_counter()
  295. asr_result_list.extend(results)
  296. # batch_data_time = time_per_frame_s * data_batch_i["speech_lengths"].sum().item()
  297. batch_data_time = meta_data.get("batch_data_time", -1)
  298. time_escape = time2 - time1
  299. speed_stats["load_data"] = meta_data.get("load_data", 0.0)
  300. speed_stats["extract_feat"] = meta_data.get("extract_feat", 0.0)
  301. speed_stats["forward"] = f"{time_escape:0.3f}"
  302. speed_stats["batch_size"] = f"{len(results)}"
  303. speed_stats["rtf"] = f"{(time_escape) / batch_data_time:0.3f}"
  304. description = f"{speed_stats}, "
  305. if pbar:
  306. pbar.update(end_idx - beg_idx)
  307. pbar.set_description(description)
  308. time_speech_total += batch_data_time
  309. time_escape_total += time_escape
  310. if pbar:
  311. # pbar.update(1)
  312. pbar.set_description(f"rtf_avg: {time_escape_total/time_speech_total:0.3f}")
  313. torch.cuda.empty_cache()
  314. return asr_result_list
  315. def vad(self, input, input_len=None, **cfg):
  316. kwargs = self.kwargs
  317. # step.1: compute the vad model
  318. deep_update(self.vad_kwargs, cfg)
  319. beg_vad = time.time()
  320. res = self.inference(
  321. input,
  322. input_len=input_len,
  323. model=self.vad_model,
  324. kwargs=self.vad_kwargs,
  325. **cfg,
  326. )
  327. end_vad = time.time()
  328. # FIX(gcf): concat the vad clips for sense vocie model for better aed
  329. if cfg.get("merge_vad", False):
  330. for i in range(len(res)):
  331. res[i]["value"] = merge_vad(
  332. res[i]["value"], kwargs.get("merge_length_s", 15) * 1000
  333. )
  334. elapsed = end_vad - beg_vad
  335. return elapsed, res
  336. def inference_with_vadres(self, input, vad_res, input_len=None, **cfg):
  337. kwargs = self.kwargs
  338. # step.2 compute asr model
  339. model = self.model
  340. deep_update(kwargs, cfg)
  341. batch_size = max(int(kwargs.get("batch_size_s", 300)) * 1000, 1)
  342. batch_size_threshold_ms = int(kwargs.get("batch_size_threshold_s", 60)) * 1000
  343. kwargs["batch_size"] = batch_size
  344. key_list, data_list = prepare_data_iterator(
  345. input, input_len=input_len, data_type=kwargs.get("data_type", None)
  346. )
  347. results_ret_list = []
  348. time_speech_total_all_samples = 1e-6
  349. beg_total = time.time()
  350. pbar_total = (
  351. tqdm(colour="red", total=len(vad_res), dynamic_ncols=True)
  352. if not kwargs.get("disable_pbar", False)
  353. else None
  354. )
  355. for i in range(len(vad_res)):
  356. key = vad_res[i]["key"]
  357. vadsegments = vad_res[i]["value"]
  358. input_i = data_list[i]
  359. fs = kwargs["frontend"].fs if hasattr(kwargs["frontend"], "fs") else 16000
  360. speech = load_audio_text_image_video(
  361. input_i, fs=fs, audio_fs=kwargs.get("fs", 16000)
  362. )
  363. speech_lengths = len(speech)
  364. n = len(vadsegments)
  365. data_with_index = [(vadsegments[i], i) for i in range(n)]
  366. sorted_data = sorted(data_with_index, key=lambda x: x[0][1] - x[0][0])
  367. results_sorted = []
  368. if not len(sorted_data):
  369. results_ret_list.append({"key": key, "text": "", "timestamp": []})
  370. logging.info("decoding, utt: {}, empty speech".format(key))
  371. continue
  372. if len(sorted_data) > 0 and len(sorted_data[0]) > 0:
  373. batch_size = max(
  374. batch_size, sorted_data[0][0][1] - sorted_data[0][0][0]
  375. )
  376. if kwargs["device"] == "cpu":
  377. batch_size = 0
  378. beg_idx = 0
  379. beg_asr_total = time.time()
  380. time_speech_total_per_sample = speech_lengths / 16000
  381. time_speech_total_all_samples += time_speech_total_per_sample
  382. # pbar_sample = tqdm(colour="blue", total=n, dynamic_ncols=True)
  383. all_segments = []
  384. max_len_in_batch = 0
  385. end_idx = 1
  386. for j, _ in enumerate(range(0, n)):
  387. # pbar_sample.update(1)
  388. sample_length = sorted_data[j][0][1] - sorted_data[j][0][0]
  389. potential_batch_length = max(max_len_in_batch, sample_length) * (
  390. j + 1 - beg_idx
  391. )
  392. # batch_size_ms_cum += sorted_data[j][0][1] - sorted_data[j][0][0]
  393. if (
  394. j < n - 1
  395. and sample_length < batch_size_threshold_ms
  396. and potential_batch_length < batch_size
  397. ):
  398. max_len_in_batch = max(max_len_in_batch, sample_length)
  399. end_idx += 1
  400. continue
  401. speech_j, speech_lengths_j, intervals = slice_padding_audio_samples(
  402. speech, speech_lengths, sorted_data[beg_idx:end_idx]
  403. )
  404. results = self.inference(
  405. speech_j, input_len=None, model=model, kwargs=kwargs, **cfg
  406. )
  407. for _b in range(len(speech_j)):
  408. results[_b]["interval"] = intervals[_b]
  409. if self.spk_model is not None:
  410. # compose vad segments: [[start_time_sec, end_time_sec, speech], [...]]
  411. for _b in range(len(speech_j)):
  412. vad_segments = [
  413. [
  414. sorted_data[beg_idx:end_idx][_b][0][0] / 1000.0,
  415. sorted_data[beg_idx:end_idx][_b][0][1] / 1000.0,
  416. np.array(speech_j[_b]),
  417. ]
  418. ]
  419. segments = sv_chunk(vad_segments)
  420. all_segments.extend(segments)
  421. speech_b = [i[2] for i in segments]
  422. spk_res = self.inference(
  423. speech_b,
  424. input_len=None,
  425. model=self.spk_model,
  426. kwargs=kwargs,
  427. **cfg,
  428. )
  429. results[_b]["spk_embedding"] = spk_res[0]["spk_embedding"]
  430. beg_idx = end_idx
  431. end_idx += 1
  432. max_len_in_batch = sample_length
  433. if len(results) < 1:
  434. continue
  435. results_sorted.extend(results)
  436. # end_asr_total = time.time()
  437. # time_escape_total_per_sample = end_asr_total - beg_asr_total
  438. # pbar_sample.update(1)
  439. # pbar_sample.set_description(f"rtf_avg_per_sample: {time_escape_total_per_sample / time_speech_total_per_sample:0.3f}, "
  440. # f"time_speech_total_per_sample: {time_speech_total_per_sample: 0.3f}, "
  441. # f"time_escape_total_per_sample: {time_escape_total_per_sample:0.3f}")
  442. restored_data = [0] * n
  443. for j in range(n):
  444. index = sorted_data[j][1]
  445. cur = results_sorted[j]
  446. pattern = r"<\|([^|]+)\|>"
  447. emotion_string = re.findall(pattern, cur["text"])
  448. cur["text"] = re.sub(pattern, "", cur["text"])
  449. cur["emo"] = "".join([f"<|{t}|>" for t in emotion_string])
  450. if self.punc_model is not None and len(cur["text"].strip()) > 0:
  451. deep_update(self.punc_kwargs, cfg)
  452. punc_res = self.inference(
  453. cur["text"],
  454. model=self.punc_model,
  455. kwargs=self.punc_kwargs,
  456. **cfg,
  457. )
  458. cur["text"] = punc_res[0]["text"]
  459. restored_data[index] = cur
  460. end_asr_total = time.time()
  461. time_escape_total_per_sample = end_asr_total - beg_asr_total
  462. if pbar_total:
  463. pbar_total.update(1)
  464. pbar_total.set_description(
  465. f"rtf_avg: {time_escape_total_per_sample / time_speech_total_per_sample:0.3f}, "
  466. f"time_speech: {time_speech_total_per_sample: 0.3f}, "
  467. f"time_escape: {time_escape_total_per_sample:0.3f}"
  468. )
  469. # end_total = time.time()
  470. # time_escape_total_all_samples = end_total - beg_total
  471. # print(f"rtf_avg_all: {time_escape_total_all_samples / time_speech_total_all_samples:0.3f}, "
  472. # f"time_speech_all: {time_speech_total_all_samples: 0.3f}, "
  473. # f"time_escape_all: {time_escape_total_all_samples:0.3f}")
  474. return restored_data
  475. def export(self, input=None, **cfg):
  476. """
  477. :param input:
  478. :param type:
  479. :param quantize:
  480. :param fallback_num:
  481. :param calib_num:
  482. :param opset_version:
  483. :param cfg:
  484. :return:
  485. """
  486. device = cfg.get("device", "cpu")
  487. model = self.model.to(device=device)
  488. kwargs = self.kwargs
  489. deep_update(kwargs, cfg)
  490. kwargs["device"] = device
  491. del kwargs["model"]
  492. model.eval()
  493. type = kwargs.get("type", "onnx")
  494. key_list, data_list = prepare_data_iterator(
  495. input, input_len=None, data_type=kwargs.get("data_type", None), key=None
  496. )
  497. with torch.no_grad():
  498. export_dir = export_utils.export(model=model, data_in=data_list, **kwargs)
  499. return export_dir