Lengyue пре 2 година
родитељ
комит
d229cee8f4

+ 0 - 197
fish_speech/models/vqgan/losses.py

@@ -1,197 +0,0 @@
-import torch
-import torch.nn.functional as F
-from torch import nn
-
-
-def feature_loss(fmap_r: list[torch.Tensor], fmap_g: list[torch.Tensor]):
-    loss = 0
-    for dr, dg in zip(fmap_r, fmap_g):
-        dr = dr.float().detach()
-        dg = dg.float()
-        loss += torch.mean(torch.abs(dr - dg))
-
-    return loss * 2
-
-
-def discriminator_loss(
-    disc_real_outputs: list[torch.Tensor], disc_generated_outputs: list[torch.Tensor]
-):
-    loss = 0
-    r_losses = []
-    g_losses = []
-    for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
-        dr = dr.float()
-        dg = dg.float()
-        r_loss = torch.mean((1 - dr) ** 2)
-        g_loss = torch.mean(dg**2)
-        loss += r_loss + g_loss
-        r_losses.append(r_loss.item())
-        g_losses.append(g_loss.item())
-
-    return loss, r_losses, g_losses
-
-
-def generator_loss(disc_outputs: list[torch.Tensor]):
-    loss = 0
-    gen_losses = []
-    for dg in disc_outputs:
-        dg = dg.float()
-        l = torch.mean((1 - dg) ** 2)
-        gen_losses.append(l)
-        loss += l
-
-    return loss, gen_losses
-
-
-def kl_loss(
-    z_p: torch.Tensor,
-    logs_q: torch.Tensor,
-    m_p: torch.Tensor,
-    logs_p: torch.Tensor,
-    z_mask: torch.Tensor,
-):
-    """
-    z_p, logs_q: [b, h, t_t]
-    m_p, logs_p: [b, h, t_t]
-    """
-    z_p = z_p.float()
-    logs_q = logs_q.float()
-    m_p = m_p.float()
-    logs_p = logs_p.float()
-    z_mask = z_mask.float()
-
-    kl = logs_p - logs_q - 0.5
-    kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p)
-    kl = torch.sum(kl * z_mask)
-    l = kl / torch.sum(z_mask)
-    return l
-
-
-def stft(x, fft_size, hop_size, win_length, window):
-    """Perform STFT and convert to magnitude spectrogram.
-    Args:
-        x (Tensor): Input signal tensor (B, T).
-        fft_size (int): FFT size.
-        hop_size (int): Hop size.
-        win_length (int): Window length.
-        window (str): Window function type.
-    Returns:
-        Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
-    """
-    spec = torch.stft(
-        x,
-        fft_size,
-        hop_size,
-        win_length,
-        window,
-        return_complex=True,
-        pad_mode="reflect",
-    )
-    spec = torch.view_as_real(spec)
-
-    # NOTE(kan-bayashi): clamp is needed to avoid nan or inf
-    return torch.sqrt(torch.clamp(spec.pow(2).sum(-1), min=1e-6)).transpose(2, 1)
-
-
-class SpectralConvergengeLoss(nn.Module):
-    """Spectral convergence loss module."""
-
-    def __init__(self):
-        """Initialize spectral convergence loss module."""
-        super(SpectralConvergengeLoss, self).__init__()
-
-    def forward(self, x_mag, y_mag):
-        """Calculate forward propagation.
-        Args:
-            x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
-            y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
-        Returns:
-            Tensor: Spectral convergence loss value.
-        """  # noqa: E501
-
-        return torch.norm(y_mag - x_mag, p="fro") / torch.norm(y_mag, p="fro")
-
-
-class LogSTFTMagnitudeLoss(nn.Module):
-    """Log STFT magnitude loss module."""
-
-    def __init__(self):
-        """Initialize los STFT magnitude loss module."""
-        super(LogSTFTMagnitudeLoss, self).__init__()
-
-    def forward(self, x_mag, y_mag):
-        """Calculate forward propagation.
-        Args:
-            x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
-            y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
-        Returns:
-            Tensor: Log STFT magnitude loss value.
-        """  # noqa: E501
-
-        return F.l1_loss(torch.log(y_mag), torch.log(x_mag))
-
-
-class STFTLoss(nn.Module):
-    """STFT loss module."""
-
-    def __init__(
-        self, fft_size=1024, shift_size=120, win_length=600, window=torch.hann_window
-    ):
-        """Initialize STFT loss module."""
-        super(STFTLoss, self).__init__()
-
-        self.fft_size = fft_size
-        self.shift_size = shift_size
-        self.win_length = win_length
-        self.register_buffer("window", window(win_length))
-        self.spectral_convergenge_loss = SpectralConvergengeLoss()
-        self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss()
-
-    def forward(self, x, y):
-        """Calculate forward propagation.
-        Args:
-            x (Tensor): Predicted signal (B, T).
-            y (Tensor): Groundtruth signal (B, T).
-        Returns:
-            Tensor: Spectral convergence loss value.
-            Tensor: Log STFT magnitude loss value.
-        """
-
-        x_mag = stft(x, self.fft_size, self.shift_size, self.win_length, self.window)
-        y_mag = stft(y, self.fft_size, self.shift_size, self.win_length, self.window)
-        sc_loss = self.spectral_convergenge_loss(x_mag, y_mag)
-        mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag)
-
-        return sc_loss, mag_loss
-
-
-class MultiResolutionSTFTLoss(nn.Module):
-    """Multi resolution STFT loss module."""
-
-    def __init__(self, resolutions, window=torch.hann_window):
-        super(MultiResolutionSTFTLoss, self).__init__()
-
-        self.stft_losses = nn.ModuleList()
-        for fs, ss, wl in resolutions:
-            self.stft_losses += [STFTLoss(fs, ss, wl, window)]
-
-    def forward(self, x, y):
-        """Calculate forward propagation.
-        Args:
-            x (Tensor): Predicted signal (B, T).
-            y (Tensor): Groundtruth signal (B, T).
-        Returns:
-            Tensor: Multi resolution spectral convergence loss value.
-            Tensor: Multi resolution log STFT magnitude loss value.
-        """
-        sc_loss = 0.0
-        mag_loss = 0.0
-        for f in self.stft_losses:
-            sc_l, mag_l = f(x, y)
-            sc_loss += sc_l
-            mag_loss += mag_l
-
-        sc_loss /= len(self.stft_losses)
-        mag_loss /= len(self.stft_losses)
-
-        return sc_loss, mag_loss

+ 0 - 349
fish_speech/models/vqgan/modules/attentions.py

