|
|
@@ -1,8 +1,8 @@
|
|
|
-import math
|
|
|
-from dataclasses import dataclass
|
|
|
+from typing import Optional
|
|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
+from vector_quantize_pytorch import VectorQuantize
|
|
|
|
|
|
from fish_speech.models.vqgan.modules.modules import WN
|
|
|
from fish_speech.models.vqgan.modules.transformer import RelativePositionTransformer
|
|
|
@@ -158,7 +158,6 @@ class SpeakerEncoder(nn.Module):
|
|
|
) -> None:
|
|
|
super().__init__()
|
|
|
|
|
|
- self.query = nn.Parameter(torch.randn(1, 1, hidden_channels))
|
|
|
self.in_proj = nn.Sequential(
|
|
|
nn.Conv1d(in_channels, hidden_channels, 1),
|
|
|
nn.SiLU(),
|
|
|
@@ -168,17 +167,16 @@ class SpeakerEncoder(nn.Module):
|
|
|
nn.SiLU(),
|
|
|
nn.Dropout(p_dropout),
|
|
|
)
|
|
|
-
|
|
|
- self.blocks = nn.ModuleList(
|
|
|
- [
|
|
|
- nn.MultiheadAttention(
|
|
|
- embed_dim=hidden_channels,
|
|
|
- num_heads=num_heads,
|
|
|
- dropout=p_dropout,
|
|
|
- batch_first=True,
|
|
|
- )
|
|
|
- for _ in range(num_layers)
|
|
|
- ]
|
|
|
+ self.encoder = RelativePositionTransformer(
|
|
|
+ in_channels=hidden_channels,
|
|
|
+ out_channels=hidden_channels,
|
|
|
+ hidden_channels=hidden_channels,
|
|
|
+ hidden_channels_ffn=hidden_channels,
|
|
|
+ n_heads=num_heads,
|
|
|
+ n_layers=num_layers,
|
|
|
+ kernel_size=5,
|
|
|
+ dropout=p_dropout,
|
|
|
+ window_size=4,
|
|
|
)
|
|
|
self.out_proj = nn.Linear(hidden_channels, out_channels)
|
|
|
|
|
|
@@ -189,22 +187,76 @@ class SpeakerEncoder(nn.Module):
|
|
|
- x_lengths: :math:`[B, 1]`
|
|
|
"""
|
|
|
|
|
|
- x_mask = ~(sequence_mask(mel_lengths, mels.size(2)).bool())
|
|
|
+ x_mask = torch.unsqueeze(sequence_mask(mel_lengths, mels.size(2)), 1).to(
|
|
|
+ mels.dtype
|
|
|
+ )
|
|
|
+ x = self.in_proj(mels) * x_mask
|
|
|
+ x = self.encoder(x, x_mask)
|
|
|
+
|
|
|
+ # Avg Pooling
|
|
|
+ x = x * x_mask
|
|
|
+ x = torch.sum(x, dim=2) / torch.sum(x_mask, dim=2)
|
|
|
+ x = self.out_proj(x)[..., None]
|
|
|
+
|
|
|
+ return x
|
|
|
+
|
|
|
+
|
|
|
+class VQEncoder(nn.Module):
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ in_channels: int = 1024,
|
|
|
+ vq_channels: int = 1024,
|
|
|
+ codebook_size: int = 2048,
|
|
|
+ downsample: int = 2,
|
|
|
+ kmeans_ckpt: Optional[str] = None,
|
|
|
+ ):
|
|
|
+ super().__init__()
|
|
|
+
|
|
|
+ self.vq = VectorQuantize(
|
|
|
+ dim=vq_channels,
|
|
|
+ codebook_size=codebook_size,
|
|
|
+ threshold_ema_dead_code=2,
|
|
|
+ kmeans_init=False,
|
|
|
+ channel_last=False,
|
|
|
+ )
|
|
|
|
|
|
- x = self.in_proj(mels).transpose(1, 2)
|
|
|
- x = torch.cat([self.query.expand(x.shape[0], -1, -1), x], dim=1)
|
|
|
+ self.conv_in = nn.Conv1d(
|
|
|
+ in_channels, vq_channels, kernel_size=downsample, stride=downsample
|
|
|
+ )
|
|
|
+ self.conv_out = nn.Sequential(
|
|
|
+ nn.Upsample(scale_factor=downsample, mode="nearest"),
|
|
|
+ nn.Conv1d(vq_channels, in_channels, kernel_size=1, stride=1),
|
|
|
+ )
|
|
|
+
|
|
|
+ if kmeans_ckpt is not None:
|
|
|
+ self.init_weights(kmeans_ckpt)
|
|
|
|
|
|
- x_mask = torch.cat(
|
|
|
- [
|
|
|
- torch.zeros(x.shape[0], 1, dtype=torch.bool, device=x.device),
|
|
|
- x_mask,
|
|
|
- ],
|
|
|
- dim=1,
|
|
|
+ def init_weights(self, kmeans_ckpt):
|
|
|
+ torch.nn.init.normal_(
|
|
|
+ self.conv_in.weight,
|
|
|
+ mean=1 / (self.conv_in.weight.shape[0] * self.conv_in.weight.shape[-1]),
|
|
|
+ std=1e-2,
|
|
|
)
|
|
|
+ self.conv_in.bias.data.zero_()
|
|
|
|
|
|
- for block in self.blocks:
|
|
|
- x = block(x, x, x, key_padding_mask=x_mask)[0]
|
|
|
+ kmeans_ckpt = "results/hubert-vq-pretrain/kmeans.pt"
|
|
|
+ kmeans_ckpt = torch.load(kmeans_ckpt, map_location="cpu")
|
|
|
|
|
|
- x = self.out_proj(x[:, :1, :]).mT
|
|
|
+ centroids = kmeans_ckpt["centroids"]
|
|
|
+ bins = kmeans_ckpt["bins"]
|
|
|
+ state_dict = {
|
|
|
+ "_codebook.initted": torch.Tensor([True]),
|
|
|
+ "_codebook.cluster_size": bins,
|
|
|
+ "_codebook.embed": centroids,
|
|
|
+ "_codebook.embed_avg": centroids.clone(),
|
|
|
+ }
|
|
|
|
|
|
- return x
|
|
|
+ self.vq.load_state_dict(state_dict, strict=True)
|
|
|
+
|
|
|
+ def forward(self, x):
|
|
|
+ # x: [B, T, C]
|
|
|
+ x = self.conv_in(x.mT)
|
|
|
+ q, _, loss = self.vq(x)
|
|
|
+ x = self.conv_out(q).mT
|
|
|
+
|
|
|
+ return x, loss
|