|
|
@@ -188,17 +188,17 @@ class Transformer(nn.Module):
|
|
|
return x
|
|
|
|
|
|
def forward(
|
|
|
- self, x: Tensor, key_padding_mask: Optional[Tensor] = None
|
|
|
+ self, inp: Tensor, key_padding_mask: Optional[Tensor] = None
|
|
|
) -> TransformerForwardResult:
|
|
|
# x: (batch, num_codebooks + 1, seq_len)
|
|
|
- seq_len = x.size(2)
|
|
|
+ seq_len = inp.size(2)
|
|
|
|
|
|
# For codebook, the decoding is actually shifted by 1
|
|
|
# Which is the labels section
|
|
|
- codebooks = x[:, 1:]
|
|
|
+ codebooks = inp[:, 1:]
|
|
|
|
|
|
# Here we want to merge the embeddings of the codebooks
|
|
|
- x = self.embed(x)
|
|
|
+ x = self.embed(inp)
|
|
|
|
|
|
mask = self.causal_mask[None, None, :seq_len, :seq_len] # (B, N, Q, K)
|
|
|
freqs_cis = self.freqs_cis[:seq_len]
|
|
|
@@ -225,7 +225,11 @@ class Transformer(nn.Module):
|
|
|
None, None, :fast_seq_len, :fast_seq_len
|
|
|
] # (B, N, Q, K)
|
|
|
fast_freqs_cis = self.freqs_cis[:fast_seq_len]
|
|
|
- codebook_embeddings = self.fast_embeddings(codebooks[:, :-1])
|
|
|
+
|
|
|
+ # Drop the last token and rotate left
|
|
|
+ codebooks = codebooks[:, :-1, 1:]
|
|
|
+ codebooks = F.pad(codebooks, (0, 1), value=self.config.codebook_padding_idx)
|
|
|
+ codebook_embeddings = self.fast_embeddings(codebooks)
|
|
|
|
|
|
x = torch.cat([x[:, None], codebook_embeddings], dim=1) # (B, N + 1, S, D)
|
|
|
b, s = x.size(0), x.size(2)
|