| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032 |
- import collections
- import gc
- import math
- import random
- import traceback
- from itertools import repeat
- from typing import Any
- import numpy as np
- import torch
- from diffusers import (
- DDIMScheduler,
- DPMSolverMultistepScheduler,
- DPMSolverSinglestepScheduler,
- EulerAncestralDiscreteScheduler,
- EulerDiscreteScheduler,
- HeunDiscreteScheduler,
- KDPM2AncestralDiscreteScheduler,
- KDPM2DiscreteScheduler,
- LCMScheduler,
- LMSDiscreteScheduler,
- PNDMScheduler,
- UniPCMultistepScheduler,
- )
- from loguru import logger
- from torch import conv2d, conv_transpose2d
- from sorawm.iopaint.schema import SDSampler
- def make_beta_schedule(
- device, schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3
- ):
- if schedule == "linear":
- betas = (
- torch.linspace(
- linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64
- )
- ** 2
- )
- elif schedule == "cosine":
- timesteps = (
- torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
- ).to(device)
- alphas = timesteps / (1 + cosine_s) * np.pi / 2
- alphas = torch.cos(alphas).pow(2).to(device)
- alphas = alphas / alphas[0]
- betas = 1 - alphas[1:] / alphas[:-1]
- betas = np.clip(betas, a_min=0, a_max=0.999)
- elif schedule == "sqrt_linear":
- betas = torch.linspace(
- linear_start, linear_end, n_timestep, dtype=torch.float64
- )
- elif schedule == "sqrt":
- betas = (
- torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
- ** 0.5
- )
- else:
- raise ValueError(f"schedule '{schedule}' unknown.")
- return betas.numpy()
- def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
- # select alphas for computing the variance schedule
- alphas = alphacums[ddim_timesteps]
- alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
- # according the the formula provided in https://arxiv.org/abs/2010.02502
- sigmas = eta * np.sqrt(
- (1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)
- )
- if verbose:
- print(
- f"Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}"
- )
- print(
- f"For the chosen value of eta, which is {eta}, "
- f"this results in the following sigma_t schedule for ddim sampler {sigmas}"
- )
- return sigmas, alphas, alphas_prev
- def make_ddim_timesteps(
- ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True
- ):
- if ddim_discr_method == "uniform":
- c = num_ddpm_timesteps // num_ddim_timesteps
- ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
- elif ddim_discr_method == "quad":
- ddim_timesteps = (
- (np.linspace(0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps)) ** 2
- ).astype(int)
- else:
- raise NotImplementedError(
- f'There is no ddim discretization method called "{ddim_discr_method}"'
- )
- # assert ddim_timesteps.shape[0] == num_ddim_timesteps
- # add one to get the final alpha values right (the ones from first scale to data during sampling)
- steps_out = ddim_timesteps + 1
- if verbose:
- print(f"Selected timesteps for ddim sampler: {steps_out}")
- return steps_out
- def noise_like(shape, device, repeat=False):
- repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(
- shape[0], *((1,) * (len(shape) - 1))
- )
- noise = lambda: torch.randn(shape, device=device)
- return repeat_noise() if repeat else noise()
- def timestep_embedding(device, timesteps, dim, max_period=10000, repeat_only=False):
- """
- Create sinusoidal timestep embeddings.
- :param timesteps: a 1-D Tensor of N indices, one per batch element.
- These may be fractional.
- :param dim: the dimension of the output.
- :param max_period: controls the minimum frequency of the embeddings.
- :return: an [N x dim] Tensor of positional embeddings.
- """
- half = dim // 2
- freqs = torch.exp(
- -math.log(max_period)
- * torch.arange(start=0, end=half, dtype=torch.float32)
- / half
- ).to(device=device)
- args = timesteps[:, None].float() * freqs[None]
- embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
- if dim % 2:
- embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
- return embedding
- ###### MAT and FcF #######
- def normalize_2nd_moment(x, dim=1):
- return (
- x * (x.square().mean(dim=dim, keepdim=True) + torch.finfo(x.dtype).eps).rsqrt()
- )
- class EasyDict(dict):
- """Convenience class that behaves like a dict but allows access with the attribute syntax."""
- def __getattr__(self, name: str) -> Any:
- try:
- return self[name]
- except KeyError:
- raise AttributeError(name)
- def __setattr__(self, name: str, value: Any) -> None:
- self[name] = value
- def __delattr__(self, name: str) -> None:
- del self[name]
- def _bias_act_ref(x, b=None, dim=1, act="linear", alpha=None, gain=None, clamp=None):
- """Slow reference implementation of `bias_act()` using standard TensorFlow ops."""
- assert isinstance(x, torch.Tensor)
- assert clamp is None or clamp >= 0
- spec = activation_funcs[act]
- alpha = float(alpha if alpha is not None else spec.def_alpha)
- gain = float(gain if gain is not None else spec.def_gain)
- clamp = float(clamp if clamp is not None else -1)
- # Add bias.
- if b is not None:
- assert isinstance(b, torch.Tensor) and b.ndim == 1
- assert 0 <= dim < x.ndim
- assert b.shape[0] == x.shape[dim]
- x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)])
- # Evaluate activation function.
- alpha = float(alpha)
- x = spec.func(x, alpha=alpha)
- # Scale by gain.
- gain = float(gain)
- if gain != 1:
- x = x * gain
- # Clamp.
- if clamp >= 0:
- x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type
- return x
- def bias_act(
- x, b=None, dim=1, act="linear", alpha=None, gain=None, clamp=None, impl="ref"
- ):
- r"""Fused bias and activation function.
- Adds bias `b` to activation tensor `x`, evaluates activation function `act`,
- and scales the result by `gain`. Each of the steps is optional. In most cases,
- the fused op is considerably more efficient than performing the same calculation
- using standard PyTorch ops. It supports first and second order gradients,
- but not third order gradients.
- Args:
- x: Input activation tensor. Can be of any shape.
- b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type
- as `x`. The shape must be known, and it must match the dimension of `x`
- corresponding to `dim`.
- dim: The dimension in `x` corresponding to the elements of `b`.
- The value of `dim` is ignored if `b` is not specified.
- act: Name of the activation function to evaluate, or `"linear"` to disable.
- Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc.
- See `activation_funcs` for a full list. `None` is not allowed.
- alpha: Shape parameter for the activation function, or `None` to use the default.
- gain: Scaling factor for the output tensor, or `None` to use default.
- See `activation_funcs` for the default scaling of each activation function.
- If unsure, consider specifying 1.
- clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable
- the clamping (default).
- impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
- Returns:
- Tensor of the same shape and datatype as `x`.
- """
- assert isinstance(x, torch.Tensor)
- assert impl in ["ref", "cuda"]
- return _bias_act_ref(
- x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp
- )
- def _get_filter_size(f):
- if f is None:
- return 1, 1
- assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
- fw = f.shape[-1]
- fh = f.shape[0]
- fw = int(fw)
- fh = int(fh)
- assert fw >= 1 and fh >= 1
- return fw, fh
- def _get_weight_shape(w):
- shape = [int(sz) for sz in w.shape]
- return shape
- def _parse_scaling(scaling):
- if isinstance(scaling, int):
- scaling = [scaling, scaling]
- assert isinstance(scaling, (list, tuple))
- assert all(isinstance(x, int) for x in scaling)
- sx, sy = scaling
- assert sx >= 1 and sy >= 1
- return sx, sy
- def _parse_padding(padding):
- if isinstance(padding, int):
- padding = [padding, padding]
- assert isinstance(padding, (list, tuple))
- assert all(isinstance(x, int) for x in padding)
- if len(padding) == 2:
- padx, pady = padding
- padding = [padx, padx, pady, pady]
- padx0, padx1, pady0, pady1 = padding
- return padx0, padx1, pady0, pady1
- def setup_filter(
- f,
- device=torch.device("cpu"),
- normalize=True,
- flip_filter=False,
- gain=1,
- separable=None,
- ):
- r"""Convenience function to setup 2D FIR filter for `upfirdn2d()`.
- Args:
- f: Torch tensor, numpy array, or python list of the shape
- `[filter_height, filter_width]` (non-separable),
- `[filter_taps]` (separable),
- `[]` (impulse), or
- `None` (identity).
- device: Result device (default: cpu).
- normalize: Normalize the filter so that it retains the magnitude
- for constant input signal (DC)? (default: True).
- flip_filter: Flip the filter? (default: False).
- gain: Overall scaling factor for signal magnitude (default: 1).
- separable: Return a separable filter? (default: select automatically).
- Returns:
- Float32 tensor of the shape
- `[filter_height, filter_width]` (non-separable) or
- `[filter_taps]` (separable).
- """
- # Validate.
- if f is None:
- f = 1
- f = torch.as_tensor(f, dtype=torch.float32)
- assert f.ndim in [0, 1, 2]
- assert f.numel() > 0
- if f.ndim == 0:
- f = f[np.newaxis]
- # Separable?
- if separable is None:
- separable = f.ndim == 1 and f.numel() >= 8
- if f.ndim == 1 and not separable:
- f = f.ger(f)
- assert f.ndim == (1 if separable else 2)
- # Apply normalize, flip, gain, and device.
- if normalize:
- f /= f.sum()
- if flip_filter:
- f = f.flip(list(range(f.ndim)))
- f = f * (gain ** (f.ndim / 2))
- f = f.to(device=device)
- return f
- def _ntuple(n):
- def parse(x):
- if isinstance(x, collections.abc.Iterable):
- return x
- return tuple(repeat(x, n))
- return parse
- to_2tuple = _ntuple(2)
- activation_funcs = {
- "linear": EasyDict(
- func=lambda x, **_: x,
- def_alpha=0,
- def_gain=1,
- cuda_idx=1,
- ref="",
- has_2nd_grad=False,
- ),
- "relu": EasyDict(
- func=lambda x, **_: torch.nn.functional.relu(x),
- def_alpha=0,
- def_gain=np.sqrt(2),
- cuda_idx=2,
- ref="y",
- has_2nd_grad=False,
- ),
- "lrelu": EasyDict(
- func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha),
- def_alpha=0.2,
- def_gain=np.sqrt(2),
- cuda_idx=3,
- ref="y",
- has_2nd_grad=False,
- ),
- "tanh": EasyDict(
- func=lambda x, **_: torch.tanh(x),
- def_alpha=0,
- def_gain=1,
- cuda_idx=4,
- ref="y",
- has_2nd_grad=True,
- ),
- "sigmoid": EasyDict(
- func=lambda x, **_: torch.sigmoid(x),
- def_alpha=0,
- def_gain=1,
- cuda_idx=5,
- ref="y",
- has_2nd_grad=True,
- ),
- "elu": EasyDict(
- func=lambda x, **_: torch.nn.functional.elu(x),
- def_alpha=0,
- def_gain=1,
- cuda_idx=6,
- ref="y",
- has_2nd_grad=True,
- ),
- "selu": EasyDict(
- func=lambda x, **_: torch.nn.functional.selu(x),
- def_alpha=0,
- def_gain=1,
- cuda_idx=7,
- ref="y",
- has_2nd_grad=True,
- ),
- "softplus": EasyDict(
- func=lambda x, **_: torch.nn.functional.softplus(x),
- def_alpha=0,
- def_gain=1,
- cuda_idx=8,
- ref="y",
- has_2nd_grad=True,
- ),
- "swish": EasyDict(
- func=lambda x, **_: torch.sigmoid(x) * x,
- def_alpha=0,
- def_gain=np.sqrt(2),
- cuda_idx=9,
- ref="x",
- has_2nd_grad=True,
- ),
- }
- def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl="cuda"):
- r"""Pad, upsample, filter, and downsample a batch of 2D images.
- Performs the following sequence of operations for each channel:
- 1. Upsample the image by inserting N-1 zeros after each pixel (`up`).
- 2. Pad the image with the specified number of zeros on each side (`padding`).
- Negative padding corresponds to cropping the image.
- 3. Convolve the image with the specified 2D FIR filter (`f`), shrinking it
- so that the footprint of all output pixels lies within the input image.
- 4. Downsample the image by keeping every Nth pixel (`down`).
- This sequence of operations bears close resemblance to scipy.signal.upfirdn().
- The fused op is considerably more efficient than performing the same calculation
- using standard PyTorch ops. It supports gradients of arbitrary order.
- Args:
- x: Float32/float64/float16 input tensor of the shape
- `[batch_size, num_channels, in_height, in_width]`.
- f: Float32 FIR filter of the shape
- `[filter_height, filter_width]` (non-separable),
- `[filter_taps]` (separable), or
- `None` (identity).
- up: Integer upsampling factor. Can be a single int or a list/tuple
- `[x, y]` (default: 1).
- down: Integer downsampling factor. Can be a single int or a list/tuple
- `[x, y]` (default: 1).
- padding: Padding with respect to the upsampled image. Can be a single number
- or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
- (default: 0).
- flip_filter: False = convolution, True = correlation (default: False).
- gain: Overall scaling factor for signal magnitude (default: 1).
- impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
- Returns:
- Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
- """
- # assert isinstance(x, torch.Tensor)
- # assert impl in ['ref', 'cuda']
- return _upfirdn2d_ref(
- x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain
- )
- def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1):
- """Slow reference implementation of `upfirdn2d()` using standard PyTorch ops."""
- # Validate arguments.
- assert isinstance(x, torch.Tensor) and x.ndim == 4
- if f is None:
- f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
- assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
- assert not f.requires_grad
- batch_size, num_channels, in_height, in_width = x.shape
- # upx, upy = _parse_scaling(up)
- # downx, downy = _parse_scaling(down)
- upx, upy = up, up
- downx, downy = down, down
- # padx0, padx1, pady0, pady1 = _parse_padding(padding)
- padx0, padx1, pady0, pady1 = padding[0], padding[1], padding[2], padding[3]
- # Upsample by inserting zeros.
- x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1])
- x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1])
- x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx])
- # Pad or crop.
- x = torch.nn.functional.pad(
- x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)]
- )
- x = x[
- :,
- :,
- max(-pady0, 0) : x.shape[2] - max(-pady1, 0),
- max(-padx0, 0) : x.shape[3] - max(-padx1, 0),
- ]
- # Setup filter.
- f = f * (gain ** (f.ndim / 2))
- f = f.to(x.dtype)
- if not flip_filter:
- f = f.flip(list(range(f.ndim)))
- # Convolve with the filter.
- f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim)
- if f.ndim == 4:
- x = conv2d(input=x, weight=f, groups=num_channels)
- else:
- x = conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels)
- x = conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels)
- # Downsample by throwing away pixels.
- x = x[:, :, ::downy, ::downx]
- return x
- def downsample2d(x, f, down=2, padding=0, flip_filter=False, gain=1, impl="cuda"):
- r"""Downsample a batch of 2D images using the given 2D FIR filter.
- By default, the result is padded so that its shape is a fraction of the input.
- User-specified padding is applied on top of that, with negative values
- indicating cropping. Pixels outside the image are assumed to be zero.
- Args:
- x: Float32/float64/float16 input tensor of the shape
- `[batch_size, num_channels, in_height, in_width]`.
- f: Float32 FIR filter of the shape
- `[filter_height, filter_width]` (non-separable),
- `[filter_taps]` (separable), or
- `None` (identity).
- down: Integer downsampling factor. Can be a single int or a list/tuple
- `[x, y]` (default: 1).
- padding: Padding with respect to the input. Can be a single number or a
- list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
- (default: 0).
- flip_filter: False = convolution, True = correlation (default: False).
- gain: Overall scaling factor for signal magnitude (default: 1).
- impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
- Returns:
- Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
- """
- downx, downy = _parse_scaling(down)
- # padx0, padx1, pady0, pady1 = _parse_padding(padding)
- padx0, padx1, pady0, pady1 = padding, padding, padding, padding
- fw, fh = _get_filter_size(f)
- p = [
- padx0 + (fw - downx + 1) // 2,
- padx1 + (fw - downx) // 2,
- pady0 + (fh - downy + 1) // 2,
- pady1 + (fh - downy) // 2,
- ]
- return upfirdn2d(
- x, f, down=down, padding=p, flip_filter=flip_filter, gain=gain, impl=impl
- )
- def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl="cuda"):
- r"""Upsample a batch of 2D images using the given 2D FIR filter.
- By default, the result is padded so that its shape is a multiple of the input.
- User-specified padding is applied on top of that, with negative values
- indicating cropping. Pixels outside the image are assumed to be zero.
- Args:
- x: Float32/float64/float16 input tensor of the shape
- `[batch_size, num_channels, in_height, in_width]`.
- f: Float32 FIR filter of the shape
- `[filter_height, filter_width]` (non-separable),
- `[filter_taps]` (separable), or
- `None` (identity).
- up: Integer upsampling factor. Can be a single int or a list/tuple
- `[x, y]` (default: 1).
- padding: Padding with respect to the output. Can be a single number or a
- list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
- (default: 0).
- flip_filter: False = convolution, True = correlation (default: False).
- gain: Overall scaling factor for signal magnitude (default: 1).
- impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
- Returns:
- Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
- """
- upx, upy = _parse_scaling(up)
- # upx, upy = up, up
- padx0, padx1, pady0, pady1 = _parse_padding(padding)
- # padx0, padx1, pady0, pady1 = padding, padding, padding, padding
- fw, fh = _get_filter_size(f)
- p = [
- padx0 + (fw + upx - 1) // 2,
- padx1 + (fw - upx) // 2,
- pady0 + (fh + upy - 1) // 2,
- pady1 + (fh - upy) // 2,
- ]
- return upfirdn2d(
- x,
- f,
- up=up,
- padding=p,
- flip_filter=flip_filter,
- gain=gain * upx * upy,
- impl=impl,
- )
- class MinibatchStdLayer(torch.nn.Module):
- def __init__(self, group_size, num_channels=1):
- super().__init__()
- self.group_size = group_size
- self.num_channels = num_channels
- def forward(self, x):
- N, C, H, W = x.shape
- G = (
- torch.min(torch.as_tensor(self.group_size), torch.as_tensor(N))
- if self.group_size is not None
- else N
- )
- F = self.num_channels
- c = C // F
- y = x.reshape(
- G, -1, F, c, H, W
- ) # [GnFcHW] Split minibatch N into n groups of size G, and channels C into F groups of size c.
- y = y - y.mean(dim=0) # [GnFcHW] Subtract mean over group.
- y = y.square().mean(dim=0) # [nFcHW] Calc variance over group.
- y = (y + 1e-8).sqrt() # [nFcHW] Calc stddev over group.
- y = y.mean(dim=[2, 3, 4]) # [nF] Take average over channels and pixels.
- y = y.reshape(-1, F, 1, 1) # [nF11] Add missing dimensions.
- y = y.repeat(G, 1, H, W) # [NFHW] Replicate over group and pixels.
- x = torch.cat([x, y], dim=1) # [NCHW] Append to input as new channels.
- return x
- class FullyConnectedLayer(torch.nn.Module):
- def __init__(
- self,
- in_features, # Number of input features.
- out_features, # Number of output features.
- bias=True, # Apply additive bias before the activation function?
- activation="linear", # Activation function: 'relu', 'lrelu', etc.
- lr_multiplier=1, # Learning rate multiplier.
- bias_init=0, # Initial value for the additive bias.
- ):
- super().__init__()
- self.weight = torch.nn.Parameter(
- torch.randn([out_features, in_features]) / lr_multiplier
- )
- self.bias = (
- torch.nn.Parameter(torch.full([out_features], np.float32(bias_init)))
- if bias
- else None
- )
- self.activation = activation
- self.weight_gain = lr_multiplier / np.sqrt(in_features)
- self.bias_gain = lr_multiplier
- def forward(self, x):
- w = self.weight * self.weight_gain
- b = self.bias
- if b is not None and self.bias_gain != 1:
- b = b * self.bias_gain
- if self.activation == "linear" and b is not None:
- # out = torch.addmm(b.unsqueeze(0), x, w.t())
- x = x.matmul(w.t())
- out = x + b.reshape([-1 if i == x.ndim - 1 else 1 for i in range(x.ndim)])
- else:
- x = x.matmul(w.t())
- out = bias_act(x, b, act=self.activation, dim=x.ndim - 1)
- return out
- def _conv2d_wrapper(
- x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True
- ):
- """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations."""
- out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
- # Flip weight if requested.
- if (
- not flip_weight
- ): # conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False).
- w = w.flip([2, 3])
- # Workaround performance pitfall in cuDNN 8.0.5, triggered when using
- # 1x1 kernel + memory_format=channels_last + less than 64 channels.
- if (
- kw == 1
- and kh == 1
- and stride == 1
- and padding in [0, [0, 0], (0, 0)]
- and not transpose
- ):
- if x.stride()[1] == 1 and min(out_channels, in_channels_per_group) < 64:
- if out_channels <= 4 and groups == 1:
- in_shape = x.shape
- x = w.squeeze(3).squeeze(2) @ x.reshape(
- [in_shape[0], in_channels_per_group, -1]
- )
- x = x.reshape([in_shape[0], out_channels, in_shape[2], in_shape[3]])
- else:
- x = x.to(memory_format=torch.contiguous_format)
- w = w.to(memory_format=torch.contiguous_format)
- x = conv2d(x, w, groups=groups)
- return x.to(memory_format=torch.channels_last)
- # Otherwise => execute using conv2d_gradfix.
- op = conv_transpose2d if transpose else conv2d
- return op(x, w, stride=stride, padding=padding, groups=groups)
- def conv2d_resample(
- x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False
- ):
- r"""2D convolution with optional up/downsampling.
- Padding is performed only once at the beginning, not between the operations.
- Args:
- x: Input tensor of shape
- `[batch_size, in_channels, in_height, in_width]`.
- w: Weight tensor of shape
- `[out_channels, in_channels//groups, kernel_height, kernel_width]`.
- f: Low-pass filter for up/downsampling. Must be prepared beforehand by
- calling setup_filter(). None = identity (default).
- up: Integer upsampling factor (default: 1).
- down: Integer downsampling factor (default: 1).
- padding: Padding with respect to the upsampled image. Can be a single number
- or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
- (default: 0).
- groups: Split input channels into N groups (default: 1).
- flip_weight: False = convolution, True = correlation (default: True).
- flip_filter: False = convolution, True = correlation (default: False).
- Returns:
- Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
- """
- # Validate arguments.
- assert isinstance(x, torch.Tensor) and (x.ndim == 4)
- assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype)
- assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2])
- assert isinstance(up, int) and (up >= 1)
- assert isinstance(down, int) and (down >= 1)
- # assert isinstance(groups, int) and (groups >= 1), f"!!!!!! groups: {groups} isinstance(groups, int) {isinstance(groups, int)} {type(groups)}"
- out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
- fw, fh = _get_filter_size(f)
- # px0, px1, py0, py1 = _parse_padding(padding)
- px0, px1, py0, py1 = padding, padding, padding, padding
- # Adjust padding to account for up/downsampling.
- if up > 1:
- px0 += (fw + up - 1) // 2
- px1 += (fw - up) // 2
- py0 += (fh + up - 1) // 2
- py1 += (fh - up) // 2
- if down > 1:
- px0 += (fw - down + 1) // 2
- px1 += (fw - down) // 2
- py0 += (fh - down + 1) // 2
- py1 += (fh - down) // 2
- # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve.
- if kw == 1 and kh == 1 and (down > 1 and up == 1):
- x = upfirdn2d(
- x=x, f=f, down=down, padding=[px0, px1, py0, py1], flip_filter=flip_filter
- )
- x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
- return x
- # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample.
- if kw == 1 and kh == 1 and (up > 1 and down == 1):
- x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
- x = upfirdn2d(
- x=x,
- f=f,
- up=up,
- padding=[px0, px1, py0, py1],
- gain=up**2,
- flip_filter=flip_filter,
- )
- return x
- # Fast path: downsampling only => use strided convolution.
- if down > 1 and up == 1:
- x = upfirdn2d(x=x, f=f, padding=[px0, px1, py0, py1], flip_filter=flip_filter)
- x = _conv2d_wrapper(
- x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight
- )
- return x
- # Fast path: upsampling with optional downsampling => use transpose strided convolution.
- if up > 1:
- if groups == 1:
- w = w.transpose(0, 1)
- else:
- w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw)
- w = w.transpose(1, 2)
- w = w.reshape(
- groups * in_channels_per_group, out_channels // groups, kh, kw
- )
- px0 -= kw - 1
- px1 -= kw - up
- py0 -= kh - 1
- py1 -= kh - up
- pxt = max(min(-px0, -px1), 0)
- pyt = max(min(-py0, -py1), 0)
- x = _conv2d_wrapper(
- x=x,
- w=w,
- stride=up,
- padding=[pyt, pxt],
- groups=groups,
- transpose=True,
- flip_weight=(not flip_weight),
- )
- x = upfirdn2d(
- x=x,
- f=f,
- padding=[px0 + pxt, px1 + pxt, py0 + pyt, py1 + pyt],
- gain=up**2,
- flip_filter=flip_filter,
- )
- if down > 1:
- x = upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
- return x
- # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d.
- if up == 1 and down == 1:
- if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0:
- return _conv2d_wrapper(
- x=x, w=w, padding=[py0, px0], groups=groups, flip_weight=flip_weight
- )
- # Fallback: Generic reference implementation.
- x = upfirdn2d(
- x=x,
- f=(f if up > 1 else None),
- up=up,
- padding=[px0, px1, py0, py1],
- gain=up**2,
- flip_filter=flip_filter,
- )
- x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
- if down > 1:
- x = upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
- return x
- class Conv2dLayer(torch.nn.Module):
- def __init__(
- self,
- in_channels, # Number of input channels.
- out_channels, # Number of output channels.
- kernel_size, # Width and height of the convolution kernel.
- bias=True, # Apply additive bias before the activation function?
- activation="linear", # Activation function: 'relu', 'lrelu', etc.
- up=1, # Integer upsampling factor.
- down=1, # Integer downsampling factor.
- resample_filter=[
- 1,
- 3,
- 3,
- 1,
- ], # Low-pass filter to apply when resampling activations.
- conv_clamp=None, # Clamp the output to +-X, None = disable clamping.
- channels_last=False, # Expect the input to have memory_format=channels_last?
- trainable=True, # Update the weights of this layer during training?
- ):
- super().__init__()
- self.activation = activation
- self.up = up
- self.down = down
- self.register_buffer("resample_filter", setup_filter(resample_filter))
- self.conv_clamp = conv_clamp
- self.padding = kernel_size // 2
- self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size**2))
- self.act_gain = activation_funcs[activation].def_gain
- memory_format = (
- torch.channels_last if channels_last else torch.contiguous_format
- )
- weight = torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(
- memory_format=memory_format
- )
- bias = torch.zeros([out_channels]) if bias else None
- if trainable:
- self.weight = torch.nn.Parameter(weight)
- self.bias = torch.nn.Parameter(bias) if bias is not None else None
- else:
- self.register_buffer("weight", weight)
- if bias is not None:
- self.register_buffer("bias", bias)
- else:
- self.bias = None
- def forward(self, x, gain=1):
- w = self.weight * self.weight_gain
- x = conv2d_resample(
- x=x,
- w=w,
- f=self.resample_filter,
- up=self.up,
- down=self.down,
- padding=self.padding,
- )
- act_gain = self.act_gain * gain
- act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None
- out = bias_act(
- x, self.bias, act=self.activation, gain=act_gain, clamp=act_clamp
- )
- return out
- def torch_gc():
- if torch.cuda.is_available():
- torch.cuda.empty_cache()
- torch.cuda.ipc_collect()
- gc.collect()
- def set_seed(seed: int):
- random.seed(seed)
- np.random.seed(seed)
- torch.manual_seed(seed)
- torch.cuda.manual_seed_all(seed)
- def get_scheduler(sd_sampler, scheduler_config):
- # https://github.com/huggingface/diffusers/issues/4167
- keys_to_pop = ["use_karras_sigmas", "algorithm_type"]
- scheduler_config = dict(scheduler_config)
- for it in keys_to_pop:
- scheduler_config.pop(it, None)
- # fmt: off
- samplers = {
- SDSampler.dpm_plus_plus_2m: [DPMSolverMultistepScheduler],
- SDSampler.dpm_plus_plus_2m_karras: [DPMSolverMultistepScheduler, dict(use_karras_sigmas=True)],
- SDSampler.dpm_plus_plus_2m_sde: [DPMSolverMultistepScheduler, dict(algorithm_type="sde-dpmsolver++")],
- SDSampler.dpm_plus_plus_2m_sde_karras: [DPMSolverMultistepScheduler, dict(algorithm_type="sde-dpmsolver++", use_karras_sigmas=True)],
- SDSampler.dpm_plus_plus_sde: [DPMSolverSinglestepScheduler],
- SDSampler.dpm_plus_plus_sde_karras: [DPMSolverSinglestepScheduler, dict(use_karras_sigmas=True)],
- SDSampler.dpm2: [KDPM2DiscreteScheduler],
- SDSampler.dpm2_karras: [KDPM2DiscreteScheduler, dict(use_karras_sigmas=True)],
- SDSampler.dpm2_a: [KDPM2AncestralDiscreteScheduler],
- SDSampler.dpm2_a_karras: [KDPM2AncestralDiscreteScheduler, dict(use_karras_sigmas=True)],
- SDSampler.euler: [EulerDiscreteScheduler],
- SDSampler.euler_a: [EulerAncestralDiscreteScheduler],
- SDSampler.heun: [HeunDiscreteScheduler],
- SDSampler.lms: [LMSDiscreteScheduler],
- SDSampler.lms_karras: [LMSDiscreteScheduler, dict(use_karras_sigmas=True)],
- SDSampler.ddim: [DDIMScheduler],
- SDSampler.pndm: [PNDMScheduler],
- SDSampler.uni_pc: [UniPCMultistepScheduler],
- SDSampler.lcm: [LCMScheduler],
- }
- # fmt: on
- if sd_sampler in samplers:
- if len(samplers[sd_sampler]) == 2:
- scheduler_cls, kwargs = samplers[sd_sampler]
- else:
- scheduler_cls, kwargs = samplers[sd_sampler][0], {}
- return scheduler_cls.from_config(scheduler_config, **kwargs)
- else:
- raise ValueError(sd_sampler)
- def is_local_files_only(**kwargs) -> bool:
- from huggingface_hub.constants import HF_HUB_OFFLINE
- return HF_HUB_OFFLINE or kwargs.get("local_files_only", False)
- def handle_from_pretrained_exceptions(func, **kwargs):
- try:
- return func(**kwargs)
- except ValueError as e:
- if "You are trying to load the model files of the `variant=fp16`" in str(e):
- logger.info("variant=fp16 not found, try revision=fp16")
- try:
- return func(**{**kwargs, "variant": None, "revision": "fp16"})
- except Exception as e:
- logger.info("revision=fp16 not found, try revision=main")
- return func(**{**kwargs, "variant": None, "revision": "main"})
- raise e
- except OSError as e:
- previous_traceback = traceback.format_exc()
- if "RevisionNotFoundError: 404 Client Error." in previous_traceback:
- logger.info("revision=fp16 not found, try revision=main")
- return func(**{**kwargs, "variant": None, "revision": "main"})
- elif "Max retries exceeded" in previous_traceback:
- logger.exception(
- "Fetching model from HuggingFace failed. "
- "If this is your first time downloading the model, you may need to set up proxy in terminal."
- "If the model has already been downloaded, you can add --local-files-only when starting."
- )
- exit(-1)
- raise e
- except Exception as e:
- raise e
- def get_torch_dtype(device, no_half: bool):
- device = str(device)
- use_fp16 = not no_half
- use_gpu = device == "cuda"
- # https://github.com/huggingface/diffusers/issues/4480
- # pipe.enable_attention_slicing and float16 will cause black output on mps
- # if device in ["cuda", "mps"] and use_fp16:
- if device in ["cuda"] and use_fp16:
- return use_gpu, torch.float16
- return use_gpu, torch.float32
- def enable_low_mem(pipe, enable: bool):
- if torch.backends.mps.is_available():
- # https://huggingface.co/docs/diffusers/v0.25.0/en/api/pipelines/stable_diffusion/image_variation#diffusers.StableDiffusionImageVariationPipeline.enable_attention_slicing
- # CUDA: Don't enable attention slicing if you're already using `scaled_dot_product_attention` (SDPA) from PyTorch 2.0 or xFormers.
- if enable:
- pipe.enable_attention_slicing("max")
- else:
- # https://huggingface.co/docs/diffusers/optimization/mps
- # Devices with less than 64GB of memory are recommended to use enable_attention_slicing
- pipe.enable_attention_slicing()
- if enable:
- pipe.vae.enable_tiling()
|