parsenet.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. """Modified from https://github.com/chaofengc/PSFRGAN
  2. """
  3. import numpy as np
  4. import torch.nn as nn
  5. from torch.nn import functional as F
  6. class NormLayer(nn.Module):
  7. """Normalization Layers.
  8. Args:
  9. channels: input channels, for batch norm and instance norm.
  10. input_size: input shape without batch size, for layer norm.
  11. """
  12. def __init__(self, channels, normalize_shape=None, norm_type="bn"):
  13. super(NormLayer, self).__init__()
  14. norm_type = norm_type.lower()
  15. self.norm_type = norm_type
  16. if norm_type == "bn":
  17. self.norm = nn.BatchNorm2d(channels, affine=True)
  18. elif norm_type == "in":
  19. self.norm = nn.InstanceNorm2d(channels, affine=False)
  20. elif norm_type == "gn":
  21. self.norm = nn.GroupNorm(32, channels, affine=True)
  22. elif norm_type == "pixel":
  23. self.norm = lambda x: F.normalize(x, p=2, dim=1)
  24. elif norm_type == "layer":
  25. self.norm = nn.LayerNorm(normalize_shape)
  26. elif norm_type == "none":
  27. self.norm = lambda x: x * 1.0
  28. else:
  29. assert 1 == 0, f"Norm type {norm_type} not support."
  30. def forward(self, x, ref=None):
  31. if self.norm_type == "spade":
  32. return self.norm(x, ref)
  33. else:
  34. return self.norm(x)
  35. class ReluLayer(nn.Module):
  36. """Relu Layer.
  37. Args:
  38. relu type: type of relu layer, candidates are
  39. - ReLU
  40. - LeakyReLU: default relu slope 0.2
  41. - PRelu
  42. - SELU
  43. - none: direct pass
  44. """
  45. def __init__(self, channels, relu_type="relu"):
  46. super(ReluLayer, self).__init__()
  47. relu_type = relu_type.lower()
  48. if relu_type == "relu":
  49. self.func = nn.ReLU(True)
  50. elif relu_type == "leakyrelu":
  51. self.func = nn.LeakyReLU(0.2, inplace=True)
  52. elif relu_type == "prelu":
  53. self.func = nn.PReLU(channels)
  54. elif relu_type == "selu":
  55. self.func = nn.SELU(True)
  56. elif relu_type == "none":
  57. self.func = lambda x: x * 1.0
  58. else:
  59. assert 1 == 0, f"Relu type {relu_type} not support."
  60. def forward(self, x):
  61. return self.func(x)
  62. class ConvLayer(nn.Module):
  63. def __init__(
  64. self,
  65. in_channels,
  66. out_channels,
  67. kernel_size=3,
  68. scale="none",
  69. norm_type="none",
  70. relu_type="none",
  71. use_pad=True,
  72. bias=True,
  73. ):
  74. super(ConvLayer, self).__init__()
  75. self.use_pad = use_pad
  76. self.norm_type = norm_type
  77. if norm_type in ["bn"]:
  78. bias = False
  79. stride = 2 if scale == "down" else 1
  80. self.scale_func = lambda x: x
  81. if scale == "up":
  82. self.scale_func = lambda x: nn.functional.interpolate(
  83. x, scale_factor=2, mode="nearest"
  84. )
  85. self.reflection_pad = nn.ReflectionPad2d(int(np.ceil((kernel_size - 1.0) / 2)))
  86. self.conv2d = nn.Conv2d(
  87. in_channels, out_channels, kernel_size, stride, bias=bias
  88. )
  89. self.relu = ReluLayer(out_channels, relu_type)
  90. self.norm = NormLayer(out_channels, norm_type=norm_type)
  91. def forward(self, x):
  92. out = self.scale_func(x)
  93. if self.use_pad:
  94. out = self.reflection_pad(out)
  95. out = self.conv2d(out)
  96. out = self.norm(out)
  97. out = self.relu(out)
  98. return out
  99. class ResidualBlock(nn.Module):
  100. """
  101. Residual block recommended in: http://torch.ch/blog/2016/02/04/resnets.html
  102. """
  103. def __init__(self, c_in, c_out, relu_type="prelu", norm_type="bn", scale="none"):
  104. super(ResidualBlock, self).__init__()
  105. if scale == "none" and c_in == c_out:
  106. self.shortcut_func = lambda x: x
  107. else:
  108. self.shortcut_func = ConvLayer(c_in, c_out, 3, scale)
  109. scale_config_dict = {
  110. "down": ["none", "down"],
  111. "up": ["up", "none"],
  112. "none": ["none", "none"],
  113. }
  114. scale_conf = scale_config_dict[scale]
  115. self.conv1 = ConvLayer(
  116. c_in, c_out, 3, scale_conf[0], norm_type=norm_type, relu_type=relu_type
  117. )
  118. self.conv2 = ConvLayer(
  119. c_out, c_out, 3, scale_conf[1], norm_type=norm_type, relu_type="none"
  120. )
  121. def forward(self, x):
  122. identity = self.shortcut_func(x)
  123. res = self.conv1(x)
  124. res = self.conv2(res)
  125. return identity + res
  126. class ParseNet(nn.Module):
  127. def __init__(
  128. self,
  129. in_size=128,
  130. out_size=128,
  131. min_feat_size=32,
  132. base_ch=64,
  133. parsing_ch=19,
  134. res_depth=10,
  135. relu_type="LeakyReLU",
  136. norm_type="bn",
  137. ch_range=[32, 256],
  138. ):
  139. super().__init__()
  140. self.res_depth = res_depth
  141. act_args = {"norm_type": norm_type, "relu_type": relu_type}
  142. min_ch, max_ch = ch_range
  143. ch_clip = lambda x: max(min_ch, min(x, max_ch)) # noqa: E731
  144. min_feat_size = min(in_size, min_feat_size)
  145. down_steps = int(np.log2(in_size // min_feat_size))
  146. up_steps = int(np.log2(out_size // min_feat_size))
  147. # =============== define encoder-body-decoder ====================
  148. self.encoder = []
  149. self.encoder.append(ConvLayer(3, base_ch, 3, 1))
  150. head_ch = base_ch
  151. for i in range(down_steps):
  152. cin, cout = ch_clip(head_ch), ch_clip(head_ch * 2)
  153. self.encoder.append(ResidualBlock(cin, cout, scale="down", **act_args))
  154. head_ch = head_ch * 2
  155. self.body = []
  156. for i in range(res_depth):
  157. self.body.append(
  158. ResidualBlock(ch_clip(head_ch), ch_clip(head_ch), **act_args)
  159. )
  160. self.decoder = []
  161. for i in range(up_steps):
  162. cin, cout = ch_clip(head_ch), ch_clip(head_ch // 2)
  163. self.decoder.append(ResidualBlock(cin, cout, scale="up", **act_args))
  164. head_ch = head_ch // 2
  165. self.encoder = nn.Sequential(*self.encoder)
  166. self.body = nn.Sequential(*self.body)
  167. self.decoder = nn.Sequential(*self.decoder)
  168. self.out_img_conv = ConvLayer(ch_clip(head_ch), 3)
  169. self.out_mask_conv = ConvLayer(ch_clip(head_ch), parsing_ch)
  170. def forward(self, x):
  171. feat = self.encoder(x)
  172. x = feat + self.body(feat)
  173. x = self.decoder(x)
  174. out_img = self.out_img_conv(x)
  175. out_mask = self.out_mask_conv(x)
  176. return out_mask, out_img