|
|
@@ -1,3 +1,4 @@
|
|
|
+import dataclasses
|
|
|
import json
|
|
|
import math
|
|
|
from collections import OrderedDict
|
|
|
@@ -57,6 +58,10 @@ class BaseModelArgs:
|
|
|
# Initialize the model
|
|
|
initializer_range: float = 0.02
|
|
|
|
|
|
+ # Dummy vars
|
|
|
+ is_reward_model: bool = False
|
|
|
+ share_codebook_embeddings: bool = True
|
|
|
+
|
|
|
def __post_init__(self):
|
|
|
if self.n_local_heads == -1:
|
|
|
self.n_local_heads = self.n_head
|
|
|
@@ -100,6 +105,28 @@ class NaiveModelArgs(BaseModelArgs):
|
|
|
class DualARModelArgs(BaseModelArgs):
|
|
|
model_type: str = "dual_ar"
|
|
|
n_fast_layer: int = 4
|
|
|
+ fast_dim: int | None = None
|
|
|
+ fast_n_head: int | None = None
|
|
|
+ fast_n_local_heads: int | None = None
|
|
|
+ fast_head_dim: int | None = None
|
|
|
+ fast_intermediate_size: int | None = None
|
|
|
+ fast_attention_qkv_bias: bool | None = None
|
|
|
+
|
|
|
+ def __post_init__(self):
|
|
|
+ super().__post_init__()
|
|
|
+
|
|
|
+ self.fast_dim = self.fast_dim or self.dim
|
|
|
+ self.fast_n_head = self.fast_n_head or self.n_head
|
|
|
+ self.fast_n_local_heads = self.fast_n_local_heads or self.n_local_heads
|
|
|
+ self.fast_head_dim = self.fast_head_dim or self.head_dim
|
|
|
+ self.fast_intermediate_size = (
|
|
|
+ self.fast_intermediate_size or self.intermediate_size
|
|
|
+ )
|
|
|
+ self.fast_attention_qkv_bias = (
|
|
|
+ self.fast_attention_qkv_bias
|
|
|
+ if self.fast_attention_qkv_bias is not None
|
|
|
+ else self.attention_qkv_bias
|
|
|
+ )
|
|
|
|
|
|
|
|
|
class KVCache(nn.Module):
|
|
|
@@ -474,20 +501,46 @@ class DualARTransformer(BaseTransformer):
|
|
|
def __init__(self, config: NaiveModelArgs, tokenizer: AutoTokenizer) -> None:
|
|
|
super().__init__(config, init_weights=False, tokenizer=tokenizer)
|
|
|
|
|
|
+ # Project to fast dim if needed
|
|
|
+ if config.fast_dim is not None and config.fast_dim != config.dim:
|
|
|
+ self.fast_project_in = nn.Linear(config.dim, config.fast_dim)
|
|
|
+ else:
|
|
|
+ self.fast_project_in = nn.Identity()
|
|
|
+
|
|
|
# Fast transformer
|
|
|
- self.fast_embeddings = nn.Embedding(config.codebook_size, config.dim)
|
|
|
+ self.fast_embeddings = nn.Embedding(config.codebook_size, config.fast_dim)
|
|
|
|
|
|
# The equivalent bs is so large that sdpa doesn't work
|
|
|
+ override_config = dataclasses.replace(
|
|
|
+ config,
|
|
|
+ dim=config.fast_dim,
|
|
|
+ n_head=config.fast_n_head,
|
|
|
+ n_local_heads=config.fast_n_local_heads,
|
|
|
+ head_dim=config.fast_head_dim,
|
|
|
+ intermediate_size=config.fast_intermediate_size,
|
|
|
+ attention_qkv_bias=config.fast_attention_qkv_bias,
|
|
|
+ )
|
|
|
+
|
|
|
self.fast_layers = nn.ModuleList(
|
|
|
- TransformerBlock(config, use_sdpa=False) for _ in range(config.n_fast_layer)
|
|
|
+ TransformerBlock(override_config, use_sdpa=False)
|
|
|
+ for _ in range(config.n_fast_layer)
|
|
|
)
|
|
|
- self.fast_norm = RMSNorm(config.dim, eps=config.norm_eps)
|
|
|
+ self.fast_norm = RMSNorm(config.fast_dim, eps=config.norm_eps)
|
|
|
self.fast_output = nn.Linear(
|
|
|
- config.dim,
|
|
|
+ config.fast_dim,
|
|
|
config.codebook_size,
|
|
|
bias=False,
|
|
|
)
|
|
|
|
|
|
+ self.register_buffer(
|
|
|
+ "fast_freqs_cis",
|
|
|
+ precompute_freqs_cis(
|
|
|
+ config.num_codebooks,
|
|
|
+ config.fast_dim // config.fast_n_head,
|
|
|
+ config.rope_base,
|
|
|
+ ),
|
|
|
+ persistent=False,
|
|
|
+ )
|
|
|
self.apply(self._init_weights)
|
|
|
|
|
|
def setup_caches(
|
|
|
@@ -495,7 +548,7 @@ class DualARTransformer(BaseTransformer):
|
|
|
):
|
|
|
super().setup_caches(max_batch_size, max_seq_len, dtype)
|
|
|
|
|
|
- head_dim = self.config.dim // self.config.n_head
|
|
|
+ head_dim = self.config.fast_dim // self.config.fast_n_head
|
|
|
|
|
|
# Fast transformer
|
|
|
# The max seq len here is the number of codebooks
|
|
|
@@ -503,7 +556,7 @@ class DualARTransformer(BaseTransformer):
|
|
|
b.attention.kv_cache = KVCache(
|
|
|
max_batch_size,
|
|
|
self.config.num_codebooks,
|
|
|
- self.config.n_local_heads,
|
|
|
+ self.config.fast_n_local_heads,
|
|
|
head_dim,
|
|
|
dtype=dtype,
|
|
|
)
|
|
|
@@ -516,13 +569,13 @@ class DualARTransformer(BaseTransformer):
|
|
|
parent_result = super().forward(inp, key_padding_mask)
|
|
|
token_logits = parent_result.logits
|
|
|
x = parent_result.hidden_states
|
|
|
+ x = self.fast_project_in(x)
|
|
|
|
|
|
# Fast transformer
|
|
|
fast_seq_len = self.config.num_codebooks
|
|
|
fast_mask = self.causal_mask[
|
|
|
None, None, :fast_seq_len, :fast_seq_len
|
|
|
] # (B, N, Q, K)
|
|
|
- fast_freqs_cis = self.freqs_cis[:fast_seq_len]
|
|
|
|
|
|
# Drop the last token and rotate left
|
|
|
codebooks = inp[:, 1:-1, 1:]
|
|
|
@@ -545,9 +598,11 @@ class DualARTransformer(BaseTransformer):
|
|
|
|
|
|
for layer in self.fast_layers:
|
|
|
if self.config.use_gradient_checkpointing and self.training:
|
|
|
- x = checkpoint(layer, x, fast_freqs_cis, fast_mask, use_reentrant=True)
|
|
|
+ x = checkpoint(
|
|
|
+ layer, x, self.fast_freqs_cis, fast_mask, use_reentrant=True
|
|
|
+ )
|
|
|
else:
|
|
|
- x = layer(x, fast_freqs_cis, fast_mask)
|
|
|
+ x = layer(x, self.fast_freqs_cis, fast_mask)
|
|
|
|
|
|
# unflatten the batch and num_codebooks
|
|
|
fast_out = self.fast_norm(x)
|
|
|
@@ -587,7 +642,7 @@ class DualARTransformer(BaseTransformer):
|
|
|
fast_mask = self.causal_mask[
|
|
|
None, None, input_pos, : self.config.num_codebooks
|
|
|
] # (B, N, Q, K)
|
|
|
- fast_freqs_cis = self.freqs_cis[input_pos]
|
|
|
+ fast_freqs_cis = self.fast_freqs_cis[input_pos]
|
|
|
|
|
|
for layer in self.fast_layers:
|
|
|
x = layer(x, fast_freqs_cis, fast_mask, input_pos=input_pos)
|
|
|
@@ -598,6 +653,13 @@ class DualARTransformer(BaseTransformer):
|
|
|
|
|
|
return codebook_logits
|
|
|
|
|
|
+ def forward_generate(
|
|
|
+ self, x: Tensor, input_pos: Optional[Tensor] = None
|
|
|
+ ) -> TransformerForwardResult:
|
|
|
+ x = super().forward_generate(x, input_pos)
|
|
|
+ x.hidden_states = self.fast_project_in(x.hidden_states)
|
|
|
+ return x
|
|
|
+
|
|
|
|
|
|
class TransformerBlock(nn.Module):
|
|
|
def __init__(self, config: BaseModelArgs, use_sdpa: bool = True) -> None:
|