Explorar o código

Docker Compose and Data Preprocessing Script (#10)

* 1.add Docker Compose for development; 2.add pre_data for precess dataset

* 1.add Docker Compose for development; 2.add pre_data for precess dataset

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* change pre_dataset to whisper_asr.py

* change pre_dataset to whisper_asr.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
paidax %!s(int64=2) %!d(string=hai) anos
pai
achega
cf69582a0a
Modificáronse 2 ficheiros con 125 adicións e 164 borrados
  1. 18 0
      docker-compose.dev.yml
  2. 107 164
      tools/whisper_asr.py

+ 18 - 0
docker-compose.dev.yml

@@ -0,0 +1,18 @@
+version: '3.8'
+
+services:
+  fish-speech:
+    build: .
+    container_name: fish-speech
+    volumes:
+      - ./data:/exp/data
+      - ./raw_data:/exp/raw_data
+    deploy:
+      resources:
+        reservations:
+          devices:
+            - driver: nvidia
+              count: all
+              capabilities: [gpu]
+    command: tail -f /dev/null
+

+ 107 - 164
tools/whisper_asr.py

@@ -1,183 +1,126 @@
-# This file is used to convert the audio files to text files using the Whisper model.
-# It's mainly used to generate the training data for the VQ model.
-
+"""
+Used to transcribe all audio files in one folder into another folder.
+e.g.
+Directory structure:
+--pre_data_root
+----SP_1
+------01.wav
+------02.wav
+------......
+----SP_2
+------01.wav
+------02.wav
+------......
+Use 
+python tools/whisper_asr.py --audio_dir pre_data_root/SP_1 --save_dir data/SP_1 
+to transcribe the first speaker.
+
+Use 
+python tools/whisper_asr.py --audio_dir pre_data_root/SP_2 --save_dir data/SP_2 
+to transcribe the second speaker.
+
+Note: Be aware of your audio sample rate, which defaults to 44.1kHz.
+"""
+
+import argparse
 import os
-import subprocess as sp
-import time
-from datetime import timedelta
-from functools import lru_cache
 from pathlib import Path
-from random import Random
 
-import click
+import librosa
 import numpy as np
-import torch
-from loguru import logger
-from transformers import WhisperProcessor
-from whisper.audio import load_audio, log_mel_spectrogram, pad_or_trim
+import whisper
+from scipy.io import wavfile
+from tqdm import tqdm
 
-from fish_speech.modules.flash_whisper import FlashWhisperForConditionalGeneration
 
-RANK_STR = ""
+def load_and_normalize_audio(filepath, target_sr):
+    wav, sr = librosa.load(filepath, sr=None, mono=True)
+    wav, _ = librosa.effects.trim(wav, top_db=20)
+    peak = np.abs(wav).max()
+    if peak > 1.0:
+        wav /= peak / 0.98
+    return librosa.resample(wav, orig_sr=sr, target_sr=target_sr), target_sr
 
 
-@lru_cache(maxsize=1)
-def get_whisper_model():
-    model = FlashWhisperForConditionalGeneration.from_pretrained(
-        "openai/whisper-medium"
-    ).cuda()
-    model.eval()
-    logger.info(f"{RANK_STR}Loaded model")
+def transcribe_audio(model, filepath):
+    return model.transcribe(
+        filepath, word_timestamps=True, task="transcribe", beam_size=5, best_of=5
+    )
 
-    return model
 
+def save_audio_segments(segments, wav, sr, save_path):
+    for i, seg in enumerate(segments):
+        start_time, end_time = seg["start"], seg["end"]
+        wav_seg = wav[int(start_time * sr) : int(end_time * sr)]
+        wav_seg_name = f"{save_path.stem}_{i}.wav"
+        out_fpath = save_path / wav_seg_name
+        wavfile.write(
+            out_fpath, rate=sr, data=(wav_seg * np.iinfo(np.int16).max).astype(np.int16)
+        )
 
-@lru_cache(maxsize=1)
-def get_whisper_processor():
-    return WhisperProcessor.from_pretrained("openai/whisper-medium")
 
+def transcribe_segment(model, filepath):
+    audio = whisper.load_audio(filepath)
+    audio = whisper.pad_or_trim(audio)
+    mel = whisper.log_mel_spectrogram(audio, n_mels=128).to(model.device)
+    _, probs = model.detect_language(mel)
+    lang = max(probs, key=probs.get)
+    options = whisper.DecodingOptions(beam_size=5)
+    result = whisper.decode(model, mel, options)
+    return result.text, lang
+
+
+def process_output(save_dir, language, out_file):
+    with open(out_file, "w", encoding="utf-8") as wf:
+        ch_name = save_dir.stem
+        for file in save_dir.glob("*.lab"):
+            with open(file, "r", encoding="utf-8") as perFile:
+                line = perFile.readline().strip()
+                result = (
+                    f"{save_dir}/{ch_name}/{file.stem}.wav|{ch_name}|{language}|{line}"
+                )
+                wf.write(f"{result}\n")
 
-def transcribe_batch(files: list[str], language: str):
-    wavs = [load_audio(file, 16000) for file in files]
-    total_time = sum([len(wav) for wav in wavs]) / 16000
-    wavs = [pad_or_trim(wav) for wav in wavs]
 
-    wavs = torch.from_numpy(np.stack(wavs)).float().cuda()
-    mels = log_mel_spectrogram(wavs).cuda()
-    model = get_whisper_model()
-    processor = get_whisper_processor()
-    forced_decoder_ids = processor.get_decoder_prompt_ids(
-        language=language, task="transcribe"
-    )
+def main(model_size, audio_dir, save_dir, out_sr, language):
+    model = whisper.load_model(model_size)
+    audio_dir, save_dir = Path(audio_dir), Path(save_dir)
+    save_dir.mkdir(exist_ok=True)
 
-    with torch.no_grad():
-        outputs = model.generate(
-            input_features=mels,
-            max_length=448,
-            do_sample=False,
-            forced_decoder_ids=forced_decoder_ids,
-        )
+    for filepath in tqdm(list(audio_dir.glob("*.wav")), desc="Processing files"):
+        wav, sr = load_and_normalize_audio(filepath, out_sr)
+        transcription = transcribe_audio(model, filepath)
+        save_path = save_dir / filepath.stem
+        save_audio_segments(transcription["segments"], wav, sr, save_path)
 
-    outputs = outputs.cpu().tolist()
-
-    # Remove EOS token
-    for output in outputs:
-        while output[-1] in [
-            processor.tokenizer.pad_token_id,
-            processor.tokenizer.eos_token_id,
-        ]:
-            output.pop()
-        output.append(processor.tokenizer.eos_token_id)
-
-    transcriptions = processor.batch_decode(outputs, skip_special_tokens=False)
-    tokens = [",".join(map(str, line)) for line in outputs]
-    transcriptions = [
-        f"{token}\t{transcription}"
-        for token, transcription in zip(tokens, transcriptions)
-    ]
-
-    return transcriptions, total_time
-
-
-@click.command()
-@click.argument("folder")
-@click.option("--rank", default=0)
-@click.option("--world-size", default=1)
-@click.option("--num-workers", default=1)
-@click.option("--language", default="english")
-def main(folder: str, rank: int, world_size: int, num_workers: int, language: str):
-    global RANK_STR
-
-    if num_workers > 1 and world_size != num_workers:
-        RANK_STR = "[Master] "
-        logger.info(f"{RANK_STR}Spawning {num_workers} workers")
-
-        visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
-        if visible_devices is None:
-            visible_devices = list(range(torch.cuda.device_count()))
-        else:
-            visible_devices = visible_devices.split(",")
-
-        processes = []
-        for i in range(num_workers):
-            env = os.environ.copy()
-            env["CUDA_VISIBLE_DEVICES"] = str(visible_devices[i % len(visible_devices)])
-            args = [
-                "python",
-                __file__,
-                "--rank",
-                str(i),
-                "--world-size",
-                str(num_workers),
-                "--language",
-                language,
-                folder,
-            ]
-            processes.append(
-                sp.Popen(
-                    args,
-                    env=env,
-                )
-            )
-
-        for p in processes:
-            p.wait()
-
-        logger.info(f"{RANK_STR}All workers finished")
-        return
-
-    # This is a worker
-    RANK_STR = f"[Rank: {rank}] "
-    logger.info(f"{RANK_STR}Starting worker")
-
-    files = [
-        str(file)
-        for file in Path(folder).rglob("*")
-        if file.suffix in [".wav", ".flac"]
-    ]
-
-    logger.info(f"{RANK_STR}Found {len(files)} files")
-
-    files = sorted(files)
-    Random(42).shuffle(files)
-    files = files[rank::world_size]
-    logger.info(f"{RANK_STR}Processing {len(files)} files")
-
-    # Batch size 64
-    total_time = 0
-    begin_time = time.time()
-    processed_files = 0
-
-    for n_batch, idx in enumerate(range(0, len(files), 64)):
-        batch = files[idx : idx + 64]
-        trascriptions, batch_time = transcribe_batch(batch, language)
-        total_time += batch_time
-        processed_files += len(batch)
-
-        if (n_batch + 1) % 10 == 0:
-            eta = (
-                (time.time() - begin_time)
-                / processed_files
-                * (len(files) - processed_files)
-            )
-            logger.info(
-                f"{RANK_STR}Processed {processed_files} files, {total_time / 3600:.2f} hours of audio, ETA: {timedelta(seconds=round(eta))}s"
-            )
-
-        # Write to file
-        for file, transcription in zip(batch, trascriptions):
-            Path(file).with_suffix(".whisper.txt").write_text(
-                transcription, encoding="utf-8"
-            )
-
-        # Stop if total time is more than 1000 / world_size hours
-        if total_time > 1000 / world_size * 3600:
-            break
-
-    logger.info(
-        f"{RANK_STR}Finished processing {len(files)} files, {total_time / 3600:.2f} hours of audio"
-    )
+        for segment_file in tqdm(
+            list(save_path.glob("*.wav")), desc="Transcribing segments"
+        ):
+            text, _ = transcribe_segment(model, segment_file)
+            with open(segment_file.with_suffix(".lab"), "w", encoding="utf-8") as f:
+                f.write(text)
+
+    # process_output(save_dir, language, save_dir / "output.txt") # Dont need summarize to one file
 
 
 if __name__ == "__main__":
-    main()
+    parser = argparse.ArgumentParser(description="Audio Transcription with Whisper")
+    parser.add_argument(
+        "--model_size", type=str, default="large", help="Size of the Whisper model"
+    )
+    parser.add_argument(
+        "--audio_dir", type=str, required=True, help="Directory containing audio files"
+    )
+    parser.add_argument(
+        "--save_dir",
+        type=str,
+        required=True,
+        help="Directory to save processed audio files",
+    )
+    parser.add_argument(
+        "--language", type=str, default="ZH", help="Language of the transcription"
+    )
+    parser.add_argument("--out_sr", type=int, default=44100, help="Output sample rate")
+    args = parser.parse_args()
+
+    main(args.model_size, args.audio_dir, args.save_dir, args.out_sr, args.language)