| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126 |
- import time
- from threading import Lock
- import numpy as np
- import torch
- import torchaudio
- from funasr import AutoModel
- from funasr.models.seaco_paraformer.model import SeacoParaformer
- # Monkey patching to disable hotwords
- SeacoParaformer.generate_hotwords_list = lambda self, *args, **kwargs: None
- def load_model(*, device="cuda"):
- zh_model = AutoModel(
- model="paraformer-zh",
- device=device,
- disable_pbar=True,
- )
- en_model = AutoModel(
- model="paraformer-en",
- device=device,
- disable_pbar=True,
- )
- return zh_model, en_model
- @torch.no_grad()
- def batch_asr_internal(model, audios, sr):
- resampled_audios = []
- for audio in audios:
- # 将 NumPy 数组转换为 PyTorch 张量
- if isinstance(audio, np.ndarray):
- audio = torch.from_numpy(audio).float()
- # 确保音频是一维的
- if audio.dim() > 1:
- audio = audio.squeeze()
- audio = torchaudio.functional.resample(audio, sr, 16000)
- assert audio.dim() == 1
- resampled_audios.append(audio)
- res = model.generate(input=resampled_audios, batch_size=len(resampled_audios))
- results = []
- for r, audio in zip(res, audios):
- text = r["text"]
- duration = len(audio) / sr * 1000
- huge_gap = False
- if "timestamp" in r and len(r["timestamp"]) > 2:
- for timestamp_a, timestamp_b in zip(
- r["timestamp"][:-1], r["timestamp"][1:]
- ):
- # If there is a gap of more than 5 seconds, we consider it as a huge gap
- if timestamp_b[0] - timestamp_a[1] > 5000:
- huge_gap = True
- break
- # Doesn't make sense to have a huge gap at the end
- if duration - r["timestamp"][-1][1] > 3000:
- huge_gap = True
- results.append(
- {
- "text": text,
- "duration": duration,
- "huge_gap": huge_gap,
- }
- )
- return results
- global_lock = Lock()
- def batch_asr(model, audios, sr):
- return batch_asr_internal(model, audios, sr)
- def is_chinese(text):
- return True
- def calculate_wer(text1, text2):
- words1 = text1.split()
- words2 = text2.split()
- # 计算编辑距离
- m, n = len(words1), len(words2)
- dp = [[0] * (n + 1) for _ in range(m + 1)]
- for i in range(m + 1):
- dp[i][0] = i
- for j in range(n + 1):
- dp[0][j] = j
- for i in range(1, m + 1):
- for j in range(1, n + 1):
- if words1[i - 1] == words2[j - 1]:
- dp[i][j] = dp[i - 1][j - 1]
- else:
- dp[i][j] = min(dp[i - 1][j], dp[i][j - 1], dp[i - 1][j - 1]) + 1
- # 计算WER
- edits = dp[m][n]
- wer = edits / len(words1)
- return wer
- if __name__ == "__main__":
- zh_model, en_model = load_model()
- audios = [
- torchaudio.load("lengyue.wav")[0][0],
- torchaudio.load("lengyue.wav")[0][0, : 44100 * 5],
- ]
- print(batch_asr(zh_model, audios, 44100))
- start_time = time.time()
- for _ in range(10):
- batch_asr(zh_model, audios, 44100)
- print("Time taken:", time.time() - start_time)
|