lit_module.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504
  1. import itertools
  2. from dataclasses import dataclass
  3. from typing import Any, Callable, Literal, Optional
  4. import lightning as L
  5. import torch
  6. import torch.nn.functional as F
  7. import wandb
  8. from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
  9. from matplotlib import pyplot as plt
  10. from torch import nn
  11. from torch.utils.checkpoint import checkpoint as gradient_checkpoint
  12. from fish_speech.models.vqgan.losses import (
  13. MultiResolutionSTFTLoss,
  14. discriminator_loss,
  15. feature_loss,
  16. generator_loss,
  17. kl_loss,
  18. )
  19. from fish_speech.models.vqgan.utils import plot_mel, sequence_mask, slice_segments
  20. @dataclass
  21. class VQEncodeResult:
  22. features: torch.Tensor
  23. indices: torch.Tensor
  24. loss: torch.Tensor
  25. feature_lengths: torch.Tensor
  26. @dataclass
  27. class VQDecodeResult:
  28. mels: torch.Tensor
  29. audios: Optional[torch.Tensor] = None
  30. class VQGAN(L.LightningModule):
  31. def __init__(
  32. self,
  33. optimizer: Callable,
  34. lr_scheduler: Callable,
  35. generator: nn.Module,
  36. discriminator: nn.Module,
  37. mel_transform: nn.Module,
  38. spec_transform: nn.Module,
  39. hop_length: int = 640,
  40. sample_rate: int = 32000,
  41. freeze_discriminator: bool = False,
  42. weight_mel: float = 45,
  43. weight_kl: float = 0.1,
  44. weight_vq: float = 1.0,
  45. weight_aux_mel: float = 20.0,
  46. ):
  47. super().__init__()
  48. # Model parameters
  49. self.optimizer_builder = optimizer
  50. self.lr_scheduler_builder = lr_scheduler
  51. # Generator and discriminator
  52. self.generator = generator
  53. self.discriminator = discriminator
  54. self.mel_transform = mel_transform
  55. self.spec_transform = spec_transform
  56. self.freeze_discriminator = freeze_discriminator
  57. # Loss weights
  58. self.weight_mel = weight_mel
  59. self.weight_kl = weight_kl
  60. self.weight_vq = weight_vq
  61. self.weight_aux_mel = weight_aux_mel
  62. # Other parameters
  63. self.hop_length = hop_length
  64. self.sampling_rate = sample_rate
  65. # Disable automatic optimization
  66. self.automatic_optimization = False
  67. if self.freeze_discriminator:
  68. for p in self.discriminator.parameters():
  69. p.requires_grad = False
  70. def configure_optimizers(self):
  71. # Need two optimizers and two schedulers
  72. optimizer_generator = self.optimizer_builder(self.generator.parameters())
  73. optimizer_discriminator = self.optimizer_builder(
  74. self.discriminator.parameters()
  75. )
  76. lr_scheduler_generator = self.lr_scheduler_builder(optimizer_generator)
  77. lr_scheduler_discriminator = self.lr_scheduler_builder(optimizer_discriminator)
  78. return (
  79. {
  80. "optimizer": optimizer_generator,
  81. "lr_scheduler": {
  82. "scheduler": lr_scheduler_generator,
  83. "interval": "step",
  84. "name": "optimizer/generator",
  85. },
  86. },
  87. {
  88. "optimizer": optimizer_discriminator,
  89. "lr_scheduler": {
  90. "scheduler": lr_scheduler_discriminator,
  91. "interval": "step",
  92. "name": "optimizer/discriminator",
  93. },
  94. },
  95. )
  96. def training_step(self, batch, batch_idx):
  97. optim_g, optim_d = self.optimizers()
  98. audios, audio_lengths = batch["audios"], batch["audio_lengths"]
  99. audios = audios.float()
  100. audios = audios[:, None, :]
  101. with torch.no_grad():
  102. gt_mels = self.mel_transform(audios)
  103. gt_specs = self.spec_transform(audios)
  104. spec_lengths = audio_lengths // self.hop_length
  105. spec_masks = torch.unsqueeze(
  106. sequence_mask(spec_lengths, gt_mels.shape[2]), 1
  107. ).to(gt_mels.dtype)
  108. (
  109. fake_audios,
  110. ids_slice,
  111. y_mask,
  112. y_mask,
  113. (z, z_p, m_p, logs_p, m_q, logs_q),
  114. loss_vq,
  115. decoded_aux_mels,
  116. ) = self.generator(gt_specs, spec_lengths)
  117. gt_mels = slice_segments(gt_mels, ids_slice, self.generator.segment_size)
  118. decoded_aux_mels = slice_segments(
  119. decoded_aux_mels, ids_slice, self.generator.segment_size
  120. )
  121. spec_masks = slice_segments(spec_masks, ids_slice, self.generator.segment_size)
  122. audios = slice_segments(
  123. audios,
  124. ids_slice * self.hop_length,
  125. self.generator.segment_size * self.hop_length,
  126. )
  127. fake_mels = self.mel_transform(fake_audios.squeeze(1))
  128. assert (
  129. audios.shape == fake_audios.shape
  130. ), f"{audios.shape} != {fake_audios.shape}"
  131. # Discriminator
  132. if self.freeze_discriminator is False:
  133. y_d_hat_r, y_d_hat_g, _, _ = self.discriminator(
  134. audios, fake_audios.detach()
  135. )
  136. with torch.autocast(device_type=audios.device.type, enabled=False):
  137. loss_disc, _, _ = discriminator_loss(y_d_hat_r, y_d_hat_g)
  138. self.log(
  139. f"train/discriminator/loss",
  140. loss_disc,
  141. on_step=True,
  142. on_epoch=False,
  143. prog_bar=False,
  144. logger=True,
  145. sync_dist=True,
  146. )
  147. optim_d.zero_grad()
  148. self.manual_backward(loss_disc)
  149. self.clip_gradients(
  150. optim_d, gradient_clip_val=1000.0, gradient_clip_algorithm="norm"
  151. )
  152. optim_d.step()
  153. # Adv Loss
  154. y_d_hat_r, y_d_hat_g, _, _ = self.discriminator(audios, fake_audios)
  155. # Adversarial Loss
  156. with torch.autocast(device_type=audios.device.type, enabled=False):
  157. loss_adv, _ = generator_loss(y_d_hat_g)
  158. self.log(
  159. f"train/generator/adv",
  160. loss_adv,
  161. on_step=True,
  162. on_epoch=False,
  163. prog_bar=False,
  164. logger=True,
  165. sync_dist=True,
  166. )
  167. with torch.autocast(device_type=audios.device.type, enabled=False):
  168. loss_fm = feature_loss(y_d_hat_r, y_d_hat_g)
  169. self.log(
  170. f"train/generator/adv_fm",
  171. loss_fm,
  172. on_step=True,
  173. on_epoch=False,
  174. prog_bar=False,
  175. logger=True,
  176. sync_dist=True,
  177. )
  178. with torch.autocast(device_type=audios.device.type, enabled=False):
  179. loss_mel = F.l1_loss(gt_mels * spec_masks, fake_mels * spec_masks)
  180. loss_aux_mel = F.l1_loss(
  181. gt_mels * spec_masks, decoded_aux_mels * spec_masks
  182. )
  183. self.log(
  184. "train/generator/loss_mel",
  185. loss_mel,
  186. on_step=True,
  187. on_epoch=False,
  188. prog_bar=False,
  189. logger=True,
  190. sync_dist=True,
  191. )
  192. self.log(
  193. "train/generator/loss_aux_mel",
  194. loss_aux_mel,
  195. on_step=True,
  196. on_epoch=False,
  197. prog_bar=False,
  198. logger=True,
  199. sync_dist=True,
  200. )
  201. self.log(
  202. "train/generator/loss_vq",
  203. loss_vq,
  204. on_step=True,
  205. on_epoch=False,
  206. prog_bar=False,
  207. logger=True,
  208. sync_dist=True,
  209. )
  210. loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, y_mask)
  211. self.log(
  212. "train/generator/loss_kl",
  213. loss_kl,
  214. on_step=True,
  215. on_epoch=False,
  216. prog_bar=False,
  217. logger=True,
  218. sync_dist=True,
  219. )
  220. loss = (
  221. loss_mel * self.weight_mel
  222. + loss_aux_mel * self.weight_aux_mel
  223. + loss_vq * self.weight_vq
  224. + loss_kl * self.weight_kl
  225. + loss_adv
  226. + loss_fm
  227. )
  228. self.log(
  229. "train/generator/loss",
  230. loss,
  231. on_step=True,
  232. on_epoch=False,
  233. prog_bar=False,
  234. logger=True,
  235. sync_dist=True,
  236. )
  237. # Backward
  238. optim_g.zero_grad()
  239. self.manual_backward(loss)
  240. self.clip_gradients(
  241. optim_g, gradient_clip_val=1000.0, gradient_clip_algorithm="norm"
  242. )
  243. optim_g.step()
  244. # Manual LR Scheduler
  245. scheduler_g, scheduler_d = self.lr_schedulers()
  246. scheduler_g.step()
  247. scheduler_d.step()
  248. def validation_step(self, batch: Any, batch_idx: int):
  249. audios, audio_lengths = batch["audios"], batch["audio_lengths"]
  250. audios = audios.float()
  251. audios = audios[:, None, :]
  252. gt_mels = self.mel_transform(audios)
  253. gt_specs = self.spec_transform(audios)
  254. spec_lengths = audio_lengths // self.hop_length
  255. spec_masks = torch.unsqueeze(
  256. sequence_mask(spec_lengths, gt_mels.shape[2]), 1
  257. ).to(gt_mels.dtype)
  258. prior_audios, _, _ = self.generator.infer(gt_specs, spec_lengths)
  259. posterior_audios, _, _ = self.generator.infer_posterior(gt_specs, spec_lengths)
  260. prior_mels = self.mel_transform(prior_audios.squeeze(1))
  261. posterior_mels = self.mel_transform(posterior_audios.squeeze(1))
  262. min_mel_length = min(
  263. gt_mels.shape[-1], prior_mels.shape[-1], posterior_mels.shape[-1]
  264. )
  265. gt_mels = gt_mels[:, :, :min_mel_length]
  266. prior_mels = prior_mels[:, :, :min_mel_length]
  267. posterior_mels = posterior_mels[:, :, :min_mel_length]
  268. prior_mel_loss = F.l1_loss(gt_mels * spec_masks, prior_mels * spec_masks)
  269. posterior_mel_loss = F.l1_loss(
  270. gt_mels * spec_masks, posterior_mels * spec_masks
  271. )
  272. self.log(
  273. "val/prior_mel_loss",
  274. prior_mel_loss,
  275. on_step=False,
  276. on_epoch=True,
  277. prog_bar=False,
  278. logger=True,
  279. sync_dist=True,
  280. )
  281. self.log(
  282. "val/posterior_mel_loss",
  283. posterior_mel_loss,
  284. on_step=False,
  285. on_epoch=True,
  286. prog_bar=False,
  287. logger=True,
  288. sync_dist=True,
  289. )
  290. # only log the first batch
  291. if batch_idx != 0:
  292. return
  293. for idx, (
  294. mel,
  295. prior_mel,
  296. posterior_mel,
  297. audio,
  298. prior_audio,
  299. posterior_audio,
  300. audio_len,
  301. ) in enumerate(
  302. zip(
  303. gt_mels,
  304. prior_mels,
  305. posterior_mels,
  306. audios.detach().float(),
  307. prior_audios.detach().float(),
  308. posterior_audios.detach().float(),
  309. audio_lengths,
  310. )
  311. ):
  312. mel_len = audio_len // self.hop_length
  313. image_mels = plot_mel(
  314. [
  315. prior_mel[:, :mel_len],
  316. posterior_mel[:, :mel_len],
  317. mel[:, :mel_len],
  318. ],
  319. [
  320. "Prior (VQ)",
  321. "Posterior (Reconstruction)",
  322. "Ground-Truth",
  323. ],
  324. )
  325. if isinstance(self.logger, WandbLogger):
  326. self.logger.experiment.log(
  327. {
  328. "reconstruction_mel": wandb.Image(image_mels, caption="mels"),
  329. "wavs": [
  330. wandb.Audio(
  331. audio[0, :audio_len],
  332. sample_rate=self.sampling_rate,
  333. caption="gt",
  334. ),
  335. wandb.Audio(
  336. prior_audio[0, :audio_len],
  337. sample_rate=self.sampling_rate,
  338. caption="prior",
  339. ),
  340. wandb.Audio(
  341. posterior_audio[0, :audio_len],
  342. sample_rate=self.sampling_rate,
  343. caption="posterior",
  344. ),
  345. ],
  346. },
  347. )
  348. if isinstance(self.logger, TensorBoardLogger):
  349. self.logger.experiment.add_figure(
  350. f"sample-{idx}/mels",
  351. image_mels,
  352. global_step=self.global_step,
  353. )
  354. self.logger.experiment.add_audio(
  355. f"sample-{idx}/wavs/gt",
  356. audio[0, :audio_len],
  357. self.global_step,
  358. sample_rate=self.sampling_rate,
  359. )
  360. self.logger.experiment.add_audio(
  361. f"sample-{idx}/wavs/prior",
  362. prior_audio[0, :audio_len],
  363. self.global_step,
  364. sample_rate=self.sampling_rate,
  365. )
  366. self.logger.experiment.add_audio(
  367. f"sample-{idx}/wavs/posterior",
  368. posterior_audio[0, :audio_len],
  369. self.global_step,
  370. sample_rate=self.sampling_rate,
  371. )
  372. plt.close(image_mels)
  373. # def encode(self, audios, audio_lengths=None):
  374. # if audio_lengths is None:
  375. # audio_lengths = torch.tensor(
  376. # [audios.shape[-1]] * audios.shape[0],
  377. # device=audios.device,
  378. # dtype=torch.long,
  379. # )
  380. # with torch.no_grad():
  381. # features = self.mel_transform(audios, sample_rate=self.sampling_rate)
  382. # feature_lengths = (
  383. # audio_lengths
  384. # / self.hop_length
  385. # # / self.vq.downsample
  386. # ).long()
  387. # # print(features.shape, feature_lengths.shape, torch.max(feature_lengths))
  388. # feature_masks = torch.unsqueeze(
  389. # sequence_mask(feature_lengths, features.shape[2]), 1
  390. # ).to(features.dtype)
  391. # features = (
  392. # gradient_checkpoint(
  393. # self.encoder, features, feature_masks, use_reentrant=False
  394. # )
  395. # * feature_masks
  396. # )
  397. # vq_features, indices, loss = self.vq(features, feature_masks)
  398. # return VQEncodeResult(
  399. # features=vq_features,
  400. # indices=indices,
  401. # loss=loss,
  402. # feature_lengths=feature_lengths,
  403. # )
  404. # def calculate_audio_lengths(self, feature_lengths):
  405. # return feature_lengths * self.hop_length * self.vq.downsample
  406. # def decode(
  407. # self,
  408. # indices=None,
  409. # features=None,
  410. # audio_lengths=None,
  411. # feature_lengths=None,
  412. # return_audios=False,
  413. # ):
  414. # assert (
  415. # indices is not None or features is not None
  416. # ), "indices or features must be provided"
  417. # assert (
  418. # feature_lengths is not None or audio_lengths is not None
  419. # ), "feature_lengths or audio_lengths must be provided"
  420. # if audio_lengths is None:
  421. # audio_lengths = self.calculate_audio_lengths(feature_lengths)
  422. # mel_lengths = audio_lengths // self.hop_length
  423. # mel_masks = torch.unsqueeze(
  424. # sequence_mask(mel_lengths, torch.max(mel_lengths)), 1
  425. # ).float()
  426. # if indices is not None:
  427. # features = self.vq.decode(indices)
  428. # # Sample mels
  429. # decoded = gradient_checkpoint(self.decoder, features, use_reentrant=False)
  430. # return VQDecodeResult(
  431. # mels=decoded,
  432. # audios=self.generator(decoded) if return_audios else None,
  433. # )