| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435 |
- import math
- import random
- import torch
- from torch import nn
- from torch.nn import functional as F
- from sorawm.iopaint.plugins.basicsr.arch_util import default_init_weights
- class NormStyleCode(nn.Module):
- def forward(self, x):
- """Normalize the style codes.
- Args:
- x (Tensor): Style codes with shape (b, c).
- Returns:
- Tensor: Normalized tensor.
- """
- return x * torch.rsqrt(torch.mean(x**2, dim=1, keepdim=True) + 1e-8)
- class ModulatedConv2d(nn.Module):
- """Modulated Conv2d used in StyleGAN2.
- There is no bias in ModulatedConv2d.
- Args:
- in_channels (int): Channel number of the input.
- out_channels (int): Channel number of the output.
- kernel_size (int): Size of the convolving kernel.
- num_style_feat (int): Channel number of style features.
- demodulate (bool): Whether to demodulate in the conv layer. Default: True.
- sample_mode (str | None): Indicating 'upsample', 'downsample' or None. Default: None.
- eps (float): A value added to the denominator for numerical stability. Default: 1e-8.
- """
- def __init__(
- self,
- in_channels,
- out_channels,
- kernel_size,
- num_style_feat,
- demodulate=True,
- sample_mode=None,
- eps=1e-8,
- ):
- super(ModulatedConv2d, self).__init__()
- self.in_channels = in_channels
- self.out_channels = out_channels
- self.kernel_size = kernel_size
- self.demodulate = demodulate
- self.sample_mode = sample_mode
- self.eps = eps
- # modulation inside each modulated conv
- self.modulation = nn.Linear(num_style_feat, in_channels, bias=True)
- # initialization
- default_init_weights(
- self.modulation,
- scale=1,
- bias_fill=1,
- a=0,
- mode="fan_in",
- nonlinearity="linear",
- )
- self.weight = nn.Parameter(
- torch.randn(1, out_channels, in_channels, kernel_size, kernel_size)
- / math.sqrt(in_channels * kernel_size**2)
- )
- self.padding = kernel_size // 2
- def forward(self, x, style):
- """Forward function.
- Args:
- x (Tensor): Tensor with shape (b, c, h, w).
- style (Tensor): Tensor with shape (b, num_style_feat).
- Returns:
- Tensor: Modulated tensor after convolution.
- """
- b, c, h, w = x.shape # c = c_in
- # weight modulation
- style = self.modulation(style).view(b, 1, c, 1, 1)
- # self.weight: (1, c_out, c_in, k, k); style: (b, 1, c, 1, 1)
- weight = self.weight * style # (b, c_out, c_in, k, k)
- if self.demodulate:
- demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps)
- weight = weight * demod.view(b, self.out_channels, 1, 1, 1)
- weight = weight.view(
- b * self.out_channels, c, self.kernel_size, self.kernel_size
- )
- # upsample or downsample if necessary
- if self.sample_mode == "upsample":
- x = F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=False)
- elif self.sample_mode == "downsample":
- x = F.interpolate(x, scale_factor=0.5, mode="bilinear", align_corners=False)
- b, c, h, w = x.shape
- x = x.view(1, b * c, h, w)
- # weight: (b*c_out, c_in, k, k), groups=b
- out = F.conv2d(x, weight, padding=self.padding, groups=b)
- out = out.view(b, self.out_channels, *out.shape[2:4])
- return out
- def __repr__(self):
- return (
- f"{self.__class__.__name__}(in_channels={self.in_channels}, out_channels={self.out_channels}, "
- f"kernel_size={self.kernel_size}, demodulate={self.demodulate}, sample_mode={self.sample_mode})"
- )
- class StyleConv(nn.Module):
- """Style conv used in StyleGAN2.
- Args:
- in_channels (int): Channel number of the input.
- out_channels (int): Channel number of the output.
- kernel_size (int): Size of the convolving kernel.
- num_style_feat (int): Channel number of style features.
- demodulate (bool): Whether demodulate in the conv layer. Default: True.
- sample_mode (str | None): Indicating 'upsample', 'downsample' or None. Default: None.
- """
- def __init__(
- self,
- in_channels,
- out_channels,
- kernel_size,
- num_style_feat,
- demodulate=True,
- sample_mode=None,
- ):
- super(StyleConv, self).__init__()
- self.modulated_conv = ModulatedConv2d(
- in_channels,
- out_channels,
- kernel_size,
- num_style_feat,
- demodulate=demodulate,
- sample_mode=sample_mode,
- )
- self.weight = nn.Parameter(torch.zeros(1)) # for noise injection
- self.bias = nn.Parameter(torch.zeros(1, out_channels, 1, 1))
- self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True)
- def forward(self, x, style, noise=None):
- # modulate
- out = self.modulated_conv(x, style) * 2**0.5 # for conversion
- # noise injection
- if noise is None:
- b, _, h, w = out.shape
- noise = out.new_empty(b, 1, h, w).normal_()
- out = out + self.weight * noise
- # add bias
- out = out + self.bias
- # activation
- out = self.activate(out)
- return out
- class ToRGB(nn.Module):
- """To RGB (image space) from features.
- Args:
- in_channels (int): Channel number of input.
- num_style_feat (int): Channel number of style features.
- upsample (bool): Whether to upsample. Default: True.
- """
- def __init__(self, in_channels, num_style_feat, upsample=True):
- super(ToRGB, self).__init__()
- self.upsample = upsample
- self.modulated_conv = ModulatedConv2d(
- in_channels,
- 3,
- kernel_size=1,
- num_style_feat=num_style_feat,
- demodulate=False,
- sample_mode=None,
- )
- self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
- def forward(self, x, style, skip=None):
- """Forward function.
- Args:
- x (Tensor): Feature tensor with shape (b, c, h, w).
- style (Tensor): Tensor with shape (b, num_style_feat).
- skip (Tensor): Base/skip tensor. Default: None.
- Returns:
- Tensor: RGB images.
- """
- out = self.modulated_conv(x, style)
- out = out + self.bias
- if skip is not None:
- if self.upsample:
- skip = F.interpolate(
- skip, scale_factor=2, mode="bilinear", align_corners=False
- )
- out = out + skip
- return out
- class ConstantInput(nn.Module):
- """Constant input.
- Args:
- num_channel (int): Channel number of constant input.
- size (int): Spatial size of constant input.
- """
- def __init__(self, num_channel, size):
- super(ConstantInput, self).__init__()
- self.weight = nn.Parameter(torch.randn(1, num_channel, size, size))
- def forward(self, batch):
- out = self.weight.repeat(batch, 1, 1, 1)
- return out
- class StyleGAN2GeneratorClean(nn.Module):
- """Clean version of StyleGAN2 Generator.
- Args:
- out_size (int): The spatial size of outputs.
- num_style_feat (int): Channel number of style features. Default: 512.
- num_mlp (int): Layer number of MLP style layers. Default: 8.
- channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
- narrow (float): Narrow ratio for channels. Default: 1.0.
- """
- def __init__(
- self, out_size, num_style_feat=512, num_mlp=8, channel_multiplier=2, narrow=1
- ):
- super(StyleGAN2GeneratorClean, self).__init__()
- # Style MLP layers
- self.num_style_feat = num_style_feat
- style_mlp_layers = [NormStyleCode()]
- for i in range(num_mlp):
- style_mlp_layers.extend(
- [
- nn.Linear(num_style_feat, num_style_feat, bias=True),
- nn.LeakyReLU(negative_slope=0.2, inplace=True),
- ]
- )
- self.style_mlp = nn.Sequential(*style_mlp_layers)
- # initialization
- default_init_weights(
- self.style_mlp,
- scale=1,
- bias_fill=0,
- a=0.2,
- mode="fan_in",
- nonlinearity="leaky_relu",
- )
- # channel list
- channels = {
- "4": int(512 * narrow),
- "8": int(512 * narrow),
- "16": int(512 * narrow),
- "32": int(512 * narrow),
- "64": int(256 * channel_multiplier * narrow),
- "128": int(128 * channel_multiplier * narrow),
- "256": int(64 * channel_multiplier * narrow),
- "512": int(32 * channel_multiplier * narrow),
- "1024": int(16 * channel_multiplier * narrow),
- }
- self.channels = channels
- self.constant_input = ConstantInput(channels["4"], size=4)
- self.style_conv1 = StyleConv(
- channels["4"],
- channels["4"],
- kernel_size=3,
- num_style_feat=num_style_feat,
- demodulate=True,
- sample_mode=None,
- )
- self.to_rgb1 = ToRGB(channels["4"], num_style_feat, upsample=False)
- self.log_size = int(math.log(out_size, 2))
- self.num_layers = (self.log_size - 2) * 2 + 1
- self.num_latent = self.log_size * 2 - 2
- self.style_convs = nn.ModuleList()
- self.to_rgbs = nn.ModuleList()
- self.noises = nn.Module()
- in_channels = channels["4"]
- # noise
- for layer_idx in range(self.num_layers):
- resolution = 2 ** ((layer_idx + 5) // 2)
- shape = [1, 1, resolution, resolution]
- self.noises.register_buffer(f"noise{layer_idx}", torch.randn(*shape))
- # style convs and to_rgbs
- for i in range(3, self.log_size + 1):
- out_channels = channels[f"{2 ** i}"]
- self.style_convs.append(
- StyleConv(
- in_channels,
- out_channels,
- kernel_size=3,
- num_style_feat=num_style_feat,
- demodulate=True,
- sample_mode="upsample",
- )
- )
- self.style_convs.append(
- StyleConv(
- out_channels,
- out_channels,
- kernel_size=3,
- num_style_feat=num_style_feat,
- demodulate=True,
- sample_mode=None,
- )
- )
- self.to_rgbs.append(ToRGB(out_channels, num_style_feat, upsample=True))
- in_channels = out_channels
- def make_noise(self):
- """Make noise for noise injection."""
- device = self.constant_input.weight.device
- noises = [torch.randn(1, 1, 4, 4, device=device)]
- for i in range(3, self.log_size + 1):
- for _ in range(2):
- noises.append(torch.randn(1, 1, 2**i, 2**i, device=device))
- return noises
- def get_latent(self, x):
- return self.style_mlp(x)
- def mean_latent(self, num_latent):
- latent_in = torch.randn(
- num_latent, self.num_style_feat, device=self.constant_input.weight.device
- )
- latent = self.style_mlp(latent_in).mean(0, keepdim=True)
- return latent
- def forward(
- self,
- styles,
- input_is_latent=False,
- noise=None,
- randomize_noise=True,
- truncation=1,
- truncation_latent=None,
- inject_index=None,
- return_latents=False,
- ):
- """Forward function for StyleGAN2GeneratorClean.
- Args:
- styles (list[Tensor]): Sample codes of styles.
- input_is_latent (bool): Whether input is latent style. Default: False.
- noise (Tensor | None): Input noise or None. Default: None.
- randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
- truncation (float): The truncation ratio. Default: 1.
- truncation_latent (Tensor | None): The truncation latent tensor. Default: None.
- inject_index (int | None): The injection index for mixing noise. Default: None.
- return_latents (bool): Whether to return style latents. Default: False.
- """
- # style codes -> latents with Style MLP layer
- if not input_is_latent:
- styles = [self.style_mlp(s) for s in styles]
- # noises
- if noise is None:
- if randomize_noise:
- noise = [None] * self.num_layers # for each style conv layer
- else: # use the stored noise
- noise = [
- getattr(self.noises, f"noise{i}") for i in range(self.num_layers)
- ]
- # style truncation
- if truncation < 1:
- style_truncation = []
- for style in styles:
- style_truncation.append(
- truncation_latent + truncation * (style - truncation_latent)
- )
- styles = style_truncation
- # get style latents with injection
- if len(styles) == 1:
- inject_index = self.num_latent
- if styles[0].ndim < 3:
- # repeat latent code for all the layers
- latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
- else: # used for encoder with different latent code for each layer
- latent = styles[0]
- elif len(styles) == 2: # mixing noises
- if inject_index is None:
- inject_index = random.randint(1, self.num_latent - 1)
- latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
- latent2 = (
- styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
- )
- latent = torch.cat([latent1, latent2], 1)
- # main generation
- out = self.constant_input(latent.shape[0])
- out = self.style_conv1(out, latent[:, 0], noise=noise[0])
- skip = self.to_rgb1(out, latent[:, 1])
- i = 1
- for conv1, conv2, noise1, noise2, to_rgb in zip(
- self.style_convs[::2],
- self.style_convs[1::2],
- noise[1::2],
- noise[2::2],
- self.to_rgbs,
- ):
- out = conv1(out, latent[:, i], noise=noise1)
- out = conv2(out, latent[:, i + 1], noise=noise2)
- skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space
- i += 2
- image = skip
- if return_latents:
- return image, latent
- else:
- return image, None
|