|
|
@@ -225,11 +225,9 @@ class Transformer(nn.Module):
|
|
|
|
|
|
# Drop the last token and rotate left
|
|
|
codebooks = codebooks[:, :-1, 1:]
|
|
|
- codebooks = F.pad(
|
|
|
- codebooks, (0, 1, 1, 0), value=self.config.codebook_padding_idx
|
|
|
- )
|
|
|
+ codebooks = F.pad(codebooks, (0, 1), value=self.config.codebook_padding_idx)
|
|
|
codebook_embeddings = self.fast_embeddings(codebooks)
|
|
|
- x = codebook_embeddings + x[:, None] # (B, N + 1, S, D)
|
|
|
+ x = torch.cat([x[:, None], codebook_embeddings], dim=1)
|
|
|
b, s = x.size(0), x.size(2)
|
|
|
x = rearrange(x, "b n s d -> (b s) n d") # flatten the batch and seq_len
|
|
|
|
|
|
@@ -268,22 +266,6 @@ class Transformer(nn.Module):
|
|
|
codebook_logits=codebook_logits,
|
|
|
)
|
|
|
|
|
|
- def forward_fast(self, x: Tensor) -> Tensor:
|
|
|
- # Fast transformer
|
|
|
- fast_seq_len = x.shape[1]
|
|
|
- fast_mask = self.causal_mask[
|
|
|
- None, None, :fast_seq_len, :fast_seq_len
|
|
|
- ] # (B, N, Q, K)
|
|
|
- fast_freqs_cis = self.freqs_cis[:fast_seq_len]
|
|
|
-
|
|
|
- for layer in self.fast_layers:
|
|
|
- x = layer(x, fast_freqs_cis, fast_mask)
|
|
|
-
|
|
|
- fast_out = self.fast_norm(x)
|
|
|
- codebook_logits = self.fast_output(fast_out)
|
|
|
-
|
|
|
- return codebook_logits
|
|
|
-
|
|
|
def forward_generate_slow(
|
|
|
self, x: Tensor, input_pos: Optional[Tensor] = None
|
|
|
) -> Tensor:
|