lit_module.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365
  1. import itertools
  2. from dataclasses import dataclass
  3. from typing import Any, Callable, Literal, Optional
  4. import lightning as L
  5. import torch
  6. import torch.nn.functional as F
  7. import wandb
  8. from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
  9. from matplotlib import pyplot as plt
  10. from torch import nn
  11. from fish_speech.models.vqgan.utils import plot_mel, sequence_mask, slice_segments
  12. @dataclass
  13. class VQEncodeResult:
  14. features: torch.Tensor
  15. indices: torch.Tensor
  16. loss: torch.Tensor
  17. feature_lengths: torch.Tensor
  18. @dataclass
  19. class VQDecodeResult:
  20. mels: torch.Tensor
  21. audios: Optional[torch.Tensor] = None
  22. class VQGAN(L.LightningModule):
  23. def __init__(
  24. self,
  25. optimizer: Callable,
  26. lr_scheduler: Callable,
  27. encoder: nn.Module,
  28. quantizer: nn.Module,
  29. aux_decoder: nn.Module,
  30. reflow: nn.Module,
  31. vocoder: nn.Module,
  32. mel_transform: nn.Module,
  33. weight_reflow: float = 1.0,
  34. weight_vq: float = 1.0,
  35. weight_aux_mel: float = 1.0,
  36. sampling_rate: int = 44100,
  37. ):
  38. super().__init__()
  39. # Model parameters
  40. self.optimizer_builder = optimizer
  41. self.lr_scheduler_builder = lr_scheduler
  42. # Modules
  43. self.encoder = encoder
  44. self.quantizer = quantizer
  45. self.aux_decoder = aux_decoder
  46. self.reflow = reflow
  47. self.mel_transform = mel_transform
  48. self.vocoder = vocoder
  49. # Freeze vocoder
  50. for param in self.vocoder.parameters():
  51. param.requires_grad = False
  52. # Loss weights
  53. self.weight_reflow = weight_reflow
  54. self.weight_vq = weight_vq
  55. self.weight_aux_mel = weight_aux_mel
  56. self.spec_min = -12
  57. self.spec_max = 3
  58. self.sampling_rate = sampling_rate
  59. self.strict_loading = False
  60. def on_save_checkpoint(self, checkpoint):
  61. # Do not save vocoder
  62. state_dict = checkpoint["state_dict"]
  63. for name in list(state_dict.keys()):
  64. if "vocoder" in name:
  65. state_dict.pop(name)
  66. def configure_optimizers(self):
  67. # Need two optimizers and two schedulers
  68. optimizer = self.optimizer_builder(self.parameters())
  69. lr_scheduler = self.lr_scheduler_builder(optimizer)
  70. return {
  71. "optimizer": optimizer,
  72. "lr_scheduler": {
  73. "scheduler": lr_scheduler,
  74. "interval": "step",
  75. },
  76. }
  77. def norm_spec(self, x):
  78. return (x - self.spec_min) / (self.spec_max - self.spec_min) * 2 - 1
  79. def denorm_spec(self, x):
  80. return (x + 1) / 2 * (self.spec_max - self.spec_min) + self.spec_min
  81. # @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
  82. def training_step(self, batch, batch_idx):
  83. audios, audio_lengths = batch["audios"], batch["audio_lengths"]
  84. audios = audios.float()
  85. audios = audios[:, None, :]
  86. with torch.no_grad():
  87. gt_mels = self.mel_transform(audios)
  88. mel_lengths = audio_lengths // self.mel_transform.hop_length
  89. mel_masks = sequence_mask(mel_lengths, gt_mels.shape[2])
  90. mel_masks_float_conv = mel_masks[:, None, :].float()
  91. # Encode
  92. encoded_features = self.encoder(gt_mels) * mel_masks_float_conv
  93. # Quantize
  94. vq_result = self.quantizer(encoded_features)
  95. loss_vq = getattr("vq_result", "loss", 0.0)
  96. vq_recon_features = vq_result.z * mel_masks_float_conv
  97. # VQ Decode
  98. aux_mel = self.aux_decoder(vq_recon_features)
  99. loss_aux_mel = F.l1_loss(
  100. aux_mel * mel_masks_float_conv, gt_mels * mel_masks_float_conv
  101. )
  102. # Reflow
  103. x_1 = self.norm_spec(gt_mels)
  104. t = torch.rand(gt_mels.shape[0], device=gt_mels.device)
  105. x_0 = torch.randn_like(x_1)
  106. # X_t = t * X_1 + (1 - t) * X_0
  107. x_t = x_0 + t[:, None, None] * (x_1 - x_0)
  108. v_pred = self.reflow(
  109. x_t,
  110. 1000 * t,
  111. vq_recon_features, # .detach()
  112. x_masks=mel_masks_float_conv,
  113. cond_masks=mel_masks_float_conv,
  114. )
  115. # Log L2 loss with
  116. weights = 0.398942 / t / (1 - t) * torch.exp(-0.5 * torch.log(t / (1 - t)) ** 2)
  117. loss_reflow = weights[:, None, None] * F.mse_loss(
  118. x_1 - x_0, v_pred, reduction="none"
  119. )
  120. loss_reflow = (loss_reflow * mel_masks_float_conv).mean()
  121. # Total loss
  122. loss = (
  123. self.weight_vq * loss_vq
  124. + self.weight_aux_mel * loss_aux_mel
  125. + self.weight_reflow * loss_reflow
  126. )
  127. # Log losses
  128. self.log(
  129. "train/loss", loss, on_step=True, on_epoch=False, prog_bar=True, logger=True
  130. )
  131. self.log(
  132. "train/loss_vq",
  133. loss_vq,
  134. on_step=True,
  135. on_epoch=False,
  136. prog_bar=False,
  137. logger=True,
  138. )
  139. self.log(
  140. "train/loss_aux_mel",
  141. loss_aux_mel,
  142. on_step=True,
  143. on_epoch=False,
  144. prog_bar=False,
  145. logger=True,
  146. )
  147. self.log(
  148. "train/loss_reflow",
  149. loss_reflow,
  150. on_step=True,
  151. on_epoch=False,
  152. prog_bar=False,
  153. logger=True,
  154. )
  155. return loss
  156. def validation_step(self, batch: Any, batch_idx: int):
  157. audios, audio_lengths = batch["audios"], batch["audio_lengths"]
  158. audios = audios.float()
  159. audios = audios[:, None, :]
  160. gt_mels = self.mel_transform(audios)
  161. mel_lengths = audio_lengths // self.mel_transform.hop_length
  162. mel_masks = sequence_mask(mel_lengths, gt_mels.shape[2])
  163. mel_masks_float_conv = mel_masks[:, None, :].float()
  164. # Encode
  165. encoded_features = self.encoder(gt_mels) * mel_masks_float_conv
  166. # Quantize
  167. vq_result = self.quantizer(encoded_features)
  168. # VQ Decode
  169. aux_mels = self.aux_decoder(vq_result.z)
  170. loss_aux_mel = F.l1_loss(
  171. aux_mels * mel_masks_float_conv, gt_mels * mel_masks_float_conv
  172. )
  173. self.log(
  174. "val/loss_aux_mel",
  175. loss_aux_mel,
  176. on_step=False,
  177. on_epoch=True,
  178. prog_bar=False,
  179. logger=True,
  180. sync_dist=True,
  181. )
  182. # Reflow inference
  183. t_start = 0.0
  184. infer_step = 10
  185. x_1 = self.norm_spec(aux_mels)
  186. x_0 = torch.randn_like(x_1)
  187. gen_mels = (1 - t_start) * x_0 + t_start * x_1
  188. t = torch.zeros(gt_mels.shape[0], device=gt_mels.device)
  189. dt = (1.0 - t_start) / infer_step
  190. for _ in range(infer_step):
  191. gen_mels += (
  192. self.reflow(
  193. gen_mels,
  194. 1000 * t,
  195. vq_result.z,
  196. x_masks=mel_masks_float_conv,
  197. cond_masks=mel_masks_float_conv,
  198. )
  199. * dt
  200. )
  201. t += dt
  202. gen_mels = self.denorm_spec(gen_mels)
  203. loss_recon_reflow = F.l1_loss(
  204. gen_mels * mel_masks_float_conv, gt_mels * mel_masks_float_conv
  205. )
  206. self.log(
  207. "val/loss_recon_reflow",
  208. loss_recon_reflow,
  209. on_step=False,
  210. on_epoch=True,
  211. prog_bar=False,
  212. logger=True,
  213. sync_dist=True,
  214. )
  215. gen_audios = self.vocoder(gen_mels)
  216. recon_audios = self.vocoder(gt_mels)
  217. aux_audios = self.vocoder(aux_mels)
  218. # only log the first batch
  219. if batch_idx != 0:
  220. return
  221. for idx, (
  222. gt_mel,
  223. reflow_mel,
  224. aux_mel,
  225. audio,
  226. reflow_audio,
  227. aux_audio,
  228. recon_audio,
  229. audio_len,
  230. ) in enumerate(
  231. zip(
  232. gt_mels,
  233. gen_mels,
  234. aux_mels,
  235. audios.float(),
  236. gen_audios.float(),
  237. aux_audios.float(),
  238. recon_audios.float(),
  239. audio_lengths,
  240. )
  241. ):
  242. mel_len = audio_len // self.mel_transform.hop_length
  243. image_mels = plot_mel(
  244. [
  245. gt_mel[:, :mel_len],
  246. reflow_mel[:, :mel_len],
  247. aux_mel[:, :mel_len],
  248. ],
  249. [
  250. "Ground-Truth",
  251. "Reflow",
  252. "Aux",
  253. ],
  254. )
  255. if isinstance(self.logger, WandbLogger):
  256. self.logger.experiment.log(
  257. {
  258. "reconstruction_mel": wandb.Image(image_mels, caption="mels"),
  259. "wavs": [
  260. wandb.Audio(
  261. audio[0, :audio_len],
  262. sample_rate=self.sampling_rate,
  263. caption="gt",
  264. ),
  265. wandb.Audio(
  266. reflow_audio[0, :audio_len],
  267. sample_rate=self.sampling_rate,
  268. caption="reflow",
  269. ),
  270. wandb.Audio(
  271. aux_audio[0, :audio_len],
  272. sample_rate=self.sampling_rate,
  273. caption="aux",
  274. ),
  275. wandb.Audio(
  276. recon_audio[0, :audio_len],
  277. sample_rate=self.sampling_rate,
  278. caption="recon",
  279. ),
  280. ],
  281. },
  282. )
  283. if isinstance(self.logger, TensorBoardLogger):
  284. self.logger.experiment.add_figure(
  285. f"sample-{idx}/mels",
  286. image_mels,
  287. global_step=self.global_step,
  288. )
  289. self.logger.experiment.add_audio(
  290. f"sample-{idx}/wavs/gt",
  291. audio[0, :audio_len],
  292. self.global_step,
  293. sample_rate=self.sampling_rate,
  294. )
  295. self.logger.experiment.add_audio(
  296. f"sample-{idx}/wavs/reflow",
  297. reflow_audio[0, :audio_len],
  298. self.global_step,
  299. sample_rate=self.sampling_rate,
  300. )
  301. self.logger.experiment.add_audio(
  302. f"sample-{idx}/wavs/aux",
  303. aux_audio[0, :audio_len],
  304. self.global_step,
  305. sample_rate=self.sampling_rate,
  306. )
  307. self.logger.experiment.add_audio(
  308. f"sample-{idx}/wavs/recon",
  309. recon_audio[0, :audio_len],
  310. self.global_step,
  311. sample_rate=self.sampling_rate,
  312. )
  313. plt.close(image_mels)