lit_module.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668
  1. import itertools
  2. from typing import Any, Callable
  3. import lightning as L
  4. import torch
  5. import torch.nn.functional as F
  6. import wandb
  7. from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
  8. from matplotlib import pyplot as plt
  9. from torch import nn
  10. from vector_quantize_pytorch import VectorQuantize
  11. from fish_speech.models.vqgan.losses import (
  12. discriminator_loss,
  13. feature_loss,
  14. generator_loss,
  15. kl_loss,
  16. )
  17. from fish_speech.models.vqgan.modules.decoder import Generator
  18. from fish_speech.models.vqgan.modules.discriminator import EnsembleDiscriminator
  19. from fish_speech.models.vqgan.modules.encoders import (
  20. ConvDownSampler,
  21. SpeakerEncoder,
  22. TextEncoder,
  23. VQEncoder,
  24. )
  25. from fish_speech.models.vqgan.utils import (
  26. plot_mel,
  27. rand_slice_segments,
  28. sequence_mask,
  29. slice_segments,
  30. )
  31. class VQGAN(L.LightningModule):
  32. def __init__(
  33. self,
  34. optimizer: Callable,
  35. lr_scheduler: Callable,
  36. downsample: ConvDownSampler,
  37. vq_encoder: VQEncoder,
  38. speaker_encoder: SpeakerEncoder,
  39. text_encoder: TextEncoder,
  40. decoder: TextEncoder,
  41. generator: Generator,
  42. discriminator: EnsembleDiscriminator,
  43. mel_transform: nn.Module,
  44. feature_mel_transform: nn.Module,
  45. segment_size: int = 20480,
  46. hop_length: int = 640,
  47. sample_rate: int = 32000,
  48. freeze_hifigan: bool = False,
  49. freeze_vq: bool = False,
  50. ):
  51. super().__init__()
  52. # Model parameters
  53. self.optimizer_builder = optimizer
  54. self.lr_scheduler_builder = lr_scheduler
  55. # Generator and discriminators
  56. self.downsample = downsample
  57. self.vq_encoder = vq_encoder
  58. self.speaker_encoder = speaker_encoder
  59. self.text_encoder = text_encoder
  60. self.decoder = decoder
  61. self.generator = generator
  62. self.discriminator = discriminator
  63. self.mel_transform = mel_transform
  64. self.feature_mel_transform = feature_mel_transform
  65. # Crop length for saving memory
  66. self.segment_size = segment_size
  67. self.hop_length = hop_length
  68. self.sampling_rate = sample_rate
  69. self.freeze_hifigan = freeze_hifigan
  70. # Disable automatic optimization
  71. self.automatic_optimization = False
  72. # Stage 1: Train the VQ only
  73. if self.freeze_hifigan:
  74. for p in self.discriminator.parameters():
  75. p.requires_grad = False
  76. for p in self.generator.parameters():
  77. p.requires_grad = False
  78. # Stage 2: Train the HifiGAN + Decoder + Generator
  79. if freeze_vq:
  80. for p in self.vq_encoder.parameters():
  81. p.requires_grad = False
  82. for p in self.text_encoder.parameters():
  83. p.requires_grad = False
  84. for p in self.downsample.parameters():
  85. p.requires_grad = False
  86. def configure_optimizers(self):
  87. # Need two optimizers and two schedulers
  88. optimizer_generator = self.optimizer_builder(
  89. itertools.chain(
  90. self.downsample.parameters(),
  91. self.vq_encoder.parameters(),
  92. self.speaker_encoder.parameters(),
  93. self.text_encoder.parameters(),
  94. self.decoder.parameters(),
  95. self.generator.parameters(),
  96. )
  97. )
  98. optimizer_discriminator = self.optimizer_builder(
  99. self.discriminator.parameters()
  100. )
  101. lr_scheduler_generator = self.lr_scheduler_builder(optimizer_generator)
  102. lr_scheduler_discriminator = self.lr_scheduler_builder(optimizer_discriminator)
  103. return (
  104. {
  105. "optimizer": optimizer_generator,
  106. "lr_scheduler": {
  107. "scheduler": lr_scheduler_generator,
  108. "interval": "step",
  109. "name": "optimizer/generator",
  110. },
  111. },
  112. {
  113. "optimizer": optimizer_discriminator,
  114. "lr_scheduler": {
  115. "scheduler": lr_scheduler_discriminator,
  116. "interval": "step",
  117. "name": "optimizer/discriminator",
  118. },
  119. },
  120. )
  121. def training_step(self, batch, batch_idx):
  122. optim_g, optim_d = self.optimizers()
  123. audios, audio_lengths = batch["audios"], batch["audio_lengths"]
  124. audios = audios.float()
  125. audios = audios[:, None, :]
  126. with torch.no_grad():
  127. gt_mels = self.mel_transform(audios, sample_rate=self.sampling_rate)
  128. features = self.feature_mel_transform(
  129. audios, sample_rate=self.sampling_rate
  130. )
  131. if self.downsample is not None:
  132. features = self.downsample(features)
  133. mel_lengths = audio_lengths // self.hop_length
  134. feature_lengths = (
  135. audio_lengths
  136. / self.sampling_rate
  137. * self.feature_mel_transform.sample_rate
  138. / self.feature_mel_transform.hop_length
  139. / (self.downsample.total_strides if self.downsample is not None else 1)
  140. ).long()
  141. feature_masks = torch.unsqueeze(
  142. sequence_mask(feature_lengths, features.shape[2]), 1
  143. ).to(gt_mels.dtype)
  144. mel_masks = torch.unsqueeze(sequence_mask(mel_lengths, gt_mels.shape[2]), 1).to(
  145. gt_mels.dtype
  146. )
  147. speaker_features = self.speaker_encoder(features, feature_masks)
  148. # vq_features is 50 hz, need to convert to true mel size
  149. text_features = self.text_encoder(features, feature_masks)
  150. text_features, loss_vq = self.vq_encoder(text_features, feature_masks)
  151. text_features = F.interpolate(
  152. text_features, size=gt_mels.shape[2], mode="nearest"
  153. )
  154. # Sample mels
  155. decoded_mels = self.decoder(text_features, mel_masks, g=speaker_features)
  156. fake_audios = self.generator(decoded_mels)
  157. y_hat_mels = self.mel_transform(fake_audios.squeeze(1))
  158. y, ids_slice = rand_slice_segments(audios, audio_lengths, self.segment_size)
  159. y_hat = slice_segments(fake_audios, ids_slice, self.segment_size)
  160. assert y.shape == y_hat.shape, f"{y.shape} != {y_hat.shape}"
  161. # Since we don't want to update the discriminator, we skip the backward pass
  162. if self.freeze_hifigan is False:
  163. # Discriminator
  164. y_d_hat_r, y_d_hat_g, _, _ = self.discriminator(y, y_hat.detach())
  165. with torch.autocast(device_type=audios.device.type, enabled=False):
  166. loss_disc_all, _, _ = discriminator_loss(y_d_hat_r, y_d_hat_g)
  167. self.log(
  168. "train/discriminator/loss",
  169. loss_disc_all,
  170. on_step=True,
  171. on_epoch=False,
  172. prog_bar=True,
  173. logger=True,
  174. sync_dist=True,
  175. )
  176. optim_d.zero_grad()
  177. self.manual_backward(loss_disc_all)
  178. self.clip_gradients(
  179. optim_d, gradient_clip_val=1.0, gradient_clip_algorithm="norm"
  180. )
  181. optim_d.step()
  182. y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = self.discriminator(y, y_hat)
  183. with torch.autocast(device_type=audios.device.type, enabled=False):
  184. loss_decoded_mel = F.l1_loss(gt_mels * mel_masks, decoded_mels * mel_masks)
  185. loss_mel = F.l1_loss(gt_mels * mel_masks, y_hat_mels * mel_masks)
  186. loss_adv, _ = generator_loss(y_d_hat_g)
  187. loss_fm = feature_loss(fmap_r, fmap_g)
  188. if self.freeze_hifigan is True:
  189. loss_gen_all = loss_decoded_mel + loss_vq
  190. else:
  191. loss_gen_all = loss_mel * 45 + loss_vq * 45 + loss_fm + loss_adv
  192. self.log(
  193. "train/generator/loss",
  194. loss_gen_all,
  195. on_step=True,
  196. on_epoch=False,
  197. prog_bar=True,
  198. logger=True,
  199. sync_dist=True,
  200. )
  201. self.log(
  202. "train/generator/loss_decoded_mel",
  203. loss_decoded_mel,
  204. on_step=True,
  205. on_epoch=False,
  206. prog_bar=False,
  207. logger=True,
  208. sync_dist=True,
  209. )
  210. self.log(
  211. "train/generator/loss_mel",
  212. loss_mel,
  213. on_step=True,
  214. on_epoch=False,
  215. prog_bar=False,
  216. logger=True,
  217. sync_dist=True,
  218. )
  219. self.log(
  220. "train/generator/loss_fm",
  221. loss_fm,
  222. on_step=True,
  223. on_epoch=False,
  224. prog_bar=False,
  225. logger=True,
  226. sync_dist=True,
  227. )
  228. self.log(
  229. "train/generator/loss_adv",
  230. loss_adv,
  231. on_step=True,
  232. on_epoch=False,
  233. prog_bar=False,
  234. logger=True,
  235. sync_dist=True,
  236. )
  237. self.log(
  238. "train/generator/loss_vq",
  239. loss_vq,
  240. on_step=True,
  241. on_epoch=False,
  242. prog_bar=False,
  243. logger=True,
  244. sync_dist=True,
  245. )
  246. optim_g.zero_grad()
  247. self.manual_backward(loss_gen_all)
  248. self.clip_gradients(
  249. optim_g, gradient_clip_val=1.0, gradient_clip_algorithm="norm"
  250. )
  251. optim_g.step()
  252. # Manual LR Scheduler
  253. scheduler_g, scheduler_d = self.lr_schedulers()
  254. scheduler_g.step()
  255. scheduler_d.step()
  256. def validation_step(self, batch: Any, batch_idx: int):
  257. audios, audio_lengths = batch["audios"], batch["audio_lengths"]
  258. audios = audios.float()
  259. audios = audios[:, None, :]
  260. gt_mels = self.mel_transform(audios, sample_rate=self.sampling_rate)
  261. features = self.feature_mel_transform(audios, sample_rate=self.sampling_rate)
  262. if self.downsample is not None:
  263. features = self.downsample(features)
  264. mel_lengths = audio_lengths // self.hop_length
  265. feature_lengths = (
  266. audio_lengths
  267. / self.sampling_rate
  268. * self.feature_mel_transform.sample_rate
  269. / self.feature_mel_transform.hop_length
  270. / (self.downsample.total_strides if self.downsample is not None else 1)
  271. ).long()
  272. feature_masks = torch.unsqueeze(
  273. sequence_mask(feature_lengths, features.shape[2]), 1
  274. ).to(gt_mels.dtype)
  275. mel_masks = torch.unsqueeze(sequence_mask(mel_lengths, gt_mels.shape[2]), 1).to(
  276. gt_mels.dtype
  277. )
  278. speaker_features = self.speaker_encoder(gt_mels, mel_masks)
  279. # vq_features is 50 hz, need to convert to true mel size
  280. text_features = self.text_encoder(features, feature_masks)
  281. text_features, _ = self.vq_encoder(text_features, feature_masks)
  282. text_features = F.interpolate(
  283. text_features, size=gt_mels.shape[2], mode="nearest"
  284. )
  285. # Sample mels
  286. decoded_mels = self.decoder(text_features, mel_masks, g=speaker_features)
  287. fake_audios = self.generator(decoded_mels)
  288. fake_mels = self.mel_transform(fake_audios.squeeze(1))
  289. min_mel_length = min(
  290. decoded_mels.shape[-1], gt_mels.shape[-1], fake_mels.shape[-1]
  291. )
  292. decoded_mels = decoded_mels[:, :, :min_mel_length]
  293. gt_mels = gt_mels[:, :, :min_mel_length]
  294. fake_mels = fake_mels[:, :, :min_mel_length]
  295. mel_loss = F.l1_loss(gt_mels * mel_masks, fake_mels * mel_masks)
  296. self.log(
  297. "val/mel_loss",
  298. mel_loss,
  299. on_step=False,
  300. on_epoch=True,
  301. prog_bar=True,
  302. logger=True,
  303. sync_dist=True,
  304. )
  305. for idx, (
  306. mel,
  307. gen_mel,
  308. decode_mel,
  309. audio,
  310. gen_audio,
  311. audio_len,
  312. ) in enumerate(
  313. zip(
  314. gt_mels,
  315. fake_mels,
  316. decoded_mels,
  317. audios.detach().float(),
  318. fake_audios.detach().float(),
  319. audio_lengths,
  320. )
  321. ):
  322. mel_len = audio_len // self.hop_length
  323. image_mels = plot_mel(
  324. [
  325. gen_mel[:, :mel_len],
  326. decode_mel[:, :mel_len],
  327. mel[:, :mel_len],
  328. ],
  329. [
  330. "Generated",
  331. "Decoded",
  332. "Ground-Truth",
  333. ],
  334. )
  335. if isinstance(self.logger, WandbLogger):
  336. self.logger.experiment.log(
  337. {
  338. "reconstruction_mel": wandb.Image(image_mels, caption="mels"),
  339. "wavs": [
  340. wandb.Audio(
  341. audio[0, :audio_len],
  342. sample_rate=self.sampling_rate,
  343. caption="gt",
  344. ),
  345. wandb.Audio(
  346. gen_audio[0, :audio_len],
  347. sample_rate=self.sampling_rate,
  348. caption="prediction",
  349. ),
  350. ],
  351. },
  352. )
  353. if isinstance(self.logger, TensorBoardLogger):
  354. self.logger.experiment.add_figure(
  355. f"sample-{idx}/mels",
  356. image_mels,
  357. global_step=self.global_step,
  358. )
  359. self.logger.experiment.add_audio(
  360. f"sample-{idx}/wavs/gt",
  361. audio[0, :audio_len],
  362. self.global_step,
  363. sample_rate=self.sampling_rate,
  364. )
  365. self.logger.experiment.add_audio(
  366. f"sample-{idx}/wavs/prediction",
  367. gen_audio[0, :audio_len],
  368. self.global_step,
  369. sample_rate=self.sampling_rate,
  370. )
  371. plt.close(image_mels)
  372. class VQNaive(L.LightningModule):
  373. def __init__(
  374. self,
  375. optimizer: Callable,
  376. lr_scheduler: Callable,
  377. downsample: ConvDownSampler,
  378. vq_encoder: VQEncoder,
  379. speaker_encoder: SpeakerEncoder,
  380. mel_encoder: TextEncoder,
  381. decoder: TextEncoder,
  382. mel_transform: nn.Module,
  383. hop_length: int = 640,
  384. sample_rate: int = 32000,
  385. vocoder: Generator = None,
  386. ):
  387. super().__init__()
  388. # Model parameters
  389. self.optimizer_builder = optimizer
  390. self.lr_scheduler_builder = lr_scheduler
  391. # Generator and discriminators
  392. self.downsample = downsample
  393. self.vq_encoder = vq_encoder
  394. self.speaker_encoder = speaker_encoder
  395. self.mel_encoder = mel_encoder
  396. self.decoder = decoder
  397. self.mel_transform = mel_transform
  398. # Crop length for saving memory
  399. self.hop_length = hop_length
  400. self.sampling_rate = sample_rate
  401. # Vocoder
  402. self.vocoder = vocoder
  403. for p in self.vocoder.parameters():
  404. p.requires_grad = False
  405. def configure_optimizers(self):
  406. optimizer = self.optimizer_builder(self.parameters())
  407. lr_scheduler = self.lr_scheduler_builder(optimizer)
  408. return {
  409. "optimizer": optimizer,
  410. "lr_scheduler": {
  411. "scheduler": lr_scheduler,
  412. "interval": "step",
  413. },
  414. }
  415. def training_step(self, batch, batch_idx):
  416. audios, audio_lengths = batch["audios"], batch["audio_lengths"]
  417. audios = audios.float()
  418. audios = audios[:, None, :]
  419. with torch.no_grad():
  420. features = gt_mels = self.mel_transform(
  421. audios, sample_rate=self.sampling_rate
  422. )
  423. if self.downsample is not None:
  424. features = self.downsample(features)
  425. mel_lengths = audio_lengths // self.hop_length
  426. feature_lengths = (
  427. audio_lengths
  428. / self.hop_length
  429. / (self.downsample.total_strides if self.downsample is not None else 1)
  430. ).long()
  431. feature_masks = torch.unsqueeze(
  432. sequence_mask(feature_lengths, features.shape[2]), 1
  433. ).to(gt_mels.dtype)
  434. mel_masks = torch.unsqueeze(sequence_mask(mel_lengths, gt_mels.shape[2]), 1).to(
  435. gt_mels.dtype
  436. )
  437. speaker_features = self.speaker_encoder(gt_mels, mel_masks)
  438. # vq_features is 50 hz, need to convert to true mel size
  439. text_features = self.mel_encoder(features, feature_masks)
  440. text_features, loss_vq = self.vq_encoder(text_features, feature_masks)
  441. text_features = F.interpolate(
  442. text_features, size=gt_mels.shape[2], mode="nearest"
  443. )
  444. # Sample mels
  445. decoded_mels = self.decoder(text_features, mel_masks, g=speaker_features)
  446. loss_mel = F.l1_loss(gt_mels * mel_masks, decoded_mels * mel_masks)
  447. loss = loss_mel + loss_vq
  448. self.log(
  449. "train/generator/loss",
  450. loss,
  451. on_step=True,
  452. on_epoch=False,
  453. prog_bar=True,
  454. logger=True,
  455. sync_dist=True,
  456. )
  457. self.log(
  458. "train/loss_mel",
  459. loss_mel,
  460. on_step=True,
  461. on_epoch=False,
  462. prog_bar=False,
  463. logger=True,
  464. sync_dist=True,
  465. )
  466. self.log(
  467. "train/generator/loss_vq",
  468. loss_vq,
  469. on_step=True,
  470. on_epoch=False,
  471. prog_bar=False,
  472. logger=True,
  473. sync_dist=True,
  474. )
  475. return loss
  476. def validation_step(self, batch: Any, batch_idx: int):
  477. audios, audio_lengths = batch["audios"], batch["audio_lengths"]
  478. audios = audios.float()
  479. audios = audios[:, None, :]
  480. features = gt_mels = self.mel_transform(audios, sample_rate=self.sampling_rate)
  481. if self.downsample is not None:
  482. features = self.downsample(features)
  483. mel_lengths = audio_lengths // self.hop_length
  484. feature_lengths = (
  485. audio_lengths
  486. / self.hop_length
  487. / (self.downsample.total_strides if self.downsample is not None else 1)
  488. ).long()
  489. feature_masks = torch.unsqueeze(
  490. sequence_mask(feature_lengths, features.shape[2]), 1
  491. ).to(gt_mels.dtype)
  492. mel_masks = torch.unsqueeze(sequence_mask(mel_lengths, gt_mels.shape[2]), 1).to(
  493. gt_mels.dtype
  494. )
  495. speaker_features = self.speaker_encoder(gt_mels, mel_masks)
  496. # vq_features is 50 hz, need to convert to true mel size
  497. text_features = self.mel_encoder(features, feature_masks)
  498. text_features, _ = self.vq_encoder(text_features, feature_masks)
  499. text_features = F.interpolate(
  500. text_features, size=gt_mels.shape[2], mode="nearest"
  501. )
  502. # Sample mels
  503. decoded_mels = self.decoder(text_features, mel_masks, g=speaker_features)
  504. fake_audios = self.vocoder(decoded_mels)
  505. mel_loss = F.l1_loss(gt_mels * mel_masks, decoded_mels * mel_masks)
  506. self.log(
  507. "val/mel_loss",
  508. mel_loss,
  509. on_step=False,
  510. on_epoch=True,
  511. prog_bar=True,
  512. logger=True,
  513. sync_dist=True,
  514. )
  515. for idx, (
  516. mel,
  517. decoded_mel,
  518. audio,
  519. gen_audio,
  520. audio_len,
  521. ) in enumerate(
  522. zip(
  523. gt_mels,
  524. decoded_mels,
  525. audios.detach().float(),
  526. fake_audios.detach().float(),
  527. audio_lengths,
  528. )
  529. ):
  530. mel_len = audio_len // self.hop_length
  531. image_mels = plot_mel(
  532. [
  533. decoded_mel[:, :mel_len],
  534. mel[:, :mel_len],
  535. ],
  536. [
  537. "Generated",
  538. "Ground-Truth",
  539. ],
  540. )
  541. if isinstance(self.logger, WandbLogger):
  542. self.logger.experiment.log(
  543. {
  544. "reconstruction_mel": wandb.Image(image_mels, caption="mels"),
  545. "wavs": [
  546. wandb.Audio(
  547. audio[0, :audio_len],
  548. sample_rate=self.sampling_rate,
  549. caption="gt",
  550. ),
  551. wandb.Audio(
  552. gen_audio[0, :audio_len],
  553. sample_rate=self.sampling_rate,
  554. caption="prediction",
  555. ),
  556. ],
  557. },
  558. )
  559. if isinstance(self.logger, TensorBoardLogger):
  560. self.logger.experiment.add_figure(
  561. f"sample-{idx}/mels",
  562. image_mels,
  563. global_step=self.global_step,
  564. )
  565. self.logger.experiment.add_audio(
  566. f"sample-{idx}/wavs/gt",
  567. audio[0, :audio_len],
  568. self.global_step,
  569. sample_rate=self.sampling_rate,
  570. )
  571. self.logger.experiment.add_audio(
  572. f"sample-{idx}/wavs/prediction",
  573. gen_audio[0, :audio_len],
  574. self.global_step,
  575. sample_rate=self.sampling_rate,
  576. )
  577. plt.close(image_mels)