|
@@ -135,7 +135,8 @@ class VQGAN(L.LightningModule):
|
|
|
x_1 = self.norm_spec(gt_mels)
|
|
x_1 = self.norm_spec(gt_mels)
|
|
|
|
|
|
|
|
if self.reflow_use_shallow:
|
|
if self.reflow_use_shallow:
|
|
|
- x_1_aux = self.norm_spec(gen_mel)
|
|
|
|
|
|
|
+ # Detach the gradient of the generated mel
|
|
|
|
|
+ x_1_aux = self.norm_spec(gen_mel.detach())
|
|
|
else:
|
|
else:
|
|
|
x_1_aux = x_1
|
|
x_1_aux = x_1
|
|
|
|
|
|