| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210 |
- """Modified from https://github.com/chaofengc/PSFRGAN
- """
- import numpy as np
- import torch.nn as nn
- from torch.nn import functional as F
- class NormLayer(nn.Module):
- """Normalization Layers.
- Args:
- channels: input channels, for batch norm and instance norm.
- input_size: input shape without batch size, for layer norm.
- """
- def __init__(self, channels, normalize_shape=None, norm_type="bn"):
- super(NormLayer, self).__init__()
- norm_type = norm_type.lower()
- self.norm_type = norm_type
- if norm_type == "bn":
- self.norm = nn.BatchNorm2d(channels, affine=True)
- elif norm_type == "in":
- self.norm = nn.InstanceNorm2d(channels, affine=False)
- elif norm_type == "gn":
- self.norm = nn.GroupNorm(32, channels, affine=True)
- elif norm_type == "pixel":
- self.norm = lambda x: F.normalize(x, p=2, dim=1)
- elif norm_type == "layer":
- self.norm = nn.LayerNorm(normalize_shape)
- elif norm_type == "none":
- self.norm = lambda x: x * 1.0
- else:
- assert 1 == 0, f"Norm type {norm_type} not support."
- def forward(self, x, ref=None):
- if self.norm_type == "spade":
- return self.norm(x, ref)
- else:
- return self.norm(x)
- class ReluLayer(nn.Module):
- """Relu Layer.
- Args:
- relu type: type of relu layer, candidates are
- - ReLU
- - LeakyReLU: default relu slope 0.2
- - PRelu
- - SELU
- - none: direct pass
- """
- def __init__(self, channels, relu_type="relu"):
- super(ReluLayer, self).__init__()
- relu_type = relu_type.lower()
- if relu_type == "relu":
- self.func = nn.ReLU(True)
- elif relu_type == "leakyrelu":
- self.func = nn.LeakyReLU(0.2, inplace=True)
- elif relu_type == "prelu":
- self.func = nn.PReLU(channels)
- elif relu_type == "selu":
- self.func = nn.SELU(True)
- elif relu_type == "none":
- self.func = lambda x: x * 1.0
- else:
- assert 1 == 0, f"Relu type {relu_type} not support."
- def forward(self, x):
- return self.func(x)
- class ConvLayer(nn.Module):
- def __init__(
- self,
- in_channels,
- out_channels,
- kernel_size=3,
- scale="none",
- norm_type="none",
- relu_type="none",
- use_pad=True,
- bias=True,
- ):
- super(ConvLayer, self).__init__()
- self.use_pad = use_pad
- self.norm_type = norm_type
- if norm_type in ["bn"]:
- bias = False
- stride = 2 if scale == "down" else 1
- self.scale_func = lambda x: x
- if scale == "up":
- self.scale_func = lambda x: nn.functional.interpolate(
- x, scale_factor=2, mode="nearest"
- )
- self.reflection_pad = nn.ReflectionPad2d(int(np.ceil((kernel_size - 1.0) / 2)))
- self.conv2d = nn.Conv2d(
- in_channels, out_channels, kernel_size, stride, bias=bias
- )
- self.relu = ReluLayer(out_channels, relu_type)
- self.norm = NormLayer(out_channels, norm_type=norm_type)
- def forward(self, x):
- out = self.scale_func(x)
- if self.use_pad:
- out = self.reflection_pad(out)
- out = self.conv2d(out)
- out = self.norm(out)
- out = self.relu(out)
- return out
- class ResidualBlock(nn.Module):
- """
- Residual block recommended in: http://torch.ch/blog/2016/02/04/resnets.html
- """
- def __init__(self, c_in, c_out, relu_type="prelu", norm_type="bn", scale="none"):
- super(ResidualBlock, self).__init__()
- if scale == "none" and c_in == c_out:
- self.shortcut_func = lambda x: x
- else:
- self.shortcut_func = ConvLayer(c_in, c_out, 3, scale)
- scale_config_dict = {
- "down": ["none", "down"],
- "up": ["up", "none"],
- "none": ["none", "none"],
- }
- scale_conf = scale_config_dict[scale]
- self.conv1 = ConvLayer(
- c_in, c_out, 3, scale_conf[0], norm_type=norm_type, relu_type=relu_type
- )
- self.conv2 = ConvLayer(
- c_out, c_out, 3, scale_conf[1], norm_type=norm_type, relu_type="none"
- )
- def forward(self, x):
- identity = self.shortcut_func(x)
- res = self.conv1(x)
- res = self.conv2(res)
- return identity + res
- class ParseNet(nn.Module):
- def __init__(
- self,
- in_size=128,
- out_size=128,
- min_feat_size=32,
- base_ch=64,
- parsing_ch=19,
- res_depth=10,
- relu_type="LeakyReLU",
- norm_type="bn",
- ch_range=[32, 256],
- ):
- super().__init__()
- self.res_depth = res_depth
- act_args = {"norm_type": norm_type, "relu_type": relu_type}
- min_ch, max_ch = ch_range
- ch_clip = lambda x: max(min_ch, min(x, max_ch)) # noqa: E731
- min_feat_size = min(in_size, min_feat_size)
- down_steps = int(np.log2(in_size // min_feat_size))
- up_steps = int(np.log2(out_size // min_feat_size))
- # =============== define encoder-body-decoder ====================
- self.encoder = []
- self.encoder.append(ConvLayer(3, base_ch, 3, 1))
- head_ch = base_ch
- for i in range(down_steps):
- cin, cout = ch_clip(head_ch), ch_clip(head_ch * 2)
- self.encoder.append(ResidualBlock(cin, cout, scale="down", **act_args))
- head_ch = head_ch * 2
- self.body = []
- for i in range(res_depth):
- self.body.append(
- ResidualBlock(ch_clip(head_ch), ch_clip(head_ch), **act_args)
- )
- self.decoder = []
- for i in range(up_steps):
- cin, cout = ch_clip(head_ch), ch_clip(head_ch // 2)
- self.decoder.append(ResidualBlock(cin, cout, scale="up", **act_args))
- head_ch = head_ch // 2
- self.encoder = nn.Sequential(*self.encoder)
- self.body = nn.Sequential(*self.body)
- self.decoder = nn.Sequential(*self.decoder)
- self.out_img_conv = ConvLayer(ch_clip(head_ch), 3)
- self.out_mask_conv = ConvLayer(ch_clip(head_ch), parsing_ch)
- def forward(self, x):
- feat = self.encoder(x)
- x = feat + self.body(feat)
- x = self.decoder(x)
- out_img = self.out_img_conv(x)
- out_mask = self.out_mask_conv(x)
- return out_mask, out_img
|