build_vq_text.py 1.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. from functools import partial
  2. from pathlib import Path
  3. import numpy as np
  4. from datasets import Dataset
  5. def parse_data(phones, items):
  6. results = []
  7. for item_name, semantic_audio in zip(items["item_name"], items["semantic_audio"]):
  8. wav_file = Path(item_name)
  9. text_file = wav_file.with_suffix(".txt")
  10. if not text_file.exists():
  11. text_file = wav_file.with_suffix(".lab")
  12. if not text_file.exists():
  13. print(f"Missing {text_file}")
  14. return None
  15. text = text_file.read_text().strip()
  16. semantic = [f"<semantic_{x}>" for x in semantic_audio.split(" ")]
  17. semantic = " ".join(semantic)
  18. results.append(f"[INST] {text} [/INST] {semantic} </s>")
  19. results.append(f"[INST] {phones[item_name]} [/INST] {semantic} </s>")
  20. return {
  21. "text": results,
  22. }
  23. if __name__ == "__main__":
  24. phones = np.load("dump/phoneme_train.npy", allow_pickle=True).item()
  25. phones1 = np.load(
  26. "/home/fish/hubert-vq-vits/dump/phoneme_train.npy", allow_pickle=True
  27. ).item()
  28. phones.update(phones1)
  29. print(len(phones))
  30. dataset = Dataset.from_csv(
  31. [
  32. "dump/semantic_train.tsv",
  33. "/home/fish/hubert-vq-vits/dump/semantic_train.tsv",
  34. ],
  35. delimiter="\t",
  36. split="train",
  37. )
  38. dataset = dataset.map(
  39. partial(parse_data, phones),
  40. num_proc=32,
  41. remove_columns=dataset.column_names,
  42. batched=True,
  43. )
  44. print(len(dataset), dataset[0])
  45. dataset.push_to_hub("fishaudio/cn-hubert-25hz-vq", private=True)