lit_module.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285
  1. import itertools
  2. from typing import Any, Callable
  3. import lightning as L
  4. import torch
  5. import torch.nn.functional as F
  6. import wandb
  7. from diffusers.schedulers import DDIMScheduler, UniPCMultistepScheduler
  8. from diffusers.utils.torch_utils import randn_tensor
  9. from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
  10. from matplotlib import pyplot as plt
  11. from torch import nn
  12. from fish_speech.models.vq_diffusion.unet1d import Unet1DDenoiser
  13. from fish_speech.models.vqgan.modules.encoders import (
  14. SpeakerEncoder,
  15. TextEncoder,
  16. VQEncoder,
  17. )
  18. from fish_speech.models.vqgan.utils import plot_mel, sequence_mask
  19. class VQDiffusion(L.LightningModule):
  20. def __init__(
  21. self,
  22. optimizer: Callable,
  23. lr_scheduler: Callable,
  24. mel_transform: nn.Module,
  25. vq_encoder: VQEncoder,
  26. speaker_encoder: SpeakerEncoder,
  27. text_encoder: TextEncoder,
  28. denoiser: Unet1DDenoiser,
  29. vocoder: nn.Module,
  30. hop_length: int = 640,
  31. sample_rate: int = 32000,
  32. ):
  33. super().__init__()
  34. # Model parameters
  35. self.optimizer_builder = optimizer
  36. self.lr_scheduler_builder = lr_scheduler
  37. # Generator and discriminators
  38. self.mel_transform = mel_transform
  39. self.noise_scheduler_train = DDIMScheduler(num_train_timesteps=1000)
  40. self.noise_scheduler_infer = UniPCMultistepScheduler(num_train_timesteps=1000)
  41. # Modules
  42. self.vq_encoder = vq_encoder
  43. self.speaker_encoder = speaker_encoder
  44. self.text_encoder = text_encoder
  45. self.denoiser = denoiser
  46. self.vocoder = vocoder
  47. self.hop_length = hop_length
  48. self.sampling_rate = sample_rate
  49. # Freeze vocoder
  50. for param in self.vocoder.parameters():
  51. param.requires_grad = False
  52. def configure_optimizers(self):
  53. optimizer = self.optimizer_builder(self.parameters())
  54. lr_scheduler = self.lr_scheduler_builder(optimizer)
  55. return {
  56. "optimizer": optimizer,
  57. "lr_scheduler": {
  58. "scheduler": lr_scheduler,
  59. "interval": "step",
  60. },
  61. }
  62. def normalize_mels(self, x):
  63. # x is in range -10.1 to 3.1, normalize to -1 to 1
  64. x_min, x_max = -10.1, 3.1
  65. return (x - x_min) / (x_max - x_min) * 2 - 1
  66. def denormalize_mels(self, x):
  67. x_min, x_max = -10.1, 3.1
  68. return (x + 1) / 2 * (x_max - x_min) + x_min
  69. def training_step(self, batch, batch_idx):
  70. audios, audio_lengths = batch["audios"], batch["audio_lengths"]
  71. features, feature_lengths = batch["features"], batch["feature_lengths"]
  72. audios = audios.float()
  73. features = features.float().mT
  74. audios = audios[:, None, :]
  75. with torch.no_grad():
  76. gt_mels = self.mel_transform(audios)
  77. mel_lengths = audio_lengths // self.hop_length
  78. feature_masks = torch.unsqueeze(
  79. sequence_mask(feature_lengths, features.shape[2]), 1
  80. ).to(gt_mels.dtype)
  81. mel_masks = torch.unsqueeze(sequence_mask(mel_lengths, gt_mels.shape[2]), 1).to(
  82. gt_mels.dtype
  83. )
  84. speaker_features = self.speaker_encoder(gt_mels, mel_masks)
  85. vq_features, vq_loss = self.vq_encoder(features, feature_masks)
  86. # vq_features is 50 hz, need to convert to true mel size
  87. vq_features = F.interpolate(vq_features, size=gt_mels.shape[2], mode="nearest")
  88. text_features = self.text_encoder(vq_features, mel_masks, g=speaker_features)
  89. # Sample noise that we'll add to the images
  90. normalized_gt_mels = self.normalize_mels(gt_mels)
  91. noise = torch.randn_like(normalized_gt_mels)
  92. # Sample a random timestep for each image
  93. timesteps = torch.randint(
  94. 0,
  95. self.noise_scheduler_train.config.num_train_timesteps,
  96. (normalized_gt_mels.shape[0],),
  97. device=normalized_gt_mels.device,
  98. ).long()
  99. # Add noise to the clean images according to the noise magnitude at each timestep
  100. # (this is the forward diffusion process)
  101. noisy_images = self.noise_scheduler_train.add_noise(
  102. normalized_gt_mels, noise, timesteps
  103. )
  104. # Predict
  105. model_output = self.denoiser(noisy_images, timesteps, mel_masks, text_features)
  106. # MSE loss without the mask
  107. # noise_loss = (
  108. # (model_output * mel_masks - normalized_gt_mels * mel_masks) ** 2
  109. # ).sum() / (mel_masks.sum() * gt_mels.shape[1])
  110. noise_loss = torch.abs(
  111. model_output * mel_masks - normalized_gt_mels * mel_masks
  112. ).sum() / (mel_masks.sum() * gt_mels.shape[1])
  113. self.log(
  114. "train/noise_loss",
  115. noise_loss,
  116. on_step=True,
  117. on_epoch=False,
  118. prog_bar=True,
  119. logger=True,
  120. sync_dist=True,
  121. )
  122. self.log(
  123. "train/vq_loss",
  124. vq_loss,
  125. on_step=True,
  126. on_epoch=False,
  127. prog_bar=True,
  128. logger=True,
  129. sync_dist=True,
  130. )
  131. return noise_loss + vq_loss
  132. def validation_step(self, batch: Any, batch_idx: int):
  133. audios, audio_lengths = batch["audios"], batch["audio_lengths"]
  134. features, feature_lengths = batch["features"], batch["feature_lengths"]
  135. audios = audios.float()
  136. features = features.float().mT
  137. audios = audios[:, None, :]
  138. gt_mels = self.mel_transform(audios)
  139. mel_lengths = audio_lengths // self.hop_length
  140. feature_masks = torch.unsqueeze(
  141. sequence_mask(feature_lengths, features.shape[2]), 1
  142. ).to(gt_mels.dtype)
  143. mel_masks = torch.unsqueeze(sequence_mask(mel_lengths, gt_mels.shape[2]), 1).to(
  144. gt_mels.dtype
  145. )
  146. speaker_features = self.speaker_encoder(gt_mels, mel_masks)
  147. vq_features, _ = self.vq_encoder(features, feature_masks)
  148. # vq_features is 50 hz, need to convert to true mel size
  149. vq_features = F.interpolate(vq_features, size=gt_mels.shape[2], mode="nearest")
  150. text_features = self.text_encoder(vq_features, mel_masks, g=speaker_features)
  151. # Begin sampling
  152. sampled_mels = torch.randn_like(gt_mels)
  153. self.noise_scheduler_infer.set_timesteps(100)
  154. for t in self.noise_scheduler_infer.timesteps:
  155. timesteps = torch.tensor([t], device=sampled_mels.device, dtype=torch.long)
  156. # 1. predict noise model_output
  157. model_output = self.denoiser(
  158. sampled_mels, timesteps, mel_masks, text_features
  159. )
  160. # 2. compute previous image: x_t -> x_t-1
  161. sampled_mels = self.noise_scheduler_infer.step(
  162. model_output, t, sampled_mels
  163. ).prev_sample
  164. sampled_mels = self.denormalize_mels(sampled_mels)
  165. with torch.autocast(device_type=sampled_mels.device.type, enabled=False):
  166. # Run vocoder on fp32
  167. fake_audios = self.vocoder.decode(sampled_mels.float())
  168. mel_loss = F.l1_loss(gt_mels, sampled_mels)
  169. self.log(
  170. "val/mel_loss",
  171. mel_loss,
  172. on_step=False,
  173. on_epoch=True,
  174. prog_bar=True,
  175. logger=True,
  176. sync_dist=True,
  177. )
  178. for idx, (
  179. mel,
  180. gen_mel,
  181. audio,
  182. gen_audio,
  183. audio_len,
  184. ) in enumerate(
  185. zip(
  186. gt_mels,
  187. sampled_mels,
  188. audios,
  189. fake_audios,
  190. audio_lengths,
  191. )
  192. ):
  193. mel_len = audio_len // self.hop_length
  194. image_mels = plot_mel(
  195. [
  196. gen_mel[:, :mel_len],
  197. mel[:, :mel_len],
  198. ],
  199. [
  200. "Generated Spectrogram",
  201. "Ground-Truth Spectrogram",
  202. ],
  203. )
  204. if isinstance(self.logger, WandbLogger):
  205. self.logger.experiment.log(
  206. {
  207. "reconstruction_mel": wandb.Image(image_mels, caption="mels"),
  208. "wavs": [
  209. wandb.Audio(
  210. audio[0, :audio_len],
  211. sample_rate=self.sampling_rate,
  212. caption="gt",
  213. ),
  214. wandb.Audio(
  215. gen_audio[0, :audio_len],
  216. sample_rate=self.sampling_rate,
  217. caption="prediction",
  218. ),
  219. ],
  220. },
  221. )
  222. if isinstance(self.logger, TensorBoardLogger):
  223. self.logger.experiment.add_figure(
  224. f"sample-{idx}/mels",
  225. image_mels,
  226. global_step=self.global_step,
  227. )
  228. self.logger.experiment.add_audio(
  229. f"sample-{idx}/wavs/gt",
  230. audio[0, :audio_len],
  231. self.global_step,
  232. sample_rate=self.sampling_rate,
  233. )
  234. self.logger.experiment.add_audio(
  235. f"sample-{idx}/wavs/prediction",
  236. gen_audio[0, :audio_len],
  237. self.global_step,
  238. sample_rate=self.sampling_rate,
  239. )
  240. plt.close(image_mels)