lit_module.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380
  1. import itertools
  2. from dataclasses import dataclass
  3. from typing import Any, Callable, Literal, Optional
  4. import lightning as L
  5. import torch
  6. import torch.nn.functional as F
  7. import wandb
  8. from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
  9. from matplotlib import pyplot as plt
  10. from torch import nn
  11. from fish_speech.models.vqgan.modules.wavenet import WaveNet
  12. from fish_speech.models.vqgan.utils import plot_mel, sequence_mask
  13. class VQGAN(L.LightningModule):
  14. def __init__(
  15. self,
  16. optimizer: Callable,
  17. lr_scheduler: Callable,
  18. encoder: WaveNet,
  19. quantizer: nn.Module,
  20. decoder: WaveNet,
  21. reflow: nn.Module,
  22. vocoder: nn.Module,
  23. mel_transform: nn.Module,
  24. weight_reflow: float = 1.0,
  25. weight_vq: float = 1.0,
  26. weight_mel: float = 1.0,
  27. sampling_rate: int = 44100,
  28. freeze_encoder: bool = False,
  29. reflow_use_shallow: bool = False,
  30. reflow_inference_steps: int = 10,
  31. reflow_inference_start_t: float = 0.5,
  32. ):
  33. super().__init__()
  34. # Model parameters
  35. self.optimizer_builder = optimizer
  36. self.lr_scheduler_builder = lr_scheduler
  37. # Modules
  38. self.encoder = encoder
  39. self.quantizer = quantizer
  40. self.decoder = decoder
  41. self.vocoder = vocoder
  42. self.reflow = reflow
  43. self.mel_transform = mel_transform
  44. # Freeze vocoder
  45. for param in self.vocoder.parameters():
  46. param.requires_grad = False
  47. # Loss weights
  48. self.weight_reflow = weight_reflow
  49. self.weight_vq = weight_vq
  50. self.weight_mel = weight_mel
  51. # Other parameters
  52. self.spec_min = -12
  53. self.spec_max = 3
  54. self.sampling_rate = sampling_rate
  55. self.reflow_use_shallow = reflow_use_shallow
  56. self.reflow_inference_steps = reflow_inference_steps
  57. self.reflow_inference_start_t = reflow_inference_start_t
  58. # Disable strict loading
  59. self.strict_loading = False
  60. # If encoder is frozen
  61. if freeze_encoder:
  62. for param in self.encoder.parameters():
  63. param.requires_grad = False
  64. for param in self.quantizer.parameters():
  65. param.requires_grad = False
  66. def on_save_checkpoint(self, checkpoint):
  67. # Do not save vocoder
  68. state_dict = checkpoint["state_dict"]
  69. for name in list(state_dict.keys()):
  70. if "vocoder" in name:
  71. state_dict.pop(name)
  72. def configure_optimizers(self):
  73. optimizer = self.optimizer_builder(self.parameters())
  74. lr_scheduler = self.lr_scheduler_builder(optimizer)
  75. return {
  76. "optimizer": optimizer,
  77. "lr_scheduler": {
  78. "scheduler": lr_scheduler,
  79. "interval": "step",
  80. },
  81. }
  82. def norm_spec(self, x):
  83. return (x - self.spec_min) / (self.spec_max - self.spec_min) * 2 - 1
  84. def denorm_spec(self, x):
  85. return (x + 1) / 2 * (self.spec_max - self.spec_min) + self.spec_min
  86. def training_step(self, batch, batch_idx):
  87. audios, audio_lengths = batch["audios"], batch["audio_lengths"]
  88. audios = audios.float()
  89. audios = audios[:, None, :]
  90. with torch.no_grad():
  91. gt_mels = self.mel_transform(audios)
  92. mel_lengths = audio_lengths // self.mel_transform.hop_length
  93. mel_masks = sequence_mask(mel_lengths, gt_mels.shape[2])
  94. mel_masks_float_conv = mel_masks[:, None, :].float()
  95. gt_mels = gt_mels * mel_masks_float_conv
  96. # Encode
  97. encoded_features = self.encoder(gt_mels) * mel_masks_float_conv
  98. # Quantize
  99. vq_result = self.quantizer(encoded_features)
  100. loss_vq = getattr("vq_result", "loss", 0.0)
  101. vq_recon_features = vq_result.z * mel_masks_float_conv
  102. # VQ Decode
  103. gen_mel = self.decoder(vq_recon_features) * mel_masks_float_conv
  104. # Mel Loss
  105. loss_mel = (gen_mel - gt_mels).abs().mean(
  106. dim=1, keepdim=True
  107. ).sum() / mel_masks_float_conv.sum()
  108. # Reflow, given x_1_aux, we want to reconstruct x_1
  109. x_1 = self.norm_spec(gt_mels)
  110. if self.reflow_use_shallow:
  111. x_1_aux = self.norm_spec(gen_mel)
  112. else:
  113. x_1_aux = x_1
  114. t = torch.rand(gt_mels.shape[0], device=gt_mels.device)
  115. x_0 = torch.randn_like(x_1)
  116. # X_t = t * X_1 + (1 - t) * X_0
  117. x_t = x_0 + t[:, None, None] * (x_1_aux - x_0)
  118. v_pred = self.reflow(
  119. x_t,
  120. 1000 * t,
  121. vq_recon_features.detach(), # Stop gradients, avoid reflow to destroy the VQ
  122. )
  123. # Log L2 loss with
  124. weights = 0.398942 / t / (1 - t) * torch.exp(-0.5 * torch.log(t / (1 - t)) ** 2)
  125. loss_reflow = weights[:, None, None] * F.mse_loss(
  126. x_1 - x_0, v_pred, reduction="none"
  127. )
  128. loss_reflow = (loss_reflow * mel_masks_float_conv).mean(
  129. dim=1
  130. ).sum() / mel_masks_float_conv.sum()
  131. # Total loss
  132. loss = (
  133. self.weight_vq * loss_vq
  134. + self.weight_mel * loss_mel
  135. + self.weight_reflow * loss_reflow
  136. )
  137. # Log losses
  138. self.log(
  139. "train/generator/loss",
  140. loss,
  141. on_step=True,
  142. on_epoch=False,
  143. prog_bar=True,
  144. logger=True,
  145. )
  146. self.log(
  147. "train/generator/loss_vq",
  148. loss_vq,
  149. on_step=True,
  150. on_epoch=False,
  151. prog_bar=False,
  152. logger=True,
  153. )
  154. self.log(
  155. "train/generator/loss_mel",
  156. loss_mel,
  157. on_step=True,
  158. on_epoch=False,
  159. prog_bar=False,
  160. logger=True,
  161. )
  162. self.log(
  163. "train/generator/loss_reflow",
  164. loss_reflow,
  165. on_step=True,
  166. on_epoch=False,
  167. prog_bar=False,
  168. logger=True,
  169. )
  170. return loss
  171. def validation_step(self, batch: Any, batch_idx: int):
  172. audios, audio_lengths = batch["audios"], batch["audio_lengths"]
  173. audios = audios.float()
  174. audios = audios[:, None, :]
  175. gt_mels = self.mel_transform(audios)
  176. mel_lengths = audio_lengths // self.mel_transform.hop_length
  177. mel_masks = sequence_mask(mel_lengths, gt_mels.shape[2])
  178. mel_masks_float_conv = mel_masks[:, None, :].float()
  179. gt_mels = gt_mels * mel_masks_float_conv
  180. # Encode
  181. encoded_features = self.encoder(gt_mels) * mel_masks_float_conv
  182. # Quantize
  183. vq_recon_features = self.quantizer(encoded_features).z * mel_masks_float_conv
  184. # VQ Decode
  185. gen_aux_mels = self.decoder(vq_recon_features) * mel_masks_float_conv
  186. loss_mel = (gen_aux_mels - gt_mels).abs().mean(
  187. dim=1, keepdim=True
  188. ).sum() / mel_masks_float_conv.sum()
  189. self.log(
  190. "val/loss_mel",
  191. loss_mel,
  192. on_step=False,
  193. on_epoch=True,
  194. prog_bar=False,
  195. logger=True,
  196. sync_dist=True,
  197. )
  198. # Reflow inference
  199. t_start = self.reflow_inference_start_t if self.reflow_use_shallow else 0.0
  200. x_1 = self.norm_spec(gen_aux_mels)
  201. x_0 = torch.randn_like(x_1)
  202. gen_reflow_mels = (1 - t_start) * x_0 + t_start * x_1
  203. t = torch.zeros(gt_mels.shape[0], device=gt_mels.device)
  204. dt = (1.0 - t_start) / self.reflow_inference_steps
  205. for _ in range(self.reflow_inference_steps):
  206. gen_reflow_mels += (
  207. self.reflow(
  208. gen_reflow_mels,
  209. 1000 * t,
  210. vq_recon_features,
  211. )
  212. * dt
  213. )
  214. t += dt
  215. gen_reflow_mels = self.denorm_spec(gen_reflow_mels) * mel_masks_float_conv
  216. loss_reflow_mel = (gen_reflow_mels - gt_mels).abs().mean(
  217. dim=1, keepdim=True
  218. ).sum() / mel_masks_float_conv.sum()
  219. self.log(
  220. "val/loss_reflow_mel",
  221. loss_reflow_mel,
  222. on_step=False,
  223. on_epoch=True,
  224. prog_bar=False,
  225. logger=True,
  226. sync_dist=True,
  227. )
  228. recon_audios = self.vocoder(gt_mels)
  229. gen_aux_audios = self.vocoder(gen_aux_mels)
  230. gen_reflow_audios = self.vocoder(gen_reflow_mels)
  231. # only log the first batch
  232. if batch_idx != 0:
  233. return
  234. for idx, (
  235. gt_mel,
  236. gen_aux_mel,
  237. gen_reflow_mel,
  238. audio,
  239. gen_aux_audio,
  240. gen_reflow_audio,
  241. recon_audio,
  242. audio_len,
  243. ) in enumerate(
  244. zip(
  245. gt_mels,
  246. gen_aux_mels,
  247. gen_reflow_mels,
  248. audios.float(),
  249. gen_aux_audios.float(),
  250. gen_reflow_audios.float(),
  251. recon_audios.float(),
  252. audio_lengths,
  253. )
  254. ):
  255. mel_len = audio_len // self.mel_transform.hop_length
  256. image_mels = plot_mel(
  257. [
  258. gt_mel[:, :mel_len],
  259. gen_aux_mel[:, :mel_len],
  260. gen_reflow_mel[:, :mel_len],
  261. ],
  262. [
  263. "Ground-Truth",
  264. "Auxiliary",
  265. "Reflow",
  266. ],
  267. )
  268. if isinstance(self.logger, WandbLogger):
  269. self.logger.experiment.log(
  270. {
  271. "reconstruction_mel": wandb.Image(image_mels, caption="mels"),
  272. "wavs": [
  273. wandb.Audio(
  274. audio[0, :audio_len],
  275. sample_rate=self.sampling_rate,
  276. caption="gt",
  277. ),
  278. wandb.Audio(
  279. gen_aux_audio[0, :audio_len],
  280. sample_rate=self.sampling_rate,
  281. caption="aux",
  282. ),
  283. wandb.Audio(
  284. gen_reflow_audio[0, :audio_len],
  285. sample_rate=self.sampling_rate,
  286. caption="reflow",
  287. ),
  288. wandb.Audio(
  289. recon_audio[0, :audio_len],
  290. sample_rate=self.sampling_rate,
  291. caption="recon",
  292. ),
  293. ],
  294. },
  295. )
  296. if isinstance(self.logger, TensorBoardLogger):
  297. self.logger.experiment.add_figure(
  298. f"sample-{idx}/mels",
  299. image_mels,
  300. global_step=self.global_step,
  301. )
  302. self.logger.experiment.add_audio(
  303. f"sample-{idx}/wavs/gt",
  304. audio[0, :audio_len],
  305. self.global_step,
  306. sample_rate=self.sampling_rate,
  307. )
  308. self.logger.experiment.add_audio(
  309. f"sample-{idx}/wavs/gen",
  310. gen_aux_audio[0, :audio_len],
  311. self.global_step,
  312. sample_rate=self.sampling_rate,
  313. )
  314. self.logger.experiment.add_audio(
  315. f"sample-{idx}/wavs/reflow",
  316. gen_reflow_audio[0, :audio_len],
  317. self.global_step,
  318. sample_rate=self.sampling_rate,
  319. )
  320. self.logger.experiment.add_audio(
  321. f"sample-{idx}/wavs/recon",
  322. recon_audio[0, :audio_len],
  323. self.global_step,
  324. sample_rate=self.sampling_rate,
  325. )
  326. plt.close(image_mels)