vqgan.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. from dataclasses import dataclass
  2. from pathlib import Path
  3. import librosa
  4. import numpy as np
  5. import torch
  6. from lightning import LightningDataModule
  7. from torch.utils.data import Dataset
  8. class VQGANDataset(Dataset):
  9. def __init__(
  10. self,
  11. filelist: str,
  12. sample_rate: int = 32000,
  13. ):
  14. super().__init__()
  15. filelist = Path(filelist)
  16. root = filelist.parent
  17. self.files = [root / line.strip() for line in filelist.read_text().splitlines()]
  18. self.sample_rate = sample_rate
  19. def __len__(self):
  20. return len(self.files)
  21. def __getitem__(self, idx):
  22. file = self.files[idx]
  23. audio, _ = librosa.load(file, sr=self.sample_rate, mono=True)
  24. features = np.load(file.with_suffix(".npy")) # (T, 1024)
  25. return {
  26. "audio": torch.from_numpy(audio),
  27. "features": torch.from_numpy(features),
  28. }
  29. @dataclass
  30. class VQGANCollator:
  31. hop_length: int = 640
  32. def __call__(self, batch):
  33. audio_lengths = torch.tensor([len(x["audio"]) for x in batch])
  34. feature_lengths = torch.tensor([len(x["features"]) for x in batch])
  35. audio_maxlen = audio_lengths.max()
  36. feature_maxlen = feature_lengths.max()
  37. if audio_maxlen % self.hop_length != 0:
  38. audio_maxlen += self.hop_length - (audio_maxlen % self.hop_length)
  39. audios, features = [], []
  40. for x in batch:
  41. audios.append(
  42. torch.nn.functional.pad(x["audio"], (0, audio_maxlen - len(x["audio"])))
  43. )
  44. features.append(
  45. torch.nn.functional.pad(
  46. x["features"], (0, 0, 0, feature_maxlen - len(x["features"]))
  47. )
  48. )
  49. return {
  50. "audios": torch.stack(audios),
  51. "features": torch.stack(features),
  52. "audio_lengths": audio_lengths,
  53. "feature_lengths": feature_lengths,
  54. }
  55. class VQGANDataModule(LightningDataModule):
  56. def __init__(
  57. self,
  58. train_dataset: VQGANDataset,
  59. val_dataset: VQGANDataset,
  60. batch_size: int = 32,
  61. hop_length: int = 640,
  62. num_workers: int = 4,
  63. ):
  64. super().__init__()
  65. self.train_dataset = train_dataset
  66. self.val_dataset = val_dataset
  67. self.batch_size = batch_size
  68. self.hop_length = hop_length
  69. self.num_workers = num_workers
  70. def train_dataloader(self):
  71. return DataLoader(
  72. self.train_dataset,
  73. batch_size=self.batch_size,
  74. collate_fn=VQGANCollator(self.hop_length),
  75. num_workers=self.num_workers,
  76. shuffle=True,
  77. )
  78. def val_dataloader(self):
  79. return DataLoader(
  80. self.val_dataset,
  81. batch_size=self.batch_size,
  82. collate_fn=VQGANCollator(self.hop_length),
  83. num_workers=self.num_workers,
  84. )
  85. if __name__ == "__main__":
  86. from torch.utils.data import DataLoader
  87. dataset = VQGANDataset("data/LibriTTS_R/vq_train_filelist.txt")
  88. dataloader = DataLoader(
  89. dataset, batch_size=4, shuffle=False, collate_fn=VQGANCollator()
  90. )
  91. for batch in dataloader:
  92. print(batch["audios"].shape)
  93. print(batch["features"].shape)
  94. print(batch["audio_lengths"])
  95. print(batch["feature_lengths"])
  96. break