vqgan.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. from dataclasses import dataclass
  2. from pathlib import Path
  3. from typing import Optional
  4. import librosa
  5. import numpy as np
  6. import torch
  7. from lightning import LightningDataModule
  8. from torch.utils.data import DataLoader, Dataset
  9. from fish_speech.utils import RankedLogger
  10. logger = RankedLogger(__name__, rank_zero_only=False)
  11. class VQGANDataset(Dataset):
  12. def __init__(
  13. self,
  14. filelist: str,
  15. sample_rate: int = 32000,
  16. hop_length: int = 640,
  17. slice_frames: Optional[int] = None,
  18. ):
  19. super().__init__()
  20. filelist = Path(filelist)
  21. root = filelist.parent
  22. self.files = [
  23. root / line.strip()
  24. for line in filelist.read_text(encoding="utf-8").splitlines()
  25. if line.strip()
  26. ]
  27. self.sample_rate = sample_rate
  28. self.hop_length = hop_length
  29. self.slice_frames = slice_frames
  30. def __len__(self):
  31. return len(self.files)
  32. def get_item(self, idx):
  33. file = self.files[idx]
  34. audio, _ = librosa.load(file, sr=self.sample_rate, mono=True)
  35. # Slice audio and features
  36. if (
  37. self.slice_frames is not None
  38. and audio.shape[0] > self.slice_frames * self.hop_length
  39. ):
  40. start = np.random.randint(
  41. 0, audio.shape[0] - self.slice_frames * self.hop_length
  42. )
  43. audio = audio[start : start + self.slice_frames * self.hop_length]
  44. if len(audio) == 0:
  45. return None
  46. max_value = np.abs(audio).max()
  47. if max_value > 1.0:
  48. audio = audio / max_value
  49. return {
  50. "audio": torch.from_numpy(audio),
  51. }
  52. def __getitem__(self, idx):
  53. try:
  54. return self.get_item(idx)
  55. except Exception as e:
  56. import traceback
  57. traceback.print_exc()
  58. logger.error(f"Error loading {self.files[idx]}: {e}")
  59. return None
  60. @dataclass
  61. class VQGANCollator:
  62. def __call__(self, batch):
  63. batch = [x for x in batch if x is not None]
  64. audio_lengths = torch.tensor([len(x["audio"]) for x in batch])
  65. audio_maxlen = audio_lengths.max()
  66. # Rounds up to nearest multiple of 2 (audio_lengths)
  67. audios = []
  68. for x in batch:
  69. audios.append(
  70. torch.nn.functional.pad(x["audio"], (0, audio_maxlen - len(x["audio"])))
  71. )
  72. return {
  73. "audios": torch.stack(audios),
  74. "audio_lengths": audio_lengths,
  75. }
  76. class VQGANDataModule(LightningDataModule):
  77. def __init__(
  78. self,
  79. train_dataset: VQGANDataset,
  80. val_dataset: VQGANDataset,
  81. batch_size: int = 32,
  82. num_workers: int = 4,
  83. val_batch_size: Optional[int] = None,
  84. ):
  85. super().__init__()
  86. self.train_dataset = train_dataset
  87. self.val_dataset = val_dataset
  88. self.batch_size = batch_size
  89. self.val_batch_size = val_batch_size or batch_size
  90. self.num_workers = num_workers
  91. def train_dataloader(self):
  92. return DataLoader(
  93. self.train_dataset,
  94. batch_size=self.batch_size,
  95. collate_fn=VQGANCollator(),
  96. num_workers=self.num_workers,
  97. shuffle=True,
  98. persistent_workers=True,
  99. )
  100. def val_dataloader(self):
  101. return DataLoader(
  102. self.val_dataset,
  103. batch_size=self.val_batch_size,
  104. collate_fn=VQGANCollator(),
  105. num_workers=self.num_workers,
  106. persistent_workers=True,
  107. )
  108. if __name__ == "__main__":
  109. dataset = VQGANDataset("data/LibriTTS_R/vq_train_filelist.txt")
  110. dataloader = DataLoader(
  111. dataset, batch_size=4, shuffle=False, collate_fn=VQGANCollator()
  112. )
  113. for batch in dataloader:
  114. print(batch["audios"].shape)
  115. print(batch["features"].shape)
  116. print(batch["audio_lengths"])
  117. print(batch["feature_lengths"])
  118. break