Explorar o código

Fix x_1 stop gradient

Lengyue %!s(int64=2) %!d(string=hai) anos
pai
achega
32c36dc12d
Modificáronse 1 ficheiros con 2 adicións e 1 borrados
  1. 2 1
      fish_speech/models/vqgan/lit_module.py

+ 2 - 1
fish_speech/models/vqgan/lit_module.py

@@ -135,7 +135,8 @@ class VQGAN(L.LightningModule):
         x_1 = self.norm_spec(gt_mels)
 
         if self.reflow_use_shallow:
-            x_1_aux = self.norm_spec(gen_mel)
+            # Detach the gradient of the generated mel
+            x_1_aux = self.norm_spec(gen_mel.detach())
         else:
             x_1_aux = x_1