@@ -11,7 +11,7 @@ trainer:
precision: bf16-mixed
max_steps: 1_000_000
val_check_interval: 1000
- strategy: ddp_find_unused_parameters_true
+ strategy: ddp
sample_rate: 44100
hop_length: 512
@@ -148,7 +148,7 @@ class VQGAN(L.LightningModule):
v_pred = self.reflow(
x_t,
1000 * t,
- vq_recon_features,
+ vq_recon_features.detach(), # Stop gradients, avoid reflow to destroy the VQ
)
# Log L2 loss with