lit_module.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347
  1. from typing import Any, Callable
  2. import torch
  3. import torch.nn.functional as F
  4. from fish_vocoder.models.vocoder import VocoderModel
  5. from fish_vocoder.modules.losses.stft import MultiResolutionSTFTLoss
  6. from fish_vocoder.utils.grad_norm import grad_norm
  7. from fish_vocoder.utils.mask import sequence_mask
  8. from torch import nn
  9. from torch.utils.checkpoint import checkpoint as gradient_checkpointing
  10. class GANModel(VocoderModel):
  11. def __init__(
  12. self,
  13. sampling_rate: int,
  14. n_fft: int,
  15. hop_length: int,
  16. win_length: int,
  17. num_mels: int,
  18. optimizer: Callable,
  19. lr_scheduler: Callable,
  20. mel_transforms: nn.ModuleDict,
  21. generator: nn.Module,
  22. discriminators: nn.ModuleDict,
  23. multi_resolution_stft_loss: MultiResolutionSTFTLoss,
  24. num_frames: int,
  25. crop_length: int | None = None,
  26. checkpointing: bool = False,
  27. feature_matching: bool = False,
  28. ):
  29. super().__init__(
  30. sampling_rate=sampling_rate,
  31. n_fft=n_fft,
  32. hop_length=hop_length,
  33. win_length=win_length,
  34. num_mels=num_mels,
  35. )
  36. # Model parameters
  37. self.optimizer_builder = optimizer
  38. self.lr_scheduler_builder = lr_scheduler
  39. # Spectrogram transforms
  40. self.mel_transforms = mel_transforms
  41. # Generator and discriminators
  42. # Compile generator so that snake can save memory
  43. self.generator = generator
  44. self.discriminators = discriminators
  45. # Loss
  46. self.multi_resolution_stft_loss = multi_resolution_stft_loss
  47. # Crop length for saving memory
  48. self.num_frames = num_frames
  49. self.crop_length = crop_length
  50. # Disable automatic optimization
  51. self.automatic_optimization = False
  52. # Gradient checkpointing
  53. self.checkpointing = checkpointing
  54. # Feature matching
  55. self.feature_matching = feature_matching
  56. def configure_optimizers(self):
  57. # Need two optimizers and two schedulers
  58. optimizer_generator = self.optimizer_builder(self.generator.parameters())
  59. optimizer_discriminator = self.optimizer_builder(
  60. self.discriminators.parameters()
  61. )
  62. lr_scheduler_generator = self.lr_scheduler_builder(optimizer_generator)
  63. lr_scheduler_discriminator = self.lr_scheduler_builder(optimizer_discriminator)
  64. return (
  65. {
  66. "optimizer": optimizer_generator,
  67. "lr_scheduler": {
  68. "scheduler": lr_scheduler_generator,
  69. "interval": "step",
  70. "name": "optimizer/generator",
  71. },
  72. },
  73. {
  74. "optimizer": optimizer_discriminator,
  75. "lr_scheduler": {
  76. "scheduler": lr_scheduler_discriminator,
  77. "interval": "step",
  78. "name": "optimizer/discriminator",
  79. },
  80. },
  81. )
  82. def training_generator(self, audio, audio_mask):
  83. if self.training and self.checkpointing:
  84. fake_audio, base_loss = gradient_checkpointing(
  85. self.forward, audio, audio_mask, use_reentrant=False
  86. )
  87. else:
  88. fake_audio, base_loss = self.forward(audio, audio_mask)
  89. assert fake_audio.shape == audio.shape
  90. # Apply mask
  91. audio = audio * audio_mask
  92. fake_audio = fake_audio * audio_mask
  93. # Multi-Resolution STFT Loss
  94. sc_loss, mag_loss = self.multi_resolution_stft_loss(
  95. fake_audio.squeeze(1), audio.squeeze(1)
  96. )
  97. loss_stft = sc_loss + mag_loss
  98. self.log(
  99. "train/generator/stft",
  100. loss_stft,
  101. on_step=True,
  102. on_epoch=False,
  103. prog_bar=True,
  104. logger=True,
  105. sync_dist=True,
  106. )
  107. # L1 Mel-Spectrogram Loss
  108. # This is not used in backpropagation currently
  109. audio_mel = self.mel_transforms.loss(audio.squeeze(1))
  110. fake_audio_mel = self.mel_transforms.loss(fake_audio.squeeze(1))
  111. loss_mel = F.l1_loss(audio_mel, fake_audio_mel)
  112. self.log(
  113. "train/generator/mel",
  114. loss_mel,
  115. on_step=True,
  116. on_epoch=False,
  117. prog_bar=True,
  118. logger=True,
  119. sync_dist=True,
  120. )
  121. # Now, we need to reduce the length of the audio to save memory
  122. if self.crop_length is not None and audio.shape[2] > self.crop_length:
  123. slice_idx = torch.randint(0, audio.shape[-1] - self.crop_length, (1,))
  124. audio = audio[..., slice_idx : slice_idx + self.crop_length]
  125. fake_audio = fake_audio[..., slice_idx : slice_idx + self.crop_length]
  126. audio_mask = audio_mask[..., slice_idx : slice_idx + self.crop_length]
  127. assert audio.shape == fake_audio.shape == audio_mask.shape
  128. # Adv Loss
  129. loss_adv_all = 0
  130. for key, disc in self.discriminators.items():
  131. score_fakes, feat_fake = disc(fake_audio)
  132. # Adversarial Loss
  133. score_fakes = torch.cat(score_fakes, dim=1)
  134. loss_fake = torch.mean((1 - score_fakes) ** 2)
  135. self.log(
  136. f"train/generator/adv_{key}",
  137. loss_fake,
  138. on_step=True,
  139. on_epoch=False,
  140. prog_bar=False,
  141. logger=True,
  142. sync_dist=True,
  143. )
  144. loss_adv_all += loss_fake
  145. if self.feature_matching is False:
  146. continue
  147. # Feature Matching Loss
  148. _, feat_real = disc(audio)
  149. loss_fm = 0
  150. for dr, dg in zip(feat_real, feat_fake):
  151. for rl, gl in zip(dr, dg):
  152. loss_fm += F.l1_loss(rl, gl)
  153. loss_fm /= len(feat_real)
  154. self.log(
  155. f"train/generator/adv_fm_{key}",
  156. loss_fm,
  157. on_step=True,
  158. on_epoch=False,
  159. prog_bar=False,
  160. logger=True,
  161. sync_dist=True,
  162. )
  163. loss_adv_all += loss_fm
  164. loss_adv_all /= len(self.discriminators)
  165. loss_gen_all = base_loss + loss_stft * 2.5 + loss_mel * 45 + loss_adv_all
  166. self.log(
  167. "train/generator/all",
  168. loss_gen_all,
  169. on_step=True,
  170. on_epoch=False,
  171. prog_bar=True,
  172. logger=True,
  173. sync_dist=True,
  174. )
  175. return loss_gen_all, audio, fake_audio
  176. def training_discriminator(self, audio, fake_audio):
  177. loss_disc_all = 0
  178. for key, disc in self.discriminators.items():
  179. if self.training and self.checkpointing:
  180. scores, _ = gradient_checkpointing(disc, audio, use_reentrant=False)
  181. score_fakes, _ = gradient_checkpointing(
  182. disc, fake_audio.detach(), use_reentrant=False
  183. )
  184. else:
  185. scores, _ = disc(audio)
  186. score_fakes, _ = disc(fake_audio.detach())
  187. scores = torch.cat(scores, dim=1)
  188. score_fakes = torch.cat(score_fakes, dim=1)
  189. loss_disc = torch.mean((scores - 1) ** 2) + torch.mean((score_fakes) ** 2)
  190. self.log(
  191. f"train/discriminator/{key}",
  192. loss_disc,
  193. on_step=True,
  194. on_epoch=False,
  195. prog_bar=False,
  196. logger=True,
  197. sync_dist=True,
  198. )
  199. loss_disc_all += loss_disc
  200. loss_disc_all /= len(self.discriminators)
  201. self.log(
  202. "train/discriminator/all",
  203. loss_disc_all,
  204. on_step=True,
  205. on_epoch=False,
  206. prog_bar=True,
  207. logger=True,
  208. sync_dist=True,
  209. )
  210. return loss_disc_all
  211. def training_step(self, batch, batch_idx):
  212. optim_g, optim_d = self.optimizers()
  213. audio, lengths = batch["audio"], batch["lengths"]
  214. audio_mask = sequence_mask(lengths)[:, None, :].to(audio.device, torch.float32)
  215. # Generator
  216. optim_g.zero_grad()
  217. loss_gen_all, audio, fake_audio = self.training_generator(audio, audio_mask)
  218. self.manual_backward(loss_gen_all)
  219. self.log(
  220. "train/generator/grad_norm",
  221. grad_norm(self.generator.parameters()),
  222. on_step=True,
  223. on_epoch=False,
  224. prog_bar=False,
  225. logger=True,
  226. sync_dist=True,
  227. )
  228. self.clip_gradients(
  229. optim_g, gradient_clip_val=1000, gradient_clip_algorithm="norm"
  230. )
  231. optim_g.step()
  232. # Discriminator
  233. assert fake_audio.shape == audio.shape
  234. optim_d.zero_grad()
  235. loss_disc_all = self.training_discriminator(audio, fake_audio)
  236. self.manual_backward(loss_disc_all)
  237. for key, disc in self.discriminators.items():
  238. self.log(
  239. f"train/discriminator/grad_norm_{key}",
  240. grad_norm(disc.parameters()),
  241. on_step=True,
  242. on_epoch=False,
  243. prog_bar=False,
  244. logger=True,
  245. sync_dist=True,
  246. )
  247. self.clip_gradients(
  248. optim_d, gradient_clip_val=1000, gradient_clip_algorithm="norm"
  249. )
  250. optim_d.step()
  251. # Manual LR Scheduler
  252. scheduler_g, scheduler_d = self.lr_schedulers()
  253. scheduler_g.step()
  254. scheduler_d.step()
  255. def forward(self, audio, mask=None, input_spec=None):
  256. if input_spec is None:
  257. input_spec = self.mel_transforms.input(audio.squeeze(1))
  258. fake_audio = self.generator(input_spec)
  259. return fake_audio, 0
  260. def validation_step(self, batch: Any, batch_idx: int):
  261. audio, lengths = batch["audio"], batch["lengths"]
  262. audio_mask = sequence_mask(lengths)[:, None, :].to(audio.device, torch.float32)
  263. # Generator
  264. fake_audio, _ = self.forward(audio, audio_mask)
  265. assert fake_audio.shape == audio.shape
  266. # Apply mask
  267. audio = audio * audio_mask
  268. fake_audio = fake_audio * audio_mask
  269. # L1 Mel-Spectrogram Loss
  270. audio_mel = self.mel_transforms.loss(audio.squeeze(1))
  271. fake_audio_mel = self.mel_transforms.loss(fake_audio.squeeze(1))
  272. loss_mel = F.l1_loss(audio_mel, fake_audio_mel)
  273. self.log(
  274. "val/metrics/mel",
  275. loss_mel,
  276. on_step=False,
  277. on_epoch=True,
  278. prog_bar=True,
  279. logger=True,
  280. sync_dist=True,
  281. )
  282. # Report other metrics
  283. self.report_val_metrics(fake_audio, audio, lengths)