|
|
@@ -2,6 +2,7 @@ import torch
|
|
|
from torch import nn
|
|
|
from torch.nn.utils.parametrizations import weight_norm
|
|
|
|
|
|
+
|
|
|
class Discriminator(nn.Module):
|
|
|
def __init__(self):
|
|
|
super().__init__()
|
|
|
@@ -16,12 +17,18 @@ class Discriminator(nn.Module):
|
|
|
(1024, 1, (3, 3), 1, (1, 1)),
|
|
|
]
|
|
|
|
|
|
- for idx, (in_channels, out_channels, kernel_size, stride, padding) in enumerate(convs):
|
|
|
- blocks.append(weight_norm(nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)))
|
|
|
+ for idx, (in_channels, out_channels, kernel_size, stride, padding) in enumerate(
|
|
|
+ convs
|
|
|
+ ):
|
|
|
+ blocks.append(
|
|
|
+ weight_norm(
|
|
|
+ nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
|
|
|
+ )
|
|
|
+ )
|
|
|
|
|
|
if idx != len(convs) - 1:
|
|
|
blocks.append(nn.SiLU(inplace=True))
|
|
|
-
|
|
|
+
|
|
|
self.blocks = nn.Sequential(*blocks)
|
|
|
|
|
|
def forward(self, x):
|