|
|
@@ -312,6 +312,27 @@ class TextToSemantic(L.LightningModule):
|
|
|
logger=True,
|
|
|
)
|
|
|
|
|
|
+ if self.model.config.num_codebooks != self.model.config.num_in_codebooks:
|
|
|
+ _, indices = codebook_logits[
|
|
|
+ :, :, : self.model.config.num_in_codebooks
|
|
|
+ ].topk(5, dim=-1)
|
|
|
+ codebook_labels = codebook_labels[
|
|
|
+ :, :, : self.model.config.num_in_codebooks
|
|
|
+ ]
|
|
|
+ correct = indices.eq(codebook_labels.unsqueeze(-1))
|
|
|
+ correct[codebook_labels == -100] = 0
|
|
|
+ correct = correct.sum()
|
|
|
+ accuracy = correct / (codebook_labels != -100).sum()
|
|
|
+
|
|
|
+ self.log(
|
|
|
+ f"{stage}/top_5_accuracy_in",
|
|
|
+ accuracy,
|
|
|
+ on_step=True,
|
|
|
+ on_epoch=False,
|
|
|
+ prog_bar=True,
|
|
|
+ logger=True,
|
|
|
+ )
|
|
|
+
|
|
|
return loss
|
|
|
|
|
|
def training_step(self, batch, batch_idx):
|