""" Part of the implementation is borrowed and modified from ControlNet, publicly available at https://github.com/lllyasviel/ControlNet/blob/main/ldm/models/diffusion/ddpm.py """ import itertools from contextlib import contextmanager, nullcontext from functools import partial import cv2 import numpy as np import torch import torch.nn as nn from einops import rearrange, repeat from omegaconf import ListConfig from torch.optim.lr_scheduler import LambdaLR from torchvision.utils import make_grid from tqdm import tqdm from sorawm.iopaint.model.anytext.ldm.models.autoencoder import ( AutoencoderKL, IdentityFirstStage, ) from sorawm.iopaint.model.anytext.ldm.models.diffusion.ddim import DDIMSampler from sorawm.iopaint.model.anytext.ldm.modules.diffusionmodules.util import ( extract_into_tensor, make_beta_schedule, noise_like, ) from sorawm.iopaint.model.anytext.ldm.modules.distributions.distributions import ( DiagonalGaussianDistribution, normal_kl, ) from sorawm.iopaint.model.anytext.ldm.modules.ema import LitEma from sorawm.iopaint.model.anytext.ldm.util import ( count_params, default, exists, instantiate_from_config, isimage, ismap, log_txt_as_img, mean_flat, ) __conditioning_keys__ = {"concat": "c_concat", "crossattn": "c_crossattn", "adm": "y"} PRINT_DEBUG = False def print_grad(grad): # print('Gradient:', grad) # print(grad.shape) a = grad.max() b = grad.min() # print(f'mean={grad.mean():.4f}, max={a:.4f}, min={b:.4f}') s = 255.0 / (a - b) c = 255 * (-b / (a - b)) grad = grad * s + c # print(f'mean={grad.mean():.4f}, max={grad.max():.4f}, min={grad.min():.4f}') img = grad[0].permute(1, 2, 0).detach().cpu().numpy() if img.shape[0] == 512: cv2.imwrite("grad-img.jpg", img) elif img.shape[0] == 64: cv2.imwrite("grad-latent.jpg", img) def disabled_train(self, mode=True): """Overwrite model.train with this function to make sure train/eval mode does not change anymore.""" return self def uniform_on_device(r1, r2, shape, device): return (r1 - r2) * torch.rand(*shape, device=device) + r2 class DDPM(torch.nn.Module): # classic DDPM with Gaussian diffusion, in image space def __init__( self, unet_config, timesteps=1000, beta_schedule="linear", loss_type="l2", ckpt_path=None, ignore_keys=[], load_only_unet=False, monitor="val/loss", use_ema=True, first_stage_key="image", image_size=256, channels=3, log_every_t=100, clip_denoised=True, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3, given_betas=None, original_elbo_weight=0.0, v_posterior=0.0, # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta l_simple_weight=1.0, conditioning_key=None, parameterization="eps", # all assuming fixed variance schedules scheduler_config=None, use_positional_encodings=False, learn_logvar=False, logvar_init=0.0, make_it_fit=False, ucg_training=None, reset_ema=False, reset_num_ema_updates=False, ): super().__init__() assert parameterization in [ "eps", "x0", "v", ], 'currently only supporting "eps" and "x0" and "v"' self.parameterization = parameterization print( f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode" ) self.cond_stage_model = None self.clip_denoised = clip_denoised self.log_every_t = log_every_t self.first_stage_key = first_stage_key self.image_size = image_size # try conv? self.channels = channels self.use_positional_encodings = use_positional_encodings self.model = DiffusionWrapper(unet_config, conditioning_key) count_params(self.model, verbose=True) self.use_ema = use_ema if self.use_ema: self.model_ema = LitEma(self.model) print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") self.use_scheduler = scheduler_config is not None if self.use_scheduler: self.scheduler_config = scheduler_config self.v_posterior = v_posterior self.original_elbo_weight = original_elbo_weight self.l_simple_weight = l_simple_weight if monitor is not None: self.monitor = monitor self.make_it_fit = make_it_fit if reset_ema: assert exists(ckpt_path) if ckpt_path is not None: self.init_from_ckpt( ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet ) if reset_ema: assert self.use_ema print( f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint." ) self.model_ema = LitEma(self.model) if reset_num_ema_updates: print( " +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ " ) assert self.use_ema self.model_ema.reset_num_updates() self.register_schedule( given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s, ) self.loss_type = loss_type self.learn_logvar = learn_logvar logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,)) if self.learn_logvar: self.logvar = nn.Parameter(self.logvar, requires_grad=True) else: self.register_buffer("logvar", logvar) self.ucg_training = ucg_training or dict() if self.ucg_training: self.ucg_prng = np.random.RandomState() def register_schedule( self, given_betas=None, beta_schedule="linear", timesteps=1000, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3, ): if exists(given_betas): betas = given_betas else: betas = make_beta_schedule( beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s, ) alphas = 1.0 - betas alphas_cumprod = np.cumprod(alphas, axis=0) # np.save('1.npy', alphas_cumprod) alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1]) (timesteps,) = betas.shape self.num_timesteps = int(timesteps) self.linear_start = linear_start self.linear_end = linear_end assert ( alphas_cumprod.shape[0] == self.num_timesteps ), "alphas have to be defined for each timestep" to_torch = partial(torch.tensor, dtype=torch.float32) self.register_buffer("betas", to_torch(betas)) self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev)) # calculations for diffusion q(x_t | x_{t-1}) and others self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod))) self.register_buffer( "sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod)) ) self.register_buffer( "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod)) ) self.register_buffer( "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod)) ) self.register_buffer( "sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1)) ) # calculations for posterior q(x_{t-1} | x_t, x_0) posterior_variance = (1 - self.v_posterior) * betas * ( 1.0 - alphas_cumprod_prev ) / (1.0 - alphas_cumprod) + self.v_posterior * betas # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) self.register_buffer("posterior_variance", to_torch(posterior_variance)) # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain self.register_buffer( "posterior_log_variance_clipped", to_torch(np.log(np.maximum(posterior_variance, 1e-20))), ) self.register_buffer( "posterior_mean_coef1", to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)), ) self.register_buffer( "posterior_mean_coef2", to_torch( (1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod) ), ) if self.parameterization == "eps": lvlb_weights = self.betas**2 / ( 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod) ) elif self.parameterization == "x0": lvlb_weights = ( 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2.0 * 1 - torch.Tensor(alphas_cumprod)) ) elif self.parameterization == "v": lvlb_weights = torch.ones_like( self.betas**2 / ( 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod) ) ) else: raise NotImplementedError("mu not supported") lvlb_weights[0] = lvlb_weights[1] self.register_buffer("lvlb_weights", lvlb_weights, persistent=False) assert not torch.isnan(self.lvlb_weights).all() @contextmanager def ema_scope(self, context=None): if self.use_ema: self.model_ema.store(self.model.parameters()) self.model_ema.copy_to(self.model) if context is not None: print(f"{context}: Switched to EMA weights") try: yield None finally: if self.use_ema: self.model_ema.restore(self.model.parameters()) if context is not None: print(f"{context}: Restored training weights") @torch.no_grad() def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): sd = torch.load(path, map_location="cpu") if "state_dict" in list(sd.keys()): sd = sd["state_dict"] keys = list(sd.keys()) for k in keys: for ik in ignore_keys: if k.startswith(ik): print("Deleting key {} from state_dict.".format(k)) del sd[k] if self.make_it_fit: n_params = len( [ name for name, _ in itertools.chain( self.named_parameters(), self.named_buffers() ) ] ) for name, param in tqdm( itertools.chain(self.named_parameters(), self.named_buffers()), desc="Fitting old weights to new weights", total=n_params, ): if not name in sd: continue old_shape = sd[name].shape new_shape = param.shape assert len(old_shape) == len(new_shape) if len(new_shape) > 2: # we only modify first two axes assert new_shape[2:] == old_shape[2:] # assumes first axis corresponds to output dim if not new_shape == old_shape: new_param = param.clone() old_param = sd[name] if len(new_shape) == 1: for i in range(new_param.shape[0]): new_param[i] = old_param[i % old_shape[0]] elif len(new_shape) >= 2: for i in range(new_param.shape[0]): for j in range(new_param.shape[1]): new_param[i, j] = old_param[ i % old_shape[0], j % old_shape[1] ] n_used_old = torch.ones(old_shape[1]) for j in range(new_param.shape[1]): n_used_old[j % old_shape[1]] += 1 n_used_new = torch.zeros(new_shape[1]) for j in range(new_param.shape[1]): n_used_new[j] = n_used_old[j % old_shape[1]] n_used_new = n_used_new[None, :] while len(n_used_new.shape) < len(new_shape): n_used_new = n_used_new.unsqueeze(-1) new_param /= n_used_new sd[name] = new_param missing, unexpected = ( self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(sd, strict=False) ) print( f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys" ) if len(missing) > 0: print(f"Missing Keys:\n {missing}") if len(unexpected) > 0: print(f"\nUnexpected Keys:\n {unexpected}") def q_mean_variance(self, x_start, t): """ Get the distribution q(x_t | x_0). :param x_start: the [N x C x ...] tensor of noiseless inputs. :param t: the number of diffusion steps (minus 1). Here, 0 means one step. :return: A tuple (mean, variance, log_variance), all of x_start's shape. """ mean = extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) log_variance = extract_into_tensor( self.log_one_minus_alphas_cumprod, t, x_start.shape ) return mean, variance, log_variance def predict_start_from_noise(self, x_t, t, noise): return ( extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise ) def predict_start_from_z_and_v(self, x_t, t, v): # self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) # self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) return ( extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v ) def predict_eps_from_z_and_v(self, x_t, t, v): return ( extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * x_t ) def q_posterior(self, x_start, x_t, t): posterior_mean = ( extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t ) posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape) posterior_log_variance_clipped = extract_into_tensor( self.posterior_log_variance_clipped, t, x_t.shape ) return posterior_mean, posterior_variance, posterior_log_variance_clipped def p_mean_variance(self, x, t, clip_denoised: bool): model_out = self.model(x, t) if self.parameterization == "eps": x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) elif self.parameterization == "x0": x_recon = model_out if clip_denoised: x_recon.clamp_(-1.0, 1.0) model_mean, posterior_variance, posterior_log_variance = self.q_posterior( x_start=x_recon, x_t=x, t=t ) return model_mean, posterior_variance, posterior_log_variance @torch.no_grad() def p_sample(self, x, t, clip_denoised=True, repeat_noise=False): b, *_, device = *x.shape, x.device model_mean, _, model_log_variance = self.p_mean_variance( x=x, t=t, clip_denoised=clip_denoised ) noise = noise_like(x.shape, device, repeat_noise) # no noise when t == 0 nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise @torch.no_grad() def p_sample_loop(self, shape, return_intermediates=False): device = self.betas.device b = shape[0] img = torch.randn(shape, device=device) intermediates = [img] for i in tqdm( reversed(range(0, self.num_timesteps)), desc="Sampling t", total=self.num_timesteps, ): img = self.p_sample( img, torch.full((b,), i, device=device, dtype=torch.long), clip_denoised=self.clip_denoised, ) if i % self.log_every_t == 0 or i == self.num_timesteps - 1: intermediates.append(img) if return_intermediates: return img, intermediates return img @torch.no_grad() def sample(self, batch_size=16, return_intermediates=False): image_size = self.image_size channels = self.channels return self.p_sample_loop( (batch_size, channels, image_size, image_size), return_intermediates=return_intermediates, ) def q_sample(self, x_start, t, noise=None): noise = default(noise, lambda: torch.randn_like(x_start)) return ( extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise ) def get_v(self, x, noise, t): return ( extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x ) def get_loss(self, pred, target, mean=True): if self.loss_type == "l1": loss = (target - pred).abs() if mean: loss = loss.mean() elif self.loss_type == "l2": if mean: loss = torch.nn.functional.mse_loss(target, pred) else: loss = torch.nn.functional.mse_loss(target, pred, reduction="none") else: raise NotImplementedError("unknown loss type '{loss_type}'") return loss def p_losses(self, x_start, t, noise=None): noise = default(noise, lambda: torch.randn_like(x_start)) x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) model_out = self.model(x_noisy, t) loss_dict = {} if self.parameterization == "eps": target = noise elif self.parameterization == "x0": target = x_start elif self.parameterization == "v": target = self.get_v(x_start, noise, t) else: raise NotImplementedError( f"Parameterization {self.parameterization} not yet supported" ) loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3]) log_prefix = "train" if self.training else "val" loss_dict.update({f"{log_prefix}/loss_simple": loss.mean()}) loss_simple = loss.mean() * self.l_simple_weight loss_vlb = (self.lvlb_weights[t] * loss).mean() loss_dict.update({f"{log_prefix}/loss_vlb": loss_vlb}) loss = loss_simple + self.original_elbo_weight * loss_vlb loss_dict.update({f"{log_prefix}/loss": loss}) return loss, loss_dict def forward(self, x, *args, **kwargs): # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size # assert h == img_size and w == img_size, f'height and width of image must be {img_size}' t = torch.randint( 0, self.num_timesteps, (x.shape[0],), device=self.device ).long() return self.p_losses(x, t, *args, **kwargs) def get_input(self, batch, k): x = batch[k] if len(x.shape) == 3: x = x[..., None] x = rearrange(x, "b h w c -> b c h w") x = x.to(memory_format=torch.contiguous_format).float() return x def shared_step(self, batch): x = self.get_input(batch, self.first_stage_key) loss, loss_dict = self(x) return loss, loss_dict def training_step(self, batch, batch_idx): for k in self.ucg_training: p = self.ucg_training[k]["p"] val = self.ucg_training[k]["val"] if val is None: val = "" for i in range(len(batch[k])): if self.ucg_prng.choice(2, p=[1 - p, p]): batch[k][i] = val loss, loss_dict = self.shared_step(batch) self.log_dict( loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=True ) self.log( "global_step", self.global_step, prog_bar=True, logger=True, on_step=True, on_epoch=False, ) if self.use_scheduler: lr = self.optimizers().param_groups[0]["lr"] self.log( "lr_abs", lr, prog_bar=True, logger=True, on_step=True, on_epoch=False ) return loss @torch.no_grad() def validation_step(self, batch, batch_idx): _, loss_dict_no_ema = self.shared_step(batch) with self.ema_scope(): _, loss_dict_ema = self.shared_step(batch) loss_dict_ema = {key + "_ema": loss_dict_ema[key] for key in loss_dict_ema} self.log_dict( loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True ) self.log_dict( loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True ) def on_train_batch_end(self, *args, **kwargs): if self.use_ema: self.model_ema(self.model) def _get_rows_from_list(self, samples): n_imgs_per_row = len(samples) denoise_grid = rearrange(samples, "n b c h w -> b n c h w") denoise_grid = rearrange(denoise_grid, "b n c h w -> (b n) c h w") denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) return denoise_grid @torch.no_grad() def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs): log = dict() x = self.get_input(batch, self.first_stage_key) N = min(x.shape[0], N) n_row = min(x.shape[0], n_row) x = x.to(self.device)[:N] log["inputs"] = x # get diffusion row diffusion_row = list() x_start = x[:n_row] for t in range(self.num_timesteps): if t % self.log_every_t == 0 or t == self.num_timesteps - 1: t = repeat(torch.tensor([t]), "1 -> b", b=n_row) t = t.to(self.device).long() noise = torch.randn_like(x_start) x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) diffusion_row.append(x_noisy) log["diffusion_row"] = self._get_rows_from_list(diffusion_row) if sample: # get denoise row with self.ema_scope("Plotting"): samples, denoise_row = self.sample( batch_size=N, return_intermediates=True ) log["samples"] = samples log["denoise_row"] = self._get_rows_from_list(denoise_row) if return_keys: if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0: return log else: return {key: log[key] for key in return_keys} return log def configure_optimizers(self): lr = self.learning_rate params = list(self.model.parameters()) if self.learn_logvar: params = params + [self.logvar] opt = torch.optim.AdamW(params, lr=lr) return opt class LatentDiffusion(DDPM): """main class""" def __init__( self, first_stage_config, cond_stage_config, num_timesteps_cond=None, cond_stage_key="image", cond_stage_trainable=False, concat_mode=True, cond_stage_forward=None, conditioning_key=None, scale_factor=1.0, scale_by_std=False, force_null_conditioning=False, *args, **kwargs, ): self.force_null_conditioning = force_null_conditioning self.num_timesteps_cond = default(num_timesteps_cond, 1) self.scale_by_std = scale_by_std assert self.num_timesteps_cond <= kwargs["timesteps"] # for backwards compatibility after implementation of DiffusionWrapper if conditioning_key is None: conditioning_key = "concat" if concat_mode else "crossattn" if ( cond_stage_config == "__is_unconditional__" and not self.force_null_conditioning ): conditioning_key = None ckpt_path = kwargs.pop("ckpt_path", None) reset_ema = kwargs.pop("reset_ema", False) reset_num_ema_updates = kwargs.pop("reset_num_ema_updates", False) ignore_keys = kwargs.pop("ignore_keys", []) super().__init__(conditioning_key=conditioning_key, *args, **kwargs) self.concat_mode = concat_mode self.cond_stage_trainable = cond_stage_trainable self.cond_stage_key = cond_stage_key try: self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1 except: self.num_downs = 0 if not scale_by_std: self.scale_factor = scale_factor else: self.register_buffer("scale_factor", torch.tensor(scale_factor)) self.instantiate_first_stage(first_stage_config) self.instantiate_cond_stage(cond_stage_config) self.cond_stage_forward = cond_stage_forward self.clip_denoised = False self.bbox_tokenizer = None self.restarted_from_ckpt = False if ckpt_path is not None: self.init_from_ckpt(ckpt_path, ignore_keys) self.restarted_from_ckpt = True if reset_ema: assert self.use_ema print( f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint." ) self.model_ema = LitEma(self.model) if reset_num_ema_updates: print( " +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ " ) assert self.use_ema self.model_ema.reset_num_updates() def make_cond_schedule( self, ): self.cond_ids = torch.full( size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long, ) ids = torch.round( torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond) ).long() self.cond_ids[: self.num_timesteps_cond] = ids @torch.no_grad() def on_train_batch_start(self, batch, batch_idx, dataloader_idx): # only for very first batch if ( self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt ): assert ( self.scale_factor == 1.0 ), "rather not use custom rescaling and std-rescaling simultaneously" # set rescale weight to 1./std of encodings print("### USING STD-RESCALING ###") x = super().get_input(batch, self.first_stage_key) x = x.to(self.device) encoder_posterior = self.encode_first_stage(x) z = self.get_first_stage_encoding(encoder_posterior).detach() del self.scale_factor self.register_buffer("scale_factor", 1.0 / z.flatten().std()) print(f"setting self.scale_factor to {self.scale_factor}") print("### USING STD-RESCALING ###") def register_schedule( self, given_betas=None, beta_schedule="linear", timesteps=1000, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3, ): super().register_schedule( given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s ) self.shorten_cond_schedule = self.num_timesteps_cond > 1 if self.shorten_cond_schedule: self.make_cond_schedule() def instantiate_first_stage(self, config): model = instantiate_from_config(config) self.first_stage_model = model.eval() self.first_stage_model.train = disabled_train for param in self.first_stage_model.parameters(): param.requires_grad = False def instantiate_cond_stage(self, config): if not self.cond_stage_trainable: if config == "__is_first_stage__": print("Using first stage also as cond stage.") self.cond_stage_model = self.first_stage_model elif config == "__is_unconditional__": print(f"Training {self.__class__.__name__} as an unconditional model.") self.cond_stage_model = None # self.be_unconditional = True else: model = instantiate_from_config(config) self.cond_stage_model = model.eval() self.cond_stage_model.train = disabled_train for param in self.cond_stage_model.parameters(): param.requires_grad = False else: assert config != "__is_first_stage__" assert config != "__is_unconditional__" model = instantiate_from_config(config) self.cond_stage_model = model def _get_denoise_row_from_list( self, samples, desc="", force_no_decoder_quantization=False ): denoise_row = [] for zd in tqdm(samples, desc=desc): denoise_row.append( self.decode_first_stage( zd.to(self.device), force_not_quantize=force_no_decoder_quantization ) ) n_imgs_per_row = len(denoise_row) denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W denoise_grid = rearrange(denoise_row, "n b c h w -> b n c h w") denoise_grid = rearrange(denoise_grid, "b n c h w -> (b n) c h w") denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) return denoise_grid def get_first_stage_encoding(self, encoder_posterior): if isinstance(encoder_posterior, DiagonalGaussianDistribution): z = encoder_posterior.sample() elif isinstance(encoder_posterior, torch.Tensor): z = encoder_posterior else: raise NotImplementedError( f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented" ) return self.scale_factor * z def get_learned_conditioning(self, c): if self.cond_stage_forward is None: if hasattr(self.cond_stage_model, "encode") and callable( self.cond_stage_model.encode ): c = self.cond_stage_model.encode(c) if isinstance(c, DiagonalGaussianDistribution): c = c.mode() else: c = self.cond_stage_model(c) else: assert hasattr(self.cond_stage_model, self.cond_stage_forward) c = getattr(self.cond_stage_model, self.cond_stage_forward)(c) return c def meshgrid(self, h, w): y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1) x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1) arr = torch.cat([y, x], dim=-1) return arr def delta_border(self, h, w): """ :param h: height :param w: width :return: normalized distance to image border, wtith min distance = 0 at border and max dist = 0.5 at image center """ lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2) arr = self.meshgrid(h, w) / lower_right_corner dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0] dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0] edge_dist = torch.min( torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1 )[0] return edge_dist def get_weighting(self, h, w, Ly, Lx, device): weighting = self.delta_border(h, w) weighting = torch.clip( weighting, self.split_input_params["clip_min_weight"], self.split_input_params["clip_max_weight"], ) weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device) if self.split_input_params["tie_braker"]: L_weighting = self.delta_border(Ly, Lx) L_weighting = torch.clip( L_weighting, self.split_input_params["clip_min_tie_weight"], self.split_input_params["clip_max_tie_weight"], ) L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device) weighting = weighting * L_weighting return weighting def get_fold_unfold( self, x, kernel_size, stride, uf=1, df=1 ): # todo load once not every time, shorten code """ :param x: img of size (bs, c, h, w) :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1]) """ bs, nc, h, w = x.shape # number of crops in image Ly = (h - kernel_size[0]) // stride[0] + 1 Lx = (w - kernel_size[1]) // stride[1] + 1 if uf == 1 and df == 1: fold_params = dict( kernel_size=kernel_size, dilation=1, padding=0, stride=stride ) unfold = torch.nn.Unfold(**fold_params) fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params) weighting = self.get_weighting( kernel_size[0], kernel_size[1], Ly, Lx, x.device ).to(x.dtype) normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx)) elif uf > 1 and df == 1: fold_params = dict( kernel_size=kernel_size, dilation=1, padding=0, stride=stride ) unfold = torch.nn.Unfold(**fold_params) fold_params2 = dict( kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf), dilation=1, padding=0, stride=(stride[0] * uf, stride[1] * uf), ) fold = torch.nn.Fold( output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2 ) weighting = self.get_weighting( kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device ).to(x.dtype) normalization = fold(weighting).view( 1, 1, h * uf, w * uf ) # normalizes the overlap weighting = weighting.view( (1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx) ) elif df > 1 and uf == 1: fold_params = dict( kernel_size=kernel_size, dilation=1, padding=0, stride=stride ) unfold = torch.nn.Unfold(**fold_params) fold_params2 = dict( kernel_size=(kernel_size[0] // df, kernel_size[0] // df), dilation=1, padding=0, stride=(stride[0] // df, stride[1] // df), ) fold = torch.nn.Fold( output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2 ) weighting = self.get_weighting( kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device ).to(x.dtype) normalization = fold(weighting).view( 1, 1, h // df, w // df ) # normalizes the overlap weighting = weighting.view( (1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx) ) else: raise NotImplementedError return fold, unfold, normalization, weighting @torch.no_grad() def get_input( self, batch, k, return_first_stage_outputs=False, force_c_encode=False, cond_key=None, return_original_cond=False, bs=None, return_x=False, mask_k=None, ): x = super().get_input(batch, k) if bs is not None: x = x[:bs] x = x.to(self.device) encoder_posterior = self.encode_first_stage(x) z = self.get_first_stage_encoding(encoder_posterior).detach() if mask_k is not None: mx = super().get_input(batch, mask_k) if bs is not None: mx = mx[:bs] mx = mx.to(self.device) encoder_posterior = self.encode_first_stage(mx) mx = self.get_first_stage_encoding(encoder_posterior).detach() if self.model.conditioning_key is not None and not self.force_null_conditioning: if cond_key is None: cond_key = self.cond_stage_key if cond_key != self.first_stage_key: if cond_key in ["caption", "coordinates_bbox", "txt"]: xc = batch[cond_key] elif cond_key in ["class_label", "cls"]: xc = batch else: xc = super().get_input(batch, cond_key).to(self.device) else: xc = x if not self.cond_stage_trainable or force_c_encode: if isinstance(xc, dict) or isinstance(xc, list): c = self.get_learned_conditioning(xc) else: c = self.get_learned_conditioning(xc.to(self.device)) else: c = xc if bs is not None: c = c[:bs] if self.use_positional_encodings: pos_x, pos_y = self.compute_latent_shifts(batch) ckey = __conditioning_keys__[self.model.conditioning_key] c = {ckey: c, "pos_x": pos_x, "pos_y": pos_y} else: c = None xc = None if self.use_positional_encodings: pos_x, pos_y = self.compute_latent_shifts(batch) c = {"pos_x": pos_x, "pos_y": pos_y} out = [z, c] if return_first_stage_outputs: xrec = self.decode_first_stage(z) out.extend([x, xrec]) if return_x: out.extend([x]) if return_original_cond: out.append(xc) if mask_k: out.append(mx) return out @torch.no_grad() def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False): if predict_cids: if z.dim() == 4: z = torch.argmax(z.exp(), dim=1).long() z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None) z = rearrange(z, "b h w c -> b c h w").contiguous() z = 1.0 / self.scale_factor * z return self.first_stage_model.decode(z) def decode_first_stage_grad(self, z, predict_cids=False, force_not_quantize=False): if predict_cids: if z.dim() == 4: z = torch.argmax(z.exp(), dim=1).long() z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None) z = rearrange(z, "b h w c -> b c h w").contiguous() z = 1.0 / self.scale_factor * z return self.first_stage_model.decode(z) @torch.no_grad() def encode_first_stage(self, x): return self.first_stage_model.encode(x) def shared_step(self, batch, **kwargs): x, c = self.get_input(batch, self.first_stage_key) loss = self(x, c) return loss def forward(self, x, c, *args, **kwargs): t = torch.randint( 0, self.num_timesteps, (x.shape[0],), device=self.device ).long() # t = torch.randint(500, 501, (x.shape[0],), device=self.device).long() if self.model.conditioning_key is not None: assert c is not None if self.cond_stage_trainable: c = self.get_learned_conditioning(c) if self.shorten_cond_schedule: # TODO: drop this option tc = self.cond_ids[t].to(self.device) c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float())) return self.p_losses(x, c, t, *args, **kwargs) def apply_model(self, x_noisy, t, cond, return_ids=False): if isinstance(cond, dict): # hybrid case, cond is expected to be a dict pass else: if not isinstance(cond, list): cond = [cond] key = ( "c_concat" if self.model.conditioning_key == "concat" else "c_crossattn" ) cond = {key: cond} x_recon = self.model(x_noisy, t, **cond) if isinstance(x_recon, tuple) and not return_ids: return x_recon[0] else: return x_recon def _predict_eps_from_xstart(self, x_t, t, pred_xstart): return ( extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart ) / extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) def _prior_bpd(self, x_start): """ Get the prior KL term for the variational lower-bound, measured in bits-per-dim. This term can't be optimized, as it only depends on the encoder. :param x_start: the [N x C x ...] tensor of inputs. :return: a batch of [N] KL values (in bits), one per batch element. """ batch_size = x_start.shape[0] t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) kl_prior = normal_kl( mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0 ) return mean_flat(kl_prior) / np.log(2.0) def p_mean_variance( self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False, return_x0=False, score_corrector=None, corrector_kwargs=None, ): t_in = t model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids) if score_corrector is not None: assert self.parameterization == "eps" model_out = score_corrector.modify_score( self, model_out, x, t, c, **corrector_kwargs ) if return_codebook_ids: model_out, logits = model_out if self.parameterization == "eps": x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) elif self.parameterization == "x0": x_recon = model_out else: raise NotImplementedError() if clip_denoised: x_recon.clamp_(-1.0, 1.0) if quantize_denoised: x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon) model_mean, posterior_variance, posterior_log_variance = self.q_posterior( x_start=x_recon, x_t=x, t=t ) if return_codebook_ids: return model_mean, posterior_variance, posterior_log_variance, logits elif return_x0: return model_mean, posterior_variance, posterior_log_variance, x_recon else: return model_mean, posterior_variance, posterior_log_variance @torch.no_grad() def p_sample( self, x, c, t, clip_denoised=False, repeat_noise=False, return_codebook_ids=False, quantize_denoised=False, return_x0=False, temperature=1.0, noise_dropout=0.0, score_corrector=None, corrector_kwargs=None, ): b, *_, device = *x.shape, x.device outputs = self.p_mean_variance( x=x, c=c, t=t, clip_denoised=clip_denoised, return_codebook_ids=return_codebook_ids, quantize_denoised=quantize_denoised, return_x0=return_x0, score_corrector=score_corrector, corrector_kwargs=corrector_kwargs, ) if return_codebook_ids: raise DeprecationWarning("Support dropped.") model_mean, _, model_log_variance, logits = outputs elif return_x0: model_mean, _, model_log_variance, x0 = outputs else: model_mean, _, model_log_variance = outputs noise = noise_like(x.shape, device, repeat_noise) * temperature if noise_dropout > 0.0: noise = torch.nn.functional.dropout(noise, p=noise_dropout) # no noise when t == 0 nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) if return_codebook_ids: return ( model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1), ) if return_x0: return ( model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0, ) else: return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise @torch.no_grad() def progressive_denoising( self, cond, shape, verbose=True, callback=None, quantize_denoised=False, img_callback=None, mask=None, x0=None, temperature=1.0, noise_dropout=0.0, score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None, log_every_t=None, ): if not log_every_t: log_every_t = self.log_every_t timesteps = self.num_timesteps if batch_size is not None: b = batch_size if batch_size is not None else shape[0] shape = [batch_size] + list(shape) else: b = batch_size = shape[0] if x_T is None: img = torch.randn(shape, device=self.device) else: img = x_T intermediates = [] if cond is not None: if isinstance(cond, dict): cond = { key: cond[key][:batch_size] if not isinstance(cond[key], list) else list(map(lambda x: x[:batch_size], cond[key])) for key in cond } else: cond = ( [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] ) if start_T is not None: timesteps = min(timesteps, start_T) iterator = ( tqdm( reversed(range(0, timesteps)), desc="Progressive Generation", total=timesteps, ) if verbose else reversed(range(0, timesteps)) ) if type(temperature) == float: temperature = [temperature] * timesteps for i in iterator: ts = torch.full((b,), i, device=self.device, dtype=torch.long) if self.shorten_cond_schedule: assert self.model.conditioning_key != "hybrid" tc = self.cond_ids[ts].to(cond.device) cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) img, x0_partial = self.p_sample( img, cond, ts, clip_denoised=self.clip_denoised, quantize_denoised=quantize_denoised, return_x0=True, temperature=temperature[i], noise_dropout=noise_dropout, score_corrector=score_corrector, corrector_kwargs=corrector_kwargs, ) if mask is not None: assert x0 is not None img_orig = self.q_sample(x0, ts) img = img_orig * mask + (1.0 - mask) * img if i % log_every_t == 0 or i == timesteps - 1: intermediates.append(x0_partial) if callback: callback(i) if img_callback: img_callback(img, i) return img, intermediates @torch.no_grad() def p_sample_loop( self, cond, shape, return_intermediates=False, x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False, mask=None, x0=None, img_callback=None, start_T=None, log_every_t=None, ): if not log_every_t: log_every_t = self.log_every_t device = self.betas.device b = shape[0] if x_T is None: img = torch.randn(shape, device=device) else: img = x_T intermediates = [img] if timesteps is None: timesteps = self.num_timesteps if start_T is not None: timesteps = min(timesteps, start_T) iterator = ( tqdm(reversed(range(0, timesteps)), desc="Sampling t", total=timesteps) if verbose else reversed(range(0, timesteps)) ) if mask is not None: assert x0 is not None assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match for i in iterator: ts = torch.full((b,), i, device=device, dtype=torch.long) if self.shorten_cond_schedule: assert self.model.conditioning_key != "hybrid" tc = self.cond_ids[ts].to(cond.device) cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) img = self.p_sample( img, cond, ts, clip_denoised=self.clip_denoised, quantize_denoised=quantize_denoised, ) if mask is not None: img_orig = self.q_sample(x0, ts) img = img_orig * mask + (1.0 - mask) * img if i % log_every_t == 0 or i == timesteps - 1: intermediates.append(img) if callback: callback(i) if img_callback: img_callback(img, i) if return_intermediates: return img, intermediates return img @torch.no_grad() def sample( self, cond, batch_size=16, return_intermediates=False, x_T=None, verbose=True, timesteps=None, quantize_denoised=False, mask=None, x0=None, shape=None, **kwargs, ): if shape is None: shape = (batch_size, self.channels, self.image_size, self.image_size) if cond is not None: if isinstance(cond, dict): cond = { key: cond[key][:batch_size] if not isinstance(cond[key], list) else list(map(lambda x: x[:batch_size], cond[key])) for key in cond } else: cond = ( [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] ) return self.p_sample_loop( cond, shape, return_intermediates=return_intermediates, x_T=x_T, verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised, mask=mask, x0=x0, ) @torch.no_grad() def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs): if ddim: ddim_sampler = DDIMSampler(self) shape = (self.channels, self.image_size, self.image_size) samples, intermediates = ddim_sampler.sample( ddim_steps, batch_size, shape, cond, verbose=False, **kwargs ) else: samples, intermediates = self.sample( cond=cond, batch_size=batch_size, return_intermediates=True, **kwargs ) return samples, intermediates @torch.no_grad() def get_unconditional_conditioning(self, batch_size, null_label=None): if null_label is not None: xc = null_label if isinstance(xc, ListConfig): xc = list(xc) if isinstance(xc, dict) or isinstance(xc, list): c = self.get_learned_conditioning(xc) else: if hasattr(xc, "to"): xc = xc.to(self.device) c = self.get_learned_conditioning(xc) else: if self.cond_stage_key in ["class_label", "cls"]: xc = self.cond_stage_model.get_unconditional_conditioning( batch_size, device=self.device ) return self.get_learned_conditioning(xc) else: raise NotImplementedError("todo") if isinstance(c, list): # in case the encoder gives us a list for i in range(len(c)): c[i] = repeat(c[i], "1 ... -> b ...", b=batch_size).to(self.device) else: c = repeat(c, "1 ... -> b ...", b=batch_size).to(self.device) return c @torch.no_grad() def log_images( self, batch, N=8, n_row=4, sample=True, ddim_steps=50, ddim_eta=0.0, return_keys=None, quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True, plot_diffusion_rows=True, unconditional_guidance_scale=1.0, unconditional_guidance_label=None, use_ema_scope=True, **kwargs, ): ema_scope = self.ema_scope if use_ema_scope else nullcontext use_ddim = ddim_steps is not None log = dict() z, c, x, xrec, xc = self.get_input( batch, self.first_stage_key, return_first_stage_outputs=True, force_c_encode=True, return_original_cond=True, bs=N, ) N = min(x.shape[0], N) n_row = min(x.shape[0], n_row) log["inputs"] = x log["reconstruction"] = xrec if self.model.conditioning_key is not None: if hasattr(self.cond_stage_model, "decode"): xc = self.cond_stage_model.decode(c) log["conditioning"] = xc elif self.cond_stage_key in ["caption", "txt"]: xc = log_txt_as_img( (x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25, ) log["conditioning"] = xc elif self.cond_stage_key in ["class_label", "cls"]: try: xc = log_txt_as_img( (x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25, ) log["conditioning"] = xc except KeyError: # probably no "human_label" in batch pass elif isimage(xc): log["conditioning"] = xc if ismap(xc): log["original_conditioning"] = self.to_rgb(xc) if plot_diffusion_rows: # get diffusion row diffusion_row = list() z_start = z[:n_row] for t in range(self.num_timesteps): if t % self.log_every_t == 0 or t == self.num_timesteps - 1: t = repeat(torch.tensor([t]), "1 -> b", b=n_row) t = t.to(self.device).long() noise = torch.randn_like(z_start) z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise) diffusion_row.append(self.decode_first_stage(z_noisy)) diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W diffusion_grid = rearrange(diffusion_row, "n b c h w -> b n c h w") diffusion_grid = rearrange(diffusion_grid, "b n c h w -> (b n) c h w") diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0]) log["diffusion_row"] = diffusion_grid if sample: # get denoise row with ema_scope("Sampling"): samples, z_denoise_row = self.sample_log( cond=c, batch_size=N, ddim=use_ddim, ddim_steps=ddim_steps, eta=ddim_eta, ) # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) x_samples = self.decode_first_stage(samples) log["samples"] = x_samples if plot_denoise_rows: denoise_grid = self._get_denoise_row_from_list(z_denoise_row) log["denoise_row"] = denoise_grid if ( quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance(self.first_stage_model, IdentityFirstStage) ): # also display when quantizing x0 while sampling with ema_scope("Plotting Quantized Denoised"): samples, z_denoise_row = self.sample_log( cond=c, batch_size=N, ddim=use_ddim, ddim_steps=ddim_steps, eta=ddim_eta, quantize_denoised=True, ) # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True, # quantize_denoised=True) x_samples = self.decode_first_stage(samples.to(self.device)) log["samples_x0_quantized"] = x_samples if unconditional_guidance_scale > 1.0: uc = self.get_unconditional_conditioning(N, unconditional_guidance_label) if self.model.conditioning_key == "crossattn-adm": uc = {"c_crossattn": [uc], "c_adm": c["c_adm"]} with ema_scope("Sampling with classifier-free guidance"): samples_cfg, _ = self.sample_log( cond=c, batch_size=N, ddim=use_ddim, ddim_steps=ddim_steps, eta=ddim_eta, unconditional_guidance_scale=unconditional_guidance_scale, unconditional_conditioning=uc, ) x_samples_cfg = self.decode_first_stage(samples_cfg) log[ f"samples_cfg_scale_{unconditional_guidance_scale:.2f}" ] = x_samples_cfg if inpaint: # make a simple center square b, h, w = z.shape[0], z.shape[2], z.shape[3] mask = torch.ones(N, h, w).to(self.device) # zeros will be filled in mask[:, h // 4 : 3 * h // 4, w // 4 : 3 * w // 4] = 0.0 mask = mask[:, None, ...] with ema_scope("Plotting Inpaint"): samples, _ = self.sample_log( cond=c, batch_size=N, ddim=use_ddim, eta=ddim_eta, ddim_steps=ddim_steps, x0=z[:N], mask=mask, ) x_samples = self.decode_first_stage(samples.to(self.device)) log["samples_inpainting"] = x_samples log["mask"] = mask # outpaint mask = 1.0 - mask with ema_scope("Plotting Outpaint"): samples, _ = self.sample_log( cond=c, batch_size=N, ddim=use_ddim, eta=ddim_eta, ddim_steps=ddim_steps, x0=z[:N], mask=mask, ) x_samples = self.decode_first_stage(samples.to(self.device)) log["samples_outpainting"] = x_samples if plot_progressive_rows: with ema_scope("Plotting Progressives"): img, progressives = self.progressive_denoising( c, shape=(self.channels, self.image_size, self.image_size), batch_size=N, ) prog_row = self._get_denoise_row_from_list( progressives, desc="Progressive Generation" ) log["progressive_row"] = prog_row if return_keys: if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0: return log else: return {key: log[key] for key in return_keys} return log def configure_optimizers(self): lr = self.learning_rate params = list(self.model.parameters()) if self.cond_stage_trainable: print(f"{self.__class__.__name__}: Also optimizing conditioner params!") params = params + list(self.cond_stage_model.parameters()) if self.learn_logvar: print("Diffusion model optimizing logvar") params.append(self.logvar) opt = torch.optim.AdamW(params, lr=lr) if self.use_scheduler: assert "target" in self.scheduler_config scheduler = instantiate_from_config(self.scheduler_config) print("Setting up LambdaLR scheduler...") scheduler = [ { "scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule), "interval": "step", "frequency": 1, } ] return [opt], scheduler return opt @torch.no_grad() def to_rgb(self, x): x = x.float() if not hasattr(self, "colorize"): self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x) x = nn.functional.conv2d(x, weight=self.colorize) x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0 return x class DiffusionWrapper(torch.nn.Module): def __init__(self, diff_model_config, conditioning_key): super().__init__() self.sequential_cross_attn = diff_model_config.pop( "sequential_crossattn", False ) self.diffusion_model = instantiate_from_config(diff_model_config) self.conditioning_key = conditioning_key assert self.conditioning_key in [ None, "concat", "crossattn", "hybrid", "adm", "hybrid-adm", "crossattn-adm", ] def forward( self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None ): if self.conditioning_key is None: out = self.diffusion_model(x, t) elif self.conditioning_key == "concat": xc = torch.cat([x] + c_concat, dim=1) out = self.diffusion_model(xc, t) elif self.conditioning_key == "crossattn": if not self.sequential_cross_attn: cc = torch.cat(c_crossattn, 1) else: cc = c_crossattn out = self.diffusion_model(x, t, context=cc) elif self.conditioning_key == "hybrid": xc = torch.cat([x] + c_concat, dim=1) cc = torch.cat(c_crossattn, 1) out = self.diffusion_model(xc, t, context=cc) elif self.conditioning_key == "hybrid-adm": assert c_adm is not None xc = torch.cat([x] + c_concat, dim=1) cc = torch.cat(c_crossattn, 1) out = self.diffusion_model(xc, t, context=cc, y=c_adm) elif self.conditioning_key == "crossattn-adm": assert c_adm is not None cc = torch.cat(c_crossattn, 1) out = self.diffusion_model(x, t, context=cc, y=c_adm) elif self.conditioning_key == "adm": cc = c_crossattn[0] out = self.diffusion_model(x, t, y=cc) else: raise NotImplementedError() return out class LatentUpscaleDiffusion(LatentDiffusion): def __init__( self, *args, low_scale_config, low_scale_key="LR", noise_level_key=None, **kwargs, ): super().__init__(*args, **kwargs) # assumes that neither the cond_stage nor the low_scale_model contain trainable params assert not self.cond_stage_trainable self.instantiate_low_stage(low_scale_config) self.low_scale_key = low_scale_key self.noise_level_key = noise_level_key def instantiate_low_stage(self, config): model = instantiate_from_config(config) self.low_scale_model = model.eval() self.low_scale_model.train = disabled_train for param in self.low_scale_model.parameters(): param.requires_grad = False @torch.no_grad() def get_input(self, batch, k, cond_key=None, bs=None, log_mode=False): if not log_mode: z, c = super().get_input(batch, k, force_c_encode=True, bs=bs) else: z, c, x, xrec, xc = super().get_input( batch, self.first_stage_key, return_first_stage_outputs=True, force_c_encode=True, return_original_cond=True, bs=bs, ) x_low = batch[self.low_scale_key][:bs] x_low = rearrange(x_low, "b h w c -> b c h w") x_low = x_low.to(memory_format=torch.contiguous_format).float() zx, noise_level = self.low_scale_model(x_low) if self.noise_level_key is not None: # get noise level from batch instead, e.g. when extracting a custom noise level for bsr raise NotImplementedError("TODO") all_conds = {"c_concat": [zx], "c_crossattn": [c], "c_adm": noise_level} if log_mode: # TODO: maybe disable if too expensive x_low_rec = self.low_scale_model.decode(zx) return z, all_conds, x, xrec, xc, x_low, x_low_rec, noise_level return z, all_conds @torch.no_grad() def log_images( self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1.0, return_keys=None, plot_denoise_rows=False, plot_progressive_rows=True, plot_diffusion_rows=True, unconditional_guidance_scale=1.0, unconditional_guidance_label=None, use_ema_scope=True, **kwargs, ): ema_scope = self.ema_scope if use_ema_scope else nullcontext use_ddim = ddim_steps is not None log = dict() z, c, x, xrec, xc, x_low, x_low_rec, noise_level = self.get_input( batch, self.first_stage_key, bs=N, log_mode=True ) N = min(x.shape[0], N) n_row = min(x.shape[0], n_row) log["inputs"] = x log["reconstruction"] = xrec log["x_lr"] = x_low log[ f"x_lr_rec_@noise_levels{'-'.join(map(lambda x: str(x), list(noise_level.cpu().numpy())))}" ] = x_low_rec if self.model.conditioning_key is not None: if hasattr(self.cond_stage_model, "decode"): xc = self.cond_stage_model.decode(c) log["conditioning"] = xc elif self.cond_stage_key in ["caption", "txt"]: xc = log_txt_as_img( (x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25, ) log["conditioning"] = xc elif self.cond_stage_key in ["class_label", "cls"]: xc = log_txt_as_img( (x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25, ) log["conditioning"] = xc elif isimage(xc): log["conditioning"] = xc if ismap(xc): log["original_conditioning"] = self.to_rgb(xc) if plot_diffusion_rows: # get diffusion row diffusion_row = list() z_start = z[:n_row] for t in range(self.num_timesteps): if t % self.log_every_t == 0 or t == self.num_timesteps - 1: t = repeat(torch.tensor([t]), "1 -> b", b=n_row) t = t.to(self.device).long() noise = torch.randn_like(z_start) z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise) diffusion_row.append(self.decode_first_stage(z_noisy)) diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W diffusion_grid = rearrange(diffusion_row, "n b c h w -> b n c h w") diffusion_grid = rearrange(diffusion_grid, "b n c h w -> (b n) c h w") diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0]) log["diffusion_row"] = diffusion_grid if sample: # get denoise row with ema_scope("Sampling"): samples, z_denoise_row = self.sample_log( cond=c, batch_size=N, ddim=use_ddim, ddim_steps=ddim_steps, eta=ddim_eta, ) # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) x_samples = self.decode_first_stage(samples) log["samples"] = x_samples if plot_denoise_rows: denoise_grid = self._get_denoise_row_from_list(z_denoise_row) log["denoise_row"] = denoise_grid if unconditional_guidance_scale > 1.0: uc_tmp = self.get_unconditional_conditioning( N, unconditional_guidance_label ) # TODO explore better "unconditional" choices for the other keys # maybe guide away from empty text label and highest noise level and maximally degraded zx? uc = dict() for k in c: if k == "c_crossattn": assert isinstance(c[k], list) and len(c[k]) == 1 uc[k] = [uc_tmp] elif k == "c_adm": # todo: only run with text-based guidance? assert isinstance(c[k], torch.Tensor) # uc[k] = torch.ones_like(c[k]) * self.low_scale_model.max_noise_level uc[k] = c[k] elif isinstance(c[k], list): uc[k] = [c[k][i] for i in range(len(c[k]))] else: uc[k] = c[k] with ema_scope("Sampling with classifier-free guidance"): samples_cfg, _ = self.sample_log( cond=c, batch_size=N, ddim=use_ddim, ddim_steps=ddim_steps, eta=ddim_eta, unconditional_guidance_scale=unconditional_guidance_scale, unconditional_conditioning=uc, ) x_samples_cfg = self.decode_first_stage(samples_cfg) log[ f"samples_cfg_scale_{unconditional_guidance_scale:.2f}" ] = x_samples_cfg if plot_progressive_rows: with ema_scope("Plotting Progressives"): img, progressives = self.progressive_denoising( c, shape=(self.channels, self.image_size, self.image_size), batch_size=N, ) prog_row = self._get_denoise_row_from_list( progressives, desc="Progressive Generation" ) log["progressive_row"] = prog_row return log class LatentFinetuneDiffusion(LatentDiffusion): """ Basis for different finetunas, such as inpainting or depth2image To disable finetuning mode, set finetune_keys to None """ def __init__( self, concat_keys: tuple, finetune_keys=( "model.diffusion_model.input_blocks.0.0.weight", "model_ema.diffusion_modelinput_blocks00weight", ), keep_finetune_dims=4, # if model was trained without concat mode before and we would like to keep these channels c_concat_log_start=None, # to log reconstruction of c_concat codes c_concat_log_end=None, *args, **kwargs, ): ckpt_path = kwargs.pop("ckpt_path", None) ignore_keys = kwargs.pop("ignore_keys", list()) super().__init__(*args, **kwargs) self.finetune_keys = finetune_keys self.concat_keys = concat_keys self.keep_dims = keep_finetune_dims self.c_concat_log_start = c_concat_log_start self.c_concat_log_end = c_concat_log_end if exists(self.finetune_keys): assert exists(ckpt_path), "can only finetune from a given checkpoint" if exists(ckpt_path): self.init_from_ckpt(ckpt_path, ignore_keys) def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): sd = torch.load(path, map_location="cpu") if "state_dict" in list(sd.keys()): sd = sd["state_dict"] keys = list(sd.keys()) for k in keys: for ik in ignore_keys: if k.startswith(ik): print("Deleting key {} from state_dict.".format(k)) del sd[k] # make it explicit, finetune by including extra input channels if exists(self.finetune_keys) and k in self.finetune_keys: new_entry = None for name, param in self.named_parameters(): if name in self.finetune_keys: print( f"modifying key '{name}' and keeping its original {self.keep_dims} (channels) dimensions only" ) new_entry = torch.zeros_like(param) # zero init assert exists(new_entry), "did not find matching parameter to modify" new_entry[:, : self.keep_dims, ...] = sd[k] sd[k] = new_entry missing, unexpected = ( self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(sd, strict=False) ) print( f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys" ) if len(missing) > 0: print(f"Missing Keys: {missing}") if len(unexpected) > 0: print(f"Unexpected Keys: {unexpected}") @torch.no_grad() def log_images( self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1.0, return_keys=None, quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True, plot_diffusion_rows=True, unconditional_guidance_scale=1.0, unconditional_guidance_label=None, use_ema_scope=True, **kwargs, ): ema_scope = self.ema_scope if use_ema_scope else nullcontext use_ddim = ddim_steps is not None log = dict() z, c, x, xrec, xc = self.get_input( batch, self.first_stage_key, bs=N, return_first_stage_outputs=True ) c_cat, c = c["c_concat"][0], c["c_crossattn"][0] N = min(x.shape[0], N) n_row = min(x.shape[0], n_row) log["inputs"] = x log["reconstruction"] = xrec if self.model.conditioning_key is not None: if hasattr(self.cond_stage_model, "decode"): xc = self.cond_stage_model.decode(c) log["conditioning"] = xc elif self.cond_stage_key in ["caption", "txt"]: xc = log_txt_as_img( (x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25, ) log["conditioning"] = xc elif self.cond_stage_key in ["class_label", "cls"]: xc = log_txt_as_img( (x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25, ) log["conditioning"] = xc elif isimage(xc): log["conditioning"] = xc if ismap(xc): log["original_conditioning"] = self.to_rgb(xc) if not (self.c_concat_log_start is None and self.c_concat_log_end is None): log["c_concat_decoded"] = self.decode_first_stage( c_cat[:, self.c_concat_log_start : self.c_concat_log_end] ) if plot_diffusion_rows: # get diffusion row diffusion_row = list() z_start = z[:n_row] for t in range(self.num_timesteps): if t % self.log_every_t == 0 or t == self.num_timesteps - 1: t = repeat(torch.tensor([t]), "1 -> b", b=n_row) t = t.to(self.device).long() noise = torch.randn_like(z_start) z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise) diffusion_row.append(self.decode_first_stage(z_noisy)) diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W diffusion_grid = rearrange(diffusion_row, "n b c h w -> b n c h w") diffusion_grid = rearrange(diffusion_grid, "b n c h w -> (b n) c h w") diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0]) log["diffusion_row"] = diffusion_grid if sample: # get denoise row with ema_scope("Sampling"): samples, z_denoise_row = self.sample_log( cond={"c_concat": [c_cat], "c_crossattn": [c]}, batch_size=N, ddim=use_ddim, ddim_steps=ddim_steps, eta=ddim_eta, ) # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) x_samples = self.decode_first_stage(samples) log["samples"] = x_samples if plot_denoise_rows: denoise_grid = self._get_denoise_row_from_list(z_denoise_row) log["denoise_row"] = denoise_grid if unconditional_guidance_scale > 1.0: uc_cross = self.get_unconditional_conditioning( N, unconditional_guidance_label ) uc_cat = c_cat uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]} with ema_scope("Sampling with classifier-free guidance"): samples_cfg, _ = self.sample_log( cond={"c_concat": [c_cat], "c_crossattn": [c]}, batch_size=N, ddim=use_ddim, ddim_steps=ddim_steps, eta=ddim_eta, unconditional_guidance_scale=unconditional_guidance_scale, unconditional_conditioning=uc_full, ) x_samples_cfg = self.decode_first_stage(samples_cfg) log[ f"samples_cfg_scale_{unconditional_guidance_scale:.2f}" ] = x_samples_cfg return log class LatentInpaintDiffusion(LatentFinetuneDiffusion): """ can either run as pure inpainting model (only concat mode) or with mixed conditionings, e.g. mask as concat and text via cross-attn. To disable finetuning mode, set finetune_keys to None """ def __init__( self, concat_keys=("mask", "masked_image"), masked_image_key="masked_image", *args, **kwargs, ): super().__init__(concat_keys, *args, **kwargs) self.masked_image_key = masked_image_key assert self.masked_image_key in concat_keys @torch.no_grad() def get_input( self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False ): # note: restricted to non-trainable encoders currently assert ( not self.cond_stage_trainable ), "trainable cond stages not yet supported for inpainting" z, c, x, xrec, xc = super().get_input( batch, self.first_stage_key, return_first_stage_outputs=True, force_c_encode=True, return_original_cond=True, bs=bs, ) assert exists(self.concat_keys) c_cat = list() for ck in self.concat_keys: cc = ( rearrange(batch[ck], "b h w c -> b c h w") .to(memory_format=torch.contiguous_format) .float() ) if bs is not None: cc = cc[:bs] cc = cc.to(self.device) bchw = z.shape if ck != self.masked_image_key: cc = torch.nn.functional.interpolate(cc, size=bchw[-2:]) else: cc = self.get_first_stage_encoding(self.encode_first_stage(cc)) c_cat.append(cc) c_cat = torch.cat(c_cat, dim=1) all_conds = {"c_concat": [c_cat], "c_crossattn": [c]} if return_first_stage_outputs: return z, all_conds, x, xrec, xc return z, all_conds @torch.no_grad() def log_images(self, *args, **kwargs): log = super(LatentInpaintDiffusion, self).log_images(*args, **kwargs) log["masked_image"] = ( rearrange(args[0]["masked_image"], "b h w c -> b c h w") .to(memory_format=torch.contiguous_format) .float() ) return log class LatentDepth2ImageDiffusion(LatentFinetuneDiffusion): """ condition on monocular depth estimation """ def __init__(self, depth_stage_config, concat_keys=("midas_in",), *args, **kwargs): super().__init__(concat_keys=concat_keys, *args, **kwargs) self.depth_model = instantiate_from_config(depth_stage_config) self.depth_stage_key = concat_keys[0] @torch.no_grad() def get_input( self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False ): # note: restricted to non-trainable encoders currently assert ( not self.cond_stage_trainable ), "trainable cond stages not yet supported for depth2img" z, c, x, xrec, xc = super().get_input( batch, self.first_stage_key, return_first_stage_outputs=True, force_c_encode=True, return_original_cond=True, bs=bs, ) assert exists(self.concat_keys) assert len(self.concat_keys) == 1 c_cat = list() for ck in self.concat_keys: cc = batch[ck] if bs is not None: cc = cc[:bs] cc = cc.to(self.device) cc = self.depth_model(cc) cc = torch.nn.functional.interpolate( cc, size=z.shape[2:], mode="bicubic", align_corners=False, ) depth_min, depth_max = ( torch.amin(cc, dim=[1, 2, 3], keepdim=True), torch.amax(cc, dim=[1, 2, 3], keepdim=True), ) cc = 2.0 * (cc - depth_min) / (depth_max - depth_min + 0.001) - 1.0 c_cat.append(cc) c_cat = torch.cat(c_cat, dim=1) all_conds = {"c_concat": [c_cat], "c_crossattn": [c]} if return_first_stage_outputs: return z, all_conds, x, xrec, xc return z, all_conds @torch.no_grad() def log_images(self, *args, **kwargs): log = super().log_images(*args, **kwargs) depth = self.depth_model(args[0][self.depth_stage_key]) depth_min, depth_max = ( torch.amin(depth, dim=[1, 2, 3], keepdim=True), torch.amax(depth, dim=[1, 2, 3], keepdim=True), ) log["depth"] = 2.0 * (depth - depth_min) / (depth_max - depth_min) - 1.0 return log class LatentUpscaleFinetuneDiffusion(LatentFinetuneDiffusion): """ condition on low-res image (and optionally on some spatial noise augmentation) """ def __init__( self, concat_keys=("lr",), reshuffle_patch_size=None, low_scale_config=None, low_scale_key=None, *args, **kwargs, ): super().__init__(concat_keys=concat_keys, *args, **kwargs) self.reshuffle_patch_size = reshuffle_patch_size self.low_scale_model = None if low_scale_config is not None: print("Initializing a low-scale model") assert exists(low_scale_key) self.instantiate_low_stage(low_scale_config) self.low_scale_key = low_scale_key def instantiate_low_stage(self, config): model = instantiate_from_config(config) self.low_scale_model = model.eval() self.low_scale_model.train = disabled_train for param in self.low_scale_model.parameters(): param.requires_grad = False @torch.no_grad() def get_input( self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False ): # note: restricted to non-trainable encoders currently assert ( not self.cond_stage_trainable ), "trainable cond stages not yet supported for upscaling-ft" z, c, x, xrec, xc = super().get_input( batch, self.first_stage_key, return_first_stage_outputs=True, force_c_encode=True, return_original_cond=True, bs=bs, ) assert exists(self.concat_keys) assert len(self.concat_keys) == 1 # optionally make spatial noise_level here c_cat = list() noise_level = None for ck in self.concat_keys: cc = batch[ck] cc = rearrange(cc, "b h w c -> b c h w") if exists(self.reshuffle_patch_size): assert isinstance(self.reshuffle_patch_size, int) cc = rearrange( cc, "b c (p1 h) (p2 w) -> b (p1 p2 c) h w", p1=self.reshuffle_patch_size, p2=self.reshuffle_patch_size, ) if bs is not None: cc = cc[:bs] cc = cc.to(self.device) if exists(self.low_scale_model) and ck == self.low_scale_key: cc, noise_level = self.low_scale_model(cc) c_cat.append(cc) c_cat = torch.cat(c_cat, dim=1) if exists(noise_level): all_conds = {"c_concat": [c_cat], "c_crossattn": [c], "c_adm": noise_level} else: all_conds = {"c_concat": [c_cat], "c_crossattn": [c]} if return_first_stage_outputs: return z, all_conds, x, xrec, xc return z, all_conds @torch.no_grad() def log_images(self, *args, **kwargs): log = super().log_images(*args, **kwargs) log["lr"] = rearrange(args[0]["lr"], "b h w c -> b c h w") return log