|
|
@@ -0,0 +1,336 @@
|
|
|
+import math
|
|
|
+
|
|
|
+import torch
|
|
|
+from einops import rearrange
|
|
|
+from torch import nn
|
|
|
+from torch.nn import functional as F
|
|
|
+
|
|
|
+try:
|
|
|
+ from xformers.ops import memory_efficient_attention
|
|
|
+except ImportError as e:
|
|
|
+ memory_efficient_attention = None
|
|
|
+# memory_efficient_attention = None
|
|
|
+
|
|
|
+
|
|
|
+class AlibiPostionEmbedding:
|
|
|
+ def __init__(self, nheads, maxpos):
|
|
|
+ context_position = torch.arange(maxpos)[:, None]
|
|
|
+ memory_position = torch.arange(maxpos)[None, :]
|
|
|
+ relative_position = memory_position - context_position
|
|
|
+ relative_position = (
|
|
|
+ torch.abs(relative_position).unsqueeze(0).expand(nheads, -1, -1)
|
|
|
+ )
|
|
|
+ self.slopes = torch.Tensor(self.get_slopes(nheads)) * -1
|
|
|
+ self.alibi = self.slopes.unsqueeze(1).unsqueeze(1) * relative_position
|
|
|
+ self.alibi = self.alibi.view(nheads, maxpos, maxpos)
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def get_slopes_power_of_2(n):
|
|
|
+ start = 2 ** (-(2 ** -(math.log2(n) - 3)))
|
|
|
+ ratio = start
|
|
|
+ return [start * ratio**i for i in range(n)]
|
|
|
+
|
|
|
+ def get_slopes(self, n):
|
|
|
+ if math.log2(n).is_integer():
|
|
|
+ return self.get_slopes_power_of_2(n)
|
|
|
+
|
|
|
+ closest_power_of_2 = 2 ** math.floor(math.log2(n))
|
|
|
+ return (
|
|
|
+ self.get_slopes_power_of_2(closest_power_of_2)
|
|
|
+ + self.get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
|
|
|
+ )
|
|
|
+
|
|
|
+ def __call__(self, x):
|
|
|
+ # N, T, C
|
|
|
+ return self.alibi[:, : x.size(1), : x.size(1)].to(x.device)
|
|
|
+
|
|
|
+
|
|
|
+class MultiheadAttention(nn.Module):
|
|
|
+ def __init__(self, d_model, nhead, dropout=0.1):
|
|
|
+ super().__init__()
|
|
|
+ assert d_model % nhead == 0
|
|
|
+ self.nhead = nhead
|
|
|
+ self.d_model = d_model
|
|
|
+ self.head_dim = d_model // nhead
|
|
|
+
|
|
|
+ self.q_proj = nn.Linear(d_model, d_model)
|
|
|
+ self.k_proj = nn.Linear(d_model, d_model)
|
|
|
+ self.v_proj = nn.Linear(d_model, d_model)
|
|
|
+ self.out_proj = nn.Linear(d_model, d_model)
|
|
|
+ self.dropout = nn.Dropout(dropout)
|
|
|
+
|
|
|
+ def forward(
|
|
|
+ self,
|
|
|
+ q,
|
|
|
+ k,
|
|
|
+ v,
|
|
|
+ attn_mask=None,
|
|
|
+ key_padding_mask=None,
|
|
|
+ attn_bias=None,
|
|
|
+ past_kv=None,
|
|
|
+ return_weights=False,
|
|
|
+ ):
|
|
|
+ # (B, T, C)
|
|
|
+ batch_size = q.size(0)
|
|
|
+ q_length = q.size(1)
|
|
|
+ k_length = k.size(1)
|
|
|
+
|
|
|
+ if past_kv is not None:
|
|
|
+ k, v = torch.cat([past_kv, k], 1), torch.cat([past_kv, v], 1)
|
|
|
+
|
|
|
+ if attn_bias is not None:
|
|
|
+ assert attn_bias.size() == (
|
|
|
+ self.nhead,
|
|
|
+ q_length,
|
|
|
+ k_length,
|
|
|
+ ), f"Should be {(self.nhead, q_length, k_length)}. Got {attn_bias.size()}"
|
|
|
+
|
|
|
+ attn_bias = attn_bias.unsqueeze(0).expand(batch_size, -1, -1, -1)
|
|
|
+
|
|
|
+ if attn_mask is not None:
|
|
|
+ assert attn_mask.size() == (
|
|
|
+ q_length,
|
|
|
+ k_length,
|
|
|
+ ), f"Should be {(q_length, k_length)}. Got {attn_mask.size()}"
|
|
|
+ assert attn_mask.dtype == torch.bool
|
|
|
+ attn_mask = attn_mask.unsqueeze(0).expand(batch_size * self.nhead, -1, -1)
|
|
|
+
|
|
|
+ if key_padding_mask is not None:
|
|
|
+ assert key_padding_mask.size() == (
|
|
|
+ batch_size,
|
|
|
+ k_length,
|
|
|
+ ), f"Should be {(batch_size, k_length)}. Got {key_padding_mask.size()}"
|
|
|
+ assert key_padding_mask.dtype == torch.bool
|
|
|
+ key_padding_mask = (
|
|
|
+ key_padding_mask.unsqueeze(1)
|
|
|
+ .unsqueeze(1)
|
|
|
+ .expand(-1, self.nhead, -1, -1)
|
|
|
+ )
|
|
|
+ key_padding_mask = key_padding_mask.reshape(
|
|
|
+ batch_size * self.nhead, 1, k_length
|
|
|
+ )
|
|
|
+ if attn_mask is None:
|
|
|
+ attn_mask = key_padding_mask.expand(-1, q.size(1), -1)
|
|
|
+ else:
|
|
|
+ attn_mask = attn_mask.logical_or(key_padding_mask)
|
|
|
+
|
|
|
+ q, k, v = self.q_proj(q), self.k_proj(k), self.v_proj(v)
|
|
|
+
|
|
|
+ if (
|
|
|
+ return_weights is False
|
|
|
+ and memory_efficient_attention is not None
|
|
|
+ and q.device.type == "cuda"
|
|
|
+ ):
|
|
|
+ # (-> b, t,. n, d)
|
|
|
+ q = rearrange(q, "b t (n d) -> b t n d", n=self.nhead)
|
|
|
+ k = rearrange(k, "b t (n d) -> b t n d", n=self.nhead)
|
|
|
+ v = rearrange(v, "b t (n d) -> b t n d", n=self.nhead)
|
|
|
+
|
|
|
+ if attn_mask is not None:
|
|
|
+ attn_mask = rearrange(attn_mask, "(b n) q k -> b n q k", n=self.nhead)
|
|
|
+ attn_bias = attn_bias.masked_fill(attn_mask, float("-inf"))
|
|
|
+
|
|
|
+ attn_output = memory_efficient_attention(
|
|
|
+ q,
|
|
|
+ k,
|
|
|
+ v,
|
|
|
+ attn_bias=attn_bias,
|
|
|
+ scale=self.head_dim**-0.5,
|
|
|
+ p=self.dropout.p,
|
|
|
+ )
|
|
|
+ attn_output = rearrange(attn_output, "b t n d -> b t (n d)", n=self.nhead)
|
|
|
+
|
|
|
+ returned_weights = None
|
|
|
+ else:
|
|
|
+ q = rearrange(q, "b t (n d) -> (b n) t d", n=self.nhead)
|
|
|
+ k = rearrange(k, "b t (n d) -> (b n) t d", n=self.nhead)
|
|
|
+ v = rearrange(v, "b t (n d) -> (b n) t d", n=self.nhead)
|
|
|
+
|
|
|
+ attn_weights = torch.bmm(q, k.mT) * (self.head_dim**-0.5)
|
|
|
+ assert attn_weights.size() == (
|
|
|
+ batch_size * self.nhead,
|
|
|
+ q.size(1),
|
|
|
+ k.size(1),
|
|
|
+ )
|
|
|
+
|
|
|
+ if attn_bias is not None:
|
|
|
+ attn_bias = rearrange(attn_bias, "b n q k -> (b n) q k")
|
|
|
+ attn_weights = attn_weights + attn_bias
|
|
|
+
|
|
|
+ if attn_mask is not None:
|
|
|
+ attn_weights = attn_weights.masked_fill(attn_mask, float("-inf"))
|
|
|
+
|
|
|
+ attn_weights = F.softmax(attn_weights, dim=-1, dtype=attn_weights.dtype)
|
|
|
+ returned_weights = attn_weights.view(
|
|
|
+ batch_size, self.nhead, q.size(1), k.size(1)
|
|
|
+ )
|
|
|
+
|
|
|
+ attn_probs = self.dropout(attn_weights)
|
|
|
+ attn_output = torch.bmm(attn_probs, v)
|
|
|
+ attn_output = rearrange(attn_output, "(b n) t d -> b t (n d)", n=self.nhead)
|
|
|
+
|
|
|
+ attn_output = self.out_proj(attn_output)
|
|
|
+ return attn_output, returned_weights
|
|
|
+
|
|
|
+
|
|
|
+class GluMLP(nn.Module):
|
|
|
+ def __init__(self, hidden_size=1024, intermediate_size=None, activation=nn.SiLU):
|
|
|
+ super().__init__()
|
|
|
+
|
|
|
+ if intermediate_size is None:
|
|
|
+ intermediate_size = hidden_size * (11 / 3)
|
|
|
+ intermediate_size = round(intermediate_size / 8) * 8
|
|
|
+
|
|
|
+ self.hidden_size = hidden_size
|
|
|
+ self.intermediate_size = intermediate_size
|
|
|
+
|
|
|
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
|
|
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
|
|
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
|
|
+ self.act_fn = activation()
|
|
|
+
|
|
|
+ def forward(self, x):
|
|
|
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
|
|
+
|
|
|
+
|
|
|
+class RMSNorm(nn.Module):
|
|
|
+ def __init__(self, hidden_size, eps=1e-6):
|
|
|
+ """
|
|
|
+ RMSNorm is equivalent to T5LayerNorm
|
|
|
+ """
|
|
|
+ super().__init__()
|
|
|
+
|
|
|
+ self.weight = nn.Parameter(torch.ones(hidden_size))
|
|
|
+ self.variance_epsilon = eps
|
|
|
+
|
|
|
+ def forward(self, hidden_states):
|
|
|
+ input_dtype = hidden_states.dtype
|
|
|
+ hidden_states = hidden_states.to(torch.float32)
|
|
|
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
|
|
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
|
|
+
|
|
|
+ return self.weight * hidden_states.to(input_dtype)
|
|
|
+
|
|
|
+
|
|
|
+class CrossAttentionLayer(nn.Module):
|
|
|
+ def __init__(self, hidden_size=1024, intermediate_size=None, dropout=0.1):
|
|
|
+ super().__init__()
|
|
|
+
|
|
|
+ self.attn = MultiheadAttention(hidden_size, 1, dropout=dropout)
|
|
|
+ self.mlp = GluMLP(hidden_size=hidden_size, intermediate_size=intermediate_size)
|
|
|
+ self.input_layernorm_q = RMSNorm(hidden_size, eps=1e-6)
|
|
|
+ self.input_layernorm_kv = RMSNorm(hidden_size, eps=1e-6)
|
|
|
+ self.post_attention_layernorm = RMSNorm(hidden_size, eps=1e-6)
|
|
|
+
|
|
|
+ def forward(self, tgt, memory, memory_key_padding_mask=None):
|
|
|
+ residual = tgt
|
|
|
+ tgt, memory = self.input_layernorm_q(tgt), self.input_layernorm_kv(memory)
|
|
|
+ x, attn_weights = self.attn(
|
|
|
+ tgt,
|
|
|
+ memory,
|
|
|
+ memory,
|
|
|
+ key_padding_mask=memory_key_padding_mask,
|
|
|
+ return_weights=True,
|
|
|
+ )
|
|
|
+ residual = x + residual
|
|
|
+
|
|
|
+ x = self.post_attention_layernorm(residual)
|
|
|
+ x = self.mlp(x)
|
|
|
+ x = x + residual
|
|
|
+
|
|
|
+ return x, attn_weights
|
|
|
+
|
|
|
+
|
|
|
+class TransformerEncoderLayer(nn.Module):
|
|
|
+ def __init__(self, hidden_size=1024, intermediate_size=None, nhead=16, dropout=0.1):
|
|
|
+ super().__init__()
|
|
|
+
|
|
|
+ self.attn = MultiheadAttention(hidden_size, nhead, dropout=dropout)
|
|
|
+ self.mlp = GluMLP(hidden_size=hidden_size, intermediate_size=intermediate_size)
|
|
|
+ self.input_layernorm = RMSNorm(hidden_size, eps=1e-6)
|
|
|
+ self.post_attention_layernorm = RMSNorm(hidden_size, eps=1e-6)
|
|
|
+
|
|
|
+ def forward(self, x, attn_bias=None, key_padding_mask=None, tgt_mask=None):
|
|
|
+ residual = x
|
|
|
+ x = self.input_layernorm(x)
|
|
|
+ x, _ = self.attn(
|
|
|
+ x,
|
|
|
+ x,
|
|
|
+ x,
|
|
|
+ attn_bias=attn_bias,
|
|
|
+ key_padding_mask=key_padding_mask,
|
|
|
+ attn_mask=tgt_mask,
|
|
|
+ return_weights=False,
|
|
|
+ )
|
|
|
+ residual = x + residual
|
|
|
+
|
|
|
+ x = self.post_attention_layernorm(residual)
|
|
|
+ x = self.mlp(x)
|
|
|
+ x = x + residual
|
|
|
+
|
|
|
+ return x
|
|
|
+
|
|
|
+
|
|
|
+class FishSpeechTransformer(nn.Module):
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ vocab_size,
|
|
|
+ codebook_size,
|
|
|
+ num_codebooks,
|
|
|
+ hidden_size=1024,
|
|
|
+ intermediate_size=None,
|
|
|
+ nhead=16,
|
|
|
+ num_encoder_layers=12,
|
|
|
+ num_decoder_layers=12,
|
|
|
+ dropout=0.1,
|
|
|
+ ):
|
|
|
+ self.embedding = nn.Embedding(vocab_size, hidden_size)
|
|
|
+ self.lm_head = nn.Linear(hidden_size, vocab_size * num_codebooks)
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ mha = MultiheadAttention(512, 8, dropout=0)
|
|
|
+ mha.eval()
|
|
|
+ mha.cuda()
|
|
|
+
|
|
|
+ q, k, v = torch.randn(3, 10, 16, 512)
|
|
|
+ q, k, v = q.cuda(), k.cuda(), v.cuda()
|
|
|
+ alibi = AlibiPostionEmbedding(8, 1024)
|
|
|
+
|
|
|
+ mha.bfloat16()
|
|
|
+ q, k, v = q.bfloat16(), k.bfloat16(), v.bfloat16()
|
|
|
+ bias = alibi(q).bfloat16()
|
|
|
+
|
|
|
+ # Causual mask
|
|
|
+ attn_mask = torch.triu(torch.ones(16, 16), diagonal=1).bool().cuda()
|
|
|
+ o, w = mha(q, k, v, return_weights=True, attn_bias=bias, attn_mask=attn_mask)
|
|
|
+
|
|
|
+ print(o.size())
|
|
|
+ print(w.size())
|
|
|
+
|
|
|
+ o1, w = mha(q, k, v, return_weights=False, attn_bias=bias, attn_mask=attn_mask)
|
|
|
+ print(o1.size())
|
|
|
+
|
|
|
+ print(o[0], o1.float()[0])
|
|
|
+
|
|
|
+ assert torch.allclose(o.float(), o1.float(), atol=1e-2, rtol=1e-2)
|
|
|
+ print("ok")
|
|
|
+
|
|
|
+ cross = CrossAttentionLayer(512, 1024, dropout=0)
|
|
|
+ cross.eval()
|
|
|
+ cross.cuda()
|
|
|
+
|
|
|
+ tgt = torch.randn(3, 10, 512).cuda()
|
|
|
+ memory = torch.randn(3, 20, 512).cuda()
|
|
|
+ o, w = cross(tgt, memory)
|
|
|
+
|
|
|
+ print(o.size())
|
|
|
+ print(w.size())
|
|
|
+
|
|
|
+ ten = TransformerEncoderLayer(512, 1024, 8, dropout=0)
|
|
|
+ ten.eval()
|
|
|
+ ten.cuda()
|
|
|
+
|
|
|
+ tgt = torch.randn(3, 10, 512).cuda()
|
|
|
+ o = ten(tgt)
|
|
|
+ print(o.size())
|