Parcourir la source

Fix x_1 stop gradient

Lengyue il y a 2 ans
Parent
commit
32c36dc12d
1 fichiers modifiés avec 2 ajouts et 1 suppressions
  1. 2 1
      fish_speech/models/vqgan/lit_module.py

+ 2 - 1
fish_speech/models/vqgan/lit_module.py

@@ -135,7 +135,8 @@ class VQGAN(L.LightningModule):
         x_1 = self.norm_spec(gt_mels)
 
         if self.reflow_use_shallow:
-            x_1_aux = self.norm_spec(gen_mel)
+            # Detach the gradient of the generated mel
+            x_1_aux = self.norm_spec(gen_mel.detach())
         else:
             x_1_aux = x_1