stylegan2_clean_arch.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435
  1. import math
  2. import random
  3. import torch
  4. from torch import nn
  5. from torch.nn import functional as F
  6. from sorawm.iopaint.plugins.basicsr.arch_util import default_init_weights
  7. class NormStyleCode(nn.Module):
  8. def forward(self, x):
  9. """Normalize the style codes.
  10. Args:
  11. x (Tensor): Style codes with shape (b, c).
  12. Returns:
  13. Tensor: Normalized tensor.
  14. """
  15. return x * torch.rsqrt(torch.mean(x**2, dim=1, keepdim=True) + 1e-8)
  16. class ModulatedConv2d(nn.Module):
  17. """Modulated Conv2d used in StyleGAN2.
  18. There is no bias in ModulatedConv2d.
  19. Args:
  20. in_channels (int): Channel number of the input.
  21. out_channels (int): Channel number of the output.
  22. kernel_size (int): Size of the convolving kernel.
  23. num_style_feat (int): Channel number of style features.
  24. demodulate (bool): Whether to demodulate in the conv layer. Default: True.
  25. sample_mode (str | None): Indicating 'upsample', 'downsample' or None. Default: None.
  26. eps (float): A value added to the denominator for numerical stability. Default: 1e-8.
  27. """
  28. def __init__(
  29. self,
  30. in_channels,
  31. out_channels,
  32. kernel_size,
  33. num_style_feat,
  34. demodulate=True,
  35. sample_mode=None,
  36. eps=1e-8,
  37. ):
  38. super(ModulatedConv2d, self).__init__()
  39. self.in_channels = in_channels
  40. self.out_channels = out_channels
  41. self.kernel_size = kernel_size
  42. self.demodulate = demodulate
  43. self.sample_mode = sample_mode
  44. self.eps = eps
  45. # modulation inside each modulated conv
  46. self.modulation = nn.Linear(num_style_feat, in_channels, bias=True)
  47. # initialization
  48. default_init_weights(
  49. self.modulation,
  50. scale=1,
  51. bias_fill=1,
  52. a=0,
  53. mode="fan_in",
  54. nonlinearity="linear",
  55. )
  56. self.weight = nn.Parameter(
  57. torch.randn(1, out_channels, in_channels, kernel_size, kernel_size)
  58. / math.sqrt(in_channels * kernel_size**2)
  59. )
  60. self.padding = kernel_size // 2
  61. def forward(self, x, style):
  62. """Forward function.
  63. Args:
  64. x (Tensor): Tensor with shape (b, c, h, w).
  65. style (Tensor): Tensor with shape (b, num_style_feat).
  66. Returns:
  67. Tensor: Modulated tensor after convolution.
  68. """
  69. b, c, h, w = x.shape # c = c_in
  70. # weight modulation
  71. style = self.modulation(style).view(b, 1, c, 1, 1)
  72. # self.weight: (1, c_out, c_in, k, k); style: (b, 1, c, 1, 1)
  73. weight = self.weight * style # (b, c_out, c_in, k, k)
  74. if self.demodulate:
  75. demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps)
  76. weight = weight * demod.view(b, self.out_channels, 1, 1, 1)
  77. weight = weight.view(
  78. b * self.out_channels, c, self.kernel_size, self.kernel_size
  79. )
  80. # upsample or downsample if necessary
  81. if self.sample_mode == "upsample":
  82. x = F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=False)
  83. elif self.sample_mode == "downsample":
  84. x = F.interpolate(x, scale_factor=0.5, mode="bilinear", align_corners=False)
  85. b, c, h, w = x.shape
  86. x = x.view(1, b * c, h, w)
  87. # weight: (b*c_out, c_in, k, k), groups=b
  88. out = F.conv2d(x, weight, padding=self.padding, groups=b)
  89. out = out.view(b, self.out_channels, *out.shape[2:4])
  90. return out
  91. def __repr__(self):
  92. return (
  93. f"{self.__class__.__name__}(in_channels={self.in_channels}, out_channels={self.out_channels}, "
  94. f"kernel_size={self.kernel_size}, demodulate={self.demodulate}, sample_mode={self.sample_mode})"
  95. )
  96. class StyleConv(nn.Module):
  97. """Style conv used in StyleGAN2.
  98. Args:
  99. in_channels (int): Channel number of the input.
  100. out_channels (int): Channel number of the output.
  101. kernel_size (int): Size of the convolving kernel.
  102. num_style_feat (int): Channel number of style features.
  103. demodulate (bool): Whether demodulate in the conv layer. Default: True.
  104. sample_mode (str | None): Indicating 'upsample', 'downsample' or None. Default: None.
  105. """
  106. def __init__(
  107. self,
  108. in_channels,
  109. out_channels,
  110. kernel_size,
  111. num_style_feat,
  112. demodulate=True,
  113. sample_mode=None,
  114. ):
  115. super(StyleConv, self).__init__()
  116. self.modulated_conv = ModulatedConv2d(
  117. in_channels,
  118. out_channels,
  119. kernel_size,
  120. num_style_feat,
  121. demodulate=demodulate,
  122. sample_mode=sample_mode,
  123. )
  124. self.weight = nn.Parameter(torch.zeros(1)) # for noise injection
  125. self.bias = nn.Parameter(torch.zeros(1, out_channels, 1, 1))
  126. self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True)
  127. def forward(self, x, style, noise=None):
  128. # modulate
  129. out = self.modulated_conv(x, style) * 2**0.5 # for conversion
  130. # noise injection
  131. if noise is None:
  132. b, _, h, w = out.shape
  133. noise = out.new_empty(b, 1, h, w).normal_()
  134. out = out + self.weight * noise
  135. # add bias
  136. out = out + self.bias
  137. # activation
  138. out = self.activate(out)
  139. return out
  140. class ToRGB(nn.Module):
  141. """To RGB (image space) from features.
  142. Args:
  143. in_channels (int): Channel number of input.
  144. num_style_feat (int): Channel number of style features.
  145. upsample (bool): Whether to upsample. Default: True.
  146. """
  147. def __init__(self, in_channels, num_style_feat, upsample=True):
  148. super(ToRGB, self).__init__()
  149. self.upsample = upsample
  150. self.modulated_conv = ModulatedConv2d(
  151. in_channels,
  152. 3,
  153. kernel_size=1,
  154. num_style_feat=num_style_feat,
  155. demodulate=False,
  156. sample_mode=None,
  157. )
  158. self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
  159. def forward(self, x, style, skip=None):
  160. """Forward function.
  161. Args:
  162. x (Tensor): Feature tensor with shape (b, c, h, w).
  163. style (Tensor): Tensor with shape (b, num_style_feat).
  164. skip (Tensor): Base/skip tensor. Default: None.
  165. Returns:
  166. Tensor: RGB images.
  167. """
  168. out = self.modulated_conv(x, style)
  169. out = out + self.bias
  170. if skip is not None:
  171. if self.upsample:
  172. skip = F.interpolate(
  173. skip, scale_factor=2, mode="bilinear", align_corners=False
  174. )
  175. out = out + skip
  176. return out
  177. class ConstantInput(nn.Module):
  178. """Constant input.
  179. Args:
  180. num_channel (int): Channel number of constant input.
  181. size (int): Spatial size of constant input.
  182. """
  183. def __init__(self, num_channel, size):
  184. super(ConstantInput, self).__init__()
  185. self.weight = nn.Parameter(torch.randn(1, num_channel, size, size))
  186. def forward(self, batch):
  187. out = self.weight.repeat(batch, 1, 1, 1)
  188. return out
  189. class StyleGAN2GeneratorClean(nn.Module):
  190. """Clean version of StyleGAN2 Generator.
  191. Args:
  192. out_size (int): The spatial size of outputs.
  193. num_style_feat (int): Channel number of style features. Default: 512.
  194. num_mlp (int): Layer number of MLP style layers. Default: 8.
  195. channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
  196. narrow (float): Narrow ratio for channels. Default: 1.0.
  197. """
  198. def __init__(
  199. self, out_size, num_style_feat=512, num_mlp=8, channel_multiplier=2, narrow=1
  200. ):
  201. super(StyleGAN2GeneratorClean, self).__init__()
  202. # Style MLP layers
  203. self.num_style_feat = num_style_feat
  204. style_mlp_layers = [NormStyleCode()]
  205. for i in range(num_mlp):
  206. style_mlp_layers.extend(
  207. [
  208. nn.Linear(num_style_feat, num_style_feat, bias=True),
  209. nn.LeakyReLU(negative_slope=0.2, inplace=True),
  210. ]
  211. )
  212. self.style_mlp = nn.Sequential(*style_mlp_layers)
  213. # initialization
  214. default_init_weights(
  215. self.style_mlp,
  216. scale=1,
  217. bias_fill=0,
  218. a=0.2,
  219. mode="fan_in",
  220. nonlinearity="leaky_relu",
  221. )
  222. # channel list
  223. channels = {
  224. "4": int(512 * narrow),
  225. "8": int(512 * narrow),
  226. "16": int(512 * narrow),
  227. "32": int(512 * narrow),
  228. "64": int(256 * channel_multiplier * narrow),
  229. "128": int(128 * channel_multiplier * narrow),
  230. "256": int(64 * channel_multiplier * narrow),
  231. "512": int(32 * channel_multiplier * narrow),
  232. "1024": int(16 * channel_multiplier * narrow),
  233. }
  234. self.channels = channels
  235. self.constant_input = ConstantInput(channels["4"], size=4)
  236. self.style_conv1 = StyleConv(
  237. channels["4"],
  238. channels["4"],
  239. kernel_size=3,
  240. num_style_feat=num_style_feat,
  241. demodulate=True,
  242. sample_mode=None,
  243. )
  244. self.to_rgb1 = ToRGB(channels["4"], num_style_feat, upsample=False)
  245. self.log_size = int(math.log(out_size, 2))
  246. self.num_layers = (self.log_size - 2) * 2 + 1
  247. self.num_latent = self.log_size * 2 - 2
  248. self.style_convs = nn.ModuleList()
  249. self.to_rgbs = nn.ModuleList()
  250. self.noises = nn.Module()
  251. in_channels = channels["4"]
  252. # noise
  253. for layer_idx in range(self.num_layers):
  254. resolution = 2 ** ((layer_idx + 5) // 2)
  255. shape = [1, 1, resolution, resolution]
  256. self.noises.register_buffer(f"noise{layer_idx}", torch.randn(*shape))
  257. # style convs and to_rgbs
  258. for i in range(3, self.log_size + 1):
  259. out_channels = channels[f"{2 ** i}"]
  260. self.style_convs.append(
  261. StyleConv(
  262. in_channels,
  263. out_channels,
  264. kernel_size=3,
  265. num_style_feat=num_style_feat,
  266. demodulate=True,
  267. sample_mode="upsample",
  268. )
  269. )
  270. self.style_convs.append(
  271. StyleConv(
  272. out_channels,
  273. out_channels,
  274. kernel_size=3,
  275. num_style_feat=num_style_feat,
  276. demodulate=True,
  277. sample_mode=None,
  278. )
  279. )
  280. self.to_rgbs.append(ToRGB(out_channels, num_style_feat, upsample=True))
  281. in_channels = out_channels
  282. def make_noise(self):
  283. """Make noise for noise injection."""
  284. device = self.constant_input.weight.device
  285. noises = [torch.randn(1, 1, 4, 4, device=device)]
  286. for i in range(3, self.log_size + 1):
  287. for _ in range(2):
  288. noises.append(torch.randn(1, 1, 2**i, 2**i, device=device))
  289. return noises
  290. def get_latent(self, x):
  291. return self.style_mlp(x)
  292. def mean_latent(self, num_latent):
  293. latent_in = torch.randn(
  294. num_latent, self.num_style_feat, device=self.constant_input.weight.device
  295. )
  296. latent = self.style_mlp(latent_in).mean(0, keepdim=True)
  297. return latent
  298. def forward(
  299. self,
  300. styles,
  301. input_is_latent=False,
  302. noise=None,
  303. randomize_noise=True,
  304. truncation=1,
  305. truncation_latent=None,
  306. inject_index=None,
  307. return_latents=False,
  308. ):
  309. """Forward function for StyleGAN2GeneratorClean.
  310. Args:
  311. styles (list[Tensor]): Sample codes of styles.
  312. input_is_latent (bool): Whether input is latent style. Default: False.
  313. noise (Tensor | None): Input noise or None. Default: None.
  314. randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
  315. truncation (float): The truncation ratio. Default: 1.
  316. truncation_latent (Tensor | None): The truncation latent tensor. Default: None.
  317. inject_index (int | None): The injection index for mixing noise. Default: None.
  318. return_latents (bool): Whether to return style latents. Default: False.
  319. """
  320. # style codes -> latents with Style MLP layer
  321. if not input_is_latent:
  322. styles = [self.style_mlp(s) for s in styles]
  323. # noises
  324. if noise is None:
  325. if randomize_noise:
  326. noise = [None] * self.num_layers # for each style conv layer
  327. else: # use the stored noise
  328. noise = [
  329. getattr(self.noises, f"noise{i}") for i in range(self.num_layers)
  330. ]
  331. # style truncation
  332. if truncation < 1:
  333. style_truncation = []
  334. for style in styles:
  335. style_truncation.append(
  336. truncation_latent + truncation * (style - truncation_latent)
  337. )
  338. styles = style_truncation
  339. # get style latents with injection
  340. if len(styles) == 1:
  341. inject_index = self.num_latent
  342. if styles[0].ndim < 3:
  343. # repeat latent code for all the layers
  344. latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
  345. else: # used for encoder with different latent code for each layer
  346. latent = styles[0]
  347. elif len(styles) == 2: # mixing noises
  348. if inject_index is None:
  349. inject_index = random.randint(1, self.num_latent - 1)
  350. latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
  351. latent2 = (
  352. styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
  353. )
  354. latent = torch.cat([latent1, latent2], 1)
  355. # main generation
  356. out = self.constant_input(latent.shape[0])
  357. out = self.style_conv1(out, latent[:, 0], noise=noise[0])
  358. skip = self.to_rgb1(out, latent[:, 1])
  359. i = 1
  360. for conv1, conv2, noise1, noise2, to_rgb in zip(
  361. self.style_convs[::2],
  362. self.style_convs[1::2],
  363. noise[1::2],
  364. noise[2::2],
  365. self.to_rgbs,
  366. ):
  367. out = conv1(out, latent[:, i], noise=noise1)
  368. out = conv2(out, latent[:, i + 1], noise=noise2)
  369. skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space
  370. i += 2
  371. image = skip
  372. if return_latents:
  373. return image, latent
  374. else:
  375. return image, None