vqgan.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. from dataclasses import dataclass
  2. from pathlib import Path
  3. from typing import Optional
  4. import librosa
  5. import numpy as np
  6. import torch
  7. from lightning import LightningDataModule
  8. from torch.utils.data import DataLoader, Dataset
  9. from fish_speech.utils import RankedLogger
  10. logger = RankedLogger(__name__, rank_zero_only=False)
  11. class VQGANDataset(Dataset):
  12. def __init__(
  13. self,
  14. filelist: str,
  15. sample_rate: int = 32000,
  16. hop_length: int = 640,
  17. slice_frames: Optional[int] = None,
  18. ):
  19. super().__init__()
  20. filelist = Path(filelist)
  21. root = filelist.parent
  22. self.files = [root / line.strip() for line in filelist.read_text().splitlines()]
  23. self.sample_rate = sample_rate
  24. self.hop_length = hop_length
  25. self.slice_frames = slice_frames
  26. def __len__(self):
  27. return len(self.files)
  28. def get_item(self, idx):
  29. file = self.files[idx]
  30. audio, _ = librosa.load(file, sr=self.sample_rate, mono=True)
  31. features = np.load(file.with_suffix(".npy")) # (T, 1024)
  32. # Slice audio and features
  33. if self.slice_frames is not None and features.shape[0] > self.slice_frames:
  34. start = np.random.randint(0, features.shape[0] - self.slice_frames)
  35. features = features[start : start + self.slice_frames]
  36. start_in_seconds, end_in_seconds = (
  37. start * 320 / 16000,
  38. (start + self.slice_frames) * 320 / 16000,
  39. )
  40. audio = audio[
  41. int(start_in_seconds * self.sample_rate) : int(
  42. end_in_seconds * self.sample_rate
  43. )
  44. ]
  45. if len(audio) == 0:
  46. return None
  47. max_value = np.abs(audio).max()
  48. if max_value > 1.0:
  49. audio = audio / max_value
  50. return {
  51. "audio": torch.from_numpy(audio),
  52. "features": torch.from_numpy(features),
  53. }
  54. def __getitem__(self, idx):
  55. try:
  56. return self.get_item(idx)
  57. except Exception as e:
  58. logger.error(f"Error loading {self.files[idx]}: {e}")
  59. return None
  60. @dataclass
  61. class VQGANCollator:
  62. def __call__(self, batch):
  63. batch = [x for x in batch if x is not None]
  64. audio_lengths = torch.tensor([len(x["audio"]) for x in batch])
  65. feature_lengths = torch.tensor([len(x["features"]) for x in batch])
  66. audio_maxlen = audio_lengths.max()
  67. feature_maxlen = feature_lengths.max()
  68. # Rounds up to nearest multiple of 2 (audio_lengths)
  69. audios, features = [], []
  70. for x in batch:
  71. audios.append(
  72. torch.nn.functional.pad(x["audio"], (0, audio_maxlen - len(x["audio"])))
  73. )
  74. features.append(
  75. torch.nn.functional.pad(
  76. x["features"], (0, feature_maxlen - len(x["features"]))
  77. )
  78. )
  79. return {
  80. "audios": torch.stack(audios),
  81. "features": torch.stack(features),
  82. "audio_lengths": audio_lengths,
  83. "feature_lengths": feature_lengths,
  84. }
  85. class VQGANDataModule(LightningDataModule):
  86. def __init__(
  87. self,
  88. train_dataset: VQGANDataset,
  89. val_dataset: VQGANDataset,
  90. batch_size: int = 32,
  91. num_workers: int = 4,
  92. val_batch_size: Optional[int] = None,
  93. ):
  94. super().__init__()
  95. self.train_dataset = train_dataset
  96. self.val_dataset = val_dataset
  97. self.batch_size = batch_size
  98. self.val_batch_size = val_batch_size or batch_size
  99. self.num_workers = num_workers
  100. def train_dataloader(self):
  101. return DataLoader(
  102. self.train_dataset,
  103. batch_size=self.batch_size,
  104. collate_fn=VQGANCollator(),
  105. num_workers=self.num_workers,
  106. shuffle=True,
  107. )
  108. def val_dataloader(self):
  109. return DataLoader(
  110. self.val_dataset,
  111. batch_size=self.batch_size,
  112. collate_fn=VQGANCollator(),
  113. num_workers=self.num_workers,
  114. )
  115. if __name__ == "__main__":
  116. dataset = VQGANDataset("data/LibriTTS_R/vq_train_filelist.txt")
  117. dataloader = DataLoader(
  118. dataset, batch_size=4, shuffle=False, collate_fn=VQGANCollator()
  119. )
  120. for batch in dataloader:
  121. print(batch["audios"].shape)
  122. print(batch["features"].shape)
  123. print(batch["audio_lengths"])
  124. print(batch["feature_lengths"])
  125. break