utils.py 37 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032
  1. import collections
  2. import gc
  3. import math
  4. import random
  5. import traceback
  6. from itertools import repeat
  7. from typing import Any
  8. import numpy as np
  9. import torch
  10. from diffusers import (
  11. DDIMScheduler,
  12. DPMSolverMultistepScheduler,
  13. DPMSolverSinglestepScheduler,
  14. EulerAncestralDiscreteScheduler,
  15. EulerDiscreteScheduler,
  16. HeunDiscreteScheduler,
  17. KDPM2AncestralDiscreteScheduler,
  18. KDPM2DiscreteScheduler,
  19. LCMScheduler,
  20. LMSDiscreteScheduler,
  21. PNDMScheduler,
  22. UniPCMultistepScheduler,
  23. )
  24. from loguru import logger
  25. from torch import conv2d, conv_transpose2d
  26. from sorawm.iopaint.schema import SDSampler
  27. def make_beta_schedule(
  28. device, schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3
  29. ):
  30. if schedule == "linear":
  31. betas = (
  32. torch.linspace(
  33. linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64
  34. )
  35. ** 2
  36. )
  37. elif schedule == "cosine":
  38. timesteps = (
  39. torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
  40. ).to(device)
  41. alphas = timesteps / (1 + cosine_s) * np.pi / 2
  42. alphas = torch.cos(alphas).pow(2).to(device)
  43. alphas = alphas / alphas[0]
  44. betas = 1 - alphas[1:] / alphas[:-1]
  45. betas = np.clip(betas, a_min=0, a_max=0.999)
  46. elif schedule == "sqrt_linear":
  47. betas = torch.linspace(
  48. linear_start, linear_end, n_timestep, dtype=torch.float64
  49. )
  50. elif schedule == "sqrt":
  51. betas = (
  52. torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
  53. ** 0.5
  54. )
  55. else:
  56. raise ValueError(f"schedule '{schedule}' unknown.")
  57. return betas.numpy()
  58. def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
  59. # select alphas for computing the variance schedule
  60. alphas = alphacums[ddim_timesteps]
  61. alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
  62. # according the the formula provided in https://arxiv.org/abs/2010.02502
  63. sigmas = eta * np.sqrt(
  64. (1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)
  65. )
  66. if verbose:
  67. print(
  68. f"Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}"
  69. )
  70. print(
  71. f"For the chosen value of eta, which is {eta}, "
  72. f"this results in the following sigma_t schedule for ddim sampler {sigmas}"
  73. )
  74. return sigmas, alphas, alphas_prev
  75. def make_ddim_timesteps(
  76. ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True
  77. ):
  78. if ddim_discr_method == "uniform":
  79. c = num_ddpm_timesteps // num_ddim_timesteps
  80. ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
  81. elif ddim_discr_method == "quad":
  82. ddim_timesteps = (
  83. (np.linspace(0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps)) ** 2
  84. ).astype(int)
  85. else:
  86. raise NotImplementedError(
  87. f'There is no ddim discretization method called "{ddim_discr_method}"'
  88. )
  89. # assert ddim_timesteps.shape[0] == num_ddim_timesteps
  90. # add one to get the final alpha values right (the ones from first scale to data during sampling)
  91. steps_out = ddim_timesteps + 1
  92. if verbose:
  93. print(f"Selected timesteps for ddim sampler: {steps_out}")
  94. return steps_out
  95. def noise_like(shape, device, repeat=False):
  96. repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(
  97. shape[0], *((1,) * (len(shape) - 1))
  98. )
  99. noise = lambda: torch.randn(shape, device=device)
  100. return repeat_noise() if repeat else noise()
  101. def timestep_embedding(device, timesteps, dim, max_period=10000, repeat_only=False):
  102. """
  103. Create sinusoidal timestep embeddings.
  104. :param timesteps: a 1-D Tensor of N indices, one per batch element.
  105. These may be fractional.
  106. :param dim: the dimension of the output.
  107. :param max_period: controls the minimum frequency of the embeddings.
  108. :return: an [N x dim] Tensor of positional embeddings.
  109. """
  110. half = dim // 2
  111. freqs = torch.exp(
  112. -math.log(max_period)
  113. * torch.arange(start=0, end=half, dtype=torch.float32)
  114. / half
  115. ).to(device=device)
  116. args = timesteps[:, None].float() * freqs[None]
  117. embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
  118. if dim % 2:
  119. embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
  120. return embedding
  121. ###### MAT and FcF #######
  122. def normalize_2nd_moment(x, dim=1):
  123. return (
  124. x * (x.square().mean(dim=dim, keepdim=True) + torch.finfo(x.dtype).eps).rsqrt()
  125. )
  126. class EasyDict(dict):
  127. """Convenience class that behaves like a dict but allows access with the attribute syntax."""
  128. def __getattr__(self, name: str) -> Any:
  129. try:
  130. return self[name]
  131. except KeyError:
  132. raise AttributeError(name)
  133. def __setattr__(self, name: str, value: Any) -> None:
  134. self[name] = value
  135. def __delattr__(self, name: str) -> None:
  136. del self[name]
  137. def _bias_act_ref(x, b=None, dim=1, act="linear", alpha=None, gain=None, clamp=None):
  138. """Slow reference implementation of `bias_act()` using standard TensorFlow ops."""
  139. assert isinstance(x, torch.Tensor)
  140. assert clamp is None or clamp >= 0
  141. spec = activation_funcs[act]
  142. alpha = float(alpha if alpha is not None else spec.def_alpha)
  143. gain = float(gain if gain is not None else spec.def_gain)
  144. clamp = float(clamp if clamp is not None else -1)
  145. # Add bias.
  146. if b is not None:
  147. assert isinstance(b, torch.Tensor) and b.ndim == 1
  148. assert 0 <= dim < x.ndim
  149. assert b.shape[0] == x.shape[dim]
  150. x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)])
  151. # Evaluate activation function.
  152. alpha = float(alpha)
  153. x = spec.func(x, alpha=alpha)
  154. # Scale by gain.
  155. gain = float(gain)
  156. if gain != 1:
  157. x = x * gain
  158. # Clamp.
  159. if clamp >= 0:
  160. x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type
  161. return x
  162. def bias_act(
  163. x, b=None, dim=1, act="linear", alpha=None, gain=None, clamp=None, impl="ref"
  164. ):
  165. r"""Fused bias and activation function.
  166. Adds bias `b` to activation tensor `x`, evaluates activation function `act`,
  167. and scales the result by `gain`. Each of the steps is optional. In most cases,
  168. the fused op is considerably more efficient than performing the same calculation
  169. using standard PyTorch ops. It supports first and second order gradients,
  170. but not third order gradients.
  171. Args:
  172. x: Input activation tensor. Can be of any shape.
  173. b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type
  174. as `x`. The shape must be known, and it must match the dimension of `x`
  175. corresponding to `dim`.
  176. dim: The dimension in `x` corresponding to the elements of `b`.
  177. The value of `dim` is ignored if `b` is not specified.
  178. act: Name of the activation function to evaluate, or `"linear"` to disable.
  179. Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc.
  180. See `activation_funcs` for a full list. `None` is not allowed.
  181. alpha: Shape parameter for the activation function, or `None` to use the default.
  182. gain: Scaling factor for the output tensor, or `None` to use default.
  183. See `activation_funcs` for the default scaling of each activation function.
  184. If unsure, consider specifying 1.
  185. clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable
  186. the clamping (default).
  187. impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
  188. Returns:
  189. Tensor of the same shape and datatype as `x`.
  190. """
  191. assert isinstance(x, torch.Tensor)
  192. assert impl in ["ref", "cuda"]
  193. return _bias_act_ref(
  194. x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp
  195. )
  196. def _get_filter_size(f):
  197. if f is None:
  198. return 1, 1
  199. assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
  200. fw = f.shape[-1]
  201. fh = f.shape[0]
  202. fw = int(fw)
  203. fh = int(fh)
  204. assert fw >= 1 and fh >= 1
  205. return fw, fh
  206. def _get_weight_shape(w):
  207. shape = [int(sz) for sz in w.shape]
  208. return shape
  209. def _parse_scaling(scaling):
  210. if isinstance(scaling, int):
  211. scaling = [scaling, scaling]
  212. assert isinstance(scaling, (list, tuple))
  213. assert all(isinstance(x, int) for x in scaling)
  214. sx, sy = scaling
  215. assert sx >= 1 and sy >= 1
  216. return sx, sy
  217. def _parse_padding(padding):
  218. if isinstance(padding, int):
  219. padding = [padding, padding]
  220. assert isinstance(padding, (list, tuple))
  221. assert all(isinstance(x, int) for x in padding)
  222. if len(padding) == 2:
  223. padx, pady = padding
  224. padding = [padx, padx, pady, pady]
  225. padx0, padx1, pady0, pady1 = padding
  226. return padx0, padx1, pady0, pady1
  227. def setup_filter(
  228. f,
  229. device=torch.device("cpu"),
  230. normalize=True,
  231. flip_filter=False,
  232. gain=1,
  233. separable=None,
  234. ):
  235. r"""Convenience function to setup 2D FIR filter for `upfirdn2d()`.
  236. Args:
  237. f: Torch tensor, numpy array, or python list of the shape
  238. `[filter_height, filter_width]` (non-separable),
  239. `[filter_taps]` (separable),
  240. `[]` (impulse), or
  241. `None` (identity).
  242. device: Result device (default: cpu).
  243. normalize: Normalize the filter so that it retains the magnitude
  244. for constant input signal (DC)? (default: True).
  245. flip_filter: Flip the filter? (default: False).
  246. gain: Overall scaling factor for signal magnitude (default: 1).
  247. separable: Return a separable filter? (default: select automatically).
  248. Returns:
  249. Float32 tensor of the shape
  250. `[filter_height, filter_width]` (non-separable) or
  251. `[filter_taps]` (separable).
  252. """
  253. # Validate.
  254. if f is None:
  255. f = 1
  256. f = torch.as_tensor(f, dtype=torch.float32)
  257. assert f.ndim in [0, 1, 2]
  258. assert f.numel() > 0
  259. if f.ndim == 0:
  260. f = f[np.newaxis]
  261. # Separable?
  262. if separable is None:
  263. separable = f.ndim == 1 and f.numel() >= 8
  264. if f.ndim == 1 and not separable:
  265. f = f.ger(f)
  266. assert f.ndim == (1 if separable else 2)
  267. # Apply normalize, flip, gain, and device.
  268. if normalize:
  269. f /= f.sum()
  270. if flip_filter:
  271. f = f.flip(list(range(f.ndim)))
  272. f = f * (gain ** (f.ndim / 2))
  273. f = f.to(device=device)
  274. return f
  275. def _ntuple(n):
  276. def parse(x):
  277. if isinstance(x, collections.abc.Iterable):
  278. return x
  279. return tuple(repeat(x, n))
  280. return parse
  281. to_2tuple = _ntuple(2)
  282. activation_funcs = {
  283. "linear": EasyDict(
  284. func=lambda x, **_: x,
  285. def_alpha=0,
  286. def_gain=1,
  287. cuda_idx=1,
  288. ref="",
  289. has_2nd_grad=False,
  290. ),
  291. "relu": EasyDict(
  292. func=lambda x, **_: torch.nn.functional.relu(x),
  293. def_alpha=0,
  294. def_gain=np.sqrt(2),
  295. cuda_idx=2,
  296. ref="y",
  297. has_2nd_grad=False,
  298. ),
  299. "lrelu": EasyDict(
  300. func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha),
  301. def_alpha=0.2,
  302. def_gain=np.sqrt(2),
  303. cuda_idx=3,
  304. ref="y",
  305. has_2nd_grad=False,
  306. ),
  307. "tanh": EasyDict(
  308. func=lambda x, **_: torch.tanh(x),
  309. def_alpha=0,
  310. def_gain=1,
  311. cuda_idx=4,
  312. ref="y",
  313. has_2nd_grad=True,
  314. ),
  315. "sigmoid": EasyDict(
  316. func=lambda x, **_: torch.sigmoid(x),
  317. def_alpha=0,
  318. def_gain=1,
  319. cuda_idx=5,
  320. ref="y",
  321. has_2nd_grad=True,
  322. ),
  323. "elu": EasyDict(
  324. func=lambda x, **_: torch.nn.functional.elu(x),
  325. def_alpha=0,
  326. def_gain=1,
  327. cuda_idx=6,
  328. ref="y",
  329. has_2nd_grad=True,
  330. ),
  331. "selu": EasyDict(
  332. func=lambda x, **_: torch.nn.functional.selu(x),
  333. def_alpha=0,
  334. def_gain=1,
  335. cuda_idx=7,
  336. ref="y",
  337. has_2nd_grad=True,
  338. ),
  339. "softplus": EasyDict(
  340. func=lambda x, **_: torch.nn.functional.softplus(x),
  341. def_alpha=0,
  342. def_gain=1,
  343. cuda_idx=8,
  344. ref="y",
  345. has_2nd_grad=True,
  346. ),
  347. "swish": EasyDict(
  348. func=lambda x, **_: torch.sigmoid(x) * x,
  349. def_alpha=0,
  350. def_gain=np.sqrt(2),
  351. cuda_idx=9,
  352. ref="x",
  353. has_2nd_grad=True,
  354. ),
  355. }
  356. def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl="cuda"):
  357. r"""Pad, upsample, filter, and downsample a batch of 2D images.
  358. Performs the following sequence of operations for each channel:
  359. 1. Upsample the image by inserting N-1 zeros after each pixel (`up`).
  360. 2. Pad the image with the specified number of zeros on each side (`padding`).
  361. Negative padding corresponds to cropping the image.
  362. 3. Convolve the image with the specified 2D FIR filter (`f`), shrinking it
  363. so that the footprint of all output pixels lies within the input image.
  364. 4. Downsample the image by keeping every Nth pixel (`down`).
  365. This sequence of operations bears close resemblance to scipy.signal.upfirdn().
  366. The fused op is considerably more efficient than performing the same calculation
  367. using standard PyTorch ops. It supports gradients of arbitrary order.
  368. Args:
  369. x: Float32/float64/float16 input tensor of the shape
  370. `[batch_size, num_channels, in_height, in_width]`.
  371. f: Float32 FIR filter of the shape
  372. `[filter_height, filter_width]` (non-separable),
  373. `[filter_taps]` (separable), or
  374. `None` (identity).
  375. up: Integer upsampling factor. Can be a single int or a list/tuple
  376. `[x, y]` (default: 1).
  377. down: Integer downsampling factor. Can be a single int or a list/tuple
  378. `[x, y]` (default: 1).
  379. padding: Padding with respect to the upsampled image. Can be a single number
  380. or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
  381. (default: 0).
  382. flip_filter: False = convolution, True = correlation (default: False).
  383. gain: Overall scaling factor for signal magnitude (default: 1).
  384. impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
  385. Returns:
  386. Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
  387. """
  388. # assert isinstance(x, torch.Tensor)
  389. # assert impl in ['ref', 'cuda']
  390. return _upfirdn2d_ref(
  391. x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain
  392. )
  393. def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1):
  394. """Slow reference implementation of `upfirdn2d()` using standard PyTorch ops."""
  395. # Validate arguments.
  396. assert isinstance(x, torch.Tensor) and x.ndim == 4
  397. if f is None:
  398. f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
  399. assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
  400. assert not f.requires_grad
  401. batch_size, num_channels, in_height, in_width = x.shape
  402. # upx, upy = _parse_scaling(up)
  403. # downx, downy = _parse_scaling(down)
  404. upx, upy = up, up
  405. downx, downy = down, down
  406. # padx0, padx1, pady0, pady1 = _parse_padding(padding)
  407. padx0, padx1, pady0, pady1 = padding[0], padding[1], padding[2], padding[3]
  408. # Upsample by inserting zeros.
  409. x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1])
  410. x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1])
  411. x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx])
  412. # Pad or crop.
  413. x = torch.nn.functional.pad(
  414. x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)]
  415. )
  416. x = x[
  417. :,
  418. :,
  419. max(-pady0, 0) : x.shape[2] - max(-pady1, 0),
  420. max(-padx0, 0) : x.shape[3] - max(-padx1, 0),
  421. ]
  422. # Setup filter.
  423. f = f * (gain ** (f.ndim / 2))
  424. f = f.to(x.dtype)
  425. if not flip_filter:
  426. f = f.flip(list(range(f.ndim)))
  427. # Convolve with the filter.
  428. f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim)
  429. if f.ndim == 4:
  430. x = conv2d(input=x, weight=f, groups=num_channels)
  431. else:
  432. x = conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels)
  433. x = conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels)
  434. # Downsample by throwing away pixels.
  435. x = x[:, :, ::downy, ::downx]
  436. return x
  437. def downsample2d(x, f, down=2, padding=0, flip_filter=False, gain=1, impl="cuda"):
  438. r"""Downsample a batch of 2D images using the given 2D FIR filter.
  439. By default, the result is padded so that its shape is a fraction of the input.
  440. User-specified padding is applied on top of that, with negative values
  441. indicating cropping. Pixels outside the image are assumed to be zero.
  442. Args:
  443. x: Float32/float64/float16 input tensor of the shape
  444. `[batch_size, num_channels, in_height, in_width]`.
  445. f: Float32 FIR filter of the shape
  446. `[filter_height, filter_width]` (non-separable),
  447. `[filter_taps]` (separable), or
  448. `None` (identity).
  449. down: Integer downsampling factor. Can be a single int or a list/tuple
  450. `[x, y]` (default: 1).
  451. padding: Padding with respect to the input. Can be a single number or a
  452. list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
  453. (default: 0).
  454. flip_filter: False = convolution, True = correlation (default: False).
  455. gain: Overall scaling factor for signal magnitude (default: 1).
  456. impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
  457. Returns:
  458. Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
  459. """
  460. downx, downy = _parse_scaling(down)
  461. # padx0, padx1, pady0, pady1 = _parse_padding(padding)
  462. padx0, padx1, pady0, pady1 = padding, padding, padding, padding
  463. fw, fh = _get_filter_size(f)
  464. p = [
  465. padx0 + (fw - downx + 1) // 2,
  466. padx1 + (fw - downx) // 2,
  467. pady0 + (fh - downy + 1) // 2,
  468. pady1 + (fh - downy) // 2,
  469. ]
  470. return upfirdn2d(
  471. x, f, down=down, padding=p, flip_filter=flip_filter, gain=gain, impl=impl
  472. )
  473. def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl="cuda"):
  474. r"""Upsample a batch of 2D images using the given 2D FIR filter.
  475. By default, the result is padded so that its shape is a multiple of the input.
  476. User-specified padding is applied on top of that, with negative values
  477. indicating cropping. Pixels outside the image are assumed to be zero.
  478. Args:
  479. x: Float32/float64/float16 input tensor of the shape
  480. `[batch_size, num_channels, in_height, in_width]`.
  481. f: Float32 FIR filter of the shape
  482. `[filter_height, filter_width]` (non-separable),
  483. `[filter_taps]` (separable), or
  484. `None` (identity).
  485. up: Integer upsampling factor. Can be a single int or a list/tuple
  486. `[x, y]` (default: 1).
  487. padding: Padding with respect to the output. Can be a single number or a
  488. list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
  489. (default: 0).
  490. flip_filter: False = convolution, True = correlation (default: False).
  491. gain: Overall scaling factor for signal magnitude (default: 1).
  492. impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
  493. Returns:
  494. Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
  495. """
  496. upx, upy = _parse_scaling(up)
  497. # upx, upy = up, up
  498. padx0, padx1, pady0, pady1 = _parse_padding(padding)
  499. # padx0, padx1, pady0, pady1 = padding, padding, padding, padding
  500. fw, fh = _get_filter_size(f)
  501. p = [
  502. padx0 + (fw + upx - 1) // 2,
  503. padx1 + (fw - upx) // 2,
  504. pady0 + (fh + upy - 1) // 2,
  505. pady1 + (fh - upy) // 2,
  506. ]
  507. return upfirdn2d(
  508. x,
  509. f,
  510. up=up,
  511. padding=p,
  512. flip_filter=flip_filter,
  513. gain=gain * upx * upy,
  514. impl=impl,
  515. )
  516. class MinibatchStdLayer(torch.nn.Module):
  517. def __init__(self, group_size, num_channels=1):
  518. super().__init__()
  519. self.group_size = group_size
  520. self.num_channels = num_channels
  521. def forward(self, x):
  522. N, C, H, W = x.shape
  523. G = (
  524. torch.min(torch.as_tensor(self.group_size), torch.as_tensor(N))
  525. if self.group_size is not None
  526. else N
  527. )
  528. F = self.num_channels
  529. c = C // F
  530. y = x.reshape(
  531. G, -1, F, c, H, W
  532. ) # [GnFcHW] Split minibatch N into n groups of size G, and channels C into F groups of size c.
  533. y = y - y.mean(dim=0) # [GnFcHW] Subtract mean over group.
  534. y = y.square().mean(dim=0) # [nFcHW] Calc variance over group.
  535. y = (y + 1e-8).sqrt() # [nFcHW] Calc stddev over group.
  536. y = y.mean(dim=[2, 3, 4]) # [nF] Take average over channels and pixels.
  537. y = y.reshape(-1, F, 1, 1) # [nF11] Add missing dimensions.
  538. y = y.repeat(G, 1, H, W) # [NFHW] Replicate over group and pixels.
  539. x = torch.cat([x, y], dim=1) # [NCHW] Append to input as new channels.
  540. return x
  541. class FullyConnectedLayer(torch.nn.Module):
  542. def __init__(
  543. self,
  544. in_features, # Number of input features.
  545. out_features, # Number of output features.
  546. bias=True, # Apply additive bias before the activation function?
  547. activation="linear", # Activation function: 'relu', 'lrelu', etc.
  548. lr_multiplier=1, # Learning rate multiplier.
  549. bias_init=0, # Initial value for the additive bias.
  550. ):
  551. super().__init__()
  552. self.weight = torch.nn.Parameter(
  553. torch.randn([out_features, in_features]) / lr_multiplier
  554. )
  555. self.bias = (
  556. torch.nn.Parameter(torch.full([out_features], np.float32(bias_init)))
  557. if bias
  558. else None
  559. )
  560. self.activation = activation
  561. self.weight_gain = lr_multiplier / np.sqrt(in_features)
  562. self.bias_gain = lr_multiplier
  563. def forward(self, x):
  564. w = self.weight * self.weight_gain
  565. b = self.bias
  566. if b is not None and self.bias_gain != 1:
  567. b = b * self.bias_gain
  568. if self.activation == "linear" and b is not None:
  569. # out = torch.addmm(b.unsqueeze(0), x, w.t())
  570. x = x.matmul(w.t())
  571. out = x + b.reshape([-1 if i == x.ndim - 1 else 1 for i in range(x.ndim)])
  572. else:
  573. x = x.matmul(w.t())
  574. out = bias_act(x, b, act=self.activation, dim=x.ndim - 1)
  575. return out
  576. def _conv2d_wrapper(
  577. x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True
  578. ):
  579. """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations."""
  580. out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
  581. # Flip weight if requested.
  582. if (
  583. not flip_weight
  584. ): # conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False).
  585. w = w.flip([2, 3])
  586. # Workaround performance pitfall in cuDNN 8.0.5, triggered when using
  587. # 1x1 kernel + memory_format=channels_last + less than 64 channels.
  588. if (
  589. kw == 1
  590. and kh == 1
  591. and stride == 1
  592. and padding in [0, [0, 0], (0, 0)]
  593. and not transpose
  594. ):
  595. if x.stride()[1] == 1 and min(out_channels, in_channels_per_group) < 64:
  596. if out_channels <= 4 and groups == 1:
  597. in_shape = x.shape
  598. x = w.squeeze(3).squeeze(2) @ x.reshape(
  599. [in_shape[0], in_channels_per_group, -1]
  600. )
  601. x = x.reshape([in_shape[0], out_channels, in_shape[2], in_shape[3]])
  602. else:
  603. x = x.to(memory_format=torch.contiguous_format)
  604. w = w.to(memory_format=torch.contiguous_format)
  605. x = conv2d(x, w, groups=groups)
  606. return x.to(memory_format=torch.channels_last)
  607. # Otherwise => execute using conv2d_gradfix.
  608. op = conv_transpose2d if transpose else conv2d
  609. return op(x, w, stride=stride, padding=padding, groups=groups)
  610. def conv2d_resample(
  611. x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False
  612. ):
  613. r"""2D convolution with optional up/downsampling.
  614. Padding is performed only once at the beginning, not between the operations.
  615. Args:
  616. x: Input tensor of shape
  617. `[batch_size, in_channels, in_height, in_width]`.
  618. w: Weight tensor of shape
  619. `[out_channels, in_channels//groups, kernel_height, kernel_width]`.
  620. f: Low-pass filter for up/downsampling. Must be prepared beforehand by
  621. calling setup_filter(). None = identity (default).
  622. up: Integer upsampling factor (default: 1).
  623. down: Integer downsampling factor (default: 1).
  624. padding: Padding with respect to the upsampled image. Can be a single number
  625. or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
  626. (default: 0).
  627. groups: Split input channels into N groups (default: 1).
  628. flip_weight: False = convolution, True = correlation (default: True).
  629. flip_filter: False = convolution, True = correlation (default: False).
  630. Returns:
  631. Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
  632. """
  633. # Validate arguments.
  634. assert isinstance(x, torch.Tensor) and (x.ndim == 4)
  635. assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype)
  636. assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2])
  637. assert isinstance(up, int) and (up >= 1)
  638. assert isinstance(down, int) and (down >= 1)
  639. # assert isinstance(groups, int) and (groups >= 1), f"!!!!!! groups: {groups} isinstance(groups, int) {isinstance(groups, int)} {type(groups)}"
  640. out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
  641. fw, fh = _get_filter_size(f)
  642. # px0, px1, py0, py1 = _parse_padding(padding)
  643. px0, px1, py0, py1 = padding, padding, padding, padding
  644. # Adjust padding to account for up/downsampling.
  645. if up > 1:
  646. px0 += (fw + up - 1) // 2
  647. px1 += (fw - up) // 2
  648. py0 += (fh + up - 1) // 2
  649. py1 += (fh - up) // 2
  650. if down > 1:
  651. px0 += (fw - down + 1) // 2
  652. px1 += (fw - down) // 2
  653. py0 += (fh - down + 1) // 2
  654. py1 += (fh - down) // 2
  655. # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve.
  656. if kw == 1 and kh == 1 and (down > 1 and up == 1):
  657. x = upfirdn2d(
  658. x=x, f=f, down=down, padding=[px0, px1, py0, py1], flip_filter=flip_filter
  659. )
  660. x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
  661. return x
  662. # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample.
  663. if kw == 1 and kh == 1 and (up > 1 and down == 1):
  664. x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
  665. x = upfirdn2d(
  666. x=x,
  667. f=f,
  668. up=up,
  669. padding=[px0, px1, py0, py1],
  670. gain=up**2,
  671. flip_filter=flip_filter,
  672. )
  673. return x
  674. # Fast path: downsampling only => use strided convolution.
  675. if down > 1 and up == 1:
  676. x = upfirdn2d(x=x, f=f, padding=[px0, px1, py0, py1], flip_filter=flip_filter)
  677. x = _conv2d_wrapper(
  678. x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight
  679. )
  680. return x
  681. # Fast path: upsampling with optional downsampling => use transpose strided convolution.
  682. if up > 1:
  683. if groups == 1:
  684. w = w.transpose(0, 1)
  685. else:
  686. w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw)
  687. w = w.transpose(1, 2)
  688. w = w.reshape(
  689. groups * in_channels_per_group, out_channels // groups, kh, kw
  690. )
  691. px0 -= kw - 1
  692. px1 -= kw - up
  693. py0 -= kh - 1
  694. py1 -= kh - up
  695. pxt = max(min(-px0, -px1), 0)
  696. pyt = max(min(-py0, -py1), 0)
  697. x = _conv2d_wrapper(
  698. x=x,
  699. w=w,
  700. stride=up,
  701. padding=[pyt, pxt],
  702. groups=groups,
  703. transpose=True,
  704. flip_weight=(not flip_weight),
  705. )
  706. x = upfirdn2d(
  707. x=x,
  708. f=f,
  709. padding=[px0 + pxt, px1 + pxt, py0 + pyt, py1 + pyt],
  710. gain=up**2,
  711. flip_filter=flip_filter,
  712. )
  713. if down > 1:
  714. x = upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
  715. return x
  716. # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d.
  717. if up == 1 and down == 1:
  718. if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0:
  719. return _conv2d_wrapper(
  720. x=x, w=w, padding=[py0, px0], groups=groups, flip_weight=flip_weight
  721. )
  722. # Fallback: Generic reference implementation.
  723. x = upfirdn2d(
  724. x=x,
  725. f=(f if up > 1 else None),
  726. up=up,
  727. padding=[px0, px1, py0, py1],
  728. gain=up**2,
  729. flip_filter=flip_filter,
  730. )
  731. x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
  732. if down > 1:
  733. x = upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
  734. return x
  735. class Conv2dLayer(torch.nn.Module):
  736. def __init__(
  737. self,
  738. in_channels, # Number of input channels.
  739. out_channels, # Number of output channels.
  740. kernel_size, # Width and height of the convolution kernel.
  741. bias=True, # Apply additive bias before the activation function?
  742. activation="linear", # Activation function: 'relu', 'lrelu', etc.
  743. up=1, # Integer upsampling factor.
  744. down=1, # Integer downsampling factor.
  745. resample_filter=[
  746. 1,
  747. 3,
  748. 3,
  749. 1,
  750. ], # Low-pass filter to apply when resampling activations.
  751. conv_clamp=None, # Clamp the output to +-X, None = disable clamping.
  752. channels_last=False, # Expect the input to have memory_format=channels_last?
  753. trainable=True, # Update the weights of this layer during training?
  754. ):
  755. super().__init__()
  756. self.activation = activation
  757. self.up = up
  758. self.down = down
  759. self.register_buffer("resample_filter", setup_filter(resample_filter))
  760. self.conv_clamp = conv_clamp
  761. self.padding = kernel_size // 2
  762. self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size**2))
  763. self.act_gain = activation_funcs[activation].def_gain
  764. memory_format = (
  765. torch.channels_last if channels_last else torch.contiguous_format
  766. )
  767. weight = torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(
  768. memory_format=memory_format
  769. )
  770. bias = torch.zeros([out_channels]) if bias else None
  771. if trainable:
  772. self.weight = torch.nn.Parameter(weight)
  773. self.bias = torch.nn.Parameter(bias) if bias is not None else None
  774. else:
  775. self.register_buffer("weight", weight)
  776. if bias is not None:
  777. self.register_buffer("bias", bias)
  778. else:
  779. self.bias = None
  780. def forward(self, x, gain=1):
  781. w = self.weight * self.weight_gain
  782. x = conv2d_resample(
  783. x=x,
  784. w=w,
  785. f=self.resample_filter,
  786. up=self.up,
  787. down=self.down,
  788. padding=self.padding,
  789. )
  790. act_gain = self.act_gain * gain
  791. act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None
  792. out = bias_act(
  793. x, self.bias, act=self.activation, gain=act_gain, clamp=act_clamp
  794. )
  795. return out
  796. def torch_gc():
  797. if torch.cuda.is_available():
  798. torch.cuda.empty_cache()
  799. torch.cuda.ipc_collect()
  800. gc.collect()
  801. def set_seed(seed: int):
  802. random.seed(seed)
  803. np.random.seed(seed)
  804. torch.manual_seed(seed)
  805. torch.cuda.manual_seed_all(seed)
  806. def get_scheduler(sd_sampler, scheduler_config):
  807. # https://github.com/huggingface/diffusers/issues/4167
  808. keys_to_pop = ["use_karras_sigmas", "algorithm_type"]
  809. scheduler_config = dict(scheduler_config)
  810. for it in keys_to_pop:
  811. scheduler_config.pop(it, None)
  812. # fmt: off
  813. samplers = {
  814. SDSampler.dpm_plus_plus_2m: [DPMSolverMultistepScheduler],
  815. SDSampler.dpm_plus_plus_2m_karras: [DPMSolverMultistepScheduler, dict(use_karras_sigmas=True)],
  816. SDSampler.dpm_plus_plus_2m_sde: [DPMSolverMultistepScheduler, dict(algorithm_type="sde-dpmsolver++")],
  817. SDSampler.dpm_plus_plus_2m_sde_karras: [DPMSolverMultistepScheduler, dict(algorithm_type="sde-dpmsolver++", use_karras_sigmas=True)],
  818. SDSampler.dpm_plus_plus_sde: [DPMSolverSinglestepScheduler],
  819. SDSampler.dpm_plus_plus_sde_karras: [DPMSolverSinglestepScheduler, dict(use_karras_sigmas=True)],
  820. SDSampler.dpm2: [KDPM2DiscreteScheduler],
  821. SDSampler.dpm2_karras: [KDPM2DiscreteScheduler, dict(use_karras_sigmas=True)],
  822. SDSampler.dpm2_a: [KDPM2AncestralDiscreteScheduler],
  823. SDSampler.dpm2_a_karras: [KDPM2AncestralDiscreteScheduler, dict(use_karras_sigmas=True)],
  824. SDSampler.euler: [EulerDiscreteScheduler],
  825. SDSampler.euler_a: [EulerAncestralDiscreteScheduler],
  826. SDSampler.heun: [HeunDiscreteScheduler],
  827. SDSampler.lms: [LMSDiscreteScheduler],
  828. SDSampler.lms_karras: [LMSDiscreteScheduler, dict(use_karras_sigmas=True)],
  829. SDSampler.ddim: [DDIMScheduler],
  830. SDSampler.pndm: [PNDMScheduler],
  831. SDSampler.uni_pc: [UniPCMultistepScheduler],
  832. SDSampler.lcm: [LCMScheduler],
  833. }
  834. # fmt: on
  835. if sd_sampler in samplers:
  836. if len(samplers[sd_sampler]) == 2:
  837. scheduler_cls, kwargs = samplers[sd_sampler]
  838. else:
  839. scheduler_cls, kwargs = samplers[sd_sampler][0], {}
  840. return scheduler_cls.from_config(scheduler_config, **kwargs)
  841. else:
  842. raise ValueError(sd_sampler)
  843. def is_local_files_only(**kwargs) -> bool:
  844. from huggingface_hub.constants import HF_HUB_OFFLINE
  845. return HF_HUB_OFFLINE or kwargs.get("local_files_only", False)
  846. def handle_from_pretrained_exceptions(func, **kwargs):
  847. try:
  848. return func(**kwargs)
  849. except ValueError as e:
  850. if "You are trying to load the model files of the `variant=fp16`" in str(e):
  851. logger.info("variant=fp16 not found, try revision=fp16")
  852. try:
  853. return func(**{**kwargs, "variant": None, "revision": "fp16"})
  854. except Exception as e:
  855. logger.info("revision=fp16 not found, try revision=main")
  856. return func(**{**kwargs, "variant": None, "revision": "main"})
  857. raise e
  858. except OSError as e:
  859. previous_traceback = traceback.format_exc()
  860. if "RevisionNotFoundError: 404 Client Error." in previous_traceback:
  861. logger.info("revision=fp16 not found, try revision=main")
  862. return func(**{**kwargs, "variant": None, "revision": "main"})
  863. elif "Max retries exceeded" in previous_traceback:
  864. logger.exception(
  865. "Fetching model from HuggingFace failed. "
  866. "If this is your first time downloading the model, you may need to set up proxy in terminal."
  867. "If the model has already been downloaded, you can add --local-files-only when starting."
  868. )
  869. exit(-1)
  870. raise e
  871. except Exception as e:
  872. raise e
  873. def get_torch_dtype(device, no_half: bool):
  874. device = str(device)
  875. use_fp16 = not no_half
  876. use_gpu = device == "cuda"
  877. # https://github.com/huggingface/diffusers/issues/4480
  878. # pipe.enable_attention_slicing and float16 will cause black output on mps
  879. # if device in ["cuda", "mps"] and use_fp16:
  880. if device in ["cuda"] and use_fp16:
  881. return use_gpu, torch.float16
  882. return use_gpu, torch.float32
  883. def enable_low_mem(pipe, enable: bool):
  884. if torch.backends.mps.is_available():
  885. # https://huggingface.co/docs/diffusers/v0.25.0/en/api/pipelines/stable_diffusion/image_variation#diffusers.StableDiffusionImageVariationPipeline.enable_attention_slicing
  886. # CUDA: Don't enable attention slicing if you're already using `scaled_dot_product_attention` (SDPA) from PyTorch 2.0 or xFormers.
  887. if enable:
  888. pipe.enable_attention_slicing("max")
  889. else:
  890. # https://huggingface.co/docs/diffusers/optimization/mps
  891. # Devices with less than 64GB of memory are recommended to use enable_attention_slicing
  892. pipe.enable_attention_slicing()
  893. if enable:
  894. pipe.vae.enable_tiling()