|
|
@@ -189,7 +189,7 @@ class TextToSemantic(L.LightningModule):
|
|
|
# We want to shift the labels by one to the right
|
|
|
codebook_labels = labels[:, 1 : 1 + self.model.config.num_codebooks, :-1]
|
|
|
codebook_labels = torch.nn.functional.pad(
|
|
|
- codebook_labels, (0, 1), value=-100
|
|
|
+ codebook_labels, (1, 0), value=-100
|
|
|
).mT
|
|
|
|
|
|
semantic_loss = F.cross_entropy(
|