lit_module.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654
  1. import itertools
  2. from typing import Any, Callable
  3. import lightning as L
  4. import torch
  5. import torch.nn.functional as F
  6. import wandb
  7. from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
  8. from matplotlib import pyplot as plt
  9. from torch import nn
  10. from vector_quantize_pytorch import VectorQuantize
  11. from fish_speech.models.vqgan.losses import (
  12. discriminator_loss,
  13. feature_loss,
  14. generator_loss,
  15. kl_loss,
  16. )
  17. from fish_speech.models.vqgan.modules.decoder import Generator
  18. from fish_speech.models.vqgan.modules.discriminator import EnsembleDiscriminator
  19. from fish_speech.models.vqgan.modules.encoders import (
  20. ConvDownSampler,
  21. SpeakerEncoder,
  22. TextEncoder,
  23. VQEncoder,
  24. )
  25. from fish_speech.models.vqgan.utils import (
  26. plot_mel,
  27. rand_slice_segments,
  28. sequence_mask,
  29. slice_segments,
  30. )
  31. class VQGAN(L.LightningModule):
  32. def __init__(
  33. self,
  34. optimizer: Callable,
  35. lr_scheduler: Callable,
  36. downsample: ConvDownSampler,
  37. vq_encoder: VQEncoder,
  38. mel_encoder: TextEncoder,
  39. decoder: TextEncoder,
  40. generator: Generator,
  41. discriminator: EnsembleDiscriminator,
  42. mel_transform: nn.Module,
  43. segment_size: int = 20480,
  44. hop_length: int = 640,
  45. sample_rate: int = 32000,
  46. freeze_hifigan: bool = False,
  47. freeze_vq: bool = False,
  48. speaker_encoder: SpeakerEncoder = None,
  49. ):
  50. super().__init__()
  51. # Model parameters
  52. self.optimizer_builder = optimizer
  53. self.lr_scheduler_builder = lr_scheduler
  54. # Generator and discriminators
  55. self.downsample = downsample
  56. self.vq_encoder = vq_encoder
  57. self.mel_encoder = mel_encoder
  58. self.speaker_encoder = speaker_encoder
  59. self.decoder = decoder
  60. self.generator = generator
  61. self.discriminator = discriminator
  62. self.mel_transform = mel_transform
  63. # Crop length for saving memory
  64. self.segment_size = segment_size
  65. self.hop_length = hop_length
  66. self.sampling_rate = sample_rate
  67. self.freeze_hifigan = freeze_hifigan
  68. # Disable automatic optimization
  69. self.automatic_optimization = False
  70. # Stage 1: Train the VQ only
  71. if self.freeze_hifigan:
  72. for p in self.discriminator.parameters():
  73. p.requires_grad = False
  74. for p in self.generator.parameters():
  75. p.requires_grad = False
  76. # Stage 2: Train the HifiGAN + Decoder + Generator
  77. if freeze_vq:
  78. for p in self.vq_encoder.parameters():
  79. p.requires_grad = False
  80. for p in self.mel_encoder.parameters():
  81. p.requires_grad = False
  82. for p in self.downsample.parameters():
  83. p.requires_grad = False
  84. def configure_optimizers(self):
  85. # Need two optimizers and two schedulers
  86. optimizer_generator = self.optimizer_builder(
  87. itertools.chain(
  88. self.downsample.parameters(),
  89. self.vq_encoder.parameters(),
  90. self.mel_encoder.parameters(),
  91. self.decoder.parameters(),
  92. self.generator.parameters(),
  93. )
  94. )
  95. optimizer_discriminator = self.optimizer_builder(
  96. self.discriminator.parameters()
  97. )
  98. lr_scheduler_generator = self.lr_scheduler_builder(optimizer_generator)
  99. lr_scheduler_discriminator = self.lr_scheduler_builder(optimizer_discriminator)
  100. return (
  101. {
  102. "optimizer": optimizer_generator,
  103. "lr_scheduler": {
  104. "scheduler": lr_scheduler_generator,
  105. "interval": "step",
  106. "name": "optimizer/generator",
  107. },
  108. },
  109. {
  110. "optimizer": optimizer_discriminator,
  111. "lr_scheduler": {
  112. "scheduler": lr_scheduler_discriminator,
  113. "interval": "step",
  114. "name": "optimizer/discriminator",
  115. },
  116. },
  117. )
  118. def training_step(self, batch, batch_idx):
  119. optim_g, optim_d = self.optimizers()
  120. audios, audio_lengths = batch["audios"], batch["audio_lengths"]
  121. audios = audios.float()
  122. audios = audios[:, None, :]
  123. with torch.no_grad():
  124. features = gt_mels = self.mel_transform(
  125. audios, sample_rate=self.sampling_rate
  126. )
  127. if self.downsample is not None:
  128. features = self.downsample(features)
  129. mel_lengths = audio_lengths // self.hop_length
  130. feature_lengths = (
  131. audio_lengths
  132. / self.hop_length
  133. / (self.downsample.total_strides if self.downsample is not None else 1)
  134. ).long()
  135. feature_masks = torch.unsqueeze(
  136. sequence_mask(feature_lengths, features.shape[2]), 1
  137. ).to(gt_mels.dtype)
  138. mel_masks = torch.unsqueeze(sequence_mask(mel_lengths, gt_mels.shape[2]), 1).to(
  139. gt_mels.dtype
  140. )
  141. # vq_features is 50 hz, need to convert to true mel size
  142. text_features = self.mel_encoder(features, feature_masks)
  143. text_features, _, loss_vq = self.vq_encoder(text_features, feature_masks)
  144. text_features = F.interpolate(
  145. text_features, size=gt_mels.shape[2], mode="nearest"
  146. )
  147. # Sample mels
  148. speaker_features = (
  149. self.speaker_encoder(gt_mels, mel_masks)
  150. if self.speaker_encoder is not None
  151. else None
  152. )
  153. decoded_mels = self.decoder(text_features, mel_masks, g=speaker_features)
  154. fake_audios = self.generator(decoded_mels)
  155. y_hat_mels = self.mel_transform(fake_audios.squeeze(1))
  156. y, ids_slice = rand_slice_segments(audios, audio_lengths, self.segment_size)
  157. y_hat = slice_segments(fake_audios, ids_slice, self.segment_size)
  158. assert y.shape == y_hat.shape, f"{y.shape} != {y_hat.shape}"
  159. # Since we don't want to update the discriminator, we skip the backward pass
  160. if self.freeze_hifigan is False:
  161. # Discriminator
  162. y_d_hat_r, y_d_hat_g, _, _ = self.discriminator(y, y_hat.detach())
  163. with torch.autocast(device_type=audios.device.type, enabled=False):
  164. loss_disc_all, _, _ = discriminator_loss(y_d_hat_r, y_d_hat_g)
  165. self.log(
  166. "train/discriminator/loss",
  167. loss_disc_all,
  168. on_step=True,
  169. on_epoch=False,
  170. prog_bar=True,
  171. logger=True,
  172. sync_dist=True,
  173. )
  174. optim_d.zero_grad()
  175. self.manual_backward(loss_disc_all)
  176. self.clip_gradients(
  177. optim_d, gradient_clip_val=1.0, gradient_clip_algorithm="norm"
  178. )
  179. optim_d.step()
  180. y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = self.discriminator(y, y_hat)
  181. with torch.autocast(device_type=audios.device.type, enabled=False):
  182. loss_decoded_mel = F.l1_loss(gt_mels * mel_masks, decoded_mels * mel_masks)
  183. loss_mel = F.l1_loss(gt_mels * mel_masks, y_hat_mels * mel_masks)
  184. loss_adv, _ = generator_loss(y_d_hat_g)
  185. loss_fm = feature_loss(fmap_r, fmap_g)
  186. if self.freeze_hifigan is True:
  187. loss_gen_all = loss_decoded_mel + loss_vq
  188. else:
  189. loss_gen_all = loss_mel * 45 + loss_vq * 45 + loss_fm + loss_adv
  190. self.log(
  191. "train/generator/loss",
  192. loss_gen_all,
  193. on_step=True,
  194. on_epoch=False,
  195. prog_bar=True,
  196. logger=True,
  197. sync_dist=True,
  198. )
  199. self.log(
  200. "train/generator/loss_decoded_mel",
  201. loss_decoded_mel,
  202. on_step=True,
  203. on_epoch=False,
  204. prog_bar=False,
  205. logger=True,
  206. sync_dist=True,
  207. )
  208. self.log(
  209. "train/generator/loss_mel",
  210. loss_mel,
  211. on_step=True,
  212. on_epoch=False,
  213. prog_bar=False,
  214. logger=True,
  215. sync_dist=True,
  216. )
  217. self.log(
  218. "train/generator/loss_fm",
  219. loss_fm,
  220. on_step=True,
  221. on_epoch=False,
  222. prog_bar=False,
  223. logger=True,
  224. sync_dist=True,
  225. )
  226. self.log(
  227. "train/generator/loss_adv",
  228. loss_adv,
  229. on_step=True,
  230. on_epoch=False,
  231. prog_bar=False,
  232. logger=True,
  233. sync_dist=True,
  234. )
  235. self.log(
  236. "train/generator/loss_vq",
  237. loss_vq,
  238. on_step=True,
  239. on_epoch=False,
  240. prog_bar=False,
  241. logger=True,
  242. sync_dist=True,
  243. )
  244. optim_g.zero_grad()
  245. self.manual_backward(loss_gen_all)
  246. self.clip_gradients(
  247. optim_g, gradient_clip_val=1.0, gradient_clip_algorithm="norm"
  248. )
  249. optim_g.step()
  250. # Manual LR Scheduler
  251. scheduler_g, scheduler_d = self.lr_schedulers()
  252. scheduler_g.step()
  253. scheduler_d.step()
  254. def validation_step(self, batch: Any, batch_idx: int):
  255. audios, audio_lengths = batch["audios"], batch["audio_lengths"]
  256. audios = audios.float()
  257. audios = audios[:, None, :]
  258. features = gt_mels = self.mel_transform(audios, sample_rate=self.sampling_rate)
  259. if self.downsample is not None:
  260. features = self.downsample(features)
  261. mel_lengths = audio_lengths // self.hop_length
  262. feature_lengths = (
  263. audio_lengths
  264. / self.hop_length
  265. / (self.downsample.total_strides if self.downsample is not None else 1)
  266. ).long()
  267. feature_masks = torch.unsqueeze(
  268. sequence_mask(feature_lengths, features.shape[2]), 1
  269. ).to(gt_mels.dtype)
  270. mel_masks = torch.unsqueeze(sequence_mask(mel_lengths, gt_mels.shape[2]), 1).to(
  271. gt_mels.dtype
  272. )
  273. # vq_features is 50 hz, need to convert to true mel size
  274. text_features = self.mel_encoder(features, feature_masks)
  275. text_features, _, _ = self.vq_encoder(text_features, feature_masks)
  276. text_features = F.interpolate(
  277. text_features, size=gt_mels.shape[2], mode="nearest"
  278. )
  279. # Sample mels
  280. speaker_features = (
  281. self.speaker_encoder(gt_mels, mel_masks)
  282. if self.speaker_encoder is not None
  283. else None
  284. )
  285. decoded_mels = self.decoder(text_features, mel_masks, g=speaker_features)
  286. fake_audios = self.generator(decoded_mels)
  287. fake_mels = self.mel_transform(fake_audios.squeeze(1))
  288. min_mel_length = min(
  289. decoded_mels.shape[-1], gt_mels.shape[-1], fake_mels.shape[-1]
  290. )
  291. decoded_mels = decoded_mels[:, :, :min_mel_length]
  292. gt_mels = gt_mels[:, :, :min_mel_length]
  293. fake_mels = fake_mels[:, :, :min_mel_length]
  294. mel_loss = F.l1_loss(gt_mels * mel_masks, fake_mels * mel_masks)
  295. self.log(
  296. "val/mel_loss",
  297. mel_loss,
  298. on_step=False,
  299. on_epoch=True,
  300. prog_bar=True,
  301. logger=True,
  302. sync_dist=True,
  303. )
  304. for idx, (
  305. mel,
  306. gen_mel,
  307. decode_mel,
  308. audio,
  309. gen_audio,
  310. audio_len,
  311. ) in enumerate(
  312. zip(
  313. gt_mels,
  314. fake_mels,
  315. decoded_mels,
  316. audios.detach().float(),
  317. fake_audios.detach().float(),
  318. audio_lengths,
  319. )
  320. ):
  321. mel_len = audio_len // self.hop_length
  322. image_mels = plot_mel(
  323. [
  324. gen_mel[:, :mel_len],
  325. decode_mel[:, :mel_len],
  326. mel[:, :mel_len],
  327. ],
  328. [
  329. "Generated",
  330. "Decoded",
  331. "Ground-Truth",
  332. ],
  333. )
  334. if isinstance(self.logger, WandbLogger):
  335. self.logger.experiment.log(
  336. {
  337. "reconstruction_mel": wandb.Image(image_mels, caption="mels"),
  338. "wavs": [
  339. wandb.Audio(
  340. audio[0, :audio_len],
  341. sample_rate=self.sampling_rate,
  342. caption="gt",
  343. ),
  344. wandb.Audio(
  345. gen_audio[0, :audio_len],
  346. sample_rate=self.sampling_rate,
  347. caption="prediction",
  348. ),
  349. ],
  350. },
  351. )
  352. if isinstance(self.logger, TensorBoardLogger):
  353. self.logger.experiment.add_figure(
  354. f"sample-{idx}/mels",
  355. image_mels,
  356. global_step=self.global_step,
  357. )
  358. self.logger.experiment.add_audio(
  359. f"sample-{idx}/wavs/gt",
  360. audio[0, :audio_len],
  361. self.global_step,
  362. sample_rate=self.sampling_rate,
  363. )
  364. self.logger.experiment.add_audio(
  365. f"sample-{idx}/wavs/prediction",
  366. gen_audio[0, :audio_len],
  367. self.global_step,
  368. sample_rate=self.sampling_rate,
  369. )
  370. plt.close(image_mels)
  371. class VQNaive(L.LightningModule):
  372. def __init__(
  373. self,
  374. optimizer: Callable,
  375. lr_scheduler: Callable,
  376. downsample: ConvDownSampler,
  377. vq_encoder: VQEncoder,
  378. speaker_encoder: SpeakerEncoder,
  379. mel_encoder: TextEncoder,
  380. decoder: TextEncoder,
  381. mel_transform: nn.Module,
  382. hop_length: int = 640,
  383. sample_rate: int = 32000,
  384. vocoder: Generator = None,
  385. ):
  386. super().__init__()
  387. # Model parameters
  388. self.optimizer_builder = optimizer
  389. self.lr_scheduler_builder = lr_scheduler
  390. # Generator and discriminators
  391. self.downsample = downsample
  392. self.vq_encoder = vq_encoder
  393. self.speaker_encoder = speaker_encoder
  394. self.mel_encoder = mel_encoder
  395. self.decoder = decoder
  396. self.mel_transform = mel_transform
  397. # Crop length for saving memory
  398. self.hop_length = hop_length
  399. self.sampling_rate = sample_rate
  400. # Vocoder
  401. self.vocoder = vocoder
  402. for p in self.vocoder.parameters():
  403. p.requires_grad = False
  404. def configure_optimizers(self):
  405. optimizer = self.optimizer_builder(self.parameters())
  406. lr_scheduler = self.lr_scheduler_builder(optimizer)
  407. return {
  408. "optimizer": optimizer,
  409. "lr_scheduler": {
  410. "scheduler": lr_scheduler,
  411. "interval": "step",
  412. },
  413. }
  414. def vq_encode(self, audios, audio_lengths):
  415. with torch.no_grad():
  416. features = gt_mels = self.mel_transform(
  417. audios, sample_rate=self.sampling_rate
  418. )
  419. if self.downsample is not None:
  420. features = self.downsample(features)
  421. mel_lengths = audio_lengths // self.hop_length
  422. feature_lengths = (
  423. audio_lengths
  424. / self.hop_length
  425. / (self.downsample.total_strides if self.downsample is not None else 1)
  426. ).long()
  427. feature_masks = torch.unsqueeze(
  428. sequence_mask(feature_lengths, features.shape[2]), 1
  429. ).to(gt_mels.dtype)
  430. mel_masks = torch.unsqueeze(sequence_mask(mel_lengths, gt_mels.shape[2]), 1).to(
  431. gt_mels.dtype
  432. )
  433. # vq_features is 50 hz, need to convert to true mel size
  434. text_features = self.mel_encoder(features, feature_masks)
  435. text_features, indices, loss_vq = self.vq_encoder(text_features, feature_masks)
  436. return mel_masks, gt_mels, text_features, indices, loss_vq
  437. def vq_decode(self, text_features, speaker_features, gt_mels, mel_masks):
  438. text_features = F.interpolate(
  439. text_features, size=gt_mels.shape[2], mode="nearest"
  440. )
  441. decoded_mels = self.decoder(text_features, mel_masks, g=speaker_features)
  442. return decoded_mels
  443. def training_step(self, batch, batch_idx):
  444. audios, audio_lengths = batch["audios"], batch["audio_lengths"]
  445. audios = audios.float()
  446. audios = audios[:, None, :]
  447. mel_masks, gt_mels, text_features, indices, loss_vq = self.vq_encode(
  448. audios, audio_lengths
  449. )
  450. speaker_features = self.speaker_encoder(gt_mels, mel_masks)
  451. decoded_mels = self.vq_decode(
  452. text_features, speaker_features, gt_mels, mel_masks
  453. )
  454. loss_mel = F.l1_loss(gt_mels * mel_masks, decoded_mels * mel_masks)
  455. loss = loss_mel + loss_vq
  456. self.log(
  457. "train/generator/loss",
  458. loss,
  459. on_step=True,
  460. on_epoch=False,
  461. prog_bar=True,
  462. logger=True,
  463. sync_dist=True,
  464. )
  465. self.log(
  466. "train/loss_mel",
  467. loss_mel,
  468. on_step=True,
  469. on_epoch=False,
  470. prog_bar=False,
  471. logger=True,
  472. sync_dist=True,
  473. )
  474. self.log(
  475. "train/generator/loss_vq",
  476. loss_vq,
  477. on_step=True,
  478. on_epoch=False,
  479. prog_bar=False,
  480. logger=True,
  481. sync_dist=True,
  482. )
  483. return loss
  484. def validation_step(self, batch: Any, batch_idx: int):
  485. audios, audio_lengths = batch["audios"], batch["audio_lengths"]
  486. audios = audios.float()
  487. audios = audios[:, None, :]
  488. mel_masks, gt_mels, text_features, indices, loss_vq = self.vq_encode(
  489. audios, audio_lengths
  490. )
  491. speaker_features = self.speaker_encoder(gt_mels, mel_masks)
  492. decoded_mels = self.vq_decode(
  493. text_features, speaker_features, gt_mels, mel_masks
  494. )
  495. fake_audios = self.vocoder(decoded_mels)
  496. mel_loss = F.l1_loss(gt_mels * mel_masks, decoded_mels * mel_masks)
  497. self.log(
  498. "val/mel_loss",
  499. mel_loss,
  500. on_step=False,
  501. on_epoch=True,
  502. prog_bar=True,
  503. logger=True,
  504. sync_dist=True,
  505. )
  506. for idx, (
  507. mel,
  508. decoded_mel,
  509. audio,
  510. gen_audio,
  511. audio_len,
  512. ) in enumerate(
  513. zip(
  514. gt_mels,
  515. decoded_mels,
  516. audios.detach().float(),
  517. fake_audios.detach().float(),
  518. audio_lengths,
  519. )
  520. ):
  521. mel_len = audio_len // self.hop_length
  522. image_mels = plot_mel(
  523. [
  524. decoded_mel[:, :mel_len],
  525. mel[:, :mel_len],
  526. ],
  527. [
  528. "Generated",
  529. "Ground-Truth",
  530. ],
  531. )
  532. if isinstance(self.logger, WandbLogger):
  533. self.logger.experiment.log(
  534. {
  535. "reconstruction_mel": wandb.Image(image_mels, caption="mels"),
  536. "wavs": [
  537. wandb.Audio(
  538. audio[0, :audio_len],
  539. sample_rate=self.sampling_rate,
  540. caption="gt",
  541. ),
  542. wandb.Audio(
  543. gen_audio[0, :audio_len],
  544. sample_rate=self.sampling_rate,
  545. caption="prediction",
  546. ),
  547. ],
  548. },
  549. )
  550. if isinstance(self.logger, TensorBoardLogger):
  551. self.logger.experiment.add_figure(
  552. f"sample-{idx}/mels",
  553. image_mels,
  554. global_step=self.global_step,
  555. )
  556. self.logger.experiment.add_audio(
  557. f"sample-{idx}/wavs/gt",
  558. audio[0, :audio_len],
  559. self.global_step,
  560. sample_rate=self.sampling_rate,
  561. )
  562. self.logger.experiment.add_audio(
  563. f"sample-{idx}/wavs/prediction",
  564. gen_audio[0, :audio_len],
  565. self.global_step,
  566. sample_rate=self.sampling_rate,
  567. )
  568. plt.close(image_mels)