Browse Source

optimize dpo behavior

Lengyue 2 years ago
parent
commit
0e0332396f
1 changed files with 1 additions and 3 deletions
  1. 1 3
      fish_speech/models/text2semantic/lit_module.py

+ 1 - 3
fish_speech/models/text2semantic/lit_module.py

@@ -197,9 +197,7 @@ class TextToSemantic(L.LightningModule):
 
             chosen_rewards = self.dpo_beta * positive_codebook_logps.detach()
             rejected_rewards = self.dpo_beta * negative_codebook_logps.detach()
-            reward_accuracy = (
-                (positive_codebook_logps > negative_codebook_logps).float().mean()
-            )
+            reward_accuracy = (chosen_rewards > rejected_rewards).float().mean()
             chosen_rewards, rejected_rewards = (
                 chosen_rewards.mean(),
                 rejected_rewards.mean(),