Explorar el Código

Fix x_1 stop gradient

Lengyue hace 2 años
padre
commit
32c36dc12d
Se han modificado 1 ficheros con 2 adiciones y 1 borrados
  1. 2 1
      fish_speech/models/vqgan/lit_module.py

+ 2 - 1
fish_speech/models/vqgan/lit_module.py

@@ -135,7 +135,8 @@ class VQGAN(L.LightningModule):
         x_1 = self.norm_spec(gt_mels)
 
         if self.reflow_use_shallow:
-            x_1_aux = self.norm_spec(gen_mel)
+            # Detach the gradient of the generated mel
+            x_1_aux = self.norm_spec(gen_mel.detach())
         else:
             x_1_aux = x_1