Jelajahi Sumber

Add transformer flow

Lengyue 2 tahun lalu
induk
melakukan
3c37033972

+ 3 - 1
fish_speech/configs/hubert_vq.yaml

@@ -60,7 +60,9 @@ model:
     filter_channels: 768
     n_heads: 4
     n_layers: 8
-    n_layers_q: 16
+    n_layers_q: 8
+    n_layers_flow: 8
+    n_flows: 4
     n_layers_spk: 6
     kernel_size: 3
     p_dropout: 0.1

+ 10 - 15
fish_speech/models/vqgan/lit_module.py

@@ -14,7 +14,7 @@ from fish_speech.models.vqgan.losses import (
     discriminator_loss,
     feature_loss,
     generator_loss,
-    kl_loss_normal,
+    kl_loss,
 )
 from fish_speech.models.vqgan.modules.discriminator import EnsembleDiscriminator
 from fish_speech.models.vqgan.modules.models import SynthesizerTrn
@@ -102,8 +102,8 @@ class VQGAN(L.LightningModule):
             ids_slice,
             x_mask,
             y_mask,
-            (z_q_audio, z_p),
-            (m_p_text, logs_p_text),
+            (z_q, z_p),
+            (m_p, logs_p),
             (m_q, logs_q),
         ) = self.generator(features, feature_lengths, gt_mels)
 
@@ -140,20 +140,15 @@ class VQGAN(L.LightningModule):
             loss_mel = F.l1_loss(y_mel, y_hat_mel)
             loss_adv, _ = generator_loss(y_d_hat_g)
             loss_fm = feature_loss(fmap_r, fmap_g)
-            # x_mask,
-            # y_mask,
-            # (z_q_audio, z_p),
-            # (m_p_text, logs_p_text),
-            # (m_q, logs_q),
-            loss_kl = kl_loss_normal(
-                m_q,
-                logs_q,
-                m_p_text,
-                logs_p_text,
-                x_mask,
+            loss_kl = kl_loss(
+                z_p=z_p,
+                logs_q=logs_q,
+                m_p=m_p,
+                logs_p=logs_p,
+                z_mask=x_mask,
             )
 
-            loss_gen_all = loss_mel * 45 + loss_fm + loss_adv + loss_kl * 0.05
+            loss_gen_all = loss_mel * 45 + loss_fm + loss_adv + loss_kl * 1
 
         self.log(
             "train/generator/loss",

+ 0 - 24
fish_speech/models/vqgan/losses.py

@@ -66,27 +66,3 @@ def kl_loss(
     kl = torch.sum(kl * z_mask)
     l = kl / torch.sum(z_mask)
     return l
-
-
-def kl_loss_normal(
-    m_q: torch.Tensor,
-    logs_q: torch.Tensor,
-    m_p: torch.Tensor,
-    logs_p: torch.Tensor,
-    z_mask: torch.Tensor,
-):
-    """
-    z_p, logs_q: [b, h, t_t]
-    m_p, logs_p: [b, h, t_t]
-    """
-    m_q = m_q.float()
-    logs_q = logs_q.float()
-    m_p = m_p.float()
-    logs_p = logs_p.float()
-    z_mask = z_mask.float()
-
-    kl = logs_p - logs_q - 0.5
-    kl += 0.5 * (torch.exp(2.0 * logs_q) + (m_q - m_p) ** 2) * torch.exp(-2.0 * logs_p)
-    kl = torch.sum(kl * z_mask)
-    l = kl / torch.sum(z_mask)
-    return l

+ 11 - 8
fish_speech/models/vqgan/modules/encoders.py

@@ -22,7 +22,6 @@ class TextEncoder(nn.Module):
         kernel_size: int,
         dropout: float,
         gin_channels=0,
-        lang_channels=0,
         speaker_cond_layer=0,
     ):
         """Text Encoder for VITS model.
@@ -37,7 +36,6 @@ class TextEncoder(nn.Module):
             kernel_size (int): Kernel size for the FFN layers in Transformer network.
             dropout (float): Dropout rate for the Transformer layers.
             gin_channels (int, optional): Number of channels for speaker embedding. Defaults to 0.
-            lang_channels (int, optional): Number of channels for language embedding. Defaults to 0.
         """
         super().__init__()
         self.out_channels = out_channels
@@ -58,7 +56,6 @@ class TextEncoder(nn.Module):
             dropout=dropout,
             window_size=4,
             gin_channels=gin_channels,
-            lang_channels=lang_channels,
             speaker_cond_layer=speaker_cond_layer,
         )
         self.proj = nn.Linear(hidden_channels, out_channels * 2)
@@ -68,7 +65,7 @@ class TextEncoder(nn.Module):
         x: torch.Tensor,
         x_lengths: torch.Tensor,
         g: torch.Tensor = None,
-        lang: torch.Tensor = None,
+        noise_scale: float = 1,
     ):
         """
         Shapes:
@@ -79,11 +76,11 @@ class TextEncoder(nn.Module):
         x = self.emb(x).mT  # * math.sqrt(self.hidden_channels)  # [b, h, t]
         x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
 
-        x = self.encoder(x, x_mask, g=g, lang=lang)
+        x = self.encoder(x, x_mask, g=g)
         stats = self.proj(x.mT).mT * x_mask
 
         m, logs = torch.split(stats, self.out_channels, dim=1)
-        z = m + torch.randn_like(m) * torch.exp(logs) * x_mask
+        z = m + torch.randn_like(m) * torch.exp(logs) * x_mask * noise_scale
         return z, m, logs, x, x_mask
 
 
@@ -126,7 +123,13 @@ class PosteriorEncoder(nn.Module):
         )
         self.proj = nn.Linear(hidden_channels, out_channels * 2)
 
-    def forward(self, x: torch.Tensor, x_lengths: torch.Tensor, g=None):
+    def forward(
+        self,
+        x: torch.Tensor,
+        x_lengths: torch.Tensor,
+        g: torch.Tensor,
+        noise_scale: float = 1,
+    ):
         """
         Shapes:
             - x: :math:`[B, C, T]`
