Procházet zdrojové kódy

Init synthesizer trn and test it

Lengyue před 1 rokem
rodič
revize
ec55ce0499

+ 350 - 0
fish_speech/models/vits_decoder/modules/attentions.py

@@ -0,0 +1,350 @@
+import math
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+from torch.nn.utils import remove_weight_norm, weight_norm
+
+from fish_speech.models.vits_decoder.modules import commons
+
+from .modules import LayerNorm
+
+
+class Encoder(nn.Module):
+    def __init__(
+        self,
+        hidden_channels,
+        filter_channels,
+        n_heads,
+        n_layers,
+        kernel_size=1,
+        p_dropout=0.0,
+        window_size=4,
+        isflow=False,
+        gin_channels=0,
+    ):
+        super().__init__()
+        self.hidden_channels = hidden_channels
+        self.filter_channels = filter_channels
+        self.n_heads = n_heads
+        self.n_layers = n_layers
+        self.kernel_size = kernel_size
+        self.p_dropout = p_dropout
+        self.window_size = window_size
+
+        self.drop = nn.Dropout(p_dropout)
+        self.attn_layers = nn.ModuleList()
+        self.norm_layers_1 = nn.ModuleList()
+        self.ffn_layers = nn.ModuleList()
+        self.norm_layers_2 = nn.ModuleList()
+        for i in range(self.n_layers):
+            self.attn_layers.append(
+                MultiHeadAttention(
+                    hidden_channels,
+                    hidden_channels,
+                    n_heads,
+                    p_dropout=p_dropout,
+                    window_size=window_size,
+                )
+            )
+            self.norm_layers_1.append(LayerNorm(hidden_channels))
+            self.ffn_layers.append(
+                FFN(
+                    hidden_channels,
+                    hidden_channels,
+                    filter_channels,
+                    kernel_size,
+                    p_dropout=p_dropout,
+                )
+            )
+            self.norm_layers_2.append(LayerNorm(hidden_channels))
+
+        if isflow:
+            cond_layer = torch.nn.Conv1d(
+                gin_channels, 2 * hidden_channels * n_layers, 1
+            )
+            self.cond_pre = torch.nn.Conv1d(hidden_channels, 2 * hidden_channels, 1)
+            self.cond_layer = weight_norm(cond_layer, "weight")
+            self.gin_channels = gin_channels
+
+    def forward(self, x, x_mask, g=None):
+        attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
+        x = x * x_mask
+        if g is not None:
+            g = self.cond_layer(g)
+
+        for i in range(self.n_layers):
+            if g is not None:
+                x = self.cond_pre(x)
+                cond_offset = i * 2 * self.hidden_channels
+                g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
+                x = commons.fused_add_tanh_sigmoid_multiply(
+                    x, g_l, torch.IntTensor([self.hidden_channels])
+                )
+            y = self.attn_layers[i](x, x, attn_mask)
+            y = self.drop(y)
+            x = self.norm_layers_1[i](x + y)
+
+            y = self.ffn_layers[i](x, x_mask)
+            y = self.drop(y)
+            x = self.norm_layers_2[i](x + y)
+        x = x * x_mask
+        return x
+
+
+class MultiHeadAttention(nn.Module):
+    def __init__(
+        self,
+        channels,
+        out_channels,
+        n_heads,
+        p_dropout=0.0,
+        window_size=None,
+        heads_share=True,
+        block_length=None,
+        proximal_bias=False,
+        proximal_init=False,
+    ):
+        super().__init__()
+        assert channels % n_heads == 0
+
+        self.channels = channels
+        self.out_channels = out_channels
+        self.n_heads = n_heads
+        self.p_dropout = p_dropout
+        self.window_size = window_size
+        self.heads_share = heads_share
+        self.block_length = block_length
+        self.proximal_bias = proximal_bias
+        self.proximal_init = proximal_init
+        self.attn = None
+
+        self.k_channels = channels // n_heads
+        self.conv_q = nn.Conv1d(channels, channels, 1)
+        self.conv_k = nn.Conv1d(channels, channels, 1)
+        self.conv_v = nn.Conv1d(channels, channels, 1)
+        self.conv_o = nn.Conv1d(channels, out_channels, 1)
+        self.drop = nn.Dropout(p_dropout)
+
+        if window_size is not None:
+            n_heads_rel = 1 if heads_share else n_heads
+            rel_stddev = self.k_channels**-0.5
+            self.emb_rel_k = nn.Parameter(
+                torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
+                * rel_stddev
+            )
+            self.emb_rel_v = nn.Parameter(
+                torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
+                * rel_stddev
+            )
+
+        nn.init.xavier_uniform_(self.conv_q.weight)
+        nn.init.xavier_uniform_(self.conv_k.weight)
+        nn.init.xavier_uniform_(self.conv_v.weight)
+        if proximal_init:
+            with torch.no_grad():
+                self.conv_k.weight.copy_(self.conv_q.weight)
+                self.conv_k.bias.copy_(self.conv_q.bias)
+
+    def forward(self, x, c, attn_mask=None):
+        q = self.conv_q(x)
+        k = self.conv_k(c)
+        v = self.conv_v(c)
+
+        x, self.attn = self.attention(q, k, v, mask=attn_mask)
+
+        x = self.conv_o(x)
+        return x
+
+    def attention(self, query, key, value, mask=None):
+        # reshape [b, d, t] -> [b, n_h, t, d_k]
+        b, d, t_s, t_t = (*key.size(), query.size(2))
+        query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
+        key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
+        value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
+
+        scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
+        if self.window_size is not None:
+            assert (
+                t_s == t_t
+            ), "Relative attention is only available for self-attention."
+            key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
+            rel_logits = self._matmul_with_relative_keys(
+                query / math.sqrt(self.k_channels), key_relative_embeddings
+            )
+            scores_local = self._relative_position_to_absolute_position(rel_logits)
+            scores = scores + scores_local
+        if self.proximal_bias:
+            assert t_s == t_t, "Proximal bias is only available for self-attention."
+            scores = scores + self._attention_bias_proximal(t_s).to(
+                device=scores.device, dtype=scores.dtype
+            )
+        if mask is not None:
+            scores = scores.masked_fill(mask == 0, -1e4)
+            if self.block_length is not None:
+                assert (
+                    t_s == t_t
+                ), "Local attention is only available for self-attention."
+                block_mask = (
+                    torch.ones_like(scores)
+                    .triu(-self.block_length)
+                    .tril(self.block_length)
+                )
+                scores = scores.masked_fill(block_mask == 0, -1e4)
+        p_attn = F.softmax(scores, dim=-1)  # [b, n_h, t_t, t_s]
+        p_attn = self.drop(p_attn)
+        output = torch.matmul(p_attn, value)
+        if self.window_size is not None:
+            relative_weights = self._absolute_position_to_relative_position(p_attn)
+            value_relative_embeddings = self._get_relative_embeddings(
+                self.emb_rel_v, t_s
+            )
+            output = output + self._matmul_with_relative_values(
+                relative_weights, value_relative_embeddings
+            )
+        output = (
+            output.transpose(2, 3).contiguous().view(b, d, t_t)
+        )  # [b, n_h, t_t, d_k] -> [b, d, t_t]
+        return output, p_attn
+
+    def _matmul_with_relative_values(self, x, y):
+        """
+        x: [b, h, l, m]
+        y: [h or 1, m, d]
+        ret: [b, h, l, d]
+        """
+        ret = torch.matmul(x, y.unsqueeze(0))
+        return ret
+
+    def _matmul_with_relative_keys(self, x, y):
+        """
+        x: [b, h, l, d]
+        y: [h or 1, m, d]
+        ret: [b, h, l, m]
+        """
+        ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
+        return ret
+
+    def _get_relative_embeddings(self, relative_embeddings, length):
+        max_relative_position = 2 * self.window_size + 1
+        # Pad first before slice to avoid using cond ops.
+        pad_length = max(length - (self.window_size + 1), 0)
+        slice_start_position = max((self.window_size + 1) - length, 0)
+        slice_end_position = slice_start_position + 2 * length - 1
+        if pad_length > 0:
+            padded_relative_embeddings = F.pad(
+                relative_embeddings,
+                commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
+            )
+        else:
+            padded_relative_embeddings = relative_embeddings
+        used_relative_embeddings = padded_relative_embeddings[
+            :, slice_start_position:slice_end_position
+        ]
+        return used_relative_embeddings
+
+    def _relative_position_to_absolute_position(self, x):
+        """
+        x: [b, h, l, 2*l-1]
+        ret: [b, h, l, l]
+        """
+        batch, heads, length, _ = x.size()
+        # Concat columns of pad to shift from relative to absolute indexing.
+        x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
+
+        # Concat extra elements so to add up to shape (len+1, 2*len-1).
+        x_flat = x.view([batch, heads, length * 2 * length])
+        x_flat = F.pad(
+            x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
+        )
+
+        # Reshape and slice out the padded elements.
+        x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
+            :, :, :length, length - 1 :
+        ]
+        return x_final
+
+    def _absolute_position_to_relative_position(self, x):
+        """
+        x: [b, h, l, l]
+        ret: [b, h, l, 2*l-1]
+        """
+        batch, heads, length, _ = x.size()
+        # pad along column
+        x = F.pad(
+            x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
+        )
+        x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
+        # add 0's in the beginning that will skew the elements after reshape
+        x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
+        x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
+        return x_final
+
+    def _attention_bias_proximal(self, length):
+        """Bias for self-attention to encourage attention to close positions.
+        Args:
+          length: an integer scalar.
+        Returns:
+          a Tensor with shape [1, 1, length, length]
+        """
+        r = torch.arange(length, dtype=torch.float32)
+        diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
+        return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
+
+
+class FFN(nn.Module):
+    def __init__(
+        self,
+        in_channels,
+        out_channels,
+        filter_channels,
+        kernel_size,
+        p_dropout=0.0,
+        activation=None,
+        causal=False,
+    ):
+        super().__init__()
+        self.in_channels = in_channels
+        self.out_channels = out_channels
+        self.filter_channels = filter_channels
+        self.kernel_size = kernel_size
+        self.p_dropout = p_dropout
+        self.activation = activation
+        self.causal = causal
+
+        if causal:
+            self.padding = self._causal_padding
+        else:
+            self.padding = self._same_padding
+
+        self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
+        self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
+        self.drop = nn.Dropout(p_dropout)
+
+    def forward(self, x, x_mask):
+        x = self.conv_1(self.padding(x * x_mask))
+        if self.activation == "gelu":
+            x = x * torch.sigmoid(1.702 * x)
+        else:
+            x = torch.relu(x)
+        x = self.drop(x)
+        x = self.conv_2(self.padding(x * x_mask))
+        return x * x_mask
+
+    def _causal_padding(self, x):
+        if self.kernel_size == 1:
+            return x
+        pad_l = self.kernel_size - 1
+        pad_r = 0
+        padding = [[0, 0], [0, 0], [pad_l, pad_r]]
+        x = F.pad(x, commons.convert_pad_shape(padding))
+        return x
+
+    def _same_padding(self, x):
+        if self.kernel_size == 1:
+            return x
+        pad_l = (self.kernel_size - 1) // 2
+        pad_r = self.kernel_size // 2
+        padding = [[0, 0], [0, 0], [pad_l, pad_r]]
+        x = F.pad(x, commons.convert_pad_shape(padding))
+        return x

