Explorar o código

fix:(vqgan) to remove not calculated loss (#126)

* fix:(vqgan) to remove not calculated loss

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Stardust·减 %!s(int64=2) %!d(string=hai) anos
pai
achega
94218857d5
Modificáronse 1 ficheiros con 10 adicións e 3 borrados
  1. 10 3
      fish_speech/models/vqgan/modules/discriminator.py

+ 10 - 3
fish_speech/models/vqgan/modules/discriminator.py

@@ -2,6 +2,7 @@ import torch
 from torch import nn
 from torch.nn.utils.parametrizations import weight_norm
 
+
 class Discriminator(nn.Module):
     def __init__(self):
         super().__init__()
@@ -16,12 +17,18 @@ class Discriminator(nn.Module):
             (1024, 1, (3, 3), 1, (1, 1)),
         ]
 
-        for idx, (in_channels, out_channels, kernel_size, stride, padding) in enumerate(convs):
-            blocks.append(weight_norm(nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)))
+        for idx, (in_channels, out_channels, kernel_size, stride, padding) in enumerate(
+            convs
+        ):
+            blocks.append(
+                weight_norm(
+                    nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
+                )
+            )
 
             if idx != len(convs) - 1:
                 blocks.append(nn.SiLU(inplace=True))
-        
+
         self.blocks = nn.Sequential(*blocks)
 
     def forward(self, x):