@@ -1,349 +0,0 @@
-import math
-
-import torch
-from torch import nn
-from torch.nn import functional as F
-from torch.nn.utils import remove_weight_norm, weight_norm
-
-from fish_speech.models.vqgan.modules import commons
-from fish_speech.models.vqgan.modules.modules import LayerNorm
-
-
-class Encoder(nn.Module):
-    def __init__(
-        self,
-        hidden_channels,
-        filter_channels,
-        n_heads,
-        n_layers,
-        kernel_size=1,
-        p_dropout=0.0,
-        window_size=4,
-        isflow=False,
-        gin_channels=0,
-    ):
-        super().__init__()
-        self.hidden_channels = hidden_channels
-        self.filter_channels = filter_channels
-        self.n_heads = n_heads
-        self.n_layers = n_layers
-        self.kernel_size = kernel_size
-        self.p_dropout = p_dropout
-        self.window_size = window_size
-
-        self.drop = nn.Dropout(p_dropout)
-        self.attn_layers = nn.ModuleList()
-        self.norm_layers_1 = nn.ModuleList()
-        self.ffn_layers = nn.ModuleList()
-        self.norm_layers_2 = nn.ModuleList()
-        for i in range(self.n_layers):
-            self.attn_layers.append(
-                MultiHeadAttention(
-                    hidden_channels,
-                    hidden_channels,
-                    n_heads,
-                    p_dropout=p_dropout,
-                    window_size=window_size,
-                )
-            )
-            self.norm_layers_1.append(LayerNorm(hidden_channels))
-            self.ffn_layers.append(
-                FFN(
-                    hidden_channels,
-                    hidden_channels,
-                    filter_channels,
-                    kernel_size,
-                    p_dropout=p_dropout,
-                )
-            )
-            self.norm_layers_2.append(LayerNorm(hidden_channels))
-
-        if isflow:
-            cond_layer = torch.nn.Conv1d(
-                gin_channels, 2 * hidden_channels * n_layers, 1
-            )
-            self.cond_pre = torch.nn.Conv1d(hidden_channels, 2 * hidden_channels, 1)
-            self.cond_layer = weight_norm(cond_layer, "weight")
-            self.gin_channels = gin_channels
-
-    def forward(self, x, x_mask, g=None):
-        attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
-        x = x * x_mask
-        if g is not None:
-            g = self.cond_layer(g)
-
-        for i in range(self.n_layers):
-            if g is not None:
-                x = self.cond_pre(x)
-                cond_offset = i * 2 * self.hidden_channels
-                g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
-                x = commons.fused_add_tanh_sigmoid_multiply(
-                    x, g_l, torch.IntTensor([self.hidden_channels])
-                )
-            y = self.attn_layers[i](x, x, attn_mask)
-            y = self.drop(y)
-            x = self.norm_layers_1[i](x + y)
-
-            y = self.ffn_layers[i](x, x_mask)
-            y = self.drop(y)
-            x = self.norm_layers_2[i](x + y)
-        x = x * x_mask
-        return x
-
-
-class MultiHeadAttention(nn.Module):
-    def __init__(
-        self,
-        channels,
-        out_channels,
-        n_heads,
-        p_dropout=0.0,
-        window_size=None,
-        heads_share=True,
-        block_length=None,
-        proximal_bias=False,
-        proximal_init=False,
-    ):
-        super().__init__()
-        assert channels % n_heads == 0
-
-        self.channels = channels
-        self.out_channels = out_channels
-        self.n_heads = n_heads
-        self.p_dropout = p_dropout
-        self.window_size = window_size
-        self.heads_share = heads_share
-        self.block_length = block_length
-        self.proximal_bias = proximal_bias
-        self.proximal_init = proximal_init
-        self.attn = None
-
-        self.k_channels = channels // n_heads
-        self.conv_q = nn.Conv1d(channels, channels, 1)
-        self.conv_k = nn.Conv1d(channels, channels, 1)
-        self.conv_v = nn.Conv1d(channels, channels, 1)
-        self.conv_o = nn.Conv1d(channels, out_channels, 1)
-        self.drop = nn.Dropout(p_dropout)
-
-        if window_size is not None:
-            n_heads_rel = 1 if heads_share else n_heads
-            rel_stddev = self.k_channels**-0.5
-            self.emb_rel_k = nn.Parameter(
-                torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
-                * rel_stddev
-            )
-            self.emb_rel_v = nn.Parameter(
-                torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
-                * rel_stddev
-            )
-
-        nn.init.xavier_uniform_(self.conv_q.weight)
-        nn.init.xavier_uniform_(self.conv_k.weight)
-        nn.init.xavier_uniform_(self.conv_v.weight)
-        if proximal_init:
-            with torch.no_grad():
-                self.conv_k.weight.copy_(self.conv_q.weight)
-                self.conv_k.bias.copy_(self.conv_q.bias)
-
-    def forward(self, x, c, attn_mask=None):
-        q = self.conv_q(x)
-        k = self.conv_k(c)
-        v = self.conv_v(c)
-
-        x, self.attn = self.attention(q, k, v, mask=attn_mask)
-
-        x = self.conv_o(x)
-        return x
-
-    def attention(self, query, key, value, mask=None):
-        # reshape [b, d, t] -> [b, n_h, t, d_k]
-        b, d, t_s, t_t = (*key.size(), query.size(2))
-        query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
-        key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
-        value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
-
-        scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
-        if self.window_size is not None:
-            assert (
-                t_s == t_t
-            ), "Relative attention is only available for self-attention."
-            key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
-            rel_logits = self._matmul_with_relative_keys(
-                query / math.sqrt(self.k_channels), key_relative_embeddings
-            )
-            scores_local = self._relative_position_to_absolute_position(rel_logits)
-            scores = scores + scores_local
-        if self.proximal_bias:
-            assert t_s == t_t, "Proximal bias is only available for self-attention."
-            scores = scores + self._attention_bias_proximal(t_s).to(
-                device=scores.device, dtype=scores.dtype
-            )
-        if mask is not None:
-            scores = scores.masked_fill(mask == 0, -1e4)
-            if self.block_length is not None:
-                assert (
-                    t_s == t_t
-                ), "Local attention is only available for self-attention."
-                block_mask = (
-                    torch.ones_like(scores)
-                    .triu(-self.block_length)
-                    .tril(self.block_length)
-                )
-                scores = scores.masked_fill(block_mask == 0, -1e4)
-        p_attn = F.softmax(scores, dim=-1)  # [b, n_h, t_t, t_s]
-        p_attn = self.drop(p_attn)
-        output = torch.matmul(p_attn, value)
-        if self.window_size is not None:
-            relative_weights = self._absolute_position_to_relative_position(p_attn)
-            value_relative_embeddings = self._get_relative_embeddings(
-                self.emb_rel_v, t_s
-            )
-            output = output + self._matmul_with_relative_values(
-                relative_weights, value_relative_embeddings
-            )
-        output = (
-            output.transpose(2, 3).contiguous().view(b, d, t_t)
-        )  # [b, n_h, t_t, d_k] -> [b, d, t_t]
-        return output, p_attn
-
-    def _matmul_with_relative_values(self, x, y):
-        """
-        x: [b, h, l, m]
-        y: [h or 1, m, d]
-        ret: [b, h, l, d]
-        """
-        ret = torch.matmul(x, y.unsqueeze(0))
-        return ret
-
-    def _matmul_with_relative_keys(self, x, y):
-        """
-        x: [b, h, l, d]
-        y: [h or 1, m, d]
-        ret: [b, h, l, m]
-        """
-        ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
-        return ret
-
-    def _get_relative_embeddings(self, relative_embeddings, length):
-        max_relative_position = 2 * self.window_size + 1
-        # Pad first before slice to avoid using cond ops.
-        pad_length = max(length - (self.window_size + 1), 0)
-        slice_start_position = max((self.window_size + 1) - length, 0)
-        slice_end_position = slice_start_position + 2 * length - 1
-        if pad_length > 0:
-            padded_relative_embeddings = F.pad(
-                relative_embeddings,
-                commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
-            )
-        else:
-            padded_relative_embeddings = relative_embeddings
-        used_relative_embeddings = padded_relative_embeddings[
-            :, slice_start_position:slice_end_position
-        ]
-        return used_relative_embeddings
-
-    def _relative_position_to_absolute_position(self, x):
-        """
-        x: [b, h, l, 2*l-1]
-        ret: [b, h, l, l]
-        """
-        batch, heads, length, _ = x.size()
-        # Concat columns of pad to shift from relative to absolute indexing.
-        x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
-
-        # Concat extra elements so to add up to shape (len+1, 2*len-1).
-        x_flat = x.view([batch, heads, length * 2 * length])
-        x_flat = F.pad(
-            x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
-        )
-
-        # Reshape and slice out the padded elements.
-        x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
-            :, :, :length, length - 1 :
-        ]
-        return x_final
-
-    def _absolute_position_to_relative_position(self, x):
-        """
-        x: [b, h, l, l]
-        ret: [b, h, l, 2*l-1]
-        """
-        batch, heads, length, _ = x.size()
-        # pad along column
-        x = F.pad(
-            x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
-        )
-        x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
-        # add 0's in the beginning that will skew the elements after reshape
-        x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
-        x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
-        return x_final
-
-    def _attention_bias_proximal(self, length):
-        """Bias for self-attention to encourage attention to close positions.
-        Args:
-          length: an integer scalar.
-        Returns:
-          a Tensor with shape [1, 1, length, length]
-        """
-        r = torch.arange(length, dtype=torch.float32)
-        diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
-        return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
-
-
-class FFN(nn.Module):
-    def __init__(
-        self,
-        in_channels,
-        out_channels,
-        filter_channels,
-        kernel_size,
-        p_dropout=0.0,
-        activation=None,
-        causal=False,
-    ):
-        super().__init__()
-        self.in_channels = in_channels
-        self.out_channels = out_channels
-        self.filter_channels = filter_channels
-        self.kernel_size = kernel_size
-        self.p_dropout = p_dropout
-        self.activation = activation
-        self.causal = causal
-
-        if causal:
-            self.padding = self._causal_padding
-        else:
-            self.padding = self._same_padding
-
-        self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
-        self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
-        self.drop = nn.Dropout(p_dropout)
-
-    def forward(self, x, x_mask):
-        x = self.conv_1(self.padding(x * x_mask))
-        if self.activation == "gelu":
-            x = x * torch.sigmoid(1.702 * x)
-        else:
-            x = torch.relu(x)
-        x = self.drop(x)
-        x = self.conv_2(self.padding(x * x_mask))
-        return x * x_mask
-
-    def _causal_padding(self, x):
-        if self.kernel_size == 1:
-            return x
-        pad_l = self.kernel_size - 1
-        pad_r = 0
-        padding = [[0, 0], [0, 0], [pad_l, pad_r]]
-        x = F.pad(x, commons.convert_pad_shape(padding))
-        return x
-
-    def _same_padding(self, x):
-        if self.kernel_size == 1:
-            return x
-        pad_l = (self.kernel_size - 1) // 2
-        pad_r = self.kernel_size // 2
-        padding = [[0, 0], [0, 0], [pad_l, pad_r]]
-        x = F.pad(x, commons.convert_pad_shape(padding))
-        return x

