vits.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. import random
  2. from dataclasses import dataclass
  3. from pathlib import Path
  4. from typing import Optional
  5. import librosa
  6. import numpy as np
  7. import torch
  8. import torch.distributed as dist
  9. from lightning import LightningDataModule
  10. from torch.utils.data import DataLoader, Dataset
  11. from torch.utils.data.distributed import DistributedSampler
  12. from transformers import AutoTokenizer
  13. from fish_speech.utils import RankedLogger
  14. logger = RankedLogger(__name__, rank_zero_only=False)
  15. class VITSDataset(Dataset):
  16. def __init__(
  17. self,
  18. filelist: str,
  19. tokenizer: AutoTokenizer,
  20. sample_rate: int = 44100,
  21. hop_length: int = 512,
  22. min_duration: float = 1.5,
  23. max_duration: float = 30.0,
  24. suffix: str = ".lab",
  25. sentence_mask_ratio: float = 0.0,
  26. ):
  27. super().__init__()
  28. filelist = Path(filelist)
  29. root = filelist.parent
  30. self.files = []
  31. for line in filelist.read_text(encoding="utf-8").splitlines():
  32. path = root / line
  33. self.files.append(path)
  34. self.sample_rate = sample_rate
  35. self.hop_length = hop_length
  36. self.min_duration = min_duration
  37. self.max_duration = max_duration
  38. self.tokenizer = tokenizer
  39. self.suffix = suffix
  40. self.sentence_mask_ratio = sentence_mask_ratio
  41. def __len__(self):
  42. return len(self.files)
  43. def get_item(self, idx):
  44. audio_file = self.files[idx]
  45. text_file = audio_file.with_suffix(self.suffix)
  46. if text_file.exists() is False or audio_file.exists() is False:
  47. return None
  48. audio, _ = librosa.load(audio_file, sr=self.sample_rate, mono=True)
  49. duration = len(audio) / self.sample_rate
  50. # Pad to minimum duration
  51. if duration < self.min_duration:
  52. pad_duration = self.min_duration - duration
  53. pad_samples = int(pad_duration * self.sample_rate)
  54. audio = np.pad(audio, (0, pad_samples))
  55. # Truncate to maximum duration
  56. if duration > self.max_duration:
  57. random_start = random.randint(
  58. 0, len(audio) - int(self.max_duration * self.sample_rate) - 1
  59. )
  60. audio = audio[
  61. random_start : random_start + int(self.max_duration * self.sample_rate)
  62. ]
  63. max_value = np.abs(audio).max()
  64. if max_value > 1.0:
  65. audio = audio / max_value
  66. if random.random() < self.sentence_mask_ratio:
  67. text = "-"
  68. else:
  69. text = text_file.read_text(encoding="utf-8")
  70. input_ids = self.tokenizer(text, return_tensors="pt").input_ids.squeeze(0)
  71. return {
  72. "audio": torch.from_numpy(audio),
  73. "text": input_ids,
  74. }
  75. def __getitem__(self, idx):
  76. try:
  77. return self.get_item(idx)
  78. except Exception as e:
  79. import traceback
  80. traceback.print_exc()
  81. logger.error(f"Error loading {self.files[idx]}: {e}")
  82. return None
  83. @dataclass
  84. class VITSCollator:
  85. tokenizer: AutoTokenizer
  86. def __call__(self, batch):
  87. batch = [x for x in batch if x is not None]
  88. audio_lengths = torch.tensor([len(x["audio"]) for x in batch])
  89. audio_maxlen = audio_lengths.max()
  90. text_lengths = torch.tensor([len(x["text"]) for x in batch])
  91. text_maxlen = text_lengths.max()
  92. # Rounds up to nearest multiple of 2 (audio_lengths)
  93. audios = []
  94. texts = []
  95. for x in batch:
  96. audios.append(
  97. torch.nn.functional.pad(x["audio"], (0, audio_maxlen - len(x["audio"])))
  98. )
  99. texts.append(
  100. torch.nn.functional.pad(
  101. x["text"],
  102. (0, text_maxlen - len(x["text"])),
  103. value=self.tokenizer.eos_token_id,
  104. )
  105. )
  106. return {
  107. "audios": torch.stack(audios),
  108. "audio_lengths": audio_lengths,
  109. "texts": torch.stack(texts),
  110. "text_lengths": text_lengths,
  111. }
  112. class VITSDataModule(LightningDataModule):
  113. def __init__(
  114. self,
  115. train_dataset: VITSDataset,
  116. val_dataset: VITSDataset,
  117. tokenizer: AutoTokenizer,
  118. batch_size: int = 32,
  119. num_workers: int = 4,
  120. val_batch_size: Optional[int] = None,
  121. ):
  122. super().__init__()
  123. self.train_dataset = train_dataset
  124. self.val_dataset = val_dataset
  125. self.batch_size = batch_size
  126. self.val_batch_size = val_batch_size or batch_size
  127. self.num_workers = num_workers
  128. self.tokenizer = tokenizer
  129. def train_dataloader(self):
  130. return DataLoader(
  131. self.train_dataset,
  132. batch_size=self.batch_size,
  133. collate_fn=VITSCollator(self.tokenizer),
  134. num_workers=self.num_workers,
  135. shuffle=False,
  136. persistent_workers=True,
  137. )
  138. def val_dataloader(self):
  139. return DataLoader(
  140. self.val_dataset,
  141. batch_size=self.val_batch_size,
  142. collate_fn=VITSCollator(self.tokenizer),
  143. num_workers=self.num_workers,
  144. persistent_workers=True,
  145. )
  146. if __name__ == "__main__":
  147. tokenizer = AutoTokenizer.from_pretrained("fishaudio/fish-speech-1")
  148. dataset = VITSDataset(
  149. "data/source/Genshin/filelist.train.txt", tokenizer=tokenizer, suffix=".lab"
  150. )
  151. dataloader = DataLoader(
  152. dataset, batch_size=4, shuffle=False, collate_fn=VITSCollator(tokenizer)
  153. )
  154. for batch in dataloader:
  155. print(batch["audios"].shape)
  156. print(batch["audio_lengths"])
  157. print(batch["texts"].shape)
  158. print(batch["text_lengths"])
  159. break