|
|
@@ -8,7 +8,12 @@ from torch.nn import Conv1d, Conv2d, ConvTranspose1d
|
|
|
from torch.nn import functional as F
|
|
|
from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
|
|
|
|
|
|
-from fish_speech.models.vqgan.utils import convert_pad_shape, get_padding, init_weights
|
|
|
+from fish_speech.models.vqgan.utils import (
|
|
|
+ convert_pad_shape,
|
|
|
+ fused_add_tanh_sigmoid_multiply,
|
|
|
+ get_padding,
|
|
|
+ init_weights,
|
|
|
+)
|
|
|
|
|
|
LRELU_SLOPE = 0.1
|
|
|
|
|
|
@@ -16,19 +21,18 @@ LRELU_SLOPE = 0.1
|
|
|
@dataclass
|
|
|
class VQEncoderOutput:
|
|
|
loss: torch.Tensor
|
|
|
- features: torch.Tensor
|
|
|
+ mean: torch.Tensor
|
|
|
+ logs: torch.Tensor
|
|
|
|
|
|
|
|
|
class VQEncoder(nn.Module):
|
|
|
def __init__(
|
|
|
self,
|
|
|
in_channels: int = 1024,
|
|
|
- channels: int = 192,
|
|
|
- num_mels: int = 128,
|
|
|
+ channels: int = 384,
|
|
|
+ out_channels: int = 192,
|
|
|
num_heads: int = 2,
|
|
|
- num_feature_layers: int = 2,
|
|
|
- num_speaker_layers: int = 4,
|
|
|
- num_mixin_layers: int = 4,
|
|
|
+ num_layers: int = 8,
|
|
|
input_downsample: bool = True,
|
|
|
code_book_size: int = 2048,
|
|
|
freeze_vq: bool = False,
|
|
|
@@ -38,7 +42,7 @@ class VQEncoder(nn.Module):
|
|
|
# Feature Encoder
|
|
|
down_sample = 2 if input_downsample else 1
|
|
|
|
|
|
- self.vq_in = nn.Conv1d(
|
|
|
+ self.in_proj = nn.Conv1d(
|
|
|
in_channels, in_channels, kernel_size=down_sample, stride=down_sample
|
|
|
)
|
|
|
self.vq = VectorQuantization(
|
|
|
@@ -49,38 +53,14 @@ class VQEncoder(nn.Module):
|
|
|
kmeans_iters=50,
|
|
|
)
|
|
|
|
|
|
- self.feature_in = nn.Linear(in_channels, channels)
|
|
|
- self.feature_blocks = nn.ModuleList(
|
|
|
- [
|
|
|
- TransformerBlock(
|
|
|
- channels,
|
|
|
- num_heads,
|
|
|
- window_size=4,
|
|
|
- window_heads_share=True,
|
|
|
- proximal_init=True,
|
|
|
- proximal_bias=False,
|
|
|
- use_relative_attn=True,
|
|
|
- )
|
|
|
- for _ in range(num_feature_layers)
|
|
|
- ]
|
|
|
+ # Init weights of in_proj to mimic the effect of avg pooling
|
|
|
+ torch.nn.init.normal_(
|
|
|
+ self.in_proj.weight, mean=1 / (down_sample * in_channels), std=0.01
|
|
|
)
|
|
|
+ self.in_proj.bias.data.zero_()
|
|
|
|
|
|
- # Speaker Encoder
|
|
|
- self.speaker_query = nn.Parameter(torch.randn(1, 1, channels))
|
|
|
- self.speaker_in = nn.Linear(num_mels, channels)
|
|
|
- self.speaker_blocks = nn.ModuleList(
|
|
|
- [
|
|
|
- TransformerBlock(
|
|
|
- channels,
|
|
|
- num_heads,
|
|
|
- use_relative_attn=False,
|
|
|
- )
|
|
|
- for _ in range(num_speaker_layers)
|
|
|
- ]
|
|
|
- )
|
|
|
-
|
|
|
- # Final Mixer
|
|
|
- self.mixer_blocks = nn.ModuleList(
|
|
|
+ self.feature_in = nn.Linear(in_channels, channels)
|
|
|
+ self.blocks = nn.ModuleList(
|
|
|
[
|
|
|
TransformerBlock(
|
|
|
channels,
|
|
|
@@ -91,10 +71,12 @@ class VQEncoder(nn.Module):
|
|
|
proximal_bias=False,
|
|
|
use_relative_attn=True,
|
|
|
)
|
|
|
- for _ in range(num_mixin_layers)
|
|
|
+ for _ in range(num_layers)
|
|
|
]
|
|
|
)
|
|
|
|
|
|
+ self.out_proj = nn.Linear(channels, out_channels * 2)
|
|
|
+
|
|
|
self.input_downsample = input_downsample
|
|
|
|
|
|
if freeze_vq:
|
|
|
@@ -104,22 +86,15 @@ class VQEncoder(nn.Module):
|
|
|
for p in self.vq_in.parameters():
|
|
|
p.requires_grad = False
|
|
|
|
|
|
- def forward(
|
|
|
- self, x, mels, key_padding_mask=None, mels_key_padding_mask=None
|
|
|
- ) -> VQEncoderOutput:
|
|
|
+ def forward(self, x, key_padding_mask=None) -> VQEncoderOutput:
|
|
|
# x: (batch, seq_len, channels)
|
|
|
- # mels: (batch, seq_len, 128)
|
|
|
|
|
|
assert key_padding_mask.size(1) == x.size(
|
|
|
1
|
|
|
- ), f"key_padding_mask shape {key_padding_mask.size()} does not match features shape {features.size()}"
|
|
|
-
|
|
|
- assert mels_key_padding_mask.size(1) == mels.size(
|
|
|
- 1
|
|
|
- ), f"mels_key_padding_mask shape {mels_key_padding_mask.size()} does not match mels shape {mels.size()}"
|
|
|
+ ), f"key_padding_mask shape {key_padding_mask.size()} does not match features shape {x.size()}"
|
|
|
|
|
|
# Encode Features
|
|
|
- features = self.vq_in(x.transpose(1, 2))
|
|
|
+ features = self.in_proj(x.transpose(1, 2))
|
|
|
features, _, loss = self.vq(features)
|
|
|
features = features.transpose(1, 2)
|
|
|
|
|
|
@@ -136,35 +111,204 @@ class VQEncoder(nn.Module):
|
|
|
key_padding_mask = key_padding_mask[:, :min_len]
|
|
|
|
|
|
features = self.feature_in(features)
|
|
|
- for block in self.feature_blocks:
|
|
|
+ for block in self.blocks:
|
|
|
features = block(features, key_padding_mask=key_padding_mask)
|
|
|
|
|
|
- # Encode Speaker
|
|
|
- speaker = self.speaker_in(mels)
|
|
|
- speaker = torch.cat(
|
|
|
- [self.speaker_query.expand(speaker.shape[0], -1, -1), speaker], dim=1
|
|
|
+ stats = self.out_proj(features).transpose(1, 2)
|
|
|
+ stats = torch.masked_fill(stats, key_padding_mask.unsqueeze(1), 0)
|
|
|
+ mean, logs = torch.chunk(stats, 2, dim=1)
|
|
|
+
|
|
|
+ return VQEncoderOutput(
|
|
|
+ loss=loss,
|
|
|
+ mean=mean,
|
|
|
+ logs=logs,
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+class WaveNet(nn.Module):
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ hidden_channels,
|
|
|
+ kernel_size,
|
|
|
+ dilation_rate,
|
|
|
+ n_layers,
|
|
|
+ gin_channels=0,
|
|
|
+ p_dropout=0,
|
|
|
+ ):
|
|
|
+ super(WaveNet, self).__init__()
|
|
|
+ assert kernel_size % 2 == 1
|
|
|
+ self.hidden_channels = hidden_channels
|
|
|
+ self.kernel_size = (kernel_size,)
|
|
|
+ self.dilation_rate = dilation_rate
|
|
|
+ self.n_layers = n_layers
|
|
|
+ self.gin_channels = gin_channels
|
|
|
+ self.p_dropout = p_dropout
|
|
|
+
|
|
|
+ self.in_layers = nn.ModuleList()
|
|
|
+ self.res_skip_layers = nn.ModuleList()
|
|
|
+ self.drop = nn.Dropout(p_dropout)
|
|
|
+
|
|
|
+ if gin_channels != 0:
|
|
|
+ self.cond_layer = weight_norm(
|
|
|
+ nn.Conv1d(gin_channels, 2 * hidden_channels * n_layers, 1)
|
|
|
+ )
|
|
|
+
|
|
|
+ for i in range(n_layers):
|
|
|
+ dilation = dilation_rate**i
|
|
|
+ padding = int((kernel_size * dilation - dilation) / 2)
|
|
|
+ in_layer = weight_norm(
|
|
|
+ nn.Conv1d(
|
|
|
+ hidden_channels,
|
|
|
+ 2 * hidden_channels,
|
|
|
+ kernel_size,
|
|
|
+ dilation=dilation,
|
|
|
+ padding=padding,
|
|
|
+ )
|
|
|
+ )
|
|
|
+ self.in_layers.append(in_layer)
|
|
|
+
|
|
|
+ # last one is not necessary
|
|
|
+ if i < n_layers - 1:
|
|
|
+ res_skip_channels = 2 * hidden_channels
|
|
|
+ else:
|
|
|
+ res_skip_channels = hidden_channels
|
|
|
+
|
|
|
+ res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
|
|
|
+ res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight")
|
|
|
+ self.res_skip_layers.append(res_skip_layer)
|
|
|
+
|
|
|
+ def forward(self, x, x_mask, g=None):
|
|
|
+ output = torch.zeros_like(x)
|
|
|
+ n_channels_tensor = torch.IntTensor([self.hidden_channels])
|
|
|
+
|
|
|
+ if g is not None:
|
|
|
+ g = self.cond_layer(g)
|
|
|
+
|
|
|
+ for i in range(self.n_layers):
|
|
|
+ x_in = self.in_layers[i](x)
|
|
|
+ if g is not None:
|
|
|
+ cond_offset = i * 2 * self.hidden_channels
|
|
|
+ g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
|
|
|
+ else:
|
|
|
+ g_l = torch.zeros_like(x_in)
|
|
|
+
|
|
|
+ acts = fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
|
|
|
+ acts = self.drop(acts)
|
|
|
+
|
|
|
+ res_skip_acts = self.res_skip_layers[i](acts)
|
|
|
+ if i < self.n_layers - 1:
|
|
|
+ res_acts = res_skip_acts[:, : self.hidden_channels, :]
|
|
|
+ x = (x + res_acts) * x_mask
|
|
|
+ output = output + res_skip_acts[:, self.hidden_channels :, :]
|
|
|
+ else:
|
|
|
+ output = output + res_skip_acts
|
|
|
+
|
|
|
+ return output * x_mask
|
|
|
+
|
|
|
+ def remove_weight_norm(self):
|
|
|
+ if self.gin_channels != 0:
|
|
|
+ torch.nn.utils.remove_weight_norm(self.cond_layer)
|
|
|
+ for l in self.in_layers:
|
|
|
+ torch.nn.utils.remove_weight_norm(l)
|
|
|
+ for l in self.res_skip_layers:
|
|
|
+ torch.nn.utils.remove_weight_norm(l)
|
|
|
+
|
|
|
+
|
|
|
+@dataclass
|
|
|
+class PosteriorEncoderOutput:
|
|
|
+ z: torch.Tensor
|
|
|
+ mean: torch.Tensor
|
|
|
+ logs: torch.Tensor
|
|
|
+
|
|
|
+
|
|
|
+class PosteriorEncoder(nn.Module):
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ in_channels: int = 1024,
|
|
|
+ out_channels: int = 192,
|
|
|
+ hidden_channels: int = 192,
|
|
|
+ kernel_size: int = 5,
|
|
|
+ dilation_rate: int = 1,
|
|
|
+ n_layers: int = 16,
|
|
|
+ gin_channels: int = 512,
|
|
|
+ ):
|
|
|
+ super().__init__()
|
|
|
+ self.in_channels = in_channels
|
|
|
+ self.out_channels = out_channels
|
|
|
+ self.hidden_channels = hidden_channels
|
|
|
+ self.kernel_size = kernel_size
|
|
|
+ self.dilation_rate = dilation_rate
|
|
|
+ self.n_layers = n_layers
|
|
|
+ self.gin_channels = gin_channels
|
|
|
+
|
|
|
+ self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
|
|
|
+ self.enc = WaveNet(
|
|
|
+ hidden_channels,
|
|
|
+ kernel_size,
|
|
|
+ dilation_rate,
|
|
|
+ n_layers,
|
|
|
+ gin_channels=gin_channels,
|
|
|
+ )
|
|
|
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
|
|
+
|
|
|
+ def forward(self, x, x_mask, g=None):
|
|
|
+ g = g.detach()
|
|
|
+ x = self.pre(x) * x_mask
|
|
|
+ x = self.enc(x, x_mask, g=g)
|
|
|
+ stats = self.proj(x) * x_mask
|
|
|
+ m, logs = torch.split(stats, self.out_channels, dim=1)
|
|
|
+ z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
|
|
|
+
|
|
|
+ return PosteriorEncoderOutput(
|
|
|
+ z=z,
|
|
|
+ mean=m,
|
|
|
+ logs=logs,
|
|
|
)
|
|
|
+
|
|
|
+
|
|
|
+class SpeakerEncoder(nn.Module):
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ in_channels: int = 128,
|
|
|
+ channels: int = 192,
|
|
|
+ out_channels: int = 512,
|
|
|
+ num_heads: int = 2,
|
|
|
+ num_layers: int = 4,
|
|
|
+ ) -> None:
|
|
|
+ super().__init__()
|
|
|
+
|
|
|
+ self.query = nn.Parameter(torch.randn(1, 1, channels))
|
|
|
+ self.in_proj = nn.Linear(in_channels, channels)
|
|
|
+ self.blocks = nn.ModuleList(
|
|
|
+ [
|
|
|
+ TransformerBlock(
|
|
|
+ channels,
|
|
|
+ num_heads,
|
|
|
+ use_relative_attn=False,
|
|
|
+ )
|
|
|
+ for _ in range(num_layers)
|
|
|
+ ]
|
|
|
+ )
|
|
|
+ self.out_proj = nn.Linear(channels, out_channels)
|
|
|
+
|
|
|
+ def forward(self, mels, mels_key_padding_mask=None):
|
|
|
+ x = self.in_proj(mels)
|
|
|
+ x = torch.cat([self.query.expand(x.shape[0], -1, -1), x], dim=1)
|
|
|
+
|
|
|
mels_key_padding_mask = torch.cat(
|
|
|
[
|
|
|
- torch.ones(
|
|
|
- speaker.shape[0], 1, dtype=torch.bool, device=speaker.device
|
|
|
- ),
|
|
|
+ torch.ones(x.shape[0], 1, dtype=torch.bool, device=x.device),
|
|
|
mels_key_padding_mask,
|
|
|
],
|
|
|
dim=1,
|
|
|
)
|
|
|
- for block in self.speaker_blocks:
|
|
|
- speaker = block(speaker, key_padding_mask=mels_key_padding_mask)
|
|
|
+ for block in self.blocks:
|
|
|
+ x = block(x, key_padding_mask=mels_key_padding_mask)
|
|
|
|
|
|
- # Mix
|
|
|
- x = features + speaker[:, :1]
|
|
|
- for block in self.mixer_blocks:
|
|
|
- x = block(x, key_padding_mask=key_padding_mask)
|
|
|
+ x = x[:, :1]
|
|
|
+ x = self.out_proj(x)
|
|
|
|
|
|
- return VQEncoderOutput(
|
|
|
- loss=loss,
|
|
|
- features=x.transpose(1, 2),
|
|
|
- )
|
|
|
+ return x.transpose(1, 2)
|
|
|
|
|
|
|
|
|
class TransformerBlock(nn.Module):
|