lit_module.py 10.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298
  1. import itertools
  2. from typing import Any, Callable
  3. import lightning as L
  4. import torch
  5. import torch.nn.functional as F
  6. import wandb
  7. from diffusers.schedulers import DDIMScheduler, UniPCMultistepScheduler
  8. from diffusers.utils.torch_utils import randn_tensor
  9. from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
  10. from matplotlib import pyplot as plt
  11. from torch import nn
  12. from tqdm import tqdm
  13. from transformers import HubertModel
  14. from fish_speech.models.vq_diffusion.convnext_1d import ConvNext1DModel
  15. from fish_speech.models.vqgan.modules.encoders import (
  16. SpeakerEncoder,
  17. TextEncoder,
  18. VQEncoder,
  19. )
  20. from fish_speech.models.vqgan.utils import plot_mel, sequence_mask
  21. class VQDiffusion(L.LightningModule):
  22. def __init__(
  23. self,
  24. optimizer: Callable,
  25. lr_scheduler: Callable,
  26. mel_transform: nn.Module,
  27. feature_mel_transform: nn.Module,
  28. vq_encoder: VQEncoder,
  29. speaker_encoder: SpeakerEncoder,
  30. text_encoder: TextEncoder,
  31. denoiser: ConvNext1DModel,
  32. vocoder: nn.Module,
  33. hop_length: int = 640,
  34. sample_rate: int = 32000,
  35. ):
  36. super().__init__()
  37. # Model parameters
  38. self.optimizer_builder = optimizer
  39. self.lr_scheduler_builder = lr_scheduler
  40. # Generator and discriminators
  41. self.mel_transform = mel_transform
  42. self.feature_mel_transform = feature_mel_transform
  43. self.noise_scheduler_train = DDIMScheduler(num_train_timesteps=1000)
  44. self.noise_scheduler_infer = UniPCMultistepScheduler(num_train_timesteps=1000)
  45. # Modules
  46. self.vq_encoder = vq_encoder
  47. self.speaker_encoder = speaker_encoder
  48. self.text_encoder = text_encoder
  49. self.denoiser = denoiser
  50. self.vocoder = vocoder
  51. self.hop_length = hop_length
  52. self.sampling_rate = sample_rate
  53. # Freeze vocoder
  54. for param in self.vocoder.parameters():
  55. param.requires_grad = False
  56. def configure_optimizers(self):
  57. optimizer = self.optimizer_builder(self.parameters())
  58. lr_scheduler = self.lr_scheduler_builder(optimizer)
  59. return {
  60. "optimizer": optimizer,
  61. "lr_scheduler": {
  62. "scheduler": lr_scheduler,
  63. "interval": "step",
  64. },
  65. }
  66. def normalize_mels(self, x):
  67. # x is in range -10.1 to 3.1, normalize to -1 to 1
  68. x_min, x_max = -10.1, 3.1
  69. return (x - x_min) / (x_max - x_min) * 2 - 1
  70. def denormalize_mels(self, x):
  71. x_min, x_max = -10.1, 3.1
  72. return (x + 1) / 2 * (x_max - x_min) + x_min
  73. def training_step(self, batch, batch_idx):
  74. audios, audio_lengths = batch["audios"], batch["audio_lengths"]
  75. features, feature_lengths = batch["features"], batch["feature_lengths"]
  76. audios = audios.float()
  77. # features = features.float().mT
  78. audios = audios[:, None, :]
  79. with torch.no_grad():
  80. gt_mels = self.mel_transform(audios, sample_rate=self.sampling_rate)
  81. features = self.feature_mel_transform(
  82. audios, sample_rate=self.sampling_rate
  83. )
  84. mel_lengths = audio_lengths // self.hop_length
  85. feature_lengths = audio_lengths // self.hop_length // 2
  86. feature_masks = torch.unsqueeze(
  87. sequence_mask(feature_lengths, features.shape[2]), 1
  88. ).to(gt_mels.dtype)
  89. mel_masks = torch.unsqueeze(sequence_mask(mel_lengths, gt_mels.shape[2]), 1).to(
  90. gt_mels.dtype
  91. )
  92. speaker_features = self.speaker_encoder(gt_mels, mel_masks)
  93. # vq_features is 50 hz, need to convert to true mel size
  94. text_features = self.text_encoder(features, feature_masks)
  95. text_features, vq_loss = self.vq_encoder(text_features, feature_masks)
  96. text_features = F.interpolate(
  97. text_features, size=gt_mels.shape[2], mode="nearest"
  98. )
  99. text_features = text_features + speaker_features
  100. # Sample noise that we'll add to the images
  101. normalized_gt_mels = self.normalize_mels(gt_mels)
  102. noise = torch.randn_like(normalized_gt_mels)
  103. # Sample a random timestep for each image
  104. timesteps = torch.randint(
  105. 0,
  106. self.noise_scheduler_train.config.num_train_timesteps,
  107. (normalized_gt_mels.shape[0],),
  108. device=normalized_gt_mels.device,
  109. ).long()
  110. # Add noise to the clean images according to the noise magnitude at each timestep
  111. # (this is the forward diffusion process)
  112. noisy_images = self.noise_scheduler_train.add_noise(
  113. normalized_gt_mels, noise, timesteps
  114. )
  115. # Predict
  116. model_output = self.denoiser(noisy_images, timesteps, mel_masks, text_features)
  117. # MSE loss without the mask
  118. noise_loss = (torch.abs(model_output * mel_masks - noise * mel_masks)).sum() / (
  119. mel_masks.sum() * gt_mels.shape[1]
  120. )
  121. self.log(
  122. "train/noise_loss",
  123. noise_loss,
  124. on_step=True,
  125. on_epoch=False,
  126. prog_bar=True,
  127. logger=True,
  128. sync_dist=True,
  129. )
  130. self.log(
  131. "train/vq_loss",
  132. vq_loss,
  133. on_step=True,
  134. on_epoch=False,
  135. prog_bar=True,
  136. logger=True,
  137. sync_dist=True,
  138. )
  139. return noise_loss + vq_loss
  140. def validation_step(self, batch: Any, batch_idx: int):
  141. audios, audio_lengths = batch["audios"], batch["audio_lengths"]
  142. features, feature_lengths = batch["features"], batch["feature_lengths"]
  143. audios = audios.float()
  144. # features = features.float().mT
  145. audios = audios[:, None, :]
  146. gt_mels = self.mel_transform(audios, sample_rate=self.sampling_rate)
  147. features = self.feature_mel_transform(audios, sample_rate=self.sampling_rate)
  148. mel_lengths = audio_lengths // self.hop_length
  149. feature_lengths = audio_lengths // self.hop_length // 2
  150. feature_masks = torch.unsqueeze(
  151. sequence_mask(feature_lengths, features.shape[2]), 1
  152. ).to(gt_mels.dtype)
  153. mel_masks = torch.unsqueeze(sequence_mask(mel_lengths, gt_mels.shape[2]), 1).to(
  154. gt_mels.dtype
  155. )
  156. speaker_features = self.speaker_encoder(gt_mels, mel_masks)
  157. # vq_features is 50 hz, need to convert to true mel size
  158. text_features = self.text_encoder(features, feature_masks)
  159. text_features, vq_loss = self.vq_encoder(text_features, feature_masks)
  160. text_features = F.interpolate(
  161. text_features, size=gt_mels.shape[2], mode="nearest"
  162. )
  163. text_features = text_features + speaker_features
  164. # Begin sampling
  165. sampled_mels = torch.randn_like(gt_mels)
  166. self.noise_scheduler_infer.set_timesteps(100)
  167. for t in tqdm(self.noise_scheduler_infer.timesteps):
  168. timesteps = torch.tensor([t], device=sampled_mels.device, dtype=torch.long)
  169. # 1. predict noise model_output
  170. model_output = self.denoiser(
  171. sampled_mels, timesteps, mel_masks, text_features
  172. )
  173. # 2. compute previous image: x_t -> x_t-1
  174. sampled_mels = self.noise_scheduler_infer.step(
  175. model_output, t, sampled_mels
  176. ).prev_sample
  177. sampled_mels = self.denormalize_mels(sampled_mels)
  178. with torch.autocast(device_type=sampled_mels.device.type, enabled=False):
  179. # Run vocoder on fp32
  180. fake_audios = self.vocoder.decode(sampled_mels.float())
  181. mel_loss = F.l1_loss(gt_mels, sampled_mels)
  182. self.log(
  183. "val/mel_loss",
  184. mel_loss,
  185. on_step=False,
  186. on_epoch=True,
  187. prog_bar=True,
  188. logger=True,
  189. sync_dist=True,
  190. )
  191. for idx, (
  192. mel,
  193. gen_mel,
  194. audio,
  195. gen_audio,
  196. audio_len,
  197. ) in enumerate(
  198. zip(
  199. gt_mels,
  200. sampled_mels,
  201. audios,
  202. fake_audios,
  203. audio_lengths,
  204. )
  205. ):
  206. mel_len = audio_len // self.hop_length
  207. image_mels = plot_mel(
  208. [
  209. gen_mel[:, :mel_len],
  210. mel[:, :mel_len],
  211. ],
  212. [
  213. "Generated Spectrogram",
  214. "Ground-Truth Spectrogram",
  215. ],
  216. )
  217. if isinstance(self.logger, WandbLogger):
  218. self.logger.experiment.log(
  219. {
  220. "reconstruction_mel": wandb.Image(image_mels, caption="mels"),
  221. "wavs": [
  222. wandb.Audio(
  223. audio[0, :audio_len],
  224. sample_rate=self.sampling_rate,
  225. caption="gt",
  226. ),
  227. wandb.Audio(
  228. gen_audio[0, :audio_len],
  229. sample_rate=self.sampling_rate,
  230. caption="prediction",
  231. ),
  232. ],
  233. },
  234. )
  235. if isinstance(self.logger, TensorBoardLogger):
  236. self.logger.experiment.add_figure(
  237. f"sample-{idx}/mels",
  238. image_mels,
  239. global_step=self.global_step,
  240. )
  241. self.logger.experiment.add_audio(
  242. f"sample-{idx}/wavs/gt",
  243. audio[0, :audio_len],
  244. self.global_step,
  245. sample_rate=self.sampling_rate,
  246. )
  247. self.logger.experiment.add_audio(
  248. f"sample-{idx}/wavs/prediction",
  249. gen_audio[0, :audio_len],
  250. self.global_step,
  251. sample_rate=self.sampling_rate,
  252. )
  253. plt.close(image_mels)