| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126 |
- # Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
- # LICENSE is in incl_licenses directory.
- import torch
- from torch import nn, pow, sin
- from torch.nn import Parameter
- class Snake(nn.Module):
- """
- Implementation of a sine-based periodic activation function
- Shape:
- - Input: (B, C, T)
- - Output: (B, C, T), same shape as the input
- Parameters:
- - alpha - trainable parameter
- References:
- - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
- https://arxiv.org/abs/2006.08195
- Examples:
- >>> a1 = snake(256)
- >>> x = torch.randn(256)
- >>> x = a1(x)
- """
- def __init__(
- self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
- ):
- """
- Initialization.
- INPUT:
- - in_features: shape of the input
- - alpha: trainable parameter
- alpha is initialized to 1 by default, higher values = higher-frequency.
- alpha will be trained along with the rest of your model.
- """
- super(Snake, self).__init__()
- self.in_features = in_features
- # initialize alpha
- self.alpha_logscale = alpha_logscale
- if self.alpha_logscale: # log scale alphas initialized to zeros
- self.alpha = Parameter(torch.zeros(in_features) * alpha)
- else: # linear scale alphas initialized to ones
- self.alpha = Parameter(torch.ones(in_features) * alpha)
- self.alpha.requires_grad = alpha_trainable
- self.no_div_by_zero = 0.000000001
- def forward(self, x):
- """
- Forward pass of the function.
- Applies the function to the input elementwise.
- Snake ∶= x + 1/a * sin^2 (xa)
- """
- alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
- if self.alpha_logscale:
- alpha = torch.exp(alpha)
- x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
- return x
- class SnakeBeta(nn.Module):
- """
- A modified Snake function which uses separate parameters for the magnitude of the periodic components
- Shape:
- - Input: (B, C, T)
- - Output: (B, C, T), same shape as the input
- Parameters:
- - alpha - trainable parameter that controls frequency
- - beta - trainable parameter that controls magnitude
- References:
- - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
- https://arxiv.org/abs/2006.08195
- Examples:
- >>> a1 = snakebeta(256)
- >>> x = torch.randn(256)
- >>> x = a1(x)
- """
- def __init__(
- self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
- ):
- """
- Initialization.
- INPUT:
- - in_features: shape of the input
- - alpha - trainable parameter that controls frequency
- - beta - trainable parameter that controls magnitude
- alpha is initialized to 1 by default, higher values = higher-frequency.
- beta is initialized to 1 by default, higher values = higher-magnitude.
- alpha will be trained along with the rest of your model.
- """
- super(SnakeBeta, self).__init__()
- self.in_features = in_features
- # initialize alpha
- self.alpha_logscale = alpha_logscale
- if self.alpha_logscale: # log scale alphas initialized to zeros
- self.alpha = Parameter(torch.zeros(in_features) * alpha)
- self.beta = Parameter(torch.zeros(in_features) * alpha)
- else: # linear scale alphas initialized to ones
- self.alpha = Parameter(torch.ones(in_features) * alpha)
- self.beta = Parameter(torch.ones(in_features) * alpha)
- self.alpha.requires_grad = alpha_trainable
- self.beta.requires_grad = alpha_trainable
- self.no_div_by_zero = 0.000000001
- def forward(self, x):
- """
- Forward pass of the function.
- Applies the function to the input elementwise.
- SnakeBeta ∶= x + 1/b * sin^2 (xa)
- """
- alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
- beta = self.beta.unsqueeze(0).unsqueeze(-1)
- if self.alpha_logscale:
- alpha = torch.exp(alpha)
- beta = torch.exp(beta)
- x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
- return x
|