|
@@ -1,7 +1,8 @@
|
|
|
import itertools
|
|
import itertools
|
|
|
-from typing import Any, Callable
|
|
|
|
|
|
|
+from typing import Any, Callable, Optional
|
|
|
|
|
|
|
|
import lightning as L
|
|
import lightning as L
|
|
|
|
|
+import numpy as np
|
|
|
import torch
|
|
import torch
|
|
|
import torch.nn.functional as F
|
|
import torch.nn.functional as F
|
|
|
import wandb
|
|
import wandb
|
|
@@ -11,7 +12,6 @@ from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
|
|
|
from matplotlib import pyplot as plt
|
|
from matplotlib import pyplot as plt
|
|
|
from torch import nn
|
|
from torch import nn
|
|
|
from tqdm import tqdm
|
|
from tqdm import tqdm
|
|
|
-from transformers import HubertModel
|
|
|
|
|
|
|
|
|
|
from fish_speech.models.vq_diffusion.convnext_1d import ConvNext1DModel
|
|
from fish_speech.models.vq_diffusion.convnext_1d import ConvNext1DModel
|
|
|
from fish_speech.models.vqgan.modules.encoders import (
|
|
from fish_speech.models.vqgan.modules.encoders import (
|
|
@@ -22,6 +22,57 @@ from fish_speech.models.vqgan.modules.encoders import (
|
|
|
from fish_speech.models.vqgan.utils import plot_mel, sequence_mask
|
|
from fish_speech.models.vqgan.utils import plot_mel, sequence_mask
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
+class ConvDownSample(nn.Module):
|
|
|
|
|
+ def __init__(
|
|
|
|
|
+ self,
|
|
|
|
|
+ dims: list,
|
|
|
|
|
+ kernel_sizes: list,
|
|
|
|
|
+ strides: list,
|
|
|
|
|
+ ):
|
|
|
|
|
+ super().__init__()
|
|
|
|
|
+
|
|
|
|
|
+ self.dims = dims
|
|
|
|
|
+ self.kernel_sizes = kernel_sizes
|
|
|
|
|
+ self.strides = strides
|
|
|
|
|
+ self.total_strides = np.prod(self.strides)
|
|
|
|
|
+
|
|
|
|
|
+ self.convs = nn.ModuleList(
|
|
|
|
|
+ [
|
|
|
|
|
+ nn.ModuleList(
|
|
|
|
|
+ [
|
|
|
|
|
+ nn.Conv1d(
|
|
|
|
|
+ in_channels=self.dims[i],
|
|
|
|
|
+ out_channels=self.dims[i + 1],
|
|
|
|
|
+ kernel_size=self.kernel_sizes[i],
|
|
|
|
|
+ stride=self.strides[i],
|
|
|
|
|
+ padding=(self.kernel_sizes[i] - 1) // 2,
|
|
|
|
|
+ ),
|
|
|
|
|
+ nn.LayerNorm(self.dims[i + 1], elementwise_affine=True),
|
|
|
|
|
+ nn.GELU(),
|
|
|
|
|
+ ]
|
|
|
|
|
+ )
|
|
|
|
|
+ for i in range(len(self.dims) - 1)
|
|
|
|
|
+ ]
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ self.apply(self.init_weights)
|
|
|
|
|
+
|
|
|
|
|
+ def init_weights(self, m):
|
|
|
|
|
+ if isinstance(m, nn.Conv1d):
|
|
|
|
|
+ nn.init.normal_(m.weight, std=0.02)
|
|
|
|
|
+ elif isinstance(m, nn.LayerNorm):
|
|
|
|
|
+ nn.init.ones_(m.weight)
|
|
|
|
|
+ nn.init.zeros_(m.bias)
|
|
|
|
|
+
|
|
|
|
|
+ def forward(self, x):
|
|
|
|
|
+ for conv, norm, act in self.convs:
|
|
|
|
|
+ x = conv(x)
|
|
|
|
|
+ x = norm(x.mT).mT
|
|
|
|
|
+ x = act(x)
|
|
|
|
|
+
|
|
|
|
|
+ return x
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
class VQDiffusion(L.LightningModule):
|
|
class VQDiffusion(L.LightningModule):
|
|
|
def __init__(
|
|
def __init__(
|
|
|
self,
|
|
self,
|
|
@@ -37,6 +88,7 @@ class VQDiffusion(L.LightningModule):
|
|
|
hop_length: int = 640,
|
|
hop_length: int = 640,
|
|
|
sample_rate: int = 32000,
|
|
sample_rate: int = 32000,
|
|
|
speaker_use_feats: bool = False,
|
|
speaker_use_feats: bool = False,
|
|
|
|
|
+ downsample: Optional[nn.Module] = None,
|
|
|
):
|
|
):
|
|
|
super().__init__()
|
|
super().__init__()
|
|
|
|
|
|
|
@@ -55,6 +107,7 @@ class VQDiffusion(L.LightningModule):
|
|
|
self.speaker_encoder = speaker_encoder
|
|
self.speaker_encoder = speaker_encoder
|
|
|
self.text_encoder = text_encoder
|
|
self.text_encoder = text_encoder
|
|
|
self.denoiser = denoiser
|
|
self.denoiser = denoiser
|
|
|
|
|
+ self.downsample = downsample
|
|
|
|
|
|
|
|
self.vocoder = vocoder
|
|
self.vocoder = vocoder
|
|
|
self.hop_length = hop_length
|
|
self.hop_length = hop_length
|
|
@@ -100,8 +153,18 @@ class VQDiffusion(L.LightningModule):
|
|
|
audios, sample_rate=self.sampling_rate
|
|
audios, sample_rate=self.sampling_rate
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
|
|
+ if self.downsample is not None:
|
|
|
|
|
+ features = self.downsample(features)
|
|
|
|
|
+
|
|
|
mel_lengths = audio_lengths // self.hop_length
|
|
mel_lengths = audio_lengths // self.hop_length
|
|
|
- feature_lengths = audio_lengths // self.hop_length // 2
|
|
|
|
|
|
|
+ feature_lengths = (
|
|
|
|
|
+ audio_lengths
|
|
|
|
|
+ / self.sampling_rate
|
|
|
|
|
+ * self.feature_mel_transform.sample_rate
|
|
|
|
|
+ / self.feature_mel_transform.hop_length
|
|
|
|
|
+ / (self.downsample.total_strides if self.downsample is not None else 1)
|
|
|
|
|
+ ).long()
|
|
|
|
|
+
|
|
|
feature_masks = torch.unsqueeze(
|
|
feature_masks = torch.unsqueeze(
|
|
|
sequence_mask(feature_lengths, features.shape[2]), 1
|
|
sequence_mask(feature_lengths, features.shape[2]), 1
|
|
|
).to(gt_mels.dtype)
|
|
).to(gt_mels.dtype)
|
|
@@ -181,8 +244,18 @@ class VQDiffusion(L.LightningModule):
|
|
|
gt_mels = self.mel_transform(audios, sample_rate=self.sampling_rate)
|
|
gt_mels = self.mel_transform(audios, sample_rate=self.sampling_rate)
|
|
|
features = self.feature_mel_transform(audios, sample_rate=self.sampling_rate)
|
|
features = self.feature_mel_transform(audios, sample_rate=self.sampling_rate)
|
|
|
|
|
|
|
|
|
|
+ if self.downsample is not None:
|
|
|
|
|
+ features = self.downsample(features)
|
|
|
|
|
+
|
|
|
mel_lengths = audio_lengths // self.hop_length
|
|
mel_lengths = audio_lengths // self.hop_length
|
|
|
- feature_lengths = audio_lengths // self.hop_length // 2
|
|
|
|
|
|
|
+ feature_lengths = (
|
|
|
|
|
+ audio_lengths
|
|
|
|
|
+ / self.sampling_rate
|
|
|
|
|
+ * self.feature_mel_transform.sample_rate
|
|
|
|
|
+ / self.feature_mel_transform.hop_length
|
|
|
|
|
+ / (self.downsample.total_strides if self.downsample is not None else 1)
|
|
|
|
|
+ ).long()
|
|
|
|
|
+
|
|
|
feature_masks = torch.unsqueeze(
|
|
feature_masks = torch.unsqueeze(
|
|
|
sequence_mask(feature_lengths, features.shape[2]), 1
|
|
sequence_mask(feature_lengths, features.shape[2]), 1
|
|
|
).to(gt_mels.dtype)
|
|
).to(gt_mels.dtype)
|
|
@@ -222,6 +295,7 @@ class VQDiffusion(L.LightningModule):
|
|
|
).prev_sample
|
|
).prev_sample
|
|
|
|
|
|
|
|
sampled_mels = self.denormalize_mels(sampled_mels)
|
|
sampled_mels = self.denormalize_mels(sampled_mels)
|
|
|
|
|
+ sampled_mels = sampled_mels * mel_masks
|
|
|
|
|
|
|
|
with torch.autocast(device_type=sampled_mels.device.type, enabled=False):
|
|
with torch.autocast(device_type=sampled_mels.device.type, enabled=False):
|
|
|
# Run vocoder on fp32
|
|
# Run vocoder on fp32
|