util.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224
  1. import importlib
  2. from inspect import isfunction
  3. import numpy as np
  4. import torch
  5. from PIL import Image, ImageDraw, ImageFont
  6. from torch import optim
  7. def log_txt_as_img(wh, xc, size=10):
  8. # wh a tuple of (width, height)
  9. # xc a list of captions to plot
  10. b = len(xc)
  11. txts = list()
  12. for bi in range(b):
  13. txt = Image.new("RGB", wh, color="white")
  14. draw = ImageDraw.Draw(txt)
  15. font = ImageFont.truetype("font/Arial_Unicode.ttf", size=size)
  16. nc = int(32 * (wh[0] / 256))
  17. lines = "\n".join(
  18. xc[bi][start : start + nc] for start in range(0, len(xc[bi]), nc)
  19. )
  20. try:
  21. draw.text((0, 0), lines, fill="black", font=font)
  22. except UnicodeEncodeError:
  23. print("Cant encode string for logging. Skipping.")
  24. txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
  25. txts.append(txt)
  26. txts = np.stack(txts)
  27. txts = torch.tensor(txts)
  28. return txts
  29. def ismap(x):
  30. if not isinstance(x, torch.Tensor):
  31. return False
  32. return (len(x.shape) == 4) and (x.shape[1] > 3)
  33. def isimage(x):
  34. if not isinstance(x, torch.Tensor):
  35. return False
  36. return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
  37. def exists(x):
  38. return x is not None
  39. def default(val, d):
  40. if exists(val):
  41. return val
  42. return d() if isfunction(d) else d
  43. def mean_flat(tensor):
  44. """
  45. https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
  46. Take the mean over all non-batch dimensions.
  47. """
  48. return tensor.mean(dim=list(range(1, len(tensor.shape))))
  49. def count_params(model, verbose=False):
  50. total_params = sum(p.numel() for p in model.parameters())
  51. if verbose:
  52. print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
  53. return total_params
  54. def instantiate_from_config(config, **kwargs):
  55. if "target" not in config:
  56. if config == "__is_first_stage__":
  57. return None
  58. elif config == "__is_unconditional__":
  59. return None
  60. raise KeyError("Expected key `target` to instantiate.")
  61. return get_obj_from_str(config["target"])(**config.get("params", dict()), **kwargs)
  62. def get_obj_from_str(string, reload=False):
  63. module, cls = string.rsplit(".", 1)
  64. if reload:
  65. module_imp = importlib.import_module(module)
  66. importlib.reload(module_imp)
  67. return getattr(importlib.import_module(module, package=None), cls)
  68. class AdamWwithEMAandWings(optim.Optimizer):
  69. # credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298
  70. def __init__(
  71. self,
  72. params,
  73. lr=1.0e-3,
  74. betas=(0.9, 0.999),
  75. eps=1.0e-8, # TODO: check hyperparameters before using
  76. weight_decay=1.0e-2,
  77. amsgrad=False,
  78. ema_decay=0.9999, # ema decay to match previous code
  79. ema_power=1.0,
  80. param_names=(),
  81. ):
  82. """AdamW that saves EMA versions of the parameters."""
  83. if not 0.0 <= lr:
  84. raise ValueError("Invalid learning rate: {}".format(lr))
  85. if not 0.0 <= eps:
  86. raise ValueError("Invalid epsilon value: {}".format(eps))
  87. if not 0.0 <= betas[0] < 1.0:
  88. raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
  89. if not 0.0 <= betas[1] < 1.0:
  90. raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
  91. if not 0.0 <= weight_decay:
  92. raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
  93. if not 0.0 <= ema_decay <= 1.0:
  94. raise ValueError("Invalid ema_decay value: {}".format(ema_decay))
  95. defaults = dict(
  96. lr=lr,
  97. betas=betas,
  98. eps=eps,
  99. weight_decay=weight_decay,
  100. amsgrad=amsgrad,
  101. ema_decay=ema_decay,
  102. ema_power=ema_power,
  103. param_names=param_names,
  104. )
  105. super().__init__(params, defaults)
  106. def __setstate__(self, state):
  107. super().__setstate__(state)
  108. for group in self.param_groups:
  109. group.setdefault("amsgrad", False)
  110. @torch.no_grad()
  111. def step(self, closure=None):
  112. """Performs a single optimization step.
  113. Args:
  114. closure (callable, optional): A closure that reevaluates the model
  115. and returns the loss.
  116. """
  117. loss = None
  118. if closure is not None:
  119. with torch.enable_grad():
  120. loss = closure()
  121. for group in self.param_groups:
  122. params_with_grad = []
  123. grads = []
  124. exp_avgs = []
  125. exp_avg_sqs = []
  126. ema_params_with_grad = []
  127. state_sums = []
  128. max_exp_avg_sqs = []
  129. state_steps = []
  130. amsgrad = group["amsgrad"]
  131. beta1, beta2 = group["betas"]
  132. ema_decay = group["ema_decay"]
  133. ema_power = group["ema_power"]
  134. for p in group["params"]:
  135. if p.grad is None:
  136. continue
  137. params_with_grad.append(p)
  138. if p.grad.is_sparse:
  139. raise RuntimeError("AdamW does not support sparse gradients")
  140. grads.append(p.grad)
  141. state = self.state[p]
  142. # State initialization
  143. if len(state) == 0:
  144. state["step"] = 0
  145. # Exponential moving average of gradient values
  146. state["exp_avg"] = torch.zeros_like(
  147. p, memory_format=torch.preserve_format
  148. )
  149. # Exponential moving average of squared gradient values
  150. state["exp_avg_sq"] = torch.zeros_like(
  151. p, memory_format=torch.preserve_format
  152. )
  153. if amsgrad:
  154. # Maintains max of all exp. moving avg. of sq. grad. values
  155. state["max_exp_avg_sq"] = torch.zeros_like(
  156. p, memory_format=torch.preserve_format
  157. )
  158. # Exponential moving average of parameter values
  159. state["param_exp_avg"] = p.detach().float().clone()
  160. exp_avgs.append(state["exp_avg"])
  161. exp_avg_sqs.append(state["exp_avg_sq"])
  162. ema_params_with_grad.append(state["param_exp_avg"])
  163. if amsgrad:
  164. max_exp_avg_sqs.append(state["max_exp_avg_sq"])
  165. # update the steps for each param group update
  166. state["step"] += 1
  167. # record the step after step update
  168. state_steps.append(state["step"])
  169. optim._functional.adamw(
  170. params_with_grad,
  171. grads,
  172. exp_avgs,
  173. exp_avg_sqs,
  174. max_exp_avg_sqs,
  175. state_steps,
  176. amsgrad=amsgrad,
  177. beta1=beta1,
  178. beta2=beta2,
  179. lr=group["lr"],
  180. weight_decay=group["weight_decay"],
  181. eps=group["eps"],
  182. maximize=False,
  183. )
  184. cur_ema_decay = min(ema_decay, 1 - state["step"] ** -ema_power)
  185. for param, ema_param in zip(params_with_grad, ema_params_with_grad):
  186. ema_param.mul_(cur_ema_decay).add_(
  187. param.float(), alpha=1 - cur_ema_decay
  188. )
  189. return loss