reference.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. from typing import Optional
  2. import torch
  3. import torch.nn.functional as F
  4. from torch import nn
  5. from fish_speech.utils import autocast_exclude_mps
  6. from .wavenet import WaveNet
  7. class ReferenceEncoder(WaveNet):
  8. def __init__(
  9. self,
  10. input_channels: Optional[int] = None,
  11. output_channels: Optional[int] = None,
  12. residual_channels: int = 512,
  13. residual_layers: int = 20,
  14. dilation_cycle: Optional[int] = 4,
  15. num_heads: int = 8,
  16. latent_len: int = 4,
  17. ):
  18. super().__init__(
  19. input_channels=input_channels,
  20. residual_channels=residual_channels,
  21. residual_layers=residual_layers,
  22. dilation_cycle=dilation_cycle,
  23. )
  24. self.head_dim = residual_channels // num_heads
  25. self.num_heads = num_heads
  26. self.latent_len = latent_len
  27. self.latent = nn.Parameter(torch.zeros(1, self.latent_len, residual_channels))
  28. self.q = nn.Linear(residual_channels, residual_channels, bias=True)
  29. self.kv = nn.Linear(residual_channels, residual_channels * 2, bias=True)
  30. self.q_norm = nn.LayerNorm(self.head_dim)
  31. self.k_norm = nn.LayerNorm(self.head_dim)
  32. self.proj = nn.Linear(residual_channels, residual_channels)
  33. self.proj_drop = nn.Dropout(0.1)
  34. self.norm = nn.LayerNorm(residual_channels)
  35. self.mlp = nn.Sequential(
  36. nn.Linear(residual_channels, residual_channels * 4),
  37. nn.SiLU(),
  38. nn.Linear(residual_channels * 4, residual_channels),
  39. )
  40. self.output_projection_attn = nn.Linear(residual_channels, output_channels)
  41. torch.nn.init.trunc_normal_(self.latent, std=0.02)
  42. self.apply(self.init_weights)
  43. def init_weights(self, m):
  44. if isinstance(m, nn.Linear):
  45. torch.nn.init.trunc_normal_(m.weight, std=0.02)
  46. if m.bias is not None:
  47. torch.nn.init.constant_(m.bias, 0)
  48. def forward(self, x, attn_mask=None):
  49. x = super().forward(x).mT
  50. B, N, C = x.shape
  51. # Calculate mask
  52. if attn_mask is not None:
  53. assert attn_mask.shape == (B, N) and attn_mask.dtype == torch.bool
  54. attn_mask = attn_mask[:, None, None, :].expand(
  55. B, self.num_heads, self.latent_len, N
  56. )
  57. q_latent = self.latent.expand(B, -1, -1)
  58. q = (
  59. self.q(q_latent)
  60. .reshape(B, self.latent_len, self.num_heads, self.head_dim)
  61. .transpose(1, 2)
  62. )
  63. kv = (
  64. self.kv(x)
  65. .reshape(B, N, 2, self.num_heads, self.head_dim)
  66. .permute(2, 0, 3, 1, 4)
  67. )
  68. k, v = kv.unbind(0)
  69. q, k = self.q_norm(q), self.k_norm(k)
  70. x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
  71. x = x.transpose(1, 2).reshape(B, self.latent_len, C)
  72. x = self.proj(x)
  73. x = self.proj_drop(x)
  74. x = x + self.mlp(self.norm(x))
  75. x = self.output_projection_attn(x)
  76. x = x.mean(1)
  77. return x
  78. if __name__ == "__main__":
  79. with autocast_exclude_mps(device_type="cpu", dtype=torch.bfloat16):
  80. model = ReferenceEncoder(
  81. input_channels=128,
  82. output_channels=64,
  83. residual_channels=384,
  84. residual_layers=20,
  85. dilation_cycle=4,
  86. num_heads=8,
  87. )
  88. x = torch.randn(4, 128, 64)
  89. mask = torch.ones(4, 64, dtype=torch.bool)
  90. y = model(x, mask)
  91. print(y.shape)
  92. loss = F.mse_loss(y, torch.randn(4, 64))
  93. loss.backward()