prepare_dataset.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. import json
  2. import os
  3. from pathlib import Path
  4. import librosa
  5. import torch
  6. from datasets import Dataset
  7. from multiprocess import set_start_method
  8. from transformers import AutoProcessor, EncodecModel
  9. set_start_method("spawn", force=True)
  10. encodec_name = "facebook/encodec_24khz"
  11. encodec_processor = AutoProcessor.from_pretrained(encodec_name)
  12. encodec_model = EncodecModel.from_pretrained(encodec_name)
  13. encodec_model.eval()
  14. def tokenize(text, audio, sr=None, speaker=None):
  15. assert sr is None or sr == encodec_processor.sampling_rate
  16. if isinstance(audio, (str, Path)):
  17. audio, sr = librosa.load(audio, sr=sr, mono=True)
  18. prompt = "[INST] "
  19. if speaker:
  20. prompt += f"[SPK] {speaker} [/SPK] "
  21. prompt += f"{text} [/INST] "
  22. inputs = encodec_processor(
  23. raw_audio=audio, sampling_rate=sr, return_tensors="pt"
  24. ).to(encodec_model.device)
  25. outputs = encodec_model.encode(
  26. inputs["input_values"], inputs["padding_mask"], bandwidth=1.5, return_dict=True
  27. )
  28. assert outputs.audio_codes.dim() == 4 # [batch, channel, codebook, code]
  29. assert outputs.audio_codes.shape[0] == outputs.audio_codes.shape[1] == 1
  30. codes = outputs.audio_codes[0, 0, 0, :].long()
  31. codes_str = " ".join([f"<encodec_{int(c)}>" for c in codes.tolist()])
  32. prompt += codes_str
  33. return {
  34. "prompt": prompt,
  35. "codes": codes,
  36. }
  37. def wrap_tokenize(x):
  38. device = torch.device("cuda", 0)
  39. if encodec_model.device != device:
  40. encodec_model.to(device)
  41. return tokenize(
  42. text=x["text"],
  43. audio=x["raw_path"],
  44. sr=encodec_processor.sampling_rate,
  45. speaker=x["speaker"],
  46. )
  47. def generator_libritts_r():
  48. base = Path("dataset/tts/LibriTTS_R")
  49. for i in base.rglob("*.wav"):
  50. text_file = i.with_suffix(".normalized.txt")
  51. if not text_file.exists():
  52. continue
  53. text = text_file.read_text().strip()
  54. yield {
  55. "text": text,
  56. "speaker": f"libritts_{i.parent.parent.name}",
  57. "raw_path": str(i),
  58. "path": str(i.relative_to(base)),
  59. }
  60. if __name__ == "__main__":
  61. dataset = Dataset.from_generator(generator_libritts_r)
  62. dataset = dataset.map(wrap_tokenize, num_proc=12)
  63. dataset = dataset.remove_columns(["raw_path"])
  64. dataset.save_to_disk("dataset/tts/libritts-r-encodec")
  65. dataset.push_to_hub("fishaudio/libritts-r-encodec", private=True)