discriminator.py 1.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344
  1. import torch
  2. from torch import nn
  3. from torch.nn.utils.parametrizations import weight_norm
  4. class Discriminator(nn.Module):
  5. def __init__(self):
  6. super().__init__()
  7. blocks = []
  8. convs = [
  9. (1, 64, (3, 9), 1, (1, 4)),
  10. (64, 128, (3, 9), (1, 2), (1, 4)),
  11. (128, 256, (3, 9), (1, 2), (1, 4)),
  12. (256, 512, (3, 9), (1, 2), (1, 4)),
  13. (512, 1024, (3, 3), 1, (1, 1)),
  14. (1024, 1, (3, 3), 1, (1, 1)),
  15. ]
  16. for idx, (in_channels, out_channels, kernel_size, stride, padding) in enumerate(
  17. convs
  18. ):
  19. blocks.append(
  20. weight_norm(
  21. nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
  22. )
  23. )
  24. if idx != len(convs) - 1:
  25. blocks.append(nn.SiLU(inplace=True))
  26. self.blocks = nn.Sequential(*blocks)
  27. def forward(self, x):
  28. return self.blocks(x[:, None])[:, 0]
  29. if __name__ == "__main__":
  30. model = Discriminator()
  31. print(sum(p.numel() for p in model.parameters()) / 1_000_000)
  32. x = torch.randn(1, 128, 1024)
  33. y = model(x)
  34. print(y.shape)
  35. print(y)