lit_module.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660
  1. import itertools
  2. from typing import Any, Callable
  3. import lightning as L
  4. import torch
  5. import torch.nn.functional as F
  6. import wandb
  7. from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
  8. from matplotlib import pyplot as plt
  9. from torch import nn
  10. from vector_quantize_pytorch import VectorQuantize
  11. from fish_speech.models.vqgan.losses import (
  12. discriminator_loss,
  13. feature_loss,
  14. generator_loss,
  15. kl_loss,
  16. )
  17. from fish_speech.models.vqgan.modules.decoder import Generator
  18. from fish_speech.models.vqgan.modules.discriminator import EnsembleDiscriminator
  19. from fish_speech.models.vqgan.modules.encoders import (
  20. ConvDownSampler,
  21. SpeakerEncoder,
  22. TextEncoder,
  23. VQEncoder,
  24. )
  25. from fish_speech.models.vqgan.utils import (
  26. plot_mel,
  27. rand_slice_segments,
  28. sequence_mask,
  29. slice_segments,
  30. )
  31. class VQGAN(L.LightningModule):
  32. def __init__(
  33. self,
  34. optimizer: Callable,
  35. lr_scheduler: Callable,
  36. downsample: ConvDownSampler,
  37. vq_encoder: VQEncoder,
  38. mel_encoder: TextEncoder,
  39. decoder: TextEncoder,
  40. generator: Generator,
  41. discriminator: EnsembleDiscriminator,
  42. mel_transform: nn.Module,
  43. segment_size: int = 20480,
  44. hop_length: int = 640,
  45. sample_rate: int = 32000,
  46. freeze_hifigan: bool = False,
  47. freeze_vq: bool = False,
  48. speaker_encoder: SpeakerEncoder = None,
  49. ):
  50. super().__init__()
  51. # Model parameters
  52. self.optimizer_builder = optimizer
  53. self.lr_scheduler_builder = lr_scheduler
  54. # Generator and discriminators
  55. self.downsample = downsample
  56. self.vq_encoder = vq_encoder
  57. self.mel_encoder = mel_encoder
  58. self.speaker_encoder = speaker_encoder
  59. self.decoder = decoder
  60. self.generator = generator
  61. self.discriminator = discriminator
  62. self.mel_transform = mel_transform
  63. # Crop length for saving memory
  64. self.segment_size = segment_size
  65. self.hop_length = hop_length
  66. self.sampling_rate = sample_rate
  67. self.freeze_hifigan = freeze_hifigan
  68. self.freeze_vq = freeze_vq
  69. # Disable automatic optimization
  70. self.automatic_optimization = False
  71. # Stage 1: Train the VQ only
  72. if self.freeze_hifigan:
  73. for p in self.discriminator.parameters():
  74. p.requires_grad = False
  75. for p in self.generator.parameters():
  76. p.requires_grad = False
  77. # Stage 2: Train the HifiGAN + Decoder + Generator
  78. if freeze_vq:
  79. for p in self.vq_encoder.parameters():
  80. p.requires_grad = False
  81. for p in self.mel_encoder.parameters():
  82. p.requires_grad = False
  83. for p in self.downsample.parameters():
  84. p.requires_grad = False
  85. def configure_optimizers(self):
  86. # Need two optimizers and two schedulers
  87. optimizer_generator = self.optimizer_builder(
  88. itertools.chain(
  89. self.downsample.parameters(),
  90. self.vq_encoder.parameters(),
  91. self.mel_encoder.parameters(),
  92. self.decoder.parameters(),
  93. self.generator.parameters(),
  94. )
  95. )
  96. optimizer_discriminator = self.optimizer_builder(
  97. self.discriminator.parameters()
  98. )
  99. lr_scheduler_generator = self.lr_scheduler_builder(optimizer_generator)
  100. lr_scheduler_discriminator = self.lr_scheduler_builder(optimizer_discriminator)
  101. return (
  102. {
  103. "optimizer": optimizer_generator,
  104. "lr_scheduler": {
  105. "scheduler": lr_scheduler_generator,
  106. "interval": "step",
  107. "name": "optimizer/generator",
  108. },
  109. },
  110. {
  111. "optimizer": optimizer_discriminator,
  112. "lr_scheduler": {
  113. "scheduler": lr_scheduler_discriminator,
  114. "interval": "step",
  115. "name": "optimizer/discriminator",
  116. },
  117. },
  118. )
  119. def training_step(self, batch, batch_idx):
  120. optim_g, optim_d = self.optimizers()
  121. audios, audio_lengths = batch["audios"], batch["audio_lengths"]
  122. audios = audios.float()
  123. audios = audios[:, None, :]
  124. with torch.no_grad():
  125. features = gt_mels = self.mel_transform(
  126. audios, sample_rate=self.sampling_rate
  127. )
  128. if self.downsample is not None:
  129. features = self.downsample(features)
  130. mel_lengths = audio_lengths // self.hop_length
  131. feature_lengths = (
  132. audio_lengths
  133. / self.hop_length
  134. / (self.downsample.total_strides if self.downsample is not None else 1)
  135. ).long()
  136. feature_masks = torch.unsqueeze(
  137. sequence_mask(feature_lengths, features.shape[2]), 1
  138. ).to(gt_mels.dtype)
  139. mel_masks = torch.unsqueeze(sequence_mask(mel_lengths, gt_mels.shape[2]), 1).to(
  140. gt_mels.dtype
  141. )
  142. # vq_features is 50 hz, need to convert to true mel size
  143. text_features = self.mel_encoder(features, feature_masks)
  144. text_features, _, loss_vq = self.vq_encoder(
  145. text_features, feature_masks, freeze_codebook=self.freeze_vq
  146. )
  147. text_features = F.interpolate(
  148. text_features, size=gt_mels.shape[2], mode="nearest"
  149. )
  150. if loss_vq.ndim > 1:
  151. loss_vq = loss_vq.mean()
  152. # Sample mels
  153. speaker_features = (
  154. self.speaker_encoder(gt_mels, mel_masks)
  155. if self.speaker_encoder is not None
  156. else None
  157. )
  158. decoded_mels = self.decoder(text_features, mel_masks, g=speaker_features)
  159. fake_audios = self.generator(decoded_mels)
  160. y_hat_mels = self.mel_transform(fake_audios.squeeze(1))
  161. y, ids_slice = rand_slice_segments(audios, audio_lengths, self.segment_size)
  162. y_hat = slice_segments(fake_audios, ids_slice, self.segment_size)
  163. assert y.shape == y_hat.shape, f"{y.shape} != {y_hat.shape}"
  164. # Since we don't want to update the discriminator, we skip the backward pass
  165. if self.freeze_hifigan is False:
  166. # Discriminator
  167. y_d_hat_r, y_d_hat_g, _, _ = self.discriminator(y, y_hat.detach())
  168. with torch.autocast(device_type=audios.device.type, enabled=False):
  169. loss_disc_all, _, _ = discriminator_loss(y_d_hat_r, y_d_hat_g)
  170. self.log(
  171. "train/discriminator/loss",
  172. loss_disc_all,
  173. on_step=True,
  174. on_epoch=False,
  175. prog_bar=True,
  176. logger=True,
  177. sync_dist=True,
  178. )
  179. optim_d.zero_grad()
  180. self.manual_backward(loss_disc_all)
  181. self.clip_gradients(
  182. optim_d, gradient_clip_val=1.0, gradient_clip_algorithm="norm"
  183. )
  184. optim_d.step()
  185. y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = self.discriminator(y, y_hat)
  186. with torch.autocast(device_type=audios.device.type, enabled=False):
  187. loss_decoded_mel = F.l1_loss(gt_mels * mel_masks, decoded_mels * mel_masks)
  188. loss_mel = F.l1_loss(gt_mels * mel_masks, y_hat_mels * mel_masks)
  189. loss_adv, _ = generator_loss(y_d_hat_g)
  190. loss_fm = feature_loss(fmap_r, fmap_g)
  191. if self.freeze_hifigan is True:
  192. loss_gen_all = loss_decoded_mel + loss_vq
  193. else:
  194. loss_gen_all = loss_mel * 45 + loss_vq * 45 + loss_fm + loss_adv
  195. self.log(
  196. "train/generator/loss",
  197. loss_gen_all,
  198. on_step=True,
  199. on_epoch=False,
  200. prog_bar=True,
  201. logger=True,
  202. sync_dist=True,
  203. )
  204. self.log(
  205. "train/generator/loss_decoded_mel",
  206. loss_decoded_mel,
  207. on_step=True,
  208. on_epoch=False,
  209. prog_bar=False,
  210. logger=True,
  211. sync_dist=True,
  212. )
  213. self.log(
  214. "train/generator/loss_mel",
  215. loss_mel,
  216. on_step=True,
  217. on_epoch=False,
  218. prog_bar=False,
  219. logger=True,
  220. sync_dist=True,
  221. )
  222. self.log(
  223. "train/generator/loss_fm",
  224. loss_fm,
  225. on_step=True,
  226. on_epoch=False,
  227. prog_bar=False,
  228. logger=True,
  229. sync_dist=True,
  230. )
  231. self.log(
  232. "train/generator/loss_adv",
  233. loss_adv,
  234. on_step=True,
  235. on_epoch=False,
  236. prog_bar=False,
  237. logger=True,
  238. sync_dist=True,
  239. )
  240. self.log(
  241. "train/generator/loss_vq",
  242. loss_vq,
  243. on_step=True,
  244. on_epoch=False,
  245. prog_bar=False,
  246. logger=True,
  247. sync_dist=True,
  248. )
  249. optim_g.zero_grad()
  250. self.manual_backward(loss_gen_all)
  251. self.clip_gradients(
  252. optim_g, gradient_clip_val=1.0, gradient_clip_algorithm="norm"
  253. )
  254. optim_g.step()
  255. # Manual LR Scheduler
  256. scheduler_g, scheduler_d = self.lr_schedulers()
  257. scheduler_g.step()
  258. scheduler_d.step()
  259. def validation_step(self, batch: Any, batch_idx: int):
  260. audios, audio_lengths = batch["audios"], batch["audio_lengths"]
  261. audios = audios.float()
  262. audios = audios[:, None, :]
  263. features = gt_mels = self.mel_transform(audios, sample_rate=self.sampling_rate)
  264. if self.downsample is not None:
  265. features = self.downsample(features)
  266. mel_lengths = audio_lengths // self.hop_length
  267. feature_lengths = (
  268. audio_lengths
  269. / self.hop_length
  270. / (self.downsample.total_strides if self.downsample is not None else 1)
  271. ).long()
  272. feature_masks = torch.unsqueeze(
  273. sequence_mask(feature_lengths, features.shape[2]), 1
  274. ).to(gt_mels.dtype)
  275. mel_masks = torch.unsqueeze(sequence_mask(mel_lengths, gt_mels.shape[2]), 1).to(
  276. gt_mels.dtype
  277. )
  278. # vq_features is 50 hz, need to convert to true mel size
  279. text_features = self.mel_encoder(features, feature_masks)
  280. text_features, _, _ = self.vq_encoder(text_features, feature_masks)
  281. text_features = F.interpolate(
  282. text_features, size=gt_mels.shape[2], mode="nearest"
  283. )
  284. # Sample mels
  285. speaker_features = (
  286. self.speaker_encoder(gt_mels, mel_masks)
  287. if self.speaker_encoder is not None
  288. else None
  289. )
  290. decoded_mels = self.decoder(text_features, mel_masks, g=speaker_features)
  291. fake_audios = self.generator(decoded_mels)
  292. fake_mels = self.mel_transform(fake_audios.squeeze(1))
  293. min_mel_length = min(
  294. decoded_mels.shape[-1], gt_mels.shape[-1], fake_mels.shape[-1]
  295. )
  296. decoded_mels = decoded_mels[:, :, :min_mel_length]
  297. gt_mels = gt_mels[:, :, :min_mel_length]
  298. fake_mels = fake_mels[:, :, :min_mel_length]
  299. mel_loss = F.l1_loss(gt_mels * mel_masks, fake_mels * mel_masks)
  300. self.log(
  301. "val/mel_loss",
  302. mel_loss,
  303. on_step=False,
  304. on_epoch=True,
  305. prog_bar=True,
  306. logger=True,
  307. sync_dist=True,
  308. )
  309. for idx, (
  310. mel,
  311. gen_mel,
  312. decode_mel,
  313. audio,
  314. gen_audio,
  315. audio_len,
  316. ) in enumerate(
  317. zip(
  318. gt_mels,
  319. fake_mels,
  320. decoded_mels,
  321. audios.detach().float(),
  322. fake_audios.detach().float(),
  323. audio_lengths,
  324. )
  325. ):
  326. mel_len = audio_len // self.hop_length
  327. image_mels = plot_mel(
  328. [
  329. gen_mel[:, :mel_len],
  330. decode_mel[:, :mel_len],
  331. mel[:, :mel_len],
  332. ],
  333. [
  334. "Generated",
  335. "Decoded",
  336. "Ground-Truth",
  337. ],
  338. )
  339. if isinstance(self.logger, WandbLogger):
  340. self.logger.experiment.log(
  341. {
  342. "reconstruction_mel": wandb.Image(image_mels, caption="mels"),
  343. "wavs": [
  344. wandb.Audio(
  345. audio[0, :audio_len],
  346. sample_rate=self.sampling_rate,
  347. caption="gt",
  348. ),
  349. wandb.Audio(
  350. gen_audio[0, :audio_len],
  351. sample_rate=self.sampling_rate,
  352. caption="prediction",
  353. ),
  354. ],
  355. },
  356. )
  357. if isinstance(self.logger, TensorBoardLogger):
  358. self.logger.experiment.add_figure(
  359. f"sample-{idx}/mels",
  360. image_mels,
  361. global_step=self.global_step,
  362. )
  363. self.logger.experiment.add_audio(
  364. f"sample-{idx}/wavs/gt",
  365. audio[0, :audio_len],
  366. self.global_step,
  367. sample_rate=self.sampling_rate,
  368. )
  369. self.logger.experiment.add_audio(
  370. f"sample-{idx}/wavs/prediction",
  371. gen_audio[0, :audio_len],
  372. self.global_step,
  373. sample_rate=self.sampling_rate,
  374. )
  375. plt.close(image_mels)
  376. class VQNaive(L.LightningModule):
  377. def __init__(
  378. self,
  379. optimizer: Callable,
  380. lr_scheduler: Callable,
  381. downsample: ConvDownSampler,
  382. vq_encoder: VQEncoder,
  383. speaker_encoder: SpeakerEncoder,
  384. mel_encoder: TextEncoder,
  385. decoder: TextEncoder,
  386. mel_transform: nn.Module,
  387. hop_length: int = 640,
  388. sample_rate: int = 32000,
  389. vocoder: Generator = None,
  390. ):
  391. super().__init__()
  392. # Model parameters
  393. self.optimizer_builder = optimizer
  394. self.lr_scheduler_builder = lr_scheduler
  395. # Generator and discriminators
  396. self.downsample = downsample
  397. self.vq_encoder = vq_encoder
  398. self.speaker_encoder = speaker_encoder
  399. self.mel_encoder = mel_encoder
  400. self.decoder = decoder
  401. self.mel_transform = mel_transform
  402. # Crop length for saving memory
  403. self.hop_length = hop_length
  404. self.sampling_rate = sample_rate
  405. # Vocoder
  406. self.vocoder = vocoder
  407. for p in self.vocoder.parameters():
  408. p.requires_grad = False
  409. def configure_optimizers(self):
  410. optimizer = self.optimizer_builder(self.parameters())
  411. lr_scheduler = self.lr_scheduler_builder(optimizer)
  412. return {
  413. "optimizer": optimizer,
  414. "lr_scheduler": {
  415. "scheduler": lr_scheduler,
  416. "interval": "step",
  417. },
  418. }
  419. def vq_encode(self, audios, audio_lengths):
  420. with torch.no_grad():
  421. features = gt_mels = self.mel_transform(
  422. audios, sample_rate=self.sampling_rate
  423. )
  424. if self.downsample is not None:
  425. features = self.downsample(features)
  426. mel_lengths = audio_lengths // self.hop_length
  427. feature_lengths = (
  428. audio_lengths
  429. / self.hop_length
  430. / (self.downsample.total_strides if self.downsample is not None else 1)
  431. ).long()
  432. feature_masks = torch.unsqueeze(
  433. sequence_mask(feature_lengths, features.shape[2]), 1
  434. ).to(gt_mels.dtype)
  435. mel_masks = torch.unsqueeze(sequence_mask(mel_lengths, gt_mels.shape[2]), 1).to(
  436. gt_mels.dtype
  437. )
  438. # vq_features is 50 hz, need to convert to true mel size
  439. text_features = self.mel_encoder(features, feature_masks)
  440. text_features, indices, loss_vq = self.vq_encoder(text_features, feature_masks)
  441. return mel_masks, gt_mels, text_features, indices, loss_vq
  442. def vq_decode(self, text_features, speaker_features, gt_mels, mel_masks):
  443. text_features = F.interpolate(
  444. text_features, size=gt_mels.shape[2], mode="nearest"
  445. )
  446. decoded_mels = self.decoder(text_features, mel_masks, g=speaker_features)
  447. return decoded_mels
  448. def training_step(self, batch, batch_idx):
  449. audios, audio_lengths = batch["audios"], batch["audio_lengths"]
  450. audios = audios.float()
  451. audios = audios[:, None, :]
  452. mel_masks, gt_mels, text_features, indices, loss_vq = self.vq_encode(
  453. audios, audio_lengths
  454. )
  455. speaker_features = self.speaker_encoder(gt_mels, mel_masks)
  456. decoded_mels = self.vq_decode(
  457. text_features, speaker_features, gt_mels, mel_masks
  458. )
  459. loss_mel = F.l1_loss(gt_mels * mel_masks, decoded_mels * mel_masks)
  460. loss = loss_mel + loss_vq
  461. self.log(
  462. "train/generator/loss",
  463. loss,
  464. on_step=True,
  465. on_epoch=False,
  466. prog_bar=True,
  467. logger=True,
  468. sync_dist=True,
  469. )
  470. self.log(
  471. "train/loss_mel",
  472. loss_mel,
  473. on_step=True,
  474. on_epoch=False,
  475. prog_bar=False,
  476. logger=True,
  477. sync_dist=True,
  478. )
  479. self.log(
  480. "train/generator/loss_vq",
  481. loss_vq,
  482. on_step=True,
  483. on_epoch=False,
  484. prog_bar=False,
  485. logger=True,
  486. sync_dist=True,
  487. )
  488. return loss
  489. def validation_step(self, batch: Any, batch_idx: int):
  490. audios, audio_lengths = batch["audios"], batch["audio_lengths"]
  491. audios = audios.float()
  492. audios = audios[:, None, :]
  493. mel_masks, gt_mels, text_features, indices, loss_vq = self.vq_encode(
  494. audios, audio_lengths
  495. )
  496. speaker_features = self.speaker_encoder(gt_mels, mel_masks)
  497. decoded_mels = self.vq_decode(
  498. text_features, speaker_features, gt_mels, mel_masks
  499. )
  500. fake_audios = self.vocoder(decoded_mels)
  501. mel_loss = F.l1_loss(gt_mels * mel_masks, decoded_mels * mel_masks)
  502. self.log(
  503. "val/mel_loss",
  504. mel_loss,
  505. on_step=False,
  506. on_epoch=True,
  507. prog_bar=True,
  508. logger=True,
  509. sync_dist=True,
  510. )
  511. for idx, (
  512. mel,
  513. decoded_mel,
  514. audio,
  515. gen_audio,
  516. audio_len,
  517. ) in enumerate(
  518. zip(
  519. gt_mels,
  520. decoded_mels,
  521. audios.detach().float(),
  522. fake_audios.detach().float(),
  523. audio_lengths,
  524. )
  525. ):
  526. mel_len = audio_len // self.hop_length
  527. image_mels = plot_mel(
  528. [
  529. decoded_mel[:, :mel_len],
  530. mel[:, :mel_len],
  531. ],
  532. [
  533. "Generated",
  534. "Ground-Truth",
  535. ],
  536. )
  537. if isinstance(self.logger, WandbLogger):
  538. self.logger.experiment.log(
  539. {
  540. "reconstruction_mel": wandb.Image(image_mels, caption="mels"),
  541. "wavs": [
  542. wandb.Audio(
  543. audio[0, :audio_len],
  544. sample_rate=self.sampling_rate,
  545. caption="gt",
  546. ),
  547. wandb.Audio(
  548. gen_audio[0, :audio_len],
  549. sample_rate=self.sampling_rate,
  550. caption="prediction",
  551. ),
  552. ],
  553. },
  554. )
  555. if isinstance(self.logger, TensorBoardLogger):
  556. self.logger.experiment.add_figure(
  557. f"sample-{idx}/mels",
  558. image_mels,
  559. global_step=self.global_step,
  560. )
  561. self.logger.experiment.add_audio(
  562. f"sample-{idx}/wavs/gt",
  563. audio[0, :audio_len],
  564. self.global_step,
  565. sample_rate=self.sampling_rate,
  566. )
  567. self.logger.experiment.add_audio(
  568. f"sample-{idx}/wavs/prediction",
  569. gen_audio[0, :audio_len],
  570. self.global_step,
  571. sample_rate=self.sampling_rate,
  572. )
  573. plt.close(image_mels)