lit_module.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394
  1. from typing import Any, Callable
  2. import lightning as L
  3. import torch
  4. import torch.nn.functional as F
  5. import wandb
  6. from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
  7. from matplotlib import pyplot as plt
  8. from torch import nn
  9. from fish_speech.models.vits_decoder.losses import (
  10. discriminator_loss,
  11. feature_loss,
  12. generator_loss,
  13. kl_loss,
  14. )
  15. from fish_speech.models.vqgan.utils import (
  16. avg_with_mask,
  17. plot_mel,
  18. sequence_mask,
  19. slice_segments,
  20. )
  21. class VITSDecoder(L.LightningModule):
  22. def __init__(
  23. self,
  24. optimizer: Callable,
  25. lr_scheduler: Callable,
  26. generator: nn.Module,
  27. discriminator: nn.Module,
  28. mel_transform: nn.Module,
  29. spec_transform: nn.Module,
  30. hop_length: int = 512,
  31. sample_rate: int = 44100,
  32. freeze_discriminator: bool = False,
  33. weight_mel: float = 45,
  34. weight_kl: float = 0.1,
  35. ):
  36. super().__init__()
  37. # Model parameters
  38. self.optimizer_builder = optimizer
  39. self.lr_scheduler_builder = lr_scheduler
  40. # Generator and discriminator
  41. self.generator = generator
  42. self.discriminator = discriminator
  43. self.mel_transform = mel_transform
  44. self.spec_transform = spec_transform
  45. self.freeze_discriminator = freeze_discriminator
  46. # Loss weights
  47. self.weight_mel = weight_mel
  48. self.weight_kl = weight_kl
  49. # Other parameters
  50. self.hop_length = hop_length
  51. self.sampling_rate = sample_rate
  52. # Disable automatic optimization
  53. self.automatic_optimization = False
  54. if self.freeze_discriminator:
  55. for p in self.discriminator.parameters():
  56. p.requires_grad = False
  57. def configure_optimizers(self):
  58. # Need two optimizers and two schedulers
  59. optimizer_generator = self.optimizer_builder(self.generator.parameters())
  60. optimizer_discriminator = self.optimizer_builder(
  61. self.discriminator.parameters()
  62. )
  63. lr_scheduler_generator = self.lr_scheduler_builder(optimizer_generator)
  64. lr_scheduler_discriminator = self.lr_scheduler_builder(optimizer_discriminator)
  65. return (
  66. {
  67. "optimizer": optimizer_generator,
  68. "lr_scheduler": {
  69. "scheduler": lr_scheduler_generator,
  70. "interval": "step",
  71. "name": "optimizer/generator",
  72. },
  73. },
  74. {
  75. "optimizer": optimizer_discriminator,
  76. "lr_scheduler": {
  77. "scheduler": lr_scheduler_discriminator,
  78. "interval": "step",
  79. "name": "optimizer/discriminator",
  80. },
  81. },
  82. )
  83. def training_step(self, batch, batch_idx):
  84. optim_g, optim_d = self.optimizers()
  85. audios, audio_lengths = batch["audios"], batch["audio_lengths"]
  86. texts, text_lengths = batch["texts"], batch["text_lengths"]
  87. audios = audios.float()
  88. audios = audios[:, None, :]
  89. with torch.no_grad():
  90. gt_mels = self.mel_transform(audios)
  91. gt_specs = self.spec_transform(audios)
  92. spec_lengths = audio_lengths // self.hop_length
  93. spec_masks = torch.unsqueeze(
  94. sequence_mask(spec_lengths, gt_mels.shape[2]), 1
  95. ).to(gt_mels.dtype)
  96. (
  97. fake_audios,
  98. ids_slice,
  99. y_mask,
  100. (z, z_p, m_p, logs_p, m_q, logs_q),
  101. ) = self.generator(
  102. audios,
  103. audio_lengths,
  104. gt_specs,
  105. spec_lengths,
  106. texts,
  107. text_lengths,
  108. )
  109. gt_mels = slice_segments(gt_mels, ids_slice, self.generator.segment_size)
  110. spec_masks = slice_segments(spec_masks, ids_slice, self.generator.segment_size)
  111. audios = slice_segments(
  112. audios,
  113. ids_slice * self.hop_length,
  114. self.generator.segment_size * self.hop_length,
  115. )
  116. fake_mels = self.mel_transform(fake_audios.squeeze(1))
  117. assert (
  118. audios.shape == fake_audios.shape
  119. ), f"{audios.shape} != {fake_audios.shape}"
  120. # Discriminator
  121. if self.freeze_discriminator is False:
  122. y_d_hat_r, y_d_hat_g, _, _ = self.discriminator(
  123. audios, fake_audios.detach()
  124. )
  125. with torch.autocast(device_type=audios.device.type, enabled=False):
  126. loss_disc, _, _ = discriminator_loss(y_d_hat_r, y_d_hat_g)
  127. self.log(
  128. f"train/discriminator/loss",
  129. loss_disc,
  130. on_step=True,
  131. on_epoch=False,
  132. prog_bar=False,
  133. logger=True,
  134. sync_dist=True,
  135. )
  136. optim_d.zero_grad()
  137. self.manual_backward(loss_disc)
  138. self.clip_gradients(
  139. optim_d, gradient_clip_val=1000.0, gradient_clip_algorithm="norm"
  140. )
  141. optim_d.step()
  142. # Adv Loss
  143. y_d_hat_r, y_d_hat_g, _, _ = self.discriminator(audios, fake_audios)
  144. # Adversarial Loss
  145. with torch.autocast(device_type=audios.device.type, enabled=False):
  146. loss_adv, _ = generator_loss(y_d_hat_g)
  147. self.log(
  148. f"train/generator/adv",
  149. loss_adv,
  150. on_step=True,
  151. on_epoch=False,
  152. prog_bar=False,
  153. logger=True,
  154. sync_dist=True,
  155. )
  156. with torch.autocast(device_type=audios.device.type, enabled=False):
  157. loss_fm = feature_loss(y_d_hat_r, y_d_hat_g)
  158. self.log(
  159. f"train/generator/adv_fm",
  160. loss_fm,
  161. on_step=True,
  162. on_epoch=False,
  163. prog_bar=False,
  164. logger=True,
  165. sync_dist=True,
  166. )
  167. with torch.autocast(device_type=audios.device.type, enabled=False):
  168. loss_mel = avg_with_mask(
  169. F.l1_loss(gt_mels, fake_mels, reduction="none"), spec_masks
  170. )
  171. self.log(
  172. "train/generator/loss_mel",
  173. loss_mel,
  174. on_step=True,
  175. on_epoch=False,
  176. prog_bar=False,
  177. logger=True,
  178. sync_dist=True,
  179. )
  180. loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, y_mask)
  181. self.log(
  182. "train/generator/loss_kl",
  183. loss_kl,
  184. on_step=True,
  185. on_epoch=False,
  186. prog_bar=False,
  187. logger=True,
  188. sync_dist=True,
  189. )
  190. loss = (
  191. loss_mel * self.weight_mel + loss_kl * self.weight_kl + loss_adv + loss_fm
  192. )
  193. self.log(
  194. "train/generator/loss",
  195. loss,
  196. on_step=True,
  197. on_epoch=False,
  198. prog_bar=True,
  199. logger=True,
  200. sync_dist=True,
  201. )
  202. # Backward
  203. optim_g.zero_grad()
  204. self.manual_backward(loss)
  205. self.clip_gradients(
  206. optim_g, gradient_clip_val=1000.0, gradient_clip_algorithm="norm"
  207. )
  208. optim_g.step()
  209. # Manual LR Scheduler
  210. scheduler_g, scheduler_d = self.lr_schedulers()
  211. scheduler_g.step()
  212. scheduler_d.step()
  213. def validation_step(self, batch: Any, batch_idx: int):
  214. audios, audio_lengths = batch["audios"], batch["audio_lengths"]
  215. texts, text_lengths = batch["texts"], batch["text_lengths"]
  216. audios = audios.float()
  217. audios = audios[:, None, :]
  218. gt_mels = self.mel_transform(audios)
  219. gt_specs = self.spec_transform(audios)
  220. spec_lengths = audio_lengths // self.hop_length
  221. spec_masks = torch.unsqueeze(
  222. sequence_mask(spec_lengths, gt_mels.shape[2]), 1
  223. ).to(gt_mels.dtype)
  224. prior_audios = self.generator.infer(
  225. audios, audio_lengths, gt_specs, spec_lengths, texts, text_lengths
  226. )
  227. posterior_audios = self.generator.infer_posterior(gt_specs, spec_lengths)
  228. prior_mels = self.mel_transform(prior_audios.squeeze(1))
  229. posterior_mels = self.mel_transform(posterior_audios.squeeze(1))
  230. min_mel_length = min(
  231. gt_mels.shape[-1], prior_mels.shape[-1], posterior_mels.shape[-1]
  232. )
  233. gt_mels = gt_mels[:, :, :min_mel_length]
  234. prior_mels = prior_mels[:, :, :min_mel_length]
  235. posterior_mels = posterior_mels[:, :, :min_mel_length]
  236. prior_mel_loss = avg_with_mask(
  237. F.l1_loss(gt_mels, prior_mels, reduction="none"), spec_masks
  238. )
  239. posterior_mel_loss = avg_with_mask(
  240. F.l1_loss(gt_mels, posterior_mels, reduction="none"), spec_masks
  241. )
  242. self.log(
  243. "val/prior_mel_loss",
  244. prior_mel_loss,
  245. on_step=False,
  246. on_epoch=True,
  247. prog_bar=False,
  248. logger=True,
  249. sync_dist=True,
  250. )
  251. self.log(
  252. "val/posterior_mel_loss",
  253. posterior_mel_loss,
  254. on_step=False,
  255. on_epoch=True,
  256. prog_bar=False,
  257. logger=True,
  258. sync_dist=True,
  259. )
  260. # only log the first batch
  261. if batch_idx != 0:
  262. return
  263. for idx, (
  264. mel,
  265. prior_mel,
  266. posterior_mel,
  267. audio,
  268. prior_audio,
  269. posterior_audio,
  270. audio_len,
  271. ) in enumerate(
  272. zip(
  273. gt_mels,
  274. prior_mels,
  275. posterior_mels,
  276. audios.detach().float(),
  277. prior_audios.detach().float(),
  278. posterior_audios.detach().float(),
  279. audio_lengths,
  280. )
  281. ):
  282. mel_len = audio_len // self.hop_length
  283. image_mels = plot_mel(
  284. [
  285. prior_mel[:, :mel_len],
  286. posterior_mel[:, :mel_len],
  287. mel[:, :mel_len],
  288. ],
  289. [
  290. "Prior (VQ)",
  291. "Posterior (Reconstruction)",
  292. "Ground-Truth",
  293. ],
  294. )
  295. if isinstance(self.logger, WandbLogger):
  296. self.logger.experiment.log(
  297. {
  298. "reconstruction_mel": wandb.Image(image_mels, caption="mels"),
  299. "wavs": [
  300. wandb.Audio(
  301. audio[0, :audio_len],
  302. sample_rate=self.sampling_rate,
  303. caption="gt",
  304. ),
  305. wandb.Audio(
  306. prior_audio[0, :audio_len],
  307. sample_rate=self.sampling_rate,
  308. caption="prior",
  309. ),
  310. wandb.Audio(
  311. posterior_audio[0, :audio_len],
  312. sample_rate=self.sampling_rate,
  313. caption="posterior",
  314. ),
  315. ],
  316. },
  317. )
  318. if isinstance(self.logger, TensorBoardLogger):
  319. self.logger.experiment.add_figure(
  320. f"sample-{idx}/mels",
  321. image_mels,
  322. global_step=self.global_step,
  323. )
  324. self.logger.experiment.add_audio(
  325. f"sample-{idx}/wavs/gt",
  326. audio[0, :audio_len],
  327. self.global_step,
  328. sample_rate=self.sampling_rate,
  329. )
  330. self.logger.experiment.add_audio(
  331. f"sample-{idx}/wavs/prior",
  332. prior_audio[0, :audio_len],
  333. self.global_step,
  334. sample_rate=self.sampling_rate,
  335. )
  336. self.logger.experiment.add_audio(
  337. f"sample-{idx}/wavs/posterior",
  338. posterior_audio[0, :audio_len],
  339. self.global_step,
  340. sample_rate=self.sampling_rate,
  341. )
  342. plt.close(image_mels)