normalization.py 1016 B

12345678910111213141516171819202122232425262728293031323334
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. class LayerNorm(nn.Module):
  5. def __init__(self, channels, eps=1e-5):
  6. super().__init__()
  7. self.channels = channels
  8. self.eps = eps
  9. self.gamma = nn.Parameter(torch.ones(channels))
  10. self.beta = nn.Parameter(torch.zeros(channels))
  11. def forward(self, x: torch.Tensor):
  12. x = F.layer_norm(x.mT, (self.channels,), self.gamma, self.beta, self.eps)
  13. return x.mT
  14. class CondLayerNorm(nn.Module):
  15. def __init__(self, channels, eps=1e-5, cond_channels=0):
  16. super().__init__()
  17. self.channels = channels
  18. self.eps = eps
  19. self.linear_gamma = nn.Linear(cond_channels, channels)
  20. self.linear_beta = nn.Linear(cond_channels, channels)
  21. def forward(self, x: torch.Tensor, cond: torch.Tensor):
  22. gamma = self.linear_gamma(cond)
  23. beta = self.linear_beta(cond)
  24. x = F.layer_norm(x.mT, (self.channels,), gamma, beta, self.eps)
  25. return x.mT