|
@@ -7,6 +7,7 @@ import torch.nn as nn
|
|
|
from einops import rearrange
|
|
from einops import rearrange
|
|
|
from torch import Tensor
|
|
from torch import Tensor
|
|
|
from torch.nn import functional as F
|
|
from torch.nn import functional as F
|
|
|
|
|
+from torch.utils.checkpoint import checkpoint
|
|
|
|
|
|
|
|
|
|
|
|
|
def find_multiple(n: int, k: int) -> int:
|
|
def find_multiple(n: int, k: int) -> int:
|
|
@@ -18,7 +19,8 @@ def find_multiple(n: int, k: int) -> int:
|
|
|
@dataclass
|
|
@dataclass
|
|
|
class ModelArgs:
|
|
class ModelArgs:
|
|
|
vocab_size: int = 32000
|
|
vocab_size: int = 32000
|
|
|
- n_layer: int = 32
|
|
|
|
|
|
|
+ n_slow_layer: int = 32
|
|
|
|
|
+ n_fast_layer: int = 4
|
|
|
n_head: int = 32
|
|
n_head: int = 32
|
|
|
dim: int = 4096
|
|
dim: int = 4096
|
|
|
intermediate_size: int = None
|
|
intermediate_size: int = None
|
|
@@ -32,15 +34,11 @@ class ModelArgs:
|
|
|
# Additional decoding heads
|
|
# Additional decoding heads
|
|
|
codebook_size: int = 160
|
|
codebook_size: int = 160
|
|
|
num_codebooks: int = 4
|
|
num_codebooks: int = 4
|
|
|
- num_in_codebooks: Optional[int] = None
|
|
|
|
|
codebook_padding_idx: int = 0
|
|
codebook_padding_idx: int = 0
|
|
|
|
|
|
|
|
# Gradient checkpointing
|
|
# Gradient checkpointing
|
|
|
use_gradient_checkpointing: bool = True
|
|
use_gradient_checkpointing: bool = True
|
|
|
|
|
|
|
|
- # NEFT
|
|
|
|
|
- neft_alpha: float = 0
|
|
|
|
|
-
|
|
|
|
|
def __post_init__(self):
|
|
def __post_init__(self):
|
|
|
if self.n_local_heads == -1:
|
|
if self.n_local_heads == -1:
|
|
|
self.n_local_heads = self.n_head
|
|
self.n_local_heads = self.n_head
|
|
@@ -48,8 +46,6 @@ class ModelArgs:
|
|
|
hidden_dim = 4 * self.dim
|
|
hidden_dim = 4 * self.dim
|
|
|
n_hidden = int(2 * hidden_dim / 3)
|
|
n_hidden = int(2 * hidden_dim / 3)
|
|
|
self.intermediate_size = find_multiple(n_hidden, 256)
|
|
self.intermediate_size = find_multiple(n_hidden, 256)
|
|
|
- if self.num_in_codebooks is None:
|
|
|
|
|
- self.num_in_codebooks = self.num_codebooks
|
|
|
|
|
self.head_dim = self.dim // self.n_head
|
|
self.head_dim = self.dim // self.n_head
|
|
|
|
|
|
|
|
|
|
|
|
@@ -85,17 +81,34 @@ class Transformer(nn.Module):
|
|
|
super().__init__()
|
|
super().__init__()
|
|
|
self.config = config
|
|
self.config = config
|
|
|
|
|
|
|
|
|
|
+ # Slow transformer
|
|
|
self.embeddings = nn.Embedding(
|
|
self.embeddings = nn.Embedding(
|
|
|
- config.vocab_size + config.codebook_size * config.num_in_codebooks,
|
|
|
|
|
|
|
+ config.vocab_size + config.codebook_size * config.num_codebooks,
|
|
|
config.dim,
|
|
config.dim,
|
|
|
)
|
|
)
|
|
|
- self.layers = nn.ModuleList(
|
|
|
|
|
- TransformerBlock(config) for _ in range(config.n_layer)
|
|
|
|
|
|
|
+ self.slow_layers = nn.ModuleList(
|
|
|
|
|
+ TransformerBlock(config, use_sdpa=True) for _ in range(config.n_slow_layer)
|
|
|
)
|
|
)
|
|
|
- self.norm = RMSNorm(config.dim, eps=config.norm_eps)
|
|
|
|
|
- self.output = nn.Linear(
|
|
|
|
|
|
|
+ self.slow_norm = RMSNorm(config.dim, eps=config.norm_eps)
|
|
|
|
|
+ self.slow_output = nn.Linear(
|
|
|
config.dim,
|
|
config.dim,
|
|
|
- config.vocab_size + config.codebook_size * config.num_codebooks,
|
|
|
|
|
|
|
+ config.vocab_size,
|
|
|
|
|
+ bias=False,
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ # Fast transformer
|
|
|
|
|
+ self.fast_embeddings = nn.Embedding(
|
|
|
|
|
+ config.codebook_size, config.dim, padding_idx=config.codebook_padding_idx
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ # The equivalent bs is so large that sdpa doesn't work
|
|
|
|
|
+ self.fast_layers = nn.ModuleList(
|
|
|
|
|
+ TransformerBlock(config, use_sdpa=False) for _ in range(config.n_fast_layer)
|
|
|
|
|
+ )
|
|
|
|
|
+ self.fast_norm = RMSNorm(config.dim, eps=config.norm_eps)
|
|
|
|
|
+ self.fast_output = nn.Linear(
|
|
|
|
|
+ config.dim,
|
|
|
|
|
+ config.codebook_size,
|
|
|
bias=False,
|
|
bias=False,
|
|
|
)
|
|
)
|
|
|
|
|
|
|
@@ -106,6 +119,7 @@ class Transformer(nn.Module):
|
|
|
config.dim // config.n_head,
|
|
config.dim // config.n_head,
|
|
|
config.rope_base,
|
|
config.rope_base,
|
|
|
),
|
|
),
|
|
|
|
|
+ persistent=False,
|
|
|
)
|
|
)
|
|
|
self.register_buffer(
|
|
self.register_buffer(
|
|
|
"causal_mask",
|
|
"causal_mask",
|
|
@@ -116,6 +130,7 @@ class Transformer(nn.Module):
|
|
|
dtype=torch.bool,
|
|
dtype=torch.bool,
|
|
|
)
|
|
)
|
|
|
),
|
|
),
|
|
|
|
|
+ persistent=False,
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
# For kv cache
|
|
# For kv cache
|
|
@@ -144,11 +159,11 @@ class Transformer(nn.Module):
|
|
|
|
|
|
|
|
def embed(self, x: Tensor) -> Tensor:
|
|
def embed(self, x: Tensor) -> Tensor:
|
|
|
# Here we want to merge the embeddings of the codebooks
|
|
# Here we want to merge the embeddings of the codebooks
|
|
|
- if self.config.num_in_codebooks == 0:
|
|
|
|
|
|
|
+ if self.config.num_codebooks == 0:
|
|
|
return self.embeddings(x[:, 0])
|
|
return self.embeddings(x[:, 0])
|
|
|
|
|
|
|
|
vocab_embeds = [self.embeddings(x[:, 0])]
|
|
vocab_embeds = [self.embeddings(x[:, 0])]
|
|
|
- for i in range(self.config.num_in_codebooks):
|
|
|
|
|
|
|
+ for i in range(self.config.num_codebooks):
|
|
|
emb = self.embeddings(
|
|
emb = self.embeddings(
|
|
|
x[:, i + 1] + i * self.config.codebook_size + self.config.vocab_size
|
|
x[:, i + 1] + i * self.config.codebook_size + self.config.vocab_size
|
|
|
)
|
|
)
|
|
@@ -158,15 +173,6 @@ class Transformer(nn.Module):
|
|
|
x = torch.stack(vocab_embeds, dim=3)
|
|
x = torch.stack(vocab_embeds, dim=3)
|
|
|
x = x.sum(dim=3)
|
|
x = x.sum(dim=3)
|
|
|
|
|
|
|
|
- if self.config.neft_alpha > 0 and self.training:
|
|
|
|
|
- # alpha / sqrt(L * D)
|
|
|
|
|
- scaled_alpha = self.config.neft_alpha / math.sqrt(
|
|
|
|
|
- self.config.dim * x.shape[2]
|
|
|
|
|
- )
|
|
|
|
|
- x += torch.rand_like(x) * scaled_alpha
|
|
|
|
|
-
|
|
|
|
|
- print("NEFT alpha:", scaled_alpha)
|
|
|
|
|
-
|
|
|
|
|
return x
|
|
return x
|
|
|
|
|
|
|
|
def compute(
|
|
def compute(
|
|
@@ -176,11 +182,11 @@ class Transformer(nn.Module):
|
|
|
mask: Tensor,
|
|
mask: Tensor,
|
|
|
input_pos: Optional[Tensor] = None,
|
|
input_pos: Optional[Tensor] = None,
|
|
|
) -> TransformerForwardResult:
|
|
) -> TransformerForwardResult:
|
|
|
|
|
+ raise NotImplementedError
|
|
|
|
|
+
|
|
|
for layer in self.layers:
|
|
for layer in self.layers:
|
|
|
if self.config.use_gradient_checkpointing and self.training:
|
|
if self.config.use_gradient_checkpointing and self.training:
|
|
|
- x = torch.utils.checkpoint.checkpoint(
|
|
|
|
|
- layer, x, freqs_cis, mask, input_pos, use_reentrant=True
|
|
|
|
|
- )
|
|
|
|
|
|
|
+ x = checkpoint(layer, x, freqs_cis, mask, input_pos, use_reentrant=True)
|
|
|
else:
|
|
else:
|
|
|
x = layer(x, freqs_cis, mask, input_pos=input_pos)
|
|
x = layer(x, freqs_cis, mask, input_pos=input_pos)
|
|
|
|
|
|
|
@@ -210,6 +216,10 @@ class Transformer(nn.Module):
|
|
|
# x: (batch, num_codebooks + 1, seq_len)
|
|
# x: (batch, num_codebooks + 1, seq_len)
|
|
|
seq_len = x.size(2)
|
|
seq_len = x.size(2)
|
|
|
|
|
|
|
|
|
|
+ # For codebook, the decoding is actually shifted by 1
|
|
|
|
|
+ # Which is the labels section
|
|
|
|
|
+ codebooks = x[:, 1:]
|
|
|
|
|
+
|
|
|
# Here we want to merge the embeddings of the codebooks
|
|
# Here we want to merge the embeddings of the codebooks
|
|
|
x = self.embed(x)
|
|
x = self.embed(x)
|
|
|
|
|
|
|
@@ -222,9 +232,59 @@ class Transformer(nn.Module):
|
|
|
if key_padding_mask is not None:
|
|
if key_padding_mask is not None:
|
|
|
mask = mask & key_padding_mask[:, None, None, :].logical_not()
|
|
mask = mask & key_padding_mask[:, None, None, :].logical_not()
|
|
|
|
|
|
|
|
- return self.compute(x, freqs_cis, mask)
|
|
|
|
|
|
|
+ for layer in self.slow_layers:
|
|
|
|
|
+ if self.config.use_gradient_checkpointing and self.training:
|
|
|
|
|
+ x = checkpoint(layer, x, freqs_cis, mask, use_reentrant=True)
|
|
|
|
|
+ else:
|
|
|
|
|
+ x = layer(x, freqs_cis, mask)
|
|
|
|
|
+
|
|
|
|
|
+ # We got slow_out here
|
|
|
|
|
+ slow_out = self.slow_norm(x)
|
|
|
|
|
+ token_logits = self.slow_output(slow_out)
|
|
|
|
|
+
|
|
|
|
|
+ # Fast transformer
|
|
|
|
|
+ fast_seq_len = self.config.num_codebooks
|
|
|
|
|
+ fast_mask = self.causal_mask[
|
|
|
|
|
+ None, None, :fast_seq_len, :fast_seq_len
|
|
|
|
|
+ ] # (B, N, Q, K)
|
|
|
|
|
+ fast_freqs_cis = self.freqs_cis[:fast_seq_len]
|
|
|
|
|
+
|
|
|
|
|
+ # There should be a bug here
|
|
|
|
|
+ # Say at t0, the given input is [/INST] for semantic token
|
|
|
|
|
+ # Then we want to predict <tok0>, <tok1>, ... (instead of <s> <s> <s>) given <feat>, <tok0>, <tok1>, ...
|
|
|
|
|
+ # Otherwise this becomes: decode tokens from same given tokens
|
|
|
|
|
+ # Ignore the last token, since the input should be <feat>, <tok0>, <tok1>, ...
|
|
|
|
|
+ codebook_embeddings = self.fast_embeddings(codebooks[:, :-1])
|
|
|
|
|
+
|
|
|
|
|
+ x = torch.cat([x[:, None, 1:], codebook_embeddings], dim=1) # (B, N + 1, S, D)
|
|
|
|
|
+ b, s = x.size(0), x.size(2)
|
|
|
|
|
+ x = rearrange(x, "b n s d -> (b s) n d") # flatten the batch and seq_len
|
|
|
|
|
+
|
|
|
|
|
+ for layer in self.fast_layers:
|
|
|
|
|
+ if self.config.use_gradient_checkpointing and self.training:
|
|
|
|
|
+ x = checkpoint(layer, x, fast_freqs_cis, fast_mask, use_reentrant=True)
|
|
|
|
|
+ else:
|
|
|
|
|
+ x = layer(x, fast_freqs_cis, fast_mask)
|
|
|
|
|
+
|
|
|
|
|
+ # unflatten the batch and num_codebooks
|
|
|
|
|
+ fast_out = self.fast_norm(x)
|
|
|
|
|
+ codebook_logits = self.fast_output(fast_out)
|
|
|
|
|
+ assert codebook_logits.shape[1] == self.config.num_codebooks
|
|
|
|
|
+ codebook_logits = rearrange(
|
|
|
|
|
+ codebook_logits,
|
|
|
|
|
+ "(b s) n d -> b s n d",
|
|
|
|
|
+ b=b,
|
|
|
|
|
+ s=s,
|
|
|
|
|
+ n=self.config.num_codebooks,
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ return TransformerForwardResult(
|
|
|
|
|
+ token_logits=token_logits,
|
|
|
|
|
+ codebook_logits=codebook_logits,
|
|
|
|
|
+ )
|
|
|
|
|
|
|
|
def forward_generate(self, x: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
|
|
def forward_generate(self, x: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
|
|
|
|
|
+ ### TODO: fix this
|
|
|
# x: (batch, num_codebooks + 1, 1)
|
|
# x: (batch, num_codebooks + 1, 1)
|
|
|
|
|
|
|
|
assert (
|
|
assert (
|
|
@@ -244,9 +304,9 @@ class Transformer(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
|
class TransformerBlock(nn.Module):
|
|
class TransformerBlock(nn.Module):
|
|
|
- def __init__(self, config: ModelArgs) -> None:
|
|
|
|
|
|
|
+ def __init__(self, config: ModelArgs, use_sdpa: bool = True) -> None:
|
|
|
super().__init__()
|
|
super().__init__()
|
|
|
- self.attention = Attention(config)
|
|
|
|
|
|
|
+ self.attention = Attention(config, use_sdpa=use_sdpa)
|
|
|
self.feed_forward = FeedForward(config)
|
|
self.feed_forward = FeedForward(config)
|
|
|
self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
|
|
self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
|
|
|
self.attention_norm = RMSNorm(config.dim, config.norm_eps)
|
|
self.attention_norm = RMSNorm(config.dim, config.norm_eps)
|
|
@@ -260,7 +320,7 @@ class TransformerBlock(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
|
class Attention(nn.Module):
|
|
class Attention(nn.Module):
|
|
|
- def __init__(self, config: ModelArgs):
|
|
|
|
|
|
|
+ def __init__(self, config: ModelArgs, use_sdpa: bool = True):
|
|
|
super().__init__()
|
|
super().__init__()
|
|
|
assert config.dim % config.n_head == 0
|
|
assert config.dim % config.n_head == 0
|
|
|
|
|
|
|
@@ -275,6 +335,7 @@ class Attention(nn.Module):
|
|
|
self.head_dim = config.head_dim
|
|
self.head_dim = config.head_dim
|
|
|
self.n_local_heads = config.n_local_heads
|
|
self.n_local_heads = config.n_local_heads
|
|
|
self.dim = config.dim
|
|
self.dim = config.dim
|
|
|
|
|
+ self.use_sdpa = use_sdpa
|
|
|
self._register_load_state_dict_pre_hook(self.load_hook)
|
|
self._register_load_state_dict_pre_hook(self.load_hook)
|
|
|
|
|
|
|
|
def load_hook(self, state_dict, prefix, *args):
|
|
def load_hook(self, state_dict, prefix, *args):
|
|
@@ -310,18 +371,56 @@ class Attention(nn.Module):
|
|
|
|
|
|
|
|
k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
|
|
k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
|
|
|
v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
|
|
v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
|
|
|
- y = F.scaled_dot_product_attention(
|
|
|
|
|
- q,
|
|
|
|
|
- k,
|
|
|
|
|
- v,
|
|
|
|
|
- attn_mask=mask,
|
|
|
|
|
- dropout_p=self.dropout if self.training else 0.0,
|
|
|
|
|
- )
|
|
|
|
|
|
|
+
|
|
|
|
|
+ if self.use_sdpa:
|
|
|
|
|
+ y = F.scaled_dot_product_attention(
|
|
|
|
|
+ q,
|
|
|
|
|
+ k,
|
|
|
|
|
+ v,
|
|
|
|
|
+ attn_mask=mask,
|
|
|
|
|
+ dropout_p=self.dropout if self.training else 0.0,
|
|
|
|
|
+ )
|
|
|
|
|
+ else:
|
|
|
|
|
+ y = self.eq_scaled_dot_product_attention(
|
|
|
|
|
+ q,
|
|
|
|
|
+ k,
|
|
|
|
|
+ v,
|
|
|
|
|
+ attn_mask=mask,
|
|
|
|
|
+ dropout_p=self.dropout if self.training else 0.0,
|
|
|
|
|
+ )
|
|
|
|
|
|
|
|
y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
|
|
y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
|
|
|
|
|
|
|
|
return self.wo(y)
|
|
return self.wo(y)
|
|
|
|
|
|
|
|
|
|
+ def eq_scaled_dot_product_attention(
|
|
|
|
|
+ self,
|
|
|
|
|
+ query,
|
|
|
|
|
+ key,
|
|
|
|
|
+ value,
|
|
|
|
|
+ attn_mask=None,
|
|
|
|
|
+ dropout_p=0.0,
|
|
|
|
|
+ ) -> torch.Tensor:
|
|
|
|
|
+ # This is a standard scaled dot product attention
|
|
|
|
|
+ # It's low efficient, but it doesn't raise cuda error
|
|
|
|
|
+
|
|
|
|
|
+ L, S = query.size(-2), key.size(-2)
|
|
|
|
|
+ scale_factor = 1 / math.sqrt(query.size(-1))
|
|
|
|
|
+ attn_bias = torch.zeros(1, 1, L, S, dtype=query.dtype, device=query.device)
|
|
|
|
|
+
|
|
|
|
|
+ if attn_mask is not None:
|
|
|
|
|
+ if attn_mask.dtype == torch.bool:
|
|
|
|
|
+ attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
|
|
|
|
|
+ else:
|
|
|
|
|
+ attn_bias += attn_mask
|
|
|
|
|
+
|
|
|
|
|
+ attn_weight = query @ key.transpose(-2, -1) * scale_factor
|
|
|
|
|
+ attn_weight += attn_bias
|
|
|
|
|
+ attn_weight = torch.softmax(attn_weight, dim=-1)
|
|
|
|
|
+ attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
|
|
|
|
|
+
|
|
|
|
|
+ return attn_weight @ value
|
|
|
|
|
+
|
|
|
|
|
|
|
|
class FeedForward(nn.Module):
|
|
class FeedForward(nn.Module):
|
|
|
def __init__(self, config: ModelArgs) -> None:
|
|
def __init__(self, config: ModelArgs) -> None:
|
|
@@ -378,13 +477,14 @@ if __name__ == "__main__":
|
|
|
args = ModelArgs(
|
|
args = ModelArgs(
|
|
|
max_seq_len=4096,
|
|
max_seq_len=4096,
|
|
|
vocab_size=32312,
|
|
vocab_size=32312,
|
|
|
- n_layer=12,
|
|
|
|
|
|
|
+ n_slow_layer=12,
|
|
|
|
|
+ n_fast_layer=4,
|
|
|
n_head=12,
|
|
n_head=12,
|
|
|
dim=768,
|
|
dim=768,
|
|
|
rope_base=10000,
|
|
rope_base=10000,
|
|
|
norm_eps=1e-5,
|
|
norm_eps=1e-5,
|
|
|
- codebook_size=0,
|
|
|
|
|
- num_codebooks=0,
|
|
|
|
|
|
|
+ codebook_size=128,
|
|
|
|
|
+ num_codebooks=4,
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
model = Transformer(args)
|
|
model = Transformer(args)
|