vits.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. from dataclasses import dataclass
  2. from pathlib import Path
  3. from typing import Optional
  4. import librosa
  5. import numpy as np
  6. import torch
  7. import torch.distributed as dist
  8. from lightning import LightningDataModule
  9. from torch.utils.data import DataLoader, Dataset
  10. from torch.utils.data.distributed import DistributedSampler
  11. from transformers import AutoTokenizer
  12. from fish_speech.utils import RankedLogger
  13. logger = RankedLogger(__name__, rank_zero_only=False)
  14. class VITSDataset(Dataset):
  15. def __init__(
  16. self,
  17. filelist: str,
  18. tokenizer: AutoTokenizer,
  19. sample_rate: int = 44100,
  20. hop_length: int = 512,
  21. min_duration: float = 1.5,
  22. max_duration: float = 30.0,
  23. suffix: str = ".lab",
  24. ):
  25. super().__init__()
  26. filelist = Path(filelist)
  27. root = filelist.parent
  28. self.files = []
  29. for line in filelist.read_text(encoding="utf-8").splitlines():
  30. path = root / line
  31. self.files.append(path)
  32. self.sample_rate = sample_rate
  33. self.hop_length = hop_length
  34. self.min_duration = min_duration
  35. self.max_duration = max_duration
  36. self.tokenizer = tokenizer
  37. self.suffix = suffix
  38. def __len__(self):
  39. return len(self.files)
  40. def get_item(self, idx):
  41. audio_file = self.files[idx]
  42. text_file = audio_file.with_suffix(self.suffix)
  43. if text_file.exists() is False or audio_file.exists() is False:
  44. return None
  45. audio, _ = librosa.load(audio_file, sr=self.sample_rate, mono=True)
  46. duration = len(audio) / self.sample_rate
  47. if (
  48. len(audio) == 0
  49. or duration < self.min_duration
  50. or duration > self.max_duration
  51. ):
  52. return None
  53. max_value = np.abs(audio).max()
  54. if max_value > 1.0:
  55. audio = audio / max_value
  56. text = text_file.read_text(encoding="utf-8")
  57. input_ids = self.tokenizer(text, return_tensors="pt").input_ids.squeeze(0)
  58. return {
  59. "audio": torch.from_numpy(audio),
  60. "text": input_ids,
  61. }
  62. def __getitem__(self, idx):
  63. try:
  64. return self.get_item(idx)
  65. except Exception as e:
  66. import traceback
  67. traceback.print_exc()
  68. logger.error(f"Error loading {self.files[idx]}: {e}")
  69. return None
  70. @dataclass
  71. class VITSCollator:
  72. tokenizer: AutoTokenizer
  73. def __call__(self, batch):
  74. batch = [x for x in batch if x is not None]
  75. audio_lengths = torch.tensor([len(x["audio"]) for x in batch])
  76. audio_maxlen = audio_lengths.max()
  77. text_lengths = torch.tensor([len(x["text"]) for x in batch])
  78. text_maxlen = text_lengths.max()
  79. # Rounds up to nearest multiple of 2 (audio_lengths)
  80. audios = []
  81. texts = []
  82. for x in batch:
  83. audios.append(
  84. torch.nn.functional.pad(x["audio"], (0, audio_maxlen - len(x["audio"])))
  85. )
  86. texts.append(
  87. torch.nn.functional.pad(
  88. x["text"],
  89. (0, text_maxlen - len(x["text"])),
  90. value=self.tokenizer.eos_token_id,
  91. )
  92. )
  93. return {
  94. "audios": torch.stack(audios),
  95. "audio_lengths": audio_lengths,
  96. "texts": torch.stack(texts),
  97. "text_lengths": text_lengths,
  98. }
  99. class VITSDataModule(LightningDataModule):
  100. def __init__(
  101. self,
  102. train_dataset: VITSDataset,
  103. val_dataset: VITSDataset,
  104. tokenizer: AutoTokenizer,
  105. batch_size: int = 32,
  106. num_workers: int = 4,
  107. val_batch_size: Optional[int] = None,
  108. ):
  109. super().__init__()
  110. self.train_dataset = train_dataset
  111. self.val_dataset = val_dataset
  112. self.batch_size = batch_size
  113. self.val_batch_size = val_batch_size or batch_size
  114. self.num_workers = num_workers
  115. self.tokenizer = tokenizer
  116. def train_dataloader(self):
  117. return DataLoader(
  118. self.train_dataset,
  119. batch_size=self.batch_size,
  120. collate_fn=VITSCollator(self.tokenizer),
  121. num_workers=self.num_workers,
  122. shuffle=False,
  123. persistent_workers=True,
  124. )
  125. def val_dataloader(self):
  126. return DataLoader(
  127. self.val_dataset,
  128. batch_size=self.val_batch_size,
  129. collate_fn=VITSCollator(self.tokenizer),
  130. num_workers=self.num_workers,
  131. persistent_workers=True,
  132. )
  133. if __name__ == "__main__":
  134. tokenizer = AutoTokenizer.from_pretrained("fishaudio/fish-speech-1")
  135. dataset = VITSDataset(
  136. "data/source/Genshin/filelist.train.txt", tokenizer=tokenizer, suffix=".lab"
  137. )
  138. dataloader = DataLoader(
  139. dataset, batch_size=4, shuffle=False, collate_fn=VITSCollator(tokenizer)
  140. )
  141. for batch in dataloader:
  142. print(batch["audios"].shape)
  143. print(batch["audio_lengths"])
  144. print(batch["texts"].shape)
  145. print(batch["text_lengths"])
  146. break