| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159 |
- import os
- os.environ["MODELSCOPE_CACHE"] = ".cache/"
- import string
- import time
- from threading import Lock
- import librosa
- import numpy as np
- import opencc
- import torch
- from faster_whisper import WhisperModel
- t2s_converter = opencc.OpenCC("t2s")
- def load_model(*, device="cuda"):
- model = WhisperModel(
- "medium",
- device=device,
- compute_type="float16",
- download_root="faster_whisper",
- )
- print("faster_whisper loaded!")
- return model
- @torch.no_grad()
- def batch_asr_internal(model: WhisperModel, audios, sr):
- resampled_audios = []
- for audio in audios:
- if isinstance(audio, np.ndarray):
- audio = torch.from_numpy(audio).float()
- if audio.dim() > 1:
- audio = audio.squeeze()
- assert audio.dim() == 1
- audio_np = audio.numpy()
- resampled_audio = librosa.resample(audio_np, orig_sr=sr, target_sr=16000)
- resampled_audios.append(resampled_audio)
- trans_results = []
- for resampled_audio in resampled_audios:
- segments, info = model.transcribe(
- resampled_audio,
- language=None,
- beam_size=5,
- initial_prompt="Punctuation is needed in any language.",
- )
- trans_results.append(list(segments))
- results = []
- for trans_res, audio in zip(trans_results, audios):
- duration = len(audio) / sr * 1000
- huge_gap = False
- max_gap = 0.0
- text = None
- last_tr = None
- for tr in trans_res:
- delta = tr.text.strip()
- if tr.id > 1:
- max_gap = max(tr.start - last_tr.end, max_gap)
- text += delta
- else:
- text = delta
- last_tr = tr
- if max_gap > 3.0:
- huge_gap = True
- break
- sim_text = t2s_converter.convert(text)
- results.append(
- {
- "text": sim_text,
- "duration": duration,
- "huge_gap": huge_gap,
- }
- )
- return results
- global_lock = Lock()
- def batch_asr(model, audios, sr):
- return batch_asr_internal(model, audios, sr)
- def is_chinese(text):
- return True
- def calculate_wer(text1, text2, debug=False):
- chars1 = remove_punctuation(text1)
- chars2 = remove_punctuation(text2)
- m, n = len(chars1), len(chars2)
- if m > n:
- chars1, chars2 = chars2, chars1
- m, n = n, m
- prev = list(range(m + 1)) # row 0 distance: [0, 1, 2, ...]
- curr = [0] * (m + 1)
- for j in range(1, n + 1):
- curr[0] = j
- for i in range(1, m + 1):
- if chars1[i - 1] == chars2[j - 1]:
- curr[i] = prev[i - 1]
- else:
- curr[i] = min(prev[i], curr[i - 1], prev[i - 1]) + 1
- prev, curr = curr, prev
- edits = prev[m]
- tot = max(len(chars1), len(chars2))
- wer = edits / tot
- if debug:
- print(" gt: ", chars1)
- print(" pred: ", chars2)
- print(" edits/tot = wer: ", edits, "/", tot, "=", wer)
- return wer
- def remove_punctuation(text):
- chinese_punctuation = (
- " \n\t”“!?。。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃《》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—"
- '‛""„‟…‧﹏'
- )
- all_punctuation = string.punctuation + chinese_punctuation
- translator = str.maketrans("", "", all_punctuation)
- text_without_punctuation = text.translate(translator)
- return text_without_punctuation
- if __name__ == "__main__":
- model = load_model()
- audios = [
- librosa.load("44100.wav", sr=44100)[0],
- librosa.load("lengyue.wav", sr=44100)[0],
- ]
- print(np.array(audios[0]))
- print(batch_asr(model, audios, 44100))
- start_time = time.time()
- for _ in range(10):
- print(batch_asr(model, audios, 44100))
- print("Time taken:", time.time() - start_time)
|