lit_module.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288
  1. import itertools
  2. from typing import Any, Callable
  3. import lightning as L
  4. import torch
  5. import torch.nn.functional as F
  6. import wandb
  7. from diffusers.schedulers import DDIMScheduler, UniPCMultistepScheduler
  8. from diffusers.utils.torch_utils import randn_tensor
  9. from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
  10. from matplotlib import pyplot as plt
  11. from torch import nn
  12. from tqdm import tqdm
  13. from transformers import HubertModel
  14. from fish_speech.models.vq_diffusion.convnext_1d import ConvNext1DModel
  15. from fish_speech.models.vqgan.modules.encoders import (
  16. SpeakerEncoder,
  17. TextEncoder,
  18. VQEncoder,
  19. )
  20. from fish_speech.models.vqgan.utils import plot_mel, sequence_mask
  21. class VQDiffusion(L.LightningModule):
  22. def __init__(
  23. self,
  24. optimizer: Callable,
  25. lr_scheduler: Callable,
  26. mel_transform: nn.Module,
  27. vq_encoder: VQEncoder,
  28. speaker_encoder: SpeakerEncoder,
  29. text_encoder: TextEncoder,
  30. denoiser: ConvNext1DModel,
  31. vocoder: nn.Module,
  32. hop_length: int = 640,
  33. sample_rate: int = 32000,
  34. ):
  35. super().__init__()
  36. # Model parameters
  37. self.optimizer_builder = optimizer
  38. self.lr_scheduler_builder = lr_scheduler
  39. # Generator and discriminators
  40. self.mel_transform = mel_transform
  41. self.noise_scheduler_train = DDIMScheduler(num_train_timesteps=1000)
  42. self.noise_scheduler_infer = UniPCMultistepScheduler(num_train_timesteps=1000)
  43. # Modules
  44. self.vq_encoder = vq_encoder
  45. self.speaker_encoder = speaker_encoder
  46. self.text_encoder = text_encoder
  47. self.denoiser = denoiser
  48. self.vocoder = vocoder
  49. self.hop_length = hop_length
  50. self.sampling_rate = sample_rate
  51. # Freeze vocoder
  52. for param in self.vocoder.parameters():
  53. param.requires_grad = False
  54. def configure_optimizers(self):
  55. optimizer = self.optimizer_builder(self.parameters())
  56. lr_scheduler = self.lr_scheduler_builder(optimizer)
  57. return {
  58. "optimizer": optimizer,
  59. "lr_scheduler": {
  60. "scheduler": lr_scheduler,
  61. "interval": "step",
  62. },
  63. }
  64. def normalize_mels(self, x):
  65. # x is in range -10.1 to 3.1, normalize to -1 to 1
  66. x_min, x_max = -10.1, 3.1
  67. return (x - x_min) / (x_max - x_min) * 2 - 1
  68. def denormalize_mels(self, x):
  69. x_min, x_max = -10.1, 3.1
  70. return (x + 1) / 2 * (x_max - x_min) + x_min
  71. def training_step(self, batch, batch_idx):
  72. audios, audio_lengths = batch["audios"], batch["audio_lengths"]
  73. features, feature_lengths = batch["features"], batch["feature_lengths"]
  74. audios = audios.float()
  75. # features = features.float().mT
  76. audios = audios[:, None, :]
  77. with torch.no_grad():
  78. gt_mels = self.mel_transform(audios)
  79. mel_lengths = audio_lengths // self.hop_length
  80. feature_masks = torch.unsqueeze(
  81. sequence_mask(feature_lengths, features.shape[1]), 1
  82. ).to(gt_mels.dtype)
  83. mel_masks = torch.unsqueeze(sequence_mask(mel_lengths, gt_mels.shape[2]), 1).to(
  84. gt_mels.dtype
  85. )
  86. speaker_features = self.speaker_encoder(gt_mels, mel_masks)
  87. # vq_features, vq_loss = self.vq_encoder(features, feature_masks)
  88. # vq_features is 50 hz, need to convert to true mel size
  89. text_features = self.text_encoder(features, feature_masks, g=speaker_features)
  90. text_features = F.interpolate(
  91. text_features, size=gt_mels.shape[2], mode="nearest"
  92. )
  93. # Sample noise that we'll add to the images
  94. normalized_gt_mels = self.normalize_mels(gt_mels)
  95. noise = torch.randn_like(normalized_gt_mels)
  96. # Sample a random timestep for each image
  97. timesteps = torch.randint(
  98. 0,
  99. self.noise_scheduler_train.config.num_train_timesteps,
  100. (normalized_gt_mels.shape[0],),
  101. device=normalized_gt_mels.device,
  102. ).long()
  103. # Add noise to the clean images according to the noise magnitude at each timestep
  104. # (this is the forward diffusion process)
  105. noisy_images = self.noise_scheduler_train.add_noise(
  106. normalized_gt_mels, noise, timesteps
  107. )
  108. # Predict
  109. model_output = self.denoiser(noisy_images, timesteps, mel_masks, text_features)
  110. # MSE loss without the mask
  111. noise_loss = (torch.abs(model_output * mel_masks - noise * mel_masks)).sum() / (
  112. mel_masks.sum() * gt_mels.shape[1]
  113. )
  114. self.log(
  115. "train/noise_loss",
  116. noise_loss,
  117. on_step=True,
  118. on_epoch=False,
  119. prog_bar=True,
  120. logger=True,
  121. sync_dist=True,
  122. )
  123. # self.log(
  124. # "train/vq_loss",
  125. # vq_loss,
  126. # on_step=True,
  127. # on_epoch=False,
  128. # prog_bar=True,
  129. # logger=True,
  130. # sync_dist=True,
  131. # )
  132. return noise_loss # + vq_loss
  133. def validation_step(self, batch: Any, batch_idx: int):
  134. audios, audio_lengths = batch["audios"], batch["audio_lengths"]
  135. features, feature_lengths = batch["features"], batch["feature_lengths"]
  136. audios = audios.float()
  137. # features = features.float().mT
  138. audios = audios[:, None, :]
  139. gt_mels = self.mel_transform(audios)
  140. mel_lengths = audio_lengths // self.hop_length
  141. feature_masks = torch.unsqueeze(
  142. sequence_mask(feature_lengths, features.shape[1]), 1
  143. ).to(gt_mels.dtype)
  144. mel_masks = torch.unsqueeze(sequence_mask(mel_lengths, gt_mels.shape[2]), 1).to(
  145. gt_mels.dtype
  146. )
  147. speaker_features = self.speaker_encoder(gt_mels, mel_masks)
  148. # vq_features, vq_loss = self.vq_encoder(features, feature_masks)
  149. # vq_features is 50 hz, need to convert to true mel size
  150. text_features = self.text_encoder(features, feature_masks, g=speaker_features)
  151. text_features = F.interpolate(
  152. text_features, size=gt_mels.shape[2], mode="nearest"
  153. )
  154. # Begin sampling
  155. sampled_mels = torch.randn_like(gt_mels)
  156. self.noise_scheduler_infer.set_timesteps(100)
  157. for t in tqdm(self.noise_scheduler_infer.timesteps):
  158. timesteps = torch.tensor([t], device=sampled_mels.device, dtype=torch.long)
  159. # 1. predict noise model_output
  160. model_output = self.denoiser(
  161. sampled_mels, timesteps, mel_masks, text_features
  162. )
  163. # 2. compute previous image: x_t -> x_t-1
  164. sampled_mels = self.noise_scheduler_infer.step(
  165. model_output, t, sampled_mels
  166. ).prev_sample
  167. sampled_mels = self.denormalize_mels(sampled_mels)
  168. with torch.autocast(device_type=sampled_mels.device.type, enabled=False):
  169. # Run vocoder on fp32
  170. fake_audios = self.vocoder.decode(sampled_mels.float())
  171. mel_loss = F.l1_loss(gt_mels, sampled_mels)
  172. self.log(
  173. "val/mel_loss",
  174. mel_loss,
  175. on_step=False,
  176. on_epoch=True,
  177. prog_bar=True,
  178. logger=True,
  179. sync_dist=True,
  180. )
  181. for idx, (
  182. mel,
  183. gen_mel,
  184. audio,
  185. gen_audio,
  186. audio_len,
  187. ) in enumerate(
  188. zip(
  189. gt_mels,
  190. sampled_mels,
  191. audios,
  192. fake_audios,
  193. audio_lengths,
  194. )
  195. ):
  196. mel_len = audio_len // self.hop_length
  197. image_mels = plot_mel(
  198. [
  199. gen_mel[:, :mel_len],
  200. mel[:, :mel_len],
  201. ],
  202. [
  203. "Generated Spectrogram",
  204. "Ground-Truth Spectrogram",
  205. ],
  206. )
  207. if isinstance(self.logger, WandbLogger):
  208. self.logger.experiment.log(
  209. {
  210. "reconstruction_mel": wandb.Image(image_mels, caption="mels"),
  211. "wavs": [
  212. wandb.Audio(
  213. audio[0, :audio_len],
  214. sample_rate=self.sampling_rate,
  215. caption="gt",
  216. ),
  217. wandb.Audio(
  218. gen_audio[0, :audio_len],
  219. sample_rate=self.sampling_rate,
  220. caption="prediction",
  221. ),
  222. ],
  223. },
  224. )
  225. if isinstance(self.logger, TensorBoardLogger):
  226. self.logger.experiment.add_figure(
  227. f"sample-{idx}/mels",
  228. image_mels,
  229. global_step=self.global_step,
  230. )
  231. self.logger.experiment.add_audio(
  232. f"sample-{idx}/wavs/gt",
  233. audio[0, :audio_len],
  234. self.global_step,
  235. sample_rate=self.sampling_rate,
  236. )
  237. self.logger.experiment.add_audio(
  238. f"sample-{idx}/wavs/prediction",
  239. gen_audio[0, :audio_len],
  240. self.global_step,
  241. sample_rate=self.sampling_rate,
  242. )
  243. plt.close(image_mels)