Explorar o código

optimize dpo behavior

Lengyue %!s(int64=2) %!d(string=hai) anos
pai
achega
0e0332396f
Modificáronse 1 ficheiros con 1 adicións e 3 borrados
  1. 1 3
      fish_speech/models/text2semantic/lit_module.py

+ 1 - 3
fish_speech/models/text2semantic/lit_module.py

@@ -197,9 +197,7 @@ class TextToSemantic(L.LightningModule):
 
             chosen_rewards = self.dpo_beta * positive_codebook_logps.detach()
             rejected_rewards = self.dpo_beta * negative_codebook_logps.detach()
-            reward_accuracy = (
-                (positive_codebook_logps > negative_codebook_logps).float().mean()
-            )
+            reward_accuracy = (chosen_rewards > rejected_rewards).float().mean()
             chosen_rewards, rejected_rewards = (
                 chosen_rewards.mean(),
                 rejected_rewards.mean(),