|
@@ -148,7 +148,7 @@ class Transformer(nn.Module):
|
|
|
self.max_seq_len = max_seq_len
|
|
self.max_seq_len = max_seq_len
|
|
|
self.max_batch_size = max_batch_size
|
|
self.max_batch_size = max_batch_size
|
|
|
|
|
|
|
|
- for b in self.layers:
|
|
|
|
|
|
|
+ for b in self.slow_layers:
|
|
|
b.attention.kv_cache = KVCache(
|
|
b.attention.kv_cache = KVCache(
|
|
|
max_batch_size,
|
|
max_batch_size,
|
|
|
max_seq_len,
|
|
max_seq_len,
|
|
@@ -157,6 +157,8 @@ class Transformer(nn.Module):
|
|
|
dtype=dtype,
|
|
dtype=dtype,
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
|
|
+ # TODO: add fast transformer kv cache
|
|
|
|
|
+
|
|
|
def embed(self, x: Tensor) -> Tensor:
|
|
def embed(self, x: Tensor) -> Tensor:
|
|
|
# Here we want to merge the embeddings of the codebooks
|
|
# Here we want to merge the embeddings of the codebooks
|
|
|
if self.config.num_codebooks == 0:
|
|
if self.config.num_codebooks == 0:
|
|
@@ -175,41 +177,6 @@ class Transformer(nn.Module):
|
|
|
|
|
|
|
|
return x
|
|
return x
|
|
|
|
|
|
|
|
- def compute(
|
|
|
|
|
- self,
|
|
|
|
|
- x: Tensor,
|
|
|
|
|
- freqs_cis: Tensor,
|
|
|
|
|
- mask: Tensor,
|
|
|
|
|
- input_pos: Optional[Tensor] = None,
|
|
|
|
|
- ) -> TransformerForwardResult:
|
|
|
|
|
- raise NotImplementedError
|
|
|
|
|
-
|
|
|
|
|
- for layer in self.layers:
|
|
|
|
|
- if self.config.use_gradient_checkpointing and self.training:
|
|
|
|
|
- x = checkpoint(layer, x, freqs_cis, mask, input_pos, use_reentrant=True)
|
|
|
|
|
- else:
|
|
|
|
|
- x = layer(x, freqs_cis, mask, input_pos=input_pos)
|
|
|
|
|
-
|
|
|
|
|
- x = self.norm(x)
|
|
|
|
|
- logits = self.output(x)
|
|
|
|
|
- token_logits = logits[:, :, : self.config.vocab_size]
|
|
|
|
|
-
|
|
|
|
|
- if self.config.num_codebooks == 0:
|
|
|
|
|
- return TransformerForwardResult(
|
|
|
|
|
- token_logits=token_logits,
|
|
|
|
|
- codebook_logits=None,
|
|
|
|
|
- )
|
|
|
|
|
-
|
|
|
|
|
- codebook_logits = logits[:, :, self.config.vocab_size :]
|
|
|
|
|
- codebook_logits = rearrange(
|
|
|
|
|
- codebook_logits, "b n (c d) -> b n c d", c=self.config.num_codebooks
|
|
|
|
|
- )
|
|
|
|
|
-
|
|
|
|
|
- return TransformerForwardResult(
|
|
|
|
|
- token_logits=token_logits,
|
|
|
|
|
- codebook_logits=codebook_logits,
|
|
|
|
|
- )
|
|
|
|
|
-
|
|
|
|
|
def forward(
|
|
def forward(
|
|
|
self, x: Tensor, key_padding_mask: Optional[Tensor] = None
|
|
self, x: Tensor, key_padding_mask: Optional[Tensor] = None
|
|
|
) -> TransformerForwardResult:
|
|
) -> TransformerForwardResult:
|
|
@@ -248,15 +215,9 @@ class Transformer(nn.Module):
|
|
|
None, None, :fast_seq_len, :fast_seq_len
|
|
None, None, :fast_seq_len, :fast_seq_len
|
|
|
] # (B, N, Q, K)
|
|
] # (B, N, Q, K)
|
|
|
fast_freqs_cis = self.freqs_cis[:fast_seq_len]
|
|
fast_freqs_cis = self.freqs_cis[:fast_seq_len]
|
|
|
-
|
|
|
|
|
- # There should be a bug here
|
|
|
|
|
- # Say at t0, the given input is [/INST] for semantic token
|
|
|
|
|
- # Then we want to predict <tok0>, <tok1>, ... (instead of <s> <s> <s>) given <feat>, <tok0>, <tok1>, ...
|
|
|
|
|
- # Otherwise this becomes: decode tokens from same given tokens
|
|
|
|
|
- # Ignore the last token, since the input should be <feat>, <tok0>, <tok1>, ...
|
|
|
|
|
codebook_embeddings = self.fast_embeddings(codebooks[:, :-1])
|
|
codebook_embeddings = self.fast_embeddings(codebooks[:, :-1])
|
|
|
|
|
|
|
|
- x = torch.cat([x[:, None, 1:], codebook_embeddings], dim=1) # (B, N + 1, S, D)
|
|
|
|
|
|
|
+ x = torch.cat([x[:, None], codebook_embeddings], dim=1) # (B, N + 1, S, D)
|
|
|
b, s = x.size(0), x.size(2)
|
|
b, s = x.size(0), x.size(2)
|
|
|
x = rearrange(x, "b n s d -> (b s) n d") # flatten the batch and seq_len
|
|
x = rearrange(x, "b n s d -> (b s) n d") # flatten the batch and seq_len
|
|
|
|
|
|
|
@@ -298,9 +259,54 @@ class Transformer(nn.Module):
|
|
|
] # (B, N, Q, K)
|
|
] # (B, N, Q, K)
|
|
|
freqs_cis = self.freqs_cis[input_pos]
|
|
freqs_cis = self.freqs_cis[input_pos]
|
|
|
|
|
|
|
|
- # TODO: support key padding mask for generation
|
|
|
|
|
|
|
+ for layer in self.slow_layers:
|
|
|
|
|
+ x = layer(x, freqs_cis, mask, input_pos=input_pos)
|
|
|
|
|
+
|
|
|
|
|
+ # If prefill, we only calculate the logits of last token
|
|
|
|
|
+ if x.size(1) > 1:
|
|
|
|
|
+ x = x[:, -1:]
|
|
|
|
|
+
|
|
|
|
|
+ # We got slow_out here
|
|
|
|
|
+ slow_out = self.slow_norm(x)
|
|
|
|
|
+ token_logits = self.slow_output(slow_out)
|
|
|
|
|
+
|
|
|
|
|
+ # Fast transformer
|
|
|
|
|
+ fast_features = [x[:, None]]
|
|
|
|
|
+ fast_logits = []
|
|
|
|
|
+
|
|
|
|
|
+ for _ in range(self.config.num_codebooks):
|
|
|
|
|
+ x = torch.cat(fast_features, dim=1) # (B, N + 1, S, D)
|
|
|
|
|
+ b, s = x.size(0), x.size(2)
|
|
|
|
|
+ x = rearrange(x, "b n s d -> (b s) n d") # flatten the batch and seq_len
|
|
|
|
|
+
|
|
|
|
|
+ fast_seq_len = x.size(1)
|
|
|
|
|
+ fast_mask = self.causal_mask[
|
|
|
|
|
+ None, None, :fast_seq_len, :fast_seq_len
|
|
|
|
|
+ ] # (B, N, Q, K)
|
|
|
|
|
+ fast_freqs_cis = self.freqs_cis[:fast_seq_len]
|
|
|
|
|
|
|
|
- return self.compute(x, freqs_cis, mask, input_pos=input_pos)
|
|
|
|
|
|
|
+ for layer in self.fast_layers:
|
|
|
|
|
+ x = layer(x, fast_freqs_cis, fast_mask)
|
|
|
|
|
+
|
|
|
|
|
+ # unflatten the batch and num_codebooks
|
|
|
|
|
+ fast_out = self.fast_norm(x[:, -1:]) # only take the last token
|
|
|
|
|
+ codebook_logits = self.fast_output(fast_out)
|
|
|
|
|
+ fast_logits.append(codebook_logits)
|
|
|
|
|
+
|
|
|
|
|
+ # Get the argmax
|
|
|
|
|
+ codebook_idx = codebook_logits.argmax(dim=-1)
|
|
|
|
|
+ codebook_embeddings = self.fast_embeddings(codebook_idx)
|
|
|
|
|
+ fast_features.append(codebook_embeddings.view(b, 1, s, -1))
|
|
|
|
|
+
|
|
|
|
|
+ codebook_logits = torch.stack(fast_logits, dim=1)
|
|
|
|
|
+ assert codebook_logits.shape[1] == self.config.num_codebooks
|
|
|
|
|
+
|
|
|
|
|
+ codebook_logits = rearrange(codebook_logits, "b c n d -> b n c d")
|
|
|
|
|
+
|
|
|
|
|
+ return TransformerForwardResult(
|
|
|
|
|
+ token_logits=token_logits,
|
|
|
|
|
+ codebook_logits=codebook_logits,
|
|
|
|
|
+ )
|
|
|
|
|
|
|
|
|
|
|
|
|
class TransformerBlock(nn.Module):
|
|
class TransformerBlock(nn.Module):
|