+ 190 - 0
fish_speech/models/vits_decoder/modules/commons.py

@@ -0,0 +1,190 @@
+import math
+
+import torch
+from torch.nn import functional as F
+
+
+def init_weights(m, mean=0.0, std=0.01):
+    classname = m.__class__.__name__
+    if classname.find("Conv") != -1:
+        m.weight.data.normal_(mean, std)
+
+
+def get_padding(kernel_size, dilation=1):
+    return int((kernel_size * dilation - dilation) / 2)
+
+
+def convert_pad_shape(pad_shape):
+    l = pad_shape[::-1]
+    pad_shape = [item for sublist in l for item in sublist]
+    return pad_shape
+
+
+def intersperse(lst, item):
+    result = [item] * (len(lst) * 2 + 1)
+    result[1::2] = lst
+    return result
+
+
+def kl_divergence(m_p, logs_p, m_q, logs_q):
+    """KL(P||Q)"""
+    kl = (logs_q - logs_p) - 0.5
+    kl += (
+        0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
+    )
+    return kl
+
+
+def rand_gumbel(shape):
+    """Sample from the Gumbel distribution, protect from overflows."""
+    uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
+    return -torch.log(-torch.log(uniform_samples))
+
+
+def rand_gumbel_like(x):
+    g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
+    return g
+
+
+def slice_segments(x, ids_str, segment_size=4):
+    ret = torch.zeros_like(x[:, :, :segment_size])
+    for i in range(x.size(0)):
+        idx_str = ids_str[i]
+        idx_end = idx_str + segment_size
+        ret[i] = x[i, :, idx_str:idx_end]
+    return ret
+
+
+def rand_slice_segments(x, x_lengths=None, segment_size=4):
+    b, d, t = x.size()
+    if x_lengths is None:
+        x_lengths = t
+    ids_str_max = x_lengths - segment_size + 1
+    ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
+    ret = slice_segments(x, ids_str, segment_size)
+    return ret, ids_str
+
+
+def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
+    position = torch.arange(length, dtype=torch.float)
+    num_timescales = channels // 2
+    log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
+        num_timescales - 1
+    )
+    inv_timescales = min_timescale * torch.exp(
+        torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
+    )
+    scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
+    signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
+    signal = F.pad(signal, [0, 0, 0, channels % 2])
+    signal = signal.view(1, channels, length)
+    return signal
+
+
+def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
+    b, channels, length = x.size()
+    signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
+    return x + signal.to(dtype=x.dtype, device=x.device)
+
+
+def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
+    b, channels, length = x.size()
+    signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
+    return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
+
+
+def subsequent_mask(length):
+    mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
+    return mask
+
+
+@torch.jit.script
+def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
+    n_channels_int = n_channels[0]
+    in_act = input_a + input_b
+    t_act = torch.tanh(in_act[:, :n_channels_int, :])
+    s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
+    acts = t_act * s_act
+    return acts
+
+
+def convert_pad_shape(pad_shape):
+    l = pad_shape[::-1]
+    pad_shape = [item for sublist in l for item in sublist]
+    return pad_shape
+
+
+def shift_1d(x):
+    x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
+    return x
+
+
+def sequence_mask(length, max_length=None):
+    if max_length is None:
+        max_length = length.max()
+    x = torch.arange(max_length, dtype=length.dtype, device=length.device)
+    return x.unsqueeze(0) < length.unsqueeze(1)
+
+
+def generate_path(duration, mask):
+    """
+    duration: [b, 1, t_x]
+    mask: [b, 1, t_y, t_x]
+    """
+    device = duration.device
+
+    b, _, t_y, t_x = mask.shape
+    cum_duration = torch.cumsum(duration, -1)
+
+    cum_duration_flat = cum_duration.view(b * t_x)
+    path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
+    path = path.view(b, t_x, t_y)
+    path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
+    path = path.unsqueeze(1).transpose(2, 3) * mask
+    return path
+
+
+def clip_grad_value_(parameters, clip_value, norm_type=2):
+    if isinstance(parameters, torch.Tensor):
+        parameters = [parameters]
+    parameters = list(filter(lambda p: p.grad is not None, parameters))
+    norm_type = float(norm_type)
+    if clip_value is not None:
+        clip_value = float(clip_value)
+
+    total_norm = 0
+    for p in parameters:
+        param_norm = p.grad.data.norm(norm_type)
+        total_norm += param_norm.item() ** norm_type
+        if clip_value is not None:
+            p.grad.data.clamp_(min=-clip_value, max=clip_value)
+    total_norm = total_norm ** (1.0 / norm_type)
+    return total_norm
+
+
+def squeeze(x, x_mask=None, n_sqz=2):
+    b, c, t = x.size()
+
+    t = (t // n_sqz) * n_sqz
+    x = x[:, :, :t]
+    x_sqz = x.view(b, c, t // n_sqz, n_sqz)
+    x_sqz = x_sqz.permute(0, 3, 1, 2).contiguous().view(b, c * n_sqz, t // n_sqz)
+
+    if x_mask is not None:
+        x_mask = x_mask[:, :, n_sqz - 1 :: n_sqz]
+    else:
+        x_mask = torch.ones(b, 1, t // n_sqz).to(device=x.device, dtype=x.dtype)
+    return x_sqz * x_mask, x_mask
+
+
+def unsqueeze(x, x_mask=None, n_sqz=2):
+    b, c, t = x.size()
+
+    x_unsqz = x.view(b, n_sqz, c // n_sqz, t)
+    x_unsqz = x_unsqz.permute(0, 2, 3, 1).contiguous().view(b, c // n_sqz, t * n_sqz)
+
+    if x_mask is not None:
+        x_mask = x_mask.unsqueeze(-1).repeat(1, 1, 1, n_sqz).view(b, 1, t * n_sqz)
+    else:
+        x_mask = torch.ones(b, 1, t * n_sqz).to(device=x.device, dtype=x.dtype)
+    return x_unsqz * x_mask, x_mask

+ 660 - 0
fish_speech/models/vits_decoder/modules/models.py

@@ -0,0 +1,660 @@
+import copy
+import math
+
+import torch
+from torch import nn
+from torch.cuda.amp import autocast
+from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d
+from torch.nn import functional as F
+from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
+
+from fish_speech.models.vits_decoder.modules import attentions, commons, modules
+
+from .commons import get_padding, init_weights
+from .mrte import MRTE
+from .vq_encoder import VQEncoder
+
+
+class TextEncoder(nn.Module):
+    def __init__(
+        self,
+        out_channels,
+        hidden_channels,
+        filter_channels,
+        n_heads,
+        n_layers,
+        kernel_size,
+        p_dropout,
+        latent_channels=192,
+    ):
+        super().__init__()
+        self.out_channels = out_channels
+        self.hidden_channels = hidden_channels
+        self.filter_channels = filter_channels
+        self.n_heads = n_heads
+        self.n_layers = n_layers
+        self.kernel_size = kernel_size
+        self.p_dropout = p_dropout
+        self.latent_channels = latent_channels
+
+        self.ssl_proj = nn.Conv1d(768, hidden_channels, 1)
+
+        self.encoder_ssl = attentions.Encoder(
+            hidden_channels,
+            filter_channels,
+            n_heads,
+            n_layers // 2,
+            kernel_size,
+            p_dropout,
+        )
+
+        self.encoder_text = attentions.Encoder(
+            hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
+        )
+        self.text_embedding = nn.Embedding(
+            322, hidden_channels
+        )  # We only use 264, but to make the weight happy
+
+        self.mrte = MRTE()
+
+        self.encoder2 = attentions.Encoder(
+            hidden_channels,
+            filter_channels,
+            n_heads,
+            n_layers // 2,
+            kernel_size,
+            p_dropout,
+        )
+
+        self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
+
+    def forward(self, y, y_lengths, text, text_lengths, ge, test=None):
+        y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(
+            y.dtype
+        )
+
+        y = self.ssl_proj(y * y_mask) * y_mask
+
+        y = self.encoder_ssl(y * y_mask, y_mask)
+
+        text_mask = torch.unsqueeze(
+            commons.sequence_mask(text_lengths, text.size(1)), 1
+        ).to(y.dtype)
+        if test == 1:
+            text[:, :] = 0
+        text = self.text_embedding(text).transpose(1, 2)
+        text = self.encoder_text(text * text_mask, text_mask)
+        y = self.mrte(y, y_mask, text, text_mask, ge)
+
+        y = self.encoder2(y * y_mask, y_mask)
+
+        stats = self.proj(y) * y_mask
+        m, logs = torch.split(stats, self.out_channels, dim=1)
+        return y, m, logs, y_mask
+
+    def extract_latent(self, x):
+        x = self.ssl_proj(x)
+        quantized, codes, commit_loss, quantized_list = self.quantizer(x)
+        return codes.transpose(0, 1)
+
+    def decode_latent(self, codes, y_mask, refer, refer_mask, ge):
+        quantized = self.quantizer.decode(codes)
+
+        y = self.vq_proj(quantized) * y_mask
+        y = self.encoder_ssl(y * y_mask, y_mask)
+
+        y = self.mrte(y, y_mask, refer, refer_mask, ge)
+
+        y = self.encoder2(y * y_mask, y_mask)
+
+        stats = self.proj(y) * y_mask
+        m, logs = torch.split(stats, self.out_channels, dim=1)
+        return y, m, logs, y_mask, quantized
+
+
+class ResidualCouplingBlock(nn.Module):
+    def __init__(
+        self,
+        channels,
+        hidden_channels,
+        kernel_size,
+        dilation_rate,
+        n_layers,
+        n_flows=4,
+        gin_channels=0,
+    ):
+        super().__init__()
+        self.channels = channels
+        self.hidden_channels = hidden_channels
+        self.kernel_size = kernel_size
+        self.dilation_rate = dilation_rate
+        self.n_layers = n_layers
+        self.n_flows = n_flows
+        self.gin_channels = gin_channels
+
+        self.flows = nn.ModuleList()
+        for i in range(n_flows):
+            self.flows.append(
+                modules.ResidualCouplingLayer(
+                    channels,
+                    hidden_channels,
+                    kernel_size,
+                    dilation_rate,
+                    n_layers,
+                    gin_channels=gin_channels,
+                    mean_only=True,
+                )
+            )
+            self.flows.append(modules.Flip())
+
+    def forward(self, x, x_mask, g=None, reverse=False):
+        if not reverse:
+            for flow in self.flows:
+                x, _ = flow(x, x_mask, g=g, reverse=reverse)
+        else:
+            for flow in reversed(self.flows):
+                x = flow(x, x_mask, g=g, reverse=reverse)
+        return x
+
+
+class PosteriorEncoder(nn.Module):
+    def __init__(
+        self,
+        in_channels,
+        out_channels,
+        hidden_channels,
+        kernel_size,
+        dilation_rate,
+        n_layers,
+        gin_channels=0,
+    ):
+        super().__init__()
+        self.in_channels = in_channels
+        self.out_channels = out_channels
+        self.hidden_channels = hidden_channels
+        self.kernel_size = kernel_size
+        self.dilation_rate = dilation_rate
+        self.n_layers = n_layers
+        self.gin_channels = gin_channels
+
+        self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
+        self.enc = modules.WN(
+            hidden_channels,
+            kernel_size,
+            dilation_rate,
+            n_layers,
+            gin_channels=gin_channels,
+        )
+        self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
+
+    def forward(self, x, x_lengths, g=None):
+        if g != None:
+            g = g.detach()
+        x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
+            x.dtype
+        )
+        x = self.pre(x) * x_mask
+        x = self.enc(x, x_mask, g=g)
+        stats = self.proj(x) * x_mask
+        m, logs = torch.split(stats, self.out_channels, dim=1)
+        z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
+        return z, m, logs, x_mask
+
+
+class Generator(torch.nn.Module):
+    def __init__(
+        self,
+        initial_channel,
+        resblock,
+        resblock_kernel_sizes,
+        resblock_dilation_sizes,
+        upsample_rates,
+        upsample_initial_channel,
+        upsample_kernel_sizes,
+        gin_channels=0,
+    ):
+        super(Generator, self).__init__()
+        self.num_kernels = len(resblock_kernel_sizes)
+        self.num_upsamples = len(upsample_rates)
+        self.conv_pre = Conv1d(
+            initial_channel, upsample_initial_channel, 7, 1, padding=3
+        )
+        resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
+
+        self.ups = nn.ModuleList()
+        for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
+            self.ups.append(
+                weight_norm(
+                    ConvTranspose1d(
+                        upsample_initial_channel // (2**i),
+                        upsample_initial_channel // (2 ** (i + 1)),
+                        k,
+                        u,
+                        padding=(k - u) // 2,
+                    )
+                )
+            )
+
+        self.resblocks = nn.ModuleList()
+        for i in range(len(self.ups)):
+            ch = upsample_initial_channel // (2 ** (i + 1))
+            for j, (k, d) in enumerate(
+                zip(resblock_kernel_sizes, resblock_dilation_sizes)
+            ):
+                self.resblocks.append(resblock(ch, k, d))
+
+        self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
+        self.ups.apply(init_weights)
+
+        if gin_channels != 0:
+            self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
+
+    def forward(self, x, g=None):
+        x = self.conv_pre(x)
+        if g is not None:
+            x = x + self.cond(g)
+
+        for i in range(self.num_upsamples):
+            x = F.leaky_relu(x, modules.LRELU_SLOPE)
+            x = self.ups[i](x)
+            xs = None
+            for j in range(self.num_kernels):
+                if xs is None:
+                    xs = self.resblocks[i * self.num_kernels + j](x)
+                else:
+                    xs += self.resblocks[i * self.num_kernels + j](x)
+            x = xs / self.num_kernels
+        x = F.leaky_relu(x)
+        x = self.conv_post(x)
+        x = torch.tanh(x)
+
+        return x
+
+    def remove_weight_norm(self):
+        print("Removing weight norm...")
+        for l in self.ups:
+            remove_weight_norm(l)
+        for l in self.resblocks:
+            l.remove_weight_norm()
+
+
+class DiscriminatorP(torch.nn.Module):
+    def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
+        super(DiscriminatorP, self).__init__()
+        self.period = period
+        self.use_spectral_norm = use_spectral_norm
+        norm_f = weight_norm if use_spectral_norm == False else spectral_norm
+        self.convs = nn.ModuleList(
+            [
+                norm_f(
+                    Conv2d(
+                        1,
+                        32,
+                        (kernel_size, 1),
+                        (stride, 1),
+                        padding=(get_padding(kernel_size, 1), 0),
+                    )
+                ),
+                norm_f(
+                    Conv2d(
+                        32,
+                        128,
+                        (kernel_size, 1),
+                        (stride, 1),
+                        padding=(get_padding(kernel_size, 1), 0),
+                    )
+                ),
+                norm_f(
+                    Conv2d(
+                        128,
+                        512,
+                        (kernel_size, 1),
+                        (stride, 1),
+                        padding=(get_padding(kernel_size, 1), 0),
+                    )
+                ),
+                norm_f(
+                    Conv2d(
+                        512,
+                        1024,
+                        (kernel_size, 1),
+                        (stride, 1),
+                        padding=(get_padding(kernel_size, 1), 0),
+                    )
+                ),
+                norm_f(
+                    Conv2d(
+                        1024,
+                        1024,
+                        (kernel_size, 1),
+                        1,
+                        padding=(get_padding(kernel_size, 1), 0),
+                    )
+                ),
+            ]
+        )
+        self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
+
+    def forward(self, x):
+        fmap = []
+
+        # 1d to 2d
+        b, c, t = x.shape
+        if t % self.period != 0:  # pad first
+            n_pad = self.period - (t % self.period)
+            x = F.pad(x, (0, n_pad), "reflect")
+            t = t + n_pad
+        x = x.view(b, c, t // self.period, self.period)
+
+        for l in self.convs:
+            x = l(x)
+            x = F.leaky_relu(x, modules.LRELU_SLOPE)
+            fmap.append(x)
+        x = self.conv_post(x)
+        fmap.append(x)
+        x = torch.flatten(x, 1, -1)
+
+        return x, fmap
+
+
+class DiscriminatorS(torch.nn.Module):
+    def __init__(self, use_spectral_norm=False):
+        super(DiscriminatorS, self).__init__()
+        norm_f = weight_norm if use_spectral_norm == False else spectral_norm
+        self.convs = nn.ModuleList(
+            [
+                norm_f(Conv1d(1, 16, 15, 1, padding=7)),
+                norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
+                norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
+                norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
+                norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
+                norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
+            ]
+        )
+        self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
+
+    def forward(self, x):
+        fmap = []
+
+        for l in self.convs:
+            x = l(x)
+            x = F.leaky_relu(x, modules.LRELU_SLOPE)
+            fmap.append(x)
+        x = self.conv_post(x)
+        fmap.append(x)
+        x = torch.flatten(x, 1, -1)
+
+        return x, fmap
+
+
+class EnsembledDiscriminator(torch.nn.Module):
+    def __init__(self, use_spectral_norm=False):
+        super().__init__()
+        periods = [2, 3, 5, 7, 11]
+
+        discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
+        discs = discs + [
+            DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
+        ]
+        self.discriminators = nn.ModuleList(discs)
+
+    def forward(self, y, y_hat):
+        y_d_rs = []
+        y_d_gs = []
+        fmap_rs = []
+        fmap_gs = []
+        for i, d in enumerate(self.discriminators):
+            y_d_r, fmap_r = d(y)
+            y_d_g, fmap_g = d(y_hat)
+            y_d_rs.append(y_d_r)
+            y_d_gs.append(y_d_g)
+            fmap_rs.append(fmap_r)
+            fmap_gs.append(fmap_g)
+
+        return y_d_rs, y_d_gs, fmap_rs, fmap_gs
+
+
+class SynthesizerTrn(nn.Module):
+    """
+    Synthesizer for Training
+    """
+
+    def __init__(
+        self,
+        *,
+        spec_channels,
+        segment_size,
+        inter_channels,
+        hidden_channels,
+        filter_channels,
+        n_heads,
+        n_layers,
+        kernel_size,
+        p_dropout,
+        resblock,
+        resblock_kernel_sizes,
+        resblock_dilation_sizes,
+        upsample_rates,
+        upsample_initial_channel,
+        upsample_kernel_sizes,
+        gin_channels=0,
+    ):
+        super().__init__()
+
+        self.spec_channels = spec_channels
+        self.inter_channels = inter_channels
+        self.hidden_channels = hidden_channels
+        self.filter_channels = filter_channels
+        self.n_heads = n_heads
+        self.n_layers = n_layers
+        self.kernel_size = kernel_size
+        self.p_dropout = p_dropout
+        self.resblock = resblock
+        self.resblock_kernel_sizes = resblock_kernel_sizes
+        self.resblock_dilation_sizes = resblock_dilation_sizes
+        self.upsample_rates = upsample_rates
+        self.upsample_initial_channel = upsample_initial_channel
+        self.upsample_kernel_sizes = upsample_kernel_sizes
+        self.segment_size = segment_size
+        self.gin_channels = gin_channels
+
+        self.enc_p = TextEncoder(
+            inter_channels,
+            hidden_channels,
+            filter_channels,
+            n_heads,
+            n_layers,
+            kernel_size,
+            p_dropout,
+        )
+        self.dec = Generator(
+            inter_channels,
+            resblock,
+            resblock_kernel_sizes,
+            resblock_dilation_sizes,
+            upsample_rates,
+            upsample_initial_channel,
+            upsample_kernel_sizes,
+            gin_channels=gin_channels,
+        )
+        self.enc_q = PosteriorEncoder(
+            spec_channels,
+            inter_channels,
+            hidden_channels,
+            5,
+            1,
+            16,
+            gin_channels=gin_channels,
+        )
+        self.flow = ResidualCouplingBlock(
+            inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels
+        )
+
+        self.ref_enc = modules.MelStyleEncoder(
+            spec_channels, style_vector_dim=gin_channels
+        )
+
+        self.vq = VQEncoder()
+        for param in self.vq.parameters():
+            param.requires_grad = False
+
+    def forward(self, audio, audio_lengths, y, y_lengths, text, text_lengths):
+        y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(
+            y.dtype
+        )
+        ge = self.ref_enc(y * y_mask, y_mask)
+        quantized = self.vq(audio, audio_lengths, sr=32000)
+
+        quantized = F.interpolate(quantized, size=int(y.shape[-1]), mode="nearest")
+
+        x, m_p, logs_p, y_mask = self.enc_p(
+            quantized, y_lengths, text, text_lengths, ge
+        )
+        z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=ge)
+        z_p = self.flow(z, y_mask, g=ge)
+
+        z_slice, ids_slice = commons.rand_slice_segments(
+            z, y_lengths, self.segment_size
+        )
+        o = self.dec(z_slice, g=ge)
+        return (
+            o,
+            ids_slice,
+            y_mask,
+            y_mask,
+            (z, z_p, m_p, logs_p, m_q, logs_q),
+        )
+
+    def infer(
+        self,
+        audio,
+        audio_lengths,
+        y,
+        y_lengths,
+        text,
+        text_lengths,
+        test=None,
+        noise_scale=0.5,
+    ):
+        # y_lengths = audio_lengths // 640 # 640 is the hop size
+        y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(
+            y.dtype
+        )
+        ge = self.ref_enc(y * y_mask, y_mask)
+
+        quantized = self.vq(audio, audio_lengths, sr=32000)
+        print(quantized.size())
+        quantized = F.interpolate(
+            quantized, size=int(audio.shape[-1] // 640), mode="nearest"
+        )
+        print(quantized.size())
+
+        x, m_p, logs_p, y_mask = self.enc_p(
+            quantized, audio_lengths, text, text_lengths, ge, test=test
+        )
+        z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
+
+        z = self.flow(z_p, y_mask, g=ge, reverse=True)
+
+        o = self.dec((z * y_mask)[:, :, :], g=ge)
+        return o, y_mask, (z, z_p, m_p, logs_p)
+
+    @torch.no_grad()
+    def decode(self, codes, text, refer, noise_scale=0.5):
+        ge = None
+        if refer is not None:
+            refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
+            refer_mask = torch.unsqueeze(
+                commons.sequence_mask(refer_lengths, refer.size(2)), 1
+            ).to(refer.dtype)
+            ge = self.ref_enc(refer * refer_mask, refer_mask)
+
+        y_lengths = torch.LongTensor([codes.size(2) * 2]).to(codes.device)
+        text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)
+
+        quantized = self.quantizer.decode(codes)
+        quantized = F.interpolate(
+            quantized, size=int(quantized.shape[-1] * 2), mode="nearest"
+        )
+
+        x, m_p, logs_p, y_mask = self.enc_p(
+            quantized, y_lengths, text, text_lengths, ge
+        )
+        z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
+
+        z = self.flow(z_p, y_mask, g=ge, reverse=True)
+
+        o = self.dec((z * y_mask)[:, :, :], g=ge)
+        return o
+
+    def extract_latent(self, x):
+        ssl = self.ssl_proj(x)
+        quantized, codes, commit_loss, quantized_list = self.quantizer(ssl)
+        return codes.transpose(0, 1)
+
+
+if __name__ == "__main__":
+    import librosa
+    from transformers import AutoTokenizer
+
+    from fish_speech.models.vqgan.spectrogram import LinearSpectrogram
+
+    model = SynthesizerTrn(
+        spec_channels=1025,
+        segment_size=20480 // 640,
+        inter_channels=192,
+        hidden_channels=192,
+        filter_channels=768,
+        n_heads=2,
+        n_layers=6,
+        kernel_size=3,
+        p_dropout=0.1,
+        resblock="1",
+        resblock_kernel_sizes=[3, 7, 11],
+        resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
+        upsample_rates=[10, 8, 2, 2, 2],
+        upsample_initial_channel=512,
+        upsample_kernel_sizes=[16, 16, 8, 2, 2],
+        gin_channels=512,
+    )
+
+    ckpt = "./checkpoints/s2_big2k1_158000.pth"
+    # Try to load the model
+    print(f"Loading model from {ckpt}")
+    checkpoint = torch.load(ckpt, map_location="cpu", weights_only=True)
+    print(model.load_state_dict(checkpoint, strict=False))
+
+    # Test
+
+    ref_audio = librosa.load(
+        "data/source/Genshin/Chinese/五郎/vo_DQAQ010_15_gorou_07.wav", sr=32000
+    )[0]
+    input_audio = librosa.load(
+        "data/source/Genshin/Chinese/空/vo_FDAQ003_46_hero_02.wav", sr=32000
+    )[0]
+    # ref_audio = input_audio
+    text = "(现在看来花瓶里的水并不是用来隐藏水迹,而是在莉莉安与考威尔的争斗中不小心撞破的…)"
+    encoded_text = AutoTokenizer.from_pretrained("fishaudio/fish-speech-1")
+    spec = LinearSpectrogram(n_fft=2048, hop_length=640, win_length=2048)
+
+    ref_audio = torch.tensor(ref_audio).unsqueeze(0).unsqueeze(0)
+    ref_spec = spec(ref_audio)
+
+    input_audio = torch.tensor(input_audio).unsqueeze(0).unsqueeze(0)
+    text = encoded_text(text, return_tensors="pt")["input_ids"]
+    print(ref_audio.size(), ref_spec.size(), input_audio.size(), text.size())
+
+    o, y_mask, (z, z_p, m_p, logs_p) = model.infer(
+        input_audio,
+        torch.LongTensor([input_audio.size(2)]),
+        ref_spec,
+        torch.LongTensor([ref_spec.size(2)]),
+        text,
+        torch.LongTensor([text.size(1)]),
+    )
+    print(o.size(), y_mask.size(), z.size(), z_p.size(), m_p.size(), logs_p.size())
+
+    # Save output
+    import soundfile as sf
+
+    sf.write("output.wav", o.squeeze().detach().numpy(), 32000)

+ 619 - 0
fish_speech/models/vits_decoder/modules/modules.py

@@ -0,0 +1,619 @@
+import numpy as np
+import torch
+from torch import nn
+from torch.nn import Conv1d
+from torch.nn import functional as F
+from torch.nn.utils import remove_weight_norm, weight_norm
+
+from .commons import fused_add_tanh_sigmoid_multiply, get_padding, init_weights
+
+LRELU_SLOPE = 0.1
+
+
+class LayerNorm(nn.Module):
+    def __init__(self, channels, eps=1e-5):
+        super().__init__()
+        self.channels = channels
+        self.eps = eps
+
+        self.gamma = nn.Parameter(torch.ones(channels))
+        self.beta = nn.Parameter(torch.zeros(channels))
+
+    def forward(self, x):
+        x = x.transpose(1, -1)
+        x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
+        return x.transpose(1, -1)
+
+
+class ConvReluNorm(nn.Module):
+    def __init__(
+        self,
+        in_channels,
+        hidden_channels,
+        out_channels,
+        kernel_size,
+        n_layers,
+        p_dropout,
+    ):
+        super().__init__()
+        self.in_channels = in_channels
+        self.hidden_channels = hidden_channels
+        self.out_channels = out_channels
+        self.kernel_size = kernel_size
+        self.n_layers = n_layers
+        self.p_dropout = p_dropout
+        assert n_layers > 1, "Number of layers should be larger than 0."
+
+        self.conv_layers = nn.ModuleList()
+        self.norm_layers = nn.ModuleList()
+        self.conv_layers.append(
+            nn.Conv1d(
+                in_channels, hidden_channels, kernel_size, padding=kernel_size // 2
+            )
+        )
+        self.norm_layers.append(LayerNorm(hidden_channels))
+        self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout))
+        for _ in range(n_layers - 1):
+            self.conv_layers.append(
+                nn.Conv1d(
+                    hidden_channels,
+                    hidden_channels,
+                    kernel_size,
+                    padding=kernel_size // 2,
+                )
+            )
+            self.norm_layers.append(LayerNorm(hidden_channels))
+        self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
+        self.proj.weight.data.zero_()
+        self.proj.bias.data.zero_()
+
+    def forward(self, x, x_mask):
+        x_org = x
+        for i in range(self.n_layers):
+            x = self.conv_layers[i](x * x_mask)
+            x = self.norm_layers[i](x)
+            x = self.relu_drop(x)
+        x = x_org + self.proj(x)
+        return x * x_mask
+
+
+class WN(torch.nn.Module):
+    def __init__(
+        self,
+        hidden_channels,
+        kernel_size,
+        dilation_rate,
+        n_layers,
+        gin_channels=0,
+        p_dropout=0,
+    ):
+        super(WN, self).__init__()
+        assert kernel_size % 2 == 1
+        self.hidden_channels = hidden_channels
+        self.kernel_size = (kernel_size,)
+        self.dilation_rate = dilation_rate
+        self.n_layers = n_layers
+        self.gin_channels = gin_channels
+        self.p_dropout = p_dropout
+
+        self.in_layers = torch.nn.ModuleList()
+        self.res_skip_layers = torch.nn.ModuleList()
+        self.drop = nn.Dropout(p_dropout)
+
+        if gin_channels != 0:
+            cond_layer = torch.nn.Conv1d(
+                gin_channels, 2 * hidden_channels * n_layers, 1
+            )
+            self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
+
+        for i in range(n_layers):
+            dilation = dilation_rate**i
+            padding = int((kernel_size * dilation - dilation) / 2)
+            in_layer = torch.nn.Conv1d(
+                hidden_channels,
+                2 * hidden_channels,
+                kernel_size,
+                dilation=dilation,
+                padding=padding,
+            )
+            in_layer = torch.nn.utils.weight_norm(in_layer, name="weight")
+            self.in_layers.append(in_layer)
+
+            # last one is not necessary
+            if i < n_layers - 1:
+                res_skip_channels = 2 * hidden_channels
+            else:
+                res_skip_channels = hidden_channels
+
+            res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
+            res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight")
+            self.res_skip_layers.append(res_skip_layer)
+
+    def forward(self, x, x_mask, g=None, **kwargs):
+        output = torch.zeros_like(x)
+        n_channels_tensor = torch.IntTensor([self.hidden_channels])
+
+        if g is not None:
+            g = self.cond_layer(g)
+
+        for i in range(self.n_layers):
+            x_in = self.in_layers[i](x)
+            if g is not None:
+                cond_offset = i * 2 * self.hidden_channels
+                g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
+            else:
+                g_l = torch.zeros_like(x_in)
+
+            acts = fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
+            acts = self.drop(acts)
+
+            res_skip_acts = self.res_skip_layers[i](acts)
+            if i < self.n_layers - 1:
+                res_acts = res_skip_acts[:, : self.hidden_channels, :]
+                x = (x + res_acts) * x_mask
+                output = output + res_skip_acts[:, self.hidden_channels :, :]
+            else:
+                output = output + res_skip_acts
+        return output * x_mask
+
+    def remove_weight_norm(self):
+        if self.gin_channels != 0:
+            torch.nn.utils.remove_weight_norm(self.cond_layer)
+        for l in self.in_layers:
+            torch.nn.utils.remove_weight_norm(l)
+        for l in self.res_skip_layers:
+            torch.nn.utils.remove_weight_norm(l)
+
+
+class ResBlock1(torch.nn.Module):
+    def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
+        super(ResBlock1, self).__init__()
+        self.convs1 = nn.ModuleList(
+            [
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=dilation[0],
+                        padding=get_padding(kernel_size, dilation[0]),
+                    )
+                ),
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=dilation[1],
+                        padding=get_padding(kernel_size, dilation[1]),
+                    )
+                ),
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=dilation[2],
+                        padding=get_padding(kernel_size, dilation[2]),
+                    )
+                ),
+            ]
+        )
+        self.convs1.apply(init_weights)
+
+        self.convs2 = nn.ModuleList(
+            [
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=1,
+                        padding=get_padding(kernel_size, 1),
+                    )
+                ),
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=1,
+                        padding=get_padding(kernel_size, 1),
+                    )
+                ),
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=1,
+                        padding=get_padding(kernel_size, 1),
+                    )
+                ),
+            ]
+        )
+        self.convs2.apply(init_weights)
+
+    def forward(self, x, x_mask=None):
+        for c1, c2 in zip(self.convs1, self.convs2):
+            xt = F.leaky_relu(x, LRELU_SLOPE)
+            if x_mask is not None:
+                xt = xt * x_mask
+            xt = c1(xt)
+            xt = F.leaky_relu(xt, LRELU_SLOPE)
+            if x_mask is not None:
+                xt = xt * x_mask
+            xt = c2(xt)
+            x = xt + x
+        if x_mask is not None:
+            x = x * x_mask
+        return x
+
+    def remove_weight_norm(self):
+        for l in self.convs1:
+            remove_weight_norm(l)
+        for l in self.convs2:
+            remove_weight_norm(l)
+
+
+class ResBlock2(torch.nn.Module):
+    def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
+        super(ResBlock2, self).__init__()
+        self.convs = nn.ModuleList(
+            [
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=dilation[0],
+                        padding=get_padding(kernel_size, dilation[0]),
+                    )
+                ),
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=dilation[1],
+                        padding=get_padding(kernel_size, dilation[1]),
+                    )
+                ),
+            ]
+        )
+        self.convs.apply(init_weights)
+
+    def forward(self, x, x_mask=None):
+        for c in self.convs:
+            xt = F.leaky_relu(x, LRELU_SLOPE)
+            if x_mask is not None:
+                xt = xt * x_mask
+            xt = c(xt)
+            x = xt + x
+        if x_mask is not None:
+            x = x * x_mask
+        return x
+
+    def remove_weight_norm(self):
+        for l in self.convs:
+            remove_weight_norm(l)
+
+
+class Flip(nn.Module):
+    def forward(self, x, *args, reverse=False, **kwargs):
+        x = torch.flip(x, [1])
+        if not reverse:
+            logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
+            return x, logdet
+        else:
+            return x
+
+
+class ResidualCouplingLayer(nn.Module):
+    def __init__(
+        self,
+        channels,
+        hidden_channels,
+        kernel_size,
+        dilation_rate,
+        n_layers,
+        p_dropout=0,
+        gin_channels=0,
+        mean_only=False,
+    ):
+        assert channels % 2 == 0, "channels should be divisible by 2"
+        super().__init__()
+        self.channels = channels
+        self.hidden_channels = hidden_channels
+        self.kernel_size = kernel_size
+        self.dilation_rate = dilation_rate
+        self.n_layers = n_layers
+        self.half_channels = channels // 2
+        self.mean_only = mean_only
+
+        self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
+        self.enc = WN(
+            hidden_channels,
+            kernel_size,
+            dilation_rate,
+            n_layers,
+            p_dropout=p_dropout,
+            gin_channels=gin_channels,
+        )
+        self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
+        self.post.weight.data.zero_()
+        self.post.bias.data.zero_()
+
+    def forward(self, x, x_mask, g=None, reverse=False):
+        x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
+        h = self.pre(x0) * x_mask
+        h = self.enc(h, x_mask, g=g)
+        stats = self.post(h) * x_mask
+        if not self.mean_only:
+            m, logs = torch.split(stats, [self.half_channels] * 2, 1)
+        else:
+            m = stats
+            logs = torch.zeros_like(m)
+
+        if not reverse:
+            x1 = m + x1 * torch.exp(logs) * x_mask
+            x = torch.cat([x0, x1], 1)
+            logdet = torch.sum(logs, [1, 2])
+            return x, logdet
+        else:
+            x1 = (x1 - m) * torch.exp(-logs) * x_mask
+            x = torch.cat([x0, x1], 1)
+            return x
+
+
+class LinearNorm(nn.Module):
+    def __init__(
+        self,
+        in_channels,
+        out_channels,
+        bias=True,
+        spectral_norm=False,
+    ):
+        super(LinearNorm, self).__init__()
+        self.fc = nn.Linear(in_channels, out_channels, bias)
+
+        if spectral_norm:
+            self.fc = nn.utils.spectral_norm(self.fc)
+
+    def forward(self, input):
+        out = self.fc(input)
+        return out
+
+
+class Mish(nn.Module):
+    def __init__(self):
+        super(Mish, self).__init__()
+
+    def forward(self, x):
+        return x * torch.tanh(F.softplus(x))
+
+
+class Conv1dGLU(nn.Module):
+    """
+    Conv1d + GLU(Gated Linear Unit) with residual connection.
+    For GLU refer to https://arxiv.org/abs/1612.08083 paper.
+    """
+
+    def __init__(self, in_channels, out_channels, kernel_size, dropout):
+        super(Conv1dGLU, self).__init__()
+        self.out_channels = out_channels
+        self.conv1 = ConvNorm(in_channels, 2 * out_channels, kernel_size=kernel_size)
+        self.dropout = nn.Dropout(dropout)
+
+    def forward(self, x):
+        residual = x
+        x = self.conv1(x)
+        x1, x2 = torch.split(x, split_size_or_sections=self.out_channels, dim=1)
+        x = x1 * torch.sigmoid(x2)
+        x = residual + self.dropout(x)
+        return x
+
+
+class ConvNorm(nn.Module):
+    def __init__(
+        self,
+        in_channels,
+        out_channels,
+        kernel_size=1,
+        stride=1,
+        padding=None,
+        dilation=1,
+        bias=True,
+        spectral_norm=False,
+    ):
+        super(ConvNorm, self).__init__()
+
+        if padding is None:
+            assert kernel_size % 2 == 1
+            padding = int(dilation * (kernel_size - 1) / 2)
+
+        self.conv = torch.nn.Conv1d(
+            in_channels,
+            out_channels,
+            kernel_size=kernel_size,
+            stride=stride,
+            padding=padding,
+            dilation=dilation,
+            bias=bias,
+        )
+
+        if spectral_norm:
+            self.conv = nn.utils.spectral_norm(self.conv)
+
+    def forward(self, input):
+        out = self.conv(input)
+        return out
+
+
+class MultiHeadAttention(nn.Module):
+    """Multi-Head Attention module"""
+
+    def __init__(self, n_head, d_model, d_k, d_v, dropout=0.0, spectral_norm=False):
+        super().__init__()
+
+        self.n_head = n_head
+        self.d_k = d_k
+        self.d_v = d_v
+
+        self.w_qs = nn.Linear(d_model, n_head * d_k)
+        self.w_ks = nn.Linear(d_model, n_head * d_k)
+        self.w_vs = nn.Linear(d_model, n_head * d_v)
+
+        self.attention = ScaledDotProductAttention(
+            temperature=np.power(d_model, 0.5), dropout=dropout
+        )
+
+        self.fc = nn.Linear(n_head * d_v, d_model)
+        self.dropout = nn.Dropout(dropout)
+
+        if spectral_norm:
+            self.w_qs = nn.utils.spectral_norm(self.w_qs)
+            self.w_ks = nn.utils.spectral_norm(self.w_ks)
+            self.w_vs = nn.utils.spectral_norm(self.w_vs)
+            self.fc = nn.utils.spectral_norm(self.fc)
+
+    def forward(self, x, mask=None):
+        d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
+        sz_b, len_x, _ = x.size()
+
+        residual = x
+
+        q = self.w_qs(x).view(sz_b, len_x, n_head, d_k)
+        k = self.w_ks(x).view(sz_b, len_x, n_head, d_k)
+        v = self.w_vs(x).view(sz_b, len_x, n_head, d_v)
+        q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_x, d_k)  # (n*b) x lq x dk
+        k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_x, d_k)  # (n*b) x lk x dk
+        v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_x, d_v)  # (n*b) x lv x dv
+
+        if mask is not None:
+            slf_mask = mask.repeat(n_head, 1, 1)  # (n*b) x .. x ..
+        else:
+            slf_mask = None
+        output, attn = self.attention(q, k, v, mask=slf_mask)
+
+        output = output.view(n_head, sz_b, len_x, d_v)
+        output = (
+            output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_x, -1)
+        )  # b x lq x (n*dv)
+
+        output = self.fc(output)
+
+        output = self.dropout(output) + residual
+        return output, attn
+
+
+class ScaledDotProductAttention(nn.Module):
+    """Scaled Dot-Product Attention"""
+
+    def __init__(self, temperature, dropout):
+        super().__init__()
+        self.temperature = temperature
+        self.softmax = nn.Softmax(dim=2)
+        self.dropout = nn.Dropout(dropout)
+
+    def forward(self, q, k, v, mask=None):
+        attn = torch.bmm(q, k.transpose(1, 2))
+        attn = attn / self.temperature
+
+        if mask is not None:
+            attn = attn.masked_fill(mask, -np.inf)
+
+        attn = self.softmax(attn)
+        p_attn = self.dropout(attn)
+
+        output = torch.bmm(p_attn, v)
+        return output, attn
+
+
+class MelStyleEncoder(nn.Module):
+    """MelStyleEncoder"""
+
+    def __init__(
+        self,
+        n_mel_channels=80,
+        style_hidden=128,
+        style_vector_dim=256,
+        style_kernel_size=5,
+        style_head=2,
+        dropout=0.1,
+    ):
+        super(MelStyleEncoder, self).__init__()
+        self.in_dim = n_mel_channels
+        self.hidden_dim = style_hidden
+        self.out_dim = style_vector_dim
+        self.kernel_size = style_kernel_size
+        self.n_head = style_head
+        self.dropout = dropout
+
+        self.spectral = nn.Sequential(
+            LinearNorm(self.in_dim, self.hidden_dim),
+            Mish(),
+            nn.Dropout(self.dropout),
+            LinearNorm(self.hidden_dim, self.hidden_dim),
+            Mish(),
+            nn.Dropout(self.dropout),
+        )
+
+        self.temporal = nn.Sequential(
+            Conv1dGLU(self.hidden_dim, self.hidden_dim, self.kernel_size, self.dropout),
+            Conv1dGLU(self.hidden_dim, self.hidden_dim, self.kernel_size, self.dropout),
+        )
+
+        self.slf_attn = MultiHeadAttention(
+            self.n_head,
+            self.hidden_dim,
+            self.hidden_dim // self.n_head,
+            self.hidden_dim // self.n_head,
+            self.dropout,
+        )
+
+        self.fc = LinearNorm(self.hidden_dim, self.out_dim)
+
+    def temporal_avg_pool(self, x, mask=None):
+        if mask is None:
+            out = torch.mean(x, dim=1)
+        else:
+            len_ = (~mask).sum(dim=1).unsqueeze(1)
+            x = x.masked_fill(mask.unsqueeze(-1), 0)
+            x = x.sum(dim=1)
+            out = torch.div(x, len_)
+        return out
+
+    def forward(self, x, mask=None):
+        x = x.transpose(1, 2)
+        if mask is not None:
+            mask = (mask.int() == 0).squeeze(1)
+        max_len = x.shape[1]
+        slf_attn_mask = (
+            mask.unsqueeze(1).expand(-1, max_len, -1) if mask is not None else None
+        )
+
+        # spectral
+        x = self.spectral(x)
+        # temporal
+        x = x.transpose(1, 2)
+        x = self.temporal(x)
+        x = x.transpose(1, 2)
+        # self-attention
+        if mask is not None:
+            x = x.masked_fill(mask.unsqueeze(-1), 0)
+        x, _ = self.slf_attn(x, mask=slf_attn_mask)
+        # fc
+        x = self.fc(x)
+        # temoral average pooling
+        w = self.temporal_avg_pool(x, mask=mask)
+
+        return w.unsqueeze(-1)

