| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224 |
- import importlib
- from inspect import isfunction
- import numpy as np
- import torch
- from PIL import Image, ImageDraw, ImageFont
- from torch import optim
- def log_txt_as_img(wh, xc, size=10):
- # wh a tuple of (width, height)
- # xc a list of captions to plot
- b = len(xc)
- txts = list()
- for bi in range(b):
- txt = Image.new("RGB", wh, color="white")
- draw = ImageDraw.Draw(txt)
- font = ImageFont.truetype("font/Arial_Unicode.ttf", size=size)
- nc = int(32 * (wh[0] / 256))
- lines = "\n".join(
- xc[bi][start : start + nc] for start in range(0, len(xc[bi]), nc)
- )
- try:
- draw.text((0, 0), lines, fill="black", font=font)
- except UnicodeEncodeError:
- print("Cant encode string for logging. Skipping.")
- txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
- txts.append(txt)
- txts = np.stack(txts)
- txts = torch.tensor(txts)
- return txts
- def ismap(x):
- if not isinstance(x, torch.Tensor):
- return False
- return (len(x.shape) == 4) and (x.shape[1] > 3)
- def isimage(x):
- if not isinstance(x, torch.Tensor):
- return False
- return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
- def exists(x):
- return x is not None
- def default(val, d):
- if exists(val):
- return val
- return d() if isfunction(d) else d
- def mean_flat(tensor):
- """
- https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
- Take the mean over all non-batch dimensions.
- """
- return tensor.mean(dim=list(range(1, len(tensor.shape))))
- def count_params(model, verbose=False):
- total_params = sum(p.numel() for p in model.parameters())
- if verbose:
- print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
- return total_params
- def instantiate_from_config(config, **kwargs):
- if "target" not in config:
- if config == "__is_first_stage__":
- return None
- elif config == "__is_unconditional__":
- return None
- raise KeyError("Expected key `target` to instantiate.")
- return get_obj_from_str(config["target"])(**config.get("params", dict()), **kwargs)
- def get_obj_from_str(string, reload=False):
- module, cls = string.rsplit(".", 1)
- if reload:
- module_imp = importlib.import_module(module)
- importlib.reload(module_imp)
- return getattr(importlib.import_module(module, package=None), cls)
- class AdamWwithEMAandWings(optim.Optimizer):
- # credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298
- def __init__(
- self,
- params,
- lr=1.0e-3,
- betas=(0.9, 0.999),
- eps=1.0e-8, # TODO: check hyperparameters before using
- weight_decay=1.0e-2,
- amsgrad=False,
- ema_decay=0.9999, # ema decay to match previous code
- ema_power=1.0,
- param_names=(),
- ):
- """AdamW that saves EMA versions of the parameters."""
- if not 0.0 <= lr:
- raise ValueError("Invalid learning rate: {}".format(lr))
- if not 0.0 <= eps:
- raise ValueError("Invalid epsilon value: {}".format(eps))
- if not 0.0 <= betas[0] < 1.0:
- raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
- if not 0.0 <= betas[1] < 1.0:
- raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
- if not 0.0 <= weight_decay:
- raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
- if not 0.0 <= ema_decay <= 1.0:
- raise ValueError("Invalid ema_decay value: {}".format(ema_decay))
- defaults = dict(
- lr=lr,
- betas=betas,
- eps=eps,
- weight_decay=weight_decay,
- amsgrad=amsgrad,
- ema_decay=ema_decay,
- ema_power=ema_power,
- param_names=param_names,
- )
- super().__init__(params, defaults)
- def __setstate__(self, state):
- super().__setstate__(state)
- for group in self.param_groups:
- group.setdefault("amsgrad", False)
- @torch.no_grad()
- def step(self, closure=None):
- """Performs a single optimization step.
- Args:
- closure (callable, optional): A closure that reevaluates the model
- and returns the loss.
- """
- loss = None
- if closure is not None:
- with torch.enable_grad():
- loss = closure()
- for group in self.param_groups:
- params_with_grad = []
- grads = []
- exp_avgs = []
- exp_avg_sqs = []
- ema_params_with_grad = []
- state_sums = []
- max_exp_avg_sqs = []
- state_steps = []
- amsgrad = group["amsgrad"]
- beta1, beta2 = group["betas"]
- ema_decay = group["ema_decay"]
- ema_power = group["ema_power"]
- for p in group["params"]:
- if p.grad is None:
- continue
- params_with_grad.append(p)
- if p.grad.is_sparse:
- raise RuntimeError("AdamW does not support sparse gradients")
- grads.append(p.grad)
- state = self.state[p]
- # State initialization
- if len(state) == 0:
- state["step"] = 0
- # Exponential moving average of gradient values
- state["exp_avg"] = torch.zeros_like(
- p, memory_format=torch.preserve_format
- )
- # Exponential moving average of squared gradient values
- state["exp_avg_sq"] = torch.zeros_like(
- p, memory_format=torch.preserve_format
- )
- if amsgrad:
- # Maintains max of all exp. moving avg. of sq. grad. values
- state["max_exp_avg_sq"] = torch.zeros_like(
- p, memory_format=torch.preserve_format
- )
- # Exponential moving average of parameter values
- state["param_exp_avg"] = p.detach().float().clone()
- exp_avgs.append(state["exp_avg"])
- exp_avg_sqs.append(state["exp_avg_sq"])
- ema_params_with_grad.append(state["param_exp_avg"])
- if amsgrad:
- max_exp_avg_sqs.append(state["max_exp_avg_sq"])
- # update the steps for each param group update
- state["step"] += 1
- # record the step after step update
- state_steps.append(state["step"])
- optim._functional.adamw(
- params_with_grad,
- grads,
- exp_avgs,
- exp_avg_sqs,
- max_exp_avg_sqs,
- state_steps,
- amsgrad=amsgrad,
- beta1=beta1,
- beta2=beta2,
- lr=group["lr"],
- weight_decay=group["weight_decay"],
- eps=group["eps"],
- maximize=False,
- )
- cur_ema_decay = min(ema_decay, 1 - state["step"] ** -ema_power)
- for param, ema_param in zip(params_with_grad, ema_params_with_grad):
- ema_param.mul_(cur_ema_decay).add_(
- param.float(), alpha=1 - cur_ema_decay
- )
- return loss
|