|
@@ -1,71 +1,81 @@
|
|
|
-import time
|
|
|
|
|
|
|
+import os
|
|
|
|
|
+
|
|
|
|
|
+os.environ["MODELSCOPE_CACHE"] = ".cache/"
|
|
|
|
|
+
|
|
|
|
|
+import string
|
|
|
|
|
+import time
|
|
|
from threading import Lock
|
|
from threading import Lock
|
|
|
|
|
|
|
|
|
|
+import librosa
|
|
|
import numpy as np
|
|
import numpy as np
|
|
|
|
|
+import opencc
|
|
|
import torch
|
|
import torch
|
|
|
-import torchaudio
|
|
|
|
|
-from funasr import AutoModel
|
|
|
|
|
-from funasr.models.seaco_paraformer.model import SeacoParaformer
|
|
|
|
|
|
|
+from faster_whisper import WhisperModel
|
|
|
|
|
|
|
|
-# Monkey patching to disable hotwords
|
|
|
|
|
-SeacoParaformer.generate_hotwords_list = lambda self, *args, **kwargs: None
|
|
|
|
|
|
|
+t2s_converter = opencc.OpenCC("t2s")
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_model(*, device="cuda"):
|
|
def load_model(*, device="cuda"):
|
|
|
- zh_model = AutoModel(
|
|
|
|
|
- model="paraformer-zh",
|
|
|
|
|
- device=device,
|
|
|
|
|
- disable_pbar=True,
|
|
|
|
|
- )
|
|
|
|
|
- en_model = AutoModel(
|
|
|
|
|
- model="paraformer-en",
|
|
|
|
|
|
|
+ model = WhisperModel(
|
|
|
|
|
+ "medium",
|
|
|
device=device,
|
|
device=device,
|
|
|
- disable_pbar=True,
|
|
|
|
|
|
|
+ compute_type="float16",
|
|
|
|
|
+ download_root="faster_whisper",
|
|
|
)
|
|
)
|
|
|
-
|
|
|
|
|
- return zh_model, en_model
|
|
|
|
|
|
|
+ print("faster_whisper loaded!")
|
|
|
|
|
+ return model
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
@torch.no_grad()
|
|
|
-def batch_asr_internal(model, audios, sr):
|
|
|
|
|
|
|
+def batch_asr_internal(model: WhisperModel, audios, sr):
|
|
|
resampled_audios = []
|
|
resampled_audios = []
|
|
|
for audio in audios:
|
|
for audio in audios:
|
|
|
- # 将 NumPy 数组转换为 PyTorch 张量
|
|
|
|
|
|
|
+
|
|
|
if isinstance(audio, np.ndarray):
|
|
if isinstance(audio, np.ndarray):
|
|
|
audio = torch.from_numpy(audio).float()
|
|
audio = torch.from_numpy(audio).float()
|
|
|
|
|
|
|
|
- # 确保音频是一维的
|
|
|
|
|
if audio.dim() > 1:
|
|
if audio.dim() > 1:
|
|
|
audio = audio.squeeze()
|
|
audio = audio.squeeze()
|
|
|
|
|
|
|
|
- audio = torchaudio.functional.resample(audio, sr, 16000)
|
|
|
|
|
assert audio.dim() == 1
|
|
assert audio.dim() == 1
|
|
|
- resampled_audios.append(audio)
|
|
|
|
|
|
|
+ audio_np = audio.numpy()
|
|
|
|
|
+ resampled_audio = librosa.resample(audio_np, orig_sr=sr, target_sr=16000)
|
|
|
|
|
+ resampled_audios.append(torch.from_numpy(resampled_audio))
|
|
|
|
|
|
|
|
- res = model.generate(input=resampled_audios, batch_size=len(resampled_audios))
|
|
|
|
|
|
|
+ trans_results = []
|
|
|
|
|
+
|
|
|
|
|
+ for resampled_audio in resampled_audios:
|
|
|
|
|
+ segments, info = model.transcribe(
|
|
|
|
|
+ resampled_audio.numpy(), language=None, beam_size=5
|
|
|
|
|
+ )
|
|
|
|
|
+ trans_results.append(list(segments))
|
|
|
|
|
|
|
|
results = []
|
|
results = []
|
|
|
- for r, audio in zip(res, audios):
|
|
|
|
|
- text = r["text"]
|
|
|
|
|
|
|
+ for trans_res, audio in zip(trans_results, audios):
|
|
|
|
|
+
|
|
|
duration = len(audio) / sr * 1000
|
|
duration = len(audio) / sr * 1000
|
|
|
huge_gap = False
|
|
huge_gap = False
|
|
|
|
|
+ max_gap = 0.0
|
|
|
|
|
+
|
|
|
|
|
+ text = None
|
|
|
|
|
+ last_tr = None
|
|
|
|
|
+
|
|
|
|
|
+ for tr in trans_res:
|
|
|
|
|
+ delta = tr.text.strip()
|
|
|
|
|
+ if tr.id > 1:
|
|
|
|
|
+ max_gap = max(tr.start - last_tr.end, max_gap)
|
|
|
|
|
+ text += delta
|
|
|
|
|
+ else:
|
|
|
|
|
+ text = delta
|
|
|
|
|
|
|
|
- if "timestamp" in r and len(r["timestamp"]) > 2:
|
|
|
|
|
- for timestamp_a, timestamp_b in zip(
|
|
|
|
|
- r["timestamp"][:-1], r["timestamp"][1:]
|
|
|
|
|
- ):
|
|
|
|
|
- # If there is a gap of more than 5 seconds, we consider it as a huge gap
|
|
|
|
|
- if timestamp_b[0] - timestamp_a[1] > 5000:
|
|
|
|
|
- huge_gap = True
|
|
|
|
|
- break
|
|
|
|
|
-
|
|
|
|
|
- # Doesn't make sense to have a huge gap at the end
|
|
|
|
|
- if duration - r["timestamp"][-1][1] > 3000:
|
|
|
|
|
|
|
+ last_tr = tr
|
|
|
|
|
+ if max_gap > 3.0:
|
|
|
huge_gap = True
|
|
huge_gap = True
|
|
|
|
|
|
|
|
|
|
+ sim_text = t2s_converter.convert(text)
|
|
|
results.append(
|
|
results.append(
|
|
|
{
|
|
{
|
|
|
- "text": text,
|
|
|
|
|
|
|
+ "text": sim_text,
|
|
|
"duration": duration,
|
|
"duration": duration,
|
|
|
"huge_gap": huge_gap,
|
|
"huge_gap": huge_gap,
|
|
|
}
|
|
}
|
|
@@ -86,11 +96,12 @@ def is_chinese(text):
|
|
|
|
|
|
|
|
|
|
|
|
|
def calculate_wer(text1, text2):
|
|
def calculate_wer(text1, text2):
|
|
|
- words1 = text1.split()
|
|
|
|
|
- words2 = text2.split()
|
|
|
|
|
|
|
+ # 将文本分割成字符列表
|
|
|
|
|
+ chars1 = remove_punctuation(text1)
|
|
|
|
|
+ chars2 = remove_punctuation(text2)
|
|
|
|
|
|
|
|
# 计算编辑距离
|
|
# 计算编辑距离
|
|
|
- m, n = len(words1), len(words2)
|
|
|
|
|
|
|
+ m, n = len(chars1), len(chars2)
|
|
|
dp = [[0] * (n + 1) for _ in range(m + 1)]
|
|
dp = [[0] * (n + 1) for _ in range(m + 1)]
|
|
|
|
|
|
|
|
for i in range(m + 1):
|
|
for i in range(m + 1):
|
|
@@ -100,27 +111,42 @@ def calculate_wer(text1, text2):
|
|
|
|
|
|
|
|
for i in range(1, m + 1):
|
|
for i in range(1, m + 1):
|
|
|
for j in range(1, n + 1):
|
|
for j in range(1, n + 1):
|
|
|
- if words1[i - 1] == words2[j - 1]:
|
|
|
|
|
|
|
+ if chars1[i - 1] == chars2[j - 1]:
|
|
|
dp[i][j] = dp[i - 1][j - 1]
|
|
dp[i][j] = dp[i - 1][j - 1]
|
|
|
else:
|
|
else:
|
|
|
dp[i][j] = min(dp[i - 1][j], dp[i][j - 1], dp[i - 1][j - 1]) + 1
|
|
dp[i][j] = min(dp[i - 1][j], dp[i][j - 1], dp[i - 1][j - 1]) + 1
|
|
|
|
|
|
|
|
- # 计算WER
|
|
|
|
|
|
|
+ # WER
|
|
|
edits = dp[m][n]
|
|
edits = dp[m][n]
|
|
|
- wer = edits / len(words1)
|
|
|
|
|
-
|
|
|
|
|
|
|
+ tot = max(len(chars1), len(chars2))
|
|
|
|
|
+ wer = edits / tot
|
|
|
|
|
+ print(" gt: ", chars1)
|
|
|
|
|
+ print(" pred: ", chars2)
|
|
|
|
|
+ print(" edits/tot = wer: ", edits, "/", tot, "=", wer)
|
|
|
return wer
|
|
return wer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
+def remove_punctuation(text):
|
|
|
|
|
+ chinese_punctuation = (
|
|
|
|
|
+ " \n\t”“!?。。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃《》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—"
|
|
|
|
|
+ '‛""„‟…‧﹏'
|
|
|
|
|
+ )
|
|
|
|
|
+ all_punctuation = string.punctuation + chinese_punctuation
|
|
|
|
|
+ translator = str.maketrans("", "", all_punctuation)
|
|
|
|
|
+ text_without_punctuation = text.translate(translator)
|
|
|
|
|
+ return text_without_punctuation
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
if __name__ == "__main__":
|
|
if __name__ == "__main__":
|
|
|
- zh_model, en_model = load_model()
|
|
|
|
|
|
|
+ model = load_model()
|
|
|
audios = [
|
|
audios = [
|
|
|
- torchaudio.load("lengyue.wav")[0][0],
|
|
|
|
|
- torchaudio.load("lengyue.wav")[0][0, : 44100 * 5],
|
|
|
|
|
|
|
+ librosa.load("44100.wav", sr=44100)[0],
|
|
|
|
|
+ librosa.load("lengyue.wav", sr=44100)[0],
|
|
|
]
|
|
]
|
|
|
- print(batch_asr(zh_model, audios, 44100))
|
|
|
|
|
|
|
+ print(np.array(audios[0]))
|
|
|
|
|
+ print(batch_asr(model, audios, 44100))
|
|
|
|
|
|
|
|
start_time = time.time()
|
|
start_time = time.time()
|
|
|
for _ in range(10):
|
|
for _ in range(10):
|
|
|
- batch_asr(zh_model, audios, 44100)
|
|
|
|
|
|
|
+ print(batch_asr(model, audios, 44100))
|
|
|
print("Time taken:", time.time() - start_time)
|
|
print("Time taken:", time.time() - start_time)
|