Procházet zdrojové kódy

Fix x_1 stop gradient

Lengyue před 2 roky
rodič
revize
32c36dc12d
1 změnil soubory, kde provedl 2 přidání a 1 odebrání
  1. 2 1
      fish_speech/models/vqgan/lit_module.py

+ 2 - 1
fish_speech/models/vqgan/lit_module.py

@@ -135,7 +135,8 @@ class VQGAN(L.LightningModule):
         x_1 = self.norm_spec(gt_mels)
 
         if self.reflow_use_shallow:
-            x_1_aux = self.norm_spec(gen_mel)
+            # Detach the gradient of the generated mel
+            x_1_aux = self.norm_spec(gen_mel.detach())
         else:
             x_1_aux = x_1