Просмотр исходного кода

Implement encoder decoder arch

Lengyue 2 лет назад
Родитель
Сommit
05d2764fdd
1 измененных файлов с 336 добавлено и 0 удалено
  1. 336 0
      fish_speech/models/text2semantic/modules.py

+ 336 - 0
fish_speech/models/text2semantic/modules.py

@@ -0,0 +1,336 @@
+import math
+
+import torch
+from einops import rearrange
+from torch import nn
+from torch.nn import functional as F
+
+try:
+    from xformers.ops import memory_efficient_attention
+except ImportError as e:
+    memory_efficient_attention = None
+# memory_efficient_attention = None
+
+
+class AlibiPostionEmbedding:
+    def __init__(self, nheads, maxpos):
+        context_position = torch.arange(maxpos)[:, None]
+        memory_position = torch.arange(maxpos)[None, :]
+        relative_position = memory_position - context_position
+        relative_position = (
+            torch.abs(relative_position).unsqueeze(0).expand(nheads, -1, -1)
+        )
+        self.slopes = torch.Tensor(self.get_slopes(nheads)) * -1
+        self.alibi = self.slopes.unsqueeze(1).unsqueeze(1) * relative_position
+        self.alibi = self.alibi.view(nheads, maxpos, maxpos)
+
+    @staticmethod
+    def get_slopes_power_of_2(n):
+        start = 2 ** (-(2 ** -(math.log2(n) - 3)))
+        ratio = start
+        return [start * ratio**i for i in range(n)]
+
+    def get_slopes(self, n):
+        if math.log2(n).is_integer():
+            return self.get_slopes_power_of_2(n)
+
+        closest_power_of_2 = 2 ** math.floor(math.log2(n))
+        return (
+            self.get_slopes_power_of_2(closest_power_of_2)
+            + self.get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
+        )
+
+    def __call__(self, x):
+        # N, T, C
+        return self.alibi[:, : x.size(1), : x.size(1)].to(x.device)
+
+
+class MultiheadAttention(nn.Module):
+    def __init__(self, d_model, nhead, dropout=0.1):
+        super().__init__()
+        assert d_model % nhead == 0
+        self.nhead = nhead
+        self.d_model = d_model
+        self.head_dim = d_model // nhead
+
+        self.q_proj = nn.Linear(d_model, d_model)
+        self.k_proj = nn.Linear(d_model, d_model)
+        self.v_proj = nn.Linear(d_model, d_model)
+        self.out_proj = nn.Linear(d_model, d_model)
+        self.dropout = nn.Dropout(dropout)
+
+    def forward(
+        self,
+        q,
+        k,
+        v,
+        attn_mask=None,
+        key_padding_mask=None,
+        attn_bias=None,
+        past_kv=None,
+        return_weights=False,
+    ):
+        # (B, T, C)
+        batch_size = q.size(0)
+        q_length = q.size(1)
+        k_length = k.size(1)
+
+        if past_kv is not None:
+            k, v = torch.cat([past_kv, k], 1), torch.cat([past_kv, v], 1)
+
+        if attn_bias is not None:
+            assert attn_bias.size() == (
+                self.nhead,
+                q_length,
+                k_length,
+            ), f"Should be {(self.nhead, q_length, k_length)}. Got {attn_bias.size()}"
+
+            attn_bias = attn_bias.unsqueeze(0).expand(batch_size, -1, -1, -1)
+
+        if attn_mask is not None:
+            assert attn_mask.size() == (
+                q_length,
+                k_length,
+            ), f"Should be {(q_length, k_length)}. Got {attn_mask.size()}"
+            assert attn_mask.dtype == torch.bool
+            attn_mask = attn_mask.unsqueeze(0).expand(batch_size * self.nhead, -1, -1)
+
+        if key_padding_mask is not None:
+            assert key_padding_mask.size() == (
+                batch_size,
+                k_length,
+            ), f"Should be {(batch_size, k_length)}. Got {key_padding_mask.size()}"
+            assert key_padding_mask.dtype == torch.bool
+            key_padding_mask = (
+                key_padding_mask.unsqueeze(1)
+                .unsqueeze(1)
+                .expand(-1, self.nhead, -1, -1)
+            )
+            key_padding_mask = key_padding_mask.reshape(
+                batch_size * self.nhead, 1, k_length
+            )
+            if attn_mask is None:
+                attn_mask = key_padding_mask.expand(-1, q.size(1), -1)
+            else:
+                attn_mask = attn_mask.logical_or(key_padding_mask)
+
+        q, k, v = self.q_proj(q), self.k_proj(k), self.v_proj(v)
+
+        if (
+            return_weights is False
+            and memory_efficient_attention is not None
+            and q.device.type == "cuda"
+        ):
+            # (-> b, t,. n, d)
+            q = rearrange(q, "b t (n d) -> b t n d", n=self.nhead)
+            k = rearrange(k, "b t (n d) -> b t n d", n=self.nhead)
+            v = rearrange(v, "b t (n d) -> b t n d", n=self.nhead)
+
+            if attn_mask is not None:
+                attn_mask = rearrange(attn_mask, "(b n) q k -> b n q k", n=self.nhead)
+                attn_bias = attn_bias.masked_fill(attn_mask, float("-inf"))
+
+            attn_output = memory_efficient_attention(
+                q,
+                k,
+                v,
+                attn_bias=attn_bias,
+                scale=self.head_dim**-0.5,
+                p=self.dropout.p,
+            )
+            attn_output = rearrange(attn_output, "b t n d -> b t (n d)", n=self.nhead)
+
+            returned_weights = None
+        else:
+            q = rearrange(q, "b t (n d) -> (b n) t d", n=self.nhead)
+            k = rearrange(k, "b t (n d) -> (b n) t d", n=self.nhead)
+            v = rearrange(v, "b t (n d) -> (b n) t d", n=self.nhead)
+
+            attn_weights = torch.bmm(q, k.mT) * (self.head_dim**-0.5)
+            assert attn_weights.size() == (
+                batch_size * self.nhead,
+                q.size(1),
+                k.size(1),
+            )
+
+            if attn_bias is not None:
+                attn_bias = rearrange(attn_bias, "b n q k -> (b n) q k")
+                attn_weights = attn_weights + attn_bias
+
+            if attn_mask is not None:
+                attn_weights = attn_weights.masked_fill(attn_mask, float("-inf"))
+
+            attn_weights = F.softmax(attn_weights, dim=-1, dtype=attn_weights.dtype)
+            returned_weights = attn_weights.view(
+                batch_size, self.nhead, q.size(1), k.size(1)
+            )
+
+            attn_probs = self.dropout(attn_weights)
+            attn_output = torch.bmm(attn_probs, v)
+            attn_output = rearrange(attn_output, "(b n) t d -> b t (n d)", n=self.nhead)
+
+        attn_output = self.out_proj(attn_output)
+        return attn_output, returned_weights
+
+
+class GluMLP(nn.Module):
+    def __init__(self, hidden_size=1024, intermediate_size=None, activation=nn.SiLU):
+        super().__init__()
+
+        if intermediate_size is None:
+            intermediate_size = hidden_size * (11 / 3)
+            intermediate_size = round(intermediate_size / 8) * 8
+
+        self.hidden_size = hidden_size
+        self.intermediate_size = intermediate_size
+
+        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
+        self.act_fn = activation()
+
+    def forward(self, x):
+        return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
+
+
+class RMSNorm(nn.Module):
+    def __init__(self, hidden_size, eps=1e-6):
+        """
+        RMSNorm is equivalent to T5LayerNorm
+        """
+        super().__init__()
+
+        self.weight = nn.Parameter(torch.ones(hidden_size))
+        self.variance_epsilon = eps
+
+    def forward(self, hidden_states):
+        input_dtype = hidden_states.dtype
+        hidden_states = hidden_states.to(torch.float32)
+        variance = hidden_states.pow(2).mean(-1, keepdim=True)
+        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+
+        return self.weight * hidden_states.to(input_dtype)
+
+
+class CrossAttentionLayer(nn.Module):
+    def __init__(self, hidden_size=1024, intermediate_size=None, dropout=0.1):
+        super().__init__()
+
+        self.attn = MultiheadAttention(hidden_size, 1, dropout=dropout)
+        self.mlp = GluMLP(hidden_size=hidden_size, intermediate_size=intermediate_size)
+        self.input_layernorm_q = RMSNorm(hidden_size, eps=1e-6)
+        self.input_layernorm_kv = RMSNorm(hidden_size, eps=1e-6)
+        self.post_attention_layernorm = RMSNorm(hidden_size, eps=1e-6)
+
+    def forward(self, tgt, memory, memory_key_padding_mask=None):
+        residual = tgt
+        tgt, memory = self.input_layernorm_q(tgt), self.input_layernorm_kv(memory)
+        x, attn_weights = self.attn(
+            tgt,
+            memory,
+            memory,
+            key_padding_mask=memory_key_padding_mask,
+            return_weights=True,
+        )
+        residual = x + residual
+
+        x = self.post_attention_layernorm(residual)
+        x = self.mlp(x)
+        x = x + residual
+
+        return x, attn_weights
+
+
+class TransformerEncoderLayer(nn.Module):
+    def __init__(self, hidden_size=1024, intermediate_size=None, nhead=16, dropout=0.1):
+        super().__init__()
+
+        self.attn = MultiheadAttention(hidden_size, nhead, dropout=dropout)
+        self.mlp = GluMLP(hidden_size=hidden_size, intermediate_size=intermediate_size)
+        self.input_layernorm = RMSNorm(hidden_size, eps=1e-6)
+        self.post_attention_layernorm = RMSNorm(hidden_size, eps=1e-6)
+
+    def forward(self, x, attn_bias=None, key_padding_mask=None, tgt_mask=None):
+        residual = x
+        x = self.input_layernorm(x)
+        x, _ = self.attn(
+            x,
+            x,
+            x,
+            attn_bias=attn_bias,
+            key_padding_mask=key_padding_mask,
+            attn_mask=tgt_mask,
+            return_weights=False,
+        )
+        residual = x + residual
+
+        x = self.post_attention_layernorm(residual)
+        x = self.mlp(x)
+        x = x + residual
+
+        return x
+
+
+class FishSpeechTransformer(nn.Module):
+    def __init__(
+        self,
+        vocab_size,
+        codebook_size,
+        num_codebooks,
+        hidden_size=1024,
+        intermediate_size=None,
+        nhead=16,
+        num_encoder_layers=12,
+        num_decoder_layers=12,
+        dropout=0.1,
+    ):
+        self.embedding = nn.Embedding(vocab_size, hidden_size)
+        self.lm_head = nn.Linear(hidden_size, vocab_size * num_codebooks)
+
+
+if __name__ == "__main__":
+    mha = MultiheadAttention(512, 8, dropout=0)
+    mha.eval()
+    mha.cuda()
+
+    q, k, v = torch.randn(3, 10, 16, 512)
+    q, k, v = q.cuda(), k.cuda(), v.cuda()
+    alibi = AlibiPostionEmbedding(8, 1024)
+
+    mha.bfloat16()
+    q, k, v = q.bfloat16(), k.bfloat16(), v.bfloat16()
+    bias = alibi(q).bfloat16()
+
+    # Causual mask
+    attn_mask = torch.triu(torch.ones(16, 16), diagonal=1).bool().cuda()
+    o, w = mha(q, k, v, return_weights=True, attn_bias=bias, attn_mask=attn_mask)
+
+    print(o.size())
+    print(w.size())
+
+    o1, w = mha(q, k, v, return_weights=False, attn_bias=bias, attn_mask=attn_mask)
+    print(o1.size())
+
+    print(o[0], o1.float()[0])
+
+    assert torch.allclose(o.float(), o1.float(), atol=1e-2, rtol=1e-2)
+    print("ok")
+
+    cross = CrossAttentionLayer(512, 1024, dropout=0)
+    cross.eval()
+    cross.cuda()
+
+    tgt = torch.randn(3, 10, 512).cuda()
+    memory = torch.randn(3, 20, 512).cuda()
+    o, w = cross(tgt, memory)
+
+    print(o.size())
+    print(w.size())
+
+    ten = TransformerEncoderLayer(512, 1024, 8, dropout=0)
+    ten.eval()
+    ten.cuda()
+
+    tgt = torch.randn(3, 10, 512).cuda()
+    o = ten(tgt)
+    print(o.size())