common.py 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. class Hswish(nn.Module):
  5. def __init__(self, inplace=True):
  6. super(Hswish, self).__init__()
  7. self.inplace = inplace
  8. def forward(self, x):
  9. return x * F.relu6(x + 3.0, inplace=self.inplace) / 6.0
  10. # out = max(0, min(1, slop*x+offset))
  11. # paddle.fluid.layers.hard_sigmoid(x, slope=0.2, offset=0.5, name=None)
  12. class Hsigmoid(nn.Module):
  13. def __init__(self, inplace=True):
  14. super(Hsigmoid, self).__init__()
  15. self.inplace = inplace
  16. def forward(self, x):
  17. # torch: F.relu6(x + 3., inplace=self.inplace) / 6.
  18. # paddle: F.relu6(1.2 * x + 3., inplace=self.inplace) / 6.
  19. return F.relu6(1.2 * x + 3.0, inplace=self.inplace) / 6.0
  20. class GELU(nn.Module):
  21. def __init__(self, inplace=True):
  22. super(GELU, self).__init__()
  23. self.inplace = inplace
  24. def forward(self, x):
  25. return torch.nn.functional.gelu(x)
  26. class Swish(nn.Module):
  27. def __init__(self, inplace=True):
  28. super(Swish, self).__init__()
  29. self.inplace = inplace
  30. def forward(self, x):
  31. if self.inplace:
  32. x.mul_(torch.sigmoid(x))
  33. return x
  34. else:
  35. return x * torch.sigmoid(x)
  36. class Activation(nn.Module):
  37. def __init__(self, act_type, inplace=True):
  38. super(Activation, self).__init__()
  39. act_type = act_type.lower()
  40. if act_type == "relu":
  41. self.act = nn.ReLU(inplace=inplace)
  42. elif act_type == "relu6":
  43. self.act = nn.ReLU6(inplace=inplace)
  44. elif act_type == "sigmoid":
  45. raise NotImplementedError
  46. elif act_type == "hard_sigmoid":
  47. self.act = Hsigmoid(inplace)
  48. elif act_type == "hard_swish":
  49. self.act = Hswish(inplace=inplace)
  50. elif act_type == "leakyrelu":
  51. self.act = nn.LeakyReLU(inplace=inplace)
  52. elif act_type == "gelu":
  53. self.act = GELU(inplace=inplace)
  54. elif act_type == "swish":
  55. self.act = Swish(inplace=inplace)
  56. else:
  57. raise NotImplementedError
  58. def forward(self, inputs):
  59. return self.act(inputs)