lit_module.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384
  1. import itertools
  2. from typing import Any, Callable, Optional
  3. import lightning as L
  4. import numpy as np
  5. import torch
  6. import torch.nn.functional as F
  7. import wandb
  8. from diffusers.schedulers import DDIMScheduler, UniPCMultistepScheduler
  9. from diffusers.utils.torch_utils import randn_tensor
  10. from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
  11. from matplotlib import pyplot as plt
  12. from torch import nn
  13. from tqdm import tqdm
  14. from fish_speech.models.vq_diffusion.convnext_1d import ConvNext1DModel
  15. from fish_speech.models.vqgan.modules.encoders import (
  16. SpeakerEncoder,
  17. TextEncoder,
  18. VQEncoder,
  19. )
  20. from fish_speech.models.vqgan.utils import plot_mel, sequence_mask
  21. class ConvDownSample(nn.Module):
  22. def __init__(
  23. self,
  24. dims: list,
  25. kernel_sizes: list,
  26. strides: list,
  27. ):
  28. super().__init__()
  29. self.dims = dims
  30. self.kernel_sizes = kernel_sizes
  31. self.strides = strides
  32. self.total_strides = np.prod(self.strides)
  33. self.convs = nn.ModuleList(
  34. [
  35. nn.ModuleList(
  36. [
  37. nn.Conv1d(
  38. in_channels=self.dims[i],
  39. out_channels=self.dims[i + 1],
  40. kernel_size=self.kernel_sizes[i],
  41. stride=self.strides[i],
  42. padding=(self.kernel_sizes[i] - 1) // 2,
  43. ),
  44. nn.LayerNorm(self.dims[i + 1], elementwise_affine=True),
  45. nn.GELU(),
  46. ]
  47. )
  48. for i in range(len(self.dims) - 1)
  49. ]
  50. )
  51. self.apply(self.init_weights)
  52. def init_weights(self, m):
  53. if isinstance(m, nn.Conv1d):
  54. nn.init.normal_(m.weight, std=0.02)
  55. elif isinstance(m, nn.LayerNorm):
  56. nn.init.ones_(m.weight)
  57. nn.init.zeros_(m.bias)
  58. def forward(self, x):
  59. for conv, norm, act in self.convs:
  60. x = conv(x)
  61. x = norm(x.mT).mT
  62. x = act(x)
  63. return x
  64. class VQDiffusion(L.LightningModule):
  65. def __init__(
  66. self,
  67. optimizer: Callable,
  68. lr_scheduler: Callable,
  69. mel_transform: nn.Module,
  70. feature_mel_transform: nn.Module,
  71. vq_encoder: VQEncoder,
  72. speaker_encoder: SpeakerEncoder,
  73. text_encoder: TextEncoder,
  74. denoiser: ConvNext1DModel,
  75. vocoder: nn.Module,
  76. hop_length: int = 640,
  77. sample_rate: int = 32000,
  78. speaker_use_feats: bool = False,
  79. downsample: Optional[nn.Module] = None,
  80. ):
  81. super().__init__()
  82. # Model parameters
  83. self.optimizer_builder = optimizer
  84. self.lr_scheduler_builder = lr_scheduler
  85. # Generator and discriminators
  86. self.mel_transform = mel_transform
  87. self.feature_mel_transform = feature_mel_transform
  88. self.noise_scheduler = DDIMScheduler(
  89. num_train_timesteps=1000,
  90. clip_sample=False,
  91. beta_end=0.01,
  92. )
  93. # Modules
  94. self.vq_encoder = vq_encoder
  95. self.speaker_encoder = speaker_encoder
  96. self.text_encoder = text_encoder
  97. self.denoiser = denoiser
  98. self.downsample = downsample
  99. self.vocoder = vocoder
  100. self.hop_length = hop_length
  101. self.sampling_rate = sample_rate
  102. self.speaker_use_feats = speaker_use_feats
  103. # Freeze vocoder
  104. for param in self.vocoder.parameters():
  105. param.requires_grad = False
  106. def configure_optimizers(self):
  107. optimizer = self.optimizer_builder(self.parameters())
  108. lr_scheduler = self.lr_scheduler_builder(optimizer)
  109. return {
  110. "optimizer": optimizer,
  111. "lr_scheduler": {
  112. "scheduler": lr_scheduler,
  113. "interval": "step",
  114. },
  115. }
  116. def normalize_mels(self, x):
  117. # x is in range -10.1 to 3.1, normalize to -1 to 1
  118. x_min, x_max = -10.1, 3.1
  119. return (x - x_min) / (x_max - x_min) * 2 - 1
  120. def denormalize_mels(self, x):
  121. x_min, x_max = -10.1, 3.1
  122. return (x + 1) / 2 * (x_max - x_min) + x_min
  123. def training_step(self, batch, batch_idx):
  124. audios, audio_lengths = batch["audios"], batch["audio_lengths"]
  125. features, feature_lengths = batch["features"], batch["feature_lengths"]
  126. audios = audios.float()
  127. # features = features.float().mT
  128. audios = audios[:, None, :]
  129. with torch.no_grad():
  130. gt_mels = self.mel_transform(audios, sample_rate=self.sampling_rate)
  131. features = self.feature_mel_transform(
  132. audios, sample_rate=self.sampling_rate
  133. )
  134. if self.downsample is not None:
  135. features = self.downsample(features)
  136. mel_lengths = audio_lengths // self.hop_length
  137. feature_lengths = (
  138. audio_lengths
  139. / self.sampling_rate
  140. * self.feature_mel_transform.sample_rate
  141. / self.feature_mel_transform.hop_length
  142. / (self.downsample.total_strides if self.downsample is not None else 1)
  143. ).long()
  144. feature_masks = torch.unsqueeze(
  145. sequence_mask(feature_lengths, features.shape[2]), 1
  146. ).to(gt_mels.dtype)
  147. mel_masks = torch.unsqueeze(sequence_mask(mel_lengths, gt_mels.shape[2]), 1).to(
  148. gt_mels.dtype
  149. )
  150. if self.speaker_use_feats:
  151. speaker_features = self.speaker_encoder(features, feature_masks)
  152. else:
  153. speaker_features = self.speaker_encoder(gt_mels, mel_masks)
  154. # vq_features is 50 hz, need to convert to true mel size
  155. text_features = self.text_encoder(features, feature_masks)
  156. text_features, vq_loss = self.vq_encoder(text_features, feature_masks)
  157. text_features = F.interpolate(
  158. text_features, size=gt_mels.shape[2], mode="nearest"
  159. )
  160. text_features = text_features + speaker_features
  161. # Sample noise that we'll add to the images
  162. normalized_gt_mels = self.normalize_mels(gt_mels)
  163. noise = torch.randn_like(normalized_gt_mels)
  164. # Sample a random timestep for each image
  165. timesteps = torch.randint(
  166. 0,
  167. self.noise_scheduler.config.num_train_timesteps,
  168. (normalized_gt_mels.shape[0],),
  169. device=normalized_gt_mels.device,
  170. ).long()
  171. # Add noise to the clean images according to the noise magnitude at each timestep
  172. # (this is the forward diffusion process)
  173. noisy_images = self.noise_scheduler.add_noise(
  174. normalized_gt_mels, noise, timesteps
  175. )
  176. # Predict
  177. model_output = self.denoiser(noisy_images, timesteps, mel_masks, text_features)
  178. # MSE loss without the mask
  179. noise_loss = (torch.abs(model_output * mel_masks - noise * mel_masks)).sum() / (
  180. mel_masks.sum() * gt_mels.shape[1]
  181. )
  182. self.log(
  183. "train/noise_loss",
  184. noise_loss,
  185. on_step=True,
  186. on_epoch=False,
  187. prog_bar=True,
  188. logger=True,
  189. sync_dist=True,
  190. )
  191. self.log(
  192. "train/vq_loss",
  193. vq_loss,
  194. on_step=True,
  195. on_epoch=False,
  196. prog_bar=True,
  197. logger=True,
  198. sync_dist=True,
  199. )
  200. return noise_loss + vq_loss
  201. def validation_step(self, batch: Any, batch_idx: int):
  202. audios, audio_lengths = batch["audios"], batch["audio_lengths"]
  203. features, feature_lengths = batch["features"], batch["feature_lengths"]
  204. audios = audios.float()
  205. # features = features.float().mT
  206. audios = audios[:, None, :]
  207. gt_mels = self.mel_transform(audios, sample_rate=self.sampling_rate)
  208. features = self.feature_mel_transform(audios, sample_rate=self.sampling_rate)
  209. if self.downsample is not None:
  210. features = self.downsample(features)
  211. mel_lengths = audio_lengths // self.hop_length
  212. feature_lengths = (
  213. audio_lengths
  214. / self.sampling_rate
  215. * self.feature_mel_transform.sample_rate
  216. / self.feature_mel_transform.hop_length
  217. / (self.downsample.total_strides if self.downsample is not None else 1)
  218. ).long()
  219. feature_masks = torch.unsqueeze(
  220. sequence_mask(feature_lengths, features.shape[2]), 1
  221. ).to(gt_mels.dtype)
  222. mel_masks = torch.unsqueeze(sequence_mask(mel_lengths, gt_mels.shape[2]), 1).to(
  223. gt_mels.dtype
  224. )
  225. if self.speaker_use_feats:
  226. speaker_features = self.speaker_encoder(features, feature_masks)
  227. else:
  228. speaker_features = self.speaker_encoder(gt_mels, mel_masks)
  229. # vq_features is 50 hz, need to convert to true mel size
  230. text_features = self.text_encoder(features, feature_masks)
  231. text_features, vq_loss = self.vq_encoder(text_features, feature_masks)
  232. text_features = F.interpolate(
  233. text_features, size=gt_mels.shape[2], mode="nearest"
  234. )
  235. text_features = text_features + speaker_features
  236. # Begin sampling
  237. sampled_mels = torch.randn_like(gt_mels)
  238. self.noise_scheduler.set_timesteps(50)
  239. for t in tqdm(self.noise_scheduler.timesteps):
  240. timesteps = torch.tensor([t], device=sampled_mels.device, dtype=torch.long)
  241. # 1. predict noise model_output
  242. model_output = self.denoiser(
  243. sampled_mels, timesteps, mel_masks, text_features
  244. )
  245. # 2. compute previous image: x_t -> x_t-1
  246. sampled_mels = self.noise_scheduler.step(
  247. model_output, t, sampled_mels
  248. ).prev_sample
  249. sampled_mels = self.denormalize_mels(sampled_mels)
  250. sampled_mels = sampled_mels * mel_masks
  251. with torch.autocast(device_type=sampled_mels.device.type, enabled=False):
  252. # Run vocoder on fp32
  253. fake_audios = self.vocoder.decode(sampled_mels.float())
  254. mel_loss = F.l1_loss(gt_mels * mel_masks, sampled_mels * mel_masks)
  255. self.log(
  256. "val/mel_loss",
  257. mel_loss,
  258. on_step=False,
  259. on_epoch=True,
  260. prog_bar=True,
  261. logger=True,
  262. sync_dist=True,
  263. )
  264. for idx, (
  265. mel,
  266. gen_mel,
  267. audio,
  268. gen_audio,
  269. audio_len,
  270. ) in enumerate(
  271. zip(
  272. gt_mels,
  273. sampled_mels,
  274. audios,
  275. fake_audios,
  276. audio_lengths,
  277. )
  278. ):
  279. mel_len = audio_len // self.hop_length
  280. image_mels = plot_mel(
  281. [
  282. gen_mel[:, :mel_len],
  283. mel[:, :mel_len],
  284. ],
  285. [
  286. "Generated Spectrogram",
  287. "Ground-Truth Spectrogram",
  288. ],
  289. )
  290. if isinstance(self.logger, WandbLogger):
  291. self.logger.experiment.log(
  292. {
  293. "reconstruction_mel": wandb.Image(image_mels, caption="mels"),
  294. "wavs": [
  295. wandb.Audio(
  296. audio[0, :audio_len],
  297. sample_rate=self.sampling_rate,
  298. caption="gt",
  299. ),
  300. wandb.Audio(
  301. gen_audio[0, :audio_len],
  302. sample_rate=self.sampling_rate,
  303. caption="prediction",
  304. ),
  305. ],
  306. },
  307. )
  308. if isinstance(self.logger, TensorBoardLogger):
  309. self.logger.experiment.add_figure(
  310. f"sample-{idx}/mels",
  311. image_mels,
  312. global_step=self.global_step,
  313. )
  314. self.logger.experiment.add_audio(
  315. f"sample-{idx}/wavs/gt",
  316. audio[0, :audio_len],
  317. self.global_step,
  318. sample_rate=self.sampling_rate,
  319. )
  320. self.logger.experiment.add_audio(
  321. f"sample-{idx}/wavs/prediction",
  322. gen_audio[0, :audio_len],
  323. self.global_step,
  324. sample_rate=self.sampling_rate,
  325. )
  326. plt.close(image_mels)