clean_wenet_speech.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. import json
  2. import os
  3. import subprocess
  4. import tempfile
  5. import time
  6. from pathlib import Path
  7. import librosa
  8. import soundfile as sf
  9. import torch
  10. import torchaudio
  11. from fish_audio_preprocess.utils.separate_audio import (
  12. init_model,
  13. merge_tracks,
  14. separate_audio,
  15. )
  16. from tqdm import tqdm
  17. rank = int(os.environ.get("SLURM_PROCID", 0))
  18. world_size = int(os.environ.get("SLURM_NTASKS", 1))
  19. device = torch.device("cuda:0")
  20. print(f"Rank {rank}/{world_size} on {device}")
  21. def main():
  22. meta_path = Path("dataset/tts/WenetSpeech/WenetSpeech.json")
  23. dataset_path = Path("dataset/tts/WenetSpeech")
  24. cleaned_path = Path("dataset/tts/WenetSpeech/cleaned")
  25. if not cleaned_path.exists():
  26. cleaned_path.mkdir(parents=True)
  27. demucs = init_model("htdemucs", device)
  28. print("Model loaded")
  29. with open(meta_path) as f:
  30. dataset = json.load(f)["audios"]
  31. print(f"Dataset loaded, {len(dataset)} samples")
  32. dataset = dataset[rank::world_size]
  33. print(f"Dataset split, {len(dataset)} samples")
  34. for data_idx, data in enumerate(dataset):
  35. done_path = cleaned_path / data["aid"] / "done"
  36. done_path.parent.mkdir(parents=True, exist_ok=True)
  37. if done_path.exists():
  38. continue
  39. print(f"Processing {data_idx}/{len(dataset)} at rank {rank}")
  40. try:
  41. with tempfile.NamedTemporaryFile(suffix=".wav") as f:
  42. subprocess.check_call(
  43. [
  44. "ffmpeg",
  45. "-y",
  46. "-i",
  47. str(dataset_path / data["path"]),
  48. "-c:a",
  49. "pcm_s16le",
  50. "-threads",
  51. "0",
  52. "-ar",
  53. "24000",
  54. str(f.name),
  55. ],
  56. stdout=subprocess.DEVNULL,
  57. stderr=subprocess.DEVNULL,
  58. )
  59. raw_audio, sr = librosa.load(f.name, sr=None, mono=True)
  60. raw_audio = torch.from_numpy(raw_audio[None]).to(device)
  61. audio = torchaudio.functional.resample(
  62. raw_audio, orig_freq=sr, new_freq=demucs.samplerate
  63. )
  64. # Make it 2 channels
  65. audio = torch.cat([audio, audio], dim=0)
  66. tracks = separate_audio(
  67. demucs, audio, shifts=1, num_workers=0, progress=False
  68. )
  69. audio = merge_tracks(tracks, filter=["vocals"])[0]
  70. vocals, sr = (
  71. torchaudio.functional.resample(
  72. audio, orig_freq=demucs.samplerate, new_freq=24000
  73. ),
  74. 24000,
  75. )
  76. vocals = vocals.cpu().numpy()
  77. for idx, segment in enumerate(data["segments"]):
  78. if segment["confidence"] <= 0.95:
  79. continue
  80. # Load audio
  81. begin = int(segment["begin_time"] * sr)
  82. end = int(segment["end_time"] * sr)
  83. segment_audio = vocals[begin:end]
  84. # Write audio
  85. temp_path = cleaned_path / data["aid"] / f"S{idx:05d}.wav"
  86. temp_path.parent.mkdir(parents=True, exist_ok=True)
  87. sf.write(temp_path, segment_audio, samplerate=sr)
  88. # Write text
  89. temp_path = temp_path.with_suffix(".txt")
  90. temp_path.write_text(segment["text"])
  91. # Write done file
  92. done_path.write_text("")
  93. except Exception as e:
  94. print(f"Error {e} on {data_idx}/{len(dataset)} at rank {rank}")
  95. time.sleep(10)
  96. continue
  97. print("Done")
  98. if __name__ == "__main__":
  99. main()