lit_module.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306
  1. from typing import Any, Callable
  2. import lightning as L
  3. import torch
  4. import torch.nn.functional as F
  5. from torch import nn
  6. from torch.utils.checkpoint import checkpoint as gradient_checkpointing
  7. class VQGAN(L.LightningModule):
  8. def __init__(
  9. self,
  10. optimizer: Callable,
  11. lr_scheduler: Callable,
  12. encoder: nn.Module,
  13. generator: nn.Module,
  14. discriminator: nn.Module,
  15. mel_transform: nn.Module,
  16. segment_size: int = 20480,
  17. ):
  18. super().__init__()
  19. # Model parameters
  20. self.optimizer_builder = optimizer
  21. self.lr_scheduler_builder = lr_scheduler
  22. # Generator and discriminators
  23. # Compile generator so that snake can save memory
  24. self.encoder = encoder
  25. self.generator = generator
  26. self.discriminator = discriminator
  27. self.mel_transform = mel_transform
  28. # Crop length for saving memory
  29. self.segment_size = segment_size
  30. # Disable automatic optimization
  31. self.automatic_optimization = False
  32. def configure_optimizers(self):
  33. # Need two optimizers and two schedulers
  34. optimizer_generator = self.optimizer_builder(self.generator.parameters())
  35. optimizer_discriminator = self.optimizer_builder(
  36. self.discriminators.parameters()
  37. )
  38. lr_scheduler_generator = self.lr_scheduler_builder(optimizer_generator)
  39. lr_scheduler_discriminator = self.lr_scheduler_builder(optimizer_discriminator)
  40. return (
  41. {
  42. "optimizer": optimizer_generator,
  43. "lr_scheduler": {
  44. "scheduler": lr_scheduler_generator,
  45. "interval": "step",
  46. "name": "optimizer/generator",
  47. },
  48. },
  49. {
  50. "optimizer": optimizer_discriminator,
  51. "lr_scheduler": {
  52. "scheduler": lr_scheduler_discriminator,
  53. "interval": "step",
  54. "name": "optimizer/discriminator",
  55. },
  56. },
  57. )
  58. def training_generator(self, audio, audio_mask):
  59. # fake_audio, base_loss = self.forward(audio, audio_mask)
  60. assert fake_audio.shape == audio.shape
  61. # Apply mask
  62. audio = audio * audio_mask
  63. fake_audio = fake_audio * audio_mask
  64. # Multi-Resolution STFT Loss
  65. sc_loss, mag_loss = self.multi_resolution_stft_loss(
  66. fake_audio.squeeze(1), audio.squeeze(1)
  67. )
  68. loss_stft = sc_loss + mag_loss
  69. self.log(
  70. "train/generator/stft",
  71. loss_stft,
  72. on_step=True,
  73. on_epoch=False,
  74. prog_bar=True,
  75. logger=True,
  76. sync_dist=True,
  77. )
  78. # L1 Mel-Spectrogram Loss
  79. # This is not used in backpropagation currently
  80. audio_mel = self.mel_transforms.loss(audio.squeeze(1))
  81. fake_audio_mel = self.mel_transforms.loss(fake_audio.squeeze(1))
  82. loss_mel = F.l1_loss(audio_mel, fake_audio_mel)
  83. self.log(
  84. "train/generator/mel",
  85. loss_mel,
  86. on_step=True,
  87. on_epoch=False,
  88. prog_bar=True,
  89. logger=True,
  90. sync_dist=True,
  91. )
  92. # Now, we need to reduce the length of the audio to save memory
  93. if self.crop_length is not None and audio.shape[2] > self.crop_length:
  94. slice_idx = torch.randint(0, audio.shape[-1] - self.crop_length, (1,))
  95. audio = audio[..., slice_idx : slice_idx + self.crop_length]
  96. fake_audio = fake_audio[..., slice_idx : slice_idx + self.crop_length]
  97. audio_mask = audio_mask[..., slice_idx : slice_idx + self.crop_length]
  98. assert audio.shape == fake_audio.shape == audio_mask.shape
  99. # Adv Loss
  100. loss_adv_all = 0
  101. for key, disc in self.discriminators.items():
  102. score_fakes, feat_fake = disc(fake_audio)
  103. # Adversarial Loss
  104. score_fakes = torch.cat(score_fakes, dim=1)
  105. loss_fake = torch.mean((1 - score_fakes) ** 2)
  106. self.log(
  107. f"train/generator/adv_{key}",
  108. loss_fake,
  109. on_step=True,
  110. on_epoch=False,
  111. prog_bar=False,
  112. logger=True,
  113. sync_dist=True,
  114. )
  115. loss_adv_all += loss_fake
  116. if self.feature_matching is False:
  117. continue
  118. # Feature Matching Loss
  119. _, feat_real = disc(audio)
  120. loss_fm = 0
  121. for dr, dg in zip(feat_real, feat_fake):
  122. for rl, gl in zip(dr, dg):
  123. loss_fm += F.l1_loss(rl, gl)
  124. loss_fm /= len(feat_real)
  125. self.log(
  126. f"train/generator/adv_fm_{key}",
  127. loss_fm,
  128. on_step=True,
  129. on_epoch=False,
  130. prog_bar=False,
  131. logger=True,
  132. sync_dist=True,
  133. )
  134. loss_adv_all += loss_fm
  135. loss_adv_all /= len(self.discriminators)
  136. loss_gen_all = base_loss + loss_stft * 2.5 + loss_mel * 45 + loss_adv_all
  137. self.log(
  138. "train/generator/all",
  139. loss_gen_all,
  140. on_step=True,
  141. on_epoch=False,
  142. prog_bar=True,
  143. logger=True,
  144. sync_dist=True,
  145. )
  146. return loss_gen_all, audio, fake_audio
  147. def training_discriminator(self, audio, fake_audio):
  148. loss_disc_all = 0
  149. for key, disc in self.discriminators.items():
  150. if self.training and self.checkpointing:
  151. scores, _ = gradient_checkpointing(disc, audio, use_reentrant=False)
  152. score_fakes, _ = gradient_checkpointing(
  153. disc, fake_audio.detach(), use_reentrant=False
  154. )
  155. else:
  156. scores, _ = disc(audio)
  157. score_fakes, _ = disc(fake_audio.detach())
  158. scores = torch.cat(scores, dim=1)
  159. score_fakes = torch.cat(score_fakes, dim=1)
  160. loss_disc = torch.mean((scores - 1) ** 2) + torch.mean((score_fakes) ** 2)
  161. self.log(
  162. f"train/discriminator/{key}",
  163. loss_disc,
  164. on_step=True,
  165. on_epoch=False,
  166. prog_bar=False,
  167. logger=True,
  168. sync_dist=True,
  169. )
  170. loss_disc_all += loss_disc
  171. loss_disc_all /= len(self.discriminators)
  172. self.log(
  173. "train/discriminator/all",
  174. loss_disc_all,
  175. on_step=True,
  176. on_epoch=False,
  177. prog_bar=True,
  178. logger=True,
  179. sync_dist=True,
  180. )
  181. return loss_disc_all
  182. def training_step(self, batch, batch_idx):
  183. optim_g, optim_d = self.optimizers()
  184. audio, lengths = batch["audio"], batch["lengths"]
  185. audio_mask = sequence_mask(lengths)[:, None, :].to(audio.device, torch.float32)
  186. # Generator
  187. optim_g.zero_grad()
  188. loss_gen_all, audio, fake_audio = self.training_generator(audio, audio_mask)
  189. self.manual_backward(loss_gen_all)
  190. self.log(
  191. "train/generator/grad_norm",
  192. grad_norm(self.generator.parameters()),
  193. on_step=True,
  194. on_epoch=False,
  195. prog_bar=False,
  196. logger=True,
  197. sync_dist=True,
  198. )
  199. self.clip_gradients(
  200. optim_g, gradient_clip_val=1000, gradient_clip_algorithm="norm"
  201. )
  202. optim_g.step()
  203. # Discriminator
  204. assert fake_audio.shape == audio.shape
  205. optim_d.zero_grad()
  206. loss_disc_all = self.training_discriminator(audio, fake_audio)
  207. self.manual_backward(loss_disc_all)
  208. for key, disc in self.discriminators.items():
  209. self.log(
  210. f"train/discriminator/grad_norm_{key}",
  211. grad_norm(disc.parameters()),
  212. on_step=True,
  213. on_epoch=False,
  214. prog_bar=False,
  215. logger=True,
  216. sync_dist=True,
  217. )
  218. self.clip_gradients(
  219. optim_d, gradient_clip_val=1000, gradient_clip_algorithm="norm"
  220. )
  221. optim_d.step()
  222. # Manual LR Scheduler
  223. scheduler_g, scheduler_d = self.lr_schedulers()
  224. scheduler_g.step()
  225. scheduler_d.step()
  226. def validation_step(self, batch: Any, batch_idx: int):
  227. audio, lengths = batch["audio"], batch["lengths"]
  228. audio_mask = sequence_mask(lengths)[:, None, :].to(audio.device, torch.float32)
  229. # Generator
  230. fake_audio, _ = self.forward(audio, audio_mask)
  231. assert fake_audio.shape == audio.shape
  232. # Apply mask
  233. audio = audio * audio_mask
  234. fake_audio = fake_audio * audio_mask
  235. # L1 Mel-Spectrogram Loss
  236. audio_mel = self.mel_transforms.loss(audio.squeeze(1))
  237. fake_audio_mel = self.mel_transforms.loss(fake_audio.squeeze(1))
  238. loss_mel = F.l1_loss(audio_mel, fake_audio_mel)
  239. self.log(
  240. "val/metrics/mel",
  241. loss_mel,
  242. on_step=False,
  243. on_epoch=True,
  244. prog_bar=True,
  245. logger=True,
  246. sync_dist=True,
  247. )
  248. # Report other metrics
  249. self.report_val_metrics(fake_audio, audio, lengths)