|
@@ -2,6 +2,7 @@ from typing import Optional
|
|
|
|
|
|
|
|
import torch
|
|
import torch
|
|
|
import torch.nn as nn
|
|
import torch.nn as nn
|
|
|
|
|
+import torch.nn.functional as F
|
|
|
from vector_quantize_pytorch import VectorQuantize
|
|
from vector_quantize_pytorch import VectorQuantize
|
|
|
|
|
|
|
|
from fish_speech.models.vqgan.modules.modules import WN
|
|
from fish_speech.models.vqgan.modules.modules import WN
|
|
@@ -13,7 +14,7 @@ from fish_speech.models.vqgan.utils import sequence_mask
|
|
|
class TextEncoder(nn.Module):
|
|
class TextEncoder(nn.Module):
|
|
|
def __init__(
|
|
def __init__(
|
|
|
self,
|
|
self,
|
|
|
- n_vocab: int,
|
|
|
|
|
|
|
+ in_channels: int,
|
|
|
out_channels: int,
|
|
out_channels: int,
|
|
|
hidden_channels: int,
|
|
hidden_channels: int,
|
|
|
hidden_channels_ffn: int,
|
|
hidden_channels_ffn: int,
|
|
@@ -23,11 +24,12 @@ class TextEncoder(nn.Module):
|
|
|
dropout: float,
|
|
dropout: float,
|
|
|
gin_channels=0,
|
|
gin_channels=0,
|
|
|
speaker_cond_layer=0,
|
|
speaker_cond_layer=0,
|
|
|
|
|
+ use_vae=True,
|
|
|
):
|
|
):
|
|
|
"""Text Encoder for VITS model.
|
|
"""Text Encoder for VITS model.
|
|
|
|
|
|
|
|
Args:
|
|
Args:
|
|
|
- n_vocab (int): Number of characters for the embedding layer.
|
|
|
|
|
|
|
+ in_channels (int): Number of characters for the embedding layer.
|
|
|
out_channels (int): Number of channels for the output.
|
|
out_channels (int): Number of channels for the output.
|
|
|
hidden_channels (int): Number of channels for the hidden layers.
|
|
hidden_channels (int): Number of channels for the hidden layers.
|
|
|
hidden_channels_ffn (int): Number of channels for the convolutional layers.
|
|
hidden_channels_ffn (int): Number of channels for the convolutional layers.
|
|
@@ -41,9 +43,7 @@ class TextEncoder(nn.Module):
|
|
|
self.out_channels = out_channels
|
|
self.out_channels = out_channels
|
|
|
self.hidden_channels = hidden_channels
|
|
self.hidden_channels = hidden_channels
|
|
|
|
|
|
|
|
- # self.emb = nn.Linear(n_vocab, hidden_channels)
|
|
|
|
|
- self.emb = nn.Linear(n_vocab, hidden_channels, 1)
|
|
|
|
|
- # nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
|
|
|
|
|
|
|
+ self.proj_in = nn.Conv1d(in_channels, hidden_channels, 1)
|
|
|
|
|
|
|
|
self.encoder = RelativePositionTransformer(
|
|
self.encoder = RelativePositionTransformer(
|
|
|
in_channels=hidden_channels,
|
|
in_channels=hidden_channels,
|
|
@@ -58,12 +58,15 @@ class TextEncoder(nn.Module):
|
|
|
gin_channels=gin_channels,
|
|
gin_channels=gin_channels,
|
|
|
speaker_cond_layer=speaker_cond_layer,
|
|
speaker_cond_layer=speaker_cond_layer,
|
|
|
)
|
|
)
|
|
|
- self.proj = nn.Linear(hidden_channels, out_channels * 2)
|
|
|
|
|
|
|
+ self.proj_out = nn.Conv1d(
|
|
|
|
|
+ hidden_channels, out_channels * 2 if use_vae else out_channels, 1
|
|
|
|
|
+ )
|
|
|
|
|
+ self.use_vae = use_vae
|
|
|
|
|
|
|
|
def forward(
|
|
def forward(
|
|
|
self,
|
|
self,
|
|
|
x: torch.Tensor,
|
|
x: torch.Tensor,
|
|
|
- x_lengths: torch.Tensor,
|
|
|
|
|
|
|
+ x_mask: torch.Tensor,
|
|
|
g: torch.Tensor = None,
|
|
g: torch.Tensor = None,
|
|
|
noise_scale: float = 1,
|
|
noise_scale: float = 1,
|
|
|
):
|
|
):
|
|
@@ -72,14 +75,14 @@ class TextEncoder(nn.Module):
|
|
|
- x: :math:`[B, T]`
|
|
- x: :math:`[B, T]`
|
|
|
- x_length: :math:`[B]`
|
|
- x_length: :math:`[B]`
|
|
|
"""
|
|
"""
|
|
|
- # x = self.emb(x).mT * math.sqrt(self.hidden_channels) # [b, h, t]
|
|
|
|
|
- x = self.emb(x).mT # * math.sqrt(self.hidden_channels) # [b, h, t]
|
|
|
|
|
- x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
|
|
|
|
-
|
|
|
|
|
|
|
+ x = self.proj_in(x) * x_mask
|
|
|
x = self.encoder(x, x_mask, g=g)
|
|
x = self.encoder(x, x_mask, g=g)
|
|
|
- stats = self.proj(x.mT).mT * x_mask
|
|
|
|
|
|
|
+ x = self.proj_out(x) * x_mask
|
|
|
|
|
|
|
|
- m, logs = torch.split(stats, self.out_channels, dim=1)
|
|
|
|
|
|
|
+ if self.use_vae is False:
|
|
|
|
|
+ return x
|
|
|
|
|
+
|
|
|
|
|
+ m, logs = torch.split(x, self.out_channels, dim=1)
|
|
|
z = m + torch.randn_like(m) * torch.exp(logs) * x_mask * noise_scale
|
|
z = m + torch.randn_like(m) * torch.exp(logs) * x_mask * noise_scale
|
|
|
return z, m, logs, x, x_mask
|
|
return z, m, logs, x, x_mask
|
|
|
|
|
|
|
@@ -113,7 +116,7 @@ class PosteriorEncoder(nn.Module):
|
|
|
super().__init__()
|
|
super().__init__()
|
|
|
self.out_channels = out_channels
|
|
self.out_channels = out_channels
|
|
|
|
|
|
|
|
- self.pre = nn.Linear(in_channels, hidden_channels)
|
|
|
|
|
|
|
+ self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
|
|
|
self.enc = WN(
|
|
self.enc = WN(
|
|
|
hidden_channels,
|
|
hidden_channels,
|
|
|
kernel_size,
|
|
kernel_size,
|
|
@@ -121,7 +124,7 @@ class PosteriorEncoder(nn.Module):
|
|
|
n_layers,
|
|
n_layers,
|
|
|
gin_channels=gin_channels,
|
|
gin_channels=gin_channels,
|
|
|
)
|
|
)
|
|
|
- self.proj = nn.Linear(hidden_channels, out_channels * 2)
|
|
|
|
|
|
|
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
|
|
|
|
|
|
|
def forward(
|
|
def forward(
|
|
|
self,
|
|
self,
|
|
@@ -137,9 +140,9 @@ class PosteriorEncoder(nn.Module):
|
|
|
- g: :math:`[B, C, 1]`
|
|
- g: :math:`[B, C, 1]`
|
|
|
"""
|
|
"""
|
|
|
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
|
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
|
|
- x = self.pre(x.mT).mT * x_mask
|
|
|
|
|
|
|
+ x = self.pre(x) * x_mask
|
|
|
x = self.enc(x, x_mask, g=g)
|
|
x = self.enc(x, x_mask, g=g)
|
|
|
- stats = self.proj(x.mT).mT * x_mask
|
|
|
|
|
|
|
+ stats = self.proj(x) * x_mask
|
|
|
m, logs = torch.split(stats, self.out_channels, dim=1)
|
|
m, logs = torch.split(stats, self.out_channels, dim=1)
|
|
|
z = m + torch.randn_like(m) * torch.exp(logs) * x_mask * noise_scale
|
|
z = m + torch.randn_like(m) * torch.exp(logs) * x_mask * noise_scale
|
|
|
return z, m, logs, x_mask
|
|
return z, m, logs, x_mask
|
|
@@ -180,22 +183,19 @@ class SpeakerEncoder(nn.Module):
|
|
|
)
|
|
)
|
|
|
self.out_proj = nn.Linear(hidden_channels, out_channels)
|
|
self.out_proj = nn.Linear(hidden_channels, out_channels)
|
|
|
|
|
|
|
|
- def forward(self, mels, mel_lengths: torch.Tensor):
|
|
|
|
|
|
|
+ def forward(self, mels, mel_masks: torch.Tensor):
|
|
|
"""
|
|
"""
|
|
|
Shapes:
|
|
Shapes:
|
|
|
- x: :math:`[B, C, T]`
|
|
- x: :math:`[B, C, T]`
|
|
|
- x_lengths: :math:`[B, 1]`
|
|
- x_lengths: :math:`[B, 1]`
|
|
|
"""
|
|
"""
|
|
|
|
|
|
|
|
- x_mask = torch.unsqueeze(sequence_mask(mel_lengths, mels.size(2)), 1).to(
|
|
|
|
|
- mels.dtype
|
|
|
|
|
- )
|
|
|
|
|
- x = self.in_proj(mels) * x_mask
|
|
|
|
|
- x = self.encoder(x, x_mask)
|
|
|
|
|
|
|
+ x = self.in_proj(mels) * mel_masks
|
|
|
|
|
+ x = self.encoder(x, mel_masks)
|
|
|
|
|
|
|
|
# Avg Pooling
|
|
# Avg Pooling
|
|
|
- x = x * x_mask
|
|
|
|
|
- x = torch.sum(x, dim=2) / torch.sum(x_mask, dim=2)
|
|
|
|
|
|
|
+ x = x * mel_masks
|
|
|
|
|
+ x = torch.sum(x, dim=2) / torch.sum(mel_masks, dim=2)
|
|
|
x = self.out_proj(x)[..., None]
|
|
x = self.out_proj(x)[..., None]
|
|
|
|
|
|
|
|
return x
|
|
return x
|
|
@@ -219,7 +219,7 @@ class VQEncoder(nn.Module):
|
|
|
kmeans_init=False,
|
|
kmeans_init=False,
|
|
|
channel_last=False,
|
|
channel_last=False,
|
|
|
)
|
|
)
|
|
|
-
|
|
|
|
|
|
|
+ self.downsample = downsample
|
|
|
self.conv_in = nn.Conv1d(
|
|
self.conv_in = nn.Conv1d(
|
|
|
in_channels, vq_channels, kernel_size=downsample, stride=downsample
|
|
in_channels, vq_channels, kernel_size=downsample, stride=downsample
|
|
|
)
|
|
)
|
|
@@ -253,10 +253,17 @@ class VQEncoder(nn.Module):
|
|
|
|
|
|
|
|
self.vq.load_state_dict(state_dict, strict=True)
|
|
self.vq.load_state_dict(state_dict, strict=True)
|
|
|
|
|
|
|
|
- def forward(self, x):
|
|
|
|
|
- # x: [B, T, C]
|
|
|
|
|
- x = self.conv_in(x.mT)
|
|
|
|
|
|
|
+ def forward(self, x, x_mask):
|
|
|
|
|
+ # x: [B, C, T], x_mask: [B, 1, T]
|
|
|
|
|
+ x_len = x.shape[2]
|
|
|
|
|
+
|
|
|
|
|
+ if x_len % self.downsample != 0:
|
|
|
|
|
+ x = F.pad(x, (0, self.downsample - x_len % self.downsample))
|
|
|
|
|
+ x_mask = F.pad(x_mask, (0, self.downsample - x_len % self.downsample))
|
|
|
|
|
+
|
|
|
|
|
+ x = self.conv_in(x)
|
|
|
q, _, loss = self.vq(x)
|
|
q, _, loss = self.vq(x)
|
|
|
- x = self.conv_out(q).mT
|
|
|
|
|
|
|
+ x = self.conv_out(q) * x_mask
|
|
|
|
|
+ x = x[:, :, :x_len]
|
|
|
|
|
|
|
|
return x, loss
|
|
return x, loss
|