|
|
@@ -40,13 +40,16 @@ def batch_asr_internal(model: WhisperModel, audios, sr):
|
|
|
assert audio.dim() == 1
|
|
|
audio_np = audio.numpy()
|
|
|
resampled_audio = librosa.resample(audio_np, orig_sr=sr, target_sr=16000)
|
|
|
- resampled_audios.append(torch.from_numpy(resampled_audio))
|
|
|
+ resampled_audios.append(resampled_audio)
|
|
|
|
|
|
trans_results = []
|
|
|
|
|
|
for resampled_audio in resampled_audios:
|
|
|
segments, info = model.transcribe(
|
|
|
- resampled_audio.numpy(), language=None, beam_size=5
|
|
|
+ resampled_audio,
|
|
|
+ language=None,
|
|
|
+ beam_size=5,
|
|
|
+ initial_prompt="Punctuation is needed in any language.",
|
|
|
)
|
|
|
trans_results.append(list(segments))
|
|
|
|
|
|
@@ -71,6 +74,7 @@ def batch_asr_internal(model: WhisperModel, audios, sr):
|
|
|
last_tr = tr
|
|
|
if max_gap > 3.0:
|
|
|
huge_gap = True
|
|
|
+ break
|
|
|
|
|
|
sim_text = t2s_converter.convert(text)
|
|
|
results.append(
|
|
|
@@ -95,34 +99,37 @@ def is_chinese(text):
|
|
|
return True
|
|
|
|
|
|
|
|
|
-def calculate_wer(text1, text2):
|
|
|
- # 将文本分割成字符列表
|
|
|
+def calculate_wer(text1, text2, debug=False):
|
|
|
chars1 = remove_punctuation(text1)
|
|
|
chars2 = remove_punctuation(text2)
|
|
|
|
|
|
- # 计算编辑距离
|
|
|
m, n = len(chars1), len(chars2)
|
|
|
- dp = [[0] * (n + 1) for _ in range(m + 1)]
|
|
|
|
|
|
- for i in range(m + 1):
|
|
|
- dp[i][0] = i
|
|
|
- for j in range(n + 1):
|
|
|
- dp[0][j] = j
|
|
|
+ if m > n:
|
|
|
+ chars1, chars2 = chars2, chars1
|
|
|
+ m, n = n, m
|
|
|
|
|
|
- for i in range(1, m + 1):
|
|
|
- for j in range(1, n + 1):
|
|
|
+ prev = list(range(m + 1)) # row 0 distance: [0, 1, 2, ...]
|
|
|
+ curr = [0] * (m + 1)
|
|
|
+
|
|
|
+ for j in range(1, n + 1):
|
|
|
+ curr[0] = j
|
|
|
+ for i in range(1, m + 1):
|
|
|
if chars1[i - 1] == chars2[j - 1]:
|
|
|
- dp[i][j] = dp[i - 1][j - 1]
|
|
|
+ curr[i] = prev[i - 1]
|
|
|
else:
|
|
|
- dp[i][j] = min(dp[i - 1][j], dp[i][j - 1], dp[i - 1][j - 1]) + 1
|
|
|
+ curr[i] = min(prev[i], curr[i - 1], prev[i - 1]) + 1
|
|
|
+ prev, curr = curr, prev
|
|
|
|
|
|
- # WER
|
|
|
- edits = dp[m][n]
|
|
|
+ edits = prev[m]
|
|
|
tot = max(len(chars1), len(chars2))
|
|
|
wer = edits / tot
|
|
|
- print(" gt: ", chars1)
|
|
|
- print(" pred: ", chars2)
|
|
|
- print(" edits/tot = wer: ", edits, "/", tot, "=", wer)
|
|
|
+
|
|
|
+ if debug:
|
|
|
+ print(" gt: ", chars1)
|
|
|
+ print(" pred: ", chars2)
|
|
|
+ print(" edits/tot = wer: ", edits, "/", tot, "=", wer)
|
|
|
+
|
|
|
return wer
|
|
|
|
|
|
|