|
|
@@ -17,10 +17,9 @@ def find_multiple(n: int, k: int) -> int:
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
-class ModelArgs:
|
|
|
+class BaseModelArgs:
|
|
|
vocab_size: int = 32000
|
|
|
- n_slow_layer: int = 32
|
|
|
- n_fast_layer: int = 4
|
|
|
+ n_layer: int = 32
|
|
|
n_head: int = 32
|
|
|
dim: int = 4096
|
|
|
intermediate_size: int = None
|
|
|
@@ -31,9 +30,10 @@ class ModelArgs:
|
|
|
max_seq_len: int = 2048
|
|
|
dropout: float = 0.0
|
|
|
|
|
|
- # Additional decoding heads
|
|
|
+ # Codebook configs
|
|
|
codebook_size: int = 160
|
|
|
num_codebooks: int = 4
|
|
|
+ num_in_codebooks: Optional[int] = None
|
|
|
codebook_padding_idx: int = 0
|
|
|
|
|
|
# Gradient checkpointing
|
|
|
@@ -46,9 +46,21 @@ class ModelArgs:
|
|
|
hidden_dim = 4 * self.dim
|
|
|
n_hidden = int(2 * hidden_dim / 3)
|
|
|
self.intermediate_size = find_multiple(n_hidden, 256)
|
|
|
+ if self.num_in_codebooks is None:
|
|
|
+ self.num_in_codebooks = self.num_codebooks
|
|
|
self.head_dim = self.dim // self.n_head
|
|
|
|
|
|
|
|
|
+@dataclass
|
|
|
+class NaiveModelArgs(BaseModelArgs):
|
|
|
+ pass
|
|
|
+
|
|
|
+
|
|
|
+@dataclass
|
|
|
+class DualARModelArgs(BaseModelArgs):
|
|
|
+ n_fast_layer: int = 4
|
|
|
+
|
|
|
+
|
|
|
class KVCache(nn.Module):
|
|
|
def __init__(
|
|
|
self, max_batch_size, max_seq_len, n_heads, head_dim, dtype=torch.bfloat16
|
|
|
@@ -76,42 +88,32 @@ class TransformerForwardResult:
|
|
|
codebook_logits: Tensor
|
|
|
|
|
|
|
|
|
-class Transformer(nn.Module):
|
|
|
- def __init__(self, config: ModelArgs) -> None:
|
|
|
+@dataclass
|
|
|
+class BaseTransformerForwardResult:
|
|
|
+ logits: Tensor
|
|
|
+ hidden_states: Tensor
|
|
|
+
|
|
|
+
|
|
|
+class BaseTransformer(nn.Module):
|
|
|
+ def __init__(self, config: BaseModelArgs) -> None:
|
|
|
super().__init__()
|
|
|
self.config = config
|
|
|
|
|
|
# Slow transformer
|
|
|
self.embeddings = nn.Embedding(
|
|
|
- config.vocab_size + config.codebook_size * config.num_codebooks,
|
|
|
+ config.vocab_size + config.codebook_size * config.num_in_codebooks,
|
|
|
config.dim,
|
|
|
)
|
|
|
- self.slow_layers = nn.ModuleList(
|
|
|
- TransformerBlock(config, use_sdpa=True) for _ in range(config.n_slow_layer)
|
|
|
+ self.layers = nn.ModuleList(
|
|
|
+ TransformerBlock(config, use_sdpa=True) for _ in range(config.n_layer)
|
|
|
)
|
|
|
- self.slow_norm = RMSNorm(config.dim, eps=config.norm_eps)
|
|
|
- self.slow_output = nn.Linear(
|
|
|
+ self.norm = RMSNorm(config.dim, eps=config.norm_eps)
|
|
|
+ self.output = nn.Linear(
|
|
|
config.dim,
|
|
|
config.vocab_size,
|
|
|
bias=False,
|
|
|
)
|
|
|
|
|
|
- # Fast transformer
|
|
|
- self.fast_embeddings = nn.Embedding(
|
|
|
- config.codebook_size, config.dim, padding_idx=config.codebook_padding_idx
|
|
|
- )
|
|
|
-
|
|
|
- # The equivalent bs is so large that sdpa doesn't work
|
|
|
- self.fast_layers = nn.ModuleList(
|
|
|
- TransformerBlock(config, use_sdpa=False) for _ in range(config.n_fast_layer)
|
|
|
- )
|
|
|
- self.fast_norm = RMSNorm(config.dim, eps=config.norm_eps)
|
|
|
- self.fast_output = nn.Linear(
|
|
|
- config.dim,
|
|
|
- config.codebook_size,
|
|
|
- bias=False,
|
|
|
- )
|
|
|
-
|
|
|
self.register_buffer(
|
|
|
"freqs_cis",
|
|
|
precompute_freqs_cis(
|
|
|
@@ -148,8 +150,7 @@ class Transformer(nn.Module):
|
|
|
self.max_seq_len = max_seq_len
|
|
|
self.max_batch_size = max_batch_size
|
|
|
|
|
|
- # Slow transformer
|
|
|
- for b in self.slow_layers:
|
|
|
+ for b in self.layers:
|
|
|
b.attention.kv_cache = KVCache(
|
|
|
max_batch_size,
|
|
|
max_seq_len,
|
|
|
@@ -158,24 +159,9 @@ class Transformer(nn.Module):
|
|
|
dtype=dtype,
|
|
|
)
|
|
|
|
|
|
- # Fast transformer
|
|
|
- # The max seq len here is the number of codebooks
|
|
|
- for b in self.fast_layers:
|
|
|
- b.attention.kv_cache = KVCache(
|
|
|
- max_batch_size,
|
|
|
- self.config.num_codebooks,
|
|
|
- self.config.n_local_heads,
|
|
|
- head_dim,
|
|
|
- dtype=dtype,
|
|
|
- )
|
|
|
-
|
|
|
def embed(self, x: Tensor) -> Tensor:
|
|
|
- # Here we want to merge the embeddings of the codebooks
|
|
|
- if self.config.num_codebooks == 0:
|
|
|
- return self.embeddings(x[:, 0])
|
|
|
-
|
|
|
vocab_embeds = [self.embeddings(x[:, 0])]
|
|
|
- for i in range(self.config.num_codebooks):
|
|
|
+ for i in range(self.config.num_in_codebooks):
|
|
|
emb = self.embeddings(
|
|
|
x[:, i + 1] + i * self.config.codebook_size + self.config.vocab_size
|
|
|
)
|
|
|
@@ -189,10 +175,9 @@ class Transformer(nn.Module):
|
|
|
|
|
|
def forward(
|
|
|
self, inp: Tensor, key_padding_mask: Optional[Tensor] = None
|
|
|
- ) -> TransformerForwardResult:
|
|
|
+ ) -> BaseTransformerForwardResult:
|
|
|
# x: (batch, num_codebooks + 1, seq_len)
|
|
|
seq_len = inp.size(2)
|
|
|
- codebooks = inp[:, 1:]
|
|
|
|
|
|
# Here we want to merge the embeddings of the codebooks
|
|
|
x = self.embed(inp)
|
|
|
@@ -206,15 +191,136 @@ class Transformer(nn.Module):
|
|
|
if key_padding_mask is not None:
|
|
|
mask = mask & key_padding_mask[:, None, None, :].logical_not()
|
|
|
|
|
|
- for layer in self.slow_layers:
|
|
|
+ for layer in self.layers:
|
|
|
if self.config.use_gradient_checkpointing and self.training:
|
|
|
x = checkpoint(layer, x, freqs_cis, mask, use_reentrant=True)
|
|
|
else:
|
|
|
x = layer(x, freqs_cis, mask)
|
|
|
|
|
|
# We got slow_out here
|
|
|
- slow_out = self.slow_norm(x)
|
|
|
- token_logits = self.slow_output(slow_out)
|
|
|
+ slow_out = self.norm(x)
|
|
|
+ token_logits = self.output(slow_out)
|
|
|
+
|
|
|
+ return BaseTransformerForwardResult(
|
|
|
+ logits=token_logits,
|
|
|
+ hidden_states=x,
|
|
|
+ )
|
|
|
+
|
|
|
+ def forward_generate(
|
|
|
+ self, x: Tensor, input_pos: Optional[Tensor] = None
|
|
|
+ ) -> BaseTransformerForwardResult:
|
|
|
+ # This is used for generation, optimized for torch compile
|
|
|
+ assert (
|
|
|
+ self.max_seq_len != -1 and self.max_batch_size != -1
|
|
|
+ ), "Please call setup_caches before forward_generate"
|
|
|
+
|
|
|
+ x = self.embed(x)
|
|
|
+
|
|
|
+ mask = self.causal_mask[
|
|
|
+ None, None, input_pos, : self.max_seq_len
|
|
|
+ ] # (B, N, Q, K)
|
|
|
+ freqs_cis = self.freqs_cis[input_pos]
|
|
|
+
|
|
|
+ for layer in self.layers:
|
|
|
+ x = layer(x, freqs_cis, mask, input_pos=input_pos)
|
|
|
+
|
|
|
+ # If prefill, we only calculate the logits of last token
|
|
|
+ if x.size(1) > 1:
|
|
|
+ x = x[:, -1:]
|
|
|
+
|
|
|
+ # We got slow_out here
|
|
|
+ slow_out = self.norm(x)
|
|
|
+ token_logits = self.output(slow_out)
|
|
|
+
|
|
|
+ return BaseTransformerForwardResult(
|
|
|
+ logits=token_logits,
|
|
|
+ hidden_states=x,
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+class NaiveTransformer(BaseTransformer):
|
|
|
+ def __init__(self, config: NaiveModelArgs) -> None:
|
|
|
+ super().__init__(config)
|
|
|
+
|
|
|
+ self.codebook_norm = RMSNorm(config.dim, eps=config.norm_eps)
|
|
|
+ self.codebook_output = nn.Linear(
|
|
|
+ config.dim,
|
|
|
+ config.codebook_size * config.num_codebooks,
|
|
|
+ bias=False,
|
|
|
+ )
|
|
|
+
|
|
|
+ def decode(self, result: BaseTransformerForwardResult) -> TransformerForwardResult:
|
|
|
+ token_logits = result.logits
|
|
|
+ x = result.hidden_states
|
|
|
+
|
|
|
+ # Codebook
|
|
|
+ codebook_logits = self.codebook_output(self.codebook_norm(x))
|
|
|
+ codebook_logits = rearrange(
|
|
|
+ codebook_logits, "b n (c d) -> b n c d", c=self.config.num_codebooks
|
|
|
+ )
|
|
|
+
|
|
|
+ return TransformerForwardResult(
|
|
|
+ token_logits=token_logits,
|
|
|
+ codebook_logits=codebook_logits,
|
|
|
+ )
|
|
|
+
|
|
|
+ def forward(
|
|
|
+ self, inp: Tensor, key_padding_mask: Optional[Tensor] = None
|
|
|
+ ) -> TransformerForwardResult:
|
|
|
+ result = super().forward(inp, key_padding_mask)
|
|
|
+ return self.decode(result)
|
|
|
+
|
|
|
+ def forward_generate(
|
|
|
+ self, x: Tensor, input_pos: Optional[Tensor] = None
|
|
|
+ ) -> TransformerForwardResult:
|
|
|
+ result = super().forward_generate(x, input_pos)
|
|
|
+ return self.decode(result)
|
|
|
+
|
|
|
+
|
|
|
+class DualARTransformer(BaseTransformer):
|
|
|
+ def __init__(self, config: DualARModelArgs) -> None:
|
|
|
+ super().__init__(config)
|
|
|
+
|
|
|
+ # Fast transformer
|
|
|
+ self.fast_embeddings = nn.Embedding(
|
|
|
+ config.codebook_size, config.dim, padding_idx=config.codebook_padding_idx
|
|
|
+ )
|
|
|
+
|
|
|
+ # The equivalent bs is so large that sdpa doesn't work
|
|
|
+ self.fast_layers = nn.ModuleList(
|
|
|
+ TransformerBlock(config, use_sdpa=False) for _ in range(config.n_fast_layer)
|
|
|
+ )
|
|
|
+ self.fast_norm = RMSNorm(config.dim, eps=config.norm_eps)
|
|
|
+ self.fast_output = nn.Linear(
|
|
|
+ config.dim,
|
|
|
+ config.codebook_size,
|
|
|
+ bias=False,
|
|
|
+ )
|
|
|
+
|
|
|
+ def setup_caches(
|
|
|
+ self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16
|
|
|
+ ):
|
|
|
+ super().setup_caches(max_batch_size, max_seq_len, dtype)
|
|
|
+
|
|
|
+ head_dim = self.config.dim // self.config.n_head
|
|
|
+
|
|
|
+ # Fast transformer
|
|
|
+ # The max seq len here is the number of codebooks
|
|
|
+ for b in self.fast_layers:
|
|
|
+ b.attention.kv_cache = KVCache(
|
|
|
+ max_batch_size,
|
|
|
+ self.config.num_codebooks,
|
|
|
+ self.config.n_local_heads,
|
|
|
+ head_dim,
|
|
|
+ dtype=dtype,
|
|
|
+ )
|
|
|
+
|
|
|
+ def forward(
|
|
|
+ self, inp: Tensor, key_padding_mask: Optional[Tensor] = None
|
|
|
+ ) -> TransformerForwardResult:
|
|
|
+ parent_result = super().forward(inp, key_padding_mask)
|
|
|
+ token_logits = parent_result.logits
|
|
|
+ x = parent_result.hidden_states
|
|
|
|
|
|
# Fast transformer
|
|
|
fast_seq_len = self.config.num_codebooks
|
|
|
@@ -224,7 +330,7 @@ class Transformer(nn.Module):
|
|
|
fast_freqs_cis = self.freqs_cis[:fast_seq_len]
|
|
|
|
|
|
# Drop the last token and rotate left
|
|
|
- codebooks = codebooks[:, :-1, 1:]
|
|
|
+ codebooks = inp[:, 1:-1, 1:]
|
|
|
codebooks = F.pad(codebooks, (0, 1), value=self.config.codebook_padding_idx)
|
|
|
codebook_embeddings = self.fast_embeddings(codebooks)
|
|
|
x = torch.cat([x[:, None], codebook_embeddings], dim=1)
|
|
|
@@ -272,36 +378,6 @@ class Transformer(nn.Module):
|
|
|
codebook_logits=codebook_logits,
|
|
|
)
|
|
|
|
|
|
- def forward_generate_slow(
|
|
|
- self, x: Tensor, input_pos: Optional[Tensor] = None
|
|
|
- ) -> Tensor:
|
|
|
- ### TODO: fix this
|
|
|
- # x: (batch, num_codebooks + 1, 1)
|
|
|
-
|
|
|
- assert (
|
|
|
- self.max_seq_len != -1 and self.max_batch_size != -1
|
|
|
- ), "Please call setup_caches before forward_generate"
|
|
|
-
|
|
|
- x = self.embed(x)
|
|
|
-
|
|
|
- mask = self.causal_mask[
|
|
|
- None, None, input_pos, : self.max_seq_len
|
|
|
- ] # (B, N, Q, K)
|
|
|
- freqs_cis = self.freqs_cis[input_pos]
|
|
|
-
|
|
|
- for layer in self.slow_layers:
|
|
|
- x = layer(x, freqs_cis, mask, input_pos=input_pos)
|
|
|
-
|
|
|
- # If prefill, we only calculate the logits of last token
|
|
|
- if x.size(1) > 1:
|
|
|
- x = x[:, -1:]
|
|
|
-
|
|
|
- # We got slow_out here
|
|
|
- slow_out = self.slow_norm(x)
|
|
|
- token_logits = self.slow_output(slow_out)
|
|
|
-
|
|
|
- return x, token_logits
|
|
|
-
|
|
|
def forward_generate_fast(
|
|
|
self, x: Tensor, input_pos: Optional[Tensor] = None
|
|
|
) -> Tensor:
|
|
|
@@ -324,7 +400,7 @@ class Transformer(nn.Module):
|
|
|
|
|
|
|
|
|
class TransformerBlock(nn.Module):
|
|
|
- def __init__(self, config: ModelArgs, use_sdpa: bool = True) -> None:
|
|
|
+ def __init__(self, config: BaseModelArgs, use_sdpa: bool = True) -> None:
|
|
|
super().__init__()
|
|
|
self.attention = Attention(config, use_sdpa=use_sdpa)
|
|
|
self.feed_forward = FeedForward(config)
|
|
|
@@ -340,7 +416,7 @@ class TransformerBlock(nn.Module):
|
|
|
|
|
|
|
|
|
class Attention(nn.Module):
|
|
|
- def __init__(self, config: ModelArgs, use_sdpa: bool = True):
|
|
|
+ def __init__(self, config: BaseModelArgs, use_sdpa: bool = True):
|
|
|
super().__init__()
|
|
|
assert config.dim % config.n_head == 0
|
|
|
|
|
|
@@ -443,7 +519,7 @@ class Attention(nn.Module):
|
|
|
|
|
|
|
|
|
class FeedForward(nn.Module):
|
|
|
- def __init__(self, config: ModelArgs) -> None:
|
|
|
+ def __init__(self, config: BaseModelArgs) -> None:
|
|
|
super().__init__()
|
|
|
self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
|
|
|
self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
|
|
|
@@ -494,10 +570,10 @@ def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
- args = ModelArgs(
|
|
|
+ args = DualARModelArgs(
|
|
|
max_seq_len=4096,
|
|
|
vocab_size=32312,
|
|
|
- n_slow_layer=12,
|
|
|
+ n_layer=12,
|
|
|
n_fast_layer=4,
|
|
|
n_head=12,
|
|
|
dim=768,
|
|
|
@@ -507,7 +583,7 @@ if __name__ == "__main__":
|
|
|
num_codebooks=4,
|
|
|
)
|
|
|
|
|
|
- model = Transformer(args)
|
|
|
+ model = DualARTransformer(args)
|
|
|
model = model.cuda().bfloat16()
|
|
|
print("Total params:", sum(i.numel() for i in model.parameters()) / 1024 / 1024)
|
|
|
|
|
|
@@ -516,4 +592,4 @@ if __name__ == "__main__":
|
|
|
key_padding_mask[0, 2:] = True
|
|
|
x1 = model(inputs, key_padding_mask=key_padding_mask)
|
|
|
print(x1.token_logits.shape)
|
|
|
- # print(x1.codebook_logits.shape)
|
|
|
+ print(x1.codebook_logits.shape)
|