+ 60 - 0
fish_speech/models/vits_decoder/modules/mrte.py

@@ -0,0 +1,60 @@
+import torch
+from torch import nn
+from torch.nn.utils import remove_weight_norm, weight_norm
+
+from fish_speech.models.vits_decoder.modules.attentions import MultiHeadAttention
+
+
+class MRTE(nn.Module):
+    def __init__(
+        self,
+        content_enc_channels=192,
+        hidden_size=512,
+        out_channels=192,
+        kernel_size=5,
+        n_heads=4,
+        ge_layer=2,
+    ):
+        super(MRTE, self).__init__()
+        self.cross_attention = MultiHeadAttention(hidden_size, hidden_size, n_heads)
+        self.c_pre = nn.Conv1d(content_enc_channels, hidden_size, 1)
+        self.text_pre = nn.Conv1d(content_enc_channels, hidden_size, 1)
+        self.c_post = nn.Conv1d(hidden_size, out_channels, 1)
+
+    def forward(self, ssl_enc, ssl_mask, text, text_mask, ge, test=None):
+        if ge == None:
+            ge = 0
+        attn_mask = text_mask.unsqueeze(2) * ssl_mask.unsqueeze(-1)
+
+        ssl_enc = self.c_pre(ssl_enc * ssl_mask)
+        text_enc = self.text_pre(text * text_mask)
+        if test != None:
+            if test == 0:
+                x = (
+                    self.cross_attention(
+                        ssl_enc * ssl_mask, text_enc * text_mask, attn_mask
+                    )
+                    + ssl_enc
+                    + ge
+                )
+            elif test == 1:
+                x = ssl_enc + ge
+            elif test == 2:
+                x = (
+                    self.cross_attention(
+                        ssl_enc * 0 * ssl_mask, text_enc * text_mask, attn_mask
+                    )
+                    + ge
+                )
+            else:
+                raise ValueError("test should be 0,1,2")
+        else:
+            x = (
+                self.cross_attention(
+                    ssl_enc * ssl_mask, text_enc * text_mask, attn_mask
+                )
+                + ssl_enc
+                + ge
+            )
+        x = self.c_post(x * ssl_mask)
+        return x