+ 0 - 192
fish_speech/models/vqgan/modules/commons.py

@@ -1,192 +0,0 @@
-import math
-
-import numpy as np
-import torch
-from torch import nn
-from torch.nn import functional as F
-
-
-def init_weights(m, mean=0.0, std=0.01):
-    classname = m.__class__.__name__
-    if classname.find("Conv") != -1:
-        m.weight.data.normal_(mean, std)
-
-
-def get_padding(kernel_size, dilation=1):
-    return int((kernel_size * dilation - dilation) / 2)
-
-
-def convert_pad_shape(pad_shape):
-    l = pad_shape[::-1]
-    pad_shape = [item for sublist in l for item in sublist]
-    return pad_shape
-
-
-def intersperse(lst, item):
-    result = [item] * (len(lst) * 2 + 1)
-    result[1::2] = lst
-    return result
-
-
-def kl_divergence(m_p, logs_p, m_q, logs_q):
-    """KL(P||Q)"""
-    kl = (logs_q - logs_p) - 0.5
-    kl += (
-        0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
-    )
-    return kl
-
-
-def rand_gumbel(shape):
-    """Sample from the Gumbel distribution, protect from overflows."""
-    uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
-    return -torch.log(-torch.log(uniform_samples))
-
-
-def rand_gumbel_like(x):
-    g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
-    return g
-
-
-def slice_segments(x, ids_str, segment_size=4):
-    ret = torch.zeros_like(x[:, :, :segment_size])
-    for i in range(x.size(0)):
-        idx_str = ids_str[i]
-        idx_end = idx_str + segment_size
-        ret[i] = x[i, :, idx_str:idx_end]
-    return ret
-
-
-def rand_slice_segments(x, x_lengths=None, segment_size=4):
-    b, d, t = x.size()
-    if x_lengths is None:
-        x_lengths = t
-    ids_str_max = x_lengths - segment_size + 1
-    ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
-    ret = slice_segments(x, ids_str, segment_size)
-    return ret, ids_str
-
-
-def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
-    position = torch.arange(length, dtype=torch.float)
-    num_timescales = channels // 2
-    log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
-        num_timescales - 1
-    )
-    inv_timescales = min_timescale * torch.exp(
-        torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
-    )
-    scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
-    signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
-    signal = F.pad(signal, [0, 0, 0, channels % 2])
-    signal = signal.view(1, channels, length)
-    return signal
-
-
-def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
-    b, channels, length = x.size()
-    signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
-    return x + signal.to(dtype=x.dtype, device=x.device)
-
-
-def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
-    b, channels, length = x.size()
-    signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
-    return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
-
-
-def subsequent_mask(length):
-    mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
-    return mask
-
-
-@torch.jit.script
-def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
-    n_channels_int = n_channels[0]
-    in_act = input_a + input_b
-    t_act = torch.tanh(in_act[:, :n_channels_int, :])
-    s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
-    acts = t_act * s_act
-    return acts
-
-
-def convert_pad_shape(pad_shape):
-    l = pad_shape[::-1]
-    pad_shape = [item for sublist in l for item in sublist]
-    return pad_shape
-
-
-def shift_1d(x):
-    x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
-    return x
-
-
-def sequence_mask(length, max_length=None):
-    if max_length is None:
-        max_length = length.max()
-    x = torch.arange(max_length, dtype=length.dtype, device=length.device)
-    return x.unsqueeze(0) < length.unsqueeze(1)
-
-
-def generate_path(duration, mask):
-    """
-    duration: [b, 1, t_x]
-    mask: [b, 1, t_y, t_x]
-    """
-    device = duration.device
-
-    b, _, t_y, t_x = mask.shape
-    cum_duration = torch.cumsum(duration, -1)
-
-    cum_duration_flat = cum_duration.view(b * t_x)
-    path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
-    path = path.view(b, t_x, t_y)
-    path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
-    path = path.unsqueeze(1).transpose(2, 3) * mask
-    return path
-
-
-def clip_grad_value_(parameters, clip_value, norm_type=2):
-    if isinstance(parameters, torch.Tensor):
-        parameters = [parameters]
-    parameters = list(filter(lambda p: p.grad is not None, parameters))
-    norm_type = float(norm_type)
-    if clip_value is not None:
-        clip_value = float(clip_value)
-
-    total_norm = 0
-    for p in parameters:
-        param_norm = p.grad.data.norm(norm_type)
-        total_norm += param_norm.item() ** norm_type
-        if clip_value is not None:
-            p.grad.data.clamp_(min=-clip_value, max=clip_value)
-    total_norm = total_norm ** (1.0 / norm_type)
-    return total_norm
-
-
-def squeeze(x, x_mask=None, n_sqz=2):
-    b, c, t = x.size()
-
-    t = (t // n_sqz) * n_sqz
-    x = x[:, :, :t]
-    x_sqz = x.view(b, c, t // n_sqz, n_sqz)
-    x_sqz = x_sqz.permute(0, 3, 1, 2).contiguous().view(b, c * n_sqz, t // n_sqz)
-
-    if x_mask is not None:
-        x_mask = x_mask[:, :, n_sqz - 1 :: n_sqz]
-    else:
-        x_mask = torch.ones(b, 1, t // n_sqz).to(device=x.device, dtype=x.dtype)
-    return x_sqz * x_mask, x_mask
-
-
-def unsqueeze(x, x_mask=None, n_sqz=2):
-    b, c, t = x.size()
-
-    x_unsqz = x.view(b, n_sqz, c // n_sqz, t)
-    x_unsqz = x_unsqz.permute(0, 2, 3, 1).contiguous().view(b, c // n_sqz, t * n_sqz)
-
-    if x_mask is not None:
-        x_mask = x_mask.unsqueeze(-1).repeat(1, 1, 1, n_sqz).view(b, 1, t * n_sqz)
-    else:
-        x_mask = torch.ones(b, 1, t * n_sqz).to(device=x.device, dtype=x.dtype)
-    return x_unsqz * x_mask, x_mask

+ 0 - 693
fish_speech/models/vqgan/modules/models.py

@@ -1,693 +0,0 @@
-import copy
-import math
-
-import torch
-from torch import nn
-from torch.cuda.amp import autocast
-from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d
-from torch.nn import functional as F
-from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
-
-from fish_speech.models.vqgan.modules import attentions, commons, modules
-from fish_speech.models.vqgan.modules.commons import get_padding, init_weights
-from fish_speech.models.vqgan.modules.rvq import DownsampleResidualVectorQuantizer
-
-
-class FeatureEncoder(nn.Module):
-    def __init__(
-        self,
-        spec_channels,
-        out_channels,
-        hidden_channels,
-        n_layers,
-        kernel_size,
-        p_dropout,
-        codebook_size=1024,
-        num_codebooks=2,
-        gin_channels=0,
-        aux_spec_channels=None,
-    ):
-        super().__init__()
-        self.out_channels = out_channels
-        self.hidden_channels = hidden_channels
-        self.n_layers = n_layers
-        self.kernel_size = kernel_size
-        self.p_dropout = p_dropout
-
-        if aux_spec_channels is None:
-            aux_spec_channels = spec_channels
-
-        self.spec_proj = nn.Conv1d(spec_channels, hidden_channels, 1)
-
-        self.encoder = modules.WN(
-            hidden_channels=hidden_channels,
-            kernel_size=kernel_size,
-            dilation_rate=1,
-            n_layers=n_layers // 2,
-        )
-
-        self.vq = DownsampleResidualVectorQuantizer(
-            input_dim=hidden_channels,
-            n_codebooks=num_codebooks,
-            codebook_size=codebook_size,
-            codebook_dim=hidden_channels,
-            min_quantizers=num_codebooks,
-            downsample_factor=(2,),
-        )
-
-        self.decoder = modules.WN(
-            hidden_channels=hidden_channels,
-            kernel_size=kernel_size,
-            dilation_rate=1,
-            n_layers=n_layers // 2,
-            gin_channels=gin_channels,
-        )
-
-        self.aux_decoder = modules.WN(
-            hidden_channels=hidden_channels,
-            kernel_size=kernel_size,
-            dilation_rate=1,
-            n_layers=4,
-        )
-        self.aux_proj = nn.Conv1d(hidden_channels, aux_spec_channels, 1)
-
-        self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
-
-    def forward(self, y, y_lengths, ge):
-        y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(
-            y.dtype
-        )
-
-        y = self.spec_proj(y * y_mask) * y_mask
-        y = self.encoder(y, y_mask) * y_mask
-        z, indices, loss_vq = self.vq(y)
-        y = self.decoder(z, y_mask, g=ge) * y_mask
-        decoded_aux_mel = self.aux_decoder(y, y_mask)
-        decoded_aux_mel = self.aux_proj(decoded_aux_mel) * y_mask
-
-        stats = self.proj(y) * y_mask
-        m, logs = torch.split(stats, self.out_channels, dim=1)
-        return y, m, logs, y_mask, loss_vq, decoded_aux_mel
-
-
-class ResidualCouplingBlock(nn.Module):
-    def __init__(
-        self,
-        channels,
-        hidden_channels,
-        kernel_size,
-        dilation_rate,
-        n_layers,
-        n_flows=4,
-        gin_channels=0,
-    ):
-        super().__init__()
-        self.channels = channels
-        self.hidden_channels = hidden_channels
-        self.kernel_size = kernel_size
-        self.dilation_rate = dilation_rate
-        self.n_layers = n_layers
-        self.n_flows = n_flows
-        self.gin_channels = gin_channels
-
-        self.flows = nn.ModuleList()
-        for i in range(n_flows):
-            self.flows.append(
-                modules.ResidualCouplingLayer(
-                    channels,
-                    hidden_channels,
-                    kernel_size,
-                    dilation_rate,
-                    n_layers,
-                    gin_channels=gin_channels,
-                    mean_only=True,
-                )
-            )
-            self.flows.append(modules.Flip())
-
-    def forward(self, x, x_mask, g=None, reverse=False):
-        if not reverse:
-            for flow in self.flows:
-                x, _ = flow(x, x_mask, g=g, reverse=reverse)
-        else:
-            for flow in reversed(self.flows):
-                x = flow(x, x_mask, g=g, reverse=reverse)
-        return x
-
-
-class PosteriorEncoder(nn.Module):
-    def __init__(
-        self,
-        in_channels,
-        out_channels,
-        hidden_channels,
-        kernel_size,
-        dilation_rate,
-        n_layers,
-        gin_channels=0,
-    ):
-        super().__init__()
-        self.in_channels = in_channels
-        self.out_channels = out_channels
-        self.hidden_channels = hidden_channels
-        self.kernel_size = kernel_size
-        self.dilation_rate = dilation_rate
-        self.n_layers = n_layers
-        self.gin_channels = gin_channels
-
-        self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
-        self.enc = modules.WN(
-            hidden_channels,
-            kernel_size,
-            dilation_rate,
-            n_layers,
-            gin_channels=gin_channels,
-        )
-        self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
-
-    def forward(self, x, x_lengths, g=None):
-        g = g.detach()
-        x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
-            x.dtype
-        )
-        x = self.pre(x) * x_mask
-        x = self.enc(x, x_mask, g=g)
-        stats = self.proj(x) * x_mask
-        m, logs = torch.split(stats, self.out_channels, dim=1)
-        z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
-        return z, m, logs, x_mask
-
-
-class WNEncoder(nn.Module):
-    def __init__(
-        self,
-        in_channels,
-        out_channels,
-        hidden_channels,
-        kernel_size,
-        dilation_rate,
-        n_layers,
-        gin_channels=0,
-    ):
-        super().__init__()
-        self.in_channels = in_channels
-        self.out_channels = out_channels
-        self.hidden_channels = hidden_channels
-        self.kernel_size = kernel_size
-        self.dilation_rate = dilation_rate
-        self.n_layers = n_layers
-        self.gin_channels = gin_channels
-
-        self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
-        self.enc = modules.WN(
-            hidden_channels,
-            kernel_size,
-            dilation_rate,
-            n_layers,
-            gin_channels=gin_channels,
-        )
-        self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
-        self.norm = modules.LayerNorm(out_channels)
-
-    def forward(self, x, x_lengths, g=None):
-        x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
-            x.dtype
-        )
-        x = self.pre(x) * x_mask
-        x = self.enc(x, x_mask, g=g)
-        out = self.proj(x) * x_mask
-        out = self.norm(out)
-        return out
-
-
-class Generator(torch.nn.Module):
-    def __init__(
-        self,
-        initial_channel,
-        resblock,
-        resblock_kernel_sizes,
-        resblock_dilation_sizes,
-        upsample_rates,
-        upsample_initial_channel,
-        upsample_kernel_sizes,
-        gin_channels=0,
-    ):
-        super(Generator, self).__init__()
-        self.num_kernels = len(resblock_kernel_sizes)
-        self.num_upsamples = len(upsample_rates)
-        self.conv_pre = Conv1d(
-            initial_channel, upsample_initial_channel, 7, 1, padding=3
-        )
-        resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
-
-        self.ups = nn.ModuleList()
-        for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
-            self.ups.append(
-                weight_norm(
-                    ConvTranspose1d(
-                        upsample_initial_channel // (2**i),
-                        upsample_initial_channel // (2 ** (i + 1)),
-                        k,
-                        u,
-                        padding=(k - u) // 2,
-                    )
-                )
-            )
-
-        self.resblocks = nn.ModuleList()
-        for i in range(len(self.ups)):
-            ch = upsample_initial_channel // (2 ** (i + 1))
-            for j, (k, d) in enumerate(
-                zip(resblock_kernel_sizes, resblock_dilation_sizes)
-            ):
-                self.resblocks.append(resblock(ch, k, d))
-
-        self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
-        self.ups.apply(init_weights)
-
-        if gin_channels != 0:
-            self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
-
-    def forward(self, x, g=None):
-        x = self.conv_pre(x)
-        if g is not None:
-            x = x + self.cond(g)
-
-        for i in range(self.num_upsamples):
-            x = F.leaky_relu(x, modules.LRELU_SLOPE)
-            x = self.ups[i](x)
-            xs = None
-            for j in range(self.num_kernels):
-                if xs is None:
-                    xs = self.resblocks[i * self.num_kernels + j](x)
-                else:
-                    xs += self.resblocks[i * self.num_kernels + j](x)
-            x = xs / self.num_kernels
-        x = F.leaky_relu(x)
-        x = self.conv_post(x)
-        x = torch.tanh(x)
-
-        return x
-
-    def remove_weight_norm(self):
-        print("Removing weight norm...")
-        for l in self.ups:
-            remove_weight_norm(l)
-        for l in self.resblocks:
-            l.remove_weight_norm()
-
-
-class DiscriminatorP(torch.nn.Module):
-    def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
-        super(DiscriminatorP, self).__init__()
-        self.period = period
-        self.use_spectral_norm = use_spectral_norm
-        norm_f = weight_norm if use_spectral_norm == False else spectral_norm
-        self.convs = nn.ModuleList(
-            [
-                norm_f(
-                    Conv2d(
-                        1,
-                        32,
-                        (kernel_size, 1),
-                        (stride, 1),
-                        padding=(get_padding(kernel_size, 1), 0),
-                    )
-                ),
-                norm_f(
-                    Conv2d(
-                        32,
-                        128,
-                        (kernel_size, 1),
-                        (stride, 1),
-                        padding=(get_padding(kernel_size, 1), 0),
-                    )
-                ),
-                norm_f(
-                    Conv2d(
-                        128,
-                        512,
-                        (kernel_size, 1),
-                        (stride, 1),
-                        padding=(get_padding(kernel_size, 1), 0),
-                    )
-                ),
-                norm_f(
-                    Conv2d(
-                        512,
-                        1024,
-                        (kernel_size, 1),
-                        (stride, 1),
-                        padding=(get_padding(kernel_size, 1), 0),
-                    )
-                ),
-                norm_f(
-                    Conv2d(
-                        1024,
-                        1024,
-                        (kernel_size, 1),
-                        1,
-                        padding=(get_padding(kernel_size, 1), 0),
-                    )
-                ),
-            ]
-        )
-        self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
-
-    def forward(self, x):
-        fmap = []
-
-        # 1d to 2d
-        b, c, t = x.shape
-        if t % self.period != 0:  # pad first
-            n_pad = self.period - (t % self.period)
-            x = F.pad(x, (0, n_pad), "reflect")
-            t = t + n_pad
-        x = x.view(b, c, t // self.period, self.period)
-
-        for l in self.convs:
-            x = l(x)
-            x = F.leaky_relu(x, modules.LRELU_SLOPE)
-            fmap.append(x)
-        x = self.conv_post(x)
-        fmap.append(x)
-        x = torch.flatten(x, 1, -1)
-
-        return x, fmap
-
-
-class DiscriminatorS(torch.nn.Module):
-    def __init__(self, use_spectral_norm=False):
-        super(DiscriminatorS, self).__init__()
-        norm_f = weight_norm if use_spectral_norm == False else spectral_norm
-        self.convs = nn.ModuleList(
-            [
-                norm_f(Conv1d(1, 16, 15, 1, padding=7)),
-                norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
-                norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
-                norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
-                norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
-                norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
-            ]
-        )
-        self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
-
-    def forward(self, x):
-        fmap = []
-
-        for l in self.convs:
-            x = l(x)
-            x = F.leaky_relu(x, modules.LRELU_SLOPE)
-            fmap.append(x)
-        x = self.conv_post(x)
-        fmap.append(x)
-        x = torch.flatten(x, 1, -1)
-
-        return x, fmap
-
-
-class EnsembledDiscriminator(torch.nn.Module):
-    def __init__(self, periods=(2, 3, 5, 7, 11), use_spectral_norm=False):
-        super(EnsembledDiscriminator, self).__init__()
-        discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
-        discs = discs + [
-            DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
-        ]
-        self.discriminators = nn.ModuleList(discs)
-
-    def forward(self, y, y_hat):
-        y_d_rs = []
-        y_d_gs = []
-        fmap_rs = []
-        fmap_gs = []
-        for i, d in enumerate(self.discriminators):
-            y_d_r, fmap_r = d(y)
-            y_d_g, fmap_g = d(y_hat)
-            y_d_rs.append(y_d_r)
-            y_d_gs.append(y_d_g)
-            fmap_rs.append(fmap_r)
-            fmap_gs.append(fmap_g)
-
-        return y_d_rs, y_d_gs, fmap_rs, fmap_gs
-
-
-class SynthesizerTrn(nn.Module):
-    """
-    Synthesizer for Training
-    """
-
-    def __init__(
-        self,
-        *,
-        spec_channels,
-        segment_size,
-        inter_channels,
-        prior_hidden_channels,
-        prior_n_layers,
-        posterior_hidden_channels,
-        posterior_n_layers,
-        kernel_size,
-        p_dropout,
-        resblock,
-        resblock_kernel_sizes,
-        resblock_dilation_sizes,
-        upsample_rates,
-        upsample_initial_channel,
-        upsample_kernel_sizes,
-        gin_channels=0,
-        freeze_quantizer=False,
-        codebook_size=1024,
-        num_codebooks=2,
-        freeze_decoder=False,
-        freeze_posterior_encoder=False,
-        aux_spec_channels=None,
-    ):
-        super().__init__()
-        self.spec_channels = spec_channels
-        self.inter_channels = inter_channels
-        self.prior_hidden_channels = prior_hidden_channels
-        self.prior_n_layers = prior_n_layers
-        self.posterior_hidden_channels = posterior_hidden_channels
-        self.posterior_n_layers = posterior_n_layers
-        self.kernel_size = kernel_size
-        self.p_dropout = p_dropout
-        self.resblock = resblock
-        self.resblock_kernel_sizes = resblock_kernel_sizes
-        self.resblock_dilation_sizes = resblock_dilation_sizes
-        self.upsample_rates = upsample_rates
-        self.upsample_initial_channel = upsample_initial_channel
-        self.upsample_kernel_sizes = upsample_kernel_sizes
-        self.segment_size = segment_size
-        self.gin_channels = gin_channels
-
-        self.enc_p = FeatureEncoder(
-            spec_channels=spec_channels,
-            out_channels=inter_channels,
-            hidden_channels=prior_hidden_channels,
-            n_layers=prior_n_layers,
-            kernel_size=kernel_size,
-            p_dropout=p_dropout,
-            codebook_size=codebook_size,
-            num_codebooks=num_codebooks,
-            gin_channels=gin_channels,
-            aux_spec_channels=aux_spec_channels,
-        )
-        self.dec = Generator(
-            initial_channel=inter_channels,
-            resblock=resblock,
-            resblock_kernel_sizes=resblock_kernel_sizes,
-            resblock_dilation_sizes=resblock_dilation_sizes,
-            upsample_rates=upsample_rates,
-            upsample_initial_channel=upsample_initial_channel,
-            upsample_kernel_sizes=upsample_kernel_sizes,
-            gin_channels=gin_channels,
-        )
-        self.enc_q = PosteriorEncoder(
-            in_channels=spec_channels,
-            out_channels=inter_channels,
-            hidden_channels=posterior_hidden_channels,
-            kernel_size=5,
-            dilation_rate=1,
-            n_layers=posterior_n_layers,
-            gin_channels=gin_channels,
-        )
-        self.flow = ResidualCouplingBlock(
-            inter_channels,
-            posterior_hidden_channels,
-            5,
-            1,
-            4,
-            gin_channels=gin_channels,
-        )
-
-        self.ref_enc = modules.MelStyleEncoder(
-            spec_channels, style_vector_dim=gin_channels
-        )
-
-        if freeze_quantizer:
-            self.enc_p.spec_proj.requires_grad_(False)
-            self.enc_p.encoder.requires_grad_(False)
-            self.enc_p.vq.requires_grad_(False)
-
-        if freeze_decoder:
-            self.dec.requires_grad_(False)
-
-        if freeze_posterior_encoder:
-            self.enc_q.requires_grad_(False)
-
-    def forward(self, y, y_lengths):
-        y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(
-            y.dtype
-        )
-        ge = self.ref_enc(y * y_mask, y_mask)
-
-        x, m_p, logs_p, y_mask, quantized, decoded_aux_mel = self.enc_p(
-            y, y_lengths, ge
-        )
-        z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=ge)
-        z_p = self.flow(z, y_mask, g=ge)
-
-        z_slice, ids_slice = commons.rand_slice_segments(
-            z, y_lengths, self.segment_size
-        )
-        o = self.dec(z_slice, g=ge)
-
-        return (
-            o,
-            ids_slice,
-            y_mask,
-            y_mask,
-            (z, z_p, m_p, logs_p, m_q, logs_q),
-            quantized,
-            decoded_aux_mel,
-        )
-
-    def infer(self, y, y_lengths, noise_scale=0.5):
-        y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(
-            y.dtype
-        )
-        ge = self.ref_enc(y * y_mask, y_mask)
-        x, m_p, logs_p, y_mask, _, _ = self.enc_p(y, y_lengths, ge)
-        z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
-
-        z = self.flow(z_p, y_mask, g=ge, reverse=True)
-
-        o = self.dec((z * y_mask)[:, :, :], g=ge)
-        return o, y_mask, (z, z_p, m_p, logs_p)
-
-    def infer_posterior(self, y, y_lengths):
-        y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(
-            y.dtype
-        )
-        ge = self.ref_enc(y * y_mask, y_mask)
-        z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=ge)
-        o = self.dec(z * y_mask, g=ge)
-        return o, y_mask, (z, m_q, logs_q)
-
-    # @torch.no_grad()
-    # def decode(self, codes, text, refer, noise_scale=0.5):
-    #     refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
-    #     refer_mask = torch.unsqueeze(
-    #         commons.sequence_mask(refer_lengths, refer.size(2)), 1
-    #     ).to(refer.dtype)
-    #     ge = self.ref_enc(refer * refer_mask, refer_mask)
-
-    #     y_lengths = torch.LongTensor([codes.size(2) * 2]).to(codes.device)
-    #     text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)
-
-    #     quantized = self.quantizer.decode(codes)
-    #     if self.semantic_frame_rate == "25hz":
-    #         quantized = F.interpolate(
-    #             quantized, size=int(quantized.shape[-1] * 2), mode="nearest"
-    #         )
-
-    #     x, m_p, logs_p, y_mask = self.enc_p(
-    #         quantized, y_lengths, text, text_lengths, ge
-    #     )
-    #     z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
-
-    #     z = self.flow(z_p, y_mask, g=ge, reverse=True)
-
-    #     o = self.dec((z * y_mask)[:, :, :], g=ge)
-    #     return o
-
-    # def extract_latent(self, x):
-    #     ssl = self.ssl_proj(x)
-    #     quantized, codes, commit_loss, quantized_list = self.quantizer(ssl)
-    #     return codes.transpose(0, 1)
-
-
-if __name__ == "__main__":
-    model = SynthesizerTrn(
-        spec_channels=1025,
-        segment_size=20480,
-        inter_channels=192,
-        prior_hidden_channels=384,
-        posterior_hidden_channels=192,
-        prior_n_layers=16,
-        posterior_n_layers=16,
-        kernel_size=3,
-        p_dropout=0.1,
-        resblock="1",
-        resblock_kernel_sizes=[3, 7, 11],
-        resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
-        upsample_rates=[10, 8, 2, 2, 2],
-        upsample_initial_channel=512,
-        upsample_kernel_sizes=[16, 16, 8, 2, 2],
-        gin_channels=512,
-        freeze_quantizer=True,
-    )
-
-    state_dict_g = torch.load("checkpoints/gpt_sovits_g_488k.pth", map_location="cpu")
-    state_dict_d = torch.load("checkpoints/gpt_sovits_d_488k.pth", map_location="cpu")
-    keys = set(model.state_dict().keys())
-    state_dict_g = {
-        k: v for k, v in state_dict_g.items() if k in keys and "enc_p" not in k
-    }
-
-    new_state = {}
-    for k, v in state_dict_g.items():
-        new_state["generator." + k] = v
-
-    for k, v in state_dict_d.items():
-        new_state["discriminator." + k] = v
-
-    torch.save(new_state, "checkpoints/gpt_sovits_488k.pth")
-    exit()
-
-    # print(EnsembledDiscriminator().load_state_dict(state_dict_d, strict=False))
-    print(model.load_state_dict(state_dict_g, strict=False))
-
-    # y = torch.randn(3, 1025, 20480)
-    # y_lengths = torch.tensor([20480, 19000, 18000])
-
-    import librosa
-    import soundfile as sf
-
-    from fish_speech.models.vqgan.spectrogram import LinearSpectrogram
-
-    spec = LinearSpectrogram(
-        n_fft=2048, win_length=2048, hop_length=640, mode="pow2_sqrt"
-    )
-
-    audio, _ = librosa.load(
-        "/***REMOVED***/workspace/llm-multimodal-test/data/Rail_ZH/星/dbc16cc114ca1700.wav",
-        sr=32000,
-    )
-
-    y = spec(torch.tensor(audio).unsqueeze(0))
-    y_lengths = torch.tensor([y.size(2)])
-
-    o, ids_slice, y_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q), quantized = model(
-        y, y_lengths
-    )
-    print(o.shape)
-
-    o, y_mask, (z, z_p, m_p, logs_p) = model.infer(y, y_lengths)
-    print(o.shape)
-
-    o, y_mask, (z, m_q, logs_q) = model.infer_posterior(y, y_lengths)
-    print(o.shape)
-
-    o = o.squeeze(0).T.detach().cpu().numpy()
-    sf.write("test.wav", o, 32000)

+ 0 - 672
fish_speech/models/vqgan/modules/modules.py

@@ -1,672 +0,0 @@
-import numpy as np
-import torch
-from torch import nn
-from torch.nn import Conv1d
-from torch.nn import functional as F
-from torch.nn.utils import remove_weight_norm, weight_norm
-
-from fish_speech.models.vqgan.modules.commons import (
-    fused_add_tanh_sigmoid_multiply,
-    get_padding,
-    init_weights,
-)
-
-LRELU_SLOPE = 0.1
-
-
-class LayerNorm(nn.Module):
-    def __init__(self, channels, eps=1e-5):
-        super().__init__()
-        self.channels = channels
-        self.eps = eps
-
-        self.gamma = nn.Parameter(torch.ones(channels))
-        self.beta = nn.Parameter(torch.zeros(channels))
-
-    def forward(self, x):
-        x = x.transpose(1, -1)
-        x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
-        return x.transpose(1, -1)
-
-
-class ConvReluNorm(nn.Module):
-    def __init__(
-        self,
-        in_channels,
-        hidden_channels,
-        out_channels,
-        kernel_size,
-        n_layers,
-        p_dropout,
-    ):
-        super().__init__()
-        self.in_channels = in_channels
-        self.hidden_channels = hidden_channels
-        self.out_channels = out_channels
-        self.kernel_size = kernel_size
-        self.n_layers = n_layers
-        self.p_dropout = p_dropout
-        assert n_layers > 1, "Number of layers should be larger than 0."
-
-        self.conv_layers = nn.ModuleList()
-        self.norm_layers = nn.ModuleList()
-        self.conv_layers.append(
-            nn.Conv1d(
-                in_channels, hidden_channels, kernel_size, padding=kernel_size // 2
-            )
-        )
-        self.norm_layers.append(LayerNorm(hidden_channels))
-        self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout))
-        for _ in range(n_layers - 1):
-            self.conv_layers.append(
-                nn.Conv1d(
-                    hidden_channels,
-                    hidden_channels,
-                    kernel_size,
-                    padding=kernel_size // 2,
-                )
-            )
-            self.norm_layers.append(LayerNorm(hidden_channels))
-        self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
-        self.proj.weight.data.zero_()
-        self.proj.bias.data.zero_()
-
-    def forward(self, x, x_mask):
-        x_org = x
-        for i in range(self.n_layers):
-            x = self.conv_layers[i](x * x_mask)
-            x = self.norm_layers[i](x)
-            x = self.relu_drop(x)
-        x = x_org + self.proj(x)
-        return x * x_mask
-
-
-class DDSConv(nn.Module):
-    """
-    Dialted and Depth-Separable Convolution
-    """
-
-    def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0):
-        super().__init__()
-        self.channels = channels
-        self.kernel_size = kernel_size
-        self.n_layers = n_layers
-        self.p_dropout = p_dropout
-
-        self.drop = nn.Dropout(p_dropout)
-        self.convs_sep = nn.ModuleList()
-        self.convs_1x1 = nn.ModuleList()
-        self.norms_1 = nn.ModuleList()
-        self.norms_2 = nn.ModuleList()
-        for i in range(n_layers):
-            dilation = kernel_size**i
-            padding = (kernel_size * dilation - dilation) // 2
-            self.convs_sep.append(
-                nn.Conv1d(
-                    channels,
-                    channels,
-                    kernel_size,
-                    groups=channels,
-                    dilation=dilation,
-                    padding=padding,
-                )
-            )
-            self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
-            self.norms_1.append(LayerNorm(channels))
-            self.norms_2.append(LayerNorm(channels))
-
-    def forward(self, x, x_mask, g=None):
-        if g is not None:
-            x = x + g
-        for i in range(self.n_layers):
-            y = self.convs_sep[i](x * x_mask)
-            y = self.norms_1[i](y)
-            y = F.gelu(y)
-            y = self.convs_1x1[i](y)
-            y = self.norms_2[i](y)
-            y = F.gelu(y)
-            y = self.drop(y)
-            x = x + y
-        return x * x_mask
-
-
-class WN(torch.nn.Module):
-    def __init__(
-        self,
-        hidden_channels,
-        kernel_size,
-        dilation_rate,
-        n_layers,
-        gin_channels=0,
-        p_dropout=0,
-    ):
-        super(WN, self).__init__()
-        assert kernel_size % 2 == 1
-        self.hidden_channels = hidden_channels
-        self.kernel_size = (kernel_size,)
-        self.dilation_rate = dilation_rate
-        self.n_layers = n_layers
-        self.gin_channels = gin_channels
-        self.p_dropout = p_dropout
-
-        self.in_layers = torch.nn.ModuleList()
-        self.res_skip_layers = torch.nn.ModuleList()
-        self.drop = nn.Dropout(p_dropout)
-
-        if gin_channels != 0:
-            cond_layer = torch.nn.Conv1d(
-                gin_channels, 2 * hidden_channels * n_layers, 1
-            )
-            self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
-
-        for i in range(n_layers):
-            dilation = dilation_rate**i
-            padding = int((kernel_size * dilation - dilation) / 2)
-            in_layer = torch.nn.Conv1d(
-                hidden_channels,
-                2 * hidden_channels,
-                kernel_size,
-                dilation=dilation,
-                padding=padding,
-            )
-            in_layer = torch.nn.utils.weight_norm(in_layer, name="weight")
-            self.in_layers.append(in_layer)
-
-            # last one is not necessary
-            if i < n_layers - 1:
-                res_skip_channels = 2 * hidden_channels
-            else:
-                res_skip_channels = hidden_channels
-
-            res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
-            res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight")
-            self.res_skip_layers.append(res_skip_layer)
-
-    def forward(self, x, x_mask, g=None, **kwargs):
-        output = torch.zeros_like(x)
-        n_channels_tensor = torch.IntTensor([self.hidden_channels])
-
-        if g is not None:
-            g = self.cond_layer(g)
-
-        for i in range(self.n_layers):
-            x_in = self.in_layers[i](x)
-            if g is not None:
-                cond_offset = i * 2 * self.hidden_channels
-                g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
-            else:
-                g_l = torch.zeros_like(x_in)
-
-            acts = fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
-            acts = self.drop(acts)
-
-            res_skip_acts = self.res_skip_layers[i](acts)
-            if i < self.n_layers - 1:
-                res_acts = res_skip_acts[:, : self.hidden_channels, :]
-                x = (x + res_acts) * x_mask
-                output = output + res_skip_acts[:, self.hidden_channels :, :]
-            else:
-                output = output + res_skip_acts
-        return output * x_mask
-
-    def remove_weight_norm(self):
-        if self.gin_channels != 0:
-            torch.nn.utils.remove_weight_norm(self.cond_layer)
-        for l in self.in_layers:
-            torch.nn.utils.remove_weight_norm(l)
-        for l in self.res_skip_layers:
-            torch.nn.utils.remove_weight_norm(l)
-
-
-class ResBlock1(torch.nn.Module):
-    def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
-        super(ResBlock1, self).__init__()
-        self.convs1 = nn.ModuleList(
-            [
-                weight_norm(
-                    Conv1d(
-                        channels,
-                        channels,
-                        kernel_size,
-                        1,
-                        dilation=dilation[0],
-                        padding=get_padding(kernel_size, dilation[0]),
-                    )
-                ),
-                weight_norm(
-                    Conv1d(
-                        channels,
-                        channels,
-                        kernel_size,
-                        1,
-                        dilation=dilation[1],
-                        padding=get_padding(kernel_size, dilation[1]),
-                    )
-                ),
-                weight_norm(
-                    Conv1d(
-                        channels,
-                        channels,
-                        kernel_size,
-                        1,
-                        dilation=dilation[2],
-                        padding=get_padding(kernel_size, dilation[2]),
-                    )
-                ),
-            ]
-        )
-        self.convs1.apply(init_weights)
-
-        self.convs2 = nn.ModuleList(
-            [
-                weight_norm(
-                    Conv1d(
-                        channels,
-                        channels,
-                        kernel_size,
-                        1,
-                        dilation=1,
-                        padding=get_padding(kernel_size, 1),
-                    )
-                ),
-                weight_norm(
-                    Conv1d(
-                        channels,
-                        channels,
-                        kernel_size,
-                        1,
-                        dilation=1,
-                        padding=get_padding(kernel_size, 1),
-                    )
-                ),
-                weight_norm(
-                    Conv1d(
-                        channels,
-                        channels,
-                        kernel_size,
-                        1,
-                        dilation=1,
-                        padding=get_padding(kernel_size, 1),
-                    )
-                ),
-            ]
-        )
-        self.convs2.apply(init_weights)
-
-    def forward(self, x, x_mask=None):
-        for c1, c2 in zip(self.convs1, self.convs2):
-            xt = F.leaky_relu(x, LRELU_SLOPE)
-            if x_mask is not None:
-                xt = xt * x_mask
-            xt = c1(xt)
-            xt = F.leaky_relu(xt, LRELU_SLOPE)
-            if x_mask is not None:
-                xt = xt * x_mask
-            xt = c2(xt)
-            x = xt + x
-        if x_mask is not None:
-            x = x * x_mask
-        return x
-
-    def remove_weight_norm(self):
-        for l in self.convs1:
-            remove_weight_norm(l)
-        for l in self.convs2:
-            remove_weight_norm(l)
-
-
-class ResBlock2(torch.nn.Module):
-    def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
-        super(ResBlock2, self).__init__()
-        self.convs = nn.ModuleList(
-            [
-                weight_norm(
-                    Conv1d(
-                        channels,
-                        channels,
-                        kernel_size,
-                        1,
-                        dilation=dilation[0],
-                        padding=get_padding(kernel_size, dilation[0]),
-                    )
-                ),
-                weight_norm(
-                    Conv1d(
-                        channels,
-                        channels,
-                        kernel_size,
-                        1,
-                        dilation=dilation[1],
-                        padding=get_padding(kernel_size, dilation[1]),
-                    )
-                ),
-            ]
-        )
-        self.convs.apply(init_weights)
-
-    def forward(self, x, x_mask=None):
-        for c in self.convs:
-            xt = F.leaky_relu(x, LRELU_SLOPE)
-            if x_mask is not None:
-                xt = xt * x_mask
-            xt = c(xt)
-            x = xt + x
-        if x_mask is not None:
-            x = x * x_mask
-        return x
-
-    def remove_weight_norm(self):
-        for l in self.convs:
-            remove_weight_norm(l)
-
-
-class Flip(nn.Module):
-    def forward(self, x, *args, reverse=False, **kwargs):
-        x = torch.flip(x, [1])
-        if not reverse:
-            logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
-            return x, logdet
-        else:
-            return x
-
-
-class ResidualCouplingLayer(nn.Module):
-    def __init__(
-        self,
-        channels,
-        hidden_channels,
-        kernel_size,
-        dilation_rate,
-        n_layers,
-        p_dropout=0,
-        gin_channels=0,
-        mean_only=False,
-    ):
-        assert channels % 2 == 0, "channels should be divisible by 2"
-        super().__init__()
-        self.channels = channels
-        self.hidden_channels = hidden_channels
-        self.kernel_size = kernel_size
-        self.dilation_rate = dilation_rate
-        self.n_layers = n_layers
-        self.half_channels = channels // 2
-        self.mean_only = mean_only
-
-        self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
-        self.enc = WN(
-            hidden_channels,
-            kernel_size,
-            dilation_rate,
-            n_layers,
-            p_dropout=p_dropout,
-            gin_channels=gin_channels,
-        )
-        self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
-        self.post.weight.data.zero_()
-        self.post.bias.data.zero_()
-
-    def forward(self, x, x_mask, g=None, reverse=False):
-        x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
-        h = self.pre(x0) * x_mask
-        h = self.enc(h, x_mask, g=g)
-        stats = self.post(h) * x_mask
-        if not self.mean_only:
-            m, logs = torch.split(stats, [self.half_channels] * 2, 1)
-        else:
-            m = stats
-            logs = torch.zeros_like(m)
-
-        if not reverse:
-            x1 = m + x1 * torch.exp(logs) * x_mask
-            x = torch.cat([x0, x1], 1)
-            logdet = torch.sum(logs, [1, 2])
-            return x, logdet
-        else:
-            x1 = (x1 - m) * torch.exp(-logs) * x_mask
-            x = torch.cat([x0, x1], 1)
-            return x
-
-
-class LinearNorm(nn.Module):
-    def __init__(
-        self,
-        in_channels,
-        out_channels,
-        bias=True,
-        spectral_norm=False,
-    ):
-        super(LinearNorm, self).__init__()
-        self.fc = nn.Linear(in_channels, out_channels, bias)
-
-        if spectral_norm:
-            self.fc = nn.utils.spectral_norm(self.fc)
-
-    def forward(self, input):
-        out = self.fc(input)
-        return out
-
-
-class Mish(nn.Module):
-    def __init__(self):
-        super(Mish, self).__init__()
-
-    def forward(self, x):
-        return x * torch.tanh(F.softplus(x))
-
-
-class Conv1dGLU(nn.Module):
-    """
-    Conv1d + GLU(Gated Linear Unit) with residual connection.
-    For GLU refer to https://arxiv.org/abs/1612.08083 paper.
-    """
-
-    def __init__(self, in_channels, out_channels, kernel_size, dropout):
-        super(Conv1dGLU, self).__init__()
-        self.out_channels = out_channels
-        self.conv1 = ConvNorm(in_channels, 2 * out_channels, kernel_size=kernel_size)
-        self.dropout = nn.Dropout(dropout)
-
-    def forward(self, x):
-        residual = x
-        x = self.conv1(x)
-        x1, x2 = torch.split(x, split_size_or_sections=self.out_channels, dim=1)
-        x = x1 * torch.sigmoid(x2)
-        x = residual + self.dropout(x)
-        return x
-
-
-class ConvNorm(nn.Module):
-    def __init__(
-        self,
-        in_channels,
-        out_channels,
-        kernel_size=1,
-        stride=1,
-        padding=None,
-        dilation=1,
-        bias=True,
-        spectral_norm=False,
-    ):
-        super(ConvNorm, self).__init__()
-
-        if padding is None:
-            assert kernel_size % 2 == 1
-            padding = int(dilation * (kernel_size - 1) / 2)
-
-        self.conv = torch.nn.Conv1d(
-            in_channels,
-            out_channels,
-            kernel_size=kernel_size,
-            stride=stride,
-            padding=padding,
-            dilation=dilation,
-            bias=bias,
-        )
-
-        if spectral_norm:
-            self.conv = nn.utils.spectral_norm(self.conv)
-
-    def forward(self, input):
-        out = self.conv(input)
-        return out
-
-
-class MultiHeadAttention(nn.Module):
-    """Multi-Head Attention module"""
-
-    def __init__(self, n_head, d_model, d_k, d_v, dropout=0.0, spectral_norm=False):
-        super().__init__()
-
-        self.n_head = n_head
-        self.d_k = d_k
-        self.d_v = d_v
-
-        self.w_qs = nn.Linear(d_model, n_head * d_k)
-        self.w_ks = nn.Linear(d_model, n_head * d_k)
-        self.w_vs = nn.Linear(d_model, n_head * d_v)
-
-        self.attention = ScaledDotProductAttention(
-            temperature=np.power(d_model, 0.5), dropout=dropout
-        )
-
-        self.fc = nn.Linear(n_head * d_v, d_model)
-        self.dropout = nn.Dropout(dropout)
-
-        if spectral_norm:
-            self.w_qs = nn.utils.spectral_norm(self.w_qs)
-            self.w_ks = nn.utils.spectral_norm(self.w_ks)
-            self.w_vs = nn.utils.spectral_norm(self.w_vs)
-            self.fc = nn.utils.spectral_norm(self.fc)
-
-    def forward(self, x, mask=None):
-        d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
-        sz_b, len_x, _ = x.size()
-
-        residual = x
-
-        q = self.w_qs(x).view(sz_b, len_x, n_head, d_k)
-        k = self.w_ks(x).view(sz_b, len_x, n_head, d_k)
-        v = self.w_vs(x).view(sz_b, len_x, n_head, d_v)
-        q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_x, d_k)  # (n*b) x lq x dk
-        k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_x, d_k)  # (n*b) x lk x dk
-        v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_x, d_v)  # (n*b) x lv x dv
-
-        if mask is not None:
-            slf_mask = mask.repeat(n_head, 1, 1)  # (n*b) x .. x ..
-        else:
-            slf_mask = None
-        output, attn = self.attention(q, k, v, mask=slf_mask)
-
-        output = output.view(n_head, sz_b, len_x, d_v)
-        output = (
-            output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_x, -1)
-        )  # b x lq x (n*dv)
-
-        output = self.fc(output)
-
-        output = self.dropout(output) + residual
-        return output, attn
-
-
-class ScaledDotProductAttention(nn.Module):
-    """Scaled Dot-Product Attention"""
-
-    def __init__(self, temperature, dropout):
-        super().__init__()
-        self.temperature = temperature
-        self.softmax = nn.Softmax(dim=2)
-        self.dropout = nn.Dropout(dropout)
-
-    def forward(self, q, k, v, mask=None):
-        attn = torch.bmm(q, k.transpose(1, 2))
-        attn = attn / self.temperature
-
-        if mask is not None:
-            attn = attn.masked_fill(mask, -np.inf)
-
-        attn = self.softmax(attn)
-        p_attn = self.dropout(attn)
-
-        output = torch.bmm(p_attn, v)
-        return output, attn
-
-
-class MelStyleEncoder(nn.Module):
-    """MelStyleEncoder"""
-
-    def __init__(
-        self,
-        n_mel_channels=80,
-        style_hidden=128,
-        style_vector_dim=256,
-        style_kernel_size=5,
-        style_head=2,
-        dropout=0.1,
-    ):
-        super(MelStyleEncoder, self).__init__()
-        self.in_dim = n_mel_channels
-        self.hidden_dim = style_hidden
-        self.out_dim = style_vector_dim
-        self.kernel_size = style_kernel_size
-        self.n_head = style_head
-        self.dropout = dropout
-
-        self.spectral = nn.Sequential(
-            LinearNorm(self.in_dim, self.hidden_dim),
-            Mish(),
-            nn.Dropout(self.dropout),
-            LinearNorm(self.hidden_dim, self.hidden_dim),
-            Mish(),
-            nn.Dropout(self.dropout),
-        )
-
-        self.temporal = nn.Sequential(
-            Conv1dGLU(self.hidden_dim, self.hidden_dim, self.kernel_size, self.dropout),
-            Conv1dGLU(self.hidden_dim, self.hidden_dim, self.kernel_size, self.dropout),
-        )
-
-        self.slf_attn = MultiHeadAttention(
-            self.n_head,
-            self.hidden_dim,
-            self.hidden_dim // self.n_head,
-            self.hidden_dim // self.n_head,
-            self.dropout,
-        )
-
-        self.fc = LinearNorm(self.hidden_dim, self.out_dim)
-
-    def temporal_avg_pool(self, x, mask=None):
-        if mask is None:
-            out = torch.mean(x, dim=1)
-        else:
-            len_ = (~mask).sum(dim=1).unsqueeze(1)
-            x = x.masked_fill(mask.unsqueeze(-1), 0)
-            x = x.sum(dim=1)
-            out = torch.div(x, len_)
-        return out
-
-    def forward(self, x, mask=None):
-        x = x.transpose(1, 2)
-        if mask is not None:
-            mask = (mask.int() == 0).squeeze(1)
-        max_len = x.shape[1]
-        slf_attn_mask = (
-            mask.unsqueeze(1).expand(-1, max_len, -1) if mask is not None else None
-        )
-
-        # spectral
-        x = self.spectral(x)
-        # temporal
-        x = x.transpose(1, 2)
-        x = self.temporal(x)
-        x = x.transpose(1, 2)
-        # self-attention
-        if mask is not None:
-            x = x.masked_fill(mask.unsqueeze(-1), 0)
-        x, _ = self.slf_attn(x, mask=slf_attn_mask)
-        # fc
-        x = self.fc(x)
-        # temoral average pooling
-        w = self.temporal_avg_pool(x, mask=mask)
-
-        return w.unsqueeze(-1)

