Browse Source

Add additional metrics

Lengyue 2 years ago
parent
commit
5bd88b131e
1 changed files with 21 additions and 0 deletions
  1. 21 0
      fish_speech/models/text2semantic/lit_module.py

+ 21 - 0
fish_speech/models/text2semantic/lit_module.py

@@ -312,6 +312,27 @@ class TextToSemantic(L.LightningModule):
             logger=True,
         )
 
+        if self.model.config.num_codebooks != self.model.config.num_in_codebooks:
+            _, indices = codebook_logits[
+                :, :, : self.model.config.num_in_codebooks
+            ].topk(5, dim=-1)
+            codebook_labels = codebook_labels[
+                :, :, : self.model.config.num_in_codebooks
+            ]
+            correct = indices.eq(codebook_labels.unsqueeze(-1))
+            correct[codebook_labels == -100] = 0
+            correct = correct.sum()
+            accuracy = correct / (codebook_labels != -100).sum()
+
+            self.log(
+                f"{stage}/top_5_accuracy_in",
+                accuracy,
+                on_step=True,
+                on_epoch=False,
+                prog_bar=True,
+                logger=True,
+            )
+
         return loss
 
     def training_step(self, batch, batch_idx):