activations.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. # Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
  2. # LICENSE is in incl_licenses directory.
  3. import torch
  4. from torch import nn, pow, sin
  5. from torch.nn import Parameter
  6. class Snake(nn.Module):
  7. """
  8. Implementation of a sine-based periodic activation function
  9. Shape:
  10. - Input: (B, C, T)
  11. - Output: (B, C, T), same shape as the input
  12. Parameters:
  13. - alpha - trainable parameter
  14. References:
  15. - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
  16. https://arxiv.org/abs/2006.08195
  17. Examples:
  18. >>> a1 = snake(256)
  19. >>> x = torch.randn(256)
  20. >>> x = a1(x)
  21. """
  22. def __init__(
  23. self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
  24. ):
  25. """
  26. Initialization.
  27. INPUT:
  28. - in_features: shape of the input
  29. - alpha: trainable parameter
  30. alpha is initialized to 1 by default, higher values = higher-frequency.
  31. alpha will be trained along with the rest of your model.
  32. """
  33. super(Snake, self).__init__()
  34. self.in_features = in_features
  35. # initialize alpha
  36. self.alpha_logscale = alpha_logscale
  37. if self.alpha_logscale: # log scale alphas initialized to zeros
  38. self.alpha = Parameter(torch.zeros(in_features) * alpha)
  39. else: # linear scale alphas initialized to ones
  40. self.alpha = Parameter(torch.ones(in_features) * alpha)
  41. self.alpha.requires_grad = alpha_trainable
  42. self.no_div_by_zero = 0.000000001
  43. def forward(self, x):
  44. """
  45. Forward pass of the function.
  46. Applies the function to the input elementwise.
  47. Snake ∶= x + 1/a * sin^2 (xa)
  48. """
  49. alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
  50. if self.alpha_logscale:
  51. alpha = torch.exp(alpha)
  52. x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
  53. return x
  54. class SnakeBeta(nn.Module):
  55. """
  56. A modified Snake function which uses separate parameters for the magnitude of the periodic components
  57. Shape:
  58. - Input: (B, C, T)
  59. - Output: (B, C, T), same shape as the input
  60. Parameters:
  61. - alpha - trainable parameter that controls frequency
  62. - beta - trainable parameter that controls magnitude
  63. References:
  64. - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
  65. https://arxiv.org/abs/2006.08195
  66. Examples:
  67. >>> a1 = snakebeta(256)
  68. >>> x = torch.randn(256)
  69. >>> x = a1(x)
  70. """
  71. def __init__(
  72. self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
  73. ):
  74. """
  75. Initialization.
  76. INPUT:
  77. - in_features: shape of the input
  78. - alpha - trainable parameter that controls frequency
  79. - beta - trainable parameter that controls magnitude
  80. alpha is initialized to 1 by default, higher values = higher-frequency.
  81. beta is initialized to 1 by default, higher values = higher-magnitude.
  82. alpha will be trained along with the rest of your model.
  83. """
  84. super(SnakeBeta, self).__init__()
  85. self.in_features = in_features
  86. # initialize alpha
  87. self.alpha_logscale = alpha_logscale
  88. if self.alpha_logscale: # log scale alphas initialized to zeros
  89. self.alpha = Parameter(torch.zeros(in_features) * alpha)
  90. self.beta = Parameter(torch.zeros(in_features) * alpha)
  91. else: # linear scale alphas initialized to ones
  92. self.alpha = Parameter(torch.ones(in_features) * alpha)
  93. self.beta = Parameter(torch.ones(in_features) * alpha)
  94. self.alpha.requires_grad = alpha_trainable
  95. self.beta.requires_grad = alpha_trainable
  96. self.no_div_by_zero = 0.000000001
  97. def forward(self, x):
  98. """
  99. Forward pass of the function.
  100. Applies the function to the input elementwise.
  101. SnakeBeta ∶= x + 1/b * sin^2 (xa)
  102. """
  103. alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
  104. beta = self.beta.unsqueeze(0).unsqueeze(-1)
  105. if self.alpha_logscale:
  106. alpha = torch.exp(alpha)
  107. beta = torch.exp(beta)
  108. x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
  109. return x