+ 0 - 124
fish_speech/models/vqgan/modules/rvq.py

@@ -1,124 +0,0 @@
-from dataclasses import dataclass
-from typing import Union
-
-import numpy as np
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from einops import rearrange
-from torch.nn.utils import weight_norm
-from vector_quantize_pytorch import LFQ, ResidualVQ
-
-
-class DownsampleResidualVectorQuantizer(nn.Module):
-    """
-    Downsampled version of ResidualVectorQuantize
-    """
-
-    def __init__(
-        self,
-        input_dim: int = 512,
-        n_codebooks: int = 9,
-        codebook_size: int = 1024,
-        codebook_dim: Union[int, list] = 8,
-        quantizer_dropout: float = 0.0,
-        min_quantizers: int = 4,
-        downsample_factor: tuple[int] = (2, 2),
-        downsample_dims: tuple[int] | None = None,
-    ):
-        super().__init__()
-        if downsample_dims is None:
-            downsample_dims = [input_dim for _ in range(len(downsample_factor))]
-
-        all_dims = (input_dim,) + tuple(downsample_dims)
-
-        # self.vq = ResidualVQ(
-        #     dim=all_dims[-1],
-        #     num_quantizers=n_codebooks,
-        #     codebook_dim=codebook_dim,
-        #     threshold_ema_dead_code=2,
-        #     codebook_size=codebook_size,
-        #     kmeans_init=False,
-        # )
-
-        self.vq = LFQ(
-            dim=all_dims[-1],
-            codebook_size=2**14,
-            entropy_loss_weight=0.1,
-            diversity_gamma=1.0,
-        )
-
-        self.downsample_factor = downsample_factor
-        self.downsample_dims = downsample_dims
-
-        self.downsample = nn.Sequential(
-            *[
-                nn.Conv1d(
-                    all_dims[idx],
-                    all_dims[idx + 1],
-                    kernel_size=factor,
-                    stride=factor,
-                )
-                for idx, factor in enumerate(downsample_factor)
-            ]
-        )
-
-        self.upsample = nn.Sequential(
-            *[
-                nn.ConvTranspose1d(
-                    all_dims[idx + 1],
-                    all_dims[idx],
-                    kernel_size=factor,
-                    stride=factor,
-                )
-                for idx, factor in reversed(list(enumerate(downsample_factor)))
-            ]
-        )
-
-    def forward(self, z):
-        original_shape = z.shape
-        z = self.downsample(z)
-        z, indices, loss = self.vq(z.mT)
-        z = self.upsample(z.mT)
-        loss = loss.mean()
-
-        # Pad or crop z to match original shape
-        diff = original_shape[-1] - z.shape[-1]
-        left = diff // 2
-        right = diff - left
-
-        if diff > 0:
-            z = F.pad(z, (left, right))
-        elif diff < 0:
-            z = z[..., left:-right]
-
-        return z, indices, loss
-
-    # def from_codes(self, codes: torch.Tensor):
-    #     z_q, z_p, codes = super().from_codes(codes)
-    #     z_q = self.upsample(z_q)
-    #     return z_q, z_p, codes
-
-    # def from_latents(self, latents: torch.Tensor):
-    #     z_q, z_p, codes = super().from_latents(latents)
-    #     z_q = self.upsample(z_q)
-    #     return z_q, z_p, codes
-
-
-if __name__ == "__main__":
-    rvq = DownsampleResidualVectorQuantizer(
-        quantizer_dropout=1.0,
-        min_quantizers=1,
-        codebook_size=256,
-        downsample_factor=(2, 2),
-    )
-    x = torch.randn(16, 512, 80)
-
-    result = rvq(x)
-    print(result.latents.shape, result.codes.shape, result.z.shape)
-
-    y = rvq.from_codes(result.codes)
-    print(y[0].shape)
-
-    y = rvq.from_latents(result.latents)
-    print(y[0].shape)