auto_rerank.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. import os
  2. os.environ["MODELSCOPE_CACHE"] = ".cache/"
  3. import string
  4. import time
  5. from threading import Lock
  6. import librosa
  7. import numpy as np
  8. import opencc
  9. import torch
  10. from faster_whisper import WhisperModel
  11. t2s_converter = opencc.OpenCC("t2s")
  12. def load_model(*, device="cuda"):
  13. model = WhisperModel(
  14. "medium",
  15. device=device,
  16. compute_type="float16",
  17. download_root="faster_whisper",
  18. )
  19. print("faster_whisper loaded!")
  20. return model
  21. @torch.no_grad()
  22. def batch_asr_internal(model: WhisperModel, audios, sr):
  23. resampled_audios = []
  24. for audio in audios:
  25. if isinstance(audio, np.ndarray):
  26. audio = torch.from_numpy(audio).float()
  27. if audio.dim() > 1:
  28. audio = audio.squeeze()
  29. assert audio.dim() == 1
  30. audio_np = audio.numpy()
  31. resampled_audio = librosa.resample(audio_np, orig_sr=sr, target_sr=16000)
  32. resampled_audios.append(torch.from_numpy(resampled_audio))
  33. trans_results = []
  34. for resampled_audio in resampled_audios:
  35. segments, info = model.transcribe(
  36. resampled_audio.numpy(), language=None, beam_size=5
  37. )
  38. trans_results.append(list(segments))
  39. results = []
  40. for trans_res, audio in zip(trans_results, audios):
  41. duration = len(audio) / sr * 1000
  42. huge_gap = False
  43. max_gap = 0.0
  44. text = None
  45. last_tr = None
  46. for tr in trans_res:
  47. delta = tr.text.strip()
  48. if tr.id > 1:
  49. max_gap = max(tr.start - last_tr.end, max_gap)
  50. text += delta
  51. else:
  52. text = delta
  53. last_tr = tr
  54. if max_gap > 3.0:
  55. huge_gap = True
  56. sim_text = t2s_converter.convert(text)
  57. results.append(
  58. {
  59. "text": sim_text,
  60. "duration": duration,
  61. "huge_gap": huge_gap,
  62. }
  63. )
  64. return results
  65. global_lock = Lock()
  66. def batch_asr(model, audios, sr):
  67. return batch_asr_internal(model, audios, sr)
  68. def is_chinese(text):
  69. return True
  70. def calculate_wer(text1, text2):
  71. # 将文本分割成字符列表
  72. chars1 = remove_punctuation(text1)
  73. chars2 = remove_punctuation(text2)
  74. # 计算编辑距离
  75. m, n = len(chars1), len(chars2)
  76. dp = [[0] * (n + 1) for _ in range(m + 1)]
  77. for i in range(m + 1):
  78. dp[i][0] = i
  79. for j in range(n + 1):
  80. dp[0][j] = j
  81. for i in range(1, m + 1):
  82. for j in range(1, n + 1):
  83. if chars1[i - 1] == chars2[j - 1]:
  84. dp[i][j] = dp[i - 1][j - 1]
  85. else:
  86. dp[i][j] = min(dp[i - 1][j], dp[i][j - 1], dp[i - 1][j - 1]) + 1
  87. # WER
  88. edits = dp[m][n]
  89. tot = max(len(chars1), len(chars2))
  90. wer = edits / tot
  91. print(" gt: ", chars1)
  92. print(" pred: ", chars2)
  93. print(" edits/tot = wer: ", edits, "/", tot, "=", wer)
  94. return wer
  95. def remove_punctuation(text):
  96. chinese_punctuation = (
  97. " \n\t”“!?。。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃《》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—"
  98. '‛""„‟…‧﹏'
  99. )
  100. all_punctuation = string.punctuation + chinese_punctuation
  101. translator = str.maketrans("", "", all_punctuation)
  102. text_without_punctuation = text.translate(translator)
  103. return text_without_punctuation
  104. if __name__ == "__main__":
  105. model = load_model()
  106. audios = [
  107. librosa.load("44100.wav", sr=44100)[0],
  108. librosa.load("lengyue.wav", sr=44100)[0],
  109. ]
  110. print(np.array(audios[0]))
  111. print(batch_asr(model, audios, 44100))
  112. start_time = time.time()
  113. for _ in range(10):
  114. print(batch_asr(model, audios, 44100))
  115. print("Time taken:", time.time() - start_time)