|
@@ -0,0 +1,573 @@
|
|
|
|
|
+import math
|
|
|
|
|
+
|
|
|
|
|
+import torch
|
|
|
|
|
+from torch import nn
|
|
|
|
|
+from torch.nn import Conv1d, Conv2d, ConvTranspose1d
|
|
|
|
|
+from torch.nn import functional as F
|
|
|
|
|
+from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
|
|
|
|
|
+
|
|
|
|
|
+from fish_speech.models.hubert_vq.utils import (
|
|
|
|
|
+ convert_pad_shape,
|
|
|
|
|
+ get_padding,
|
|
|
|
|
+ init_weights,
|
|
|
|
|
+)
|
|
|
|
|
+
|
|
|
|
|
+LRELU_SLOPE = 0.1
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class VQEncoder(nn.Module):
|
|
|
|
|
+ def __init__(self, *args, **kwargs) -> None:
|
|
|
|
|
+ super().__init__(*args, **kwargs)
|
|
|
|
|
+
|
|
|
|
|
+ encoder_layer = nn.TransformerEncoderLayer(
|
|
|
|
|
+ d_model=256, nhead=4, dim_feedforward=1024, dropout=0.1, activation="gelu"
|
|
|
|
|
+ )
|
|
|
|
|
+ self.encoder = nn.TransformerEncoder(
|
|
|
|
|
+ encoder_layer, num_layers=6, norm=nn.LayerNorm(256)
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class RelativeAttention(nn.Module):
|
|
|
|
|
+ def __init__(
|
|
|
|
|
+ self,
|
|
|
|
|
+ channels,
|
|
|
|
|
+ n_heads,
|
|
|
|
|
+ p_dropout=0.0,
|
|
|
|
|
+ window_size=4,
|
|
|
|
|
+ window_heads_share=True,
|
|
|
|
|
+ proximal_init=True,
|
|
|
|
|
+ proximal_bias=False,
|
|
|
|
|
+ ):
|
|
|
|
|
+ super().__init__()
|
|
|
|
|
+ assert channels % n_heads == 0
|
|
|
|
|
+
|
|
|
|
|
+ self.channels = channels
|
|
|
|
|
+ self.n_heads = n_heads
|
|
|
|
|
+ self.p_dropout = p_dropout
|
|
|
|
|
+ self.window_size = window_size
|
|
|
|
|
+ self.heads_share = window_heads_share
|
|
|
|
|
+ self.proximal_init = proximal_init
|
|
|
|
|
+ self.proximal_bias = proximal_bias
|
|
|
|
|
+
|
|
|
|
|
+ self.k_channels = channels // n_heads
|
|
|
|
|
+ self.qkv = nn.Linear(channels, channels * 3)
|
|
|
|
|
+ self.drop = nn.Dropout(p_dropout)
|
|
|
|
|
+
|
|
|
|
|
+ if window_size is not None:
|
|
|
|
|
+ n_heads_rel = 1 if window_heads_share else n_heads
|
|
|
|
|
+ rel_stddev = self.k_channels**-0.5
|
|
|
|
|
+ self.emb_rel_k = nn.Parameter(
|
|
|
|
|
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
|
|
|
|
|
+ * rel_stddev
|
|
|
|
|
+ )
|
|
|
|
|
+ self.emb_rel_v = nn.Parameter(
|
|
|
|
|
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
|
|
|
|
|
+ * rel_stddev
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ nn.init.xavier_uniform_(self.qkv.weight)
|
|
|
|
|
+
|
|
|
|
|
+ if proximal_init:
|
|
|
|
|
+ with torch.no_grad():
|
|
|
|
|
+ # Sync qk weights
|
|
|
|
|
+ self.qkv.weight.data[: self.channels] = self.qkv.weight.data[
|
|
|
|
|
+ self.channels : self.channels * 2
|
|
|
|
|
+ ]
|
|
|
|
|
+ self.qkv.bias.data[: self.channels] = self.qkv.bias.data[
|
|
|
|
|
+ self.channels : self.channels * 2
|
|
|
|
|
+ ]
|
|
|
|
|
+
|
|
|
|
|
+ def forward(self, x, key_padding_mask=None):
|
|
|
|
|
+ # x: (batch, seq_len, channels)
|
|
|
|
|
+ batch_size, seq_len, _ = x.size()
|
|
|
|
|
+ qkv = (
|
|
|
|
|
+ self.qkv(x)
|
|
|
|
|
+ .reshape(batch_size, seq_len, 3, self.n_heads, self.k_channels)
|
|
|
|
|
+ .permute(2, 0, 3, 1, 4)
|
|
|
|
|
+ )
|
|
|
|
|
+ query, key, value = torch.unbind(qkv, dim=0)
|
|
|
|
|
+
|
|
|
|
|
+ scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
|
|
|
|
|
+
|
|
|
|
|
+ if self.window_size is not None:
|
|
|
|
|
+ key_relative_embeddings = self._get_relative_embeddings(
|
|
|
|
|
+ self.emb_rel_k, seq_len
|
|
|
|
|
+ )
|
|
|
|
|
+ rel_logits = self._matmul_with_relative_keys(
|
|
|
|
|
+ query / math.sqrt(self.k_channels), key_relative_embeddings
|
|
|
|
|
+ )
|
|
|
|
|
+ scores_local = self._relative_position_to_absolute_position(rel_logits)
|
|
|
|
|
+ scores = scores + scores_local
|
|
|
|
|
+
|
|
|
|
|
+ if self.proximal_bias:
|
|
|
|
|
+ scores = scores + self._attention_bias_proximal(seq_len).to(
|
|
|
|
|
+ device=scores.device, dtype=scores.dtype
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ # key_padding_mask: (batch, seq_len)
|
|
|
|
|
+ if key_padding_mask is not None:
|
|
|
|
|
+ assert key_padding_mask.size() == (
|
|
|
|
|
+ batch_size,
|
|
|
|
|
+ seq_len,
|
|
|
|
|
+ ), f"key_padding_mask shape {key_padding_mask.size()} is not (batch_size, seq_len)"
|
|
|
|
|
+ assert (
|
|
|
|
|
+ key_padding_mask.dtype == torch.bool
|
|
|
|
|
+ ), f"key_padding_mask dtype {key_padding_mask.dtype} is not bool"
|
|
|
|
|
+
|
|
|
|
|
+ key_padding_mask = key_padding_mask.view(batch_size, 1, 1, seq_len).expand(
|
|
|
|
|
+ -1, self.n_heads, -1, -1
|
|
|
|
|
+ )
|
|
|
|
|
+ print(key_padding_mask.shape, scores.shape)
|
|
|
|
|
+ scores = scores.masked_fill(key_padding_mask, float("-inf"))
|
|
|
|
|
+
|
|
|
|
|
+ print(scores[0, 0])
|
|
|
|
|
+
|
|
|
|
|
+ p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
|
|
|
|
|
+ p_attn = self.drop(p_attn)
|
|
|
|
|
+ output = torch.matmul(p_attn, value)
|
|
|
|
|
+
|
|
|
|
|
+ if self.window_size is not None:
|
|
|
|
|
+ relative_weights = self._absolute_position_to_relative_position(p_attn)
|
|
|
|
|
+ value_relative_embeddings = self._get_relative_embeddings(
|
|
|
|
|
+ self.emb_rel_v, seq_len
|
|
|
|
|
+ )
|
|
|
|
|
+ output = output + self._matmul_with_relative_values(
|
|
|
|
|
+ relative_weights, value_relative_embeddings
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ return output.reshape(batch_size, seq_len, self.n_heads * self.k_channels)
|
|
|
|
|
+
|
|
|
|
|
+ def _matmul_with_relative_values(self, x, y):
|
|
|
|
|
+ """
|
|
|
|
|
+ x: [b, h, l, m]
|
|
|
|
|
+ y: [h or 1, m, d]
|
|
|
|
|
+ ret: [b, h, l, d]
|
|
|
|
|
+ """
|
|
|
|
|
+ ret = torch.matmul(x, y.unsqueeze(0))
|
|
|
|
|
+ return ret
|
|
|
|
|
+
|
|
|
|
|
+ def _matmul_with_relative_keys(self, x, y):
|
|
|
|
|
+ """
|
|
|
|
|
+ x: [b, h, l, d]
|
|
|
|
|
+ y: [h or 1, m, d]
|
|
|
|
|
+ ret: [b, h, l, m]
|
|
|
|
|
+ """
|
|
|
|
|
+ ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
|
|
|
|
|
+ return ret
|
|
|
|
|
+
|
|
|
|
|
+ def _get_relative_embeddings(self, relative_embeddings, length):
|
|
|
|
|
+ max_relative_position = 2 * self.window_size + 1
|
|
|
|
|
+ # Pad first before slice to avoid using cond ops.
|
|
|
|
|
+ pad_length = max(length - (self.window_size + 1), 0)
|
|
|
|
|
+ slice_start_position = max((self.window_size + 1) - length, 0)
|
|
|
|
|
+ slice_end_position = slice_start_position + 2 * length - 1
|
|
|
|
|
+ if pad_length > 0:
|
|
|
|
|
+ padded_relative_embeddings = F.pad(
|
|
|
|
|
+ relative_embeddings,
|
|
|
|
|
+ convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
|
|
|
|
|
+ )
|
|
|
|
|
+ else:
|
|
|
|
|
+ padded_relative_embeddings = relative_embeddings
|
|
|
|
|
+ used_relative_embeddings = padded_relative_embeddings[
|
|
|
|
|
+ :, slice_start_position:slice_end_position
|
|
|
|
|
+ ]
|
|
|
|
|
+ return used_relative_embeddings
|
|
|
|
|
+
|
|
|
|
|
+ def _relative_position_to_absolute_position(self, x):
|
|
|
|
|
+ """
|
|
|
|
|
+ x: [b, h, l, 2*l-1]
|
|
|
|
|
+ ret: [b, h, l, l]
|
|
|
|
|
+ """
|
|
|
|
|
+ batch, heads, length, _ = x.size()
|
|
|
|
|
+ # Concat columns of pad to shift from relative to absolute indexing.
|
|
|
|
|
+ x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
|
|
|
|
|
+
|
|
|
|
|
+ # Concat extra elements so to add up to shape (len+1, 2*len-1).
|
|
|
|
|
+ x_flat = x.view([batch, heads, length * 2 * length])
|
|
|
|
|
+ x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [0, length - 1]]))
|
|
|
|
|
+
|
|
|
|
|
+ # Reshape and slice out the padded elements.
|
|
|
|
|
+ x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
|
|
|
|
|
+ :, :, :length, length - 1 :
|
|
|
|
|
+ ]
|
|
|
|
|
+ return x_final
|
|
|
|
|
+
|
|
|
|
|
+ def _absolute_position_to_relative_position(self, x):
|
|
|
|
|
+ """
|
|
|
|
|
+ x: [b, h, l, l]
|
|
|
|
|
+ ret: [b, h, l, 2*l-1]
|
|
|
|
|
+ """
|
|
|
|
|
+ batch, heads, length, _ = x.size()
|
|
|
|
|
+ # pad along column
|
|
|
|
|
+ x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]]))
|
|
|
|
|
+ x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
|
|
|
|
|
+ # add 0's in the beginning that will skew the elements after reshape
|
|
|
|
|
+ x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
|
|
|
|
|
+ x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
|
|
|
|
|
+ return x_final
|
|
|
|
|
+
|
|
|
|
|
+ def _attention_bias_proximal(self, length):
|
|
|
|
|
+ """Bias for self-attention to encourage attention to close positions.
|
|
|
|
|
+ Args:
|
|
|
|
|
+ length: an integer scalar.
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ a Tensor with shape [1, 1, length, length]
|
|
|
|
|
+ """
|
|
|
|
|
+ r = torch.arange(length, dtype=torch.float32)
|
|
|
|
|
+ diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
|
|
|
|
|
+ return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class ResBlock1(torch.nn.Module):
|
|
|
|
|
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
|
|
|
|
|
+ super(ResBlock1, self).__init__()
|
|
|
|
|
+ self.convs1 = nn.ModuleList(
|
|
|
|
|
+ [
|
|
|
|
|
+ weight_norm(
|
|
|
|
|
+ Conv1d(
|
|
|
|
|
+ channels,
|
|
|
|
|
+ channels,
|
|
|
|
|
+ kernel_size,
|
|
|
|
|
+ 1,
|
|
|
|
|
+ dilation=dilation[0],
|
|
|
|
|
+ padding=get_padding(kernel_size, dilation[0]),
|
|
|
|
|
+ )
|
|
|
|
|
+ ),
|
|
|
|
|
+ weight_norm(
|
|
|
|
|
+ Conv1d(
|
|
|
|
|
+ channels,
|
|
|
|
|
+ channels,
|
|
|
|
|
+ kernel_size,
|
|
|
|
|
+ 1,
|
|
|
|
|
+ dilation=dilation[1],
|
|
|
|
|
+ padding=get_padding(kernel_size, dilation[1]),
|
|
|
|
|
+ )
|
|
|
|
|
+ ),
|
|
|
|
|
+ weight_norm(
|
|
|
|
|
+ Conv1d(
|
|
|
|
|
+ channels,
|
|
|
|
|
+ channels,
|
|
|
|
|
+ kernel_size,
|
|
|
|
|
+ 1,
|
|
|
|
|
+ dilation=dilation[2],
|
|
|
|
|
+ padding=get_padding(kernel_size, dilation[2]),
|
|
|
|
|
+ )
|
|
|
|
|
+ ),
|
|
|
|
|
+ ]
|
|
|
|
|
+ )
|
|
|
|
|
+ self.convs1.apply(init_weights)
|
|
|
|
|
+
|
|
|
|
|
+ self.convs2 = nn.ModuleList(
|
|
|
|
|
+ [
|
|
|
|
|
+ weight_norm(
|
|
|
|
|
+ Conv1d(
|
|
|
|
|
+ channels,
|
|
|
|
|
+ channels,
|
|
|
|
|
+ kernel_size,
|
|
|
|
|
+ 1,
|
|
|
|
|
+ dilation=1,
|
|
|
|
|
+ padding=get_padding(kernel_size, 1),
|
|
|
|
|
+ )
|
|
|
|
|
+ ),
|
|
|
|
|
+ weight_norm(
|
|
|
|
|
+ Conv1d(
|
|
|
|
|
+ channels,
|
|
|
|
|
+ channels,
|
|
|
|
|
+ kernel_size,
|
|
|
|
|
+ 1,
|
|
|
|
|
+ dilation=1,
|
|
|
|
|
+ padding=get_padding(kernel_size, 1),
|
|
|
|
|
+ )
|
|
|
|
|
+ ),
|
|
|
|
|
+ weight_norm(
|
|
|
|
|
+ Conv1d(
|
|
|
|
|
+ channels,
|
|
|
|
|
+ channels,
|
|
|
|
|
+ kernel_size,
|
|
|
|
|
+ 1,
|
|
|
|
|
+ dilation=1,
|
|
|
|
|
+ padding=get_padding(kernel_size, 1),
|
|
|
|
|
+ )
|
|
|
|
|
+ ),
|
|
|
|
|
+ ]
|
|
|
|
|
+ )
|
|
|
|
|
+ self.convs2.apply(init_weights)
|
|
|
|
|
+
|
|
|
|
|
+ def forward(self, x, x_mask=None):
|
|
|
|
|
+ for c1, c2 in zip(self.convs1, self.convs2):
|
|
|
|
|
+ xt = F.leaky_relu(x, LRELU_SLOPE)
|
|
|
|
|
+ if x_mask is not None:
|
|
|
|
|
+ xt = xt * x_mask
|
|
|
|
|
+ xt = c1(xt)
|
|
|
|
|
+ xt = F.leaky_relu(xt, LRELU_SLOPE)
|
|
|
|
|
+ if x_mask is not None:
|
|
|
|
|
+ xt = xt * x_mask
|
|
|
|
|
+ xt = c2(xt)
|
|
|
|
|
+ x = xt + x
|
|
|
|
|
+ if x_mask is not None:
|
|
|
|
|
+ x = x * x_mask
|
|
|
|
|
+ return x
|
|
|
|
|
+
|
|
|
|
|
+ def remove_weight_norm(self):
|
|
|
|
|
+ for l in self.convs1:
|
|
|
|
|
+ remove_weight_norm(l)
|
|
|
|
|
+ for l in self.convs2:
|
|
|
|
|
+ remove_weight_norm(l)
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class ResBlock2(torch.nn.Module):
|
|
|
|
|
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
|
|
|
|
|
+ super(ResBlock2, self).__init__()
|
|
|
|
|
+ self.convs = nn.ModuleList(
|
|
|
|
|
+ [
|
|
|
|
|
+ weight_norm(
|
|
|
|
|
+ Conv1d(
|
|
|
|
|
+ channels,
|
|
|
|
|
+ channels,
|
|
|
|
|
+ kernel_size,
|
|
|
|
|
+ 1,
|
|
|
|
|
+ dilation=dilation[0],
|
|
|
|
|
+ padding=get_padding(kernel_size, dilation[0]),
|
|
|
|
|
+ )
|
|
|
|
|
+ ),
|
|
|
|
|
+ weight_norm(
|
|
|
|
|
+ Conv1d(
|
|
|
|
|
+ channels,
|
|
|
|
|
+ channels,
|
|
|
|
|
+ kernel_size,
|
|
|
|
|
+ 1,
|
|
|
|
|
+ dilation=dilation[1],
|
|
|
|
|
+ padding=get_padding(kernel_size, dilation[1]),
|
|
|
|
|
+ )
|
|
|
|
|
+ ),
|
|
|
|
|
+ ]
|
|
|
|
|
+ )
|
|
|
|
|
+ self.convs.apply(init_weights)
|
|
|
|
|
+
|
|
|
|
|
+ def forward(self, x, x_mask=None):
|
|
|
|
|
+ for c in self.convs:
|
|
|
|
|
+ xt = F.leaky_relu(x, LRELU_SLOPE)
|
|
|
|
|
+ if x_mask is not None:
|
|
|
|
|
+ xt = xt * x_mask
|
|
|
|
|
+ xt = c(xt)
|
|
|
|
|
+ x = xt + x
|
|
|
|
|
+ if x_mask is not None:
|
|
|
|
|
+ x = x * x_mask
|
|
|
|
|
+ return x
|
|
|
|
|
+
|
|
|
|
|
+ def remove_weight_norm(self):
|
|
|
|
|
+ for l in self.convs:
|
|
|
|
|
+ remove_weight_norm(l)
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class Generator(nn.Module):
|
|
|
|
|
+ def __init__(
|
|
|
|
|
+ self,
|
|
|
|
|
+ initial_channel,
|
|
|
|
|
+ resblock,
|
|
|
|
|
+ resblock_kernel_sizes,
|
|
|
|
|
+ resblock_dilation_sizes,
|
|
|
|
|
+ upsample_rates,
|
|
|
|
|
+ upsample_initial_channel,
|
|
|
|
|
+ upsample_kernel_sizes,
|
|
|
|
|
+ gin_channels=0,
|
|
|
|
|
+ ):
|
|
|
|
|
+ super(Generator, self).__init__()
|
|
|
|
|
+ self.num_kernels = len(resblock_kernel_sizes)
|
|
|
|
|
+ self.num_upsamples = len(upsample_rates)
|
|
|
|
|
+ self.conv_pre = Conv1d(
|
|
|
|
|
+ initial_channel, upsample_initial_channel, 7, 1, padding=3
|
|
|
|
|
+ )
|
|
|
|
|
+ resblock = ResBlock1 if resblock == "1" else ResBlock2
|
|
|
|
|
+
|
|
|
|
|
+ self.ups = nn.ModuleList()
|
|
|
|
|
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
|
|
|
|
+ self.ups.append(
|
|
|
|
|
+ weight_norm(
|
|
|
|
|
+ ConvTranspose1d(
|
|
|
|
|
+ upsample_initial_channel // (2**i),
|
|
|
|
|
+ upsample_initial_channel // (2 ** (i + 1)),
|
|
|
|
|
+ k,
|
|
|
|
|
+ u,
|
|
|
|
|
+ padding=(k - u) // 2,
|
|
|
|
|
+ )
|
|
|
|
|
+ )
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ self.resblocks = nn.ModuleList()
|
|
|
|
|
+ for i in range(len(self.ups)):
|
|
|
|
|
+ ch = upsample_initial_channel // (2 ** (i + 1))
|
|
|
|
|
+ for j, (k, d) in enumerate(
|
|
|
|
|
+ zip(resblock_kernel_sizes, resblock_dilation_sizes)
|
|
|
|
|
+ ):
|
|
|
|
|
+ self.resblocks.append(resblock(ch, k, d))
|
|
|
|
|
+
|
|
|
|
|
+ self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
|
|
|
|
|
+ self.ups.apply(init_weights)
|
|
|
|
|
+
|
|
|
|
|
+ if gin_channels != 0:
|
|
|
|
|
+ self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
|
|
|
|
|
+
|
|
|
|
|
+ def forward(self, x, g=None):
|
|
|
|
|
+ x = self.conv_pre(x)
|
|
|
|
|
+ if g is not None:
|
|
|
|
|
+ x = x + self.cond(g)
|
|
|
|
|
+
|
|
|
|
|
+ for i in range(self.num_upsamples):
|
|
|
|
|
+ x = F.leaky_relu(x, LRELU_SLOPE)
|
|
|
|
|
+ x = self.ups[i](x)
|
|
|
|
|
+ xs = None
|
|
|
|
|
+ for j in range(self.num_kernels):
|
|
|
|
|
+ if xs is None:
|
|
|
|
|
+ xs = self.resblocks[i * self.num_kernels + j](x)
|
|
|
|
|
+ else:
|
|
|
|
|
+ xs += self.resblocks[i * self.num_kernels + j](x)
|
|
|
|
|
+ x = xs / self.num_kernels
|
|
|
|
|
+ x = F.leaky_relu(x)
|
|
|
|
|
+ x = self.conv_post(x)
|
|
|
|
|
+ x = torch.tanh(x)
|
|
|
|
|
+
|
|
|
|
|
+ return x
|
|
|
|
|
+
|
|
|
|
|
+ def remove_weight_norm(self):
|
|
|
|
|
+ print("Removing weight norm...")
|
|
|
|
|
+ for l in self.ups:
|
|
|
|
|
+ remove_weight_norm(l)
|
|
|
|
|
+ for l in self.resblocks:
|
|
|
|
|
+ l.remove_weight_norm()
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class DiscriminatorP(nn.Module):
|
|
|
|
|
+ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
|
|
|
|
|
+ super(DiscriminatorP, self).__init__()
|
|
|
|
|
+ self.period = period
|
|
|
|
|
+ self.use_spectral_norm = use_spectral_norm
|
|
|
|
|
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
|
|
|
|
+ self.convs = nn.ModuleList(
|
|
|
|
|
+ [
|
|
|
|
|
+ norm_f(
|
|
|
|
|
+ Conv2d(
|
|
|
|
|
+ 1,
|
|
|
|
|
+ 32,
|
|
|
|
|
+ (kernel_size, 1),
|
|
|
|
|
+ (stride, 1),
|
|
|
|
|
+ padding=(get_padding(kernel_size, 1), 0),
|
|
|
|
|
+ )
|
|
|
|
|
+ ),
|
|
|
|
|
+ norm_f(
|
|
|
|
|
+ Conv2d(
|
|
|
|
|
+ 32,
|
|
|
|
|
+ 128,
|
|
|
|
|
+ (kernel_size, 1),
|
|
|
|
|
+ (stride, 1),
|
|
|
|
|
+ padding=(get_padding(kernel_size, 1), 0),
|
|
|
|
|
+ )
|
|
|
|
|
+ ),
|
|
|
|
|
+ norm_f(
|
|
|
|
|
+ Conv2d(
|
|
|
|
|
+ 128,
|
|
|
|
|
+ 512,
|
|
|
|
|
+ (kernel_size, 1),
|
|
|
|
|
+ (stride, 1),
|
|
|
|
|
+ padding=(get_padding(kernel_size, 1), 0),
|
|
|
|
|
+ )
|
|
|
|
|
+ ),
|
|
|
|
|
+ norm_f(
|
|
|
|
|
+ Conv2d(
|
|
|
|
|
+ 512,
|
|
|
|
|
+ 1024,
|
|
|
|
|
+ (kernel_size, 1),
|
|
|
|
|
+ (stride, 1),
|
|
|
|
|
+ padding=(get_padding(kernel_size, 1), 0),
|
|
|
|
|
+ )
|
|
|
|
|
+ ),
|
|
|
|
|
+ norm_f(
|
|
|
|
|
+ Conv2d(
|
|
|
|
|
+ 1024,
|
|
|
|
|
+ 1024,
|
|
|
|
|
+ (kernel_size, 1),
|
|
|
|
|
+ 1,
|
|
|
|
|
+ padding=(get_padding(kernel_size, 1), 0),
|
|
|
|
|
+ )
|
|
|
|
|
+ ),
|
|
|
|
|
+ ]
|
|
|
|
|
+ )
|
|
|
|
|
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
|
|
|
|
|
+
|
|
|
|
|
+ def forward(self, x):
|
|
|
|
|
+ fmap = []
|
|
|
|
|
+
|
|
|
|
|
+ # 1d to 2d
|
|
|
|
|
+ b, c, t = x.shape
|
|
|
|
|
+ if t % self.period != 0: # pad first
|
|
|
|
|
+ n_pad = self.period - (t % self.period)
|
|
|
|
|
+ x = F.pad(x, (0, n_pad), "reflect")
|
|
|
|
|
+ t = t + n_pad
|
|
|
|
|
+ x = x.view(b, c, t // self.period, self.period)
|
|
|
|
|
+
|
|
|
|
|
+ for l in self.convs:
|
|
|
|
|
+ x = l(x)
|
|
|
|
|
+ x = F.leaky_relu(x, LRELU_SLOPE)
|
|
|
|
|
+ fmap.append(x)
|
|
|
|
|
+ x = self.conv_post(x)
|
|
|
|
|
+ fmap.append(x)
|
|
|
|
|
+ x = torch.flatten(x, 1, -1)
|
|
|
|
|
+
|
|
|
|
|
+ return x, fmap
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class DiscriminatorS(nn.Module):
|
|
|
|
|
+ def __init__(self, use_spectral_norm=False):
|
|
|
|
|
+ super(DiscriminatorS, self).__init__()
|
|
|
|
|
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
|
|
|
|
+ self.convs = nn.ModuleList(
|
|
|
|
|
+ [
|
|
|
|
|
+ norm_f(Conv1d(1, 16, 15, 1, padding=7)),
|
|
|
|
|
+ norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
|
|
|
|
|
+ norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
|
|
|
|
|
+ norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
|
|
|
|
|
+ norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
|
|
|
|
|
+ norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
|
|
|
|
|
+ ]
|
|
|
|
|
+ )
|
|
|
|
|
+ self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
|
|
|
|
|
+
|
|
|
|
|
+ def forward(self, x):
|
|
|
|
|
+ fmap = []
|
|
|
|
|
+
|
|
|
|
|
+ for l in self.convs:
|
|
|
|
|
+ x = l(x)
|
|
|
|
|
+ x = F.leaky_relu(x, LRELU_SLOPE)
|
|
|
|
|
+ fmap.append(x)
|
|
|
|
|
+ x = self.conv_post(x)
|
|
|
|
|
+ fmap.append(x)
|
|
|
|
|
+ x = torch.flatten(x, 1, -1)
|
|
|
|
|
+
|
|
|
|
|
+ return x, fmap
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class EnsembleDiscriminator(nn.Module):
|
|
|
|
|
+ def __init__(self, use_spectral_norm=False):
|
|
|
|
|
+ super(EnsembleDiscriminator, self).__init__()
|
|
|
|
|
+ periods = [2, 3, 5, 7, 11]
|
|
|
|
|
+
|
|
|
|
|
+ discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
|
|
|
|
|
+ discs = discs + [
|
|
|
|
|
+ DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
|
|
|
|
|
+ ]
|
|
|
|
|
+ self.discriminators = nn.ModuleList(discs)
|
|
|
|
|
+
|
|
|
|
|
+ def forward(self, y, y_hat):
|
|
|
|
|
+ y_d_rs = []
|
|
|
|
|
+ y_d_gs = []
|
|
|
|
|
+ fmap_rs = []
|
|
|
|
|
+ fmap_gs = []
|
|
|
|
|
+ for i, d in enumerate(self.discriminators):
|
|
|
|
|
+ y_d_r, fmap_r = d(y)
|
|
|
|
|
+ y_d_g, fmap_g = d(y_hat)
|
|
|
|
|
+ y_d_rs.append(y_d_r)
|
|
|
|
|
+ y_d_gs.append(y_d_g)
|
|
|
|
|
+ fmap_rs.append(fmap_r)
|
|
|
|
|
+ fmap_gs.append(fmap_g)
|
|
|
|
|
+
|
|
|
|
|
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|