| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115 |
- from typing import Optional
- import torch
- import torch.nn.functional as F
- from torch import nn
- from fish_speech.utils import autocast_exclude_mps
- from .wavenet import WaveNet
- class ReferenceEncoder(WaveNet):
- def __init__(
- self,
- input_channels: Optional[int] = None,
- output_channels: Optional[int] = None,
- residual_channels: int = 512,
- residual_layers: int = 20,
- dilation_cycle: Optional[int] = 4,
- num_heads: int = 8,
- latent_len: int = 4,
- ):
- super().__init__(
- input_channels=input_channels,
- residual_channels=residual_channels,
- residual_layers=residual_layers,
- dilation_cycle=dilation_cycle,
- )
- self.head_dim = residual_channels // num_heads
- self.num_heads = num_heads
- self.latent_len = latent_len
- self.latent = nn.Parameter(torch.zeros(1, self.latent_len, residual_channels))
- self.q = nn.Linear(residual_channels, residual_channels, bias=True)
- self.kv = nn.Linear(residual_channels, residual_channels * 2, bias=True)
- self.q_norm = nn.LayerNorm(self.head_dim)
- self.k_norm = nn.LayerNorm(self.head_dim)
- self.proj = nn.Linear(residual_channels, residual_channels)
- self.proj_drop = nn.Dropout(0.1)
- self.norm = nn.LayerNorm(residual_channels)
- self.mlp = nn.Sequential(
- nn.Linear(residual_channels, residual_channels * 4),
- nn.SiLU(),
- nn.Linear(residual_channels * 4, residual_channels),
- )
- self.output_projection_attn = nn.Linear(residual_channels, output_channels)
- torch.nn.init.trunc_normal_(self.latent, std=0.02)
- self.apply(self.init_weights)
- def init_weights(self, m):
- if isinstance(m, nn.Linear):
- torch.nn.init.trunc_normal_(m.weight, std=0.02)
- if m.bias is not None:
- torch.nn.init.constant_(m.bias, 0)
- def forward(self, x, attn_mask=None):
- x = super().forward(x).mT
- B, N, C = x.shape
- # Calculate mask
- if attn_mask is not None:
- assert attn_mask.shape == (B, N) and attn_mask.dtype == torch.bool
- attn_mask = attn_mask[:, None, None, :].expand(
- B, self.num_heads, self.latent_len, N
- )
- q_latent = self.latent.expand(B, -1, -1)
- q = (
- self.q(q_latent)
- .reshape(B, self.latent_len, self.num_heads, self.head_dim)
- .transpose(1, 2)
- )
- kv = (
- self.kv(x)
- .reshape(B, N, 2, self.num_heads, self.head_dim)
- .permute(2, 0, 3, 1, 4)
- )
- k, v = kv.unbind(0)
- q, k = self.q_norm(q), self.k_norm(k)
- x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
- x = x.transpose(1, 2).reshape(B, self.latent_len, C)
- x = self.proj(x)
- x = self.proj_drop(x)
- x = x + self.mlp(self.norm(x))
- x = self.output_projection_attn(x)
- x = x.mean(1)
- return x
- if __name__ == "__main__":
- with autocast_exclude_mps(device_type="cpu", dtype=torch.bfloat16):
- model = ReferenceEncoder(
- input_channels=128,
- output_channels=64,
- residual_channels=384,
- residual_layers=20,
- dilation_cycle=4,
- num_heads=8,
- )
- x = torch.randn(4, 128, 64)
- mask = torch.ones(4, 64, dtype=torch.bool)
- y = model(x, mask)
- print(y.shape)
- loss = F.mse_loss(y, torch.randn(4, 64))
- loss.backward()
|