|
|
@@ -192,9 +192,6 @@ class Transformer(nn.Module):
|
|
|
) -> TransformerForwardResult:
|
|
|
# x: (batch, num_codebooks + 1, seq_len)
|
|
|
seq_len = inp.size(2)
|
|
|
-
|
|
|
- # For codebook, the decoding is actually shifted by 1
|
|
|
- # Which is the labels section
|
|
|
codebooks = inp[:, 1:]
|
|
|
|
|
|
# Here we want to merge the embeddings of the codebooks
|
|
|
@@ -228,13 +225,20 @@ class Transformer(nn.Module):
|
|
|
|
|
|
# Drop the last token and rotate left
|
|
|
codebooks = codebooks[:, :-1, 1:]
|
|
|
- codebooks = F.pad(codebooks, (0, 1), value=self.config.codebook_padding_idx)
|
|
|
+ codebooks = F.pad(
|
|
|
+ codebooks, (0, 1, 1, 0), value=self.config.codebook_padding_idx
|
|
|
+ )
|
|
|
codebook_embeddings = self.fast_embeddings(codebooks)
|
|
|
-
|
|
|
- x = torch.cat([x[:, None], codebook_embeddings], dim=1) # (B, N + 1, S, D)
|
|
|
+ x = codebook_embeddings + x[:, None] # (B, N + 1, S, D)
|
|
|
b, s = x.size(0), x.size(2)
|
|
|
x = rearrange(x, "b n s d -> (b s) n d") # flatten the batch and seq_len
|
|
|
|
|
|
+ # Remove padded part
|
|
|
+ codebooks = rearrange(codebooks, "b n s -> (b s) n")
|
|
|
+ codebook_mask = (codebooks == self.config.codebook_padding_idx).all(dim=-1)
|
|
|
+ x_bs, x_len = x.size(0), x.size(1)
|
|
|
+ x = x[~codebook_mask]
|
|
|
+
|
|
|
for layer in self.fast_layers:
|
|
|
if self.config.use_gradient_checkpointing and self.training:
|
|
|
x = checkpoint(layer, x, fast_freqs_cis, fast_mask, use_reentrant=True)
|
|
|
@@ -244,6 +248,12 @@ class Transformer(nn.Module):
|
|
|
# unflatten the batch and num_codebooks
|
|
|
fast_out = self.fast_norm(x)
|
|
|
codebook_logits = self.fast_output(fast_out)
|
|
|
+
|
|
|
+ # Re-pad the codebook_logits
|
|
|
+ buffer = torch.zeros(x_bs, x_len, codebook_logits.size(-1), device=x.device)
|
|
|
+ buffer[~codebook_mask] = codebook_logits
|
|
|
+ codebook_logits = buffer
|
|
|
+
|
|
|
assert codebook_logits.shape[1] == self.config.num_codebooks
|
|
|
codebook_logits = rearrange(
|
|
|
codebook_logits,
|
|
|
@@ -258,6 +268,22 @@ class Transformer(nn.Module):
|
|
|
codebook_logits=codebook_logits,
|
|
|
)
|
|
|
|
|
|
+ def forward_fast(self, x: Tensor) -> Tensor:
|
|
|
+ # Fast transformer
|
|
|
+ fast_seq_len = x.shape[1]
|
|
|
+ fast_mask = self.causal_mask[
|
|
|
+ None, None, :fast_seq_len, :fast_seq_len
|
|
|
+ ] # (B, N, Q, K)
|
|
|
+ fast_freqs_cis = self.freqs_cis[:fast_seq_len]
|
|
|
+
|
|
|
+ for layer in self.fast_layers:
|
|
|
+ x = layer(x, fast_freqs_cis, fast_mask)
|
|
|
+
|
|
|
+ fast_out = self.fast_norm(x)
|
|
|
+ codebook_logits = self.fast_output(fast_out)
|
|
|
+
|
|
|
+ return codebook_logits
|
|
|
+
|
|
|
def forward_generate_slow(
|
|
|
self, x: Tensor, input_pos: Optional[Tensor] = None
|
|
|
) -> Tensor:
|