Bläddra i källkod

Add auto upsample to vqgan

Lengyue 2 år sedan
förälder
incheckning
a0f1468be8

+ 11 - 1
fish_speech/configs/vqgan_pretrain.yaml

@@ -87,7 +87,17 @@ model:
     _target_: fish_speech.models.vqgan.modules.firefly.FireflyBase
     ckpt_path: null # You may download the pretrained vocoder and set the path here
 
-  mel_transform:
+  encode_mel_transform:
+    _target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram
+    sample_rate: ${sample_rate}
+    n_fft: ${n_fft}
+    hop_length: ${hop_length}
+    win_length: ${win_length}
+    n_mels: ${num_mels}
+    f_min: 0.0
+    f_max: 8000.0
+
+  gt_mel_transform:
     _target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram
     sample_rate: ${sample_rate}
     n_fft: ${n_fft}

+ 46 - 22
fish_speech/models/vqgan/lit_module.py

@@ -25,7 +25,8 @@ class VQGAN(L.LightningModule):
         decoder: WaveNet,
         discriminator: Discriminator,
         vocoder: nn.Module,
-        mel_transform: nn.Module,
+        encode_mel_transform: nn.Module,
+        gt_mel_transform: nn.Module,
         weight_adv: float = 1.0,
         weight_vq: float = 1.0,
         weight_mel: float = 1.0,
@@ -44,7 +45,11 @@ class VQGAN(L.LightningModule):
         self.decoder = decoder
         self.vocoder = vocoder
         self.discriminator = discriminator
-        self.mel_transform = mel_transform
+        self.encode_mel_transform = encode_mel_transform
+        self.gt_mel_transform = gt_mel_transform
+
+        # A simple linear layer to project quality to condition channels
+        self.quality_projection = nn.Linear(1, 768)
 
         # Freeze vocoder
         for param in self.vocoder.parameters():
@@ -84,6 +89,7 @@ class VQGAN(L.LightningModule):
                 self.encoder.parameters(),
                 self.quantizer.parameters(),
                 self.decoder.parameters(),
