|
|
@@ -1,183 +1,126 @@
|
|
|
-# This file is used to convert the audio files to text files using the Whisper model.
|
|
|
-# It's mainly used to generate the training data for the VQ model.
|
|
|
-
|
|
|
+"""
|
|
|
+Used to transcribe all audio files in one folder into another folder.
|
|
|
+e.g.
|
|
|
+Directory structure:
|
|
|
+--pre_data_root
|
|
|
+----SP_1
|
|
|
+------01.wav
|
|
|
+------02.wav
|
|
|
+------......
|
|
|
+----SP_2
|
|
|
+------01.wav
|
|
|
+------02.wav
|
|
|
+------......
|
|
|
+Use
|
|
|
+python tools/whisper_asr.py --audio_dir pre_data_root/SP_1 --save_dir data/SP_1
|
|
|
+to transcribe the first speaker.
|
|
|
+
|
|
|
+Use
|
|
|
+python tools/whisper_asr.py --audio_dir pre_data_root/SP_2 --save_dir data/SP_2
|
|
|
+to transcribe the second speaker.
|
|
|
+
|
|
|
+Note: Be aware of your audio sample rate, which defaults to 44.1kHz.
|
|
|
+"""
|
|
|
+
|
|
|
+import argparse
|
|
|
import os
|
|
|
-import subprocess as sp
|
|
|
-import time
|
|
|
-from datetime import timedelta
|
|
|
-from functools import lru_cache
|
|
|
from pathlib import Path
|
|
|
-from random import Random
|
|
|
|
|
|
-import click
|
|
|
+import librosa
|
|
|
import numpy as np
|
|
|
-import torch
|
|
|
-from loguru import logger
|
|
|
-from transformers import WhisperProcessor
|
|
|
-from whisper.audio import load_audio, log_mel_spectrogram, pad_or_trim
|
|
|
+import whisper
|
|
|
+from scipy.io import wavfile
|
|
|
+from tqdm import tqdm
|
|
|
|
|
|
-from fish_speech.modules.flash_whisper import FlashWhisperForConditionalGeneration
|
|
|
|
|
|
-RANK_STR = ""
|
|
|
+def load_and_normalize_audio(filepath, target_sr):
|
|
|
+ wav, sr = librosa.load(filepath, sr=None, mono=True)
|
|
|
+ wav, _ = librosa.effects.trim(wav, top_db=20)
|
|
|
+ peak = np.abs(wav).max()
|
|
|
+ if peak > 1.0:
|
|
|
+ wav /= peak / 0.98
|
|
|
+ return librosa.resample(wav, orig_sr=sr, target_sr=target_sr), target_sr
|
|
|
|
|
|
|
|
|
-@lru_cache(maxsize=1)
|
|
|
-def get_whisper_model():
|
|
|
- model = FlashWhisperForConditionalGeneration.from_pretrained(
|
|
|
- "openai/whisper-medium"
|
|
|
- ).cuda()
|
|
|
- model.eval()
|
|
|
- logger.info(f"{RANK_STR}Loaded model")
|
|
|
+def transcribe_audio(model, filepath):
|
|
|
+ return model.transcribe(
|
|
|
+ filepath, word_timestamps=True, task="transcribe", beam_size=5, best_of=5
|
|
|
+ )
|
|
|
|
|
|
- return model
|
|
|
|
|
|
+def save_audio_segments(segments, wav, sr, save_path):
|
|
|
+ for i, seg in enumerate(segments):
|
|
|
+ start_time, end_time = seg["start"], seg["end"]
|
|
|
+ wav_seg = wav[int(start_time * sr) : int(end_time * sr)]
|
|
|
+ wav_seg_name = f"{save_path.stem}_{i}.wav"
|
|
|
+ out_fpath = save_path / wav_seg_name
|
|
|
+ wavfile.write(
|
|
|
+ out_fpath, rate=sr, data=(wav_seg * np.iinfo(np.int16).max).astype(np.int16)
|
|
|
+ )
|
|
|
|
|
|
-@lru_cache(maxsize=1)
|
|
|
-def get_whisper_processor():
|
|
|
- return WhisperProcessor.from_pretrained("openai/whisper-medium")
|
|
|
|
|
|
+def transcribe_segment(model, filepath):
|
|
|
+ audio = whisper.load_audio(filepath)
|
|
|
+ audio = whisper.pad_or_trim(audio)
|
|
|
+ mel = whisper.log_mel_spectrogram(audio, n_mels=128).to(model.device)
|
|
|
+ _, probs = model.detect_language(mel)
|
|
|
+ lang = max(probs, key=probs.get)
|
|
|
+ options = whisper.DecodingOptions(beam_size=5)
|
|
|
+ result = whisper.decode(model, mel, options)
|
|
|
+ return result.text, lang
|
|
|
+
|
|
|
+
|
|
|
+def process_output(save_dir, language, out_file):
|
|
|
+ with open(out_file, "w", encoding="utf-8") as wf:
|
|
|
+ ch_name = save_dir.stem
|
|
|
+ for file in save_dir.glob("*.lab"):
|
|
|
+ with open(file, "r", encoding="utf-8") as perFile:
|
|
|
+ line = perFile.readline().strip()
|
|
|
+ result = (
|
|
|
+ f"{save_dir}/{ch_name}/{file.stem}.wav|{ch_name}|{language}|{line}"
|
|
|
+ )
|
|
|
+ wf.write(f"{result}\n")
|
|
|
|
|
|
-def transcribe_batch(files: list[str], language: str):
|
|
|
- wavs = [load_audio(file, 16000) for file in files]
|
|
|
- total_time = sum([len(wav) for wav in wavs]) / 16000
|
|
|
- wavs = [pad_or_trim(wav) for wav in wavs]
|
|
|
|
|
|
- wavs = torch.from_numpy(np.stack(wavs)).float().cuda()
|
|
|
- mels = log_mel_spectrogram(wavs).cuda()
|
|
|
- model = get_whisper_model()
|
|
|
- processor = get_whisper_processor()
|
|
|
- forced_decoder_ids = processor.get_decoder_prompt_ids(
|
|
|
- language=language, task="transcribe"
|
|
|
- )
|
|
|
+def main(model_size, audio_dir, save_dir, out_sr, language):
|
|
|
+ model = whisper.load_model(model_size)
|
|
|
+ audio_dir, save_dir = Path(audio_dir), Path(save_dir)
|
|
|
+ save_dir.mkdir(exist_ok=True)
|
|
|
|
|
|
- with torch.no_grad():
|
|
|
- outputs = model.generate(
|
|
|
- input_features=mels,
|
|
|
- max_length=448,
|
|
|
- do_sample=False,
|
|
|
- forced_decoder_ids=forced_decoder_ids,
|
|
|
- )
|
|
|
+ for filepath in tqdm(list(audio_dir.glob("*.wav")), desc="Processing files"):
|
|
|
+ wav, sr = load_and_normalize_audio(filepath, out_sr)
|
|
|
+ transcription = transcribe_audio(model, filepath)
|
|
|
+ save_path = save_dir / filepath.stem
|
|
|
+ save_audio_segments(transcription["segments"], wav, sr, save_path)
|
|
|
|
|
|
- outputs = outputs.cpu().tolist()
|
|
|
-
|
|
|
- # Remove EOS token
|
|
|
- for output in outputs:
|
|
|
- while output[-1] in [
|
|
|
- processor.tokenizer.pad_token_id,
|
|
|
- processor.tokenizer.eos_token_id,
|
|
|
- ]:
|
|
|
- output.pop()
|
|
|
- output.append(processor.tokenizer.eos_token_id)
|
|
|
-
|
|
|
- transcriptions = processor.batch_decode(outputs, skip_special_tokens=False)
|
|
|
- tokens = [",".join(map(str, line)) for line in outputs]
|
|
|
- transcriptions = [
|
|
|
- f"{token}\t{transcription}"
|
|
|
- for token, transcription in zip(tokens, transcriptions)
|
|
|
- ]
|
|
|
-
|
|
|
- return transcriptions, total_time
|
|
|
-
|
|
|
-
|
|
|
-@click.command()
|
|
|
-@click.argument("folder")
|
|
|
-@click.option("--rank", default=0)
|
|
|
-@click.option("--world-size", default=1)
|
|
|
-@click.option("--num-workers", default=1)
|
|
|
-@click.option("--language", default="english")
|
|
|
-def main(folder: str, rank: int, world_size: int, num_workers: int, language: str):
|
|
|
- global RANK_STR
|
|
|
-
|
|
|
- if num_workers > 1 and world_size != num_workers:
|
|
|
- RANK_STR = "[Master] "
|
|
|
- logger.info(f"{RANK_STR}Spawning {num_workers} workers")
|
|
|
-
|
|
|
- visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
|
|
|
- if visible_devices is None:
|
|
|
- visible_devices = list(range(torch.cuda.device_count()))
|
|
|
- else:
|
|
|
- visible_devices = visible_devices.split(",")
|
|
|
-
|
|
|
- processes = []
|
|
|
- for i in range(num_workers):
|
|
|
- env = os.environ.copy()
|
|
|
- env["CUDA_VISIBLE_DEVICES"] = str(visible_devices[i % len(visible_devices)])
|
|
|
- args = [
|
|
|
- "python",
|
|
|
- __file__,
|
|
|
- "--rank",
|
|
|
- str(i),
|
|
|
- "--world-size",
|
|
|
- str(num_workers),
|
|
|
- "--language",
|
|
|
- language,
|
|
|
- folder,
|
|
|
- ]
|
|
|
- processes.append(
|
|
|
- sp.Popen(
|
|
|
- args,
|
|
|
- env=env,
|
|
|
- )
|
|
|
- )
|
|
|
-
|
|
|
- for p in processes:
|
|
|
- p.wait()
|
|
|
-
|
|
|
- logger.info(f"{RANK_STR}All workers finished")
|
|
|
- return
|
|
|
-
|
|
|
- # This is a worker
|
|
|
- RANK_STR = f"[Rank: {rank}] "
|
|
|
- logger.info(f"{RANK_STR}Starting worker")
|
|
|
-
|
|
|
- files = [
|
|
|
- str(file)
|
|
|
- for file in Path(folder).rglob("*")
|
|
|
- if file.suffix in [".wav", ".flac"]
|
|
|
- ]
|
|
|
-
|
|
|
- logger.info(f"{RANK_STR}Found {len(files)} files")
|
|
|
-
|
|
|
- files = sorted(files)
|
|
|
- Random(42).shuffle(files)
|
|
|
- files = files[rank::world_size]
|
|
|
- logger.info(f"{RANK_STR}Processing {len(files)} files")
|
|
|
-
|
|
|
- # Batch size 64
|
|
|
- total_time = 0
|
|
|
- begin_time = time.time()
|
|
|
- processed_files = 0
|
|
|
-
|
|
|
- for n_batch, idx in enumerate(range(0, len(files), 64)):
|
|
|
- batch = files[idx : idx + 64]
|
|
|
- trascriptions, batch_time = transcribe_batch(batch, language)
|
|
|
- total_time += batch_time
|
|
|
- processed_files += len(batch)
|
|
|
-
|
|
|
- if (n_batch + 1) % 10 == 0:
|
|
|
- eta = (
|
|
|
- (time.time() - begin_time)
|
|
|
- / processed_files
|
|
|
- * (len(files) - processed_files)
|
|
|
- )
|
|
|
- logger.info(
|
|
|
- f"{RANK_STR}Processed {processed_files} files, {total_time / 3600:.2f} hours of audio, ETA: {timedelta(seconds=round(eta))}s"
|
|
|
- )
|
|
|
-
|
|
|
- # Write to file
|
|
|
- for file, transcription in zip(batch, trascriptions):
|
|
|
- Path(file).with_suffix(".whisper.txt").write_text(
|
|
|
- transcription, encoding="utf-8"
|
|
|
- )
|
|
|
-
|
|
|
- # Stop if total time is more than 1000 / world_size hours
|
|
|
- if total_time > 1000 / world_size * 3600:
|
|
|
- break
|
|
|
-
|
|
|
- logger.info(
|
|
|
- f"{RANK_STR}Finished processing {len(files)} files, {total_time / 3600:.2f} hours of audio"
|
|
|
- )
|
|
|
+ for segment_file in tqdm(
|
|
|
+ list(save_path.glob("*.wav")), desc="Transcribing segments"
|
|
|
+ ):
|
|
|
+ text, _ = transcribe_segment(model, segment_file)
|
|
|
+ with open(segment_file.with_suffix(".lab"), "w", encoding="utf-8") as f:
|
|
|
+ f.write(text)
|
|
|
+
|
|
|
+ # process_output(save_dir, language, save_dir / "output.txt") # Dont need summarize to one file
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
- main()
|
|
|
+ parser = argparse.ArgumentParser(description="Audio Transcription with Whisper")
|
|
|
+ parser.add_argument(
|
|
|
+ "--model_size", type=str, default="large", help="Size of the Whisper model"
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ "--audio_dir", type=str, required=True, help="Directory containing audio files"
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ "--save_dir",
|
|
|
+ type=str,
|
|
|
+ required=True,
|
|
|
+ help="Directory to save processed audio files",
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ "--language", type=str, default="ZH", help="Language of the transcription"
|
|
|
+ )
|
|
|
+ parser.add_argument("--out_sr", type=int, default=44100, help="Output sample rate")
|
|
|
+ args = parser.parse_args()
|
|
|
+
|
|
|
+ main(args.model_size, args.audio_dir, args.save_dir, args.out_sr, args.language)
|