+ 69 - 0
fish_speech/models/vits_decoder/modules/vq_encoder.py

@@ -0,0 +1,69 @@
+import torch
+from torch import nn
+
+from fish_speech.models.vqgan.modules.fsq import DownsampleFiniteScalarQuantize
+from fish_speech.models.vqgan.modules.wavenet import WaveNet
+from fish_speech.models.vqgan.spectrogram import LogMelSpectrogram
+
+
+class VQEncoder(nn.Module):
+    def __init__(
+        self,
+    ):
+        super().__init__()
+
+        self.encoder = WaveNet(
+            input_channels=128,
+            residual_channels=768,
+            residual_layers=20,
+            dilation_cycle=4,
+        )
+
+        self.quantizer = DownsampleFiniteScalarQuantize(
+            input_dim=768, n_codebooks=1, n_groups=2, levels=[8, 5, 5, 5]
+        )
+
+        self.spec = LogMelSpectrogram(
+            sample_rate=44100,
+            n_fft=2048,
+            win_length=2048,
+            hop_length=512,
+            n_mels=128,
+            f_min=0.0,
+            f_max=8000.0,
+        )
+
+        self.eval()
+        e = self.load_state_dict(
+            torch.load("checkpoints/vq-gan-group-fsq-2x1024.pth", map_location="cpu"),
+            strict=False,
+        )
+
+        assert len(e.missing_keys) == 0, e.missing_keys
+        assert all(
+            k.startswith("decoder.")
+            or k.startswith("quality_projection.")
+            or k.startswith("discriminator.")
+            for k in e.unexpected_keys
+        ), e.unexpected_keys
+
+    @torch.no_grad()
+    def forward(self, audios, audio_lengths, use_decoder=False, sr=None):
+        mel_spec = self.spec(audios, sample_rate=sr)
+
+        if sr is not None:
+            audio_lengths = audio_lengths * 44100 // sr
+
+        mel_lengths = audio_lengths // self.spec.hop_length
+        mel_masks = (
+            torch.arange(mel_spec.shape[2], device=mel_spec.device)
+            < mel_lengths[:, None]
+        )
+        mel_masks_float_conv = mel_masks[:, None, :].float()
+        mels = mel_spec * mel_masks_float_conv
+
+        # Encode
+        encoded_features = self.encoder(mels) * mel_masks_float_conv
+        encoded_features = self.quantizer(encoded_features).z * mel_masks_float_conv
+
+        return encoded_features