fcf.py 56 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736
  1. import os
  2. import random
  3. import cv2
  4. import numpy as np
  5. import torch
  6. import torch.fft as fft
  7. import torch.nn.functional as F
  8. from torch import conv2d, nn
  9. from sorawm.iopaint.helper import (
  10. boxes_from_mask,
  11. download_model,
  12. get_cache_path_by_url,
  13. load_model,
  14. norm_img,
  15. resize_max_size,
  16. )
  17. from sorawm.iopaint.schema import InpaintRequest
  18. from .base import InpaintModel
  19. from .utils import (
  20. Conv2dLayer,
  21. FullyConnectedLayer,
  22. MinibatchStdLayer,
  23. _parse_padding,
  24. _parse_scaling,
  25. activation_funcs,
  26. bias_act,
  27. conv2d_resample,
  28. downsample2d,
  29. normalize_2nd_moment,
  30. setup_filter,
  31. upsample2d,
  32. )
  33. def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl="cuda"):
  34. assert isinstance(x, torch.Tensor)
  35. return _upfirdn2d_ref(
  36. x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain
  37. )
  38. def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1):
  39. """Slow reference implementation of `upfirdn2d()` using standard PyTorch ops."""
  40. # Validate arguments.
  41. assert isinstance(x, torch.Tensor) and x.ndim == 4
  42. if f is None:
  43. f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
  44. assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
  45. assert f.dtype == torch.float32 and not f.requires_grad
  46. batch_size, num_channels, in_height, in_width = x.shape
  47. upx, upy = _parse_scaling(up)
  48. downx, downy = _parse_scaling(down)
  49. padx0, padx1, pady0, pady1 = _parse_padding(padding)
  50. # Upsample by inserting zeros.
  51. x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1])
  52. x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1])
  53. x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx])
  54. # Pad or crop.
  55. x = torch.nn.functional.pad(
  56. x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)]
  57. )
  58. x = x[
  59. :,
  60. :,
  61. max(-pady0, 0) : x.shape[2] - max(-pady1, 0),
  62. max(-padx0, 0) : x.shape[3] - max(-padx1, 0),
  63. ]
  64. # Setup filter.
  65. f = f * (gain ** (f.ndim / 2))
  66. f = f.to(x.dtype)
  67. if not flip_filter:
  68. f = f.flip(list(range(f.ndim)))
  69. # Convolve with the filter.
  70. f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim)
  71. if f.ndim == 4:
  72. x = conv2d(input=x, weight=f, groups=num_channels)
  73. else:
  74. x = conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels)
  75. x = conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels)
  76. # Downsample by throwing away pixels.
  77. x = x[:, :, ::downy, ::downx]
  78. return x
  79. class EncoderEpilogue(torch.nn.Module):
  80. def __init__(
  81. self,
  82. in_channels, # Number of input channels.
  83. cmap_dim, # Dimensionality of mapped conditioning label, 0 = no label.
  84. z_dim, # Output Latent (Z) dimensionality.
  85. resolution, # Resolution of this block.
  86. img_channels, # Number of input color channels.
  87. architecture="resnet", # Architecture: 'orig', 'skip', 'resnet'.
  88. mbstd_group_size=4, # Group size for the minibatch standard deviation layer, None = entire minibatch.
  89. mbstd_num_channels=1, # Number of features for the minibatch standard deviation layer, 0 = disable.
  90. activation="lrelu", # Activation function: 'relu', 'lrelu', etc.
  91. conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping.
  92. ):
  93. assert architecture in ["orig", "skip", "resnet"]
  94. super().__init__()
  95. self.in_channels = in_channels
  96. self.cmap_dim = cmap_dim
  97. self.resolution = resolution
  98. self.img_channels = img_channels
  99. self.architecture = architecture
  100. if architecture == "skip":
  101. self.fromrgb = Conv2dLayer(
  102. self.img_channels, in_channels, kernel_size=1, activation=activation
  103. )
  104. self.mbstd = (
  105. MinibatchStdLayer(
  106. group_size=mbstd_group_size, num_channels=mbstd_num_channels
  107. )
  108. if mbstd_num_channels > 0
  109. else None
  110. )
  111. self.conv = Conv2dLayer(
  112. in_channels + mbstd_num_channels,
  113. in_channels,
  114. kernel_size=3,
  115. activation=activation,
  116. conv_clamp=conv_clamp,
  117. )
  118. self.fc = FullyConnectedLayer(
  119. in_channels * (resolution**2), z_dim, activation=activation
  120. )
  121. self.dropout = torch.nn.Dropout(p=0.5)
  122. def forward(self, x, cmap, force_fp32=False):
  123. _ = force_fp32 # unused
  124. dtype = torch.float32
  125. memory_format = torch.contiguous_format
  126. # FromRGB.
  127. x = x.to(dtype=dtype, memory_format=memory_format)
  128. # Main layers.
  129. if self.mbstd is not None:
  130. x = self.mbstd(x)
  131. const_e = self.conv(x)
  132. x = self.fc(const_e.flatten(1))
  133. x = self.dropout(x)
  134. # Conditioning.
  135. if self.cmap_dim > 0:
  136. x = (x * cmap).sum(dim=1, keepdim=True) * (1 / np.sqrt(self.cmap_dim))
  137. assert x.dtype == dtype
  138. return x, const_e
  139. class EncoderBlock(torch.nn.Module):
  140. def __init__(
  141. self,
  142. in_channels, # Number of input channels, 0 = first block.
  143. tmp_channels, # Number of intermediate channels.
  144. out_channels, # Number of output channels.
  145. resolution, # Resolution of this block.
  146. img_channels, # Number of input color channels.
  147. first_layer_idx, # Index of the first layer.
  148. architecture="skip", # Architecture: 'orig', 'skip', 'resnet'.
  149. activation="lrelu", # Activation function: 'relu', 'lrelu', etc.
  150. resample_filter=[
  151. 1,
  152. 3,
  153. 3,
  154. 1,
  155. ], # Low-pass filter to apply when resampling activations.
  156. conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping.
  157. use_fp16=False, # Use FP16 for this block?
  158. fp16_channels_last=False, # Use channels-last memory format with FP16?
  159. freeze_layers=0, # Freeze-D: Number of layers to freeze.
  160. ):
  161. assert in_channels in [0, tmp_channels]
  162. assert architecture in ["orig", "skip", "resnet"]
  163. super().__init__()
  164. self.in_channels = in_channels
  165. self.resolution = resolution
  166. self.img_channels = img_channels + 1
  167. self.first_layer_idx = first_layer_idx
  168. self.architecture = architecture
  169. self.use_fp16 = use_fp16
  170. self.channels_last = use_fp16 and fp16_channels_last
  171. self.register_buffer("resample_filter", setup_filter(resample_filter))
  172. self.num_layers = 0
  173. def trainable_gen():
  174. while True:
  175. layer_idx = self.first_layer_idx + self.num_layers
  176. trainable = layer_idx >= freeze_layers
  177. self.num_layers += 1
  178. yield trainable
  179. trainable_iter = trainable_gen()
  180. if in_channels == 0:
  181. self.fromrgb = Conv2dLayer(
  182. self.img_channels,
  183. tmp_channels,
  184. kernel_size=1,
  185. activation=activation,
  186. trainable=next(trainable_iter),
  187. conv_clamp=conv_clamp,
  188. channels_last=self.channels_last,
  189. )
  190. self.conv0 = Conv2dLayer(
  191. tmp_channels,
  192. tmp_channels,
  193. kernel_size=3,
  194. activation=activation,
  195. trainable=next(trainable_iter),
  196. conv_clamp=conv_clamp,
  197. channels_last=self.channels_last,
  198. )
  199. self.conv1 = Conv2dLayer(
  200. tmp_channels,
  201. out_channels,
  202. kernel_size=3,
  203. activation=activation,
  204. down=2,
  205. trainable=next(trainable_iter),
  206. resample_filter=resample_filter,
  207. conv_clamp=conv_clamp,
  208. channels_last=self.channels_last,
  209. )
  210. if architecture == "resnet":
  211. self.skip = Conv2dLayer(
  212. tmp_channels,
  213. out_channels,
  214. kernel_size=1,
  215. bias=False,
  216. down=2,
  217. trainable=next(trainable_iter),
  218. resample_filter=resample_filter,
  219. channels_last=self.channels_last,
  220. )
  221. def forward(self, x, img, force_fp32=False):
  222. # dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32
  223. dtype = torch.float32
  224. memory_format = (
  225. torch.channels_last
  226. if self.channels_last and not force_fp32
  227. else torch.contiguous_format
  228. )
  229. # Input.
  230. if x is not None:
  231. x = x.to(dtype=dtype, memory_format=memory_format)
  232. # FromRGB.
  233. if self.in_channels == 0:
  234. img = img.to(dtype=dtype, memory_format=memory_format)
  235. y = self.fromrgb(img)
  236. x = x + y if x is not None else y
  237. img = (
  238. downsample2d(img, self.resample_filter)
  239. if self.architecture == "skip"
  240. else None
  241. )
  242. # Main layers.
  243. if self.architecture == "resnet":
  244. y = self.skip(x, gain=np.sqrt(0.5))
  245. x = self.conv0(x)
  246. feat = x.clone()
  247. x = self.conv1(x, gain=np.sqrt(0.5))
  248. x = y.add_(x)
  249. else:
  250. x = self.conv0(x)
  251. feat = x.clone()
  252. x = self.conv1(x)
  253. assert x.dtype == dtype
  254. return x, img, feat
  255. class EncoderNetwork(torch.nn.Module):
  256. def __init__(
  257. self,
  258. c_dim, # Conditioning label (C) dimensionality.
  259. z_dim, # Input latent (Z) dimensionality.
  260. img_resolution, # Input resolution.
  261. img_channels, # Number of input color channels.
  262. architecture="orig", # Architecture: 'orig', 'skip', 'resnet'.
  263. channel_base=16384, # Overall multiplier for the number of channels.
  264. channel_max=512, # Maximum number of channels in any layer.
  265. num_fp16_res=0, # Use FP16 for the N highest resolutions.
  266. conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping.
  267. cmap_dim=None, # Dimensionality of mapped conditioning label, None = default.
  268. block_kwargs={}, # Arguments for DiscriminatorBlock.
  269. mapping_kwargs={}, # Arguments for MappingNetwork.
  270. epilogue_kwargs={}, # Arguments for EncoderEpilogue.
  271. ):
  272. super().__init__()
  273. self.c_dim = c_dim
  274. self.z_dim = z_dim
  275. self.img_resolution = img_resolution
  276. self.img_resolution_log2 = int(np.log2(img_resolution))
  277. self.img_channels = img_channels
  278. self.block_resolutions = [
  279. 2**i for i in range(self.img_resolution_log2, 2, -1)
  280. ]
  281. channels_dict = {
  282. res: min(channel_base // res, channel_max)
  283. for res in self.block_resolutions + [4]
  284. }
  285. fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8)
  286. if cmap_dim is None:
  287. cmap_dim = channels_dict[4]
  288. if c_dim == 0:
  289. cmap_dim = 0
  290. common_kwargs = dict(
  291. img_channels=img_channels, architecture=architecture, conv_clamp=conv_clamp
  292. )
  293. cur_layer_idx = 0
  294. for res in self.block_resolutions:
  295. in_channels = channels_dict[res] if res < img_resolution else 0
  296. tmp_channels = channels_dict[res]
  297. out_channels = channels_dict[res // 2]
  298. use_fp16 = res >= fp16_resolution
  299. use_fp16 = False
  300. block = EncoderBlock(
  301. in_channels,
  302. tmp_channels,
  303. out_channels,
  304. resolution=res,
  305. first_layer_idx=cur_layer_idx,
  306. use_fp16=use_fp16,
  307. **block_kwargs,
  308. **common_kwargs,
  309. )
  310. setattr(self, f"b{res}", block)
  311. cur_layer_idx += block.num_layers
  312. if c_dim > 0:
  313. self.mapping = MappingNetwork(
  314. z_dim=0,
  315. c_dim=c_dim,
  316. w_dim=cmap_dim,
  317. num_ws=None,
  318. w_avg_beta=None,
  319. **mapping_kwargs,
  320. )
  321. self.b4 = EncoderEpilogue(
  322. channels_dict[4],
  323. cmap_dim=cmap_dim,
  324. z_dim=z_dim * 2,
  325. resolution=4,
  326. **epilogue_kwargs,
  327. **common_kwargs,
  328. )
  329. def forward(self, img, c, **block_kwargs):
  330. x = None
  331. feats = {}
  332. for res in self.block_resolutions:
  333. block = getattr(self, f"b{res}")
  334. x, img, feat = block(x, img, **block_kwargs)
  335. feats[res] = feat
  336. cmap = None
  337. if self.c_dim > 0:
  338. cmap = self.mapping(None, c)
  339. x, const_e = self.b4(x, cmap)
  340. feats[4] = const_e
  341. B, _ = x.shape
  342. z = torch.zeros(
  343. (B, self.z_dim), requires_grad=False, dtype=x.dtype, device=x.device
  344. ) ## Noise for Co-Modulation
  345. return x, z, feats
  346. def fma(a, b, c): # => a * b + c
  347. return _FusedMultiplyAdd.apply(a, b, c)
  348. class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c
  349. @staticmethod
  350. def forward(ctx, a, b, c): # pylint: disable=arguments-differ
  351. out = torch.addcmul(c, a, b)
  352. ctx.save_for_backward(a, b)
  353. ctx.c_shape = c.shape
  354. return out
  355. @staticmethod
  356. def backward(ctx, dout): # pylint: disable=arguments-differ
  357. a, b = ctx.saved_tensors
  358. c_shape = ctx.c_shape
  359. da = None
  360. db = None
  361. dc = None
  362. if ctx.needs_input_grad[0]:
  363. da = _unbroadcast(dout * b, a.shape)
  364. if ctx.needs_input_grad[1]:
  365. db = _unbroadcast(dout * a, b.shape)
  366. if ctx.needs_input_grad[2]:
  367. dc = _unbroadcast(dout, c_shape)
  368. return da, db, dc
  369. def _unbroadcast(x, shape):
  370. extra_dims = x.ndim - len(shape)
  371. assert extra_dims >= 0
  372. dim = [
  373. i
  374. for i in range(x.ndim)
  375. if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)
  376. ]
  377. if len(dim):
  378. x = x.sum(dim=dim, keepdim=True)
  379. if extra_dims:
  380. x = x.reshape(-1, *x.shape[extra_dims + 1 :])
  381. assert x.shape == shape
  382. return x
  383. def modulated_conv2d(
  384. x, # Input tensor of shape [batch_size, in_channels, in_height, in_width].
  385. weight, # Weight tensor of shape [out_channels, in_channels, kernel_height, kernel_width].
  386. styles, # Modulation coefficients of shape [batch_size, in_channels].
  387. noise=None, # Optional noise tensor to add to the output activations.
  388. up=1, # Integer upsampling factor.
  389. down=1, # Integer downsampling factor.
  390. padding=0, # Padding with respect to the upsampled image.
  391. resample_filter=None,
  392. # Low-pass filter to apply when resampling activations. Must be prepared beforehand by calling upfirdn2d.setup_filter().
  393. demodulate=True, # Apply weight demodulation?
  394. flip_weight=True, # False = convolution, True = correlation (matches torch.nn.functional.conv2d).
  395. fused_modconv=True, # Perform modulation, convolution, and demodulation as a single fused operation?
  396. ):
  397. batch_size = x.shape[0]
  398. out_channels, in_channels, kh, kw = weight.shape
  399. # Pre-normalize inputs to avoid FP16 overflow.
  400. if x.dtype == torch.float16 and demodulate:
  401. weight = weight * (
  402. 1
  403. / np.sqrt(in_channels * kh * kw)
  404. / weight.norm(float("inf"), dim=[1, 2, 3], keepdim=True)
  405. ) # max_Ikk
  406. styles = styles / styles.norm(float("inf"), dim=1, keepdim=True) # max_I
  407. # Calculate per-sample weights and demodulation coefficients.
  408. w = None
  409. dcoefs = None
  410. if demodulate or fused_modconv:
  411. w = weight.unsqueeze(0) # [NOIkk]
  412. w = w * styles.reshape(batch_size, 1, -1, 1, 1) # [NOIkk]
  413. if demodulate:
  414. dcoefs = (w.square().sum(dim=[2, 3, 4]) + 1e-8).rsqrt() # [NO]
  415. if demodulate and fused_modconv:
  416. w = w * dcoefs.reshape(batch_size, -1, 1, 1, 1) # [NOIkk]
  417. # Execute by scaling the activations before and after the convolution.
  418. if not fused_modconv:
  419. x = x * styles.to(x.dtype).reshape(batch_size, -1, 1, 1)
  420. x = conv2d_resample.conv2d_resample(
  421. x=x,
  422. w=weight.to(x.dtype),
  423. f=resample_filter,
  424. up=up,
  425. down=down,
  426. padding=padding,
  427. flip_weight=flip_weight,
  428. )
  429. if demodulate and noise is not None:
  430. x = fma(
  431. x, dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1), noise.to(x.dtype)
  432. )
  433. elif demodulate:
  434. x = x * dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1)
  435. elif noise is not None:
  436. x = x.add_(noise.to(x.dtype))
  437. return x
  438. # Execute as one fused op using grouped convolution.
  439. batch_size = int(batch_size)
  440. x = x.reshape(1, -1, *x.shape[2:])
  441. w = w.reshape(-1, in_channels, kh, kw)
  442. x = conv2d_resample(
  443. x=x,
  444. w=w.to(x.dtype),
  445. f=resample_filter,
  446. up=up,
  447. down=down,
  448. padding=padding,
  449. groups=batch_size,
  450. flip_weight=flip_weight,
  451. )
  452. x = x.reshape(batch_size, -1, *x.shape[2:])
  453. if noise is not None:
  454. x = x.add_(noise)
  455. return x
  456. class SynthesisLayer(torch.nn.Module):
  457. def __init__(
  458. self,
  459. in_channels, # Number of input channels.
  460. out_channels, # Number of output channels.
  461. w_dim, # Intermediate latent (W) dimensionality.
  462. resolution, # Resolution of this layer.
  463. kernel_size=3, # Convolution kernel size.
  464. up=1, # Integer upsampling factor.
  465. use_noise=True, # Enable noise input?
  466. activation="lrelu", # Activation function: 'relu', 'lrelu', etc.
  467. resample_filter=[
  468. 1,
  469. 3,
  470. 3,
  471. 1,
  472. ], # Low-pass filter to apply when resampling activations.
  473. conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping.
  474. channels_last=False, # Use channels_last format for the weights?
  475. ):
  476. super().__init__()
  477. self.resolution = resolution
  478. self.up = up
  479. self.use_noise = use_noise
  480. self.activation = activation
  481. self.conv_clamp = conv_clamp
  482. self.register_buffer("resample_filter", setup_filter(resample_filter))
  483. self.padding = kernel_size // 2
  484. self.act_gain = activation_funcs[activation].def_gain
  485. self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1)
  486. memory_format = (
  487. torch.channels_last if channels_last else torch.contiguous_format
  488. )
  489. self.weight = torch.nn.Parameter(
  490. torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(
  491. memory_format=memory_format
  492. )
  493. )
  494. if use_noise:
  495. self.register_buffer("noise_const", torch.randn([resolution, resolution]))
  496. self.noise_strength = torch.nn.Parameter(torch.zeros([]))
  497. self.bias = torch.nn.Parameter(torch.zeros([out_channels]))
  498. def forward(self, x, w, noise_mode="none", fused_modconv=True, gain=1):
  499. assert noise_mode in ["random", "const", "none"]
  500. in_resolution = self.resolution // self.up
  501. styles = self.affine(w)
  502. noise = None
  503. if self.use_noise and noise_mode == "random":
  504. noise = (
  505. torch.randn(
  506. [x.shape[0], 1, self.resolution, self.resolution], device=x.device
  507. )
  508. * self.noise_strength
  509. )
  510. if self.use_noise and noise_mode == "const":
  511. noise = self.noise_const * self.noise_strength
  512. flip_weight = self.up == 1 # slightly faster
  513. x = modulated_conv2d(
  514. x=x,
  515. weight=self.weight,
  516. styles=styles,
  517. noise=noise,
  518. up=self.up,
  519. padding=self.padding,
  520. resample_filter=self.resample_filter,
  521. flip_weight=flip_weight,
  522. fused_modconv=fused_modconv,
  523. )
  524. act_gain = self.act_gain * gain
  525. act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None
  526. x = F.leaky_relu(x, negative_slope=0.2, inplace=False)
  527. if act_gain != 1:
  528. x = x * act_gain
  529. if act_clamp is not None:
  530. x = x.clamp(-act_clamp, act_clamp)
  531. return x
  532. class ToRGBLayer(torch.nn.Module):
  533. def __init__(
  534. self,
  535. in_channels,
  536. out_channels,
  537. w_dim,
  538. kernel_size=1,
  539. conv_clamp=None,
  540. channels_last=False,
  541. ):
  542. super().__init__()
  543. self.conv_clamp = conv_clamp
  544. self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1)
  545. memory_format = (
  546. torch.channels_last if channels_last else torch.contiguous_format
  547. )
  548. self.weight = torch.nn.Parameter(
  549. torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(
  550. memory_format=memory_format
  551. )
  552. )
  553. self.bias = torch.nn.Parameter(torch.zeros([out_channels]))
  554. self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size**2))
  555. def forward(self, x, w, fused_modconv=True):
  556. styles = self.affine(w) * self.weight_gain
  557. x = modulated_conv2d(
  558. x=x,
  559. weight=self.weight,
  560. styles=styles,
  561. demodulate=False,
  562. fused_modconv=fused_modconv,
  563. )
  564. x = bias_act(x, self.bias.to(x.dtype), clamp=self.conv_clamp)
  565. return x
  566. class SynthesisForeword(torch.nn.Module):
  567. def __init__(
  568. self,
  569. z_dim, # Output Latent (Z) dimensionality.
  570. resolution, # Resolution of this block.
  571. in_channels,
  572. img_channels, # Number of input color channels.
  573. architecture="skip", # Architecture: 'orig', 'skip', 'resnet'.
  574. activation="lrelu", # Activation function: 'relu', 'lrelu', etc.
  575. ):
  576. super().__init__()
  577. self.in_channels = in_channels
  578. self.z_dim = z_dim
  579. self.resolution = resolution
  580. self.img_channels = img_channels
  581. self.architecture = architecture
  582. self.fc = FullyConnectedLayer(
  583. self.z_dim, (self.z_dim // 2) * 4 * 4, activation=activation
  584. )
  585. self.conv = SynthesisLayer(
  586. self.in_channels, self.in_channels, w_dim=(z_dim // 2) * 3, resolution=4
  587. )
  588. if architecture == "skip":
  589. self.torgb = ToRGBLayer(
  590. self.in_channels,
  591. self.img_channels,
  592. kernel_size=1,
  593. w_dim=(z_dim // 2) * 3,
  594. )
  595. def forward(self, x, ws, feats, img, force_fp32=False):
  596. _ = force_fp32 # unused
  597. dtype = torch.float32
  598. memory_format = torch.contiguous_format
  599. x_global = x.clone()
  600. # ToRGB.
  601. x = self.fc(x)
  602. x = x.view(-1, self.z_dim // 2, 4, 4)
  603. x = x.to(dtype=dtype, memory_format=memory_format)
  604. # Main layers.
  605. x_skip = feats[4].clone()
  606. x = x + x_skip
  607. mod_vector = []
  608. mod_vector.append(ws[:, 0])
  609. mod_vector.append(x_global.clone())
  610. mod_vector = torch.cat(mod_vector, dim=1)
  611. x = self.conv(x, mod_vector)
  612. mod_vector = []
  613. mod_vector.append(ws[:, 2 * 2 - 3])
  614. mod_vector.append(x_global.clone())
  615. mod_vector = torch.cat(mod_vector, dim=1)
  616. if self.architecture == "skip":
  617. img = self.torgb(x, mod_vector)
  618. img = img.to(dtype=torch.float32, memory_format=torch.contiguous_format)
  619. assert x.dtype == dtype
  620. return x, img
  621. class SELayer(nn.Module):
  622. def __init__(self, channel, reduction=16):
  623. super(SELayer, self).__init__()
  624. self.avg_pool = nn.AdaptiveAvgPool2d(1)
  625. self.fc = nn.Sequential(
  626. nn.Linear(channel, channel // reduction, bias=False),
  627. nn.ReLU(inplace=False),
  628. nn.Linear(channel // reduction, channel, bias=False),
  629. nn.Sigmoid(),
  630. )
  631. def forward(self, x):
  632. b, c, _, _ = x.size()
  633. y = self.avg_pool(x).view(b, c)
  634. y = self.fc(y).view(b, c, 1, 1)
  635. res = x * y.expand_as(x)
  636. return res
  637. class FourierUnit(nn.Module):
  638. def __init__(
  639. self,
  640. in_channels,
  641. out_channels,
  642. groups=1,
  643. spatial_scale_factor=None,
  644. spatial_scale_mode="bilinear",
  645. spectral_pos_encoding=False,
  646. use_se=False,
  647. se_kwargs=None,
  648. ffc3d=False,
  649. fft_norm="ortho",
  650. ):
  651. # bn_layer not used
  652. super(FourierUnit, self).__init__()
  653. self.groups = groups
  654. self.conv_layer = torch.nn.Conv2d(
  655. in_channels=in_channels * 2 + (2 if spectral_pos_encoding else 0),
  656. out_channels=out_channels * 2,
  657. kernel_size=1,
  658. stride=1,
  659. padding=0,
  660. groups=self.groups,
  661. bias=False,
  662. )
  663. self.relu = torch.nn.ReLU(inplace=False)
  664. # squeeze and excitation block
  665. self.use_se = use_se
  666. if use_se:
  667. if se_kwargs is None:
  668. se_kwargs = {}
  669. self.se = SELayer(self.conv_layer.in_channels, **se_kwargs)
  670. self.spatial_scale_factor = spatial_scale_factor
  671. self.spatial_scale_mode = spatial_scale_mode
  672. self.spectral_pos_encoding = spectral_pos_encoding
  673. self.ffc3d = ffc3d
  674. self.fft_norm = fft_norm
  675. def forward(self, x):
  676. batch = x.shape[0]
  677. if self.spatial_scale_factor is not None:
  678. orig_size = x.shape[-2:]
  679. x = F.interpolate(
  680. x,
  681. scale_factor=self.spatial_scale_factor,
  682. mode=self.spatial_scale_mode,
  683. align_corners=False,
  684. )
  685. r_size = x.size()
  686. # (batch, c, h, w/2+1, 2)
  687. fft_dim = (-3, -2, -1) if self.ffc3d else (-2, -1)
  688. ffted = fft.rfftn(x, dim=fft_dim, norm=self.fft_norm)
  689. ffted = torch.stack((ffted.real, ffted.imag), dim=-1)
  690. ffted = ffted.permute(0, 1, 4, 2, 3).contiguous() # (batch, c, 2, h, w/2+1)
  691. ffted = ffted.view(
  692. (
  693. batch,
  694. -1,
  695. )
  696. + ffted.size()[3:]
  697. )
  698. if self.spectral_pos_encoding:
  699. height, width = ffted.shape[-2:]
  700. coords_vert = (
  701. torch.linspace(0, 1, height)[None, None, :, None]
  702. .expand(batch, 1, height, width)
  703. .to(ffted)
  704. )
  705. coords_hor = (
  706. torch.linspace(0, 1, width)[None, None, None, :]
  707. .expand(batch, 1, height, width)
  708. .to(ffted)
  709. )
  710. ffted = torch.cat((coords_vert, coords_hor, ffted), dim=1)
  711. if self.use_se:
  712. ffted = self.se(ffted)
  713. ffted = self.conv_layer(ffted) # (batch, c*2, h, w/2+1)
  714. ffted = self.relu(ffted)
  715. ffted = (
  716. ffted.view(
  717. (
  718. batch,
  719. -1,
  720. 2,
  721. )
  722. + ffted.size()[2:]
  723. )
  724. .permute(0, 1, 3, 4, 2)
  725. .contiguous()
  726. ) # (batch,c, t, h, w/2+1, 2)
  727. ffted = torch.complex(ffted[..., 0], ffted[..., 1])
  728. ifft_shape_slice = x.shape[-3:] if self.ffc3d else x.shape[-2:]
  729. output = torch.fft.irfftn(
  730. ffted, s=ifft_shape_slice, dim=fft_dim, norm=self.fft_norm
  731. )
  732. if self.spatial_scale_factor is not None:
  733. output = F.interpolate(
  734. output,
  735. size=orig_size,
  736. mode=self.spatial_scale_mode,
  737. align_corners=False,
  738. )
  739. return output
  740. class SpectralTransform(nn.Module):
  741. def __init__(
  742. self,
  743. in_channels,
  744. out_channels,
  745. stride=1,
  746. groups=1,
  747. enable_lfu=True,
  748. **fu_kwargs,
  749. ):
  750. # bn_layer not used
  751. super(SpectralTransform, self).__init__()
  752. self.enable_lfu = enable_lfu
  753. if stride == 2:
  754. self.downsample = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
  755. else:
  756. self.downsample = nn.Identity()
  757. self.stride = stride
  758. self.conv1 = nn.Sequential(
  759. nn.Conv2d(
  760. in_channels, out_channels // 2, kernel_size=1, groups=groups, bias=False
  761. ),
  762. # nn.BatchNorm2d(out_channels // 2),
  763. nn.ReLU(inplace=True),
  764. )
  765. self.fu = FourierUnit(out_channels // 2, out_channels // 2, groups, **fu_kwargs)
  766. if self.enable_lfu:
  767. self.lfu = FourierUnit(out_channels // 2, out_channels // 2, groups)
  768. self.conv2 = torch.nn.Conv2d(
  769. out_channels // 2, out_channels, kernel_size=1, groups=groups, bias=False
  770. )
  771. def forward(self, x):
  772. x = self.downsample(x)
  773. x = self.conv1(x)
  774. output = self.fu(x)
  775. if self.enable_lfu:
  776. n, c, h, w = x.shape
  777. split_no = 2
  778. split_s = h // split_no
  779. xs = torch.cat(
  780. torch.split(x[:, : c // 4], split_s, dim=-2), dim=1
  781. ).contiguous()
  782. xs = torch.cat(torch.split(xs, split_s, dim=-1), dim=1).contiguous()
  783. xs = self.lfu(xs)
  784. xs = xs.repeat(1, 1, split_no, split_no).contiguous()
  785. else:
  786. xs = 0
  787. output = self.conv2(x + output + xs)
  788. return output
  789. class FFC(nn.Module):
  790. def __init__(
  791. self,
  792. in_channels,
  793. out_channels,
  794. kernel_size,
  795. ratio_gin,
  796. ratio_gout,
  797. stride=1,
  798. padding=0,
  799. dilation=1,
  800. groups=1,
  801. bias=False,
  802. enable_lfu=True,
  803. padding_type="reflect",
  804. gated=False,
  805. **spectral_kwargs,
  806. ):
  807. super(FFC, self).__init__()
  808. assert stride == 1 or stride == 2, "Stride should be 1 or 2."
  809. self.stride = stride
  810. in_cg = int(in_channels * ratio_gin)
  811. in_cl = in_channels - in_cg
  812. out_cg = int(out_channels * ratio_gout)
  813. out_cl = out_channels - out_cg
  814. # groups_g = 1 if groups == 1 else int(groups * ratio_gout)
  815. # groups_l = 1 if groups == 1 else groups - groups_g
  816. self.ratio_gin = ratio_gin
  817. self.ratio_gout = ratio_gout
  818. self.global_in_num = in_cg
  819. module = nn.Identity if in_cl == 0 or out_cl == 0 else nn.Conv2d
  820. self.convl2l = module(
  821. in_cl,
  822. out_cl,
  823. kernel_size,
  824. stride,
  825. padding,
  826. dilation,
  827. groups,
  828. bias,
  829. padding_mode=padding_type,
  830. )
  831. module = nn.Identity if in_cl == 0 or out_cg == 0 else nn.Conv2d
  832. self.convl2g = module(
  833. in_cl,
  834. out_cg,
  835. kernel_size,
  836. stride,
  837. padding,
  838. dilation,
  839. groups,
  840. bias,
  841. padding_mode=padding_type,
  842. )
  843. module = nn.Identity if in_cg == 0 or out_cl == 0 else nn.Conv2d
  844. self.convg2l = module(
  845. in_cg,
  846. out_cl,
  847. kernel_size,
  848. stride,
  849. padding,
  850. dilation,
  851. groups,
  852. bias,
  853. padding_mode=padding_type,
  854. )
  855. module = nn.Identity if in_cg == 0 or out_cg == 0 else SpectralTransform
  856. self.convg2g = module(
  857. in_cg,
  858. out_cg,
  859. stride,
  860. 1 if groups == 1 else groups // 2,
  861. enable_lfu,
  862. **spectral_kwargs,
  863. )
  864. self.gated = gated
  865. module = (
  866. nn.Identity if in_cg == 0 or out_cl == 0 or not self.gated else nn.Conv2d
  867. )
  868. self.gate = module(in_channels, 2, 1)
  869. def forward(self, x, fname=None):
  870. x_l, x_g = x if type(x) is tuple else (x, 0)
  871. out_xl, out_xg = 0, 0
  872. if self.gated:
  873. total_input_parts = [x_l]
  874. if torch.is_tensor(x_g):
  875. total_input_parts.append(x_g)
  876. total_input = torch.cat(total_input_parts, dim=1)
  877. gates = torch.sigmoid(self.gate(total_input))
  878. g2l_gate, l2g_gate = gates.chunk(2, dim=1)
  879. else:
  880. g2l_gate, l2g_gate = 1, 1
  881. spec_x = self.convg2g(x_g)
  882. if self.ratio_gout != 1:
  883. out_xl = self.convl2l(x_l) + self.convg2l(x_g) * g2l_gate
  884. if self.ratio_gout != 0:
  885. out_xg = self.convl2g(x_l) * l2g_gate + spec_x
  886. return out_xl, out_xg
  887. class FFC_BN_ACT(nn.Module):
  888. def __init__(
  889. self,
  890. in_channels,
  891. out_channels,
  892. kernel_size,
  893. ratio_gin,
  894. ratio_gout,
  895. stride=1,
  896. padding=0,
  897. dilation=1,
  898. groups=1,
  899. bias=False,
  900. norm_layer=nn.SyncBatchNorm,
  901. activation_layer=nn.Identity,
  902. padding_type="reflect",
  903. enable_lfu=True,
  904. **kwargs,
  905. ):
  906. super(FFC_BN_ACT, self).__init__()
  907. self.ffc = FFC(
  908. in_channels,
  909. out_channels,
  910. kernel_size,
  911. ratio_gin,
  912. ratio_gout,
  913. stride,
  914. padding,
  915. dilation,
  916. groups,
  917. bias,
  918. enable_lfu,
  919. padding_type=padding_type,
  920. **kwargs,
  921. )
  922. lnorm = nn.Identity if ratio_gout == 1 else norm_layer
  923. gnorm = nn.Identity if ratio_gout == 0 else norm_layer
  924. global_channels = int(out_channels * ratio_gout)
  925. # self.bn_l = lnorm(out_channels - global_channels)
  926. # self.bn_g = gnorm(global_channels)
  927. lact = nn.Identity if ratio_gout == 1 else activation_layer
  928. gact = nn.Identity if ratio_gout == 0 else activation_layer
  929. self.act_l = lact(inplace=True)
  930. self.act_g = gact(inplace=True)
  931. def forward(self, x, fname=None):
  932. x_l, x_g = self.ffc(
  933. x,
  934. fname=fname,
  935. )
  936. x_l = self.act_l(x_l)
  937. x_g = self.act_g(x_g)
  938. return x_l, x_g
  939. class FFCResnetBlock(nn.Module):
  940. def __init__(
  941. self,
  942. dim,
  943. padding_type,
  944. norm_layer,
  945. activation_layer=nn.ReLU,
  946. dilation=1,
  947. spatial_transform_kwargs=None,
  948. inline=False,
  949. ratio_gin=0.75,
  950. ratio_gout=0.75,
  951. ):
  952. super().__init__()
  953. self.conv1 = FFC_BN_ACT(
  954. dim,
  955. dim,
  956. kernel_size=3,
  957. padding=dilation,
  958. dilation=dilation,
  959. norm_layer=norm_layer,
  960. activation_layer=activation_layer,
  961. padding_type=padding_type,
  962. ratio_gin=ratio_gin,
  963. ratio_gout=ratio_gout,
  964. )
  965. self.conv2 = FFC_BN_ACT(
  966. dim,
  967. dim,
  968. kernel_size=3,
  969. padding=dilation,
  970. dilation=dilation,
  971. norm_layer=norm_layer,
  972. activation_layer=activation_layer,
  973. padding_type=padding_type,
  974. ratio_gin=ratio_gin,
  975. ratio_gout=ratio_gout,
  976. )
  977. self.inline = inline
  978. def forward(self, x, fname=None):
  979. if self.inline:
  980. x_l, x_g = (
  981. x[:, : -self.conv1.ffc.global_in_num],
  982. x[:, -self.conv1.ffc.global_in_num :],
  983. )
  984. else:
  985. x_l, x_g = x if type(x) is tuple else (x, 0)
  986. id_l, id_g = x_l, x_g
  987. x_l, x_g = self.conv1((x_l, x_g), fname=fname)
  988. x_l, x_g = self.conv2((x_l, x_g), fname=fname)
  989. x_l, x_g = id_l + x_l, id_g + x_g
  990. out = x_l, x_g
  991. if self.inline:
  992. out = torch.cat(out, dim=1)
  993. return out
  994. class ConcatTupleLayer(nn.Module):
  995. def forward(self, x):
  996. assert isinstance(x, tuple)
  997. x_l, x_g = x
  998. assert torch.is_tensor(x_l) or torch.is_tensor(x_g)
  999. if not torch.is_tensor(x_g):
  1000. return x_l
  1001. return torch.cat(x, dim=1)
  1002. class FFCBlock(torch.nn.Module):
  1003. def __init__(
  1004. self,
  1005. dim, # Number of output/input channels.
  1006. kernel_size, # Width and height of the convolution kernel.
  1007. padding,
  1008. ratio_gin=0.75,
  1009. ratio_gout=0.75,
  1010. activation="linear", # Activation function: 'relu', 'lrelu', etc.
  1011. ):
  1012. super().__init__()
  1013. if activation == "linear":
  1014. self.activation = nn.Identity
  1015. else:
  1016. self.activation = nn.ReLU
  1017. self.padding = padding
  1018. self.kernel_size = kernel_size
  1019. self.ffc_block = FFCResnetBlock(
  1020. dim=dim,
  1021. padding_type="reflect",
  1022. norm_layer=nn.SyncBatchNorm,
  1023. activation_layer=self.activation,
  1024. dilation=1,
  1025. ratio_gin=ratio_gin,
  1026. ratio_gout=ratio_gout,
  1027. )
  1028. self.concat_layer = ConcatTupleLayer()
  1029. def forward(self, gen_ft, mask, fname=None):
  1030. x = gen_ft.float()
  1031. x_l, x_g = (
  1032. x[:, : -self.ffc_block.conv1.ffc.global_in_num],
  1033. x[:, -self.ffc_block.conv1.ffc.global_in_num :],
  1034. )
  1035. id_l, id_g = x_l, x_g
  1036. x_l, x_g = self.ffc_block((x_l, x_g), fname=fname)
  1037. x_l, x_g = id_l + x_l, id_g + x_g
  1038. x = self.concat_layer((x_l, x_g))
  1039. return x + gen_ft.float()
  1040. class FFCSkipLayer(torch.nn.Module):
  1041. def __init__(
  1042. self,
  1043. dim, # Number of input/output channels.
  1044. kernel_size=3, # Convolution kernel size.
  1045. ratio_gin=0.75,
  1046. ratio_gout=0.75,
  1047. ):
  1048. super().__init__()
  1049. self.padding = kernel_size // 2
  1050. self.ffc_act = FFCBlock(
  1051. dim=dim,
  1052. kernel_size=kernel_size,
  1053. activation=nn.ReLU,
  1054. padding=self.padding,
  1055. ratio_gin=ratio_gin,
  1056. ratio_gout=ratio_gout,
  1057. )
  1058. def forward(self, gen_ft, mask, fname=None):
  1059. x = self.ffc_act(gen_ft, mask, fname=fname)
  1060. return x
  1061. class SynthesisBlock(torch.nn.Module):
  1062. def __init__(
  1063. self,
  1064. in_channels, # Number of input channels, 0 = first block.
  1065. out_channels, # Number of output channels.
  1066. w_dim, # Intermediate latent (W) dimensionality.
  1067. resolution, # Resolution of this block.
  1068. img_channels, # Number of output color channels.
  1069. is_last, # Is this the last block?
  1070. architecture="skip", # Architecture: 'orig', 'skip', 'resnet'.
  1071. resample_filter=[
  1072. 1,
  1073. 3,
  1074. 3,
  1075. 1,
  1076. ], # Low-pass filter to apply when resampling activations.
  1077. conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping.
  1078. use_fp16=False, # Use FP16 for this block?
  1079. fp16_channels_last=False, # Use channels-last memory format with FP16?
  1080. **layer_kwargs, # Arguments for SynthesisLayer.
  1081. ):
  1082. assert architecture in ["orig", "skip", "resnet"]
  1083. super().__init__()
  1084. self.in_channels = in_channels
  1085. self.w_dim = w_dim
  1086. self.resolution = resolution
  1087. self.img_channels = img_channels
  1088. self.is_last = is_last
  1089. self.architecture = architecture
  1090. self.use_fp16 = use_fp16
  1091. self.channels_last = use_fp16 and fp16_channels_last
  1092. self.register_buffer("resample_filter", setup_filter(resample_filter))
  1093. self.num_conv = 0
  1094. self.num_torgb = 0
  1095. self.res_ffc = {4: 0, 8: 0, 16: 0, 32: 1, 64: 1, 128: 1, 256: 1, 512: 1}
  1096. if in_channels != 0 and resolution >= 8:
  1097. self.ffc_skip = nn.ModuleList()
  1098. for _ in range(self.res_ffc[resolution]):
  1099. self.ffc_skip.append(FFCSkipLayer(dim=out_channels))
  1100. if in_channels == 0:
  1101. self.const = torch.nn.Parameter(
  1102. torch.randn([out_channels, resolution, resolution])
  1103. )
  1104. if in_channels != 0:
  1105. self.conv0 = SynthesisLayer(
  1106. in_channels,
  1107. out_channels,
  1108. w_dim=w_dim * 3,
  1109. resolution=resolution,
  1110. up=2,
  1111. resample_filter=resample_filter,
  1112. conv_clamp=conv_clamp,
  1113. channels_last=self.channels_last,
  1114. **layer_kwargs,
  1115. )
  1116. self.num_conv += 1
  1117. self.conv1 = SynthesisLayer(
  1118. out_channels,
  1119. out_channels,
  1120. w_dim=w_dim * 3,
  1121. resolution=resolution,
  1122. conv_clamp=conv_clamp,
  1123. channels_last=self.channels_last,
  1124. **layer_kwargs,
  1125. )
  1126. self.num_conv += 1
  1127. if is_last or architecture == "skip":
  1128. self.torgb = ToRGBLayer(
  1129. out_channels,
  1130. img_channels,
  1131. w_dim=w_dim * 3,
  1132. conv_clamp=conv_clamp,
  1133. channels_last=self.channels_last,
  1134. )
  1135. self.num_torgb += 1
  1136. if in_channels != 0 and architecture == "resnet":
  1137. self.skip = Conv2dLayer(
  1138. in_channels,
  1139. out_channels,
  1140. kernel_size=1,
  1141. bias=False,
  1142. up=2,
  1143. resample_filter=resample_filter,
  1144. channels_last=self.channels_last,
  1145. )
  1146. def forward(
  1147. self,
  1148. x,
  1149. mask,
  1150. feats,
  1151. img,
  1152. ws,
  1153. fname=None,
  1154. force_fp32=False,
  1155. fused_modconv=None,
  1156. **layer_kwargs,
  1157. ):
  1158. dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32
  1159. dtype = torch.float32
  1160. memory_format = (
  1161. torch.channels_last
  1162. if self.channels_last and not force_fp32
  1163. else torch.contiguous_format
  1164. )
  1165. if fused_modconv is None:
  1166. fused_modconv = (not self.training) and (
  1167. dtype == torch.float32 or int(x.shape[0]) == 1
  1168. )
  1169. x = x.to(dtype=dtype, memory_format=memory_format)
  1170. x_skip = (
  1171. feats[self.resolution].clone().to(dtype=dtype, memory_format=memory_format)
  1172. )
  1173. # Main layers.
  1174. if self.in_channels == 0:
  1175. x = self.conv1(x, ws[1], fused_modconv=fused_modconv, **layer_kwargs)
  1176. elif self.architecture == "resnet":
  1177. y = self.skip(x, gain=np.sqrt(0.5))
  1178. x = self.conv0(
  1179. x, ws[0].clone(), fused_modconv=fused_modconv, **layer_kwargs
  1180. )
  1181. if len(self.ffc_skip) > 0:
  1182. mask = F.interpolate(
  1183. mask,
  1184. size=x_skip.shape[2:],
  1185. )
  1186. z = x + x_skip
  1187. for fres in self.ffc_skip:
  1188. z = fres(z, mask)
  1189. x = x + z
  1190. else:
  1191. x = x + x_skip
  1192. x = self.conv1(
  1193. x,
  1194. ws[1].clone(),
  1195. fused_modconv=fused_modconv,
  1196. gain=np.sqrt(0.5),
  1197. **layer_kwargs,
  1198. )
  1199. x = y.add_(x)
  1200. else:
  1201. x = self.conv0(
  1202. x, ws[0].clone(), fused_modconv=fused_modconv, **layer_kwargs
  1203. )
  1204. if len(self.ffc_skip) > 0:
  1205. mask = F.interpolate(
  1206. mask,
  1207. size=x_skip.shape[2:],
  1208. )
  1209. z = x + x_skip
  1210. for fres in self.ffc_skip:
  1211. z = fres(z, mask)
  1212. x = x + z
  1213. else:
  1214. x = x + x_skip
  1215. x = self.conv1(
  1216. x, ws[1].clone(), fused_modconv=fused_modconv, **layer_kwargs
  1217. )
  1218. # ToRGB.
  1219. if img is not None:
  1220. img = upsample2d(img, self.resample_filter)
  1221. if self.is_last or self.architecture == "skip":
  1222. y = self.torgb(x, ws[2].clone(), fused_modconv=fused_modconv)
  1223. y = y.to(dtype=torch.float32, memory_format=torch.contiguous_format)
  1224. img = img.add_(y) if img is not None else y
  1225. x = x.to(dtype=dtype)
  1226. assert x.dtype == dtype
  1227. assert img is None or img.dtype == torch.float32
  1228. return x, img
  1229. class SynthesisNetwork(torch.nn.Module):
  1230. def __init__(
  1231. self,
  1232. w_dim, # Intermediate latent (W) dimensionality.
  1233. z_dim, # Output Latent (Z) dimensionality.
  1234. img_resolution, # Output image resolution.
  1235. img_channels, # Number of color channels.
  1236. channel_base=16384, # Overall multiplier for the number of channels.
  1237. channel_max=512, # Maximum number of channels in any layer.
  1238. num_fp16_res=0, # Use FP16 for the N highest resolutions.
  1239. **block_kwargs, # Arguments for SynthesisBlock.
  1240. ):
  1241. assert img_resolution >= 4 and img_resolution & (img_resolution - 1) == 0
  1242. super().__init__()
  1243. self.w_dim = w_dim
  1244. self.img_resolution = img_resolution
  1245. self.img_resolution_log2 = int(np.log2(img_resolution))
  1246. self.img_channels = img_channels
  1247. self.block_resolutions = [
  1248. 2**i for i in range(3, self.img_resolution_log2 + 1)
  1249. ]
  1250. channels_dict = {
  1251. res: min(channel_base // res, channel_max) for res in self.block_resolutions
  1252. }
  1253. fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8)
  1254. self.foreword = SynthesisForeword(
  1255. img_channels=img_channels,
  1256. in_channels=min(channel_base // 4, channel_max),
  1257. z_dim=z_dim * 2,
  1258. resolution=4,
  1259. )
  1260. self.num_ws = self.img_resolution_log2 * 2 - 2
  1261. for res in self.block_resolutions:
  1262. if res // 2 in channels_dict.keys():
  1263. in_channels = channels_dict[res // 2] if res > 4 else 0
  1264. else:
  1265. in_channels = min(channel_base // (res // 2), channel_max)
  1266. out_channels = channels_dict[res]
  1267. use_fp16 = res >= fp16_resolution
  1268. use_fp16 = False
  1269. is_last = res == self.img_resolution
  1270. block = SynthesisBlock(
  1271. in_channels,
  1272. out_channels,
  1273. w_dim=w_dim,
  1274. resolution=res,
  1275. img_channels=img_channels,
  1276. is_last=is_last,
  1277. use_fp16=use_fp16,
  1278. **block_kwargs,
  1279. )
  1280. setattr(self, f"b{res}", block)
  1281. def forward(self, x_global, mask, feats, ws, fname=None, **block_kwargs):
  1282. img = None
  1283. x, img = self.foreword(x_global, ws, feats, img)
  1284. for res in self.block_resolutions:
  1285. block = getattr(self, f"b{res}")
  1286. mod_vector0 = []
  1287. mod_vector0.append(ws[:, int(np.log2(res)) * 2 - 5])
  1288. mod_vector0.append(x_global.clone())
  1289. mod_vector0 = torch.cat(mod_vector0, dim=1)
  1290. mod_vector1 = []
  1291. mod_vector1.append(ws[:, int(np.log2(res)) * 2 - 4])
  1292. mod_vector1.append(x_global.clone())
  1293. mod_vector1 = torch.cat(mod_vector1, dim=1)
  1294. mod_vector_rgb = []
  1295. mod_vector_rgb.append(ws[:, int(np.log2(res)) * 2 - 3])
  1296. mod_vector_rgb.append(x_global.clone())
  1297. mod_vector_rgb = torch.cat(mod_vector_rgb, dim=1)
  1298. x, img = block(
  1299. x,
  1300. mask,
  1301. feats,
  1302. img,
  1303. (mod_vector0, mod_vector1, mod_vector_rgb),
  1304. fname=fname,
  1305. **block_kwargs,
  1306. )
  1307. return img
  1308. class MappingNetwork(torch.nn.Module):
  1309. def __init__(
  1310. self,
  1311. z_dim, # Input latent (Z) dimensionality, 0 = no latent.
  1312. c_dim, # Conditioning label (C) dimensionality, 0 = no label.
  1313. w_dim, # Intermediate latent (W) dimensionality.
  1314. num_ws, # Number of intermediate latents to output, None = do not broadcast.
  1315. num_layers=8, # Number of mapping layers.
  1316. embed_features=None, # Label embedding dimensionality, None = same as w_dim.
  1317. layer_features=None, # Number of intermediate features in the mapping layers, None = same as w_dim.
  1318. activation="lrelu", # Activation function: 'relu', 'lrelu', etc.
  1319. lr_multiplier=0.01, # Learning rate multiplier for the mapping layers.
  1320. w_avg_beta=0.995, # Decay for tracking the moving average of W during training, None = do not track.
  1321. ):
  1322. super().__init__()
  1323. self.z_dim = z_dim
  1324. self.c_dim = c_dim
  1325. self.w_dim = w_dim
  1326. self.num_ws = num_ws
  1327. self.num_layers = num_layers
  1328. self.w_avg_beta = w_avg_beta
  1329. if embed_features is None:
  1330. embed_features = w_dim
  1331. if c_dim == 0:
  1332. embed_features = 0
  1333. if layer_features is None:
  1334. layer_features = w_dim
  1335. features_list = (
  1336. [z_dim + embed_features] + [layer_features] * (num_layers - 1) + [w_dim]
  1337. )
  1338. if c_dim > 0:
  1339. self.embed = FullyConnectedLayer(c_dim, embed_features)
  1340. for idx in range(num_layers):
  1341. in_features = features_list[idx]
  1342. out_features = features_list[idx + 1]
  1343. layer = FullyConnectedLayer(
  1344. in_features,
  1345. out_features,
  1346. activation=activation,
  1347. lr_multiplier=lr_multiplier,
  1348. )
  1349. setattr(self, f"fc{idx}", layer)
  1350. if num_ws is not None and w_avg_beta is not None:
  1351. self.register_buffer("w_avg", torch.zeros([w_dim]))
  1352. def forward(
  1353. self, z, c, truncation_psi=1, truncation_cutoff=None, skip_w_avg_update=False
  1354. ):
  1355. # Embed, normalize, and concat inputs.
  1356. x = None
  1357. with torch.autograd.profiler.record_function("input"):
  1358. if self.z_dim > 0:
  1359. x = normalize_2nd_moment(z.to(torch.float32))
  1360. if self.c_dim > 0:
  1361. y = normalize_2nd_moment(self.embed(c.to(torch.float32)))
  1362. x = torch.cat([x, y], dim=1) if x is not None else y
  1363. # Main layers.
  1364. for idx in range(self.num_layers):
  1365. layer = getattr(self, f"fc{idx}")
  1366. x = layer(x)
  1367. # Update moving average of W.
  1368. if self.w_avg_beta is not None and self.training and not skip_w_avg_update:
  1369. with torch.autograd.profiler.record_function("update_w_avg"):
  1370. self.w_avg.copy_(
  1371. x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta)
  1372. )
  1373. # Broadcast.
  1374. if self.num_ws is not None:
  1375. with torch.autograd.profiler.record_function("broadcast"):
  1376. x = x.unsqueeze(1).repeat([1, self.num_ws, 1])
  1377. # Apply truncation.
  1378. if truncation_psi != 1:
  1379. with torch.autograd.profiler.record_function("truncate"):
  1380. assert self.w_avg_beta is not None
  1381. if self.num_ws is None or truncation_cutoff is None:
  1382. x = self.w_avg.lerp(x, truncation_psi)
  1383. else:
  1384. x[:, :truncation_cutoff] = self.w_avg.lerp(
  1385. x[:, :truncation_cutoff], truncation_psi
  1386. )
  1387. return x
  1388. class Generator(torch.nn.Module):
  1389. def __init__(
  1390. self,
  1391. z_dim, # Input latent (Z) dimensionality.
  1392. c_dim, # Conditioning label (C) dimensionality.
  1393. w_dim, # Intermediate latent (W) dimensionality.
  1394. img_resolution, # Output resolution.
  1395. img_channels, # Number of output color channels.
  1396. encoder_kwargs={}, # Arguments for EncoderNetwork.
  1397. mapping_kwargs={}, # Arguments for MappingNetwork.
  1398. synthesis_kwargs={}, # Arguments for SynthesisNetwork.
  1399. ):
  1400. super().__init__()
  1401. self.z_dim = z_dim
  1402. self.c_dim = c_dim
  1403. self.w_dim = w_dim
  1404. self.img_resolution = img_resolution
  1405. self.img_channels = img_channels
  1406. self.encoder = EncoderNetwork(
  1407. c_dim=c_dim,
  1408. z_dim=z_dim,
  1409. img_resolution=img_resolution,
  1410. img_channels=img_channels,
  1411. **encoder_kwargs,
  1412. )
  1413. self.synthesis = SynthesisNetwork(
  1414. z_dim=z_dim,
  1415. w_dim=w_dim,
  1416. img_resolution=img_resolution,
  1417. img_channels=img_channels,
  1418. **synthesis_kwargs,
  1419. )
  1420. self.num_ws = self.synthesis.num_ws
  1421. self.mapping = MappingNetwork(
  1422. z_dim=z_dim, c_dim=c_dim, w_dim=w_dim, num_ws=self.num_ws, **mapping_kwargs
  1423. )
  1424. def forward(
  1425. self,
  1426. img,
  1427. c,
  1428. fname=None,
  1429. truncation_psi=1,
  1430. truncation_cutoff=None,
  1431. **synthesis_kwargs,
  1432. ):
  1433. mask = img[:, -1].unsqueeze(1)
  1434. x_global, z, feats = self.encoder(img, c)
  1435. ws = self.mapping(
  1436. z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff
  1437. )
  1438. img = self.synthesis(x_global, mask, feats, ws, fname=fname, **synthesis_kwargs)
  1439. return img
  1440. FCF_MODEL_URL = os.environ.get(
  1441. "FCF_MODEL_URL",
  1442. "https://github.com/Sanster/models/releases/download/add_fcf/places_512_G.pth",
  1443. )
  1444. FCF_MODEL_MD5 = os.environ.get("FCF_MODEL_MD5", "3323152bc01bf1c56fd8aba74435a211")
  1445. class FcF(InpaintModel):
  1446. name = "fcf"
  1447. min_size = 512
  1448. pad_mod = 512
  1449. pad_to_square = True
  1450. is_erase_model = True
  1451. def init_model(self, device, **kwargs):
  1452. seed = 0
  1453. random.seed(seed)
  1454. np.random.seed(seed)
  1455. torch.manual_seed(seed)
  1456. torch.cuda.manual_seed_all(seed)
  1457. torch.backends.cudnn.deterministic = True
  1458. torch.backends.cudnn.benchmark = False
  1459. kwargs = {
  1460. "channel_base": 1 * 32768,
  1461. "channel_max": 512,
  1462. "num_fp16_res": 4,
  1463. "conv_clamp": 256,
  1464. }
  1465. G = Generator(
  1466. z_dim=512,
  1467. c_dim=0,
  1468. w_dim=512,
  1469. img_resolution=512,
  1470. img_channels=3,
  1471. synthesis_kwargs=kwargs,
  1472. encoder_kwargs=kwargs,
  1473. mapping_kwargs={"num_layers": 2},
  1474. )
  1475. self.model = load_model(G, FCF_MODEL_URL, device, FCF_MODEL_MD5)
  1476. self.label = torch.zeros([1, self.model.c_dim], device=device)
  1477. @staticmethod
  1478. def download():
  1479. download_model(FCF_MODEL_URL, FCF_MODEL_MD5)
  1480. @staticmethod
  1481. def is_downloaded() -> bool:
  1482. return os.path.exists(get_cache_path_by_url(FCF_MODEL_URL))
  1483. @torch.no_grad()
  1484. def __call__(self, image, mask, config: InpaintRequest):
  1485. """
  1486. images: [H, W, C] RGB, not normalized
  1487. masks: [H, W]
  1488. return: BGR IMAGE
  1489. """
  1490. if image.shape[0] == 512 and image.shape[1] == 512:
  1491. return self._pad_forward(image, mask, config)
  1492. boxes = boxes_from_mask(mask)
  1493. crop_result = []
  1494. config.hd_strategy_crop_margin = 128
  1495. for box in boxes:
  1496. crop_image, crop_mask, crop_box = self._crop_box(image, mask, box, config)
  1497. origin_size = crop_image.shape[:2]
  1498. resize_image = resize_max_size(crop_image, size_limit=512)
  1499. resize_mask = resize_max_size(crop_mask, size_limit=512)
  1500. inpaint_result = self._pad_forward(resize_image, resize_mask, config)
  1501. # only paste masked area result
  1502. inpaint_result = cv2.resize(
  1503. inpaint_result,
  1504. (origin_size[1], origin_size[0]),
  1505. interpolation=cv2.INTER_CUBIC,
  1506. )
  1507. original_pixel_indices = crop_mask < 127
  1508. inpaint_result[original_pixel_indices] = crop_image[:, :, ::-1][
  1509. original_pixel_indices
  1510. ]
  1511. crop_result.append((inpaint_result, crop_box))
  1512. inpaint_result = image[:, :, ::-1].copy()
  1513. for crop_image, crop_box in crop_result:
  1514. x1, y1, x2, y2 = crop_box
  1515. inpaint_result[y1:y2, x1:x2, :] = crop_image
  1516. return inpaint_result
  1517. def forward(self, image, mask, config: InpaintRequest):
  1518. """Input images and output images have same size
  1519. images: [H, W, C] RGB
  1520. masks: [H, W] mask area == 255
  1521. return: BGR IMAGE
  1522. """
  1523. image = norm_img(image) # [0, 1]
  1524. image = image * 2 - 1 # [0, 1] -> [-1, 1]
  1525. mask = (mask > 120) * 255
  1526. mask = norm_img(mask)
  1527. image = torch.from_numpy(image).unsqueeze(0).to(self.device)
  1528. mask = torch.from_numpy(mask).unsqueeze(0).to(self.device)
  1529. erased_img = image * (1 - mask)
  1530. input_image = torch.cat([0.5 - mask, erased_img], dim=1)
  1531. output = self.model(
  1532. input_image, self.label, truncation_psi=0.1, noise_mode="none"
  1533. )
  1534. output = (
  1535. (output.permute(0, 2, 3, 1) * 127.5 + 127.5)
  1536. .round()
  1537. .clamp(0, 255)
  1538. .to(torch.uint8)
  1539. )
  1540. output = output[0].cpu().numpy()
  1541. cur_res = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
  1542. return cur_res