lit_module.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307
  1. import itertools
  2. from typing import Any, Callable
  3. import lightning as L
  4. import torch
  5. import torch.nn.functional as F
  6. import wandb
  7. from diffusers.schedulers import DDIMScheduler, UniPCMultistepScheduler
  8. from diffusers.utils.torch_utils import randn_tensor
  9. from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
  10. from matplotlib import pyplot as plt
  11. from torch import nn
  12. from tqdm import tqdm
  13. from transformers import HubertModel
  14. from fish_speech.models.vq_diffusion.convnext_1d import ConvNext1DModel
  15. from fish_speech.models.vqgan.modules.encoders import (
  16. SpeakerEncoder,
  17. TextEncoder,
  18. VQEncoder,
  19. )
  20. from fish_speech.models.vqgan.utils import plot_mel, sequence_mask
  21. class VQDiffusion(L.LightningModule):
  22. def __init__(
  23. self,
  24. optimizer: Callable,
  25. lr_scheduler: Callable,
  26. mel_transform: nn.Module,
  27. feature_mel_transform: nn.Module,
  28. vq_encoder: VQEncoder,
  29. speaker_encoder: SpeakerEncoder,
  30. text_encoder: TextEncoder,
  31. denoiser: ConvNext1DModel,
  32. vocoder: nn.Module,
  33. hop_length: int = 640,
  34. sample_rate: int = 32000,
  35. speaker_use_feats: bool = False,
  36. ):
  37. super().__init__()
  38. # Model parameters
  39. self.optimizer_builder = optimizer
  40. self.lr_scheduler_builder = lr_scheduler
  41. # Generator and discriminators
  42. self.mel_transform = mel_transform
  43. self.feature_mel_transform = feature_mel_transform
  44. self.noise_scheduler_train = DDIMScheduler(num_train_timesteps=1000)
  45. self.noise_scheduler_infer = UniPCMultistepScheduler(num_train_timesteps=1000)
  46. # Modules
  47. self.vq_encoder = vq_encoder
  48. self.speaker_encoder = speaker_encoder
  49. self.text_encoder = text_encoder
  50. self.denoiser = denoiser
  51. self.vocoder = vocoder
  52. self.hop_length = hop_length
  53. self.sampling_rate = sample_rate
  54. self.speaker_use_feats = speaker_use_feats
  55. # Freeze vocoder
  56. for param in self.vocoder.parameters():
  57. param.requires_grad = False
  58. def configure_optimizers(self):
  59. optimizer = self.optimizer_builder(self.parameters())
  60. lr_scheduler = self.lr_scheduler_builder(optimizer)
  61. return {
  62. "optimizer": optimizer,
  63. "lr_scheduler": {
  64. "scheduler": lr_scheduler,
  65. "interval": "step",
  66. },
  67. }
  68. def normalize_mels(self, x):
  69. # x is in range -10.1 to 3.1, normalize to -1 to 1
  70. x_min, x_max = -10.1, 3.1
  71. return (x - x_min) / (x_max - x_min) * 2 - 1
  72. def denormalize_mels(self, x):
  73. x_min, x_max = -10.1, 3.1
  74. return (x + 1) / 2 * (x_max - x_min) + x_min
  75. def training_step(self, batch, batch_idx):
  76. audios, audio_lengths = batch["audios"], batch["audio_lengths"]
  77. features, feature_lengths = batch["features"], batch["feature_lengths"]
  78. audios = audios.float()
  79. # features = features.float().mT
  80. audios = audios[:, None, :]
  81. with torch.no_grad():
  82. gt_mels = self.mel_transform(audios, sample_rate=self.sampling_rate)
  83. features = self.feature_mel_transform(
  84. audios, sample_rate=self.sampling_rate
  85. )
  86. mel_lengths = audio_lengths // self.hop_length
  87. feature_lengths = audio_lengths // self.hop_length // 2
  88. feature_masks = torch.unsqueeze(
  89. sequence_mask(feature_lengths, features.shape[2]), 1
  90. ).to(gt_mels.dtype)
  91. mel_masks = torch.unsqueeze(sequence_mask(mel_lengths, gt_mels.shape[2]), 1).to(
  92. gt_mels.dtype
  93. )
  94. if self.speaker_use_feats:
  95. speaker_features = self.speaker_encoder(features, feature_masks)
  96. else:
  97. speaker_features = self.speaker_encoder(gt_mels, mel_masks)
  98. # vq_features is 50 hz, need to convert to true mel size
  99. text_features = self.text_encoder(features, feature_masks)
  100. text_features, vq_loss = self.vq_encoder(text_features, feature_masks)
  101. text_features = F.interpolate(
  102. text_features, size=gt_mels.shape[2], mode="nearest"
  103. )
  104. text_features = text_features + speaker_features
  105. # Sample noise that we'll add to the images
  106. normalized_gt_mels = self.normalize_mels(gt_mels)
  107. noise = torch.randn_like(normalized_gt_mels)
  108. # Sample a random timestep for each image
  109. timesteps = torch.randint(
  110. 0,
  111. self.noise_scheduler_train.config.num_train_timesteps,
  112. (normalized_gt_mels.shape[0],),
  113. device=normalized_gt_mels.device,
  114. ).long()
  115. # Add noise to the clean images according to the noise magnitude at each timestep
  116. # (this is the forward diffusion process)
  117. noisy_images = self.noise_scheduler_train.add_noise(
  118. normalized_gt_mels, noise, timesteps
  119. )
  120. # Predict
  121. model_output = self.denoiser(noisy_images, timesteps, mel_masks, text_features)
  122. # MSE loss without the mask
  123. noise_loss = (torch.abs(model_output * mel_masks - noise * mel_masks)).sum() / (
  124. mel_masks.sum() * gt_mels.shape[1]
  125. )
  126. self.log(
  127. "train/noise_loss",
  128. noise_loss,
  129. on_step=True,
  130. on_epoch=False,
  131. prog_bar=True,
  132. logger=True,
  133. sync_dist=True,
  134. )
  135. self.log(
  136. "train/vq_loss",
  137. vq_loss,
  138. on_step=True,
  139. on_epoch=False,
  140. prog_bar=True,
  141. logger=True,
  142. sync_dist=True,
  143. )
  144. return noise_loss + vq_loss
  145. def validation_step(self, batch: Any, batch_idx: int):
  146. audios, audio_lengths = batch["audios"], batch["audio_lengths"]
  147. features, feature_lengths = batch["features"], batch["feature_lengths"]
  148. audios = audios.float()
  149. # features = features.float().mT
  150. audios = audios[:, None, :]
  151. gt_mels = self.mel_transform(audios, sample_rate=self.sampling_rate)
  152. features = self.feature_mel_transform(audios, sample_rate=self.sampling_rate)
  153. mel_lengths = audio_lengths // self.hop_length
  154. feature_lengths = audio_lengths // self.hop_length // 2
  155. feature_masks = torch.unsqueeze(
  156. sequence_mask(feature_lengths, features.shape[2]), 1
  157. ).to(gt_mels.dtype)
  158. mel_masks = torch.unsqueeze(sequence_mask(mel_lengths, gt_mels.shape[2]), 1).to(
  159. gt_mels.dtype
  160. )
  161. if self.speaker_use_feats:
  162. speaker_features = self.speaker_encoder(features, feature_masks)
  163. else:
  164. speaker_features = self.speaker_encoder(gt_mels, mel_masks)
  165. # vq_features is 50 hz, need to convert to true mel size
  166. text_features = self.text_encoder(features, feature_masks)
  167. text_features, vq_loss = self.vq_encoder(text_features, feature_masks)
  168. text_features = F.interpolate(
  169. text_features, size=gt_mels.shape[2], mode="nearest"
  170. )
  171. text_features = text_features + speaker_features
  172. # Begin sampling
  173. sampled_mels = torch.randn_like(gt_mels)
  174. self.noise_scheduler_infer.set_timesteps(100)
  175. for t in tqdm(self.noise_scheduler_infer.timesteps):
  176. timesteps = torch.tensor([t], device=sampled_mels.device, dtype=torch.long)
  177. # 1. predict noise model_output
  178. model_output = self.denoiser(
  179. sampled_mels, timesteps, mel_masks, text_features
  180. )
  181. # 2. compute previous image: x_t -> x_t-1
  182. sampled_mels = self.noise_scheduler_infer.step(
  183. model_output, t, sampled_mels
  184. ).prev_sample
  185. sampled_mels = self.denormalize_mels(sampled_mels)
  186. with torch.autocast(device_type=sampled_mels.device.type, enabled=False):
  187. # Run vocoder on fp32
  188. fake_audios = self.vocoder.decode(sampled_mels.float())
  189. mel_loss = F.l1_loss(gt_mels * mel_masks, sampled_mels * mel_masks)
  190. self.log(
  191. "val/mel_loss",
  192. mel_loss,
  193. on_step=False,
  194. on_epoch=True,
  195. prog_bar=True,
  196. logger=True,
  197. sync_dist=True,
  198. )
  199. for idx, (
  200. mel,
  201. gen_mel,
  202. audio,
  203. gen_audio,
  204. audio_len,
  205. ) in enumerate(
  206. zip(
  207. gt_mels,
  208. sampled_mels,
  209. audios,
  210. fake_audios,
  211. audio_lengths,
  212. )
  213. ):
  214. mel_len = audio_len // self.hop_length
  215. image_mels = plot_mel(
  216. [
  217. gen_mel[:, :mel_len],
  218. mel[:, :mel_len],
  219. ],
  220. [
  221. "Generated Spectrogram",
  222. "Ground-Truth Spectrogram",
  223. ],
  224. )
  225. if isinstance(self.logger, WandbLogger):
  226. self.logger.experiment.log(
  227. {
  228. "reconstruction_mel": wandb.Image(image_mels, caption="mels"),
  229. "wavs": [
  230. wandb.Audio(
  231. audio[0, :audio_len],
  232. sample_rate=self.sampling_rate,
  233. caption="gt",
  234. ),
  235. wandb.Audio(
  236. gen_audio[0, :audio_len],
  237. sample_rate=self.sampling_rate,
  238. caption="prediction",
  239. ),
  240. ],
  241. },
  242. )
  243. if isinstance(self.logger, TensorBoardLogger):
  244. self.logger.experiment.add_figure(
  245. f"sample-{idx}/mels",
  246. image_mels,
  247. global_step=self.global_step,
  248. )
  249. self.logger.experiment.add_audio(
  250. f"sample-{idx}/wavs/gt",
  251. audio[0, :audio_len],
  252. self.global_step,
  253. sample_rate=self.sampling_rate,
  254. )
  255. self.logger.experiment.add_audio(
  256. f"sample-{idx}/wavs/prediction",
  257. gen_audio[0, :audio_len],
  258. self.global_step,
  259. sample_rate=self.sampling_rate,
  260. )
  261. plt.close(image_mels)