lit_module.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343
  1. import itertools
  2. from typing import Any, Callable
  3. import lightning as L
  4. import torch
  5. import torch.nn.functional as F
  6. import wandb
  7. from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
  8. from matplotlib import pyplot as plt
  9. from torch import nn
  10. from vector_quantize_pytorch import VectorQuantize
  11. from fish_speech.models.vqgan.losses import (
  12. discriminator_loss,
  13. feature_loss,
  14. generator_loss,
  15. kl_loss,
  16. )
  17. from fish_speech.models.vqgan.modules.discriminator import EnsembleDiscriminator
  18. from fish_speech.models.vqgan.modules.models import SynthesizerTrn
  19. from fish_speech.models.vqgan.utils import plot_mel, sequence_mask, slice_segments
  20. class VQGAN(L.LightningModule):
  21. def __init__(
  22. self,
  23. optimizer: Callable,
  24. lr_scheduler: Callable,
  25. generator: SynthesizerTrn,
  26. discriminator: EnsembleDiscriminator,
  27. mel_transform: nn.Module,
  28. segment_size: int = 20480,
  29. hop_length: int = 640,
  30. sample_rate: int = 32000,
  31. ):
  32. super().__init__()
  33. # Model parameters
  34. self.optimizer_builder = optimizer
  35. self.lr_scheduler_builder = lr_scheduler
  36. # Generator and discriminators
  37. self.generator = generator
  38. self.discriminator = discriminator
  39. self.mel_transform = mel_transform
  40. # Crop length for saving memory
  41. self.segment_size = segment_size
  42. self.hop_length = hop_length
  43. self.sampling_rate = sample_rate
  44. # Disable automatic optimization
  45. self.automatic_optimization = False
  46. def configure_optimizers(self):
  47. # Need two optimizers and two schedulers
  48. optimizer_generator = self.optimizer_builder(self.generator.parameters())
  49. optimizer_discriminator = self.optimizer_builder(
  50. self.discriminator.parameters()
  51. )
  52. lr_scheduler_generator = self.lr_scheduler_builder(optimizer_generator)
  53. lr_scheduler_discriminator = self.lr_scheduler_builder(optimizer_discriminator)
  54. return (
  55. {
  56. "optimizer": optimizer_generator,
  57. "lr_scheduler": {
  58. "scheduler": lr_scheduler_generator,
  59. "interval": "step",
  60. "name": "optimizer/generator",
  61. },
  62. },
  63. {
  64. "optimizer": optimizer_discriminator,
  65. "lr_scheduler": {
  66. "scheduler": lr_scheduler_discriminator,
  67. "interval": "step",
  68. "name": "optimizer/discriminator",
  69. },
  70. },
  71. )
  72. def training_step(self, batch, batch_idx):
  73. optim_g, optim_d = self.optimizers()
  74. audios, audio_lengths = batch["audios"], batch["audio_lengths"]
  75. features, feature_lengths = batch["features"], batch["feature_lengths"]
  76. audios = audios[:, None, :]
  77. audios = audios.float()
  78. features = features.float()
  79. with torch.no_grad():
  80. gt_mels = self.mel_transform(audios)
  81. gt_mels = gt_mels[:, :, : features.shape[1]]
  82. (
  83. y_hat,
  84. ids_slice,
  85. x_mask,
  86. y_mask,
  87. (z_q, z_p),
  88. (m_p, logs_p),
  89. (m_q, logs_q),
  90. vq_loss,
  91. ) = self.generator(features, feature_lengths, gt_mels)
  92. y_hat_mel = self.mel_transform(y_hat.squeeze(1))
  93. y_mel = slice_segments(gt_mels, ids_slice, self.segment_size // self.hop_length)
  94. y = slice_segments(audios, ids_slice * self.hop_length, self.segment_size)
  95. # Discriminator
  96. y_d_hat_r, y_d_hat_g, _, _ = self.discriminator(y, y_hat.detach())
  97. with torch.autocast(device_type=audios.device.type, enabled=False):
  98. loss_disc_all, _, _ = discriminator_loss(y_d_hat_r, y_d_hat_g)
  99. self.log(
  100. "train/discriminator/loss",
  101. loss_disc_all,
  102. on_step=True,
  103. on_epoch=False,
  104. prog_bar=True,
  105. logger=True,
  106. sync_dist=True,
  107. )
  108. optim_d.zero_grad()
  109. self.manual_backward(loss_disc_all)
  110. self.clip_gradients(
  111. optim_d, gradient_clip_val=1000.0, gradient_clip_algorithm="norm"
  112. )
  113. optim_d.step()
  114. y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = self.discriminator(y, y_hat)
  115. with torch.autocast(device_type=audios.device.type, enabled=False):
  116. loss_mel = F.l1_loss(y_mel, y_hat_mel)
  117. loss_adv, _ = generator_loss(y_d_hat_g)
  118. loss_fm = feature_loss(fmap_r, fmap_g)
  119. loss_kl = kl_loss(
  120. z_p=z_p,
  121. logs_q=logs_q,
  122. m_p=m_p,
  123. logs_p=logs_p,
  124. z_mask=x_mask,
  125. )
  126. # Cyclical kl loss
  127. # then 500 steps linear: 0.1
  128. # then 500 steps 0.1
  129. # then go back to 0
  130. beta = self.global_step % 1000
  131. beta = min(beta, 500) / 500 * 0.1 + 1e-6
  132. loss_gen_all = loss_mel * 45 + loss_fm + loss_adv + loss_kl * beta + vq_loss
  133. self.log(
  134. "train/generator/loss",
  135. loss_gen_all,
  136. on_step=True,
  137. on_epoch=False,
  138. prog_bar=True,
  139. logger=True,
  140. sync_dist=True,
  141. )
  142. self.log(
  143. "train/generator/loss_mel",
  144. loss_mel,
  145. on_step=True,
  146. on_epoch=False,
  147. prog_bar=False,
  148. logger=True,
  149. sync_dist=True,
  150. )
  151. self.log(
  152. "train/generator/loss_fm",
  153. loss_fm,
  154. on_step=True,
  155. on_epoch=False,
  156. prog_bar=False,
  157. logger=True,
  158. sync_dist=True,
  159. )
  160. self.log(
  161. "train/generator/loss_adv",
  162. loss_adv,
  163. on_step=True,
  164. on_epoch=False,
  165. prog_bar=False,
  166. logger=True,
  167. sync_dist=True,
  168. )
  169. self.log(
  170. "train/generator/loss_kl",
  171. loss_kl,
  172. on_step=True,
  173. on_epoch=False,
  174. prog_bar=False,
  175. logger=True,
  176. sync_dist=True,
  177. )
  178. self.log(
  179. "train/generator/loss_vq",
  180. vq_loss,
  181. on_step=True,
  182. on_epoch=False,
  183. prog_bar=False,
  184. logger=True,
  185. sync_dist=True,
  186. )
  187. optim_g.zero_grad()
  188. self.manual_backward(loss_gen_all)
  189. self.clip_gradients(
  190. optim_g, gradient_clip_val=1000.0, gradient_clip_algorithm="norm"
  191. )
  192. optim_g.step()
  193. # Manual LR Scheduler
  194. scheduler_g, scheduler_d = self.lr_schedulers()
  195. scheduler_g.step()
  196. scheduler_d.step()
  197. def validation_step(self, batch: Any, batch_idx: int):
  198. audios, audio_lengths = batch["audios"], batch["audio_lengths"]
  199. features, feature_lengths = batch["features"], batch["feature_lengths"]
  200. audios = audios.float()
  201. features = features.float()
  202. audios = audios[:, None, :]
  203. gt_mels = self.mel_transform(audios)
  204. gt_mels = gt_mels[:, :, : features.shape[1]]
  205. fake_audios = self.generator.infer(features, feature_lengths, gt_mels)
  206. posterior_audios = self.generator.reconstruct(gt_mels, feature_lengths)
  207. fake_mels = self.mel_transform(fake_audios.squeeze(1))
  208. posterior_mels = self.mel_transform(posterior_audios.squeeze(1))
  209. min_mel_length = min(gt_mels.shape[-1], fake_mels.shape[-1])
  210. gt_mels = gt_mels[:, :, :min_mel_length]
  211. fake_mels = fake_mels[:, :, :min_mel_length]
  212. posterior_mels = posterior_mels[:, :, :min_mel_length]
  213. mel_loss = F.l1_loss(gt_mels, fake_mels)
  214. self.log(
  215. "val/mel_loss",
  216. mel_loss,
  217. on_step=False,
  218. on_epoch=True,
  219. prog_bar=True,
  220. logger=True,
  221. sync_dist=True,
  222. )
  223. for idx, (
  224. mel,
  225. gen_mel,
  226. post_mel,
  227. audio,
  228. gen_audio,
  229. post_audio,
  230. audio_len,
  231. ) in enumerate(
  232. zip(
  233. gt_mels,
  234. fake_mels,
  235. posterior_mels,
  236. audios,
  237. fake_audios,
  238. posterior_audios,
  239. audio_lengths,
  240. )
  241. ):
  242. mel_len = audio_len // self.hop_length
  243. image_mels = plot_mel(
  244. [
  245. gen_mel[:, :mel_len],
  246. post_mel[:, :mel_len],
  247. mel[:, :mel_len],
  248. ],
  249. [
  250. "Generated Spectrogram",
  251. "Posterior Spectrogram",
  252. "Ground-Truth Spectrogram",
  253. ],
  254. )
  255. if isinstance(self.logger, WandbLogger):
  256. self.logger.experiment.log(
  257. {
  258. "reconstruction_mel": wandb.Image(image_mels, caption="mels"),
  259. "wavs": [
  260. wandb.Audio(
  261. audio[0, :audio_len],
  262. sample_rate=self.sampling_rate,
  263. caption="gt",
  264. ),
  265. wandb.Audio(
  266. gen_audio[0, :audio_len],
  267. sample_rate=self.sampling_rate,
  268. caption="prediction",
  269. ),
  270. wandb.Audio(
  271. post_audio[0, :audio_len],
  272. sample_rate=self.sampling_rate,
  273. caption="posterior",
  274. ),
  275. ],
  276. },
  277. )
  278. if isinstance(self.logger, TensorBoardLogger):
  279. self.logger.experiment.add_figure(
  280. f"sample-{idx}/mels",
  281. image_mels,
  282. global_step=self.global_step,
  283. )
  284. self.logger.experiment.add_audio(
  285. f"sample-{idx}/wavs/gt",
  286. audio[0, :audio_len],
  287. self.global_step,
  288. sample_rate=self.sampling_rate,
  289. )
  290. self.logger.experiment.add_audio(
  291. f"sample-{idx}/wavs/prediction",
  292. gen_audio[0, :audio_len],
  293. self.global_step,
  294. sample_rate=self.sampling_rate,
  295. )
  296. self.logger.experiment.add_audio(
  297. f"sample-{idx}/wavs/posterior",
  298. post_audio[0, :audio_len],
  299. self.global_step,
  300. sample_rate=self.sampling_rate,
  301. )
  302. plt.close(image_mels)