Bladeren bron

Fix stop gradient

Lengyue 2 jaren geleden
bovenliggende
commit
c08afee273
2 gewijzigde bestanden met toevoegingen van 2 en 2 verwijderingen
  1. 1 1
      fish_speech/configs/vqgan_pretrain.yaml
  2. 1 1
      fish_speech/models/vqgan/lit_module.py

+ 1 - 1
fish_speech/configs/vqgan_pretrain.yaml

@@ -11,7 +11,7 @@ trainer:
   precision: bf16-mixed
   max_steps: 1_000_000
   val_check_interval: 1000
-  strategy: ddp_find_unused_parameters_true
+  strategy: ddp
 
 sample_rate: 44100
 hop_length: 512

+ 1 - 1
fish_speech/models/vqgan/lit_module.py

@@ -148,7 +148,7 @@ class VQGAN(L.LightningModule):
         v_pred = self.reflow(
             x_t,
             1000 * t,
-            vq_recon_features,
+            vq_recon_features.detach(),  # Stop gradients, avoid reflow to destroy the VQ
         )
 
         # Log L2 loss with