lit_module.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681
  1. import itertools
  2. from typing import Any, Callable
  3. import lightning as L
  4. import torch
  5. import torch.nn.functional as F
  6. import wandb
  7. from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
  8. from matplotlib import pyplot as plt
  9. from torch import nn
  10. from vector_quantize_pytorch import VectorQuantize
  11. from fish_speech.models.vqgan.losses import (
  12. discriminator_loss,
  13. feature_loss,
  14. generator_loss,
  15. kl_loss,
  16. )
  17. from fish_speech.models.vqgan.modules.decoder import Generator
  18. from fish_speech.models.vqgan.modules.discriminator import EnsembleDiscriminator
  19. from fish_speech.models.vqgan.modules.encoders import (
  20. ConvDownSampler,
  21. SpeakerEncoder,
  22. TextEncoder,
  23. VQEncoder,
  24. )
  25. from fish_speech.models.vqgan.utils import (
  26. plot_mel,
  27. rand_slice_segments,
  28. sequence_mask,
  29. slice_segments,
  30. )
  31. class VQGAN(L.LightningModule):
  32. def __init__(
  33. self,
  34. optimizer: Callable,
  35. lr_scheduler: Callable,
  36. downsample: ConvDownSampler,
  37. vq_encoder: VQEncoder,
  38. mel_encoder: TextEncoder,
  39. decoder: TextEncoder,
  40. generator: Generator,
  41. discriminator: EnsembleDiscriminator,
  42. mel_transform: nn.Module,
  43. segment_size: int = 20480,
  44. hop_length: int = 640,
  45. sample_rate: int = 32000,
  46. freeze_hifigan: bool = False,
  47. freeze_vq: bool = False,
  48. speaker_encoder: SpeakerEncoder = None,
  49. ):
  50. super().__init__()
  51. # Model parameters
  52. self.optimizer_builder = optimizer
  53. self.lr_scheduler_builder = lr_scheduler
  54. # Generator and discriminators
  55. self.downsample = downsample
  56. self.vq_encoder = vq_encoder
  57. self.mel_encoder = mel_encoder
  58. self.speaker_encoder = speaker_encoder
  59. self.decoder = decoder
  60. self.generator = generator
  61. self.discriminator = discriminator
  62. self.mel_transform = mel_transform
  63. # Crop length for saving memory
  64. self.segment_size = segment_size
  65. self.hop_length = hop_length
  66. self.sampling_rate = sample_rate
  67. self.freeze_hifigan = freeze_hifigan
  68. self.freeze_vq = freeze_vq
  69. # Disable automatic optimization
  70. self.automatic_optimization = False
  71. # Stage 1: Train the VQ only
  72. if self.freeze_hifigan:
  73. for p in self.discriminator.parameters():
  74. p.requires_grad = False
  75. for p in self.generator.parameters():
  76. p.requires_grad = False
  77. # Stage 2: Train the HifiGAN + Decoder + Generator
  78. if freeze_vq:
  79. for p in self.vq_encoder.parameters():
  80. p.requires_grad = False
  81. for p in self.mel_encoder.parameters():
  82. p.requires_grad = False
  83. for p in self.downsample.parameters():
  84. p.requires_grad = False
  85. def configure_optimizers(self):
  86. # Need two optimizers and two schedulers
  87. components = []
  88. if self.freeze_vq is False:
  89. components.extend(
  90. [
  91. self.downsample.parameters(),
  92. self.vq_encoder.parameters(),
  93. self.mel_encoder.parameters(),
  94. ]
  95. )
  96. if self.speaker_encoder is not None:
  97. components.append(self.speaker_encoder.parameters())
  98. components.append(self.decoder.parameters())
  99. if self.freeze_hifigan is False:
  100. components.append(self.generator.parameters())
  101. optimizer_generator = self.optimizer_builder(itertools.chain(*components))
  102. optimizer_discriminator = self.optimizer_builder(
  103. self.discriminator.parameters()
  104. )
  105. lr_scheduler_generator = self.lr_scheduler_builder(optimizer_generator)
  106. lr_scheduler_discriminator = self.lr_scheduler_builder(optimizer_discriminator)
  107. return (
  108. {
  109. "optimizer": optimizer_generator,
  110. "lr_scheduler": {
  111. "scheduler": lr_scheduler_generator,
  112. "interval": "step",
  113. "name": "optimizer/generator",
  114. },
  115. },
  116. {
  117. "optimizer": optimizer_discriminator,
  118. "lr_scheduler": {
  119. "scheduler": lr_scheduler_discriminator,
  120. "interval": "step",
  121. "name": "optimizer/discriminator",
  122. },
  123. },
  124. )
  125. def training_step(self, batch, batch_idx):
  126. optim_g, optim_d = self.optimizers()
  127. audios, audio_lengths = batch["audios"], batch["audio_lengths"]
  128. audios = audios.float()
  129. audios = audios[:, None, :]
  130. with torch.no_grad():
  131. features = gt_mels = self.mel_transform(
  132. audios, sample_rate=self.sampling_rate
  133. )
  134. if self.freeze_vq:
  135. # Disable gradient computation for VQ
  136. torch.set_grad_enabled(False)
  137. self.vq_encoder.eval()
  138. self.mel_encoder.eval()
  139. self.downsample.eval()
  140. if self.downsample is not None:
  141. features = self.downsample(features)
  142. mel_lengths = audio_lengths // self.hop_length
  143. feature_lengths = (
  144. audio_lengths
  145. / self.hop_length
  146. / (self.downsample.total_strides if self.downsample is not None else 1)
  147. ).long()
  148. feature_masks = torch.unsqueeze(
  149. sequence_mask(feature_lengths, features.shape[2]), 1
  150. ).to(gt_mels.dtype)
  151. mel_masks = torch.unsqueeze(sequence_mask(mel_lengths, gt_mels.shape[2]), 1).to(
  152. gt_mels.dtype
  153. )
  154. # vq_features is 50 hz, need to convert to true mel size
  155. text_features = self.mel_encoder(features, feature_masks)
  156. text_features, _, loss_vq = self.vq_encoder(
  157. text_features, feature_masks, freeze_codebook=self.freeze_vq
  158. )
  159. text_features = F.interpolate(
  160. text_features, size=gt_mels.shape[2], mode="nearest"
  161. )
  162. if loss_vq.ndim > 1:
  163. loss_vq = loss_vq.mean()
  164. if self.freeze_vq:
  165. # Enable gradient computation
  166. torch.set_grad_enabled(True)
  167. # Sample mels
  168. speaker_features = (
  169. self.speaker_encoder(gt_mels, mel_masks)
  170. if self.speaker_encoder is not None
  171. else None
  172. )
  173. decoded_mels = self.decoder(text_features, mel_masks, g=speaker_features)
  174. fake_audios = self.generator(decoded_mels)
  175. y_hat_mels = self.mel_transform(fake_audios.squeeze(1))
  176. y, ids_slice = rand_slice_segments(audios, audio_lengths, self.segment_size)
  177. y_hat = slice_segments(fake_audios, ids_slice, self.segment_size)
  178. assert y.shape == y_hat.shape, f"{y.shape} != {y_hat.shape}"
  179. # Since we don't want to update the discriminator, we skip the backward pass
  180. if self.freeze_hifigan is False:
  181. # Discriminator
  182. y_d_hat_r, y_d_hat_g, _, _ = self.discriminator(y, y_hat.detach())
  183. with torch.autocast(device_type=audios.device.type, enabled=False):
  184. loss_disc_all, _, _ = discriminator_loss(y_d_hat_r, y_d_hat_g)
  185. self.log(
  186. "train/discriminator/loss",
  187. loss_disc_all,
  188. on_step=True,
  189. on_epoch=False,
  190. prog_bar=True,
  191. logger=True,
  192. sync_dist=True,
  193. )
  194. optim_d.zero_grad()
  195. self.manual_backward(loss_disc_all)
  196. self.clip_gradients(
  197. optim_d, gradient_clip_val=1.0, gradient_clip_algorithm="norm"
  198. )
  199. optim_d.step()
  200. y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = self.discriminator(y, y_hat)
  201. with torch.autocast(device_type=audios.device.type, enabled=False):
  202. loss_decoded_mel = F.l1_loss(gt_mels * mel_masks, decoded_mels * mel_masks)
  203. loss_mel = F.l1_loss(gt_mels * mel_masks, y_hat_mels * mel_masks)
  204. loss_adv, _ = generator_loss(y_d_hat_g)
  205. loss_fm = feature_loss(fmap_r, fmap_g)
  206. if self.freeze_hifigan is True:
  207. loss_gen_all = loss_decoded_mel + loss_vq
  208. else:
  209. loss_gen_all = loss_mel * 45 + loss_vq * 45 + loss_fm + loss_adv
  210. self.log(
  211. "train/generator/loss",
  212. loss_gen_all,
  213. on_step=True,
  214. on_epoch=False,
  215. prog_bar=True,
  216. logger=True,
  217. sync_dist=True,
  218. )
  219. self.log(
  220. "train/generator/loss_decoded_mel",
  221. loss_decoded_mel,
  222. on_step=True,
  223. on_epoch=False,
  224. prog_bar=False,
  225. logger=True,
  226. sync_dist=True,
  227. )
  228. self.log(
  229. "train/generator/loss_mel",
  230. loss_mel,
  231. on_step=True,
  232. on_epoch=False,
  233. prog_bar=False,
  234. logger=True,
  235. sync_dist=True,
  236. )
  237. self.log(
  238. "train/generator/loss_fm",
  239. loss_fm,
  240. on_step=True,
  241. on_epoch=False,
  242. prog_bar=False,
  243. logger=True,
  244. sync_dist=True,
  245. )
  246. self.log(
  247. "train/generator/loss_adv",
  248. loss_adv,
  249. on_step=True,
  250. on_epoch=False,
  251. prog_bar=False,
  252. logger=True,
  253. sync_dist=True,
  254. )
  255. self.log(
  256. "train/generator/loss_vq",
  257. loss_vq,
  258. on_step=True,
  259. on_epoch=False,
  260. prog_bar=False,
  261. logger=True,
  262. sync_dist=True,
  263. )
  264. optim_g.zero_grad()
  265. self.manual_backward(loss_gen_all)
  266. self.clip_gradients(
  267. optim_g, gradient_clip_val=1.0, gradient_clip_algorithm="norm"
  268. )
  269. optim_g.step()
  270. # Manual LR Scheduler
  271. scheduler_g, scheduler_d = self.lr_schedulers()
  272. scheduler_g.step()
  273. scheduler_d.step()
  274. def validation_step(self, batch: Any, batch_idx: int):
  275. audios, audio_lengths = batch["audios"], batch["audio_lengths"]
  276. audios = audios.float()
  277. audios = audios[:, None, :]
  278. features = gt_mels = self.mel_transform(audios, sample_rate=self.sampling_rate)
  279. if self.downsample is not None:
  280. features = self.downsample(features)
  281. mel_lengths = audio_lengths // self.hop_length
  282. feature_lengths = (
  283. audio_lengths
  284. / self.hop_length
  285. / (self.downsample.total_strides if self.downsample is not None else 1)
  286. ).long()
  287. feature_masks = torch.unsqueeze(
  288. sequence_mask(feature_lengths, features.shape[2]), 1
  289. ).to(gt_mels.dtype)
  290. mel_masks = torch.unsqueeze(sequence_mask(mel_lengths, gt_mels.shape[2]), 1).to(
  291. gt_mels.dtype
  292. )
  293. # vq_features is 50 hz, need to convert to true mel size
  294. text_features = self.mel_encoder(features, feature_masks)
  295. text_features, _, _ = self.vq_encoder(text_features, feature_masks)
  296. text_features = F.interpolate(
  297. text_features, size=gt_mels.shape[2], mode="nearest"
  298. )
  299. # Sample mels
  300. speaker_features = (
  301. self.speaker_encoder(gt_mels, mel_masks)
  302. if self.speaker_encoder is not None
  303. else None
  304. )
  305. decoded_mels = self.decoder(text_features, mel_masks, g=speaker_features)
  306. fake_audios = self.generator(decoded_mels)
  307. fake_mels = self.mel_transform(fake_audios.squeeze(1))
  308. min_mel_length = min(
  309. decoded_mels.shape[-1], gt_mels.shape[-1], fake_mels.shape[-1]
  310. )
  311. decoded_mels = decoded_mels[:, :, :min_mel_length]
  312. gt_mels = gt_mels[:, :, :min_mel_length]
  313. fake_mels = fake_mels[:, :, :min_mel_length]
  314. mel_loss = F.l1_loss(gt_mels * mel_masks, fake_mels * mel_masks)
  315. self.log(
  316. "val/mel_loss",
  317. mel_loss,
  318. on_step=False,
  319. on_epoch=True,
  320. prog_bar=True,
  321. logger=True,
  322. sync_dist=True,
  323. )
  324. for idx, (
  325. mel,
  326. gen_mel,
  327. decode_mel,
  328. audio,
  329. gen_audio,
  330. audio_len,
  331. ) in enumerate(
  332. zip(
  333. gt_mels,
  334. fake_mels,
  335. decoded_mels,
  336. audios.detach().float(),
  337. fake_audios.detach().float(),
  338. audio_lengths,
  339. )
  340. ):
  341. mel_len = audio_len // self.hop_length
  342. image_mels = plot_mel(
  343. [
  344. gen_mel[:, :mel_len],
  345. decode_mel[:, :mel_len],
  346. mel[:, :mel_len],
  347. ],
  348. [
  349. "Generated",
  350. "Decoded",
  351. "Ground-Truth",
  352. ],
  353. )
  354. if isinstance(self.logger, WandbLogger):
  355. self.logger.experiment.log(
  356. {
  357. "reconstruction_mel": wandb.Image(image_mels, caption="mels"),
  358. "wavs": [
  359. wandb.Audio(
  360. audio[0, :audio_len],
  361. sample_rate=self.sampling_rate,
  362. caption="gt",
  363. ),
  364. wandb.Audio(
  365. gen_audio[0, :audio_len],
  366. sample_rate=self.sampling_rate,
  367. caption="prediction",
  368. ),
  369. ],
  370. },
  371. )
  372. if isinstance(self.logger, TensorBoardLogger):
  373. self.logger.experiment.add_figure(
  374. f"sample-{idx}/mels",
  375. image_mels,
  376. global_step=self.global_step,
  377. )
  378. self.logger.experiment.add_audio(
  379. f"sample-{idx}/wavs/gt",
  380. audio[0, :audio_len],
  381. self.global_step,
  382. sample_rate=self.sampling_rate,
  383. )
  384. self.logger.experiment.add_audio(
  385. f"sample-{idx}/wavs/prediction",
  386. gen_audio[0, :audio_len],
  387. self.global_step,
  388. sample_rate=self.sampling_rate,
  389. )
  390. plt.close(image_mels)
  391. class VQNaive(L.LightningModule):
  392. def __init__(
  393. self,
  394. optimizer: Callable,
  395. lr_scheduler: Callable,
  396. downsample: ConvDownSampler,
  397. vq_encoder: VQEncoder,
  398. speaker_encoder: SpeakerEncoder,
  399. mel_encoder: TextEncoder,
  400. decoder: TextEncoder,
  401. mel_transform: nn.Module,
  402. hop_length: int = 640,
  403. sample_rate: int = 32000,
  404. vocoder: Generator = None,
  405. ):
  406. super().__init__()
  407. # Model parameters
  408. self.optimizer_builder = optimizer
  409. self.lr_scheduler_builder = lr_scheduler
  410. # Generator and discriminators
  411. self.downsample = downsample
  412. self.vq_encoder = vq_encoder
  413. self.speaker_encoder = speaker_encoder
  414. self.mel_encoder = mel_encoder
  415. self.decoder = decoder
  416. self.mel_transform = mel_transform
  417. # Crop length for saving memory
  418. self.hop_length = hop_length
  419. self.sampling_rate = sample_rate
  420. # Vocoder
  421. self.vocoder = vocoder
  422. for p in self.vocoder.parameters():
  423. p.requires_grad = False
  424. def configure_optimizers(self):
  425. optimizer = self.optimizer_builder(self.parameters())
  426. lr_scheduler = self.lr_scheduler_builder(optimizer)
  427. return {
  428. "optimizer": optimizer,
  429. "lr_scheduler": {
  430. "scheduler": lr_scheduler,
  431. "interval": "step",
  432. },
  433. }
  434. def vq_encode(self, audios, audio_lengths):
  435. with torch.no_grad():
  436. features = gt_mels = self.mel_transform(
  437. audios, sample_rate=self.sampling_rate
  438. )
  439. if self.downsample is not None:
  440. features = self.downsample(features)
  441. mel_lengths = audio_lengths // self.hop_length
  442. feature_lengths = (
  443. audio_lengths
  444. / self.hop_length
  445. / (self.downsample.total_strides if self.downsample is not None else 1)
  446. ).long()
  447. feature_masks = torch.unsqueeze(
  448. sequence_mask(feature_lengths, features.shape[2]), 1
  449. ).to(gt_mels.dtype)
  450. mel_masks = torch.unsqueeze(sequence_mask(mel_lengths, gt_mels.shape[2]), 1).to(
  451. gt_mels.dtype
  452. )
  453. # vq_features is 50 hz, need to convert to true mel size
  454. text_features = self.mel_encoder(features, feature_masks)
  455. text_features, indices, loss_vq = self.vq_encoder(text_features, feature_masks)
  456. return mel_masks, gt_mels, text_features, indices, loss_vq
  457. def vq_decode(self, text_features, speaker_features, gt_mels, mel_masks):
  458. text_features = F.interpolate(
  459. text_features, size=gt_mels.shape[2], mode="nearest"
  460. )
  461. decoded_mels = self.decoder(text_features, mel_masks, g=speaker_features)
  462. return decoded_mels
  463. def training_step(self, batch, batch_idx):
  464. audios, audio_lengths = batch["audios"], batch["audio_lengths"]
  465. audios = audios.float()
  466. audios = audios[:, None, :]
  467. mel_masks, gt_mels, text_features, indices, loss_vq = self.vq_encode(
  468. audios, audio_lengths
  469. )
  470. speaker_features = self.speaker_encoder(gt_mels, mel_masks)
  471. decoded_mels = self.vq_decode(
  472. text_features, speaker_features, gt_mels, mel_masks
  473. )
  474. loss_mel = F.l1_loss(gt_mels * mel_masks, decoded_mels * mel_masks)
  475. loss = loss_mel + loss_vq
  476. self.log(
  477. "train/generator/loss",
  478. loss,
  479. on_step=True,
  480. on_epoch=False,
  481. prog_bar=True,
  482. logger=True,
  483. sync_dist=True,
  484. )
  485. self.log(
  486. "train/loss_mel",
  487. loss_mel,
  488. on_step=True,
  489. on_epoch=False,
  490. prog_bar=False,
  491. logger=True,
  492. sync_dist=True,
  493. )
  494. self.log(
  495. "train/generator/loss_vq",
  496. loss_vq,
  497. on_step=True,
  498. on_epoch=False,
  499. prog_bar=False,
  500. logger=True,
  501. sync_dist=True,
  502. )
  503. return loss
  504. def validation_step(self, batch: Any, batch_idx: int):
  505. audios, audio_lengths = batch["audios"], batch["audio_lengths"]
  506. audios = audios.float()
  507. audios = audios[:, None, :]
  508. mel_masks, gt_mels, text_features, indices, loss_vq = self.vq_encode(
  509. audios, audio_lengths
  510. )
  511. speaker_features = self.speaker_encoder(gt_mels, mel_masks)
  512. decoded_mels = self.vq_decode(
  513. text_features, speaker_features, gt_mels, mel_masks
  514. )
  515. fake_audios = self.vocoder(decoded_mels)
  516. mel_loss = F.l1_loss(gt_mels * mel_masks, decoded_mels * mel_masks)
  517. self.log(
  518. "val/mel_loss",
  519. mel_loss,
  520. on_step=False,
  521. on_epoch=True,
  522. prog_bar=True,
  523. logger=True,
  524. sync_dist=True,
  525. )
  526. for idx, (
  527. mel,
  528. decoded_mel,
  529. audio,
  530. gen_audio,
  531. audio_len,
  532. ) in enumerate(
  533. zip(
  534. gt_mels,
  535. decoded_mels,
  536. audios.detach().float(),
  537. fake_audios.detach().float(),
  538. audio_lengths,
  539. )
  540. ):
  541. mel_len = audio_len // self.hop_length
  542. image_mels = plot_mel(
  543. [
  544. decoded_mel[:, :mel_len],
  545. mel[:, :mel_len],
  546. ],
  547. [
  548. "Generated",
  549. "Ground-Truth",
  550. ],
  551. )
  552. if isinstance(self.logger, WandbLogger):
  553. self.logger.experiment.log(
  554. {
  555. "reconstruction_mel": wandb.Image(image_mels, caption="mels"),
  556. "wavs": [
  557. wandb.Audio(
  558. audio[0, :audio_len],
  559. sample_rate=self.sampling_rate,
  560. caption="gt",
  561. ),
  562. wandb.Audio(
  563. gen_audio[0, :audio_len],
  564. sample_rate=self.sampling_rate,
  565. caption="prediction",
  566. ),
  567. ],
  568. },
  569. )
  570. if isinstance(self.logger, TensorBoardLogger):
  571. self.logger.experiment.add_figure(
  572. f"sample-{idx}/mels",
  573. image_mels,
  574. global_step=self.global_step,
  575. )
  576. self.logger.experiment.add_audio(
  577. f"sample-{idx}/wavs/gt",
  578. audio[0, :audio_len],
  579. self.global_step,
  580. sample_rate=self.sampling_rate,
  581. )
  582. self.logger.experiment.add_audio(
  583. f"sample-{idx}/wavs/prediction",
  584. gen_audio[0, :audio_len],
  585. self.global_step,
  586. sample_rate=self.sampling_rate,
  587. )
  588. plt.close(image_mels)