|
|
@@ -197,9 +197,7 @@ class TextToSemantic(L.LightningModule):
|
|
|
|
|
|
chosen_rewards = self.dpo_beta * positive_codebook_logps.detach()
|
|
|
rejected_rewards = self.dpo_beta * negative_codebook_logps.detach()
|
|
|
- reward_accuracy = (
|
|
|
- (positive_codebook_logps > negative_codebook_logps).float().mean()
|
|
|
- )
|
|
|
+ reward_accuracy = (chosen_rewards > rejected_rewards).float().mean()
|
|
|
chosen_rewards, rejected_rewards = (
|
|
|
chosen_rewards.mean(),
|
|
|
rejected_rewards.mean(),
|