clean_wenet_speech.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. import json
  2. from pathlib import Path
  3. import subprocess
  4. import librosa
  5. import soundfile as sf
  6. import torch
  7. import torchaudio
  8. from fish_audio_preprocess.utils.separate_audio import (
  9. separate_audio,
  10. merge_tracks,
  11. init_model,
  12. )
  13. from tqdm import tqdm
  14. import time
  15. import os
  16. import tempfile
  17. rank = int(os.environ.get("SLURM_PROCID", 0))
  18. world_size = int(os.environ.get("SLURM_NTASKS", 1))
  19. device = torch.device("cuda:0")
  20. print(f"Rank {rank}/{world_size} on {device}")
  21. def main():
  22. meta_path = Path("dataset/tts/WenetSpeech/WenetSpeech.json")
  23. dataset_path = Path("dataset/tts/WenetSpeech")
  24. cleaned_path = Path("dataset/tts/WenetSpeech/cleaned")
  25. if not cleaned_path.exists():
  26. cleaned_path.mkdir(parents=True)
  27. demucs = init_model("htdemucs", device)
  28. print("Model loaded")
  29. with open(meta_path) as f:
  30. dataset = json.load(f)["audios"]
  31. print(f"Dataset loaded, {len(dataset)} samples")
  32. dataset = dataset[rank::world_size]
  33. print(f"Dataset split, {len(dataset)} samples")
  34. for data_idx, data in enumerate(dataset):
  35. done_path = cleaned_path / data["aid"] / "done"
  36. done_path.parent.mkdir(parents=True, exist_ok=True)
  37. if done_path.exists():
  38. continue
  39. print(f"Processing {data_idx}/{len(dataset)} at rank {rank}")
  40. try:
  41. with tempfile.NamedTemporaryFile(suffix=".wav") as f:
  42. subprocess.check_call(
  43. [
  44. "ffmpeg",
  45. "-y",
  46. "-i",
  47. str(dataset_path / data["path"]),
  48. "-c:a",
  49. "pcm_s16le",
  50. "-threads",
  51. "0",
  52. "-ar",
  53. "24000",
  54. str(f.name),
  55. ],
  56. stdout=subprocess.DEVNULL,
  57. stderr=subprocess.DEVNULL,
  58. )
  59. raw_audio, sr = librosa.load(f.name, sr=None, mono=True)
  60. raw_audio = torch.from_numpy(raw_audio[None]).to(device)
  61. audio = torchaudio.functional.resample(
  62. raw_audio, orig_freq=sr, new_freq=demucs.samplerate
  63. )
  64. # Make it 2 channels
  65. audio = torch.cat([audio, audio], dim=0)
  66. tracks = separate_audio(demucs, audio, shifts=1, num_workers=0, progress=False)
  67. audio = merge_tracks(tracks, filter=["vocals"])[0]
  68. vocals, sr = (
  69. torchaudio.functional.resample(
  70. audio, orig_freq=demucs.samplerate, new_freq=24000
  71. ),
  72. 24000,
  73. )
  74. vocals = vocals.cpu().numpy()
  75. for idx, segment in enumerate(data["segments"]):
  76. if segment["confidence"] <= 0.95:
  77. continue
  78. # Load audio
  79. begin = int(segment["begin_time"] * sr)
  80. end = int(segment["end_time"] * sr)
  81. segment_audio = vocals[begin:end]
  82. # Write audio
  83. temp_path = cleaned_path / data["aid"] / f"S{idx:05d}.wav"
  84. temp_path.parent.mkdir(parents=True, exist_ok=True)
  85. sf.write(temp_path, segment_audio, samplerate=sr)
  86. # Write text
  87. temp_path = temp_path.with_suffix(".txt")
  88. temp_path.write_text(segment["text"])
  89. # Write done file
  90. done_path.write_text("")
  91. except Exception as e:
  92. print(f"Error {e} on {data_idx}/{len(dataset)} at rank {rank}")
  93. time.sleep(10)
  94. continue
  95. print("Done")
  96. if __name__ == "__main__":
  97. main()