@@ -138,7 +141,7 @@ class PosteriorEncoder(nn.Module):
         x = self.enc(x, x_mask, g=g)
         stats = self.proj(x.mT).mT * x_mask
         m, logs = torch.split(stats, self.out_channels, dim=1)
-        z = m + torch.randn_like(m) * torch.exp(logs) * x_mask
+        z = m + torch.randn_like(m) * torch.exp(logs) * x_mask * noise_scale
         return z, m, logs, x_mask
 
 

+ 194 - 0
fish_speech/models/vqgan/modules/flow.py

@@ -0,0 +1,194 @@
+import torch
+from torch import nn
+
+from fish_speech.models.vqgan.modules.modules import Flip
+from fish_speech.models.vqgan.modules.normalization import LayerNorm
+from fish_speech.models.vqgan.modules.transformer import FFN, MultiHeadAttention
+
+
+class TransformerCouplingBlock(nn.Module):
+    def __init__(
+        self,
+        channels,
+        hidden_channels,
+        filter_channels,
+        n_heads,
+        n_layers,
+        kernel_size,
+        p_dropout,
+        n_flows=4,
+        gin_channels=0,
+    ):
+        super().__init__()
+        self.channels = channels
+        self.hidden_channels = hidden_channels
+        self.kernel_size = kernel_size
+        self.n_layers = n_layers
+        self.n_flows = n_flows
+        self.gin_channels = gin_channels
+
+        self.flows = nn.ModuleList()
+
+        for i in range(n_flows):
+            self.flows.append(
+                TransformerCouplingLayer(
+                    channels,
+                    hidden_channels,
+                    kernel_size,
+                    n_layers,
+                    n_heads,
+                    p_dropout,
+                    filter_channels,
+                    mean_only=True,
+                    gin_channels=self.gin_channels,
+                )
+            )
+            self.flows.append(Flip())
+
+    def forward(self, x, x_mask, g=None, reverse=False):
+        if not reverse:
+            for flow in self.flows:
+                x, _ = flow(x, x_mask, g=g, reverse=reverse)
+        else:
+            for flow in reversed(self.flows):
+                x = flow(x, x_mask, g=g, reverse=reverse)
+        return x
+
+
+class TransformerCouplingLayer(nn.Module):
+    def __init__(
+        self,
+        channels,
+        hidden_channels,
+        kernel_size,
+        n_layers,
+        n_heads,
+        p_dropout=0,
+        filter_channels=0,
+        mean_only=False,
+        gin_channels=0,
+    ):
+        super().__init__()
+
+        assert channels % 2 == 0, "channels should be divisible by 2"
+
+        self.channels = channels
+        self.hidden_channels = hidden_channels
+        self.kernel_size = kernel_size
+        self.n_layers = n_layers
+        self.half_channels = channels // 2
+        self.mean_only = mean_only
+
+        self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
+        self.enc = Encoder(
+            hidden_channels,
+            filter_channels,
+            n_heads,
+            n_layers,
+            kernel_size,
+            p_dropout,
+            gin_channels=gin_channels,
+        )
+        self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
+        self.post.weight.data.zero_()
+        self.post.bias.data.zero_()
+
+    def forward(self, x, x_mask, g=None, reverse=False):
+        x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
+        h = self.pre(x0) * x_mask
+        h = self.enc(h, x_mask, g=g)
+        stats = self.post(h) * x_mask
+        if not self.mean_only:
+            m, logs = torch.split(stats, [self.half_channels] * 2, 1)
+        else:
+            m = stats
+            logs = torch.zeros_like(m)
+
+        if not reverse:
+            x1 = m + x1 * torch.exp(logs) * x_mask
+            x = torch.cat([x0, x1], 1)
+            logdet = torch.sum(logs, [1, 2])
+            return x, logdet
+        else:
+            x1 = (x1 - m) * torch.exp(-logs) * x_mask
+            x = torch.cat([x0, x1], 1)
+            return x
+
+
+class Encoder(nn.Module):
+    def __init__(
+        self,
+        hidden_channels,
+        filter_channels,
+        n_heads,
+        n_layers,
+        kernel_size=1,
+        p_dropout=0.0,
+        window_size=4,
+        gin_channels=512,
+        cond_layer_idx=2,
+    ):
+        super().__init__()
+        self.hidden_channels = hidden_channels
+        self.filter_channels = filter_channels
+        self.n_heads = n_heads
+        self.n_layers = n_layers
+        self.kernel_size = kernel_size
+        self.p_dropout = p_dropout
+        self.window_size = window_size
+
+        self.spk_emb_linear = nn.Linear(gin_channels, self.hidden_channels)
+        self.cond_layer_idx = cond_layer_idx
+
+        assert (
+            self.cond_layer_idx < self.n_layers
+        ), "cond_layer_idx should be less than n_layers"
+
+        self.drop = nn.Dropout(p_dropout)
+        self.attn_layers = nn.ModuleList()
+        self.norm_layers_1 = nn.ModuleList()
+        self.ffn_layers = nn.ModuleList()
+        self.norm_layers_2 = nn.ModuleList()
+        for i in range(self.n_layers):
+            self.attn_layers.append(
+                MultiHeadAttention(
+                    hidden_channels,
+                    hidden_channels,
+                    n_heads,
+                    p_dropout=p_dropout,
+                    window_size=window_size,
+                )
+            )
+            self.norm_layers_1.append(LayerNorm(hidden_channels))
+            self.ffn_layers.append(
+                FFN(
+                    hidden_channels,
+                    hidden_channels,
+                    filter_channels,
+                    kernel_size,
+                    p_dropout=p_dropout,
+                )
+            )
+            self.norm_layers_2.append(LayerNorm(hidden_channels))
+
+    def forward(self, x, x_mask, g=None):
+        attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
+        x = x * x_mask
+
+        for i in range(self.n_layers):
+            if i == self.cond_layer_idx and g is not None:
+                g = self.spk_emb_linear(g.transpose(1, 2))
+                g = g.transpose(1, 2)
+                x = x + g
+                x = x * x_mask
+            y = self.attn_layers[i](x, x, attn_mask)
+            y = self.drop(y)
+            x = self.norm_layers_1[i](x + y)
+
+            y = self.ffn_layers[i](x, x_mask)
+            y = self.drop(y)
+            x = self.norm_layers_2[i](x + y)
+
+        x = x * x_mask
+
+        return x