+                self.quality_projection.parameters(),
             )
         )
         optimizer_discriminator = self.optimizer_builder(
@@ -121,20 +127,27 @@ class VQGAN(L.LightningModule):
         audios = audios[:, None, :]
 
         with torch.no_grad():
-            gt_mels = self.mel_transform(audios)
+            encoded_mels = self.encode_mel_transform(audios)
+            gt_mels = self.gt_mel_transform(audios)
+            quality = ((gt_mels.mean(-1) > -8).sum(-1) - 90) / 10
+            quality = quality.unsqueeze(-1)
 
-        mel_lengths = audio_lengths // self.mel_transform.hop_length
+        mel_lengths = audio_lengths // self.gt_mel_transform.hop_length
         mel_masks = sequence_mask(mel_lengths, gt_mels.shape[2])
         mel_masks_float_conv = mel_masks[:, None, :].float()
         gt_mels = gt_mels * mel_masks_float_conv
+        encoded_mels = encoded_mels * mel_masks_float_conv
 
         # Encode
-        encoded_features = self.encoder(gt_mels) * mel_masks_float_conv
+        encoded_features = self.encoder(encoded_mels) * mel_masks_float_conv
 
         # Quantize
         vq_result = self.quantizer(encoded_features)
         loss_vq = getattr("vq_result", "loss", 0.0)
         vq_recon_features = vq_result.z * mel_masks_float_conv
+        vq_recon_features = (
+            vq_recon_features + self.quality_projection(quality)[:, :, None]
+        )
 
         # VQ Decode
         gen_mel = (
@@ -233,14 +246,6 @@ class VQGAN(L.LightningModule):
             prog_bar=False,
             logger=True,
         )
-        self.log(
-            "train/generator/loss_speaker_id",
-            loss_speaker_id,
-            on_step=True,
-            on_epoch=False,
-            prog_bar=False,
-            logger=True,
-        )
 
         # Generator backward
         optim_g.zero_grad()
@@ -260,17 +265,29 @@ class VQGAN(L.LightningModule):
         audios = audios.float()
         audios = audios[:, None, :]
 
-        gt_mels = self.mel_transform(audios)
-        mel_lengths = audio_lengths // self.mel_transform.hop_length
+        encoded_mels = self.encode_mel_transform(audios)
+        gt_mels = self.gt_mel_transform(audios)
+
+        mel_lengths = audio_lengths // self.gt_mel_transform.hop_length
         mel_masks = sequence_mask(mel_lengths, gt_mels.shape[2])
         mel_masks_float_conv = mel_masks[:, None, :].float()
         gt_mels = gt_mels * mel_masks_float_conv
+        encoded_mels = encoded_mels * mel_masks_float_conv
 
         # Encode
-        encoded_features = self.encoder(gt_mels) * mel_masks_float_conv
+        encoded_features = self.encoder(encoded_mels) * mel_masks_float_conv
 
         # Quantize
         vq_recon_features = self.quantizer(encoded_features).z * mel_masks_float_conv
+        vq_recon_features = (
+            vq_recon_features
+            + self.quality_projection(
+                torch.ones(
+                    vq_recon_features.shape[0], 1, device=vq_recon_features.device
+                )
+                * 2
+            )[:, :, None]
+        )
 
         # VQ Decode
         gen_aux_mels = (
@@ -319,7 +336,7 @@ class VQGAN(L.LightningModule):
             if idx > 4:
                 break
 
-            mel_len = audio_len // self.mel_transform.hop_length
+            mel_len = audio_len // self.gt_mel_transform.hop_length
 
             image_mels = plot_mel(
                 [
@@ -386,14 +403,14 @@ class VQGAN(L.LightningModule):
     def encode(self, audios, audio_lengths):
         audios = audios.float()
 
-        gt_mels = self.mel_transform(audios)
-        mel_lengths = audio_lengths // self.mel_transform.hop_length
-        mel_masks = sequence_mask(mel_lengths, gt_mels.shape[2])
+        mels = self.encode_mel_transform(audios)
+        mel_lengths = audio_lengths // self.encode_mel_transform.hop_length
+        mel_masks = sequence_mask(mel_lengths, mels.shape[2])
         mel_masks_float_conv = mel_masks[:, None, :].float()
-        gt_mels = gt_mels * mel_masks_float_conv
+        mels = mels * mel_masks_float_conv
 
         # Encode
-        encoded_features = self.encoder(gt_mels) * mel_masks_float_conv
+        encoded_features = self.encoder(mels) * mel_masks_float_conv
         feature_lengths = mel_lengths // math.prod(self.quantizer.downsample_factor)
 
         return self.quantizer.encode(encoded_features), feature_lengths
@@ -404,6 +421,13 @@ class VQGAN(L.LightningModule):
         mel_masks_float_conv = mel_masks[:, None, :].float()
 
         z = self.quantizer.decode(indices) * mel_masks_float_conv
+        z = (
+            z
+            + self.quality_projection(torch.ones(z.shape[0], 1, device=z.device) * 2)[
+                :, :, None
+            ]
+        )
+
         gen_mel = (
             self.decoder(
                 torch.randn_like(z) * mel_masks_float_conv,

+ 113 - 0
fish_speech/models/vqgan/modules/reference.py

@@ -0,0 +1,113 @@
+from typing import Optional
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from .wavenet import WaveNet
+
+
+class ReferenceEncoder(WaveNet):
+    def __init__(
+        self,
+        input_channels: Optional[int] = None,
+        output_channels: Optional[int] = None,
+        residual_channels: int = 512,
+        residual_layers: int = 20,
+        dilation_cycle: Optional[int] = 4,
+        num_heads: int = 8,
+        latent_len: int = 4,
+    ):
+        super().__init__(
+            input_channels=input_channels,
+            residual_channels=residual_channels,
+            residual_layers=residual_layers,
+            dilation_cycle=dilation_cycle,
+        )
+
+        self.head_dim = residual_channels // num_heads
+        self.num_heads = num_heads
+
+        self.latent_len = latent_len
+        self.latent = nn.Parameter(torch.zeros(1, self.latent_len, residual_channels))
+
+        self.q = nn.Linear(residual_channels, residual_channels, bias=True)
+        self.kv = nn.Linear(residual_channels, residual_channels * 2, bias=True)
+        self.q_norm = nn.LayerNorm(self.head_dim)
+        self.k_norm = nn.LayerNorm(self.head_dim)
+        self.proj = nn.Linear(residual_channels, residual_channels)
+        self.proj_drop = nn.Dropout(0.1)
+
+        self.norm = nn.LayerNorm(residual_channels)
+        self.mlp = nn.Sequential(
+            nn.Linear(residual_channels, residual_channels * 4),
+            nn.SiLU(),
+            nn.Linear(residual_channels * 4, residual_channels),
+        )
+        self.output_projection_attn = nn.Linear(residual_channels, output_channels)
+
+        torch.nn.init.trunc_normal_(self.latent, std=0.02)
+        self.apply(self.init_weights)
+
+    def init_weights(self, m):
+        if isinstance(m, nn.Linear):
+            torch.nn.init.trunc_normal_(m.weight, std=0.02)
+            if m.bias is not None:
+                torch.nn.init.constant_(m.bias, 0)
+
+    def forward(self, x, attn_mask=None):
+        x = super().forward(x).mT
+        B, N, C = x.shape
+
+        # Calculate mask
+        if attn_mask is not None:
+            assert attn_mask.shape == (B, N) and attn_mask.dtype == torch.bool
+
+            attn_mask = attn_mask[:, None, None, :].expand(
+                B, self.num_heads, self.latent_len, N
+            )
+
+        q_latent = self.latent.expand(B, -1, -1)
+        q = (
+            self.q(q_latent)
+            .reshape(B, self.latent_len, self.num_heads, self.head_dim)
+            .transpose(1, 2)
+        )
+
+        kv = (
+            self.kv(x)
+            .reshape(B, N, 2, self.num_heads, self.head_dim)
+            .permute(2, 0, 3, 1, 4)
+        )
+        k, v = kv.unbind(0)
+
+        q, k = self.q_norm(q), self.k_norm(k)
+        x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
+
+        x = x.transpose(1, 2).reshape(B, self.latent_len, C)
+        x = self.proj(x)
+        x = self.proj_drop(x)
+
+        x = x + self.mlp(self.norm(x))
+        x = self.output_projection_attn(x)
+        x = x.mean(1)
+
+        return x
+
+
+if __name__ == "__main__":
+    with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
+        model = ReferenceEncoder(
+            input_channels=128,
+            output_channels=64,
+            residual_channels=384,
+            residual_layers=20,
+            dilation_cycle=4,
+            num_heads=8,
+        )
+        x = torch.randn(4, 128, 64)
+        mask = torch.ones(4, 64, dtype=torch.bool)
+        y = model(x, mask)
+        print(y.shape)
+        loss = F.mse_loss(y, torch.randn(4, 64))
+        loss.backward()