smart_pad.py 1.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. import random
  2. from multiprocessing import Pool
  3. from pathlib import Path
  4. import click
  5. import librosa
  6. import torch.nn.functional as F
  7. import torchaudio
  8. from tqdm import tqdm
  9. from tools.file import AUDIO_EXTENSIONS, list_files
  10. threshold = 10 ** (-50 / 20.0)
  11. def process(file):
  12. waveform, sample_rate = torchaudio.load(str(file), backend="sox")
  13. if waveform.size(0) > 1:
  14. waveform = waveform.mean(dim=0, keepdim=True)
  15. loudness = librosa.feature.rms(
  16. y=waveform.numpy().squeeze(), frame_length=2048, hop_length=512, center=True
  17. )[0]
  18. for i in range(len(loudness) - 1, 0, -1):
  19. if loudness[i] > threshold:
  20. break
  21. end_silent_time = (len(loudness) - i) * 512 / sample_rate
  22. if end_silent_time <= 0.3:
  23. random_time = random.uniform(0.3, 0.7) - end_silent_time
  24. waveform = F.pad(
  25. waveform, (0, int(random_time * sample_rate)), mode="constant", value=0
  26. )
  27. for i in range(len(loudness)):
  28. if loudness[i] > threshold:
  29. break
  30. start_silent_time = i * 512 / sample_rate
  31. if start_silent_time > 0.02:
  32. waveform = waveform[:, int((start_silent_time - 0.02) * sample_rate) :]
  33. torchaudio.save(uri=str(file), src=waveform, sample_rate=sample_rate)
  34. @click.command()
  35. @click.argument("source", type=Path)
  36. @click.option("--num-workers", type=int, default=12)
  37. def main(source, num_workers):
  38. files = list(list_files(source, AUDIO_EXTENSIONS, recursive=True))
  39. with Pool(num_workers) as p:
  40. list(tqdm(p.imap_unordered(process, files), total=len(files)))
  41. if __name__ == "__main__":
  42. main()