+ 28 - 6
fish_speech/models/vqgan/modules/models.py

@@ -7,6 +7,7 @@ from fish_speech.models.vqgan.modules.encoders import (
     SpeakerEncoder,
     TextEncoder,
 )
+from fish_speech.models.vqgan.modules.flow import TransformerCouplingBlock
 from fish_speech.models.vqgan.utils import rand_slice_segments
 
 
@@ -17,6 +18,7 @@ class SynthesizerTrn(nn.Module):
 
     def __init__(
         self,
+        *,
         in_channels,
         spec_channels,
         segment_size,
@@ -25,8 +27,10 @@ class SynthesizerTrn(nn.Module):
         filter_channels,
         n_heads,
         n_layers,
+        n_flows,
         n_layers_q,
         n_layers_spk,
+        n_layers_flow,
         kernel_size,
         p_dropout,
         speaker_cond_layer,
@@ -62,6 +66,17 @@ class SynthesizerTrn(nn.Module):
             num_layers=n_layers_spk,
             p_dropout=p_dropout,
         )
+        self.flow = TransformerCouplingBlock(
+            channels=inter_channels,
+            hidden_channels=hidden_channels,
+            filter_channels=filter_channels,
+            n_heads=n_heads,
+            n_layers=n_layers_flow,
+            kernel_size=5,
+            p_dropout=p_dropout,
+            n_flows=n_flows,
+            gin_channels=gin_channels,
+        )
         self.enc_q = PosteriorEncoder(
             spec_channels,
             inter_channels,
@@ -84,8 +99,11 @@ class SynthesizerTrn(nn.Module):
 
     def forward(self, x, x_lengths, y):
         g = self.enc_spk(y, x_lengths)
-        z_p, m_p, logs_p, _, x_mask = self.enc_p(x, x_lengths, g=g)
+
+        _, m_p, logs_p, _, x_mask = self.enc_p(x, x_lengths, g=g)
         z_q, m_q, logs_q, y_mask = self.enc_q(y, x_lengths, g=g)
+        z_p = self.flow(z_q, y_mask, g=g, reverse=False)
+
         z_slice, ids_slice = rand_slice_segments(z_q, x_lengths, self.segment_size)
         o = self.dec(z_slice, g=g)
 
@@ -99,17 +117,21 @@ class SynthesizerTrn(nn.Module):
             (m_q, logs_q),
         )
 
