lit_module.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381
  1. import itertools
  2. from dataclasses import dataclass
  3. from typing import Any, Callable, Literal, Optional
  4. import lightning as L
  5. import torch
  6. import torch.nn.functional as F
  7. import wandb
  8. from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
  9. from matplotlib import pyplot as plt
  10. from torch import nn
  11. from fish_speech.models.vqgan.modules.wavenet import WaveNet
  12. from fish_speech.models.vqgan.utils import plot_mel, sequence_mask
  13. class VQGAN(L.LightningModule):
  14. def __init__(
  15. self,
  16. optimizer: Callable,
  17. lr_scheduler: Callable,
  18. encoder: WaveNet,
  19. quantizer: nn.Module,
  20. decoder: WaveNet,
  21. reflow: nn.Module,
  22. vocoder: nn.Module,
  23. mel_transform: nn.Module,
  24. weight_reflow: float = 1.0,
  25. weight_vq: float = 1.0,
  26. weight_mel: float = 1.0,
  27. sampling_rate: int = 44100,
  28. freeze_encoder: bool = False,
  29. reflow_use_shallow: bool = False,
  30. reflow_inference_steps: int = 10,
  31. reflow_inference_start_t: float = 0.5,
  32. ):
  33. super().__init__()
  34. # Model parameters
  35. self.optimizer_builder = optimizer
  36. self.lr_scheduler_builder = lr_scheduler
  37. # Modules
  38. self.encoder = encoder
  39. self.quantizer = quantizer
  40. self.decoder = decoder
  41. self.vocoder = vocoder
  42. self.reflow = reflow
  43. self.mel_transform = mel_transform
  44. # Freeze vocoder
  45. for param in self.vocoder.parameters():
  46. param.requires_grad = False
  47. # Loss weights
  48. self.weight_reflow = weight_reflow
  49. self.weight_vq = weight_vq
  50. self.weight_mel = weight_mel
  51. # Other parameters
  52. self.spec_min = -12
  53. self.spec_max = 3
  54. self.sampling_rate = sampling_rate
  55. self.reflow_use_shallow = reflow_use_shallow
  56. self.reflow_inference_steps = reflow_inference_steps
  57. self.reflow_inference_start_t = reflow_inference_start_t
  58. # Disable strict loading
  59. self.strict_loading = False
  60. # If encoder is frozen
  61. if freeze_encoder:
  62. for param in self.encoder.parameters():
  63. param.requires_grad = False
  64. for param in self.quantizer.parameters():
  65. param.requires_grad = False
  66. def on_save_checkpoint(self, checkpoint):
  67. # Do not save vocoder
  68. state_dict = checkpoint["state_dict"]
  69. for name in list(state_dict.keys()):
  70. if "vocoder" in name:
  71. state_dict.pop(name)
  72. def configure_optimizers(self):
  73. optimizer = self.optimizer_builder(self.parameters())
  74. lr_scheduler = self.lr_scheduler_builder(optimizer)
  75. return {
  76. "optimizer": optimizer,
  77. "lr_scheduler": {
  78. "scheduler": lr_scheduler,
  79. "interval": "step",
  80. },
  81. }
  82. def norm_spec(self, x):
  83. return (x - self.spec_min) / (self.spec_max - self.spec_min) * 2 - 1
  84. def denorm_spec(self, x):
  85. return (x + 1) / 2 * (self.spec_max - self.spec_min) + self.spec_min
  86. def training_step(self, batch, batch_idx):
  87. audios, audio_lengths = batch["audios"], batch["audio_lengths"]
  88. audios = audios.float()
  89. audios = audios[:, None, :]
  90. with torch.no_grad():
  91. gt_mels = self.mel_transform(audios)
  92. mel_lengths = audio_lengths // self.mel_transform.hop_length
  93. mel_masks = sequence_mask(mel_lengths, gt_mels.shape[2])
  94. mel_masks_float_conv = mel_masks[:, None, :].float()
  95. gt_mels = gt_mels * mel_masks_float_conv
  96. # Encode
  97. encoded_features = self.encoder(gt_mels) * mel_masks_float_conv
  98. # Quantize
  99. vq_result = self.quantizer(encoded_features)
  100. loss_vq = getattr("vq_result", "loss", 0.0)
  101. vq_recon_features = vq_result.z * mel_masks_float_conv
  102. # VQ Decode
  103. gen_mel = self.decoder(vq_recon_features) * mel_masks_float_conv
  104. # Mel Loss
  105. loss_mel = (gen_mel - gt_mels).abs().mean(
  106. dim=1, keepdim=True
  107. ).sum() / mel_masks_float_conv.sum()
  108. # Reflow, given x_1_aux, we want to reconstruct x_1
  109. x_1 = self.norm_spec(gt_mels)
  110. if self.reflow_use_shallow:
  111. # Detach the gradient of the generated mel
  112. x_1_aux = self.norm_spec(gen_mel.detach())
  113. else:
  114. x_1_aux = x_1
  115. t = torch.rand(gt_mels.shape[0], device=gt_mels.device)
  116. x_0 = torch.randn_like(x_1)
  117. # X_t = t * X_1 + (1 - t) * X_0
  118. x_t = x_0 + t[:, None, None] * (x_1_aux - x_0)
  119. v_pred = self.reflow(
  120. x_t,
  121. 1000 * t,
  122. vq_recon_features.detach(), # Stop gradients, avoid reflow to destroy the VQ
  123. )
  124. # Log L2 loss with
  125. weights = 0.398942 / t / (1 - t) * torch.exp(-0.5 * torch.log(t / (1 - t)) ** 2)
  126. loss_reflow = weights[:, None, None] * F.mse_loss(
  127. x_1 - x_0, v_pred, reduction="none"
  128. )
  129. loss_reflow = (loss_reflow * mel_masks_float_conv).mean(
  130. dim=1
  131. ).sum() / mel_masks_float_conv.sum()
  132. # Total loss
  133. loss = (
  134. self.weight_vq * loss_vq
  135. + self.weight_mel * loss_mel
  136. + self.weight_reflow * loss_reflow
  137. )
  138. # Log losses
  139. self.log(
  140. "train/generator/loss",
  141. loss,
  142. on_step=True,
  143. on_epoch=False,
  144. prog_bar=True,
  145. logger=True,
  146. )
  147. self.log(
  148. "train/generator/loss_vq",
  149. loss_vq,
  150. on_step=True,
  151. on_epoch=False,
  152. prog_bar=False,
  153. logger=True,
  154. )
  155. self.log(
  156. "train/generator/loss_mel",
  157. loss_mel,
  158. on_step=True,
  159. on_epoch=False,
  160. prog_bar=False,
  161. logger=True,
  162. )
  163. self.log(
  164. "train/generator/loss_reflow",
  165. loss_reflow,
  166. on_step=True,
  167. on_epoch=False,
  168. prog_bar=False,
  169. logger=True,
  170. )
  171. return loss
  172. def validation_step(self, batch: Any, batch_idx: int):
  173. audios, audio_lengths = batch["audios"], batch["audio_lengths"]
  174. audios = audios.float()
  175. audios = audios[:, None, :]
  176. gt_mels = self.mel_transform(audios)
  177. mel_lengths = audio_lengths // self.mel_transform.hop_length
  178. mel_masks = sequence_mask(mel_lengths, gt_mels.shape[2])
  179. mel_masks_float_conv = mel_masks[:, None, :].float()
  180. gt_mels = gt_mels * mel_masks_float_conv
  181. # Encode
  182. encoded_features = self.encoder(gt_mels) * mel_masks_float_conv
  183. # Quantize
  184. vq_recon_features = self.quantizer(encoded_features).z * mel_masks_float_conv
  185. # VQ Decode
  186. gen_aux_mels = self.decoder(vq_recon_features) * mel_masks_float_conv
  187. loss_mel = (gen_aux_mels - gt_mels).abs().mean(
  188. dim=1, keepdim=True
  189. ).sum() / mel_masks_float_conv.sum()
  190. self.log(
  191. "val/loss_mel",
  192. loss_mel,
  193. on_step=False,
  194. on_epoch=True,
  195. prog_bar=False,
  196. logger=True,
  197. sync_dist=True,
  198. )
  199. # Reflow inference
  200. t_start = self.reflow_inference_start_t if self.reflow_use_shallow else 0.0
  201. x_1 = self.norm_spec(gen_aux_mels)
  202. x_0 = torch.randn_like(x_1)
  203. gen_reflow_mels = (1 - t_start) * x_0 + t_start * x_1
  204. t = torch.zeros(gt_mels.shape[0], device=gt_mels.device)
  205. dt = (1.0 - t_start) / self.reflow_inference_steps
  206. for _ in range(self.reflow_inference_steps):
  207. gen_reflow_mels += (
  208. self.reflow(
  209. gen_reflow_mels,
  210. 1000 * t,
  211. vq_recon_features,
  212. )
  213. * dt
  214. )
  215. t += dt
  216. gen_reflow_mels = self.denorm_spec(gen_reflow_mels) * mel_masks_float_conv
  217. loss_reflow_mel = (gen_reflow_mels - gt_mels).abs().mean(
  218. dim=1, keepdim=True
  219. ).sum() / mel_masks_float_conv.sum()
  220. self.log(
  221. "val/loss_reflow_mel",
  222. loss_reflow_mel,
  223. on_step=False,
  224. on_epoch=True,
  225. prog_bar=False,
  226. logger=True,
  227. sync_dist=True,
  228. )
  229. recon_audios = self.vocoder(gt_mels)
  230. gen_aux_audios = self.vocoder(gen_aux_mels)
  231. gen_reflow_audios = self.vocoder(gen_reflow_mels)
  232. # only log the first batch
  233. if batch_idx != 0:
  234. return
  235. for idx, (
  236. gt_mel,
  237. gen_aux_mel,
  238. gen_reflow_mel,
  239. audio,
  240. gen_aux_audio,
  241. gen_reflow_audio,
  242. recon_audio,
  243. audio_len,
  244. ) in enumerate(
  245. zip(
  246. gt_mels,
  247. gen_aux_mels,
  248. gen_reflow_mels,
  249. audios.float(),
  250. gen_aux_audios.float(),
  251. gen_reflow_audios.float(),
  252. recon_audios.float(),
  253. audio_lengths,
  254. )
  255. ):
  256. mel_len = audio_len // self.mel_transform.hop_length
  257. image_mels = plot_mel(
  258. [
  259. gt_mel[:, :mel_len],
  260. gen_aux_mel[:, :mel_len],
  261. gen_reflow_mel[:, :mel_len],
  262. ],
  263. [
  264. "Ground-Truth",
  265. "Auxiliary",
  266. "Reflow",
  267. ],
  268. )
  269. if isinstance(self.logger, WandbLogger):
  270. self.logger.experiment.log(
  271. {
  272. "reconstruction_mel": wandb.Image(image_mels, caption="mels"),
  273. "wavs": [
  274. wandb.Audio(
  275. audio[0, :audio_len],
  276. sample_rate=self.sampling_rate,
  277. caption="gt",
  278. ),
  279. wandb.Audio(
  280. gen_aux_audio[0, :audio_len],
  281. sample_rate=self.sampling_rate,
  282. caption="aux",
  283. ),
  284. wandb.Audio(
  285. gen_reflow_audio[0, :audio_len],
  286. sample_rate=self.sampling_rate,
  287. caption="reflow",
  288. ),
  289. wandb.Audio(
  290. recon_audio[0, :audio_len],
  291. sample_rate=self.sampling_rate,
  292. caption="recon",
  293. ),
  294. ],
  295. },
  296. )
  297. if isinstance(self.logger, TensorBoardLogger):
  298. self.logger.experiment.add_figure(
  299. f"sample-{idx}/mels",
  300. image_mels,
  301. global_step=self.global_step,
  302. )
  303. self.logger.experiment.add_audio(
  304. f"sample-{idx}/wavs/gt",
  305. audio[0, :audio_len],
  306. self.global_step,
  307. sample_rate=self.sampling_rate,
  308. )
  309. self.logger.experiment.add_audio(
  310. f"sample-{idx}/wavs/gen",
  311. gen_aux_audio[0, :audio_len],
  312. self.global_step,
  313. sample_rate=self.sampling_rate,
  314. )
  315. self.logger.experiment.add_audio(
  316. f"sample-{idx}/wavs/reflow",
  317. gen_reflow_audio[0, :audio_len],
  318. self.global_step,
  319. sample_rate=self.sampling_rate,
  320. )
  321. self.logger.experiment.add_audio(
  322. f"sample-{idx}/wavs/recon",
  323. recon_audio[0, :audio_len],
  324. self.global_step,
  325. sample_rate=self.sampling_rate,
  326. )
  327. plt.close(image_mels)