lit_module.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381
  1. import itertools
  2. from typing import Any, Callable, Optional
  3. import lightning as L
  4. import numpy as np
  5. import torch
  6. import torch.nn.functional as F
  7. import wandb
  8. from diffusers.schedulers import DDIMScheduler, UniPCMultistepScheduler
  9. from diffusers.utils.torch_utils import randn_tensor
  10. from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
  11. from matplotlib import pyplot as plt
  12. from torch import nn
  13. from tqdm import tqdm
  14. from fish_speech.models.vq_diffusion.convnext_1d import ConvNext1DModel
  15. from fish_speech.models.vqgan.modules.encoders import (
  16. SpeakerEncoder,
  17. TextEncoder,
  18. VQEncoder,
  19. )
  20. from fish_speech.models.vqgan.utils import plot_mel, sequence_mask
  21. class ConvDownSample(nn.Module):
  22. def __init__(
  23. self,
  24. dims: list,
  25. kernel_sizes: list,
  26. strides: list,
  27. ):
  28. super().__init__()
  29. self.dims = dims
  30. self.kernel_sizes = kernel_sizes
  31. self.strides = strides
  32. self.total_strides = np.prod(self.strides)
  33. self.convs = nn.ModuleList(
  34. [
  35. nn.ModuleList(
  36. [
  37. nn.Conv1d(
  38. in_channels=self.dims[i],
  39. out_channels=self.dims[i + 1],
  40. kernel_size=self.kernel_sizes[i],
  41. stride=self.strides[i],
  42. padding=(self.kernel_sizes[i] - 1) // 2,
  43. ),
  44. nn.LayerNorm(self.dims[i + 1], elementwise_affine=True),
  45. nn.GELU(),
  46. ]
  47. )
  48. for i in range(len(self.dims) - 1)
  49. ]
  50. )
  51. self.apply(self.init_weights)
  52. def init_weights(self, m):
  53. if isinstance(m, nn.Conv1d):
  54. nn.init.normal_(m.weight, std=0.02)
  55. elif isinstance(m, nn.LayerNorm):
  56. nn.init.ones_(m.weight)
  57. nn.init.zeros_(m.bias)
  58. def forward(self, x):
  59. for conv, norm, act in self.convs:
  60. x = conv(x)
  61. x = norm(x.mT).mT
  62. x = act(x)
  63. return x
  64. class VQDiffusion(L.LightningModule):
  65. def __init__(
  66. self,
  67. optimizer: Callable,
  68. lr_scheduler: Callable,
  69. mel_transform: nn.Module,
  70. feature_mel_transform: nn.Module,
  71. vq_encoder: VQEncoder,
  72. speaker_encoder: SpeakerEncoder,
  73. text_encoder: TextEncoder,
  74. denoiser: ConvNext1DModel,
  75. vocoder: nn.Module,
  76. hop_length: int = 640,
  77. sample_rate: int = 32000,
  78. speaker_use_feats: bool = False,
  79. downsample: Optional[nn.Module] = None,
  80. ):
  81. super().__init__()
  82. # Model parameters
  83. self.optimizer_builder = optimizer
  84. self.lr_scheduler_builder = lr_scheduler
  85. # Generator and discriminators
  86. self.mel_transform = mel_transform
  87. self.feature_mel_transform = feature_mel_transform
  88. self.noise_scheduler_train = DDIMScheduler(num_train_timesteps=1000)
  89. self.noise_scheduler_infer = UniPCMultistepScheduler(num_train_timesteps=1000)
  90. # Modules
  91. self.vq_encoder = vq_encoder
  92. self.speaker_encoder = speaker_encoder
  93. self.text_encoder = text_encoder
  94. self.denoiser = denoiser
  95. self.downsample = downsample
  96. self.vocoder = vocoder
  97. self.hop_length = hop_length
  98. self.sampling_rate = sample_rate
  99. self.speaker_use_feats = speaker_use_feats
  100. # Freeze vocoder
  101. for param in self.vocoder.parameters():
  102. param.requires_grad = False
  103. def configure_optimizers(self):
  104. optimizer = self.optimizer_builder(self.parameters())
  105. lr_scheduler = self.lr_scheduler_builder(optimizer)
  106. return {
  107. "optimizer": optimizer,
  108. "lr_scheduler": {
  109. "scheduler": lr_scheduler,
  110. "interval": "step",
  111. },
  112. }
  113. def normalize_mels(self, x):
  114. # x is in range -10.1 to 3.1, normalize to -1 to 1
  115. x_min, x_max = -10.1, 3.1
  116. return (x - x_min) / (x_max - x_min) * 2 - 1
  117. def denormalize_mels(self, x):
  118. x_min, x_max = -10.1, 3.1
  119. return (x + 1) / 2 * (x_max - x_min) + x_min
  120. def training_step(self, batch, batch_idx):
  121. audios, audio_lengths = batch["audios"], batch["audio_lengths"]
  122. features, feature_lengths = batch["features"], batch["feature_lengths"]
  123. audios = audios.float()
  124. # features = features.float().mT
  125. audios = audios[:, None, :]
  126. with torch.no_grad():
  127. gt_mels = self.mel_transform(audios, sample_rate=self.sampling_rate)
  128. features = self.feature_mel_transform(
  129. audios, sample_rate=self.sampling_rate
  130. )
  131. if self.downsample is not None:
  132. features = self.downsample(features)
  133. mel_lengths = audio_lengths // self.hop_length
  134. feature_lengths = (
  135. audio_lengths
  136. / self.sampling_rate
  137. * self.feature_mel_transform.sample_rate
  138. / self.feature_mel_transform.hop_length
  139. / (self.downsample.total_strides if self.downsample is not None else 1)
  140. ).long()
  141. feature_masks = torch.unsqueeze(
  142. sequence_mask(feature_lengths, features.shape[2]), 1
  143. ).to(gt_mels.dtype)
  144. mel_masks = torch.unsqueeze(sequence_mask(mel_lengths, gt_mels.shape[2]), 1).to(
  145. gt_mels.dtype
  146. )
  147. if self.speaker_use_feats:
  148. speaker_features = self.speaker_encoder(features, feature_masks)
  149. else:
  150. speaker_features = self.speaker_encoder(gt_mels, mel_masks)
  151. # vq_features is 50 hz, need to convert to true mel size
  152. text_features = self.text_encoder(features, feature_masks)
  153. text_features, vq_loss = self.vq_encoder(text_features, feature_masks)
  154. text_features = F.interpolate(
  155. text_features, size=gt_mels.shape[2], mode="nearest"
  156. )
  157. text_features = text_features + speaker_features
  158. # Sample noise that we'll add to the images
  159. normalized_gt_mels = self.normalize_mels(gt_mels)
  160. noise = torch.randn_like(normalized_gt_mels)
  161. # Sample a random timestep for each image
  162. timesteps = torch.randint(
  163. 0,
  164. self.noise_scheduler_train.config.num_train_timesteps,
  165. (normalized_gt_mels.shape[0],),
  166. device=normalized_gt_mels.device,
  167. ).long()
  168. # Add noise to the clean images according to the noise magnitude at each timestep
  169. # (this is the forward diffusion process)
  170. noisy_images = self.noise_scheduler_train.add_noise(
  171. normalized_gt_mels, noise, timesteps
  172. )
  173. # Predict
  174. model_output = self.denoiser(noisy_images, timesteps, mel_masks, text_features)
  175. # MSE loss without the mask
  176. noise_loss = (torch.abs(model_output * mel_masks - noise * mel_masks)).sum() / (
  177. mel_masks.sum() * gt_mels.shape[1]
  178. )
  179. self.log(
  180. "train/noise_loss",
  181. noise_loss,
  182. on_step=True,
  183. on_epoch=False,
  184. prog_bar=True,
  185. logger=True,
  186. sync_dist=True,
  187. )
  188. self.log(
  189. "train/vq_loss",
  190. vq_loss,
  191. on_step=True,
  192. on_epoch=False,
  193. prog_bar=True,
  194. logger=True,
  195. sync_dist=True,
  196. )
  197. return noise_loss + vq_loss
  198. def validation_step(self, batch: Any, batch_idx: int):
  199. audios, audio_lengths = batch["audios"], batch["audio_lengths"]
  200. features, feature_lengths = batch["features"], batch["feature_lengths"]
  201. audios = audios.float()
  202. # features = features.float().mT
  203. audios = audios[:, None, :]
  204. gt_mels = self.mel_transform(audios, sample_rate=self.sampling_rate)
  205. features = self.feature_mel_transform(audios, sample_rate=self.sampling_rate)
  206. if self.downsample is not None:
  207. features = self.downsample(features)
  208. mel_lengths = audio_lengths // self.hop_length
  209. feature_lengths = (
  210. audio_lengths
  211. / self.sampling_rate
  212. * self.feature_mel_transform.sample_rate
  213. / self.feature_mel_transform.hop_length
  214. / (self.downsample.total_strides if self.downsample is not None else 1)
  215. ).long()
  216. feature_masks = torch.unsqueeze(
  217. sequence_mask(feature_lengths, features.shape[2]), 1
  218. ).to(gt_mels.dtype)
  219. mel_masks = torch.unsqueeze(sequence_mask(mel_lengths, gt_mels.shape[2]), 1).to(
  220. gt_mels.dtype
  221. )
  222. if self.speaker_use_feats:
  223. speaker_features = self.speaker_encoder(features, feature_masks)
  224. else:
  225. speaker_features = self.speaker_encoder(gt_mels, mel_masks)
  226. # vq_features is 50 hz, need to convert to true mel size
  227. text_features = self.text_encoder(features, feature_masks)
  228. text_features, vq_loss = self.vq_encoder(text_features, feature_masks)
  229. text_features = F.interpolate(
  230. text_features, size=gt_mels.shape[2], mode="nearest"
  231. )
  232. text_features = text_features + speaker_features
  233. # Begin sampling
  234. sampled_mels = torch.randn_like(gt_mels)
  235. self.noise_scheduler_infer.set_timesteps(100)
  236. for t in tqdm(self.noise_scheduler_infer.timesteps):
  237. timesteps = torch.tensor([t], device=sampled_mels.device, dtype=torch.long)
  238. # 1. predict noise model_output
  239. model_output = self.denoiser(
  240. sampled_mels, timesteps, mel_masks, text_features
  241. )
  242. # 2. compute previous image: x_t -> x_t-1
  243. sampled_mels = self.noise_scheduler_infer.step(
  244. model_output, t, sampled_mels
  245. ).prev_sample
  246. sampled_mels = self.denormalize_mels(sampled_mels)
  247. sampled_mels = sampled_mels * mel_masks
  248. with torch.autocast(device_type=sampled_mels.device.type, enabled=False):
  249. # Run vocoder on fp32
  250. fake_audios = self.vocoder.decode(sampled_mels.float())
  251. mel_loss = F.l1_loss(gt_mels * mel_masks, sampled_mels * mel_masks)
  252. self.log(
  253. "val/mel_loss",
  254. mel_loss,
  255. on_step=False,
  256. on_epoch=True,
  257. prog_bar=True,
  258. logger=True,
  259. sync_dist=True,
  260. )
  261. for idx, (
  262. mel,
  263. gen_mel,
  264. audio,
  265. gen_audio,
  266. audio_len,
  267. ) in enumerate(
  268. zip(
  269. gt_mels,
  270. sampled_mels,
  271. audios,
  272. fake_audios,
  273. audio_lengths,
  274. )
  275. ):
  276. mel_len = audio_len // self.hop_length
  277. image_mels = plot_mel(
  278. [
  279. gen_mel[:, :mel_len],
  280. mel[:, :mel_len],
  281. ],
  282. [
  283. "Generated Spectrogram",
  284. "Ground-Truth Spectrogram",
  285. ],
  286. )
  287. if isinstance(self.logger, WandbLogger):
  288. self.logger.experiment.log(
  289. {
  290. "reconstruction_mel": wandb.Image(image_mels, caption="mels"),
  291. "wavs": [
  292. wandb.Audio(
  293. audio[0, :audio_len],
  294. sample_rate=self.sampling_rate,
  295. caption="gt",
  296. ),
  297. wandb.Audio(
  298. gen_audio[0, :audio_len],
  299. sample_rate=self.sampling_rate,
  300. caption="prediction",
  301. ),
  302. ],
  303. },
  304. )
  305. if isinstance(self.logger, TensorBoardLogger):
  306. self.logger.experiment.add_figure(
  307. f"sample-{idx}/mels",
  308. image_mels,
  309. global_step=self.global_step,
  310. )
  311. self.logger.experiment.add_audio(
  312. f"sample-{idx}/wavs/gt",
  313. audio[0, :audio_len],
  314. self.global_step,
  315. sample_rate=self.sampling_rate,
  316. )
  317. self.logger.experiment.add_audio(
  318. f"sample-{idx}/wavs/prediction",
  319. gen_audio[0, :audio_len],
  320. self.global_step,
  321. sample_rate=self.sampling_rate,
  322. )
  323. plt.close(image_mels)