-    def infer(self, x, x_lengths, y, max_len=None):
+    def infer(self, x, x_lengths, y, max_len=None, noise_scale=0.35):
         g = self.enc_spk(y, x_lengths)
-        z_p, m_p, logs_p, h_text, x_mask = self.enc_p(x, x_lengths, g=g)
-        # z_p_audio, m_p_audio, logs_p_audio = self.flow(z_p_text, m_p_text, logs_p_text, x_mask, g=g, reverse=True)
+        z_p, m_p, logs_p, h_text, x_mask = self.enc_p(
+            x, x_lengths, g=g, noise_scale=noise_scale
+        )
+        z_p = self.flow(z_p, x_mask, g=g, reverse=True)
 
         o = self.dec((z_p * x_mask)[:, :, :max_len], g=g)
         return o
 
-    def reconstruct(self, x, x_lengths, max_len=None):
+    def reconstruct(self, x, x_lengths, max_len=None, noise_scale=0.35):
         g = self.enc_spk(x, x_lengths)
-        z_q, m_q, logs_q, x_mask = self.enc_q(x, x_lengths, g=g)
+        z_q, m_q, logs_q, x_mask = self.enc_q(
+            x, x_lengths, g=g, noise_scale=noise_scale
+        )
         o = self.dec((z_q * x_mask)[:, :, :max_len], g=g)
 
         return o

+ 6 - 2
fish_speech/models/vqgan/modules/transformer.py

@@ -23,10 +23,13 @@ class RelativePositionTransformer(nn.Module):
         dropout=0.0,
         window_size=4,
         gin_channels=0,
-        lang_channels=0,
         speaker_cond_layer=0,
     ):
         super().__init__()
+        assert (
+            out_channels == hidden_channels
+        ), "out_channels must be equal to hidden_channels"
+
         self.n_layers = n_layers
         self.speaker_cond_layer = speaker_cond_layer
 
@@ -56,6 +59,7 @@ class RelativePositionTransformer(nn.Module):
                 )
             )
             self.norm_layers_2.append(LayerNorm(hidden_channels))
+
         if gin_channels != 0:
             self.cond = nn.Linear(gin_channels, hidden_channels)
 
@@ -64,7 +68,6 @@ class RelativePositionTransformer(nn.Module):
         x: torch.Tensor,
         x_mask: torch.Tensor,
         g: torch.Tensor = None,
-        lang: torch.Tensor = None,
     ):
         attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
         x = x * x_mask
@@ -75,6 +78,7 @@ class RelativePositionTransformer(nn.Module):
                 # ! g = torch.detach(g)
                 x = x + self.cond(g.mT).mT
                 x = x * x_mask
+
             y = self.attn_layers[i](x, x, attn_mask)
             y = self.drop(y)
             x = self.norm_layers_1[i](x + y)