|
@@ -1,534 +0,0 @@
|
|
|
-import math
|
|
|
|
|
-from typing import Optional
|
|
|
|
|
-
|
|
|
|
|
-import torch
|
|
|
|
|
-from einops import rearrange
|
|
|
|
|
-from torch import nn
|
|
|
|
|
-from torch.nn import functional as F
|
|
|
|
|
-from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
|
|
|
|
-
|
|
|
|
|
-
|
|
|
|
|
-def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
|
|
|
|
|
- """
|
|
|
|
|
- Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
|
|
|
|
|
-
|
|
|
|
|
- This function calculates a frequency tensor with complex exponentials using the given dimension 'dim'
|
|
|
|
|
- and the end index 'end'. The 'theta' parameter scales the frequencies.
|
|
|
|
|
- The returned tensor contains complex values in complex64 data type.
|
|
|
|
|
-
|
|
|
|
|
- Args:
|
|
|
|
|
- dim (int): Dimension of the frequency tensor.
|
|
|
|
|
- end (int): End index for precomputing frequencies.
|
|
|
|
|
- theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
|
|
|
|
|
-
|
|
|
|
|
- Returns:
|
|
|
|
|
- torch.Tensor: Precomputed frequency tensor with complex exponentials.
|
|
|
|
|
- """
|
|
|
|
|
- freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
|
|
|
|
- t = torch.arange(end, device=freqs.device) # type: ignore
|
|
|
|
|
- freqs = torch.outer(t, freqs).float() # type: ignore
|
|
|
|
|
- freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
|
|
|
|
|
- return freqs_cis
|
|
|
|
|
-
|
|
|
|
|
-
|
|
|
|
|
-def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
|
|
|
|
|
- """
|
|
|
|
|
- Reshape frequency tensor for broadcasting it with another tensor.
|
|
|
|
|
-
|
|
|
|
|
- This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
|
|
|
|
|
- for the purpose of broadcasting the frequency tensor during element-wise operations.
|
|
|
|
|
-
|
|
|
|
|
- Args:
|
|
|
|
|
- freqs_cis (torch.Tensor): Frequency tensor to be reshaped.
|
|
|
|
|
- x (torch.Tensor): Target tensor for broadcasting compatibility.
|
|
|
|
|
-
|
|
|
|
|
- Returns:
|
|
|
|
|
- torch.Tensor: Reshaped frequency tensor.
|
|
|
|
|
-
|
|
|
|
|
- Raises:
|
|
|
|
|
- AssertionError: If the frequency tensor doesn't match the expected shape.
|
|
|
|
|
- AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions.
|
|
|
|
|
- """
|
|
|
|
|
- ndim = x.ndim
|
|
|
|
|
- assert 0 <= 1 < ndim
|
|
|
|
|
- assert freqs_cis.shape == (x.shape[1], x.shape[-1])
|
|
|
|
|
- shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
|
|
|
|
- return freqs_cis.view(*shape)
|
|
|
|
|
-
|
|
|
|
|
-
|
|
|
|
|
-def apply_rotary_emb(
|
|
|
|
|
- x: torch.Tensor,
|
|
|
|
|
- freqs_cis: torch.Tensor,
|
|
|
|
|
-) -> tuple[torch.Tensor, torch.Tensor]:
|
|
|
|
|
- x_ = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
|
|
|
|
|
- freqs_cis = reshape_for_broadcast(freqs_cis, x_)
|
|
|
|
|
- return torch.view_as_real(x_ * freqs_cis).flatten(3).type_as(x)
|
|
|
|
|
-
|
|
|
|
|
-
|
|
|
|
|
-class MultiheadAttention(nn.Module):
|
|
|
|
|
- def __init__(self, d_model, nhead, dropout=0.1, is_cross_attention=False):
|
|
|
|
|
- super().__init__()
|
|
|
|
|
- assert d_model % nhead == 0
|
|
|
|
|
- self.nhead = nhead
|
|
|
|
|
- self.d_model = d_model
|
|
|
|
|
- self.head_dim = d_model // nhead
|
|
|
|
|
- self.is_cross_attention = is_cross_attention
|
|
|
|
|
-
|
|
|
|
|
- # Auto fuse linear projection
|
|
|
|
|
- if is_cross_attention:
|
|
|
|
|
- self.q_proj = nn.Linear(d_model, d_model)
|
|
|
|
|
- self.kv_proj = nn.Linear(d_model, d_model * 2)
|
|
|
|
|
- else:
|
|
|
|
|
- self.qkv_proj = nn.Linear(d_model, d_model * 3)
|
|
|
|
|
-
|
|
|
|
|
- self.o_proj = nn.Linear(d_model, d_model)
|
|
|
|
|
- self.dropout = nn.Dropout(dropout)
|
|
|
|
|
-
|
|
|
|
|
- def forward(
|
|
|
|
|
- self,
|
|
|
|
|
- q,
|
|
|
|
|
- freqs_cis_q,
|
|
|
|
|
- kv=None,
|
|
|
|
|
- freqs_cis_kv=None,
|
|
|
|
|
- attn_mask=None,
|
|
|
|
|
- input_pos=None,
|
|
|
|
|
- kv_cache=None,
|
|
|
|
|
- ):
|
|
|
|
|
- if self.is_cross_attention:
|
|
|
|
|
- q = self.q_proj(q)
|
|
|
|
|
- if kv is None:
|
|
|
|
|
- assert self.kv_cache is not None, "kv_cache should be initialized"
|
|
|
|
|
- k, v = None
|
|
|
|
|
- else:
|
|
|
|
|
- # Using kv cache
|
|
|
|
|
- kv = self.kv_proj(kv)
|
|
|
|
|
- k, v = torch.chunk(kv, 2, dim=-1)
|
|
|
|
|
- else:
|
|
|
|
|
- assert kv is None, f"kv should be None for self attention"
|
|
|
|
|
- assert (
|
|
|
|
|
- freqs_cis_kv is None
|
|
|
|
|
- ), f"freqs_cis_kv should be None for self attention"
|
|
|
|
|
- q, k, v = torch.chunk(self.qkv_proj(q), 3, dim=-1)
|
|
|
|
|
-
|
|
|
|
|
- # max_batch_size, max_seq_length, n_heads, head_dim
|
|
|
|
|
- q = rearrange(q, "b t (h d) -> b t h d", h=self.nhead, d=self.head_dim)
|
|
|
|
|
- q = apply_rotary_emb(q, freqs_cis_q)
|
|
|
|
|
-
|
|
|
|
|
- if freqs_cis_kv is None:
|
|
|
|
|
- freqs_cis_kv = freqs_cis_q
|
|
|
|
|
-
|
|
|
|
|
- # Only do when self attention or cross attention without kv cache
|
|
|
|
|
- if k is not None:
|
|
|
|
|
- assert v is not None, "v should not be None when k is not None"
|
|
|
|
|
- k = rearrange(k, "b t (h d) -> b t h d", h=self.nhead, d=self.head_dim)
|
|
|
|
|
- v = rearrange(v, "b t (h d) -> b t h d", h=self.nhead, d=self.head_dim)
|
|
|
|
|
- k = apply_rotary_emb(k, freqs_cis_kv)
|
|
|
|
|
-
|
|
|
|
|
- if kv_cache is not None:
|
|
|
|
|
- if k is None:
|
|
|
|
|
- assert v is None, "v should be None when k is None"
|
|
|
|
|
- k, v = kv_cache[0], kv_cache[1]
|
|
|
|
|
- else:
|
|
|
|
|
- k = torch.cat([kv_cache[0], k], dim=1)
|
|
|
|
|
- v = torch.cat([kv_cache[1], v], dim=1)
|
|
|
|
|
- kv_cache = (k, v)
|
|
|
|
|
-
|
|
|
|
|
- q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
|
|
|
|
|
- value = F.scaled_dot_product_attention(
|
|
|
|
|
- q,
|
|
|
|
|
- k,
|
|
|
|
|
- v,
|
|
|
|
|
- attn_mask=attn_mask,
|
|
|
|
|
- dropout_p=self.dropout.p if self.training else 0,
|
|
|
|
|
- )
|
|
|
|
|
-
|
|
|
|
|
- value = rearrange(value, "b h t d -> b t (h d)")
|
|
|
|
|
- return self.o_proj(value), kv_cache
|
|
|
|
|
-
|
|
|
|
|
-
|
|
|
|
|
-class GluMLP(nn.Module):
|
|
|
|
|
- def __init__(self, hidden_size=1024, intermediate_size=None, activation=nn.SiLU):
|
|
|
|
|
- super().__init__()
|
|
|
|
|
-
|
|
|
|
|
- if intermediate_size is None:
|
|
|
|
|
- intermediate_size = hidden_size * (11 / 3)
|
|
|
|
|
- intermediate_size = round(intermediate_size / 8) * 8
|
|
|
|
|
-
|
|
|
|
|
- self.hidden_size = hidden_size
|
|
|
|
|
- self.intermediate_size = intermediate_size
|
|
|
|
|
-
|
|
|
|
|
- self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
|
|
|
|
- self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
|
|
|
|
- self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
|
|
|
|
- self.act_fn = activation()
|
|
|
|
|
-
|
|
|
|
|
- def forward(self, x):
|
|
|
|
|
- return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
|
|
|
|
-
|
|
|
|
|
-
|
|
|
|
|
-class RMSNorm(nn.Module):
|
|
|
|
|
- def __init__(self, hidden_size, eps=1e-6):
|
|
|
|
|
- """
|
|
|
|
|
- RMSNorm is equivalent to T5LayerNorm
|
|
|
|
|
- """
|
|
|
|
|
- super().__init__()
|
|
|
|
|
-
|
|
|
|
|
- self.weight = nn.Parameter(torch.ones(hidden_size))
|
|
|
|
|
- self.variance_epsilon = eps
|
|
|
|
|
-
|
|
|
|
|
- def forward(self, hidden_states):
|
|
|
|
|
- input_dtype = hidden_states.dtype
|
|
|
|
|
- hidden_states = hidden_states.to(torch.float32)
|
|
|
|
|
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
|
|
|
|
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
|
|
|
|
-
|
|
|
|
|
- return self.weight * hidden_states.to(input_dtype)
|
|
|
|
|
-
|
|
|
|
|
-
|
|
|
|
|
-class TransformerEncoderLayer(nn.Module):
|
|
|
|
|
- def __init__(self, hidden_size=1024, intermediate_size=None, nhead=16, dropout=0.1):
|
|
|
|
|
- super().__init__()
|
|
|
|
|
-
|
|
|
|
|
- self.attention = MultiheadAttention(hidden_size, nhead, dropout=dropout)
|
|
|
|
|
- self.ffn = GluMLP(hidden_size=hidden_size, intermediate_size=intermediate_size)
|
|
|
|
|
-
|
|
|
|
|
- self.attention_norm = RMSNorm(hidden_size, eps=1e-6)
|
|
|
|
|
- self.ffn_norm = RMSNorm(hidden_size, eps=1e-6)
|
|
|
|
|
-
|
|
|
|
|
- def forward(
|
|
|
|
|
- self,
|
|
|
|
|
- x,
|
|
|
|
|
- freqs_cis,
|
|
|
|
|
- attn_mask=None,
|
|
|
|
|
- input_pos=None,
|
|
|
|
|
- ):
|
|
|
|
|
- x = (
|
|
|
|
|
- x
|
|
|
|
|
- + self.attention(
|
|
|
|
|
- q=self.attention_norm(x),
|
|
|
|
|
- freqs_cis_q=freqs_cis,
|
|
|
|
|
- attn_mask=attn_mask,
|
|
|
|
|
- input_pos=input_pos,
|
|
|
|
|
- )[0]
|
|
|
|
|
- )
|
|
|
|
|
-
|
|
|
|
|
- return x + self.ffn(self.ffn_norm(x))
|
|
|
|
|
-
|
|
|
|
|
-
|
|
|
|
|
-class TransformerDecoderLayer(nn.Module):
|
|
|
|
|
- def __init__(self, hidden_size=1024, intermediate_size=None, nhead=16, dropout=0.1):
|
|
|
|
|
- super().__init__()
|
|
|
|
|
-
|
|
|
|
|
- self.self_attention = MultiheadAttention(hidden_size, nhead, dropout=dropout)
|
|
|
|
|
- self.cross_attention = MultiheadAttention(
|
|
|
|
|
- hidden_size, nhead, dropout=dropout, is_cross_attention=True
|
|
|
|
|
- )
|
|
|
|
|
- self.ffn = GluMLP(hidden_size=hidden_size, intermediate_size=intermediate_size)
|
|
|
|
|
-
|
|
|
|
|
- self.self_attention_norm = RMSNorm(hidden_size, eps=1e-6)
|
|
|
|
|
- self.cross_attention_norm = RMSNorm(hidden_size, eps=1e-6)
|
|
|
|
|
- self.ffn_norm = RMSNorm(hidden_size, eps=1e-6)
|
|
|
|
|
-
|
|
|
|
|
- def forward(
|
|
|
|
|
- self,
|
|
|
|
|
- x,
|
|
|
|
|
- context,
|
|
|
|
|
- freqs_cis_q,
|
|
|
|
|
- freqs_cis_kv,
|
|
|
|
|
- self_attn_mask=None,
|
|
|
|
|
- cross_attn_mask=None,
|
|
|
|
|
- input_pos=None,
|
|
|
|
|
- ):
|
|
|
|
|
- x = x + self.self_attention(
|
|
|
|
|
- q=self.self_attention_norm(x),
|
|
|
|
|
- freqs_cis_q=freqs_cis_q,
|
|
|
|
|
- attn_mask=self_attn_mask,
|
|
|
|
|
- input_pos=input_pos,
|
|
|
|
|
- )
|
|
|
|
|
-
|
|
|
|
|
- x = x + self.cross_attention(
|
|
|
|
|
- q=self.cross_attention_norm(x),
|
|
|
|
|
- kv=context,
|
|
|
|
|
- freqs_cis_q=freqs_cis_q,
|
|
|
|
|
- freqs_cis_kv=freqs_cis_kv,
|
|
|
|
|
- attn_mask=cross_attn_mask,
|
|
|
|
|
- input_pos=input_pos,
|
|
|
|
|
- )
|
|
|
|
|
-
|
|
|
|
|
- return x + self.ffn(self.ffn_norm(x))
|
|
|
|
|
-
|
|
|
|
|
-
|
|
|
|
|
-class Transformer(nn.Module):
|
|
|
|
|
- def __init__(
|
|
|
|
|
- self,
|
|
|
|
|
- vocab_size,
|
|
|
|
|
- codebook_size,
|
|
|
|
|
- num_codebooks,
|
|
|
|
|
- hidden_size=1024,
|
|
|
|
|
- intermediate_size=None,
|
|
|
|
|
- nhead=16,
|
|
|
|
|
- num_encoder_layers=12,
|
|
|
|
|
- num_decoder_layers=12,
|
|
|
|
|
- dropout=0.1,
|
|
|
|
|
- max_position=4096,
|
|
|
|
|
- ):
|
|
|
|
|
- super().__init__()
|
|
|
|
|
-
|
|
|
|
|
- self.encoder_embedding = nn.Embedding(vocab_size, hidden_size)
|
|
|
|
|
- self.decoder_embeddings = nn.ModuleList(
|
|
|
|
|
- [nn.Embedding(codebook_size, hidden_size) for _ in range(num_codebooks)]
|
|
|
|
|
- )
|
|
|
|
|
- self.decoder_head = nn.Linear(hidden_size, codebook_size * num_codebooks)
|
|
|
|
|
- self.codebook_size = codebook_size
|
|
|
|
|
- self.num_codebooks = num_codebooks
|
|
|
|
|
- self.nhead = nhead
|
|
|
|
|
-
|
|
|
|
|
- self.encoder = nn.ModuleList(
|
|
|
|
|
- [
|
|
|
|
|
- TransformerEncoderLayer(
|
|
|
|
|
- hidden_size=hidden_size,
|
|
|
|
|
- intermediate_size=intermediate_size,
|
|
|
|
|
- nhead=nhead,
|
|
|
|
|
- dropout=dropout,
|
|
|
|
|
- )
|
|
|
|
|
- for _ in range(num_encoder_layers)
|
|
|
|
|
- ]
|
|
|
|
|
- )
|
|
|
|
|
-
|
|
|
|
|
- self.decoder = nn.ModuleList(
|
|
|
|
|
- [
|
|
|
|
|
- TransformerDecoderLayer(
|
|
|
|
|
- hidden_size=hidden_size,
|
|
|
|
|
- intermediate_size=intermediate_size,
|
|
|
|
|
- nhead=nhead,
|
|
|
|
|
- dropout=dropout,
|
|
|
|
|
- )
|
|
|
|
|
- for _ in range(num_decoder_layers)
|
|
|
|
|
- ]
|
|
|
|
|
- )
|
|
|
|
|
-
|
|
|
|
|
- self.register_buffer(
|
|
|
|
|
- "freqs_cis",
|
|
|
|
|
- precompute_freqs_cis(hidden_size // nhead, max_position, theta=10000.0),
|
|
|
|
|
- )
|
|
|
|
|
-
|
|
|
|
|
- causual_mask = torch.triu(
|
|
|
|
|
- torch.ones(max_position, max_position), diagonal=1
|
|
|
|
|
- ).bool()
|
|
|
|
|
- causual_mask = torch.zeros(max_position, max_position).masked_fill(
|
|
|
|
|
- causual_mask, float("-inf")
|
|
|
|
|
- )
|
|
|
|
|
-
|
|
|
|
|
- self.register_buffer("causual_mask", causual_mask)
|
|
|
|
|
-
|
|
|
|
|
- # The following are reserved for kv cache
|
|
|
|
|
- self.max_batch_size = -1
|
|
|
|
|
- self.max_seq_length = -1
|
|
|
|
|
-
|
|
|
|
|
- def setup_kv_caches(self, max_batch_size, max_seq_length):
|
|
|
|
|
- if (
|
|
|
|
|
- self.max_seq_length >= max_seq_length
|
|
|
|
|
- and self.max_batch_size >= max_batch_size
|
|
|
|
|
- ):
|
|
|
|
|
- return
|
|
|
|
|
-
|
|
|
|
|
- if max_seq_length % 8 != 0:
|
|
|
|
|
- max_seq_length = max_seq_length + (8 - max_seq_length % 8)
|
|
|
|
|
-
|
|
|
|
|
- self.max_seq_length = max_seq_length
|
|
|
|
|
- self.max_batch_size = max_batch_size
|
|
|
|
|
-
|
|
|
|
|
- for b in self.decoder:
|
|
|
|
|
- b.self_attention.kv_cache = KVCache(
|
|
|
|
|
- max_batch_size,
|
|
|
|
|
- max_seq_length,
|
|
|
|
|
- b.self_attention.nhead,
|
|
|
|
|
- b.self_attention.head_dim,
|
|
|
|
|
- ).to(b.self_attention_norm.weight.device)
|
|
|
|
|
-
|
|
|
|
|
- b.cross_attention.kv_cache = KVCache(
|
|
|
|
|
- max_batch_size,
|
|
|
|
|
- max_seq_length,
|
|
|
|
|
- b.cross_attention.nhead,
|
|
|
|
|
- b.cross_attention.head_dim,
|
|
|
|
|
- ).to(b.cross_attention_norm.weight.device)
|
|
|
|
|
-
|
|
|
|
|
- def get_key_padding_mask(self, key_padding_mask, q_size=None):
|
|
|
|
|
- # inputs: (B, T) bool ->
|
|
|
|
|
- assert key_padding_mask.dtype == torch.bool and key_padding_mask.ndim == 2
|
|
|
|
|
-
|
|
|
|
|
- key_padding_mask = (
|
|
|
|
|
- key_padding_mask.unsqueeze(1).unsqueeze(1).expand(-1, self.nhead, -1, -1)
|
|
|
|
|
- )
|
|
|
|
|
-
|
|
|
|
|
- key_padding_mask = key_padding_mask.reshape(
|
|
|
|
|
- key_padding_mask.shape[0], self.nhead, 1, key_padding_mask.shape[1]
|
|
|
|
|
- )
|
|
|
|
|
-
|
|
|
|
|
- if q_size is not None:
|
|
|
|
|
- key_padding_mask = key_padding_mask.expand(-1, -1, q_size, -1)
|
|
|
|
|
-
|
|
|
|
|
- new_mask = torch.zeros(
|
|
|
|
|
- *key_padding_mask.shape, dtype=torch.float, device=key_padding_mask.device
|
|
|
|
|
- )
|
|
|
|
|
- new_mask = new_mask.masked_fill(key_padding_mask, float("-inf"))
|
|
|
|
|
-
|
|
|
|
|
- return new_mask
|
|
|
|
|
-
|
|
|
|
|
- def forward_encoder(
|
|
|
|
|
- self, inputs, input_mask=None
|
|
|
|
|
- ) -> tuple[torch.Tensor, torch.Tensor]:
|
|
|
|
|
- # inputs: (B, T)
|
|
|
|
|
- # input_mask: (B, T), bool mask
|
|
|
|
|
- inputs = self.encoder_embedding(inputs)
|
|
|
|
|
-
|
|
|
|
|
- # Calculate mask
|
|
|
|
|
- if input_mask is None:
|
|
|
|
|
- # Assume no padding
|
|
|
|
|
- input_mask = torch.zeros(
|
|
|
|
|
- inputs.shape[0], inputs.shape[1], dtype=torch.bool, device=inputs.device
|
|
|
|
|
- )
|
|
|
|
|
-
|
|
|
|
|
- input_mask = self.get_key_padding_mask(input_mask, q_size=None).to(inputs.dtype)
|
|
|
|
|
-
|
|
|
|
|
- freqs_cis = self.freqs_cis[: inputs.shape[1]]
|
|
|
|
|
- input_mask_self = input_mask.expand(-1, -1, inputs.shape[1], -1)
|
|
|
|
|
-
|
|
|
|
|
- for layer in self.encoder:
|
|
|
|
|
- inputs = layer(inputs, freqs_cis=freqs_cis, attn_mask=input_mask_self)
|
|
|
|
|
-
|
|
|
|
|
- return inputs, input_mask
|
|
|
|
|
-
|
|
|
|
|
- def forward_decoder(
|
|
|
|
|
- self, codes, inputs, input_mask, codes_mask=None, input_pos=None
|
|
|
|
|
- ):
|
|
|
|
|
- # codes: (B, C, T)
|
|
|
|
|
- # inputs: (B, T, N)
|
|
|
|
|
-
|
|
|
|
|
- print(f"Codes: {codes.shape}, Inputs: {inputs.shape}")
|
|
|
|
|
- codes = rearrange(codes, "b c t -> c b t")
|
|
|
|
|
- codes = torch.stack(
|
|
|
|
|
- [emb(code) for emb, code in zip(self.decoder_embeddings, codes)], dim=0
|
|
|
|
|
- )
|
|
|
|
|
- codes = torch.mean(codes, dim=0) # (B, T)
|
|
|
|
|
-
|
|
|
|
|
- # If kv cache is enabled
|
|
|
|
|
- input_mask = input_mask.expand(-1, -1, codes.shape[1], -1)
|
|
|
|
|
-
|
|
|
|
|
- # Calculate mask
|
|
|
|
|
- if input_pos is not None:
|
|
|
|
|
- attn_mask = self.causual_mask[: codes.shape[1], : codes.shape[1]]
|
|
|
|
|
- else:
|
|
|
|
|
- attn_mask = None
|
|
|
|
|
-
|
|
|
|
|
- # if codes_mask is not None:
|
|
|
|
|
- # codes_mask = self.get_key_padding_mask(codes_mask)
|
|
|
|
|
- # attn_mask = attn_mask + codes_mask
|
|
|
|
|
-
|
|
|
|
|
- # For kv cache
|
|
|
|
|
- if input_pos is not None:
|
|
|
|
|
- freqs_cis_q = self.freqs_cis[input_pos]
|
|
|
|
|
- else:
|
|
|
|
|
- freqs_cis_q = self.freqs_cis[: codes.shape[1]]
|
|
|
|
|
-
|
|
|
|
|
- freqs_cis_kv = self.freqs_cis[: inputs.shape[1]]
|
|
|
|
|
-
|
|
|
|
|
- for layer in self.decoder:
|
|
|
|
|
- codes = layer(
|
|
|
|
|
- codes,
|
|
|
|
|
- inputs,
|
|
|
|
|
- freqs_cis_q=freqs_cis_q,
|
|
|
|
|
- freqs_cis_kv=freqs_cis_kv,
|
|
|
|
|
- self_attn_mask=attn_mask,
|
|
|
|
|
- cross_attn_mask=input_mask,
|
|
|
|
|
- input_pos=input_pos,
|
|
|
|
|
- )
|
|
|
|
|
-
|
|
|
|
|
- codes = self.decoder_head(codes)
|
|
|
|
|
- codes = rearrange(
|
|
|
|
|
- codes, "b t (c d) -> b c t d", c=self.num_codebooks, d=self.codebook_size
|
|
|
|
|
- )
|
|
|
|
|
-
|
|
|
|
|
- return codes
|
|
|
|
|
-
|
|
|
|
|
- def forward(
|
|
|
|
|
- self,
|
|
|
|
|
- inputs,
|
|
|
|
|
- codes,
|
|
|
|
|
- input_mask=None,
|
|
|
|
|
- codes_mask=None,
|
|
|
|
|
- input_pos=None,
|
|
|
|
|
- ):
|
|
|
|
|
- # inputs: (B, T)
|
|
|
|
|
- # codes: (B, C, T)
|
|
|
|
|
- # input_mask: (B, T), bool mask
|
|
|
|
|
- # codes_mask: (B, T), bool mask
|
|
|
|
|
- # input_pos: (B, T), int mask
|
|
|
|
|
-
|
|
|
|
|
- inputs, input_mask = self.forward_encoder(inputs, input_mask)
|
|
|
|
|
- codes = self.forward_decoder(codes, inputs, input_mask, codes_mask, input_pos)
|
|
|
|
|
-
|
|
|
|
|
- return codes
|
|
|
|
|
-
|
|
|
|
|
-
|
|
|
|
|
-if __name__ == "__main__":
|
|
|
|
|
- mha = MultiheadAttention(512, 8, dropout=0, is_cross_attention=True)
|
|
|
|
|
- mha.eval()
|
|
|
|
|
- mha.cuda()
|
|
|
|
|
-
|
|
|
|
|
- q, kv = torch.randn(2, 10, 16, 512)
|
|
|
|
|
- q, kv = q.cuda(), kv.cuda()
|
|
|
|
|
-
|
|
|
|
|
- mha.bfloat16()
|
|
|
|
|
- q, kv = q.bfloat16(), kv.bfloat16()
|
|
|
|
|
- freqs_cis = precompute_freqs_cis(512 // 8, 4096 * 2).cuda()[:16]
|
|
|
|
|
-
|
|
|
|
|
- # Causual mask
|
|
|
|
|
- attn_mask = torch.triu(torch.ones(16, 16), diagonal=1).bool().cuda()
|
|
|
|
|
- o = mha(q, freqs_cis, kv=kv, attn_mask=attn_mask)
|
|
|
|
|
-
|
|
|
|
|
- trans = (
|
|
|
|
|
- Transformer(
|
|
|
|
|
- vocab_size=30000,
|
|
|
|
|
- codebook_size=120,
|
|
|
|
|
- num_codebooks=4,
|
|
|
|
|
- hidden_size=1024,
|
|
|
|
|
- intermediate_size=None,
|
|
|
|
|
- nhead=16,
|
|
|
|
|
- num_encoder_layers=12,
|
|
|
|
|
- num_decoder_layers=12,
|
|
|
|
|
- )
|
|
|
|
|
- .bfloat16()
|
|
|
|
|
- .cuda()
|
|
|
|
|
- )
|
|
|
|
|
- trans.eval()
|
|
|
|
|
-
|
|
|
|
|
- # Print n param
|
|
|
|
|
- print("Total params:", sum(i.numel() for i in trans.parameters()) / 1024 / 1024)
|
|
|
|
|
- inputs = torch.randint(0, 1000, (2, 16)).cuda()
|
|
|
|
|
- codes = torch.randint(0, 120, (2, 4, 128)).cuda()
|
|
|
|
|
- x = trans(inputs, codes)
|
|
|
|
|
- x1 = trans(inputs, codes)
|
|
|
|
|
-
|
|
|
|
|
- assert torch.allclose(x, x1, atol=1e-4, rtol=1e-3), "Model is not deterministic"
|
|
|
|
|
- print("Model is deterministic")
|
|
|
|
|
-
|
|
|
|
|
- # Test kv cache
|
|
|
|
|
- trans.setup_kv_caches(2, 1024)
|
|
|
|
|
- inputs, inputs_mask = trans.forward_encoder(inputs)
|
|
|
|
|
-
|
|
|
|
|
- outputs = []
|
|
|
|
|
-
|
|
|
|
|
- for i in range(128):
|
|
|
|
|
- code = codes[..., i].unsqueeze(-1)
|
|
|
|
|
- code_mask = torch.tensor([[1], [1]], dtype=torch.bool, device=code.device)
|
|
|
|
|
- input_pos = torch.tensor([i], dtype=torch.long, device=code.device)
|
|
|
|
|
- outputs.append(
|
|
|
|
|
- trans.forward_decoder(
|
|
|
|
|
- code, inputs, inputs_mask, code_mask, input_pos=input_pos
|
|
|
|
|
- )
|
|
|
|
|
- )
|
|
|
|
|
-
|
|
|
|
|
- outputs = torch.cat(outputs, dim=2)
|
|
|
|
|
- print(x.shape, outputs.shape)
|
|
|
|
|
- assert torch.allclose(x, outputs, atol=1e-4, rtol=1e-3), "KV cache is not working"
|
|
|