RecMv1_enhance.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from .common import Activation
  5. class ConvBNLayer(nn.Module):
  6. def __init__(
  7. self,
  8. num_channels,
  9. filter_size,
  10. num_filters,
  11. stride,
  12. padding,
  13. channels=None,
  14. num_groups=1,
  15. act="hard_swish",
  16. ):
  17. super(ConvBNLayer, self).__init__()
  18. self.act = act
  19. self._conv = nn.Conv2d(
  20. in_channels=num_channels,
  21. out_channels=num_filters,
  22. kernel_size=filter_size,
  23. stride=stride,
  24. padding=padding,
  25. groups=num_groups,
  26. bias=False,
  27. )
  28. self._batch_norm = nn.BatchNorm2d(
  29. num_filters,
  30. )
  31. if self.act is not None:
  32. self._act = Activation(act_type=act, inplace=True)
  33. def forward(self, inputs):
  34. y = self._conv(inputs)
  35. y = self._batch_norm(y)
  36. if self.act is not None:
  37. y = self._act(y)
  38. return y
  39. class DepthwiseSeparable(nn.Module):
  40. def __init__(
  41. self,
  42. num_channels,
  43. num_filters1,
  44. num_filters2,
  45. num_groups,
  46. stride,
  47. scale,
  48. dw_size=3,
  49. padding=1,
  50. use_se=False,
  51. ):
  52. super(DepthwiseSeparable, self).__init__()
  53. self.use_se = use_se
  54. self._depthwise_conv = ConvBNLayer(
  55. num_channels=num_channels,
  56. num_filters=int(num_filters1 * scale),
  57. filter_size=dw_size,
  58. stride=stride,
  59. padding=padding,
  60. num_groups=int(num_groups * scale),
  61. )
  62. if use_se:
  63. self._se = SEModule(int(num_filters1 * scale))
  64. self._pointwise_conv = ConvBNLayer(
  65. num_channels=int(num_filters1 * scale),
  66. filter_size=1,
  67. num_filters=int(num_filters2 * scale),
  68. stride=1,
  69. padding=0,
  70. )
  71. def forward(self, inputs):
  72. y = self._depthwise_conv(inputs)
  73. if self.use_se:
  74. y = self._se(y)
  75. y = self._pointwise_conv(y)
  76. return y
  77. class MobileNetV1Enhance(nn.Module):
  78. def __init__(
  79. self,
  80. in_channels=3,
  81. scale=0.5,
  82. last_conv_stride=1,
  83. last_pool_type="max",
  84. **kwargs,
  85. ):
  86. super().__init__()
  87. self.scale = scale
  88. self.block_list = []
  89. self.conv1 = ConvBNLayer(
  90. num_channels=in_channels,
  91. filter_size=3,
  92. channels=3,
  93. num_filters=int(32 * scale),
  94. stride=2,
  95. padding=1,
  96. )
  97. conv2_1 = DepthwiseSeparable(
  98. num_channels=int(32 * scale),
  99. num_filters1=32,
  100. num_filters2=64,
  101. num_groups=32,
  102. stride=1,
  103. scale=scale,
  104. )
  105. self.block_list.append(conv2_1)
  106. conv2_2 = DepthwiseSeparable(
  107. num_channels=int(64 * scale),
  108. num_filters1=64,
  109. num_filters2=128,
  110. num_groups=64,
  111. stride=1,
  112. scale=scale,
  113. )
  114. self.block_list.append(conv2_2)
  115. conv3_1 = DepthwiseSeparable(
  116. num_channels=int(128 * scale),
  117. num_filters1=128,
  118. num_filters2=128,
  119. num_groups=128,
  120. stride=1,
  121. scale=scale,
  122. )
  123. self.block_list.append(conv3_1)
  124. conv3_2 = DepthwiseSeparable(
  125. num_channels=int(128 * scale),
  126. num_filters1=128,
  127. num_filters2=256,
  128. num_groups=128,
  129. stride=(2, 1),
  130. scale=scale,
  131. )
  132. self.block_list.append(conv3_2)
  133. conv4_1 = DepthwiseSeparable(
  134. num_channels=int(256 * scale),
  135. num_filters1=256,
  136. num_filters2=256,
  137. num_groups=256,
  138. stride=1,
  139. scale=scale,
  140. )
  141. self.block_list.append(conv4_1)
  142. conv4_2 = DepthwiseSeparable(
  143. num_channels=int(256 * scale),
  144. num_filters1=256,
  145. num_filters2=512,
  146. num_groups=256,
  147. stride=(2, 1),
  148. scale=scale,
  149. )
  150. self.block_list.append(conv4_2)
  151. for _ in range(5):
  152. conv5 = DepthwiseSeparable(
  153. num_channels=int(512 * scale),
  154. num_filters1=512,
  155. num_filters2=512,
  156. num_groups=512,
  157. stride=1,
  158. dw_size=5,
  159. padding=2,
  160. scale=scale,
  161. use_se=False,
  162. )
  163. self.block_list.append(conv5)
  164. conv5_6 = DepthwiseSeparable(
  165. num_channels=int(512 * scale),
  166. num_filters1=512,
  167. num_filters2=1024,
  168. num_groups=512,
  169. stride=(2, 1),
  170. dw_size=5,
  171. padding=2,
  172. scale=scale,
  173. use_se=True,
  174. )
  175. self.block_list.append(conv5_6)
  176. conv6 = DepthwiseSeparable(
  177. num_channels=int(1024 * scale),
  178. num_filters1=1024,
  179. num_filters2=1024,
  180. num_groups=1024,
  181. stride=last_conv_stride,
  182. dw_size=5,
  183. padding=2,
  184. use_se=True,
  185. scale=scale,
  186. )
  187. self.block_list.append(conv6)
  188. self.block_list = nn.Sequential(*self.block_list)
  189. if last_pool_type == "avg":
  190. self.pool = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
  191. else:
  192. self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
  193. self.out_channels = int(1024 * scale)
  194. def forward(self, inputs):
  195. y = self.conv1(inputs)
  196. y = self.block_list(y)
  197. y = self.pool(y)
  198. return y
  199. def hardsigmoid(x):
  200. return F.relu6(x + 3.0, inplace=True) / 6.0
  201. class SEModule(nn.Module):
  202. def __init__(self, channel, reduction=4):
  203. super(SEModule, self).__init__()
  204. self.avg_pool = nn.AdaptiveAvgPool2d(1)
  205. self.conv1 = nn.Conv2d(
  206. in_channels=channel,
  207. out_channels=channel // reduction,
  208. kernel_size=1,
  209. stride=1,
  210. padding=0,
  211. bias=True,
  212. )
  213. self.conv2 = nn.Conv2d(
  214. in_channels=channel // reduction,
  215. out_channels=channel,
  216. kernel_size=1,
  217. stride=1,
  218. padding=0,
  219. bias=True,
  220. )
  221. def forward(self, inputs):
  222. outputs = self.avg_pool(inputs)
  223. outputs = self.conv1(outputs)
  224. outputs = F.relu(outputs)
  225. outputs = self.conv2(outputs)
  226. outputs = hardsigmoid(outputs)
  227. x = torch.mul(inputs, outputs)
  228. return x