lit_module.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442
  1. import itertools
  2. import math
  3. from typing import Any, Callable
  4. import lightning as L
  5. import torch
  6. import torch.nn.functional as F
  7. import wandb
  8. from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
  9. from matplotlib import pyplot as plt
  10. from torch import nn
  11. from fish_speech.models.vqgan.modules.discriminator import Discriminator
  12. from fish_speech.models.vqgan.modules.wavenet import WaveNet
  13. from fish_speech.models.vqgan.utils import avg_with_mask, plot_mel, sequence_mask
  14. class VQGAN(L.LightningModule):
  15. def __init__(
  16. self,
  17. optimizer: Callable,
  18. lr_scheduler: Callable,
  19. encoder: WaveNet,
  20. quantizer: nn.Module,
  21. decoder: WaveNet,
  22. discriminator: Discriminator,
  23. vocoder: nn.Module,
  24. encode_mel_transform: nn.Module,
  25. gt_mel_transform: nn.Module,
  26. weight_adv: float = 1.0,
  27. weight_vq: float = 1.0,
  28. weight_mel: float = 1.0,
  29. sampling_rate: int = 44100,
  30. freeze_encoder: bool = False,
  31. ):
  32. super().__init__()
  33. # Model parameters
  34. self.optimizer_builder = optimizer
  35. self.lr_scheduler_builder = lr_scheduler
  36. # Modules
  37. self.encoder = encoder
  38. self.quantizer = quantizer
  39. self.decoder = decoder
  40. self.vocoder = vocoder
  41. self.discriminator = discriminator
  42. self.encode_mel_transform = encode_mel_transform
  43. self.gt_mel_transform = gt_mel_transform
  44. # A simple linear layer to project quality to condition channels
  45. self.quality_projection = nn.Linear(1, 768)
  46. # Freeze vocoder
  47. for param in self.vocoder.parameters():
  48. param.requires_grad = False
  49. # Loss weights
  50. self.weight_adv = weight_adv
  51. self.weight_vq = weight_vq
  52. self.weight_mel = weight_mel
  53. # Other parameters
  54. self.sampling_rate = sampling_rate
  55. # Disable strict loading
  56. self.strict_loading = False
  57. # If encoder is frozen
  58. if freeze_encoder:
  59. for param in self.encoder.parameters():
  60. param.requires_grad = False
  61. for param in self.quantizer.parameters():
  62. param.requires_grad = False
  63. self.automatic_optimization = False
  64. def on_save_checkpoint(self, checkpoint):
  65. # Do not save vocoder
  66. state_dict = checkpoint["state_dict"]
  67. for name in list(state_dict.keys()):
  68. if "vocoder" in name:
  69. state_dict.pop(name)
  70. def configure_optimizers(self):
  71. optimizer_generator = self.optimizer_builder(
  72. itertools.chain(
  73. self.encoder.parameters(),
  74. self.quantizer.parameters(),
  75. self.decoder.parameters(),
  76. self.quality_projection.parameters(),
  77. )
  78. )
  79. optimizer_discriminator = self.optimizer_builder(
  80. self.discriminator.parameters()
  81. )
  82. lr_scheduler_generator = self.lr_scheduler_builder(optimizer_generator)
  83. lr_scheduler_discriminator = self.lr_scheduler_builder(optimizer_discriminator)
  84. return (
  85. {
  86. "optimizer": optimizer_generator,
  87. "lr_scheduler": {
  88. "scheduler": lr_scheduler_generator,
  89. "interval": "step",
  90. "name": "optimizer/generator",
  91. },
  92. },
  93. {
  94. "optimizer": optimizer_discriminator,
  95. "lr_scheduler": {
  96. "scheduler": lr_scheduler_discriminator,
  97. "interval": "step",
  98. "name": "optimizer/discriminator",
  99. },
  100. },
  101. )
  102. def training_step(self, batch, batch_idx):
  103. optim_g, optim_d = self.optimizers()
  104. audios, audio_lengths = batch["audios"], batch["audio_lengths"]
  105. audios = audios.float()
  106. audios = audios[:, None, :]
  107. with torch.no_grad():
  108. encoded_mels = self.encode_mel_transform(audios)
  109. gt_mels = self.gt_mel_transform(audios)
  110. quality = ((gt_mels.mean(-1) > -8).sum(-1) - 90) / 10
  111. quality = quality.unsqueeze(-1)
  112. mel_lengths = audio_lengths // self.gt_mel_transform.hop_length
  113. mel_masks = sequence_mask(mel_lengths, gt_mels.shape[2])
  114. mel_masks_float_conv = mel_masks[:, None, :].float()
  115. gt_mels = gt_mels * mel_masks_float_conv
  116. encoded_mels = encoded_mels * mel_masks_float_conv
  117. # Encode
  118. encoded_features = self.encoder(encoded_mels) * mel_masks_float_conv
  119. # Quantize
  120. vq_result = self.quantizer(encoded_features)
  121. loss_vq = getattr("vq_result", "loss", 0.0)
  122. vq_recon_features = vq_result.z * mel_masks_float_conv
  123. vq_recon_features = (
  124. vq_recon_features + self.quality_projection(quality)[:, :, None]
  125. )
  126. # VQ Decode
  127. gen_mel = (
  128. self.decoder(
  129. torch.randn_like(vq_recon_features) * mel_masks_float_conv,
  130. condition=vq_recon_features,
  131. )
  132. * mel_masks_float_conv
  133. )
  134. # Discriminator
  135. real_logits = self.discriminator(gt_mels)
  136. fake_logits = self.discriminator(gen_mel.detach())
  137. d_mask = F.interpolate(
  138. mel_masks_float_conv, size=(real_logits.shape[2],), mode="nearest"
  139. )
  140. loss_real = avg_with_mask((real_logits - 1) ** 2, d_mask)
  141. loss_fake = avg_with_mask(fake_logits**2, d_mask)
  142. loss_d = loss_real + loss_fake
  143. self.log(
  144. "train/discriminator/loss",
  145. loss_d,
  146. on_step=True,
  147. on_epoch=False,
  148. prog_bar=True,
  149. logger=True,
  150. )
  151. # Discriminator backward
  152. optim_d.zero_grad()
  153. self.manual_backward(loss_d)
  154. self.clip_gradients(
  155. optim_d, gradient_clip_val=1000.0, gradient_clip_algorithm="norm"
  156. )
  157. optim_d.step()
  158. # Mel Loss, applying l1, using a weighted sum
  159. mel_distance = (
  160. gen_mel - gt_mels
  161. ).abs() # * 0.5 + self.ssim(gen_mel, gt_mels) * 0.5
  162. loss_mel_low_freq = avg_with_mask(mel_distance[:, :40, :], mel_masks_float_conv)
  163. loss_mel_mid_freq = avg_with_mask(
  164. mel_distance[:, 40:70, :], mel_masks_float_conv
  165. )
  166. loss_mel_high_freq = avg_with_mask(
  167. mel_distance[:, 70:, :], mel_masks_float_conv
  168. )
  169. loss_mel = (
  170. loss_mel_low_freq * 0.6 + loss_mel_mid_freq * 0.3 + loss_mel_high_freq * 0.1
  171. )
  172. # Adversarial Loss
  173. fake_logits = self.discriminator(gen_mel)
  174. loss_adv = avg_with_mask((fake_logits - 1) ** 2, d_mask)
  175. # Total loss
  176. loss = (
  177. self.weight_vq * loss_vq
  178. + self.weight_mel * loss_mel
  179. + self.weight_adv * loss_adv
  180. )
  181. # Log losses
  182. self.log(
  183. "train/generator/loss",
  184. loss,
  185. on_step=True,
  186. on_epoch=False,
  187. prog_bar=True,
  188. logger=True,
  189. )
  190. self.log(
  191. "train/generator/loss_vq",
  192. loss_vq,
  193. on_step=True,
  194. on_epoch=False,
  195. prog_bar=False,
  196. logger=True,
  197. )
  198. self.log(
  199. "train/generator/loss_mel",
  200. loss_mel,
  201. on_step=True,
  202. on_epoch=False,
  203. prog_bar=False,
  204. logger=True,
  205. )
  206. self.log(
  207. "train/generator/loss_adv",
  208. loss_adv,
  209. on_step=True,
  210. on_epoch=False,
  211. prog_bar=False,
  212. logger=True,
  213. )
  214. # Generator backward
  215. optim_g.zero_grad()
  216. self.manual_backward(loss)
  217. self.clip_gradients(
  218. optim_g, gradient_clip_val=1000.0, gradient_clip_algorithm="norm"
  219. )
  220. optim_g.step()
  221. scheduler_g, scheduler_d = self.lr_schedulers()
  222. scheduler_g.step()
  223. scheduler_d.step()
  224. def validation_step(self, batch: Any, batch_idx: int):
  225. audios, audio_lengths = batch["audios"], batch["audio_lengths"]
  226. audios = audios.float()
  227. audios = audios[:, None, :]
  228. encoded_mels = self.encode_mel_transform(audios)
  229. gt_mels = self.gt_mel_transform(audios)
  230. mel_lengths = audio_lengths // self.gt_mel_transform.hop_length
  231. mel_masks = sequence_mask(mel_lengths, gt_mels.shape[2])
  232. mel_masks_float_conv = mel_masks[:, None, :].float()
  233. gt_mels = gt_mels * mel_masks_float_conv
  234. encoded_mels = encoded_mels * mel_masks_float_conv
  235. # Encode
  236. encoded_features = self.encoder(encoded_mels) * mel_masks_float_conv
  237. # Quantize
  238. vq_recon_features = self.quantizer(encoded_features).z * mel_masks_float_conv
  239. vq_recon_features = (
  240. vq_recon_features
  241. + self.quality_projection(
  242. torch.ones(
  243. vq_recon_features.shape[0], 1, device=vq_recon_features.device
  244. )
  245. * 2
  246. )[:, :, None]
  247. )
  248. # VQ Decode
  249. gen_aux_mels = (
  250. self.decoder(
  251. torch.randn_like(vq_recon_features) * mel_masks_float_conv,
  252. condition=vq_recon_features,
  253. )
  254. * mel_masks_float_conv
  255. )
  256. loss_mel = avg_with_mask((gen_aux_mels - gt_mels).abs(), mel_masks_float_conv)
  257. self.log(
  258. "val/loss_mel",
  259. loss_mel,
  260. on_step=False,
  261. on_epoch=True,
  262. prog_bar=False,
  263. logger=True,
  264. sync_dist=True,
  265. )
  266. recon_audios = self.vocoder(gt_mels)
  267. gen_aux_audios = self.vocoder(gen_aux_mels)
  268. # only log the first batch
  269. if batch_idx != 0:
  270. return
  271. for idx, (
  272. gt_mel,
  273. gen_aux_mel,
  274. audio,
  275. gen_aux_audio,
  276. recon_audio,
  277. audio_len,
  278. ) in enumerate(
  279. zip(
  280. gt_mels,
  281. gen_aux_mels,
  282. audios.cpu().float(),
  283. gen_aux_audios.cpu().float(),
  284. recon_audios.cpu().float(),
  285. audio_lengths,
  286. )
  287. ):
  288. if idx > 4:
  289. break
  290. mel_len = audio_len // self.gt_mel_transform.hop_length
  291. image_mels = plot_mel(
  292. [
  293. gt_mel[:, :mel_len],
  294. gen_aux_mel[:, :mel_len],
  295. ],
  296. [
  297. "Ground-Truth",
  298. "Auxiliary",
  299. ],
  300. )
  301. if isinstance(self.logger, WandbLogger):
  302. self.logger.experiment.log(
  303. {
  304. "reconstruction_mel": wandb.Image(image_mels, caption="mels"),
  305. "wavs": [
  306. wandb.Audio(
  307. audio[0, :audio_len],
  308. sample_rate=self.sampling_rate,
  309. caption="gt",
  310. ),
  311. wandb.Audio(
  312. gen_aux_audio[0, :audio_len],
  313. sample_rate=self.sampling_rate,
  314. caption="aux",
  315. ),
  316. wandb.Audio(
  317. recon_audio[0, :audio_len],
  318. sample_rate=self.sampling_rate,
  319. caption="recon",
  320. ),
  321. ],
  322. },
  323. )
  324. if isinstance(self.logger, TensorBoardLogger):
  325. self.logger.experiment.add_figure(
  326. f"sample-{idx}/mels",
  327. image_mels,
  328. global_step=self.global_step,
  329. )
  330. self.logger.experiment.add_audio(
  331. f"sample-{idx}/wavs/gt",
  332. audio[0, :audio_len],
  333. self.global_step,
  334. sample_rate=self.sampling_rate,
  335. )
  336. self.logger.experiment.add_audio(
  337. f"sample-{idx}/wavs/gen",
  338. gen_aux_audio[0, :audio_len],
  339. self.global_step,
  340. sample_rate=self.sampling_rate,
  341. )
  342. self.logger.experiment.add_audio(
  343. f"sample-{idx}/wavs/recon",
  344. recon_audio[0, :audio_len],
  345. self.global_step,
  346. sample_rate=self.sampling_rate,
  347. )
  348. plt.close(image_mels)
  349. def encode(self, audios, audio_lengths):
  350. audios = audios.float()
  351. mels = self.encode_mel_transform(audios)
  352. mel_lengths = audio_lengths // self.encode_mel_transform.hop_length
  353. mel_masks = sequence_mask(mel_lengths, mels.shape[2])
  354. mel_masks_float_conv = mel_masks[:, None, :].float()
  355. mels = mels * mel_masks_float_conv
  356. # Encode
  357. encoded_features = self.encoder(mels) * mel_masks_float_conv
  358. feature_lengths = mel_lengths // math.prod(self.quantizer.downsample_factor)
  359. return self.quantizer.encode(encoded_features), feature_lengths
  360. def decode(self, indices, feature_lengths, return_audios=False):
  361. factor = math.prod(self.quantizer.downsample_factor)
  362. mel_masks = sequence_mask(feature_lengths * factor, indices.shape[2] * factor)
  363. mel_masks_float_conv = mel_masks[:, None, :].float()
  364. z = self.quantizer.decode(indices) * mel_masks_float_conv
  365. z = (
  366. z
  367. + self.quality_projection(torch.ones(z.shape[0], 1, device=z.device) * 2)[
  368. :, :, None
  369. ]
  370. )
  371. gen_mel = (
  372. self.decoder(
  373. torch.randn_like(z) * mel_masks_float_conv,
  374. condition=z,
  375. )
  376. * mel_masks_float_conv
  377. )
  378. if return_audios:
  379. return self.vocoder(gen_mel)
  380. return gen_mel