lit_module.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447
  1. import itertools
  2. from dataclasses import dataclass
  3. from typing import Any, Callable, Literal, Optional
  4. import lightning as L
  5. import torch
  6. import torch.nn.functional as F
  7. import wandb
  8. from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
  9. from matplotlib import pyplot as plt
  10. from torch import nn
  11. from torch.utils.checkpoint import checkpoint as gradient_checkpoint
  12. from fish_speech.models.vqgan.losses import (
  13. MultiResolutionSTFTLoss,
  14. discriminator_loss,
  15. feature_loss,
  16. generator_loss,
  17. )
  18. from fish_speech.models.vqgan.modules.convnext import ConvNeXt
  19. from fish_speech.models.vqgan.modules.encoders import VQEncoder
  20. from fish_speech.models.vqgan.utils import plot_mel, sequence_mask
  21. @dataclass
  22. class VQEncodeResult:
  23. features: torch.Tensor
  24. indices: torch.Tensor
  25. loss: torch.Tensor
  26. feature_lengths: torch.Tensor
  27. @dataclass
  28. class VQDecodeResult:
  29. mels: torch.Tensor
  30. audios: Optional[torch.Tensor] = None
  31. class VQGAN(L.LightningModule):
  32. def __init__(
  33. self,
  34. optimizer: Callable,
  35. lr_scheduler: Callable,
  36. encoder: ConvNeXt,
  37. vq: VQEncoder,
  38. decoder: ConvNeXt,
  39. generator: nn.Module,
  40. discriminator: ConvNeXt,
  41. mel_transform: nn.Module,
  42. hop_length: int = 640,
  43. sample_rate: int = 32000,
  44. freeze_discriminator: bool = False,
  45. ):
  46. super().__init__()
  47. # pretrain: vq use gt mel as target, hifigan use gt mel as input
  48. # finetune: end-to-end training, use gt mel as hifi gan target but freeze vq
  49. # Model parameters
  50. self.optimizer_builder = optimizer
  51. self.lr_scheduler_builder = lr_scheduler
  52. # Generator and discriminator
  53. self.encoder = encoder
  54. self.vq = vq
  55. self.decoder = decoder
  56. self.generator = generator
  57. self.discriminator = discriminator
  58. self.mel_transform = mel_transform
  59. self.freeze_discriminator = freeze_discriminator
  60. # Crop length for saving memory
  61. self.hop_length = hop_length
  62. self.sampling_rate = sample_rate
  63. # Disable automatic optimization
  64. self.automatic_optimization = False
  65. if self.freeze_discriminator:
  66. for p in self.discriminator.parameters():
  67. p.requires_grad = False
  68. # Freeze generator
  69. for p in self.generator.parameters():
  70. p.requires_grad = False
  71. def configure_optimizers(self):
  72. # Need two optimizers and two schedulers
  73. optimizer_generator = self.optimizer_builder(
  74. itertools.chain(
  75. self.encoder.parameters(),
  76. self.vq.parameters(),
  77. self.decoder.parameters(),
  78. )
  79. )
  80. optimizer_discriminator = self.optimizer_builder(
  81. self.discriminator.parameters()
  82. )
  83. lr_scheduler_generator = self.lr_scheduler_builder(optimizer_generator)
  84. lr_scheduler_discriminator = self.lr_scheduler_builder(optimizer_discriminator)
  85. return (
  86. {
  87. "optimizer": optimizer_generator,
  88. "lr_scheduler": {
  89. "scheduler": lr_scheduler_generator,
  90. "interval": "step",
  91. "name": "optimizer/generator",
  92. },
  93. },
  94. {
  95. "optimizer": optimizer_discriminator,
  96. "lr_scheduler": {
  97. "scheduler": lr_scheduler_discriminator,
  98. "interval": "step",
  99. "name": "optimizer/discriminator",
  100. },
  101. },
  102. )
  103. def training_step(self, batch, batch_idx):
  104. optim_g, optim_d = self.optimizers()
  105. audios, audio_lengths = batch["audios"], batch["audio_lengths"]
  106. audios = audios.float()
  107. audios = audios[:, None, :]
  108. with torch.no_grad():
  109. gt_mels = self.mel_transform(audios, sample_rate=self.sampling_rate)
  110. mel_lengths = audio_lengths // self.hop_length
  111. mel_masks = torch.unsqueeze(sequence_mask(mel_lengths, gt_mels.shape[2]), 1).to(
  112. gt_mels.dtype
  113. )
  114. vq_result = self.encode(audios, audio_lengths)
  115. loss_vq = vq_result.loss
  116. if loss_vq.ndim > 1:
  117. loss_vq = loss_vq.mean()
  118. decoded_mels = self.decode(
  119. indices=None,
  120. features=vq_result.features,
  121. audio_lengths=audio_lengths,
  122. ).mels
  123. with torch.no_grad():
  124. with torch.autocast(device_type=audios.device.type, enabled=False):
  125. fake_audios = self.generator(decoded_mels.float())
  126. assert (
  127. audios.shape == fake_audios.shape
  128. ), f"{audios.shape} != {fake_audios.shape}"
  129. # Discriminator
  130. if self.freeze_discriminator is False:
  131. scores = self.discriminator(gt_mels)
  132. score_fakes = self.discriminator(decoded_mels.detach())
  133. with torch.autocast(device_type=audios.device.type, enabled=False):
  134. loss_disc, _, _ = discriminator_loss([scores], [score_fakes])
  135. self.log(
  136. f"train/discriminator/loss",
  137. loss_disc,
  138. on_step=True,
  139. on_epoch=False,
  140. prog_bar=False,
  141. logger=True,
  142. sync_dist=True,
  143. )
  144. optim_d.zero_grad()
  145. self.manual_backward(loss_disc)
  146. self.clip_gradients(
  147. optim_d, gradient_clip_val=1000.0, gradient_clip_algorithm="norm"
  148. )
  149. optim_d.step()
  150. # Adv Loss
  151. score_fakes = self.discriminator(decoded_mels)
  152. # Adversarial Loss
  153. with torch.autocast(device_type=audios.device.type, enabled=False):
  154. loss_adv, _ = generator_loss([score_fakes])
  155. self.log(
  156. f"train/generator/adv",
  157. loss_adv,
  158. on_step=True,
  159. on_epoch=False,
  160. prog_bar=False,
  161. logger=True,
  162. sync_dist=True,
  163. )
  164. # Feature Matching Loss
  165. score_gts = self.discriminator(gt_mels)
  166. with torch.autocast(device_type=audios.device.type, enabled=False):
  167. loss_fm = feature_loss([score_gts], [score_fakes])
  168. self.log(
  169. f"train/generator/adv_fm",
  170. loss_fm,
  171. on_step=True,
  172. on_epoch=False,
  173. prog_bar=False,
  174. logger=True,
  175. sync_dist=True,
  176. )
  177. with torch.autocast(device_type=audios.device.type, enabled=False):
  178. loss_mel = F.l1_loss(gt_mels * mel_masks, decoded_mels * mel_masks)
  179. self.log(
  180. "train/generator/loss_mel",
  181. loss_mel,
  182. on_step=True,
  183. on_epoch=False,
  184. prog_bar=False,
  185. logger=True,
  186. sync_dist=True,
  187. )
  188. self.log(
  189. "train/generator/loss_vq",
  190. loss_vq,
  191. on_step=True,
  192. on_epoch=False,
  193. prog_bar=False,
  194. logger=True,
  195. sync_dist=True,
  196. )
  197. loss = loss_mel * 20 + loss_vq + loss_adv + loss_fm
  198. self.log(
  199. "train/generator/loss",
  200. loss,
  201. on_step=True,
  202. on_epoch=False,
  203. prog_bar=False,
  204. logger=True,
  205. sync_dist=True,
  206. )
  207. optim_g.zero_grad()
  208. self.manual_backward(loss)
  209. self.clip_gradients(
  210. optim_g, gradient_clip_val=1000.0, gradient_clip_algorithm="norm"
  211. )
  212. optim_g.step()
  213. # Manual LR Scheduler
  214. scheduler_g, scheduler_d = self.lr_schedulers()
  215. scheduler_g.step()
  216. scheduler_d.step()
  217. def validation_step(self, batch: Any, batch_idx: int):
  218. audios, audio_lengths = batch["audios"], batch["audio_lengths"]
  219. audios = audios.float()
  220. audios = audios[:, None, :]
  221. gt_mels = self.mel_transform(audios, sample_rate=self.sampling_rate)
  222. mel_lengths = audio_lengths // self.hop_length
  223. mel_masks = torch.unsqueeze(sequence_mask(mel_lengths, gt_mels.shape[2]), 1).to(
  224. gt_mels.dtype
  225. )
  226. vq_result = self.encode(audios, audio_lengths)
  227. decoded_mels = self.decode(
  228. indices=vq_result.indices,
  229. audio_lengths=audio_lengths,
  230. ).mels
  231. fake_audios = self.generator(decoded_mels)
  232. fake_mels = self.mel_transform(fake_audios.squeeze(1))
  233. min_mel_length = min(
  234. decoded_mels.shape[-1], gt_mels.shape[-1], fake_mels.shape[-1]
  235. )
  236. decoded_mels = decoded_mels[:, :, :min_mel_length]
  237. gt_mels = gt_mels[:, :, :min_mel_length]
  238. fake_mels = fake_mels[:, :, :min_mel_length]
  239. mel_loss = F.l1_loss(gt_mels * mel_masks, fake_mels * mel_masks)
  240. self.log(
  241. "val/mel_loss",
  242. mel_loss,
  243. on_step=False,
  244. on_epoch=True,
  245. prog_bar=True,
  246. logger=True,
  247. sync_dist=True,
  248. )
  249. for idx, (
  250. mel,
  251. gen_mel,
  252. decode_mel,
  253. audio,
  254. gen_audio,
  255. audio_len,
  256. ) in enumerate(
  257. zip(
  258. gt_mels,
  259. fake_mels,
  260. decoded_mels,
  261. audios.detach().float(),
  262. fake_audios.detach().float(),
  263. audio_lengths,
  264. )
  265. ):
  266. mel_len = audio_len // self.hop_length
  267. image_mels = plot_mel(
  268. [
  269. gen_mel[:, :mel_len],
  270. decode_mel[:, :mel_len],
  271. mel[:, :mel_len],
  272. ],
  273. [
  274. "Generated",
  275. "Decoded",
  276. "Ground-Truth",
  277. ],
  278. )
  279. if isinstance(self.logger, WandbLogger):
  280. self.logger.experiment.log(
  281. {
  282. "reconstruction_mel": wandb.Image(image_mels, caption="mels"),
  283. "wavs": [
  284. wandb.Audio(
  285. audio[0, :audio_len],
  286. sample_rate=self.sampling_rate,
  287. caption="gt",
  288. ),
  289. wandb.Audio(
  290. gen_audio[0, :audio_len],
  291. sample_rate=self.sampling_rate,
  292. caption="prediction",
  293. ),
  294. ],
  295. },
  296. )
  297. if isinstance(self.logger, TensorBoardLogger):
  298. self.logger.experiment.add_figure(
  299. f"sample-{idx}/mels",
  300. image_mels,
  301. global_step=self.global_step,
  302. )
  303. self.logger.experiment.add_audio(
  304. f"sample-{idx}/wavs/gt",
  305. audio[0, :audio_len],
  306. self.global_step,
  307. sample_rate=self.sampling_rate,
  308. )
  309. self.logger.experiment.add_audio(
  310. f"sample-{idx}/wavs/prediction",
  311. gen_audio[0, :audio_len],
  312. self.global_step,
  313. sample_rate=self.sampling_rate,
  314. )
  315. plt.close(image_mels)
  316. def encode(self, audios, audio_lengths=None):
  317. if audio_lengths is None:
  318. audio_lengths = torch.tensor(
  319. [audios.shape[-1]] * audios.shape[0],
  320. device=audios.device,
  321. dtype=torch.long,
  322. )
  323. with torch.no_grad():
  324. features = self.mel_transform(audios, sample_rate=self.sampling_rate)
  325. feature_lengths = (
  326. audio_lengths
  327. / self.hop_length
  328. # / self.vq.downsample
  329. ).long()
  330. # print(features.shape, feature_lengths.shape, torch.max(feature_lengths))
  331. feature_masks = torch.unsqueeze(
  332. sequence_mask(feature_lengths, features.shape[2]), 1
  333. ).to(features.dtype)
  334. features = (
  335. gradient_checkpoint(
  336. self.encoder, features, feature_masks, use_reentrant=False
  337. )
  338. * feature_masks
  339. )
  340. vq_features, indices, loss = self.vq(features, feature_masks)
  341. return VQEncodeResult(
  342. features=vq_features,
  343. indices=indices,
  344. loss=loss,
  345. feature_lengths=feature_lengths,
  346. )
  347. def calculate_audio_lengths(self, feature_lengths):
  348. return feature_lengths * self.hop_length * self.vq.downsample
  349. def decode(
  350. self,
  351. indices=None,
  352. features=None,
  353. audio_lengths=None,
  354. feature_lengths=None,
  355. return_audios=False,
  356. ):
  357. assert (
  358. indices is not None or features is not None
  359. ), "indices or features must be provided"
  360. assert (
  361. feature_lengths is not None or audio_lengths is not None
  362. ), "feature_lengths or audio_lengths must be provided"
  363. if audio_lengths is None:
  364. audio_lengths = self.calculate_audio_lengths(feature_lengths)
  365. mel_lengths = audio_lengths // self.hop_length
  366. mel_masks = torch.unsqueeze(
  367. sequence_mask(mel_lengths, torch.max(mel_lengths)), 1
  368. ).float()
  369. if indices is not None:
  370. features = self.vq.decode(indices)
  371. # Sample mels
  372. decoded = gradient_checkpoint(self.decoder, features, use_reentrant=False)
  373. return VQDecodeResult(
  374. mels=decoded,
  375. audios=self.generator(decoded) if return_audios else None,
  376. )