auto_rerank.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. import os
  2. os.environ["MODELSCOPE_CACHE"] = ".cache/"
  3. import string
  4. import time
  5. from threading import Lock
  6. import librosa
  7. import numpy as np
  8. import opencc
  9. import torch
  10. from faster_whisper import WhisperModel
  11. t2s_converter = opencc.OpenCC("t2s")
  12. def load_model(*, device="cuda"):
  13. model = WhisperModel(
  14. "medium",
  15. device=device,
  16. compute_type="float16",
  17. download_root="faster_whisper",
  18. )
  19. print("faster_whisper loaded!")
  20. return model
  21. @torch.no_grad()
  22. def batch_asr_internal(model: WhisperModel, audios, sr):
  23. resampled_audios = []
  24. for audio in audios:
  25. if isinstance(audio, np.ndarray):
  26. audio = torch.from_numpy(audio).float()
  27. if audio.dim() > 1:
  28. audio = audio.squeeze()
  29. assert audio.dim() == 1
  30. audio_np = audio.numpy()
  31. resampled_audio = librosa.resample(audio_np, orig_sr=sr, target_sr=16000)
  32. resampled_audios.append(resampled_audio)
  33. trans_results = []
  34. for resampled_audio in resampled_audios:
  35. segments, info = model.transcribe(
  36. resampled_audio,
  37. language=None,
  38. beam_size=5,
  39. initial_prompt="Punctuation is needed in any language.",
  40. )
  41. trans_results.append(list(segments))
  42. results = []
  43. for trans_res, audio in zip(trans_results, audios):
  44. duration = len(audio) / sr * 1000
  45. huge_gap = False
  46. max_gap = 0.0
  47. text = None
  48. last_tr = None
  49. for tr in trans_res:
  50. delta = tr.text.strip()
  51. if tr.id > 1:
  52. max_gap = max(tr.start - last_tr.end, max_gap)
  53. text += delta
  54. else:
  55. text = delta
  56. last_tr = tr
  57. if max_gap > 3.0:
  58. huge_gap = True
  59. break
  60. sim_text = t2s_converter.convert(text)
  61. results.append(
  62. {
  63. "text": sim_text,
  64. "duration": duration,
  65. "huge_gap": huge_gap,
  66. }
  67. )
  68. return results
  69. global_lock = Lock()
  70. def batch_asr(model, audios, sr):
  71. return batch_asr_internal(model, audios, sr)
  72. def is_chinese(text):
  73. return True
  74. def calculate_wer(text1, text2, debug=False):
  75. chars1 = remove_punctuation(text1)
  76. chars2 = remove_punctuation(text2)
  77. m, n = len(chars1), len(chars2)
  78. if m > n:
  79. chars1, chars2 = chars2, chars1
  80. m, n = n, m
  81. prev = list(range(m + 1)) # row 0 distance: [0, 1, 2, ...]
  82. curr = [0] * (m + 1)
  83. for j in range(1, n + 1):
  84. curr[0] = j
  85. for i in range(1, m + 1):
  86. if chars1[i - 1] == chars2[j - 1]:
  87. curr[i] = prev[i - 1]
  88. else:
  89. curr[i] = min(prev[i], curr[i - 1], prev[i - 1]) + 1
  90. prev, curr = curr, prev
  91. edits = prev[m]
  92. tot = max(len(chars1), len(chars2))
  93. wer = edits / tot
  94. if debug:
  95. print(" gt: ", chars1)
  96. print(" pred: ", chars2)
  97. print(" edits/tot = wer: ", edits, "/", tot, "=", wer)
  98. return wer
  99. def remove_punctuation(text):
  100. chinese_punctuation = (
  101. " \n\t”“!?。。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃《》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—"
  102. '‛""„‟…‧﹏'
  103. )
  104. all_punctuation = string.punctuation + chinese_punctuation
  105. translator = str.maketrans("", "", all_punctuation)
  106. text_without_punctuation = text.translate(translator)
  107. return text_without_punctuation
  108. if __name__ == "__main__":
  109. model = load_model()
  110. audios = [
  111. librosa.load("44100.wav", sr=44100)[0],
  112. librosa.load("lengyue.wav", sr=44100)[0],
  113. ]
  114. print(np.array(audios[0]))
  115. print(batch_asr(model, audios, 44100))
  116. start_time = time.time()
  117. for _ in range(10):
  118. print(batch_asr(model, audios, 44100))
  119. print("Time taken:", time.time() - start_time)