lit_module.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396
  1. import itertools
  2. from dataclasses import dataclass
  3. from typing import Any, Callable, Literal, Optional
  4. import lightning as L
  5. import torch
  6. import torch.nn.functional as F
  7. import wandb
  8. from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
  9. from matplotlib import pyplot as plt
  10. from torch import nn
  11. from fish_speech.models.vits_decoder.losses import (
  12. discriminator_loss,
  13. feature_loss,
  14. generator_loss,
  15. kl_loss,
  16. )
  17. from fish_speech.models.vqgan.utils import (
  18. avg_with_mask,
  19. plot_mel,
  20. sequence_mask,
  21. slice_segments,
  22. )
  23. class VITSDecoder(L.LightningModule):
  24. def __init__(
  25. self,
  26. optimizer: Callable,
  27. lr_scheduler: Callable,
  28. generator: nn.Module,
  29. discriminator: nn.Module,
  30. mel_transform: nn.Module,
  31. spec_transform: nn.Module,
  32. hop_length: int = 512,
  33. sample_rate: int = 44100,
  34. freeze_discriminator: bool = False,
  35. weight_mel: float = 45,
  36. weight_kl: float = 0.1,
  37. ):
  38. super().__init__()
  39. # Model parameters
  40. self.optimizer_builder = optimizer
  41. self.lr_scheduler_builder = lr_scheduler
  42. # Generator and discriminator
  43. self.generator = generator
  44. self.discriminator = discriminator
  45. self.mel_transform = mel_transform
  46. self.spec_transform = spec_transform
  47. self.freeze_discriminator = freeze_discriminator
  48. # Loss weights
  49. self.weight_mel = weight_mel
  50. self.weight_kl = weight_kl
  51. # Other parameters
  52. self.hop_length = hop_length
  53. self.sampling_rate = sample_rate
  54. # Disable automatic optimization
  55. self.automatic_optimization = False
  56. if self.freeze_discriminator:
  57. for p in self.discriminator.parameters():
  58. p.requires_grad = False
  59. def configure_optimizers(self):
  60. # Need two optimizers and two schedulers
  61. optimizer_generator = self.optimizer_builder(self.generator.parameters())
  62. optimizer_discriminator = self.optimizer_builder(
  63. self.discriminator.parameters()
  64. )
  65. lr_scheduler_generator = self.lr_scheduler_builder(optimizer_generator)
  66. lr_scheduler_discriminator = self.lr_scheduler_builder(optimizer_discriminator)
  67. return (
  68. {
  69. "optimizer": optimizer_generator,
  70. "lr_scheduler": {
  71. "scheduler": lr_scheduler_generator,
  72. "interval": "step",
  73. "name": "optimizer/generator",
  74. },
  75. },
  76. {
  77. "optimizer": optimizer_discriminator,
  78. "lr_scheduler": {
  79. "scheduler": lr_scheduler_discriminator,
  80. "interval": "step",
  81. "name": "optimizer/discriminator",
  82. },
  83. },
  84. )
  85. def training_step(self, batch, batch_idx):
  86. optim_g, optim_d = self.optimizers()
  87. audios, audio_lengths = batch["audios"], batch["audio_lengths"]
  88. texts, text_lengths = batch["texts"], batch["text_lengths"]
  89. audios = audios.float()
  90. audios = audios[:, None, :]
  91. with torch.no_grad():
  92. gt_mels = self.mel_transform(audios)
  93. gt_specs = self.spec_transform(audios)
  94. spec_lengths = audio_lengths // self.hop_length
  95. spec_masks = torch.unsqueeze(
  96. sequence_mask(spec_lengths, gt_mels.shape[2]), 1
  97. ).to(gt_mels.dtype)
  98. (
  99. fake_audios,
  100. ids_slice,
  101. y_mask,
  102. (z, z_p, m_p, logs_p, m_q, logs_q),
  103. ) = self.generator(
  104. audios,
  105. audio_lengths,
  106. gt_specs,
  107. spec_lengths,
  108. texts,
  109. text_lengths,
  110. )
  111. gt_mels = slice_segments(gt_mels, ids_slice, self.generator.segment_size)
  112. spec_masks = slice_segments(spec_masks, ids_slice, self.generator.segment_size)
  113. audios = slice_segments(
  114. audios,
  115. ids_slice * self.hop_length,
  116. self.generator.segment_size * self.hop_length,
  117. )
  118. fake_mels = self.mel_transform(fake_audios.squeeze(1))
  119. assert (
  120. audios.shape == fake_audios.shape
  121. ), f"{audios.shape} != {fake_audios.shape}"
  122. # Discriminator
  123. if self.freeze_discriminator is False:
  124. y_d_hat_r, y_d_hat_g, _, _ = self.discriminator(
  125. audios, fake_audios.detach()
  126. )
  127. with torch.autocast(device_type=audios.device.type, enabled=False):
  128. loss_disc, _, _ = discriminator_loss(y_d_hat_r, y_d_hat_g)
  129. self.log(
  130. f"train/discriminator/loss",
  131. loss_disc,
  132. on_step=True,
  133. on_epoch=False,
  134. prog_bar=False,
  135. logger=True,
  136. sync_dist=True,
  137. )
  138. optim_d.zero_grad()
  139. self.manual_backward(loss_disc)
  140. self.clip_gradients(
  141. optim_d, gradient_clip_val=1000.0, gradient_clip_algorithm="norm"
  142. )
  143. optim_d.step()
  144. # Adv Loss
  145. y_d_hat_r, y_d_hat_g, _, _ = self.discriminator(audios, fake_audios)
  146. # Adversarial Loss
  147. with torch.autocast(device_type=audios.device.type, enabled=False):
  148. loss_adv, _ = generator_loss(y_d_hat_g)
  149. self.log(
  150. f"train/generator/adv",
  151. loss_adv,
  152. on_step=True,
  153. on_epoch=False,
  154. prog_bar=False,
  155. logger=True,
  156. sync_dist=True,
  157. )
  158. with torch.autocast(device_type=audios.device.type, enabled=False):
  159. loss_fm = feature_loss(y_d_hat_r, y_d_hat_g)
  160. self.log(
  161. f"train/generator/adv_fm",
  162. loss_fm,
  163. on_step=True,
  164. on_epoch=False,
  165. prog_bar=False,
  166. logger=True,
  167. sync_dist=True,
  168. )
  169. with torch.autocast(device_type=audios.device.type, enabled=False):
  170. loss_mel = avg_with_mask(
  171. F.l1_loss(gt_mels, fake_mels, reduction="none"), spec_masks
  172. )
  173. self.log(
  174. "train/generator/loss_mel",
  175. loss_mel,
  176. on_step=True,
  177. on_epoch=False,
  178. prog_bar=False,
  179. logger=True,
  180. sync_dist=True,
  181. )
  182. loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, y_mask)
  183. self.log(
  184. "train/generator/loss_kl",
  185. loss_kl,
  186. on_step=True,
  187. on_epoch=False,
  188. prog_bar=False,
  189. logger=True,
  190. sync_dist=True,
  191. )
  192. loss = (
  193. loss_mel * self.weight_mel + loss_kl * self.weight_kl + loss_adv + loss_fm
  194. )
  195. self.log(
  196. "train/generator/loss",
  197. loss,
  198. on_step=True,
  199. on_epoch=False,
  200. prog_bar=True,
  201. logger=True,
  202. sync_dist=True,
  203. )
  204. # Backward
  205. optim_g.zero_grad()
  206. self.manual_backward(loss)
  207. self.clip_gradients(
  208. optim_g, gradient_clip_val=1000.0, gradient_clip_algorithm="norm"
  209. )
  210. optim_g.step()
  211. # Manual LR Scheduler
  212. scheduler_g, scheduler_d = self.lr_schedulers()
  213. scheduler_g.step()
  214. scheduler_d.step()
  215. def validation_step(self, batch: Any, batch_idx: int):
  216. audios, audio_lengths = batch["audios"], batch["audio_lengths"]
  217. texts, text_lengths = batch["texts"], batch["text_lengths"]
  218. audios = audios.float()
  219. audios = audios[:, None, :]
  220. gt_mels = self.mel_transform(audios)
  221. gt_specs = self.spec_transform(audios)
  222. spec_lengths = audio_lengths // self.hop_length
  223. spec_masks = torch.unsqueeze(
  224. sequence_mask(spec_lengths, gt_mels.shape[2]), 1
  225. ).to(gt_mels.dtype)
  226. prior_audios = self.generator.infer(
  227. audios, audio_lengths, gt_specs, spec_lengths, texts, text_lengths
  228. )
  229. posterior_audios = self.generator.infer_posterior(gt_specs, spec_lengths)
  230. prior_mels = self.mel_transform(prior_audios.squeeze(1))
  231. posterior_mels = self.mel_transform(posterior_audios.squeeze(1))
  232. min_mel_length = min(
  233. gt_mels.shape[-1], prior_mels.shape[-1], posterior_mels.shape[-1]
  234. )
  235. gt_mels = gt_mels[:, :, :min_mel_length]
  236. prior_mels = prior_mels[:, :, :min_mel_length]
  237. posterior_mels = posterior_mels[:, :, :min_mel_length]
  238. prior_mel_loss = avg_with_mask(
  239. F.l1_loss(gt_mels, prior_mels, reduction="none"), spec_masks
  240. )
  241. posterior_mel_loss = avg_with_mask(
  242. F.l1_loss(gt_mels, posterior_mels, reduction="none"), spec_masks
  243. )
  244. self.log(
  245. "val/prior_mel_loss",
  246. prior_mel_loss,
  247. on_step=False,
  248. on_epoch=True,
  249. prog_bar=False,
  250. logger=True,
  251. sync_dist=True,
  252. )
  253. self.log(
  254. "val/posterior_mel_loss",
  255. posterior_mel_loss,
  256. on_step=False,
  257. on_epoch=True,
  258. prog_bar=False,
  259. logger=True,
  260. sync_dist=True,
  261. )
  262. # only log the first batch
  263. if batch_idx != 0:
  264. return
  265. for idx, (
  266. mel,
  267. prior_mel,
  268. posterior_mel,
  269. audio,
  270. prior_audio,
  271. posterior_audio,
  272. audio_len,
  273. ) in enumerate(
  274. zip(
  275. gt_mels,
  276. prior_mels,
  277. posterior_mels,
  278. audios.detach().float(),
  279. prior_audios.detach().float(),
  280. posterior_audios.detach().float(),
  281. audio_lengths,
  282. )
  283. ):
  284. mel_len = audio_len // self.hop_length
  285. image_mels = plot_mel(
  286. [
  287. prior_mel[:, :mel_len],
  288. posterior_mel[:, :mel_len],
  289. mel[:, :mel_len],
  290. ],
  291. [
  292. "Prior (VQ)",
  293. "Posterior (Reconstruction)",
  294. "Ground-Truth",
  295. ],
  296. )
  297. if isinstance(self.logger, WandbLogger):
  298. self.logger.experiment.log(
  299. {
  300. "reconstruction_mel": wandb.Image(image_mels, caption="mels"),
  301. "wavs": [
  302. wandb.Audio(
  303. audio[0, :audio_len],
  304. sample_rate=self.sampling_rate,
  305. caption="gt",
  306. ),
  307. wandb.Audio(
  308. prior_audio[0, :audio_len],
  309. sample_rate=self.sampling_rate,
  310. caption="prior",
  311. ),
  312. wandb.Audio(
  313. posterior_audio[0, :audio_len],
  314. sample_rate=self.sampling_rate,
  315. caption="posterior",
  316. ),
  317. ],
  318. },
  319. )
  320. if isinstance(self.logger, TensorBoardLogger):
  321. self.logger.experiment.add_figure(
  322. f"sample-{idx}/mels",
  323. image_mels,
  324. global_step=self.global_step,
  325. )
  326. self.logger.experiment.add_audio(
  327. f"sample-{idx}/wavs/gt",
  328. audio[0, :audio_len],
  329. self.global_step,
  330. sample_rate=self.sampling_rate,
  331. )
  332. self.logger.experiment.add_audio(
  333. f"sample-{idx}/wavs/prior",
  334. prior_audio[0, :audio_len],
  335. self.global_step,
  336. sample_rate=self.sampling_rate,
  337. )
  338. self.logger.experiment.add_audio(
  339. f"sample-{idx}/wavs/posterior",
  340. posterior_audio[0, :audio_len],
  341. self.global_step,
  342. sample_rate=self.sampling_rate,
  343. )
  344. plt.close(image_mels)