lit_module.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390
  1. import itertools
  2. import math
  3. from typing import Any, Callable
  4. import lightning as L
  5. import torch
  6. import torch.nn.functional as F
  7. import wandb
  8. from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
  9. from matplotlib import pyplot as plt
  10. from torch import nn
  11. from fish_speech.models.vqgan.modules.discriminator import Discriminator
  12. from fish_speech.models.vqgan.modules.wavenet import WaveNet
  13. from fish_speech.models.vqgan.utils import avg_with_mask, plot_mel, sequence_mask
  14. class VQGAN(L.LightningModule):
  15. def __init__(
  16. self,
  17. optimizer: Callable,
  18. lr_scheduler: Callable,
  19. encoder: WaveNet,
  20. quantizer: nn.Module,
  21. decoder: WaveNet,
  22. discriminator: Discriminator,
  23. vocoder: nn.Module,
  24. mel_transform: nn.Module,
  25. weight_adv: float = 1.0,
  26. weight_vq: float = 1.0,
  27. weight_mel: float = 1.0,
  28. sampling_rate: int = 44100,
  29. freeze_encoder: bool = False,
  30. ):
  31. super().__init__()
  32. # Model parameters
  33. self.optimizer_builder = optimizer
  34. self.lr_scheduler_builder = lr_scheduler
  35. # Modules
  36. self.encoder = encoder
  37. self.quantizer = quantizer
  38. self.decoder = decoder
  39. self.vocoder = vocoder
  40. self.discriminator = discriminator
  41. self.mel_transform = mel_transform
  42. # Freeze vocoder
  43. for param in self.vocoder.parameters():
  44. param.requires_grad = False
  45. # Loss weights
  46. self.weight_adv = weight_adv
  47. self.weight_vq = weight_vq
  48. self.weight_mel = weight_mel
  49. # Other parameters
  50. self.sampling_rate = sampling_rate
  51. # Disable strict loading
  52. self.strict_loading = False
  53. # If encoder is frozen
  54. if freeze_encoder:
  55. for param in self.encoder.parameters():
  56. param.requires_grad = False
  57. for param in self.quantizer.parameters():
  58. param.requires_grad = False
  59. self.automatic_optimization = False
  60. def on_save_checkpoint(self, checkpoint):
  61. # Do not save vocoder
  62. state_dict = checkpoint["state_dict"]
  63. for name in list(state_dict.keys()):
  64. if "vocoder" in name:
  65. state_dict.pop(name)
  66. def configure_optimizers(self):
  67. optimizer_generator = self.optimizer_builder(
  68. itertools.chain(
  69. self.encoder.parameters(),
  70. self.quantizer.parameters(),
  71. self.decoder.parameters(),
  72. )
  73. )
  74. optimizer_discriminator = self.optimizer_builder(
  75. self.discriminator.parameters()
  76. )
  77. lr_scheduler_generator = self.lr_scheduler_builder(optimizer_generator)
  78. lr_scheduler_discriminator = self.lr_scheduler_builder(optimizer_discriminator)
  79. return (
  80. {
  81. "optimizer": optimizer_generator,
  82. "lr_scheduler": {
  83. "scheduler": lr_scheduler_generator,
  84. "interval": "step",
  85. "name": "optimizer/generator",
  86. },
  87. },
  88. {
  89. "optimizer": optimizer_discriminator,
  90. "lr_scheduler": {
  91. "scheduler": lr_scheduler_discriminator,
  92. "interval": "step",
  93. "name": "optimizer/discriminator",
  94. },
  95. },
  96. )
  97. def training_step(self, batch, batch_idx):
  98. optim_g, optim_d = self.optimizers()
  99. audios, audio_lengths = batch["audios"], batch["audio_lengths"]
  100. audios = audios.float()
  101. audios = audios[:, None, :]
  102. with torch.no_grad():
  103. gt_mels = self.mel_transform(audios)
  104. mel_lengths = audio_lengths // self.mel_transform.hop_length
  105. mel_masks = sequence_mask(mel_lengths, gt_mels.shape[2])
  106. mel_masks_float_conv = mel_masks[:, None, :].float()
  107. gt_mels = gt_mels * mel_masks_float_conv
  108. # Encode
  109. encoded_features = self.encoder(gt_mels) * mel_masks_float_conv
  110. # Quantize
  111. vq_result = self.quantizer(encoded_features)
  112. loss_vq = getattr("vq_result", "loss", 0.0)
  113. vq_recon_features = vq_result.z * mel_masks_float_conv
  114. # VQ Decode
  115. gen_mel = (
  116. self.decoder(
  117. torch.randn_like(vq_recon_features) * mel_masks_float_conv,
  118. condition=vq_recon_features,
  119. )
  120. * mel_masks_float_conv
  121. )
  122. # Discriminator
  123. real_logits = self.discriminator(gt_mels)
  124. fake_logits = self.discriminator(gen_mel.detach())
  125. d_mask = F.interpolate(
  126. mel_masks_float_conv, size=(real_logits.shape[2],), mode="nearest"
  127. )
  128. loss_real = avg_with_mask((real_logits - 1) ** 2, d_mask)
  129. loss_fake = avg_with_mask(fake_logits**2, d_mask)
  130. loss_d = loss_real + loss_fake
  131. self.log(
  132. "train/discriminator/loss",
  133. loss_d,
  134. on_step=True,
  135. on_epoch=False,
  136. prog_bar=True,
  137. logger=True,
  138. )
  139. # Discriminator backward
  140. optim_d.zero_grad()
  141. self.manual_backward(loss_d)
  142. self.clip_gradients(
  143. optim_d, gradient_clip_val=1000.0, gradient_clip_algorithm="norm"
  144. )
  145. optim_d.step()
  146. # Mel Loss
  147. loss_mel = avg_with_mask((gen_mel - gt_mels).abs(), mel_masks_float_conv)
  148. # Adversarial Loss
  149. fake_logits = self.discriminator(gen_mel)
  150. loss_adv = avg_with_mask((fake_logits - 1) ** 2, d_mask)
  151. # Total loss
  152. loss = (
  153. self.weight_vq * loss_vq
  154. + self.weight_mel * loss_mel
  155. + self.weight_adv * loss_adv
  156. )
  157. # Log losses
  158. self.log(
  159. "train/generator/loss",
  160. loss,
  161. on_step=True,
  162. on_epoch=False,
  163. prog_bar=True,
  164. logger=True,
  165. )
  166. self.log(
  167. "train/generator/loss_vq",
  168. loss_vq,
  169. on_step=True,
  170. on_epoch=False,
  171. prog_bar=False,
  172. logger=True,
  173. )
  174. self.log(
  175. "train/generator/loss_mel",
  176. loss_mel,
  177. on_step=True,
  178. on_epoch=False,
  179. prog_bar=False,
  180. logger=True,
  181. )
  182. # Generator backward
  183. optim_g.zero_grad()
  184. self.manual_backward(loss)
  185. self.clip_gradients(
  186. optim_g, gradient_clip_val=1000.0, gradient_clip_algorithm="norm"
  187. )
  188. optim_g.step()
  189. scheduler_g, scheduler_d = self.lr_schedulers()
  190. scheduler_g.step()
  191. scheduler_d.step()
  192. def validation_step(self, batch: Any, batch_idx: int):
  193. audios, audio_lengths = batch["audios"], batch["audio_lengths"]
  194. audios = audios.float()
  195. audios = audios[:, None, :]
  196. gt_mels = self.mel_transform(audios)
  197. mel_lengths = audio_lengths // self.mel_transform.hop_length
  198. mel_masks = sequence_mask(mel_lengths, gt_mels.shape[2])
  199. mel_masks_float_conv = mel_masks[:, None, :].float()
  200. gt_mels = gt_mels * mel_masks_float_conv
  201. # Encode
  202. encoded_features = self.encoder(gt_mels) * mel_masks_float_conv
  203. # Quantize
  204. vq_recon_features = self.quantizer(encoded_features).z * mel_masks_float_conv
  205. # VQ Decode
  206. gen_aux_mels = (
  207. self.decoder(
  208. torch.randn_like(vq_recon_features) * mel_masks_float_conv,
  209. condition=vq_recon_features,
  210. )
  211. * mel_masks_float_conv
  212. )
  213. loss_mel = avg_with_mask((gen_aux_mels - gt_mels).abs(), mel_masks_float_conv)
  214. self.log(
  215. "val/loss_mel",
  216. loss_mel,
  217. on_step=False,
  218. on_epoch=True,
  219. prog_bar=False,
  220. logger=True,
  221. sync_dist=True,
  222. )
  223. recon_audios = self.vocoder(gt_mels)
  224. gen_aux_audios = self.vocoder(gen_aux_mels)
  225. # only log the first batch
  226. if batch_idx != 0:
  227. return
  228. for idx, (
  229. gt_mel,
  230. gen_aux_mel,
  231. audio,
  232. gen_aux_audio,
  233. recon_audio,
  234. audio_len,
  235. ) in enumerate(
  236. zip(
  237. gt_mels,
  238. gen_aux_mels,
  239. audios.cpu().float(),
  240. gen_aux_audios.cpu().float(),
  241. recon_audios.cpu().float(),
  242. audio_lengths,
  243. )
  244. ):
  245. if idx > 4:
  246. break
  247. mel_len = audio_len // self.mel_transform.hop_length
  248. image_mels = plot_mel(
  249. [
  250. gt_mel[:, :mel_len],
  251. gen_aux_mel[:, :mel_len],
  252. ],
  253. [
  254. "Ground-Truth",
  255. "Auxiliary",
  256. ],
  257. )
  258. if isinstance(self.logger, WandbLogger):
  259. self.logger.experiment.log(
  260. {
  261. "reconstruction_mel": wandb.Image(image_mels, caption="mels"),
  262. "wavs": [
  263. wandb.Audio(
  264. audio[0, :audio_len],
  265. sample_rate=self.sampling_rate,
  266. caption="gt",
  267. ),
  268. wandb.Audio(
  269. gen_aux_audio[0, :audio_len],
  270. sample_rate=self.sampling_rate,
  271. caption="aux",
  272. ),
  273. wandb.Audio(
  274. recon_audio[0, :audio_len],
  275. sample_rate=self.sampling_rate,
  276. caption="recon",
  277. ),
  278. ],
  279. },
  280. )
  281. if isinstance(self.logger, TensorBoardLogger):
  282. self.logger.experiment.add_figure(
  283. f"sample-{idx}/mels",
  284. image_mels,
  285. global_step=self.global_step,
  286. )
  287. self.logger.experiment.add_audio(
  288. f"sample-{idx}/wavs/gt",
  289. audio[0, :audio_len],
  290. self.global_step,
  291. sample_rate=self.sampling_rate,
  292. )
  293. self.logger.experiment.add_audio(
  294. f"sample-{idx}/wavs/gen",
  295. gen_aux_audio[0, :audio_len],
  296. self.global_step,
  297. sample_rate=self.sampling_rate,
  298. )
  299. self.logger.experiment.add_audio(
  300. f"sample-{idx}/wavs/recon",
  301. recon_audio[0, :audio_len],
  302. self.global_step,
  303. sample_rate=self.sampling_rate,
  304. )
  305. plt.close(image_mels)
  306. def encode(self, audios, audio_lengths):
  307. audios = audios.float()
  308. gt_mels = self.mel_transform(audios)
  309. mel_lengths = audio_lengths // self.mel_transform.hop_length
  310. mel_masks = sequence_mask(mel_lengths, gt_mels.shape[2])
  311. mel_masks_float_conv = mel_masks[:, None, :].float()
  312. gt_mels = gt_mels * mel_masks_float_conv
  313. # Encode
  314. encoded_features = self.encoder(gt_mels) * mel_masks_float_conv
  315. feature_lengths = mel_lengths // math.prod(self.quantizer.downsample_factor)
  316. return self.quantizer.encode(encoded_features), feature_lengths
  317. def decode(self, indices, feature_lengths, return_audios=False):
  318. factor = math.prod(self.quantizer.downsample_factor)
  319. mel_masks = sequence_mask(feature_lengths * factor, indices.shape[2] * factor)
  320. mel_masks_float_conv = mel_masks[:, None, :].float()
  321. z = self.quantizer.decode(indices) * mel_masks_float_conv
  322. gen_mel = (
  323. self.decoder(
  324. torch.randn_like(z) * mel_masks_float_conv,
  325. condition=z,
  326. )
  327. * mel_masks_float_conv
  328. )
  329. if return_audios:
  330. return self.vocoder(gen_mel)
  331. return gen_mel