lit_module.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373
  1. import itertools
  2. from typing import Any, Callable, Optional
  3. import lightning as L
  4. import numpy as np
  5. import torch
  6. import torch.nn.functional as F
  7. import wandb
  8. from diffusers.schedulers import DDIMScheduler, UniPCMultistepScheduler
  9. from diffusers.utils.torch_utils import randn_tensor
  10. from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
  11. from matplotlib import pyplot as plt
  12. from torch import nn
  13. from tqdm import tqdm
  14. from fish_speech.models.vq_diffusion.convnext_1d import ConvNext1DModel
  15. from fish_speech.models.vqgan.modules.encoders import (
  16. SpeakerEncoder,
  17. TextEncoder,
  18. VQEncoder,
  19. )
  20. from fish_speech.models.vqgan.utils import plot_mel, sequence_mask
  21. class ConvDownSample(nn.Module):
  22. def __init__(
  23. self,
  24. dims: list,
  25. kernel_sizes: list,
  26. strides: list,
  27. ):
  28. super().__init__()
  29. self.dims = dims
  30. self.kernel_sizes = kernel_sizes
  31. self.strides = strides
  32. self.total_strides = np.prod(self.strides)
  33. self.convs = nn.ModuleList(
  34. [
  35. nn.ModuleList(
  36. [
  37. nn.Conv1d(
  38. in_channels=self.dims[i],
  39. out_channels=self.dims[i + 1],
  40. kernel_size=self.kernel_sizes[i],
  41. stride=self.strides[i],
  42. padding=(self.kernel_sizes[i] - 1) // 2,
  43. ),
  44. nn.LayerNorm(self.dims[i + 1], elementwise_affine=True),
  45. nn.GELU(),
  46. ]
  47. )
  48. for i in range(len(self.dims) - 1)
  49. ]
  50. )
  51. self.apply(self.init_weights)
  52. def init_weights(self, m):
  53. if isinstance(m, nn.Conv1d):
  54. nn.init.normal_(m.weight, std=0.02)
  55. elif isinstance(m, nn.LayerNorm):
  56. nn.init.ones_(m.weight)
  57. nn.init.zeros_(m.bias)
  58. def forward(self, x):
  59. for conv, norm, act in self.convs:
  60. x = conv(x)
  61. x = norm(x.mT).mT
  62. x = act(x)
  63. return x
  64. class VQDiffusion(L.LightningModule):
  65. def __init__(
  66. self,
  67. optimizer: Callable,
  68. lr_scheduler: Callable,
  69. mel_transform: nn.Module,
  70. feature_mel_transform: nn.Module,
  71. vq_encoder: VQEncoder,
  72. speaker_encoder: SpeakerEncoder,
  73. text_encoder: TextEncoder,
  74. decoder: ConvNext1DModel,
  75. postnet: ConvNext1DModel,
  76. vocoder: nn.Module,
  77. hop_length: int = 640,
  78. sample_rate: int = 32000,
  79. speaker_use_feats: bool = False,
  80. downsample: Optional[nn.Module] = None,
  81. ):
  82. super().__init__()
  83. # Model parameters
  84. self.optimizer_builder = optimizer
  85. self.lr_scheduler_builder = lr_scheduler
  86. # Generator and discriminators
  87. self.mel_transform = mel_transform
  88. self.feature_mel_transform = feature_mel_transform
  89. self.noise_scheduler = DDIMScheduler(
  90. num_train_timesteps=1000,
  91. clip_sample=False,
  92. beta_end=0.01,
  93. )
  94. # Modules
  95. self.vq_encoder = vq_encoder
  96. self.speaker_encoder = speaker_encoder
  97. self.text_encoder = text_encoder
  98. self.decoder = decoder
  99. self.postnet = postnet
  100. self.downsample = downsample
  101. self.vocoder = vocoder
  102. self.hop_length = hop_length
  103. self.sampling_rate = sample_rate
  104. self.speaker_use_feats = speaker_use_feats
  105. # Freeze vocoder
  106. for param in self.vocoder.parameters():
  107. param.requires_grad = False
  108. def configure_optimizers(self):
  109. optimizer = self.optimizer_builder(self.parameters())
  110. lr_scheduler = self.lr_scheduler_builder(optimizer)
  111. return {
  112. "optimizer": optimizer,
  113. "lr_scheduler": {
  114. "scheduler": lr_scheduler,
  115. "interval": "step",
  116. },
  117. }
  118. def normalize_mels(self, x):
  119. # x is in range -10.1 to 3.1, normalize to -1 to 1
  120. x_min, x_max = -10.1, 3.1
  121. return (x - x_min) / (x_max - x_min) * 2 - 1
  122. def denormalize_mels(self, x):
  123. x_min, x_max = -10.1, 3.1
  124. return (x + 1) / 2 * (x_max - x_min) + x_min
  125. def training_step(self, batch, batch_idx):
  126. audios, audio_lengths = batch["audios"], batch["audio_lengths"]
  127. features, feature_lengths = batch["features"], batch["feature_lengths"]
  128. audios = audios.float()
  129. # features = features.float().mT
  130. audios = audios[:, None, :]
  131. with torch.no_grad():
  132. gt_mels = self.mel_transform(audios, sample_rate=self.sampling_rate)
  133. features = self.feature_mel_transform(
  134. audios, sample_rate=self.sampling_rate
  135. )
  136. if self.downsample is not None:
  137. features = self.downsample(features)
  138. mel_lengths = audio_lengths // self.hop_length
  139. feature_lengths = (
  140. audio_lengths
  141. / self.sampling_rate
  142. * self.feature_mel_transform.sample_rate
  143. / self.feature_mel_transform.hop_length
  144. / (self.downsample.total_strides if self.downsample is not None else 1)
  145. ).long()
  146. feature_masks = torch.unsqueeze(
  147. sequence_mask(feature_lengths, features.shape[2]), 1
  148. ).to(gt_mels.dtype)
  149. mel_masks = torch.unsqueeze(sequence_mask(mel_lengths, gt_mels.shape[2]), 1).to(
  150. gt_mels.dtype
  151. )
  152. if self.speaker_use_feats:
  153. speaker_features = self.speaker_encoder(features, feature_masks)
  154. else:
  155. speaker_features = self.speaker_encoder(gt_mels, mel_masks)
  156. # vq_features is 50 hz, need to convert to true mel size
  157. text_features = self.text_encoder(features, feature_masks)
  158. text_features, vq_loss = self.vq_encoder(text_features, feature_masks)
  159. text_features = F.interpolate(
  160. text_features, size=gt_mels.shape[2], mode="nearest"
  161. )
  162. # Sample noise that we'll add to the images
  163. normalized_gt_mels = gt_mels / 2.303
  164. # Predict
  165. mels = self.decoder(text_features, mel_masks, g=speaker_features)
  166. t = torch.tensor([0] * mels.shape[0], device=mels.device, dtype=torch.long)
  167. postnet_mels = self.postnet(mels, t, mel_masks)
  168. # MSE loss without the mask
  169. mel_loss = F.l1_loss(
  170. mels * mel_masks,
  171. normalized_gt_mels * mel_masks,
  172. )
  173. postnet_loss = F.l1_loss(
  174. postnet_mels * mel_masks,
  175. normalized_gt_mels * mel_masks,
  176. )
  177. self.log(
  178. "train/mel_loss",
  179. mel_loss,
  180. on_step=True,
  181. on_epoch=False,
  182. prog_bar=True,
  183. logger=True,
  184. sync_dist=True,
  185. )
  186. self.log(
  187. "train/postnet_loss",
  188. postnet_loss,
  189. on_step=True,
  190. on_epoch=False,
  191. prog_bar=True,
  192. logger=True,
  193. sync_dist=True,
  194. )
  195. self.log(
  196. "train/vq_loss",
  197. vq_loss,
  198. on_step=True,
  199. on_epoch=False,
  200. prog_bar=True,
  201. logger=True,
  202. sync_dist=True,
  203. )
  204. return vq_loss + mel_loss + postnet_loss
  205. def validation_step(self, batch: Any, batch_idx: int):
  206. audios, audio_lengths = batch["audios"], batch["audio_lengths"]
  207. features, feature_lengths = batch["features"], batch["feature_lengths"]
  208. audios = audios.float()
  209. # features = features.float().mT
  210. audios = audios[:, None, :]
  211. gt_mels = self.mel_transform(audios, sample_rate=self.sampling_rate)
  212. features = self.feature_mel_transform(audios, sample_rate=self.sampling_rate)
  213. if self.downsample is not None:
  214. features = self.downsample(features)
  215. mel_lengths = audio_lengths // self.hop_length
  216. feature_lengths = (
  217. audio_lengths
  218. / self.sampling_rate
  219. * self.feature_mel_transform.sample_rate
  220. / self.feature_mel_transform.hop_length
  221. / (self.downsample.total_strides if self.downsample is not None else 1)
  222. ).long()
  223. feature_masks = torch.unsqueeze(
  224. sequence_mask(feature_lengths, features.shape[2]), 1
  225. ).to(gt_mels.dtype)
  226. mel_masks = torch.unsqueeze(sequence_mask(mel_lengths, gt_mels.shape[2]), 1).to(
  227. gt_mels.dtype
  228. )
  229. if self.speaker_use_feats:
  230. speaker_features = self.speaker_encoder(features, feature_masks)
  231. else:
  232. speaker_features = self.speaker_encoder(gt_mels, mel_masks)
  233. # vq_features is 50 hz, need to convert to true mel size
  234. text_features = self.text_encoder(features, feature_masks)
  235. text_features, vq_loss = self.vq_encoder(text_features, feature_masks)
  236. text_features = F.interpolate(
  237. text_features, size=gt_mels.shape[2], mode="nearest"
  238. )
  239. # Begin sampling
  240. mels = self.decoder(text_features, mel_masks, g=speaker_features)
  241. t = torch.tensor([0] * mels.shape[0], device=mels.device, dtype=torch.long)
  242. postnet_mels = self.postnet(mels, t, mel_masks)
  243. sampled_mels = postnet_mels * 2.303
  244. sampled_mels = sampled_mels * mel_masks
  245. with torch.autocast(device_type=sampled_mels.device.type, enabled=False):
  246. # Run vocoder on fp32
  247. fake_audios = self.vocoder.decode(sampled_mels.float())
  248. mel_loss = F.l1_loss(gt_mels * mel_masks, sampled_mels * mel_masks)
  249. self.log(
  250. "val/mel_loss",
  251. mel_loss,
  252. on_step=False,
  253. on_epoch=True,
  254. prog_bar=True,
  255. logger=True,
  256. sync_dist=True,
  257. )
  258. for idx, (
  259. mel,
  260. gen_mel,
  261. audio,
  262. gen_audio,
  263. audio_len,
  264. ) in enumerate(
  265. zip(
  266. gt_mels,
  267. sampled_mels,
  268. audios,
  269. fake_audios,
  270. audio_lengths,
  271. )
  272. ):
  273. mel_len = audio_len // self.hop_length
  274. image_mels = plot_mel(
  275. [
  276. gen_mel[:, :mel_len],
  277. mel[:, :mel_len],
  278. ],
  279. [
  280. "Generated Spectrogram",
  281. "Ground-Truth Spectrogram",
  282. ],
  283. )
  284. if isinstance(self.logger, WandbLogger):
  285. self.logger.experiment.log(
  286. {
  287. "reconstruction_mel": wandb.Image(image_mels, caption="mels"),
  288. "wavs": [
  289. wandb.Audio(
  290. audio[0, :audio_len],
  291. sample_rate=self.sampling_rate,
  292. caption="gt",
  293. ),
  294. wandb.Audio(
  295. gen_audio[0, :audio_len],
  296. sample_rate=self.sampling_rate,
  297. caption="prediction",
  298. ),
  299. ],
  300. },
  301. )
  302. if isinstance(self.logger, TensorBoardLogger):
  303. self.logger.experiment.add_figure(
  304. f"sample-{idx}/mels",
  305. image_mels,
  306. global_step=self.global_step,
  307. )
  308. self.logger.experiment.add_audio(
  309. f"sample-{idx}/wavs/gt",
  310. audio[0, :audio_len],
  311. self.global_step,
  312. sample_rate=self.sampling_rate,
  313. )
  314. self.logger.experiment.add_audio(
  315. f"sample-{idx}/wavs/prediction",
  316. gen_audio[0, :audio_len],
  317. self.global_step,
  318. sample_rate=self.sampling_rate,
  319. )
  320. plt.close(image_mels)