build_vq_dataset.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. from functools import lru_cache
  2. from pathlib import Path
  3. import numpy as np
  4. from datasets import Dataset, DatasetDict
  5. @lru_cache(maxsize=1)
  6. def get_phonemes():
  7. phones = {}
  8. phones.update(np.load("dump/phoneme_dev.npy", allow_pickle=True).item())
  9. phones.update(np.load("dump/phoneme_train.npy", allow_pickle=True).item())
  10. phones.update(
  11. np.load(
  12. "/home/fish/hubert-vq-vits/dump/phoneme_dev.npy", allow_pickle=True
  13. ).item()
  14. )
  15. phones.update(
  16. np.load(
  17. "/home/fish/hubert-vq-vits/dump/phoneme_train.npy", allow_pickle=True
  18. ).item()
  19. )
  20. print("Loaded phonemes")
  21. return phones
  22. def parse_data(items):
  23. results = []
  24. phones = get_phonemes()
  25. for item_name, semantic_audio in zip(items["item_name"], items["semantic_audio"]):
  26. file_name = item_name
  27. if item_name.startswith("/wenet-speech-vocals"):
  28. file_name = "/home/fish/wenetspeech/dsall" + item_name
  29. wav_file = Path(file_name)
  30. text_file = wav_file.with_suffix(".txt")
  31. if not text_file.exists():
  32. text_file = wav_file.with_suffix(".lab")
  33. if not text_file.exists():
  34. print(f"Missing {text_file}")
  35. return None
  36. text = text_file.read_text().strip()
  37. semantic = [f"<semantic_{x}>" for x in semantic_audio.split(" ")]
  38. semantic = " ".join(semantic)
  39. results.append(f"[INST] {text} [/INST] {semantic} </s>")
  40. results.append(f"[INST] {phones[item_name]} [/INST] {semantic} </s>")
  41. return {
  42. "text": results,
  43. }
  44. if __name__ == "__main__":
  45. test_dataset = Dataset.from_csv(
  46. ["dump/semantic_dev.tsv", "/home/fish/hubert-vq-vits/dump/semantic_dev.tsv"],
  47. delimiter="\t",
  48. split="test",
  49. )
  50. test_dataset = test_dataset.map(
  51. parse_data,
  52. num_proc=32,
  53. remove_columns=test_dataset.column_names,
  54. batched=True,
  55. batch_size=10000,
  56. )
  57. train_dataset = Dataset.from_csv(
  58. [
  59. "dump/semantic_train.tsv",
  60. "/home/fish/hubert-vq-vits/dump/semantic_train.tsv",
  61. ],
  62. delimiter="\t",
  63. split="train",
  64. )
  65. train_dataset = train_dataset.map(
  66. parse_data,
  67. num_proc=32,
  68. remove_columns=train_dataset.column_names,
  69. batched=True,
  70. batch_size=10000,
  71. )
  72. dataset = DatasetDict(
  73. {
  74. "train": train_dataset,
  75. "test": test_dataset,
  76. }
  77. )
  78. print(
  79. f"There are {len(dataset['train'])} training examples and {len(dataset['test'])} test examples"
  80. )
  81. print(dataset["train"][0])
  82. print(dataset["test"][1])
  83. dataset.push_to_hub("fishaudio/cn-hubert-25hz-vq", private=True)