ddpm.py 87 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386
  1. """
  2. Part of the implementation is borrowed and modified from ControlNet, publicly available at https://github.com/lllyasviel/ControlNet/blob/main/ldm/models/diffusion/ddpm.py
  3. """
  4. import itertools
  5. from contextlib import contextmanager, nullcontext
  6. from functools import partial
  7. import cv2
  8. import numpy as np
  9. import torch
  10. import torch.nn as nn
  11. from einops import rearrange, repeat
  12. from omegaconf import ListConfig
  13. from torch.optim.lr_scheduler import LambdaLR
  14. from torchvision.utils import make_grid
  15. from tqdm import tqdm
  16. from sorawm.iopaint.model.anytext.ldm.models.autoencoder import (
  17. AutoencoderKL,
  18. IdentityFirstStage,
  19. )
  20. from sorawm.iopaint.model.anytext.ldm.models.diffusion.ddim import DDIMSampler
  21. from sorawm.iopaint.model.anytext.ldm.modules.diffusionmodules.util import (
  22. extract_into_tensor,
  23. make_beta_schedule,
  24. noise_like,
  25. )
  26. from sorawm.iopaint.model.anytext.ldm.modules.distributions.distributions import (
  27. DiagonalGaussianDistribution,
  28. normal_kl,
  29. )
  30. from sorawm.iopaint.model.anytext.ldm.modules.ema import LitEma
  31. from sorawm.iopaint.model.anytext.ldm.util import (
  32. count_params,
  33. default,
  34. exists,
  35. instantiate_from_config,
  36. isimage,
  37. ismap,
  38. log_txt_as_img,
  39. mean_flat,
  40. )
  41. __conditioning_keys__ = {"concat": "c_concat", "crossattn": "c_crossattn", "adm": "y"}
  42. PRINT_DEBUG = False
  43. def print_grad(grad):
  44. # print('Gradient:', grad)
  45. # print(grad.shape)
  46. a = grad.max()
  47. b = grad.min()
  48. # print(f'mean={grad.mean():.4f}, max={a:.4f}, min={b:.4f}')
  49. s = 255.0 / (a - b)
  50. c = 255 * (-b / (a - b))
  51. grad = grad * s + c
  52. # print(f'mean={grad.mean():.4f}, max={grad.max():.4f}, min={grad.min():.4f}')
  53. img = grad[0].permute(1, 2, 0).detach().cpu().numpy()
  54. if img.shape[0] == 512:
  55. cv2.imwrite("grad-img.jpg", img)
  56. elif img.shape[0] == 64:
  57. cv2.imwrite("grad-latent.jpg", img)
  58. def disabled_train(self, mode=True):
  59. """Overwrite model.train with this function to make sure train/eval mode
  60. does not change anymore."""
  61. return self
  62. def uniform_on_device(r1, r2, shape, device):
  63. return (r1 - r2) * torch.rand(*shape, device=device) + r2
  64. class DDPM(torch.nn.Module):
  65. # classic DDPM with Gaussian diffusion, in image space
  66. def __init__(
  67. self,
  68. unet_config,
  69. timesteps=1000,
  70. beta_schedule="linear",
  71. loss_type="l2",
  72. ckpt_path=None,
  73. ignore_keys=[],
  74. load_only_unet=False,
  75. monitor="val/loss",
  76. use_ema=True,
  77. first_stage_key="image",
  78. image_size=256,
  79. channels=3,
  80. log_every_t=100,
  81. clip_denoised=True,
  82. linear_start=1e-4,
  83. linear_end=2e-2,
  84. cosine_s=8e-3,
  85. given_betas=None,
  86. original_elbo_weight=0.0,
  87. v_posterior=0.0, # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
  88. l_simple_weight=1.0,
  89. conditioning_key=None,
  90. parameterization="eps", # all assuming fixed variance schedules
  91. scheduler_config=None,
  92. use_positional_encodings=False,
  93. learn_logvar=False,
  94. logvar_init=0.0,
  95. make_it_fit=False,
  96. ucg_training=None,
  97. reset_ema=False,
  98. reset_num_ema_updates=False,
  99. ):
  100. super().__init__()
  101. assert parameterization in [
  102. "eps",
  103. "x0",
  104. "v",
  105. ], 'currently only supporting "eps" and "x0" and "v"'
  106. self.parameterization = parameterization
  107. print(
  108. f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode"
  109. )
  110. self.cond_stage_model = None
  111. self.clip_denoised = clip_denoised
  112. self.log_every_t = log_every_t
  113. self.first_stage_key = first_stage_key
  114. self.image_size = image_size # try conv?
  115. self.channels = channels
  116. self.use_positional_encodings = use_positional_encodings
  117. self.model = DiffusionWrapper(unet_config, conditioning_key)
  118. count_params(self.model, verbose=True)
  119. self.use_ema = use_ema
  120. if self.use_ema:
  121. self.model_ema = LitEma(self.model)
  122. print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
  123. self.use_scheduler = scheduler_config is not None
  124. if self.use_scheduler:
  125. self.scheduler_config = scheduler_config
  126. self.v_posterior = v_posterior
  127. self.original_elbo_weight = original_elbo_weight
  128. self.l_simple_weight = l_simple_weight
  129. if monitor is not None:
  130. self.monitor = monitor
  131. self.make_it_fit = make_it_fit
  132. if reset_ema:
  133. assert exists(ckpt_path)
  134. if ckpt_path is not None:
  135. self.init_from_ckpt(
  136. ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet
  137. )
  138. if reset_ema:
  139. assert self.use_ema
  140. print(
  141. f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint."
  142. )
  143. self.model_ema = LitEma(self.model)
  144. if reset_num_ema_updates:
  145. print(
  146. " +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ "
  147. )
  148. assert self.use_ema
  149. self.model_ema.reset_num_updates()
  150. self.register_schedule(
  151. given_betas=given_betas,
  152. beta_schedule=beta_schedule,
  153. timesteps=timesteps,
  154. linear_start=linear_start,
  155. linear_end=linear_end,
  156. cosine_s=cosine_s,
  157. )
  158. self.loss_type = loss_type
  159. self.learn_logvar = learn_logvar
  160. logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))
  161. if self.learn_logvar:
  162. self.logvar = nn.Parameter(self.logvar, requires_grad=True)
  163. else:
  164. self.register_buffer("logvar", logvar)
  165. self.ucg_training = ucg_training or dict()
  166. if self.ucg_training:
  167. self.ucg_prng = np.random.RandomState()
  168. def register_schedule(
  169. self,
  170. given_betas=None,
  171. beta_schedule="linear",
  172. timesteps=1000,
  173. linear_start=1e-4,
  174. linear_end=2e-2,
  175. cosine_s=8e-3,
  176. ):
  177. if exists(given_betas):
  178. betas = given_betas
  179. else:
  180. betas = make_beta_schedule(
  181. beta_schedule,
  182. timesteps,
  183. linear_start=linear_start,
  184. linear_end=linear_end,
  185. cosine_s=cosine_s,
  186. )
  187. alphas = 1.0 - betas
  188. alphas_cumprod = np.cumprod(alphas, axis=0)
  189. # np.save('1.npy', alphas_cumprod)
  190. alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
  191. (timesteps,) = betas.shape
  192. self.num_timesteps = int(timesteps)
  193. self.linear_start = linear_start
  194. self.linear_end = linear_end
  195. assert (
  196. alphas_cumprod.shape[0] == self.num_timesteps
  197. ), "alphas have to be defined for each timestep"
  198. to_torch = partial(torch.tensor, dtype=torch.float32)
  199. self.register_buffer("betas", to_torch(betas))
  200. self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
  201. self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev))
  202. # calculations for diffusion q(x_t | x_{t-1}) and others
  203. self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod)))
  204. self.register_buffer(
  205. "sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod))
  206. )
  207. self.register_buffer(
  208. "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod))
  209. )
  210. self.register_buffer(
  211. "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod))
  212. )
  213. self.register_buffer(
  214. "sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1))
  215. )
  216. # calculations for posterior q(x_{t-1} | x_t, x_0)
  217. posterior_variance = (1 - self.v_posterior) * betas * (
  218. 1.0 - alphas_cumprod_prev
  219. ) / (1.0 - alphas_cumprod) + self.v_posterior * betas
  220. # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
  221. self.register_buffer("posterior_variance", to_torch(posterior_variance))
  222. # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
  223. self.register_buffer(
  224. "posterior_log_variance_clipped",
  225. to_torch(np.log(np.maximum(posterior_variance, 1e-20))),
  226. )
  227. self.register_buffer(
  228. "posterior_mean_coef1",
  229. to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)),
  230. )
  231. self.register_buffer(
  232. "posterior_mean_coef2",
  233. to_torch(
  234. (1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod)
  235. ),
  236. )
  237. if self.parameterization == "eps":
  238. lvlb_weights = self.betas**2 / (
  239. 2
  240. * self.posterior_variance
  241. * to_torch(alphas)
  242. * (1 - self.alphas_cumprod)
  243. )
  244. elif self.parameterization == "x0":
  245. lvlb_weights = (
  246. 0.5
  247. * np.sqrt(torch.Tensor(alphas_cumprod))
  248. / (2.0 * 1 - torch.Tensor(alphas_cumprod))
  249. )
  250. elif self.parameterization == "v":
  251. lvlb_weights = torch.ones_like(
  252. self.betas**2
  253. / (
  254. 2
  255. * self.posterior_variance
  256. * to_torch(alphas)
  257. * (1 - self.alphas_cumprod)
  258. )
  259. )
  260. else:
  261. raise NotImplementedError("mu not supported")
  262. lvlb_weights[0] = lvlb_weights[1]
  263. self.register_buffer("lvlb_weights", lvlb_weights, persistent=False)
  264. assert not torch.isnan(self.lvlb_weights).all()
  265. @contextmanager
  266. def ema_scope(self, context=None):
  267. if self.use_ema:
  268. self.model_ema.store(self.model.parameters())
  269. self.model_ema.copy_to(self.model)
  270. if context is not None:
  271. print(f"{context}: Switched to EMA weights")
  272. try:
  273. yield None
  274. finally:
  275. if self.use_ema:
  276. self.model_ema.restore(self.model.parameters())
  277. if context is not None:
  278. print(f"{context}: Restored training weights")
  279. @torch.no_grad()
  280. def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
  281. sd = torch.load(path, map_location="cpu")
  282. if "state_dict" in list(sd.keys()):
  283. sd = sd["state_dict"]
  284. keys = list(sd.keys())
  285. for k in keys:
  286. for ik in ignore_keys:
  287. if k.startswith(ik):
  288. print("Deleting key {} from state_dict.".format(k))
  289. del sd[k]
  290. if self.make_it_fit:
  291. n_params = len(
  292. [
  293. name
  294. for name, _ in itertools.chain(
  295. self.named_parameters(), self.named_buffers()
  296. )
  297. ]
  298. )
  299. for name, param in tqdm(
  300. itertools.chain(self.named_parameters(), self.named_buffers()),
  301. desc="Fitting old weights to new weights",
  302. total=n_params,
  303. ):
  304. if not name in sd:
  305. continue
  306. old_shape = sd[name].shape
  307. new_shape = param.shape
  308. assert len(old_shape) == len(new_shape)
  309. if len(new_shape) > 2:
  310. # we only modify first two axes
  311. assert new_shape[2:] == old_shape[2:]
  312. # assumes first axis corresponds to output dim
  313. if not new_shape == old_shape:
  314. new_param = param.clone()
  315. old_param = sd[name]
  316. if len(new_shape) == 1:
  317. for i in range(new_param.shape[0]):
  318. new_param[i] = old_param[i % old_shape[0]]
  319. elif len(new_shape) >= 2:
  320. for i in range(new_param.shape[0]):
  321. for j in range(new_param.shape[1]):
  322. new_param[i, j] = old_param[
  323. i % old_shape[0], j % old_shape[1]
  324. ]
  325. n_used_old = torch.ones(old_shape[1])
  326. for j in range(new_param.shape[1]):
  327. n_used_old[j % old_shape[1]] += 1
  328. n_used_new = torch.zeros(new_shape[1])
  329. for j in range(new_param.shape[1]):
  330. n_used_new[j] = n_used_old[j % old_shape[1]]
  331. n_used_new = n_used_new[None, :]
  332. while len(n_used_new.shape) < len(new_shape):
  333. n_used_new = n_used_new.unsqueeze(-1)
  334. new_param /= n_used_new
  335. sd[name] = new_param
  336. missing, unexpected = (
  337. self.load_state_dict(sd, strict=False)
  338. if not only_model
  339. else self.model.load_state_dict(sd, strict=False)
  340. )
  341. print(
  342. f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
  343. )
  344. if len(missing) > 0:
  345. print(f"Missing Keys:\n {missing}")
  346. if len(unexpected) > 0:
  347. print(f"\nUnexpected Keys:\n {unexpected}")
  348. def q_mean_variance(self, x_start, t):
  349. """
  350. Get the distribution q(x_t | x_0).
  351. :param x_start: the [N x C x ...] tensor of noiseless inputs.
  352. :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
  353. :return: A tuple (mean, variance, log_variance), all of x_start's shape.
  354. """
  355. mean = extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
  356. variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
  357. log_variance = extract_into_tensor(
  358. self.log_one_minus_alphas_cumprod, t, x_start.shape
  359. )
  360. return mean, variance, log_variance
  361. def predict_start_from_noise(self, x_t, t, noise):
  362. return (
  363. extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
  364. - extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
  365. * noise
  366. )
  367. def predict_start_from_z_and_v(self, x_t, t, v):
  368. # self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
  369. # self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
  370. return (
  371. extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t
  372. - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
  373. )
  374. def predict_eps_from_z_and_v(self, x_t, t, v):
  375. return (
  376. extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v
  377. + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape)
  378. * x_t
  379. )
  380. def q_posterior(self, x_start, x_t, t):
  381. posterior_mean = (
  382. extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
  383. + extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
  384. )
  385. posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)
  386. posterior_log_variance_clipped = extract_into_tensor(
  387. self.posterior_log_variance_clipped, t, x_t.shape
  388. )
  389. return posterior_mean, posterior_variance, posterior_log_variance_clipped
  390. def p_mean_variance(self, x, t, clip_denoised: bool):
  391. model_out = self.model(x, t)
  392. if self.parameterization == "eps":
  393. x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
  394. elif self.parameterization == "x0":
  395. x_recon = model_out
  396. if clip_denoised:
  397. x_recon.clamp_(-1.0, 1.0)
  398. model_mean, posterior_variance, posterior_log_variance = self.q_posterior(
  399. x_start=x_recon, x_t=x, t=t
  400. )
  401. return model_mean, posterior_variance, posterior_log_variance
  402. @torch.no_grad()
  403. def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
  404. b, *_, device = *x.shape, x.device
  405. model_mean, _, model_log_variance = self.p_mean_variance(
  406. x=x, t=t, clip_denoised=clip_denoised
  407. )
  408. noise = noise_like(x.shape, device, repeat_noise)
  409. # no noise when t == 0
  410. nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
  411. return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
  412. @torch.no_grad()
  413. def p_sample_loop(self, shape, return_intermediates=False):
  414. device = self.betas.device
  415. b = shape[0]
  416. img = torch.randn(shape, device=device)
  417. intermediates = [img]
  418. for i in tqdm(
  419. reversed(range(0, self.num_timesteps)),
  420. desc="Sampling t",
  421. total=self.num_timesteps,
  422. ):
  423. img = self.p_sample(
  424. img,
  425. torch.full((b,), i, device=device, dtype=torch.long),
  426. clip_denoised=self.clip_denoised,
  427. )
  428. if i % self.log_every_t == 0 or i == self.num_timesteps - 1:
  429. intermediates.append(img)
  430. if return_intermediates:
  431. return img, intermediates
  432. return img
  433. @torch.no_grad()
  434. def sample(self, batch_size=16, return_intermediates=False):
  435. image_size = self.image_size
  436. channels = self.channels
  437. return self.p_sample_loop(
  438. (batch_size, channels, image_size, image_size),
  439. return_intermediates=return_intermediates,
  440. )
  441. def q_sample(self, x_start, t, noise=None):
  442. noise = default(noise, lambda: torch.randn_like(x_start))
  443. return (
  444. extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
  445. + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
  446. * noise
  447. )
  448. def get_v(self, x, noise, t):
  449. return (
  450. extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise
  451. - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x
  452. )
  453. def get_loss(self, pred, target, mean=True):
  454. if self.loss_type == "l1":
  455. loss = (target - pred).abs()
  456. if mean:
  457. loss = loss.mean()
  458. elif self.loss_type == "l2":
  459. if mean:
  460. loss = torch.nn.functional.mse_loss(target, pred)
  461. else:
  462. loss = torch.nn.functional.mse_loss(target, pred, reduction="none")
  463. else:
  464. raise NotImplementedError("unknown loss type '{loss_type}'")
  465. return loss
  466. def p_losses(self, x_start, t, noise=None):
  467. noise = default(noise, lambda: torch.randn_like(x_start))
  468. x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
  469. model_out = self.model(x_noisy, t)
  470. loss_dict = {}
  471. if self.parameterization == "eps":
  472. target = noise
  473. elif self.parameterization == "x0":
  474. target = x_start
  475. elif self.parameterization == "v":
  476. target = self.get_v(x_start, noise, t)
  477. else:
  478. raise NotImplementedError(
  479. f"Parameterization {self.parameterization} not yet supported"
  480. )
  481. loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3])
  482. log_prefix = "train" if self.training else "val"
  483. loss_dict.update({f"{log_prefix}/loss_simple": loss.mean()})
  484. loss_simple = loss.mean() * self.l_simple_weight
  485. loss_vlb = (self.lvlb_weights[t] * loss).mean()
  486. loss_dict.update({f"{log_prefix}/loss_vlb": loss_vlb})
  487. loss = loss_simple + self.original_elbo_weight * loss_vlb
  488. loss_dict.update({f"{log_prefix}/loss": loss})
  489. return loss, loss_dict
  490. def forward(self, x, *args, **kwargs):
  491. # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size
  492. # assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
  493. t = torch.randint(
  494. 0, self.num_timesteps, (x.shape[0],), device=self.device
  495. ).long()
  496. return self.p_losses(x, t, *args, **kwargs)
  497. def get_input(self, batch, k):
  498. x = batch[k]
  499. if len(x.shape) == 3:
  500. x = x[..., None]
  501. x = rearrange(x, "b h w c -> b c h w")
  502. x = x.to(memory_format=torch.contiguous_format).float()
  503. return x
  504. def shared_step(self, batch):
  505. x = self.get_input(batch, self.first_stage_key)
  506. loss, loss_dict = self(x)
  507. return loss, loss_dict
  508. def training_step(self, batch, batch_idx):
  509. for k in self.ucg_training:
  510. p = self.ucg_training[k]["p"]
  511. val = self.ucg_training[k]["val"]
  512. if val is None:
  513. val = ""
  514. for i in range(len(batch[k])):
  515. if self.ucg_prng.choice(2, p=[1 - p, p]):
  516. batch[k][i] = val
  517. loss, loss_dict = self.shared_step(batch)
  518. self.log_dict(
  519. loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=True
  520. )
  521. self.log(
  522. "global_step",
  523. self.global_step,
  524. prog_bar=True,
  525. logger=True,
  526. on_step=True,
  527. on_epoch=False,
  528. )
  529. if self.use_scheduler:
  530. lr = self.optimizers().param_groups[0]["lr"]
  531. self.log(
  532. "lr_abs", lr, prog_bar=True, logger=True, on_step=True, on_epoch=False
  533. )
  534. return loss
  535. @torch.no_grad()
  536. def validation_step(self, batch, batch_idx):
  537. _, loss_dict_no_ema = self.shared_step(batch)
  538. with self.ema_scope():
  539. _, loss_dict_ema = self.shared_step(batch)
  540. loss_dict_ema = {key + "_ema": loss_dict_ema[key] for key in loss_dict_ema}
  541. self.log_dict(
  542. loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True
  543. )
  544. self.log_dict(
  545. loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True
  546. )
  547. def on_train_batch_end(self, *args, **kwargs):
  548. if self.use_ema:
  549. self.model_ema(self.model)
  550. def _get_rows_from_list(self, samples):
  551. n_imgs_per_row = len(samples)
  552. denoise_grid = rearrange(samples, "n b c h w -> b n c h w")
  553. denoise_grid = rearrange(denoise_grid, "b n c h w -> (b n) c h w")
  554. denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
  555. return denoise_grid
  556. @torch.no_grad()
  557. def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
  558. log = dict()
  559. x = self.get_input(batch, self.first_stage_key)
  560. N = min(x.shape[0], N)
  561. n_row = min(x.shape[0], n_row)
  562. x = x.to(self.device)[:N]
  563. log["inputs"] = x
  564. # get diffusion row
  565. diffusion_row = list()
  566. x_start = x[:n_row]
  567. for t in range(self.num_timesteps):
  568. if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
  569. t = repeat(torch.tensor([t]), "1 -> b", b=n_row)
  570. t = t.to(self.device).long()
  571. noise = torch.randn_like(x_start)
  572. x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
  573. diffusion_row.append(x_noisy)
  574. log["diffusion_row"] = self._get_rows_from_list(diffusion_row)
  575. if sample:
  576. # get denoise row
  577. with self.ema_scope("Plotting"):
  578. samples, denoise_row = self.sample(
  579. batch_size=N, return_intermediates=True
  580. )
  581. log["samples"] = samples
  582. log["denoise_row"] = self._get_rows_from_list(denoise_row)
  583. if return_keys:
  584. if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
  585. return log
  586. else:
  587. return {key: log[key] for key in return_keys}
  588. return log
  589. def configure_optimizers(self):
  590. lr = self.learning_rate
  591. params = list(self.model.parameters())
  592. if self.learn_logvar:
  593. params = params + [self.logvar]
  594. opt = torch.optim.AdamW(params, lr=lr)
  595. return opt
  596. class LatentDiffusion(DDPM):
  597. """main class"""
  598. def __init__(
  599. self,
  600. first_stage_config,
  601. cond_stage_config,
  602. num_timesteps_cond=None,
  603. cond_stage_key="image",
  604. cond_stage_trainable=False,
  605. concat_mode=True,
  606. cond_stage_forward=None,
  607. conditioning_key=None,
  608. scale_factor=1.0,
  609. scale_by_std=False,
  610. force_null_conditioning=False,
  611. *args,
  612. **kwargs,
  613. ):
  614. self.force_null_conditioning = force_null_conditioning
  615. self.num_timesteps_cond = default(num_timesteps_cond, 1)
  616. self.scale_by_std = scale_by_std
  617. assert self.num_timesteps_cond <= kwargs["timesteps"]
  618. # for backwards compatibility after implementation of DiffusionWrapper
  619. if conditioning_key is None:
  620. conditioning_key = "concat" if concat_mode else "crossattn"
  621. if (
  622. cond_stage_config == "__is_unconditional__"
  623. and not self.force_null_conditioning
  624. ):
  625. conditioning_key = None
  626. ckpt_path = kwargs.pop("ckpt_path", None)
  627. reset_ema = kwargs.pop("reset_ema", False)
  628. reset_num_ema_updates = kwargs.pop("reset_num_ema_updates", False)
  629. ignore_keys = kwargs.pop("ignore_keys", [])
  630. super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
  631. self.concat_mode = concat_mode
  632. self.cond_stage_trainable = cond_stage_trainable
  633. self.cond_stage_key = cond_stage_key
  634. try:
  635. self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
  636. except:
  637. self.num_downs = 0
  638. if not scale_by_std:
  639. self.scale_factor = scale_factor
  640. else:
  641. self.register_buffer("scale_factor", torch.tensor(scale_factor))
  642. self.instantiate_first_stage(first_stage_config)
  643. self.instantiate_cond_stage(cond_stage_config)
  644. self.cond_stage_forward = cond_stage_forward
  645. self.clip_denoised = False
  646. self.bbox_tokenizer = None
  647. self.restarted_from_ckpt = False
  648. if ckpt_path is not None:
  649. self.init_from_ckpt(ckpt_path, ignore_keys)
  650. self.restarted_from_ckpt = True
  651. if reset_ema:
  652. assert self.use_ema
  653. print(
  654. f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint."
  655. )
  656. self.model_ema = LitEma(self.model)
  657. if reset_num_ema_updates:
  658. print(
  659. " +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ "
  660. )
  661. assert self.use_ema
  662. self.model_ema.reset_num_updates()
  663. def make_cond_schedule(
  664. self,
  665. ):
  666. self.cond_ids = torch.full(
  667. size=(self.num_timesteps,),
  668. fill_value=self.num_timesteps - 1,
  669. dtype=torch.long,
  670. )
  671. ids = torch.round(
  672. torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)
  673. ).long()
  674. self.cond_ids[: self.num_timesteps_cond] = ids
  675. @torch.no_grad()
  676. def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
  677. # only for very first batch
  678. if (
  679. self.scale_by_std
  680. and self.current_epoch == 0
  681. and self.global_step == 0
  682. and batch_idx == 0
  683. and not self.restarted_from_ckpt
  684. ):
  685. assert (
  686. self.scale_factor == 1.0
  687. ), "rather not use custom rescaling and std-rescaling simultaneously"
  688. # set rescale weight to 1./std of encodings
  689. print("### USING STD-RESCALING ###")
  690. x = super().get_input(batch, self.first_stage_key)
  691. x = x.to(self.device)
  692. encoder_posterior = self.encode_first_stage(x)
  693. z = self.get_first_stage_encoding(encoder_posterior).detach()
  694. del self.scale_factor
  695. self.register_buffer("scale_factor", 1.0 / z.flatten().std())
  696. print(f"setting self.scale_factor to {self.scale_factor}")
  697. print("### USING STD-RESCALING ###")
  698. def register_schedule(
  699. self,
  700. given_betas=None,
  701. beta_schedule="linear",
  702. timesteps=1000,
  703. linear_start=1e-4,
  704. linear_end=2e-2,
  705. cosine_s=8e-3,
  706. ):
  707. super().register_schedule(
  708. given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s
  709. )
  710. self.shorten_cond_schedule = self.num_timesteps_cond > 1
  711. if self.shorten_cond_schedule:
  712. self.make_cond_schedule()
  713. def instantiate_first_stage(self, config):
  714. model = instantiate_from_config(config)
  715. self.first_stage_model = model.eval()
  716. self.first_stage_model.train = disabled_train
  717. for param in self.first_stage_model.parameters():
  718. param.requires_grad = False
  719. def instantiate_cond_stage(self, config):
  720. if not self.cond_stage_trainable:
  721. if config == "__is_first_stage__":
  722. print("Using first stage also as cond stage.")
  723. self.cond_stage_model = self.first_stage_model
  724. elif config == "__is_unconditional__":
  725. print(f"Training {self.__class__.__name__} as an unconditional model.")
  726. self.cond_stage_model = None
  727. # self.be_unconditional = True
  728. else:
  729. model = instantiate_from_config(config)
  730. self.cond_stage_model = model.eval()
  731. self.cond_stage_model.train = disabled_train
  732. for param in self.cond_stage_model.parameters():
  733. param.requires_grad = False
  734. else:
  735. assert config != "__is_first_stage__"
  736. assert config != "__is_unconditional__"
  737. model = instantiate_from_config(config)
  738. self.cond_stage_model = model
  739. def _get_denoise_row_from_list(
  740. self, samples, desc="", force_no_decoder_quantization=False
  741. ):
  742. denoise_row = []
  743. for zd in tqdm(samples, desc=desc):
  744. denoise_row.append(
  745. self.decode_first_stage(
  746. zd.to(self.device), force_not_quantize=force_no_decoder_quantization
  747. )
  748. )
  749. n_imgs_per_row = len(denoise_row)
  750. denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W
  751. denoise_grid = rearrange(denoise_row, "n b c h w -> b n c h w")
  752. denoise_grid = rearrange(denoise_grid, "b n c h w -> (b n) c h w")
  753. denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
  754. return denoise_grid
  755. def get_first_stage_encoding(self, encoder_posterior):
  756. if isinstance(encoder_posterior, DiagonalGaussianDistribution):
  757. z = encoder_posterior.sample()
  758. elif isinstance(encoder_posterior, torch.Tensor):
  759. z = encoder_posterior
  760. else:
  761. raise NotImplementedError(
  762. f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented"
  763. )
  764. return self.scale_factor * z
  765. def get_learned_conditioning(self, c):
  766. if self.cond_stage_forward is None:
  767. if hasattr(self.cond_stage_model, "encode") and callable(
  768. self.cond_stage_model.encode
  769. ):
  770. c = self.cond_stage_model.encode(c)
  771. if isinstance(c, DiagonalGaussianDistribution):
  772. c = c.mode()
  773. else:
  774. c = self.cond_stage_model(c)
  775. else:
  776. assert hasattr(self.cond_stage_model, self.cond_stage_forward)
  777. c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
  778. return c
  779. def meshgrid(self, h, w):
  780. y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1)
  781. x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1)
  782. arr = torch.cat([y, x], dim=-1)
  783. return arr
  784. def delta_border(self, h, w):
  785. """
  786. :param h: height
  787. :param w: width
  788. :return: normalized distance to image border,
  789. wtith min distance = 0 at border and max dist = 0.5 at image center
  790. """
  791. lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2)
  792. arr = self.meshgrid(h, w) / lower_right_corner
  793. dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0]
  794. dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0]
  795. edge_dist = torch.min(
  796. torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1
  797. )[0]
  798. return edge_dist
  799. def get_weighting(self, h, w, Ly, Lx, device):
  800. weighting = self.delta_border(h, w)
  801. weighting = torch.clip(
  802. weighting,
  803. self.split_input_params["clip_min_weight"],
  804. self.split_input_params["clip_max_weight"],
  805. )
  806. weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device)
  807. if self.split_input_params["tie_braker"]:
  808. L_weighting = self.delta_border(Ly, Lx)
  809. L_weighting = torch.clip(
  810. L_weighting,
  811. self.split_input_params["clip_min_tie_weight"],
  812. self.split_input_params["clip_max_tie_weight"],
  813. )
  814. L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device)
  815. weighting = weighting * L_weighting
  816. return weighting
  817. def get_fold_unfold(
  818. self, x, kernel_size, stride, uf=1, df=1
  819. ): # todo load once not every time, shorten code
  820. """
  821. :param x: img of size (bs, c, h, w)
  822. :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])
  823. """
  824. bs, nc, h, w = x.shape
  825. # number of crops in image
  826. Ly = (h - kernel_size[0]) // stride[0] + 1
  827. Lx = (w - kernel_size[1]) // stride[1] + 1
  828. if uf == 1 and df == 1:
  829. fold_params = dict(
  830. kernel_size=kernel_size, dilation=1, padding=0, stride=stride
  831. )
  832. unfold = torch.nn.Unfold(**fold_params)
  833. fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params)
  834. weighting = self.get_weighting(
  835. kernel_size[0], kernel_size[1], Ly, Lx, x.device
  836. ).to(x.dtype)
  837. normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap
  838. weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx))
  839. elif uf > 1 and df == 1:
  840. fold_params = dict(
  841. kernel_size=kernel_size, dilation=1, padding=0, stride=stride
  842. )
  843. unfold = torch.nn.Unfold(**fold_params)
  844. fold_params2 = dict(
  845. kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf),
  846. dilation=1,
  847. padding=0,
  848. stride=(stride[0] * uf, stride[1] * uf),
  849. )
  850. fold = torch.nn.Fold(
  851. output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2
  852. )
  853. weighting = self.get_weighting(
  854. kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device
  855. ).to(x.dtype)
  856. normalization = fold(weighting).view(
  857. 1, 1, h * uf, w * uf
  858. ) # normalizes the overlap
  859. weighting = weighting.view(
  860. (1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx)
  861. )
  862. elif df > 1 and uf == 1:
  863. fold_params = dict(
  864. kernel_size=kernel_size, dilation=1, padding=0, stride=stride
  865. )
  866. unfold = torch.nn.Unfold(**fold_params)
  867. fold_params2 = dict(
  868. kernel_size=(kernel_size[0] // df, kernel_size[0] // df),
  869. dilation=1,
  870. padding=0,
  871. stride=(stride[0] // df, stride[1] // df),
  872. )
  873. fold = torch.nn.Fold(
  874. output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2
  875. )
  876. weighting = self.get_weighting(
  877. kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device
  878. ).to(x.dtype)
  879. normalization = fold(weighting).view(
  880. 1, 1, h // df, w // df
  881. ) # normalizes the overlap
  882. weighting = weighting.view(
  883. (1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx)
  884. )
  885. else:
  886. raise NotImplementedError
  887. return fold, unfold, normalization, weighting
  888. @torch.no_grad()
  889. def get_input(
  890. self,
  891. batch,
  892. k,
  893. return_first_stage_outputs=False,
  894. force_c_encode=False,
  895. cond_key=None,
  896. return_original_cond=False,
  897. bs=None,
  898. return_x=False,
  899. mask_k=None,
  900. ):
  901. x = super().get_input(batch, k)
  902. if bs is not None:
  903. x = x[:bs]
  904. x = x.to(self.device)
  905. encoder_posterior = self.encode_first_stage(x)
  906. z = self.get_first_stage_encoding(encoder_posterior).detach()
  907. if mask_k is not None:
  908. mx = super().get_input(batch, mask_k)
  909. if bs is not None:
  910. mx = mx[:bs]
  911. mx = mx.to(self.device)
  912. encoder_posterior = self.encode_first_stage(mx)
  913. mx = self.get_first_stage_encoding(encoder_posterior).detach()
  914. if self.model.conditioning_key is not None and not self.force_null_conditioning:
  915. if cond_key is None:
  916. cond_key = self.cond_stage_key
  917. if cond_key != self.first_stage_key:
  918. if cond_key in ["caption", "coordinates_bbox", "txt"]:
  919. xc = batch[cond_key]
  920. elif cond_key in ["class_label", "cls"]:
  921. xc = batch
  922. else:
  923. xc = super().get_input(batch, cond_key).to(self.device)
  924. else:
  925. xc = x
  926. if not self.cond_stage_trainable or force_c_encode:
  927. if isinstance(xc, dict) or isinstance(xc, list):
  928. c = self.get_learned_conditioning(xc)
  929. else:
  930. c = self.get_learned_conditioning(xc.to(self.device))
  931. else:
  932. c = xc
  933. if bs is not None:
  934. c = c[:bs]
  935. if self.use_positional_encodings:
  936. pos_x, pos_y = self.compute_latent_shifts(batch)
  937. ckey = __conditioning_keys__[self.model.conditioning_key]
  938. c = {ckey: c, "pos_x": pos_x, "pos_y": pos_y}
  939. else:
  940. c = None
  941. xc = None
  942. if self.use_positional_encodings:
  943. pos_x, pos_y = self.compute_latent_shifts(batch)
  944. c = {"pos_x": pos_x, "pos_y": pos_y}
  945. out = [z, c]
  946. if return_first_stage_outputs:
  947. xrec = self.decode_first_stage(z)
  948. out.extend([x, xrec])
  949. if return_x:
  950. out.extend([x])
  951. if return_original_cond:
  952. out.append(xc)
  953. if mask_k:
  954. out.append(mx)
  955. return out
  956. @torch.no_grad()
  957. def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
  958. if predict_cids:
  959. if z.dim() == 4:
  960. z = torch.argmax(z.exp(), dim=1).long()
  961. z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
  962. z = rearrange(z, "b h w c -> b c h w").contiguous()
  963. z = 1.0 / self.scale_factor * z
  964. return self.first_stage_model.decode(z)
  965. def decode_first_stage_grad(self, z, predict_cids=False, force_not_quantize=False):
  966. if predict_cids:
  967. if z.dim() == 4:
  968. z = torch.argmax(z.exp(), dim=1).long()
  969. z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
  970. z = rearrange(z, "b h w c -> b c h w").contiguous()
  971. z = 1.0 / self.scale_factor * z
  972. return self.first_stage_model.decode(z)
  973. @torch.no_grad()
  974. def encode_first_stage(self, x):
  975. return self.first_stage_model.encode(x)
  976. def shared_step(self, batch, **kwargs):
  977. x, c = self.get_input(batch, self.first_stage_key)
  978. loss = self(x, c)
  979. return loss
  980. def forward(self, x, c, *args, **kwargs):
  981. t = torch.randint(
  982. 0, self.num_timesteps, (x.shape[0],), device=self.device
  983. ).long()
  984. # t = torch.randint(500, 501, (x.shape[0],), device=self.device).long()
  985. if self.model.conditioning_key is not None:
  986. assert c is not None
  987. if self.cond_stage_trainable:
  988. c = self.get_learned_conditioning(c)
  989. if self.shorten_cond_schedule: # TODO: drop this option
  990. tc = self.cond_ids[t].to(self.device)
  991. c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
  992. return self.p_losses(x, c, t, *args, **kwargs)
  993. def apply_model(self, x_noisy, t, cond, return_ids=False):
  994. if isinstance(cond, dict):
  995. # hybrid case, cond is expected to be a dict
  996. pass
  997. else:
  998. if not isinstance(cond, list):
  999. cond = [cond]
  1000. key = (
  1001. "c_concat" if self.model.conditioning_key == "concat" else "c_crossattn"
  1002. )
  1003. cond = {key: cond}
  1004. x_recon = self.model(x_noisy, t, **cond)
  1005. if isinstance(x_recon, tuple) and not return_ids:
  1006. return x_recon[0]
  1007. else:
  1008. return x_recon
  1009. def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
  1010. return (
  1011. extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
  1012. - pred_xstart
  1013. ) / extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
  1014. def _prior_bpd(self, x_start):
  1015. """
  1016. Get the prior KL term for the variational lower-bound, measured in
  1017. bits-per-dim.
  1018. This term can't be optimized, as it only depends on the encoder.
  1019. :param x_start: the [N x C x ...] tensor of inputs.
  1020. :return: a batch of [N] KL values (in bits), one per batch element.
  1021. """
  1022. batch_size = x_start.shape[0]
  1023. t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
  1024. qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
  1025. kl_prior = normal_kl(
  1026. mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0
  1027. )
  1028. return mean_flat(kl_prior) / np.log(2.0)
  1029. def p_mean_variance(
  1030. self,
  1031. x,
  1032. c,
  1033. t,
  1034. clip_denoised: bool,
  1035. return_codebook_ids=False,
  1036. quantize_denoised=False,
  1037. return_x0=False,
  1038. score_corrector=None,
  1039. corrector_kwargs=None,
  1040. ):
  1041. t_in = t
  1042. model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids)
  1043. if score_corrector is not None:
  1044. assert self.parameterization == "eps"
  1045. model_out = score_corrector.modify_score(
  1046. self, model_out, x, t, c, **corrector_kwargs
  1047. )
  1048. if return_codebook_ids:
  1049. model_out, logits = model_out
  1050. if self.parameterization == "eps":
  1051. x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
  1052. elif self.parameterization == "x0":
  1053. x_recon = model_out
  1054. else:
  1055. raise NotImplementedError()
  1056. if clip_denoised:
  1057. x_recon.clamp_(-1.0, 1.0)
  1058. if quantize_denoised:
  1059. x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon)
  1060. model_mean, posterior_variance, posterior_log_variance = self.q_posterior(
  1061. x_start=x_recon, x_t=x, t=t
  1062. )
  1063. if return_codebook_ids:
  1064. return model_mean, posterior_variance, posterior_log_variance, logits
  1065. elif return_x0:
  1066. return model_mean, posterior_variance, posterior_log_variance, x_recon
  1067. else:
  1068. return model_mean, posterior_variance, posterior_log_variance
  1069. @torch.no_grad()
  1070. def p_sample(
  1071. self,
  1072. x,
  1073. c,
  1074. t,
  1075. clip_denoised=False,
  1076. repeat_noise=False,
  1077. return_codebook_ids=False,
  1078. quantize_denoised=False,
  1079. return_x0=False,
  1080. temperature=1.0,
  1081. noise_dropout=0.0,
  1082. score_corrector=None,
  1083. corrector_kwargs=None,
  1084. ):
  1085. b, *_, device = *x.shape, x.device
  1086. outputs = self.p_mean_variance(
  1087. x=x,
  1088. c=c,
  1089. t=t,
  1090. clip_denoised=clip_denoised,
  1091. return_codebook_ids=return_codebook_ids,
  1092. quantize_denoised=quantize_denoised,
  1093. return_x0=return_x0,
  1094. score_corrector=score_corrector,
  1095. corrector_kwargs=corrector_kwargs,
  1096. )
  1097. if return_codebook_ids:
  1098. raise DeprecationWarning("Support dropped.")
  1099. model_mean, _, model_log_variance, logits = outputs
  1100. elif return_x0:
  1101. model_mean, _, model_log_variance, x0 = outputs
  1102. else:
  1103. model_mean, _, model_log_variance = outputs
  1104. noise = noise_like(x.shape, device, repeat_noise) * temperature
  1105. if noise_dropout > 0.0:
  1106. noise = torch.nn.functional.dropout(noise, p=noise_dropout)
  1107. # no noise when t == 0
  1108. nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
  1109. if return_codebook_ids:
  1110. return (
  1111. model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise,
  1112. logits.argmax(dim=1),
  1113. )
  1114. if return_x0:
  1115. return (
  1116. model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise,
  1117. x0,
  1118. )
  1119. else:
  1120. return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
  1121. @torch.no_grad()
  1122. def progressive_denoising(
  1123. self,
  1124. cond,
  1125. shape,
  1126. verbose=True,
  1127. callback=None,
  1128. quantize_denoised=False,
  1129. img_callback=None,
  1130. mask=None,
  1131. x0=None,
  1132. temperature=1.0,
  1133. noise_dropout=0.0,
  1134. score_corrector=None,
  1135. corrector_kwargs=None,
  1136. batch_size=None,
  1137. x_T=None,
  1138. start_T=None,
  1139. log_every_t=None,
  1140. ):
  1141. if not log_every_t:
  1142. log_every_t = self.log_every_t
  1143. timesteps = self.num_timesteps
  1144. if batch_size is not None:
  1145. b = batch_size if batch_size is not None else shape[0]
  1146. shape = [batch_size] + list(shape)
  1147. else:
  1148. b = batch_size = shape[0]
  1149. if x_T is None:
  1150. img = torch.randn(shape, device=self.device)
  1151. else:
  1152. img = x_T
  1153. intermediates = []
  1154. if cond is not None:
  1155. if isinstance(cond, dict):
  1156. cond = {
  1157. key: cond[key][:batch_size]
  1158. if not isinstance(cond[key], list)
  1159. else list(map(lambda x: x[:batch_size], cond[key]))
  1160. for key in cond
  1161. }
  1162. else:
  1163. cond = (
  1164. [c[:batch_size] for c in cond]
  1165. if isinstance(cond, list)
  1166. else cond[:batch_size]
  1167. )
  1168. if start_T is not None:
  1169. timesteps = min(timesteps, start_T)
  1170. iterator = (
  1171. tqdm(
  1172. reversed(range(0, timesteps)),
  1173. desc="Progressive Generation",
  1174. total=timesteps,
  1175. )
  1176. if verbose
  1177. else reversed(range(0, timesteps))
  1178. )
  1179. if type(temperature) == float:
  1180. temperature = [temperature] * timesteps
  1181. for i in iterator:
  1182. ts = torch.full((b,), i, device=self.device, dtype=torch.long)
  1183. if self.shorten_cond_schedule:
  1184. assert self.model.conditioning_key != "hybrid"
  1185. tc = self.cond_ids[ts].to(cond.device)
  1186. cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
  1187. img, x0_partial = self.p_sample(
  1188. img,
  1189. cond,
  1190. ts,
  1191. clip_denoised=self.clip_denoised,
  1192. quantize_denoised=quantize_denoised,
  1193. return_x0=True,
  1194. temperature=temperature[i],
  1195. noise_dropout=noise_dropout,
  1196. score_corrector=score_corrector,
  1197. corrector_kwargs=corrector_kwargs,
  1198. )
  1199. if mask is not None:
  1200. assert x0 is not None
  1201. img_orig = self.q_sample(x0, ts)
  1202. img = img_orig * mask + (1.0 - mask) * img
  1203. if i % log_every_t == 0 or i == timesteps - 1:
  1204. intermediates.append(x0_partial)
  1205. if callback:
  1206. callback(i)
  1207. if img_callback:
  1208. img_callback(img, i)
  1209. return img, intermediates
  1210. @torch.no_grad()
  1211. def p_sample_loop(
  1212. self,
  1213. cond,
  1214. shape,
  1215. return_intermediates=False,
  1216. x_T=None,
  1217. verbose=True,
  1218. callback=None,
  1219. timesteps=None,
  1220. quantize_denoised=False,
  1221. mask=None,
  1222. x0=None,
  1223. img_callback=None,
  1224. start_T=None,
  1225. log_every_t=None,
  1226. ):
  1227. if not log_every_t:
  1228. log_every_t = self.log_every_t
  1229. device = self.betas.device
  1230. b = shape[0]
  1231. if x_T is None:
  1232. img = torch.randn(shape, device=device)
  1233. else:
  1234. img = x_T
  1235. intermediates = [img]
  1236. if timesteps is None:
  1237. timesteps = self.num_timesteps
  1238. if start_T is not None:
  1239. timesteps = min(timesteps, start_T)
  1240. iterator = (
  1241. tqdm(reversed(range(0, timesteps)), desc="Sampling t", total=timesteps)
  1242. if verbose
  1243. else reversed(range(0, timesteps))
  1244. )
  1245. if mask is not None:
  1246. assert x0 is not None
  1247. assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match
  1248. for i in iterator:
  1249. ts = torch.full((b,), i, device=device, dtype=torch.long)
  1250. if self.shorten_cond_schedule:
  1251. assert self.model.conditioning_key != "hybrid"
  1252. tc = self.cond_ids[ts].to(cond.device)
  1253. cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
  1254. img = self.p_sample(
  1255. img,
  1256. cond,
  1257. ts,
  1258. clip_denoised=self.clip_denoised,
  1259. quantize_denoised=quantize_denoised,
  1260. )
  1261. if mask is not None:
  1262. img_orig = self.q_sample(x0, ts)
  1263. img = img_orig * mask + (1.0 - mask) * img
  1264. if i % log_every_t == 0 or i == timesteps - 1:
  1265. intermediates.append(img)
  1266. if callback:
  1267. callback(i)
  1268. if img_callback:
  1269. img_callback(img, i)
  1270. if return_intermediates:
  1271. return img, intermediates
  1272. return img
  1273. @torch.no_grad()
  1274. def sample(
  1275. self,
  1276. cond,
  1277. batch_size=16,
  1278. return_intermediates=False,
  1279. x_T=None,
  1280. verbose=True,
  1281. timesteps=None,
  1282. quantize_denoised=False,
  1283. mask=None,
  1284. x0=None,
  1285. shape=None,
  1286. **kwargs,
  1287. ):
  1288. if shape is None:
  1289. shape = (batch_size, self.channels, self.image_size, self.image_size)
  1290. if cond is not None:
  1291. if isinstance(cond, dict):
  1292. cond = {
  1293. key: cond[key][:batch_size]
  1294. if not isinstance(cond[key], list)
  1295. else list(map(lambda x: x[:batch_size], cond[key]))
  1296. for key in cond
  1297. }
  1298. else:
  1299. cond = (
  1300. [c[:batch_size] for c in cond]
  1301. if isinstance(cond, list)
  1302. else cond[:batch_size]
  1303. )
  1304. return self.p_sample_loop(
  1305. cond,
  1306. shape,
  1307. return_intermediates=return_intermediates,
  1308. x_T=x_T,
  1309. verbose=verbose,
  1310. timesteps=timesteps,
  1311. quantize_denoised=quantize_denoised,
  1312. mask=mask,
  1313. x0=x0,
  1314. )
  1315. @torch.no_grad()
  1316. def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):
  1317. if ddim:
  1318. ddim_sampler = DDIMSampler(self)
  1319. shape = (self.channels, self.image_size, self.image_size)
  1320. samples, intermediates = ddim_sampler.sample(
  1321. ddim_steps, batch_size, shape, cond, verbose=False, **kwargs
  1322. )
  1323. else:
  1324. samples, intermediates = self.sample(
  1325. cond=cond, batch_size=batch_size, return_intermediates=True, **kwargs
  1326. )
  1327. return samples, intermediates
  1328. @torch.no_grad()
  1329. def get_unconditional_conditioning(self, batch_size, null_label=None):
  1330. if null_label is not None:
  1331. xc = null_label
  1332. if isinstance(xc, ListConfig):
  1333. xc = list(xc)
  1334. if isinstance(xc, dict) or isinstance(xc, list):
  1335. c = self.get_learned_conditioning(xc)
  1336. else:
  1337. if hasattr(xc, "to"):
  1338. xc = xc.to(self.device)
  1339. c = self.get_learned_conditioning(xc)
  1340. else:
  1341. if self.cond_stage_key in ["class_label", "cls"]:
  1342. xc = self.cond_stage_model.get_unconditional_conditioning(
  1343. batch_size, device=self.device
  1344. )
  1345. return self.get_learned_conditioning(xc)
  1346. else:
  1347. raise NotImplementedError("todo")
  1348. if isinstance(c, list): # in case the encoder gives us a list
  1349. for i in range(len(c)):
  1350. c[i] = repeat(c[i], "1 ... -> b ...", b=batch_size).to(self.device)
  1351. else:
  1352. c = repeat(c, "1 ... -> b ...", b=batch_size).to(self.device)
  1353. return c
  1354. @torch.no_grad()
  1355. def log_images(
  1356. self,
  1357. batch,
  1358. N=8,
  1359. n_row=4,
  1360. sample=True,
  1361. ddim_steps=50,
  1362. ddim_eta=0.0,
  1363. return_keys=None,
  1364. quantize_denoised=True,
  1365. inpaint=True,
  1366. plot_denoise_rows=False,
  1367. plot_progressive_rows=True,
  1368. plot_diffusion_rows=True,
  1369. unconditional_guidance_scale=1.0,
  1370. unconditional_guidance_label=None,
  1371. use_ema_scope=True,
  1372. **kwargs,
  1373. ):
  1374. ema_scope = self.ema_scope if use_ema_scope else nullcontext
  1375. use_ddim = ddim_steps is not None
  1376. log = dict()
  1377. z, c, x, xrec, xc = self.get_input(
  1378. batch,
  1379. self.first_stage_key,
  1380. return_first_stage_outputs=True,
  1381. force_c_encode=True,
  1382. return_original_cond=True,
  1383. bs=N,
  1384. )
  1385. N = min(x.shape[0], N)
  1386. n_row = min(x.shape[0], n_row)
  1387. log["inputs"] = x
  1388. log["reconstruction"] = xrec
  1389. if self.model.conditioning_key is not None:
  1390. if hasattr(self.cond_stage_model, "decode"):
  1391. xc = self.cond_stage_model.decode(c)
  1392. log["conditioning"] = xc
  1393. elif self.cond_stage_key in ["caption", "txt"]:
  1394. xc = log_txt_as_img(
  1395. (x.shape[2], x.shape[3]),
  1396. batch[self.cond_stage_key],
  1397. size=x.shape[2] // 25,
  1398. )
  1399. log["conditioning"] = xc
  1400. elif self.cond_stage_key in ["class_label", "cls"]:
  1401. try:
  1402. xc = log_txt_as_img(
  1403. (x.shape[2], x.shape[3]),
  1404. batch["human_label"],
  1405. size=x.shape[2] // 25,
  1406. )
  1407. log["conditioning"] = xc
  1408. except KeyError:
  1409. # probably no "human_label" in batch
  1410. pass
  1411. elif isimage(xc):
  1412. log["conditioning"] = xc
  1413. if ismap(xc):
  1414. log["original_conditioning"] = self.to_rgb(xc)
  1415. if plot_diffusion_rows:
  1416. # get diffusion row
  1417. diffusion_row = list()
  1418. z_start = z[:n_row]
  1419. for t in range(self.num_timesteps):
  1420. if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
  1421. t = repeat(torch.tensor([t]), "1 -> b", b=n_row)
  1422. t = t.to(self.device).long()
  1423. noise = torch.randn_like(z_start)
  1424. z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
  1425. diffusion_row.append(self.decode_first_stage(z_noisy))
  1426. diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
  1427. diffusion_grid = rearrange(diffusion_row, "n b c h w -> b n c h w")
  1428. diffusion_grid = rearrange(diffusion_grid, "b n c h w -> (b n) c h w")
  1429. diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
  1430. log["diffusion_row"] = diffusion_grid
  1431. if sample:
  1432. # get denoise row
  1433. with ema_scope("Sampling"):
  1434. samples, z_denoise_row = self.sample_log(
  1435. cond=c,
  1436. batch_size=N,
  1437. ddim=use_ddim,
  1438. ddim_steps=ddim_steps,
  1439. eta=ddim_eta,
  1440. )
  1441. # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
  1442. x_samples = self.decode_first_stage(samples)
  1443. log["samples"] = x_samples
  1444. if plot_denoise_rows:
  1445. denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
  1446. log["denoise_row"] = denoise_grid
  1447. if (
  1448. quantize_denoised
  1449. and not isinstance(self.first_stage_model, AutoencoderKL)
  1450. and not isinstance(self.first_stage_model, IdentityFirstStage)
  1451. ):
  1452. # also display when quantizing x0 while sampling
  1453. with ema_scope("Plotting Quantized Denoised"):
  1454. samples, z_denoise_row = self.sample_log(
  1455. cond=c,
  1456. batch_size=N,
  1457. ddim=use_ddim,
  1458. ddim_steps=ddim_steps,
  1459. eta=ddim_eta,
  1460. quantize_denoised=True,
  1461. )
  1462. # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,
  1463. # quantize_denoised=True)
  1464. x_samples = self.decode_first_stage(samples.to(self.device))
  1465. log["samples_x0_quantized"] = x_samples
  1466. if unconditional_guidance_scale > 1.0:
  1467. uc = self.get_unconditional_conditioning(N, unconditional_guidance_label)
  1468. if self.model.conditioning_key == "crossattn-adm":
  1469. uc = {"c_crossattn": [uc], "c_adm": c["c_adm"]}
  1470. with ema_scope("Sampling with classifier-free guidance"):
  1471. samples_cfg, _ = self.sample_log(
  1472. cond=c,
  1473. batch_size=N,
  1474. ddim=use_ddim,
  1475. ddim_steps=ddim_steps,
  1476. eta=ddim_eta,
  1477. unconditional_guidance_scale=unconditional_guidance_scale,
  1478. unconditional_conditioning=uc,
  1479. )
  1480. x_samples_cfg = self.decode_first_stage(samples_cfg)
  1481. log[
  1482. f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"
  1483. ] = x_samples_cfg
  1484. if inpaint:
  1485. # make a simple center square
  1486. b, h, w = z.shape[0], z.shape[2], z.shape[3]
  1487. mask = torch.ones(N, h, w).to(self.device)
  1488. # zeros will be filled in
  1489. mask[:, h // 4 : 3 * h // 4, w // 4 : 3 * w // 4] = 0.0
  1490. mask = mask[:, None, ...]
  1491. with ema_scope("Plotting Inpaint"):
  1492. samples, _ = self.sample_log(
  1493. cond=c,
  1494. batch_size=N,
  1495. ddim=use_ddim,
  1496. eta=ddim_eta,
  1497. ddim_steps=ddim_steps,
  1498. x0=z[:N],
  1499. mask=mask,
  1500. )
  1501. x_samples = self.decode_first_stage(samples.to(self.device))
  1502. log["samples_inpainting"] = x_samples
  1503. log["mask"] = mask
  1504. # outpaint
  1505. mask = 1.0 - mask
  1506. with ema_scope("Plotting Outpaint"):
  1507. samples, _ = self.sample_log(
  1508. cond=c,
  1509. batch_size=N,
  1510. ddim=use_ddim,
  1511. eta=ddim_eta,
  1512. ddim_steps=ddim_steps,
  1513. x0=z[:N],
  1514. mask=mask,
  1515. )
  1516. x_samples = self.decode_first_stage(samples.to(self.device))
  1517. log["samples_outpainting"] = x_samples
  1518. if plot_progressive_rows:
  1519. with ema_scope("Plotting Progressives"):
  1520. img, progressives = self.progressive_denoising(
  1521. c,
  1522. shape=(self.channels, self.image_size, self.image_size),
  1523. batch_size=N,
  1524. )
  1525. prog_row = self._get_denoise_row_from_list(
  1526. progressives, desc="Progressive Generation"
  1527. )
  1528. log["progressive_row"] = prog_row
  1529. if return_keys:
  1530. if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
  1531. return log
  1532. else:
  1533. return {key: log[key] for key in return_keys}
  1534. return log
  1535. def configure_optimizers(self):
  1536. lr = self.learning_rate
  1537. params = list(self.model.parameters())
  1538. if self.cond_stage_trainable:
  1539. print(f"{self.__class__.__name__}: Also optimizing conditioner params!")
  1540. params = params + list(self.cond_stage_model.parameters())
  1541. if self.learn_logvar:
  1542. print("Diffusion model optimizing logvar")
  1543. params.append(self.logvar)
  1544. opt = torch.optim.AdamW(params, lr=lr)
  1545. if self.use_scheduler:
  1546. assert "target" in self.scheduler_config
  1547. scheduler = instantiate_from_config(self.scheduler_config)
  1548. print("Setting up LambdaLR scheduler...")
  1549. scheduler = [
  1550. {
  1551. "scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule),
  1552. "interval": "step",
  1553. "frequency": 1,
  1554. }
  1555. ]
  1556. return [opt], scheduler
  1557. return opt
  1558. @torch.no_grad()
  1559. def to_rgb(self, x):
  1560. x = x.float()
  1561. if not hasattr(self, "colorize"):
  1562. self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x)
  1563. x = nn.functional.conv2d(x, weight=self.colorize)
  1564. x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0
  1565. return x
  1566. class DiffusionWrapper(torch.nn.Module):
  1567. def __init__(self, diff_model_config, conditioning_key):
  1568. super().__init__()
  1569. self.sequential_cross_attn = diff_model_config.pop(
  1570. "sequential_crossattn", False
  1571. )
  1572. self.diffusion_model = instantiate_from_config(diff_model_config)
  1573. self.conditioning_key = conditioning_key
  1574. assert self.conditioning_key in [
  1575. None,
  1576. "concat",
  1577. "crossattn",
  1578. "hybrid",
  1579. "adm",
  1580. "hybrid-adm",
  1581. "crossattn-adm",
  1582. ]
  1583. def forward(
  1584. self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None
  1585. ):
  1586. if self.conditioning_key is None:
  1587. out = self.diffusion_model(x, t)
  1588. elif self.conditioning_key == "concat":
  1589. xc = torch.cat([x] + c_concat, dim=1)
  1590. out = self.diffusion_model(xc, t)
  1591. elif self.conditioning_key == "crossattn":
  1592. if not self.sequential_cross_attn:
  1593. cc = torch.cat(c_crossattn, 1)
  1594. else:
  1595. cc = c_crossattn
  1596. out = self.diffusion_model(x, t, context=cc)
  1597. elif self.conditioning_key == "hybrid":
  1598. xc = torch.cat([x] + c_concat, dim=1)
  1599. cc = torch.cat(c_crossattn, 1)
  1600. out = self.diffusion_model(xc, t, context=cc)
  1601. elif self.conditioning_key == "hybrid-adm":
  1602. assert c_adm is not None
  1603. xc = torch.cat([x] + c_concat, dim=1)
  1604. cc = torch.cat(c_crossattn, 1)
  1605. out = self.diffusion_model(xc, t, context=cc, y=c_adm)
  1606. elif self.conditioning_key == "crossattn-adm":
  1607. assert c_adm is not None
  1608. cc = torch.cat(c_crossattn, 1)
  1609. out = self.diffusion_model(x, t, context=cc, y=c_adm)
  1610. elif self.conditioning_key == "adm":
  1611. cc = c_crossattn[0]
  1612. out = self.diffusion_model(x, t, y=cc)
  1613. else:
  1614. raise NotImplementedError()
  1615. return out
  1616. class LatentUpscaleDiffusion(LatentDiffusion):
  1617. def __init__(
  1618. self,
  1619. *args,
  1620. low_scale_config,
  1621. low_scale_key="LR",
  1622. noise_level_key=None,
  1623. **kwargs,
  1624. ):
  1625. super().__init__(*args, **kwargs)
  1626. # assumes that neither the cond_stage nor the low_scale_model contain trainable params
  1627. assert not self.cond_stage_trainable
  1628. self.instantiate_low_stage(low_scale_config)
  1629. self.low_scale_key = low_scale_key
  1630. self.noise_level_key = noise_level_key
  1631. def instantiate_low_stage(self, config):
  1632. model = instantiate_from_config(config)
  1633. self.low_scale_model = model.eval()
  1634. self.low_scale_model.train = disabled_train
  1635. for param in self.low_scale_model.parameters():
  1636. param.requires_grad = False
  1637. @torch.no_grad()
  1638. def get_input(self, batch, k, cond_key=None, bs=None, log_mode=False):
  1639. if not log_mode:
  1640. z, c = super().get_input(batch, k, force_c_encode=True, bs=bs)
  1641. else:
  1642. z, c, x, xrec, xc = super().get_input(
  1643. batch,
  1644. self.first_stage_key,
  1645. return_first_stage_outputs=True,
  1646. force_c_encode=True,
  1647. return_original_cond=True,
  1648. bs=bs,
  1649. )
  1650. x_low = batch[self.low_scale_key][:bs]
  1651. x_low = rearrange(x_low, "b h w c -> b c h w")
  1652. x_low = x_low.to(memory_format=torch.contiguous_format).float()
  1653. zx, noise_level = self.low_scale_model(x_low)
  1654. if self.noise_level_key is not None:
  1655. # get noise level from batch instead, e.g. when extracting a custom noise level for bsr
  1656. raise NotImplementedError("TODO")
  1657. all_conds = {"c_concat": [zx], "c_crossattn": [c], "c_adm": noise_level}
  1658. if log_mode:
  1659. # TODO: maybe disable if too expensive
  1660. x_low_rec = self.low_scale_model.decode(zx)
  1661. return z, all_conds, x, xrec, xc, x_low, x_low_rec, noise_level
  1662. return z, all_conds
  1663. @torch.no_grad()
  1664. def log_images(
  1665. self,
  1666. batch,
  1667. N=8,
  1668. n_row=4,
  1669. sample=True,
  1670. ddim_steps=200,
  1671. ddim_eta=1.0,
  1672. return_keys=None,
  1673. plot_denoise_rows=False,
  1674. plot_progressive_rows=True,
  1675. plot_diffusion_rows=True,
  1676. unconditional_guidance_scale=1.0,
  1677. unconditional_guidance_label=None,
  1678. use_ema_scope=True,
  1679. **kwargs,
  1680. ):
  1681. ema_scope = self.ema_scope if use_ema_scope else nullcontext
  1682. use_ddim = ddim_steps is not None
  1683. log = dict()
  1684. z, c, x, xrec, xc, x_low, x_low_rec, noise_level = self.get_input(
  1685. batch, self.first_stage_key, bs=N, log_mode=True
  1686. )
  1687. N = min(x.shape[0], N)
  1688. n_row = min(x.shape[0], n_row)
  1689. log["inputs"] = x
  1690. log["reconstruction"] = xrec
  1691. log["x_lr"] = x_low
  1692. log[
  1693. f"x_lr_rec_@noise_levels{'-'.join(map(lambda x: str(x), list(noise_level.cpu().numpy())))}"
  1694. ] = x_low_rec
  1695. if self.model.conditioning_key is not None:
  1696. if hasattr(self.cond_stage_model, "decode"):
  1697. xc = self.cond_stage_model.decode(c)
  1698. log["conditioning"] = xc
  1699. elif self.cond_stage_key in ["caption", "txt"]:
  1700. xc = log_txt_as_img(
  1701. (x.shape[2], x.shape[3]),
  1702. batch[self.cond_stage_key],
  1703. size=x.shape[2] // 25,
  1704. )
  1705. log["conditioning"] = xc
  1706. elif self.cond_stage_key in ["class_label", "cls"]:
  1707. xc = log_txt_as_img(
  1708. (x.shape[2], x.shape[3]),
  1709. batch["human_label"],
  1710. size=x.shape[2] // 25,
  1711. )
  1712. log["conditioning"] = xc
  1713. elif isimage(xc):
  1714. log["conditioning"] = xc
  1715. if ismap(xc):
  1716. log["original_conditioning"] = self.to_rgb(xc)
  1717. if plot_diffusion_rows:
  1718. # get diffusion row
  1719. diffusion_row = list()
  1720. z_start = z[:n_row]
  1721. for t in range(self.num_timesteps):
  1722. if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
  1723. t = repeat(torch.tensor([t]), "1 -> b", b=n_row)
  1724. t = t.to(self.device).long()
  1725. noise = torch.randn_like(z_start)
  1726. z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
  1727. diffusion_row.append(self.decode_first_stage(z_noisy))
  1728. diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
  1729. diffusion_grid = rearrange(diffusion_row, "n b c h w -> b n c h w")
  1730. diffusion_grid = rearrange(diffusion_grid, "b n c h w -> (b n) c h w")
  1731. diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
  1732. log["diffusion_row"] = diffusion_grid
  1733. if sample:
  1734. # get denoise row
  1735. with ema_scope("Sampling"):
  1736. samples, z_denoise_row = self.sample_log(
  1737. cond=c,
  1738. batch_size=N,
  1739. ddim=use_ddim,
  1740. ddim_steps=ddim_steps,
  1741. eta=ddim_eta,
  1742. )
  1743. # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
  1744. x_samples = self.decode_first_stage(samples)
  1745. log["samples"] = x_samples
  1746. if plot_denoise_rows:
  1747. denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
  1748. log["denoise_row"] = denoise_grid
  1749. if unconditional_guidance_scale > 1.0:
  1750. uc_tmp = self.get_unconditional_conditioning(
  1751. N, unconditional_guidance_label
  1752. )
  1753. # TODO explore better "unconditional" choices for the other keys
  1754. # maybe guide away from empty text label and highest noise level and maximally degraded zx?
  1755. uc = dict()
  1756. for k in c:
  1757. if k == "c_crossattn":
  1758. assert isinstance(c[k], list) and len(c[k]) == 1
  1759. uc[k] = [uc_tmp]
  1760. elif k == "c_adm": # todo: only run with text-based guidance?
  1761. assert isinstance(c[k], torch.Tensor)
  1762. # uc[k] = torch.ones_like(c[k]) * self.low_scale_model.max_noise_level
  1763. uc[k] = c[k]
  1764. elif isinstance(c[k], list):
  1765. uc[k] = [c[k][i] for i in range(len(c[k]))]
  1766. else:
  1767. uc[k] = c[k]
  1768. with ema_scope("Sampling with classifier-free guidance"):
  1769. samples_cfg, _ = self.sample_log(
  1770. cond=c,
  1771. batch_size=N,
  1772. ddim=use_ddim,
  1773. ddim_steps=ddim_steps,
  1774. eta=ddim_eta,
  1775. unconditional_guidance_scale=unconditional_guidance_scale,
  1776. unconditional_conditioning=uc,
  1777. )
  1778. x_samples_cfg = self.decode_first_stage(samples_cfg)
  1779. log[
  1780. f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"
  1781. ] = x_samples_cfg
  1782. if plot_progressive_rows:
  1783. with ema_scope("Plotting Progressives"):
  1784. img, progressives = self.progressive_denoising(
  1785. c,
  1786. shape=(self.channels, self.image_size, self.image_size),
  1787. batch_size=N,
  1788. )
  1789. prog_row = self._get_denoise_row_from_list(
  1790. progressives, desc="Progressive Generation"
  1791. )
  1792. log["progressive_row"] = prog_row
  1793. return log
  1794. class LatentFinetuneDiffusion(LatentDiffusion):
  1795. """
  1796. Basis for different finetunas, such as inpainting or depth2image
  1797. To disable finetuning mode, set finetune_keys to None
  1798. """
  1799. def __init__(
  1800. self,
  1801. concat_keys: tuple,
  1802. finetune_keys=(
  1803. "model.diffusion_model.input_blocks.0.0.weight",
  1804. "model_ema.diffusion_modelinput_blocks00weight",
  1805. ),
  1806. keep_finetune_dims=4,
  1807. # if model was trained without concat mode before and we would like to keep these channels
  1808. c_concat_log_start=None, # to log reconstruction of c_concat codes
  1809. c_concat_log_end=None,
  1810. *args,
  1811. **kwargs,
  1812. ):
  1813. ckpt_path = kwargs.pop("ckpt_path", None)
  1814. ignore_keys = kwargs.pop("ignore_keys", list())
  1815. super().__init__(*args, **kwargs)
  1816. self.finetune_keys = finetune_keys
  1817. self.concat_keys = concat_keys
  1818. self.keep_dims = keep_finetune_dims
  1819. self.c_concat_log_start = c_concat_log_start
  1820. self.c_concat_log_end = c_concat_log_end
  1821. if exists(self.finetune_keys):
  1822. assert exists(ckpt_path), "can only finetune from a given checkpoint"
  1823. if exists(ckpt_path):
  1824. self.init_from_ckpt(ckpt_path, ignore_keys)
  1825. def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
  1826. sd = torch.load(path, map_location="cpu")
  1827. if "state_dict" in list(sd.keys()):
  1828. sd = sd["state_dict"]
  1829. keys = list(sd.keys())
  1830. for k in keys:
  1831. for ik in ignore_keys:
  1832. if k.startswith(ik):
  1833. print("Deleting key {} from state_dict.".format(k))
  1834. del sd[k]
  1835. # make it explicit, finetune by including extra input channels
  1836. if exists(self.finetune_keys) and k in self.finetune_keys:
  1837. new_entry = None
  1838. for name, param in self.named_parameters():
  1839. if name in self.finetune_keys:
  1840. print(
  1841. f"modifying key '{name}' and keeping its original {self.keep_dims} (channels) dimensions only"
  1842. )
  1843. new_entry = torch.zeros_like(param) # zero init
  1844. assert exists(new_entry), "did not find matching parameter to modify"
  1845. new_entry[:, : self.keep_dims, ...] = sd[k]
  1846. sd[k] = new_entry
  1847. missing, unexpected = (
  1848. self.load_state_dict(sd, strict=False)
  1849. if not only_model
  1850. else self.model.load_state_dict(sd, strict=False)
  1851. )
  1852. print(
  1853. f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
  1854. )
  1855. if len(missing) > 0:
  1856. print(f"Missing Keys: {missing}")
  1857. if len(unexpected) > 0:
  1858. print(f"Unexpected Keys: {unexpected}")
  1859. @torch.no_grad()
  1860. def log_images(
  1861. self,
  1862. batch,
  1863. N=8,
  1864. n_row=4,
  1865. sample=True,
  1866. ddim_steps=200,
  1867. ddim_eta=1.0,
  1868. return_keys=None,
  1869. quantize_denoised=True,
  1870. inpaint=True,
  1871. plot_denoise_rows=False,
  1872. plot_progressive_rows=True,
  1873. plot_diffusion_rows=True,
  1874. unconditional_guidance_scale=1.0,
  1875. unconditional_guidance_label=None,
  1876. use_ema_scope=True,
  1877. **kwargs,
  1878. ):
  1879. ema_scope = self.ema_scope if use_ema_scope else nullcontext
  1880. use_ddim = ddim_steps is not None
  1881. log = dict()
  1882. z, c, x, xrec, xc = self.get_input(
  1883. batch, self.first_stage_key, bs=N, return_first_stage_outputs=True
  1884. )
  1885. c_cat, c = c["c_concat"][0], c["c_crossattn"][0]
  1886. N = min(x.shape[0], N)
  1887. n_row = min(x.shape[0], n_row)
  1888. log["inputs"] = x
  1889. log["reconstruction"] = xrec
  1890. if self.model.conditioning_key is not None:
  1891. if hasattr(self.cond_stage_model, "decode"):
  1892. xc = self.cond_stage_model.decode(c)
  1893. log["conditioning"] = xc
  1894. elif self.cond_stage_key in ["caption", "txt"]:
  1895. xc = log_txt_as_img(
  1896. (x.shape[2], x.shape[3]),
  1897. batch[self.cond_stage_key],
  1898. size=x.shape[2] // 25,
  1899. )
  1900. log["conditioning"] = xc
  1901. elif self.cond_stage_key in ["class_label", "cls"]:
  1902. xc = log_txt_as_img(
  1903. (x.shape[2], x.shape[3]),
  1904. batch["human_label"],
  1905. size=x.shape[2] // 25,
  1906. )
  1907. log["conditioning"] = xc
  1908. elif isimage(xc):
  1909. log["conditioning"] = xc
  1910. if ismap(xc):
  1911. log["original_conditioning"] = self.to_rgb(xc)
  1912. if not (self.c_concat_log_start is None and self.c_concat_log_end is None):
  1913. log["c_concat_decoded"] = self.decode_first_stage(
  1914. c_cat[:, self.c_concat_log_start : self.c_concat_log_end]
  1915. )
  1916. if plot_diffusion_rows:
  1917. # get diffusion row
  1918. diffusion_row = list()
  1919. z_start = z[:n_row]
  1920. for t in range(self.num_timesteps):
  1921. if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
  1922. t = repeat(torch.tensor([t]), "1 -> b", b=n_row)
  1923. t = t.to(self.device).long()
  1924. noise = torch.randn_like(z_start)
  1925. z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
  1926. diffusion_row.append(self.decode_first_stage(z_noisy))
  1927. diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
  1928. diffusion_grid = rearrange(diffusion_row, "n b c h w -> b n c h w")
  1929. diffusion_grid = rearrange(diffusion_grid, "b n c h w -> (b n) c h w")
  1930. diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
  1931. log["diffusion_row"] = diffusion_grid
  1932. if sample:
  1933. # get denoise row
  1934. with ema_scope("Sampling"):
  1935. samples, z_denoise_row = self.sample_log(
  1936. cond={"c_concat": [c_cat], "c_crossattn": [c]},
  1937. batch_size=N,
  1938. ddim=use_ddim,
  1939. ddim_steps=ddim_steps,
  1940. eta=ddim_eta,
  1941. )
  1942. # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
  1943. x_samples = self.decode_first_stage(samples)
  1944. log["samples"] = x_samples
  1945. if plot_denoise_rows:
  1946. denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
  1947. log["denoise_row"] = denoise_grid
  1948. if unconditional_guidance_scale > 1.0:
  1949. uc_cross = self.get_unconditional_conditioning(
  1950. N, unconditional_guidance_label
  1951. )
  1952. uc_cat = c_cat
  1953. uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]}
  1954. with ema_scope("Sampling with classifier-free guidance"):
  1955. samples_cfg, _ = self.sample_log(
  1956. cond={"c_concat": [c_cat], "c_crossattn": [c]},
  1957. batch_size=N,
  1958. ddim=use_ddim,
  1959. ddim_steps=ddim_steps,
  1960. eta=ddim_eta,
  1961. unconditional_guidance_scale=unconditional_guidance_scale,
  1962. unconditional_conditioning=uc_full,
  1963. )
  1964. x_samples_cfg = self.decode_first_stage(samples_cfg)
  1965. log[
  1966. f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"
  1967. ] = x_samples_cfg
  1968. return log
  1969. class LatentInpaintDiffusion(LatentFinetuneDiffusion):
  1970. """
  1971. can either run as pure inpainting model (only concat mode) or with mixed conditionings,
  1972. e.g. mask as concat and text via cross-attn.
  1973. To disable finetuning mode, set finetune_keys to None
  1974. """
  1975. def __init__(
  1976. self,
  1977. concat_keys=("mask", "masked_image"),
  1978. masked_image_key="masked_image",
  1979. *args,
  1980. **kwargs,
  1981. ):
  1982. super().__init__(concat_keys, *args, **kwargs)
  1983. self.masked_image_key = masked_image_key
  1984. assert self.masked_image_key in concat_keys
  1985. @torch.no_grad()
  1986. def get_input(
  1987. self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False
  1988. ):
  1989. # note: restricted to non-trainable encoders currently
  1990. assert (
  1991. not self.cond_stage_trainable
  1992. ), "trainable cond stages not yet supported for inpainting"
  1993. z, c, x, xrec, xc = super().get_input(
  1994. batch,
  1995. self.first_stage_key,
  1996. return_first_stage_outputs=True,
  1997. force_c_encode=True,
  1998. return_original_cond=True,
  1999. bs=bs,
  2000. )
  2001. assert exists(self.concat_keys)
  2002. c_cat = list()
  2003. for ck in self.concat_keys:
  2004. cc = (
  2005. rearrange(batch[ck], "b h w c -> b c h w")
  2006. .to(memory_format=torch.contiguous_format)
  2007. .float()
  2008. )
  2009. if bs is not None:
  2010. cc = cc[:bs]
  2011. cc = cc.to(self.device)
  2012. bchw = z.shape
  2013. if ck != self.masked_image_key:
  2014. cc = torch.nn.functional.interpolate(cc, size=bchw[-2:])
  2015. else:
  2016. cc = self.get_first_stage_encoding(self.encode_first_stage(cc))
  2017. c_cat.append(cc)
  2018. c_cat = torch.cat(c_cat, dim=1)
  2019. all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
  2020. if return_first_stage_outputs:
  2021. return z, all_conds, x, xrec, xc
  2022. return z, all_conds
  2023. @torch.no_grad()
  2024. def log_images(self, *args, **kwargs):
  2025. log = super(LatentInpaintDiffusion, self).log_images(*args, **kwargs)
  2026. log["masked_image"] = (
  2027. rearrange(args[0]["masked_image"], "b h w c -> b c h w")
  2028. .to(memory_format=torch.contiguous_format)
  2029. .float()
  2030. )
  2031. return log
  2032. class LatentDepth2ImageDiffusion(LatentFinetuneDiffusion):
  2033. """
  2034. condition on monocular depth estimation
  2035. """
  2036. def __init__(self, depth_stage_config, concat_keys=("midas_in",), *args, **kwargs):
  2037. super().__init__(concat_keys=concat_keys, *args, **kwargs)
  2038. self.depth_model = instantiate_from_config(depth_stage_config)
  2039. self.depth_stage_key = concat_keys[0]
  2040. @torch.no_grad()
  2041. def get_input(
  2042. self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False
  2043. ):
  2044. # note: restricted to non-trainable encoders currently
  2045. assert (
  2046. not self.cond_stage_trainable
  2047. ), "trainable cond stages not yet supported for depth2img"
  2048. z, c, x, xrec, xc = super().get_input(
  2049. batch,
  2050. self.first_stage_key,
  2051. return_first_stage_outputs=True,
  2052. force_c_encode=True,
  2053. return_original_cond=True,
  2054. bs=bs,
  2055. )
  2056. assert exists(self.concat_keys)
  2057. assert len(self.concat_keys) == 1
  2058. c_cat = list()
  2059. for ck in self.concat_keys:
  2060. cc = batch[ck]
  2061. if bs is not None:
  2062. cc = cc[:bs]
  2063. cc = cc.to(self.device)
  2064. cc = self.depth_model(cc)
  2065. cc = torch.nn.functional.interpolate(
  2066. cc,
  2067. size=z.shape[2:],
  2068. mode="bicubic",
  2069. align_corners=False,
  2070. )
  2071. depth_min, depth_max = (
  2072. torch.amin(cc, dim=[1, 2, 3], keepdim=True),
  2073. torch.amax(cc, dim=[1, 2, 3], keepdim=True),
  2074. )
  2075. cc = 2.0 * (cc - depth_min) / (depth_max - depth_min + 0.001) - 1.0
  2076. c_cat.append(cc)
  2077. c_cat = torch.cat(c_cat, dim=1)
  2078. all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
  2079. if return_first_stage_outputs:
  2080. return z, all_conds, x, xrec, xc
  2081. return z, all_conds
  2082. @torch.no_grad()
  2083. def log_images(self, *args, **kwargs):
  2084. log = super().log_images(*args, **kwargs)
  2085. depth = self.depth_model(args[0][self.depth_stage_key])
  2086. depth_min, depth_max = (
  2087. torch.amin(depth, dim=[1, 2, 3], keepdim=True),
  2088. torch.amax(depth, dim=[1, 2, 3], keepdim=True),
  2089. )
  2090. log["depth"] = 2.0 * (depth - depth_min) / (depth_max - depth_min) - 1.0
  2091. return log
  2092. class LatentUpscaleFinetuneDiffusion(LatentFinetuneDiffusion):
  2093. """
  2094. condition on low-res image (and optionally on some spatial noise augmentation)
  2095. """
  2096. def __init__(
  2097. self,
  2098. concat_keys=("lr",),
  2099. reshuffle_patch_size=None,
  2100. low_scale_config=None,
  2101. low_scale_key=None,
  2102. *args,
  2103. **kwargs,
  2104. ):
  2105. super().__init__(concat_keys=concat_keys, *args, **kwargs)
  2106. self.reshuffle_patch_size = reshuffle_patch_size
  2107. self.low_scale_model = None
  2108. if low_scale_config is not None:
  2109. print("Initializing a low-scale model")
  2110. assert exists(low_scale_key)
  2111. self.instantiate_low_stage(low_scale_config)
  2112. self.low_scale_key = low_scale_key
  2113. def instantiate_low_stage(self, config):
  2114. model = instantiate_from_config(config)
  2115. self.low_scale_model = model.eval()
  2116. self.low_scale_model.train = disabled_train
  2117. for param in self.low_scale_model.parameters():
  2118. param.requires_grad = False
  2119. @torch.no_grad()
  2120. def get_input(
  2121. self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False
  2122. ):
  2123. # note: restricted to non-trainable encoders currently
  2124. assert (
  2125. not self.cond_stage_trainable
  2126. ), "trainable cond stages not yet supported for upscaling-ft"
  2127. z, c, x, xrec, xc = super().get_input(
  2128. batch,
  2129. self.first_stage_key,
  2130. return_first_stage_outputs=True,
  2131. force_c_encode=True,
  2132. return_original_cond=True,
  2133. bs=bs,
  2134. )
  2135. assert exists(self.concat_keys)
  2136. assert len(self.concat_keys) == 1
  2137. # optionally make spatial noise_level here
  2138. c_cat = list()
  2139. noise_level = None
  2140. for ck in self.concat_keys:
  2141. cc = batch[ck]
  2142. cc = rearrange(cc, "b h w c -> b c h w")
  2143. if exists(self.reshuffle_patch_size):
  2144. assert isinstance(self.reshuffle_patch_size, int)
  2145. cc = rearrange(
  2146. cc,
  2147. "b c (p1 h) (p2 w) -> b (p1 p2 c) h w",
  2148. p1=self.reshuffle_patch_size,
  2149. p2=self.reshuffle_patch_size,
  2150. )
  2151. if bs is not None:
  2152. cc = cc[:bs]
  2153. cc = cc.to(self.device)
  2154. if exists(self.low_scale_model) and ck == self.low_scale_key:
  2155. cc, noise_level = self.low_scale_model(cc)
  2156. c_cat.append(cc)
  2157. c_cat = torch.cat(c_cat, dim=1)
  2158. if exists(noise_level):
  2159. all_conds = {"c_concat": [c_cat], "c_crossattn": [c], "c_adm": noise_level}
  2160. else:
  2161. all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
  2162. if return_first_stage_outputs:
  2163. return z, all_conds, x, xrec, xc
  2164. return z, all_conds
  2165. @torch.no_grad()
  2166. def log_images(self, *args, **kwargs):
  2167. log = super().log_images(*args, **kwargs)
  2168. log["lr"] = rearrange(args[0]["lr"], "b h w c -> b c h w")
  2169. return log