|
|
@@ -1,10 +1,12 @@
|
|
|
import math
|
|
|
+from dataclasses import dataclass
|
|
|
|
|
|
import torch
|
|
|
from torch import nn
|
|
|
from torch.nn import Conv1d, Conv2d, ConvTranspose1d
|
|
|
from torch.nn import functional as F
|
|
|
from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
|
|
|
+from vector_quantize_pytorch import VectorQuantize
|
|
|
|
|
|
from fish_speech.models.hubert_vq.utils import (
|
|
|
convert_pad_shape,
|
|
|
@@ -15,17 +17,251 @@ from fish_speech.models.hubert_vq.utils import (
|
|
|
LRELU_SLOPE = 0.1
|
|
|
|
|
|
|
|
|
+@dataclass
|
|
|
+class VQEncoderOutput:
|
|
|
+ loss: torch.Tensor
|
|
|
+ features: torch.Tensor
|
|
|
+
|
|
|
+
|
|
|
class VQEncoder(nn.Module):
|
|
|
- def __init__(self, *args, **kwargs) -> None:
|
|
|
- super().__init__(*args, **kwargs)
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ in_channels: int = 1024,
|
|
|
+ channels: int = 192,
|
|
|
+ num_heads: int = 2,
|
|
|
+ num_feature_layers: int = 2,
|
|
|
+ num_speaker_layers: int = 4,
|
|
|
+ num_mixin_layers: int = 4,
|
|
|
+ input_downsample: bool = True,
|
|
|
+ code_book_size: int = 2048,
|
|
|
+ freeze_vq: bool = False,
|
|
|
+ ):
|
|
|
+ super().__init__()
|
|
|
+
|
|
|
+ # Feature Encoder
|
|
|
+ down_sample = 2 if input_downsample else 1
|
|
|
+
|
|
|
+ self.vq_in = nn.Linear(in_channels * down_sample, in_channels)
|
|
|
+ self.vq = VectorQuantize(
|
|
|
+ dim=in_channels,
|
|
|
+ codebook_size=code_book_size,
|
|
|
+ threshold_ema_dead_code=2,
|
|
|
+ )
|
|
|
+
|
|
|
+ self.feature_in = nn.Linear(in_channels, channels)
|
|
|
+ self.feature_blocks = nn.ModuleList(
|
|
|
+ [
|
|
|
+ TransformerBlock(
|
|
|
+ channels,
|
|
|
+ num_heads,
|
|
|
+ window_size=4,
|
|
|
+ window_heads_share=True,
|
|
|
+ proximal_init=True,
|
|
|
+ proximal_bias=False,
|
|
|
+ use_relative_attn=True,
|
|
|
+ )
|
|
|
+ for _ in range(num_feature_layers)
|
|
|
+ ]
|
|
|
+ )
|
|
|
|
|
|
- encoder_layer = nn.TransformerEncoderLayer(
|
|
|
- d_model=256, nhead=4, dim_feedforward=1024, dropout=0.1, activation="gelu"
|
|
|
+ # Speaker Encoder
|
|
|
+ self.speaker_query = nn.Parameter(torch.randn(1, 1, channels))
|
|
|
+ self.speaker_in = nn.Linear(in_channels * down_sample, channels)
|
|
|
+ self.speaker_blocks = nn.ModuleList(
|
|
|
+ [
|
|
|
+ TransformerBlock(
|
|
|
+ channels,
|
|
|
+ num_heads,
|
|
|
+ use_relative_attn=False,
|
|
|
+ )
|
|
|
+ for _ in range(num_speaker_layers)
|
|
|
+ ]
|
|
|
)
|
|
|
- self.encoder = nn.TransformerEncoder(
|
|
|
- encoder_layer, num_layers=6, norm=nn.LayerNorm(256)
|
|
|
+
|
|
|
+ # Final Mixer
|
|
|
+ self.mixer_in = nn.ModuleList(
|
|
|
+ [
|
|
|
+ TransformerBlock(
|
|
|
+ channels,
|
|
|
+ num_heads,
|
|
|
+ window_size=4,
|
|
|
+ window_heads_share=True,
|
|
|
+ proximal_init=True,
|
|
|
+ proximal_bias=False,
|
|
|
+ use_relative_attn=True,
|
|
|
+ )
|
|
|
+ for _ in range(num_mixin_layers)
|
|
|
+ ]
|
|
|
)
|
|
|
|
|
|
+ self.input_downsample = input_downsample
|
|
|
+
|
|
|
+ if freeze_vq:
|
|
|
+ for p in self.vq.parameters():
|
|
|
+ p.requires_grad = False
|
|
|
+
|
|
|
+ for p in self.vq_in.parameters():
|
|
|
+ p.requires_grad = False
|
|
|
+
|
|
|
+ def forward(self, x, key_padding_mask=None):
|
|
|
+ # (batch, seq_len, channels)
|
|
|
+
|
|
|
+ if self.input_downsample and key_padding_mask is not None:
|
|
|
+ key_padding_mask = key_padding_mask[:, ::2]
|
|
|
+
|
|
|
+ # Merge Channels
|
|
|
+ if self.input_downsample:
|
|
|
+ feature_0, feature_1 = x[:, ::2], x[:, 1::2]
|
|
|
+ min_len = min(feature_0.size(1), feature_1.size(1))
|
|
|
+ x = torch.cat([feature_0[:, :min_len], feature_1[:, :min_len]], dim=2)
|
|
|
+
|
|
|
+ # Encode Features
|
|
|
+ features = self.vq_in(x)
|
|
|
+ assert key_padding_mask.size(1) == features.size(
|
|
|
+ 1
|
|
|
+ ), f"key_padding_mask shape {key_padding_mask.size()} is not (batch_size, seq_len)"
|
|
|
+
|
|
|
+ features, _, loss = self.vq(features, mask=~key_padding_mask)
|
|
|
+
|
|
|
+ features = self.feature_in(features)
|
|
|
+ for block in self.feature_blocks:
|
|
|
+ features = block(features, key_padding_mask=key_padding_mask)
|
|
|
+
|
|
|
+ # Encode Speaker
|
|
|
+ speaker = self.speaker_in(x)
|
|
|
+ speaker = torch.cat(
|
|
|
+ [self.speaker_query.expand(speaker.shape[0], -1, -1), speaker], dim=1
|
|
|
+ )
|
|
|
+ for block in self.speaker_blocks:
|
|
|
+ speaker = block(speaker)
|
|
|
+
|
|
|
+ # Mix
|
|
|
+ x = features + speaker[:, :1]
|
|
|
+ for block in self.mixer_in:
|
|
|
+ x = block(x, key_padding_mask=key_padding_mask)
|
|
|
+
|
|
|
+ return VQEncoderOutput(
|
|
|
+ loss=loss,
|
|
|
+ features=x.transpose(1, 2),
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+class TransformerBlock(nn.Module):
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ channels,
|
|
|
+ n_heads,
|
|
|
+ mlp_ratio=4 * 2 / 3,
|
|
|
+ p_dropout=0.0,
|
|
|
+ window_size=4,
|
|
|
+ window_heads_share=True,
|
|
|
+ proximal_init=True,
|
|
|
+ proximal_bias=False,
|
|
|
+ use_relative_attn=True,
|
|
|
+ ):
|
|
|
+ super().__init__()
|
|
|
+
|
|
|
+ self.attn_norm = RMSNorm(channels)
|
|
|
+
|
|
|
+ if use_relative_attn:
|
|
|
+ self.attn = RelativeAttention(
|
|
|
+ channels,
|
|
|
+ n_heads,
|
|
|
+ p_dropout,
|
|
|
+ window_size,
|
|
|
+ window_heads_share,
|
|
|
+ proximal_init,
|
|
|
+ proximal_bias,
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ self.attn = nn.MultiheadAttention(
|
|
|
+ embed_dim=channels,
|
|
|
+ num_heads=n_heads,
|
|
|
+ dropout=p_dropout,
|
|
|
+ batch_first=True,
|
|
|
+ )
|
|
|
+
|
|
|
+ self.mlp_norm = RMSNorm(channels)
|
|
|
+ self.mlp = SwiGLU(channels, int(channels * mlp_ratio), channels, drop=p_dropout)
|
|
|
+
|
|
|
+ def forward(self, x, key_padding_mask=None):
|
|
|
+ norm_x = self.attn_norm(x)
|
|
|
+
|
|
|
+ if isinstance(self.attn, RelativeAttention):
|
|
|
+ attn = self.attn(norm_x, key_padding_mask=key_padding_mask)
|
|
|
+ else:
|
|
|
+ attn, _ = self.attn(
|
|
|
+ norm_x, norm_x, norm_x, key_padding_mask=key_padding_mask
|
|
|
+ )
|
|
|
+
|
|
|
+ x = x + attn
|
|
|
+ x = x + self.mlp(self.mlp_norm(x))
|
|
|
+
|
|
|
+ return x
|
|
|
+
|
|
|
+
|
|
|
+class SwiGLU(nn.Module):
|
|
|
+ """
|
|
|
+ Swish-Gated Linear Unit (SwiGLU) activation function
|
|
|
+ """
|
|
|
+
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ in_features,
|
|
|
+ hidden_features=None,
|
|
|
+ out_features=None,
|
|
|
+ bias=True,
|
|
|
+ drop=0.0,
|
|
|
+ ):
|
|
|
+ super().__init__()
|
|
|
+ out_features = out_features or in_features
|
|
|
+ hidden_features = hidden_features or in_features
|
|
|
+ assert hidden_features % 2 == 0
|
|
|
+
|
|
|
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
|
|
|
+ self.act = nn.SiLU()
|
|
|
+ self.drop1 = nn.Dropout(drop)
|
|
|
+ self.norm = RMSNorm(hidden_features // 2)
|
|
|
+ self.fc2 = nn.Linear(hidden_features // 2, out_features, bias=bias)
|
|
|
+ self.drop2 = nn.Dropout(drop)
|
|
|
+
|
|
|
+ def init_weights(self):
|
|
|
+ # override init of fc1 w/ gate portion set to weight near zero, bias=1
|
|
|
+ fc1_mid = self.fc1.bias.shape[0] // 2
|
|
|
+ nn.init.ones_(self.fc1.bias[fc1_mid:])
|
|
|
+ nn.init.normal_(self.fc1.weight[fc1_mid:], std=1e-6)
|
|
|
+
|
|
|
+ def forward(self, x):
|
|
|
+ x = self.fc1(x)
|
|
|
+ x1, x2 = x.chunk(2, dim=-1)
|
|
|
+
|
|
|
+ x = x1 * self.act(x2)
|
|
|
+ x = self.drop1(x)
|
|
|
+ x = self.norm(x)
|
|
|
+ x = self.fc2(x)
|
|
|
+ x = self.drop2(x)
|
|
|
+
|
|
|
+ return x
|
|
|
+
|
|
|
+
|
|
|
+class RMSNorm(nn.Module):
|
|
|
+ def __init__(self, hidden_size, eps=1e-6):
|
|
|
+ """
|
|
|
+ LlamaRMSNorm is equivalent to T5LayerNorm
|
|
|
+ """
|
|
|
+ super().__init__()
|
|
|
+
|
|
|
+ self.weight = nn.Parameter(torch.ones(hidden_size))
|
|
|
+ self.variance_epsilon = eps
|
|
|
+
|
|
|
+ def forward(self, hidden_states):
|
|
|
+ input_dtype = hidden_states.dtype
|
|
|
+ hidden_states = hidden_states.to(torch.float32)
|
|
|
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
|
|
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
|
|
+
|
|
|
+ return self.weight * hidden_states.to(input_dtype)
|
|
|
+
|
|
|
|
|
|
class RelativeAttention(nn.Module):
|
|
|
def __init__(
|
|
|
@@ -117,11 +353,8 @@ class RelativeAttention(nn.Module):
|
|
|
key_padding_mask = key_padding_mask.view(batch_size, 1, 1, seq_len).expand(
|
|
|
-1, self.n_heads, -1, -1
|
|
|
)
|
|
|
- print(key_padding_mask.shape, scores.shape)
|
|
|
scores = scores.masked_fill(key_padding_mask, float("-inf"))
|
|
|
|
|
|
- print(scores[0, 0])
|
|
|
-
|
|
|
p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
|
|
|
p_attn = self.drop(p_attn)
|
|
|
output = torch.matmul(p_attn, value)
|
|
|
@@ -571,3 +804,13 @@ class EnsembleDiscriminator(nn.Module):
|
|
|
fmap_gs.append(fmap_g)
|
|
|
|
|
|
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ vq = VQEncoder()
|
|
|
+ x = torch.randn(1, 90, 1024)
|
|
|
+ key_padding_mask = torch.zeros(1, 90).bool()
|
|
|
+ key_padding_mask[:, 67:] = True
|
|
|
+
|
|
|
+ output = vq(x, key_padding_mask=key_padding_mask)
|
|
|
+ print(output)
|