lit_module.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348
  1. import itertools
  2. from typing import Any, Callable
  3. import lightning as L
  4. import torch
  5. import torch.nn.functional as F
  6. import wandb
  7. from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
  8. from matplotlib import pyplot as plt
  9. from torch import nn
  10. from vector_quantize_pytorch import VectorQuantize
  11. from fish_speech.models.vqgan.losses import (
  12. discriminator_loss,
  13. feature_loss,
  14. generator_loss,
  15. kl_loss,
  16. )
  17. from fish_speech.models.vqgan.modules.discriminator import EnsembleDiscriminator
  18. from fish_speech.models.vqgan.modules.models import SynthesizerTrn
  19. from fish_speech.models.vqgan.utils import plot_mel, sequence_mask, slice_segments
  20. class VQGAN(L.LightningModule):
  21. def __init__(
  22. self,
  23. optimizer: Callable,
  24. lr_scheduler: Callable,
  25. generator: SynthesizerTrn,
  26. discriminator: EnsembleDiscriminator,
  27. mel_transform: nn.Module,
  28. segment_size: int = 20480,
  29. hop_length: int = 640,
  30. sample_rate: int = 32000,
  31. ):
  32. super().__init__()
  33. # Model parameters
  34. self.optimizer_builder = optimizer
  35. self.lr_scheduler_builder = lr_scheduler
  36. # Generator and discriminators
  37. self.generator = generator
  38. self.discriminator = discriminator
  39. self.mel_transform = mel_transform
  40. # Crop length for saving memory
  41. self.segment_size = segment_size
  42. self.hop_length = hop_length
  43. self.sampling_rate = sample_rate
  44. # Disable automatic optimization
  45. self.automatic_optimization = False
  46. def configure_optimizers(self):
  47. # Need two optimizers and two schedulers
  48. optimizer_generator = self.optimizer_builder(self.generator.parameters())
  49. optimizer_discriminator = self.optimizer_builder(
  50. self.discriminator.parameters()
  51. )
  52. lr_scheduler_generator = self.lr_scheduler_builder(optimizer_generator)
  53. lr_scheduler_discriminator = self.lr_scheduler_builder(optimizer_discriminator)
  54. return (
  55. {
  56. "optimizer": optimizer_generator,
  57. "lr_scheduler": {
  58. "scheduler": lr_scheduler_generator,
  59. "interval": "step",
  60. "name": "optimizer/generator",
  61. },
  62. },
  63. {
  64. "optimizer": optimizer_discriminator,
  65. "lr_scheduler": {
  66. "scheduler": lr_scheduler_discriminator,
  67. "interval": "step",
  68. "name": "optimizer/discriminator",
  69. },
  70. },
  71. )
  72. def training_step(self, batch, batch_idx):
  73. optim_g, optim_d = self.optimizers()
  74. audios, audio_lengths = batch["audios"], batch["audio_lengths"]
  75. features, feature_lengths = batch["features"], batch["feature_lengths"]
  76. audios = audios[:, None, :]
  77. audios = audios.float()
  78. # features = features.long()
  79. with torch.no_grad():
  80. gt_mels = self.mel_transform(audios)
  81. gt_mels = gt_mels[:, :, : features.shape[1]]
  82. (
  83. y_hat,
  84. ids_slice,
  85. x_mask,
  86. y_mask,
  87. (z_q, z_p),
  88. (m_p, logs_p),
  89. (m_q, logs_q),
  90. # vq_loss,
  91. ) = self.generator(features, feature_lengths, gt_mels)
  92. y_hat_mel = self.mel_transform(y_hat.squeeze(1))
  93. y_mel = slice_segments(gt_mels, ids_slice, self.segment_size // self.hop_length)
  94. y = slice_segments(audios, ids_slice * self.hop_length, self.segment_size)
  95. # Discriminator
  96. y_d_hat_r, y_d_hat_g, _, _ = self.discriminator(y, y_hat.detach())
  97. with torch.autocast(device_type=audios.device.type, enabled=False):
  98. loss_disc_all, _, _ = discriminator_loss(y_d_hat_r, y_d_hat_g)
  99. self.log(
  100. "train/discriminator/loss",
  101. loss_disc_all,
  102. on_step=True,
  103. on_epoch=False,
  104. prog_bar=True,
  105. logger=True,
  106. sync_dist=True,
  107. )
  108. optim_d.zero_grad()
  109. self.manual_backward(loss_disc_all)
  110. self.clip_gradients(
  111. optim_d, gradient_clip_val=1000.0, gradient_clip_algorithm="norm"
  112. )
  113. optim_d.step()
  114. y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = self.discriminator(y, y_hat)
  115. with torch.autocast(device_type=audios.device.type, enabled=False):
  116. loss_mel = F.l1_loss(y_mel, y_hat_mel)
  117. loss_adv, _ = generator_loss(y_d_hat_g)
  118. loss_fm = feature_loss(fmap_r, fmap_g)
  119. loss_kl = kl_loss(
  120. z_p=z_p,
  121. logs_q=logs_q,
  122. m_p=m_p,
  123. logs_p=logs_p,
  124. z_mask=x_mask,
  125. )
  126. # Cyclical kl loss
  127. # then 500 steps linear: 0.1
  128. # then 500 steps 0.1
  129. # then go back to 0
  130. if self.global_step < 100000:
  131. beta = 1e-6
  132. else:
  133. beta = self.global_step % 1000
  134. beta = min(beta, 500) / 500 * 0.1 + 1e-6
  135. loss_gen_all = (
  136. loss_mel * 45 + loss_fm + loss_adv + loss_kl * beta
  137. ) # + vq_loss
  138. self.log(
  139. "train/generator/loss",
  140. loss_gen_all,
  141. on_step=True,
  142. on_epoch=False,
  143. prog_bar=True,
  144. logger=True,
  145. sync_dist=True,
  146. )
  147. self.log(
  148. "train/generator/loss_mel",
  149. loss_mel,
  150. on_step=True,
  151. on_epoch=False,
  152. prog_bar=False,
  153. logger=True,
  154. sync_dist=True,
  155. )
  156. self.log(
  157. "train/generator/loss_fm",
  158. loss_fm,
  159. on_step=True,
  160. on_epoch=False,
  161. prog_bar=False,
  162. logger=True,
  163. sync_dist=True,
  164. )
  165. self.log(
  166. "train/generator/loss_adv",
  167. loss_adv,
  168. on_step=True,
  169. on_epoch=False,
  170. prog_bar=False,
  171. logger=True,
  172. sync_dist=True,
  173. )
  174. self.log(
  175. "train/generator/loss_kl",
  176. loss_kl,
  177. on_step=True,
  178. on_epoch=False,
  179. prog_bar=False,
  180. logger=True,
  181. sync_dist=True,
  182. )
  183. # self.log(
  184. # "train/generator/loss_vq",
  185. # vq_loss,
  186. # on_step=True,
  187. # on_epoch=False,
  188. # prog_bar=False,
  189. # logger=True,
  190. # sync_dist=True,
  191. # )
  192. optim_g.zero_grad()
  193. self.manual_backward(loss_gen_all)
  194. self.clip_gradients(
  195. optim_g, gradient_clip_val=1000.0, gradient_clip_algorithm="norm"
  196. )
  197. optim_g.step()
  198. # Manual LR Scheduler
  199. scheduler_g, scheduler_d = self.lr_schedulers()
  200. scheduler_g.step()
  201. scheduler_d.step()
  202. def validation_step(self, batch: Any, batch_idx: int):
  203. audios, audio_lengths = batch["audios"], batch["audio_lengths"]
  204. features, feature_lengths = batch["features"], batch["feature_lengths"]
  205. audios = audios.float()
  206. # features = features.float()
  207. audios = audios[:, None, :]
  208. gt_mels = self.mel_transform(audios)
  209. gt_mels = gt_mels[:, :, : features.shape[1]]
  210. fake_audios = self.generator.infer(features, feature_lengths, gt_mels)
  211. posterior_audios = self.generator.reconstruct(gt_mels, feature_lengths)
  212. fake_mels = self.mel_transform(fake_audios.squeeze(1))
  213. posterior_mels = self.mel_transform(posterior_audios.squeeze(1))
  214. min_mel_length = min(gt_mels.shape[-1], fake_mels.shape[-1])
  215. gt_mels = gt_mels[:, :, :min_mel_length]
  216. fake_mels = fake_mels[:, :, :min_mel_length]
  217. posterior_mels = posterior_mels[:, :, :min_mel_length]
  218. mel_loss = F.l1_loss(gt_mels, fake_mels)
  219. self.log(
  220. "val/mel_loss",
  221. mel_loss,
  222. on_step=False,
  223. on_epoch=True,
  224. prog_bar=True,
  225. logger=True,
  226. sync_dist=True,
  227. )
  228. for idx, (
  229. mel,
  230. gen_mel,
  231. post_mel,
  232. audio,
  233. gen_audio,
  234. post_audio,
  235. audio_len,
  236. ) in enumerate(
  237. zip(
  238. gt_mels,
  239. fake_mels,
  240. posterior_mels,
  241. audios,
  242. fake_audios,
  243. posterior_audios,
  244. audio_lengths,
  245. )
  246. ):
  247. mel_len = audio_len // self.hop_length
  248. image_mels = plot_mel(
  249. [
  250. gen_mel[:, :mel_len],
  251. post_mel[:, :mel_len],
  252. mel[:, :mel_len],
  253. ],
  254. [
  255. "Generated Spectrogram",
  256. "Posterior Spectrogram",
  257. "Ground-Truth Spectrogram",
  258. ],
  259. )
  260. if isinstance(self.logger, WandbLogger):
  261. self.logger.experiment.log(
  262. {
  263. "reconstruction_mel": wandb.Image(image_mels, caption="mels"),
  264. "wavs": [
  265. wandb.Audio(
  266. audio[0, :audio_len],
  267. sample_rate=self.sampling_rate,
  268. caption="gt",
  269. ),
  270. wandb.Audio(
  271. gen_audio[0, :audio_len],
  272. sample_rate=self.sampling_rate,
  273. caption="prediction",
  274. ),
  275. wandb.Audio(
  276. post_audio[0, :audio_len],
  277. sample_rate=self.sampling_rate,
  278. caption="posterior",
  279. ),
  280. ],
  281. },
  282. )
  283. if isinstance(self.logger, TensorBoardLogger):
  284. self.logger.experiment.add_figure(
  285. f"sample-{idx}/mels",
  286. image_mels,
  287. global_step=self.global_step,
  288. )
  289. self.logger.experiment.add_audio(
  290. f"sample-{idx}/wavs/gt",
  291. audio[0, :audio_len],
  292. self.global_step,
  293. sample_rate=self.sampling_rate,
  294. )
  295. self.logger.experiment.add_audio(
  296. f"sample-{idx}/wavs/prediction",
  297. gen_audio[0, :audio_len],
  298. self.global_step,
  299. sample_rate=self.sampling_rate,
  300. )
  301. self.logger.experiment.add_audio(
  302. f"sample-{idx}/wavs/posterior",
  303. post_audio[0, :audio_len],
  304. self.global_step,
  305. sample_rate=self.sampling_rate,
  306. )
  307. plt.close(image_mels)