auto_rerank.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. import time
  2. from threading import Lock
  3. import numpy as np
  4. import torch
  5. import torchaudio
  6. from funasr import AutoModel
  7. from funasr.models.seaco_paraformer.model import SeacoParaformer
  8. # Monkey patching to disable hotwords
  9. SeacoParaformer.generate_hotwords_list = lambda self, *args, **kwargs: None
  10. def load_model(*, device="cuda"):
  11. zh_model = AutoModel(
  12. model="paraformer-zh",
  13. device=device,
  14. disable_pbar=True,
  15. )
  16. en_model = AutoModel(
  17. model="paraformer-en",
  18. device=device,
  19. disable_pbar=True,
  20. )
  21. return zh_model, en_model
  22. @torch.no_grad()
  23. def batch_asr_internal(model, audios, sr):
  24. resampled_audios = []
  25. for audio in audios:
  26. # 将 NumPy 数组转换为 PyTorch 张量
  27. if isinstance(audio, np.ndarray):
  28. audio = torch.from_numpy(audio).float()
  29. # 确保音频是一维的
  30. if audio.dim() > 1:
  31. audio = audio.squeeze()
  32. audio = torchaudio.functional.resample(audio, sr, 16000)
  33. assert audio.dim() == 1
  34. resampled_audios.append(audio)
  35. res = model.generate(input=resampled_audios, batch_size=len(resampled_audios))
  36. results = []
  37. for r, audio in zip(res, audios):
  38. text = r["text"]
  39. duration = len(audio) / sr * 1000
  40. huge_gap = False
  41. if "timestamp" in r and len(r["timestamp"]) > 2:
  42. for timestamp_a, timestamp_b in zip(
  43. r["timestamp"][:-1], r["timestamp"][1:]
  44. ):
  45. # If there is a gap of more than 5 seconds, we consider it as a huge gap
  46. if timestamp_b[0] - timestamp_a[1] > 5000:
  47. huge_gap = True
  48. break
  49. # Doesn't make sense to have a huge gap at the end
  50. if duration - r["timestamp"][-1][1] > 3000:
  51. huge_gap = True
  52. results.append(
  53. {
  54. "text": text,
  55. "duration": duration,
  56. "huge_gap": huge_gap,
  57. }
  58. )
  59. return results
  60. global_lock = Lock()
  61. def batch_asr(model, audios, sr):
  62. return batch_asr_internal(model, audios, sr)
  63. def is_chinese(text):
  64. return True
  65. def calculate_wer(text1, text2):
  66. words1 = text1.split()
  67. words2 = text2.split()
  68. # 计算编辑距离
  69. m, n = len(words1), len(words2)
  70. dp = [[0] * (n + 1) for _ in range(m + 1)]
  71. for i in range(m + 1):
  72. dp[i][0] = i
  73. for j in range(n + 1):
  74. dp[0][j] = j
  75. for i in range(1, m + 1):
  76. for j in range(1, n + 1):
  77. if words1[i - 1] == words2[j - 1]:
  78. dp[i][j] = dp[i - 1][j - 1]
  79. else:
  80. dp[i][j] = min(dp[i - 1][j], dp[i][j - 1], dp[i - 1][j - 1]) + 1
  81. # 计算WER
  82. edits = dp[m][n]
  83. wer = edits / len(words1)
  84. return wer
  85. if __name__ == "__main__":
  86. zh_model, en_model = load_model()
  87. audios = [
  88. torchaudio.load("lengyue.wav")[0][0],
  89. torchaudio.load("lengyue.wav")[0][0, : 44100 * 5],
  90. ]
  91. print(batch_asr(zh_model, audios, 44100))
  92. start_time = time.time()
  93. for _ in range(10):
  94. batch_asr(zh_model, audios, 44100)
  95. print("Time taken:", time.time() - start_time)