lit_module.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489
  1. import itertools
  2. from typing import Any, Callable, Literal
  3. import lightning as L
  4. import torch
  5. import torch.nn.functional as F
  6. import wandb
  7. from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
  8. from matplotlib import pyplot as plt
  9. from torch import nn
  10. from vector_quantize_pytorch import VectorQuantize
  11. from fish_speech.models.vqgan.losses import (
  12. discriminator_loss,
  13. feature_loss,
  14. generator_loss,
  15. kl_loss,
  16. )
  17. from fish_speech.models.vqgan.modules.decoder import Generator
  18. from fish_speech.models.vqgan.modules.discriminator import EnsembleDiscriminator
  19. from fish_speech.models.vqgan.modules.encoders import (
  20. ConvDownSampler,
  21. SpeakerEncoder,
  22. TextEncoder,
  23. VQEncoder,
  24. )
  25. from fish_speech.models.vqgan.utils import (
  26. plot_mel,
  27. rand_slice_segments,
  28. sequence_mask,
  29. slice_segments,
  30. )
  31. class VQGAN(L.LightningModule):
  32. def __init__(
  33. self,
  34. optimizer: Callable,
  35. lr_scheduler: Callable,
  36. downsample: ConvDownSampler,
  37. vq_encoder: VQEncoder,
  38. mel_encoder: TextEncoder,
  39. decoder: TextEncoder,
  40. generator: Generator,
  41. discriminator: EnsembleDiscriminator,
  42. mel_transform: nn.Module,
  43. segment_size: int = 20480,
  44. hop_length: int = 640,
  45. sample_rate: int = 32000,
  46. mode: Literal["pretrain-stage1", "pretrain-stage2", "finetune"] = "finetune",
  47. speaker_encoder: SpeakerEncoder = None,
  48. ):
  49. super().__init__()
  50. # pretrain-stage1: vq use gt mel as target, hifigan use gt mel as input
  51. # pretrain-stage2: end-to-end training, use gt mel as hifi gan target
  52. # finetune: end-to-end training, use gt mel as hifi gan target but freeze vq
  53. # Model parameters
  54. self.optimizer_builder = optimizer
  55. self.lr_scheduler_builder = lr_scheduler
  56. # Generator and discriminators
  57. self.downsample = downsample
  58. self.vq_encoder = vq_encoder
  59. self.mel_encoder = mel_encoder
  60. self.speaker_encoder = speaker_encoder
  61. self.decoder = decoder
  62. self.generator = generator
  63. self.discriminator = discriminator
  64. self.mel_transform = mel_transform
  65. # Crop length for saving memory
  66. self.segment_size = segment_size
  67. self.hop_length = hop_length
  68. self.sampling_rate = sample_rate
  69. self.mode = mode
  70. # Disable automatic optimization
  71. self.automatic_optimization = False
  72. # Finetune: Train the VQ only
  73. if self.mode == "finetune":
  74. for p in self.vq_encoder.parameters():
  75. p.requires_grad = False
  76. for p in self.mel_encoder.parameters():
  77. p.requires_grad = False
  78. for p in self.downsample.parameters():
  79. p.requires_grad = False
  80. def configure_optimizers(self):
  81. # Need two optimizers and two schedulers
  82. components = []
  83. if self.mode != "finetune":
  84. components.extend(
  85. [
  86. self.downsample.parameters(),
  87. self.vq_encoder.parameters(),
  88. self.mel_encoder.parameters(),
  89. ]
  90. )
  91. if self.speaker_encoder is not None:
  92. components.append(self.speaker_encoder.parameters())
  93. if self.decoder is not None:
  94. components.append(self.decoder.parameters())
  95. components.append(self.generator.parameters())
  96. optimizer_generator = self.optimizer_builder(itertools.chain(*components))
  97. optimizer_discriminator = self.optimizer_builder(
  98. self.discriminator.parameters()
  99. )
  100. lr_scheduler_generator = self.lr_scheduler_builder(optimizer_generator)
  101. lr_scheduler_discriminator = self.lr_scheduler_builder(optimizer_discriminator)
  102. return (
  103. {
  104. "optimizer": optimizer_generator,
  105. "lr_scheduler": {
  106. "scheduler": lr_scheduler_generator,
  107. "interval": "step",
  108. "name": "optimizer/generator",
  109. },
  110. },
  111. {
  112. "optimizer": optimizer_discriminator,
  113. "lr_scheduler": {
  114. "scheduler": lr_scheduler_discriminator,
  115. "interval": "step",
  116. "name": "optimizer/discriminator",
  117. },
  118. },
  119. )
  120. def training_step(self, batch, batch_idx):
  121. optim_g, optim_d = self.optimizers()
  122. audios, audio_lengths = batch["audios"], batch["audio_lengths"]
  123. audios = audios.float()
  124. audios = audios[:, None, :]
  125. with torch.no_grad():
  126. features = gt_mels = self.mel_transform(
  127. audios, sample_rate=self.sampling_rate
  128. )
  129. if self.mode == "finetune":
  130. # Disable gradient computation for VQ
  131. torch.set_grad_enabled(False)
  132. self.vq_encoder.eval()
  133. self.mel_encoder.eval()
  134. self.downsample.eval()
  135. if self.downsample is not None:
  136. features = self.downsample(features)
  137. mel_lengths = audio_lengths // self.hop_length
  138. feature_lengths = (
  139. audio_lengths
  140. / self.hop_length
  141. / (self.downsample.total_strides if self.downsample is not None else 1)
  142. ).long()
  143. feature_masks = torch.unsqueeze(
  144. sequence_mask(feature_lengths, features.shape[2]), 1
  145. ).to(gt_mels.dtype)
  146. mel_masks = torch.unsqueeze(sequence_mask(mel_lengths, gt_mels.shape[2]), 1).to(
  147. gt_mels.dtype
  148. )
  149. # vq_features is 50 hz, need to convert to true mel size
  150. text_features = self.mel_encoder(features, feature_masks)
  151. text_features, _, loss_vq = self.vq_encoder(text_features, feature_masks)
  152. text_features = F.interpolate(
  153. text_features, size=gt_mels.shape[2], mode="nearest"
  154. )
  155. if loss_vq.ndim > 1:
  156. loss_vq = loss_vq.mean()
  157. if self.mode == "finetune":
  158. # Enable gradient computation
  159. torch.set_grad_enabled(True)
  160. # Sample mels
  161. if self.decoder is not None:
  162. speaker_features = (
  163. self.speaker_encoder(gt_mels, mel_masks)
  164. if self.speaker_encoder is not None
  165. else None
  166. )
  167. decoded_mels = self.decoder(text_features, mel_masks, g=speaker_features)
  168. else:
  169. decoded_mels = text_features
  170. input_mels = gt_mels if self.mode == "pretrain-stage1" else decoded_mels
  171. if self.segment_size is not None:
  172. audios, ids_slice = rand_slice_segments(
  173. audios, audio_lengths, self.segment_size
  174. )
  175. input_mels = slice_segments(
  176. input_mels,
  177. ids_slice // self.hop_length,
  178. self.segment_size // self.hop_length,
  179. )
  180. sliced_gt_mels = slice_segments(
  181. gt_mels,
  182. ids_slice // self.hop_length,
  183. self.segment_size // self.hop_length,
  184. )
  185. gen_mel_masks = slice_segments(
  186. mel_masks,
  187. ids_slice // self.hop_length,
  188. self.segment_size // self.hop_length,
  189. )
  190. else:
  191. sliced_gt_mels = gt_mels
  192. gen_mel_masks = mel_masks
  193. fake_audios = self.generator(input_mels)
  194. fake_audio_mels = self.mel_transform(fake_audios.squeeze(1))
  195. assert (
  196. audios.shape == fake_audios.shape
  197. ), f"{audios.shape} != {fake_audios.shape}"
  198. # Discriminator
  199. y_d_hat_r, y_d_hat_g, _, _ = self.discriminator(audios, fake_audios.detach())
  200. with torch.autocast(device_type=audios.device.type, enabled=False):
  201. loss_disc_all, _, _ = discriminator_loss(y_d_hat_r, y_d_hat_g)
  202. self.log(
  203. "train/discriminator/loss",
  204. loss_disc_all,
  205. on_step=True,
  206. on_epoch=False,
  207. prog_bar=True,
  208. logger=True,
  209. sync_dist=True,
  210. )
  211. optim_d.zero_grad()
  212. self.manual_backward(loss_disc_all)
  213. self.clip_gradients(
  214. optim_d, gradient_clip_val=1000.0, gradient_clip_algorithm="norm"
  215. )
  216. optim_d.step()
  217. y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = self.discriminator(audios, fake_audios)
  218. with torch.autocast(device_type=audios.device.type, enabled=False):
  219. loss_decoded_mel = F.l1_loss(gt_mels * mel_masks, decoded_mels * mel_masks)
  220. loss_mel = F.l1_loss(
  221. sliced_gt_mels * gen_mel_masks, fake_audio_mels * gen_mel_masks
  222. )
  223. loss_adv, _ = generator_loss(y_d_hat_g)
  224. loss_fm = feature_loss(fmap_r, fmap_g)
  225. if self.mode == "pretrain-stage1":
  226. loss_vq_all = loss_decoded_mel + loss_vq
  227. loss_gen_all = loss_mel * 45 + loss_fm + loss_adv
  228. else:
  229. loss_gen_all = loss_mel * 45 + loss_vq * 45 + loss_fm + loss_adv
  230. self.log(
  231. "train/generator/loss_gen_all",
  232. loss_gen_all,
  233. on_step=True,
  234. on_epoch=False,
  235. prog_bar=True,
  236. logger=True,
  237. sync_dist=True,
  238. )
  239. if self.mode == "pretrain-stage1":
  240. self.log(
  241. "train/generator/loss_vq_all",
  242. loss_vq_all,
  243. on_step=True,
  244. on_epoch=False,
  245. prog_bar=True,
  246. logger=True,
  247. sync_dist=True,
  248. )
  249. self.log(
  250. "train/generator/loss_decoded_mel",
  251. loss_decoded_mel,
  252. on_step=True,
  253. on_epoch=False,
  254. prog_bar=False,
  255. logger=True,
  256. sync_dist=True,
  257. )
  258. self.log(
  259. "train/generator/loss_mel",
  260. loss_mel,
  261. on_step=True,
  262. on_epoch=False,
  263. prog_bar=False,
  264. logger=True,
  265. sync_dist=True,
  266. )
  267. self.log(
  268. "train/generator/loss_fm",
  269. loss_fm,
  270. on_step=True,
  271. on_epoch=False,
  272. prog_bar=False,
  273. logger=True,
  274. sync_dist=True,
  275. )
  276. self.log(
  277. "train/generator/loss_adv",
  278. loss_adv,
  279. on_step=True,
  280. on_epoch=False,
  281. prog_bar=False,
  282. logger=True,
  283. sync_dist=True,
  284. )
  285. self.log(
  286. "train/generator/loss_vq",
  287. loss_vq,
  288. on_step=True,
  289. on_epoch=False,
  290. prog_bar=False,
  291. logger=True,
  292. sync_dist=True,
  293. )
  294. optim_g.zero_grad()
  295. # Only backpropagate loss_vq_all in pretrain-stage1
  296. if self.mode == "pretrain-stage1":
  297. self.manual_backward(loss_vq_all)
  298. self.manual_backward(loss_gen_all)
  299. self.clip_gradients(
  300. optim_g, gradient_clip_val=1000.0, gradient_clip_algorithm="norm"
  301. )
  302. optim_g.step()
  303. # Manual LR Scheduler
  304. scheduler_g, scheduler_d = self.lr_schedulers()
  305. scheduler_g.step()
  306. scheduler_d.step()
  307. def validation_step(self, batch: Any, batch_idx: int):
  308. audios, audio_lengths = batch["audios"], batch["audio_lengths"]
  309. audios = audios.float()
  310. audios = audios[:, None, :]
  311. features = gt_mels = self.mel_transform(audios, sample_rate=self.sampling_rate)
  312. if self.downsample is not None:
  313. features = self.downsample(features)
  314. mel_lengths = audio_lengths // self.hop_length
  315. feature_lengths = (
  316. audio_lengths
  317. / self.hop_length
  318. / (self.downsample.total_strides if self.downsample is not None else 1)
  319. ).long()
  320. feature_masks = torch.unsqueeze(
  321. sequence_mask(feature_lengths, features.shape[2]), 1
  322. ).to(gt_mels.dtype)
  323. mel_masks = torch.unsqueeze(sequence_mask(mel_lengths, gt_mels.shape[2]), 1).to(
  324. gt_mels.dtype
  325. )
  326. # vq_features is 50 hz, need to convert to true mel size
  327. text_features = self.mel_encoder(features, feature_masks)
  328. text_features, _, _ = self.vq_encoder(text_features, feature_masks)
  329. text_features = F.interpolate(
  330. text_features, size=gt_mels.shape[2], mode="nearest"
  331. )
  332. # Sample mels
  333. if self.decoder is not None:
  334. speaker_features = (
  335. self.speaker_encoder(gt_mels, mel_masks)
  336. if self.speaker_encoder is not None
  337. else None
  338. )
  339. decoded_mels = self.decoder(text_features, mel_masks, g=speaker_features)
  340. else:
  341. decoded_mels = text_features
  342. fake_audios = self.generator(decoded_mels)
  343. fake_mels = self.mel_transform(fake_audios.squeeze(1))
  344. min_mel_length = min(
  345. decoded_mels.shape[-1], gt_mels.shape[-1], fake_mels.shape[-1]
  346. )
  347. decoded_mels = decoded_mels[:, :, :min_mel_length]
  348. gt_mels = gt_mels[:, :, :min_mel_length]
  349. fake_mels = fake_mels[:, :, :min_mel_length]
  350. mel_loss = F.l1_loss(gt_mels * mel_masks, fake_mels * mel_masks)
  351. self.log(
  352. "val/mel_loss",
  353. mel_loss,
  354. on_step=False,
  355. on_epoch=True,
  356. prog_bar=True,
  357. logger=True,
  358. sync_dist=True,
  359. )
  360. for idx, (
  361. mel,
  362. gen_mel,
  363. decode_mel,
  364. audio,
  365. gen_audio,
  366. audio_len,
  367. ) in enumerate(
  368. zip(
  369. gt_mels,
  370. fake_mels,
  371. decoded_mels,
  372. audios.detach().float(),
  373. fake_audios.detach().float(),
  374. audio_lengths,
  375. )
  376. ):
  377. mel_len = audio_len // self.hop_length
  378. image_mels = plot_mel(
  379. [
  380. gen_mel[:, :mel_len],
  381. decode_mel[:, :mel_len],
  382. mel[:, :mel_len],
  383. ],
  384. [
  385. "Generated",
  386. "Decoded",
  387. "Ground-Truth",
  388. ],
  389. )
  390. if isinstance(self.logger, WandbLogger):
  391. self.logger.experiment.log(
  392. {
  393. "reconstruction_mel": wandb.Image(image_mels, caption="mels"),
  394. "wavs": [
  395. wandb.Audio(
  396. audio[0, :audio_len],
  397. sample_rate=self.sampling_rate,
  398. caption="gt",
  399. ),
  400. wandb.Audio(
  401. gen_audio[0, :audio_len],
  402. sample_rate=self.sampling_rate,
  403. caption="prediction",
  404. ),
  405. ],
  406. },
  407. )
  408. if isinstance(self.logger, TensorBoardLogger):
  409. self.logger.experiment.add_figure(
  410. f"sample-{idx}/mels",
  411. image_mels,
  412. global_step=self.global_step,
  413. )
  414. self.logger.experiment.add_audio(
  415. f"sample-{idx}/wavs/gt",
  416. audio[0, :audio_len],
  417. self.global_step,
  418. sample_rate=self.sampling_rate,
  419. )
  420. self.logger.experiment.add_audio(
  421. f"sample-{idx}/wavs/prediction",
  422. gen_audio[0, :audio_len],
  423. self.global_step,
  424. sample_rate=self.sampling_rate,
  425. )
  426. plt.close(image_mels)