cldm.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780
  1. import copy
  2. import os
  3. from pathlib import Path
  4. import einops
  5. import torch
  6. import torch as th
  7. import torch.nn as nn
  8. from easydict import EasyDict as edict
  9. from einops import rearrange, repeat
  10. from sorawm.iopaint.model.anytext.ldm.models.diffusion.ddim import DDIMSampler
  11. from sorawm.iopaint.model.anytext.ldm.models.diffusion.ddpm import LatentDiffusion
  12. from sorawm.iopaint.model.anytext.ldm.modules.attention import SpatialTransformer
  13. from sorawm.iopaint.model.anytext.ldm.modules.diffusionmodules.openaimodel import (
  14. AttentionBlock,
  15. Downsample,
  16. ResBlock,
  17. TimestepEmbedSequential,
  18. UNetModel,
  19. )
  20. from sorawm.iopaint.model.anytext.ldm.modules.diffusionmodules.util import (
  21. conv_nd,
  22. linear,
  23. timestep_embedding,
  24. zero_module,
  25. )
  26. from sorawm.iopaint.model.anytext.ldm.modules.distributions.distributions import (
  27. DiagonalGaussianDistribution,
  28. )
  29. from sorawm.iopaint.model.anytext.ldm.util import (
  30. exists,
  31. instantiate_from_config,
  32. log_txt_as_img,
  33. )
  34. from .recognizer import TextRecognizer, create_predictor
  35. CURRENT_DIR = Path(os.path.dirname(os.path.abspath(__file__)))
  36. def count_parameters(model):
  37. return sum(p.numel() for p in model.parameters() if p.requires_grad)
  38. class ControlledUnetModel(UNetModel):
  39. def forward(
  40. self,
  41. x,
  42. timesteps=None,
  43. context=None,
  44. control=None,
  45. only_mid_control=False,
  46. **kwargs,
  47. ):
  48. hs = []
  49. with torch.no_grad():
  50. t_emb = timestep_embedding(
  51. timesteps, self.model_channels, repeat_only=False
  52. )
  53. if self.use_fp16:
  54. t_emb = t_emb.half()
  55. emb = self.time_embed(t_emb)
  56. h = x.type(self.dtype)
  57. for module in self.input_blocks:
  58. h = module(h, emb, context)
  59. hs.append(h)
  60. h = self.middle_block(h, emb, context)
  61. if control is not None:
  62. h += control.pop()
  63. for i, module in enumerate(self.output_blocks):
  64. if only_mid_control or control is None:
  65. h = torch.cat([h, hs.pop()], dim=1)
  66. else:
  67. h = torch.cat([h, hs.pop() + control.pop()], dim=1)
  68. h = module(h, emb, context)
  69. h = h.type(x.dtype)
  70. return self.out(h)
  71. class ControlNet(nn.Module):
  72. def __init__(
  73. self,
  74. image_size,
  75. in_channels,
  76. model_channels,
  77. glyph_channels,
  78. position_channels,
  79. num_res_blocks,
  80. attention_resolutions,
  81. dropout=0,
  82. channel_mult=(1, 2, 4, 8),
  83. conv_resample=True,
  84. dims=2,
  85. use_checkpoint=False,
  86. use_fp16=False,
  87. num_heads=-1,
  88. num_head_channels=-1,
  89. num_heads_upsample=-1,
  90. use_scale_shift_norm=False,
  91. resblock_updown=False,
  92. use_new_attention_order=False,
  93. use_spatial_transformer=False, # custom transformer support
  94. transformer_depth=1, # custom transformer support
  95. context_dim=None, # custom transformer support
  96. n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
  97. legacy=True,
  98. disable_self_attentions=None,
  99. num_attention_blocks=None,
  100. disable_middle_self_attn=False,
  101. use_linear_in_transformer=False,
  102. ):
  103. super().__init__()
  104. if use_spatial_transformer:
  105. assert (
  106. context_dim is not None
  107. ), "Fool!! You forgot to include the dimension of your cross-attention conditioning..."
  108. if context_dim is not None:
  109. assert (
  110. use_spatial_transformer
  111. ), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..."
  112. from omegaconf.listconfig import ListConfig
  113. if type(context_dim) == ListConfig:
  114. context_dim = list(context_dim)
  115. if num_heads_upsample == -1:
  116. num_heads_upsample = num_heads
  117. if num_heads == -1:
  118. assert (
  119. num_head_channels != -1
  120. ), "Either num_heads or num_head_channels has to be set"
  121. if num_head_channels == -1:
  122. assert (
  123. num_heads != -1
  124. ), "Either num_heads or num_head_channels has to be set"
  125. self.dims = dims
  126. self.image_size = image_size
  127. self.in_channels = in_channels
  128. self.model_channels = model_channels
  129. if isinstance(num_res_blocks, int):
  130. self.num_res_blocks = len(channel_mult) * [num_res_blocks]
  131. else:
  132. if len(num_res_blocks) != len(channel_mult):
  133. raise ValueError(
  134. "provide num_res_blocks either as an int (globally constant) or "
  135. "as a list/tuple (per-level) with the same length as channel_mult"
  136. )
  137. self.num_res_blocks = num_res_blocks
  138. if disable_self_attentions is not None:
  139. # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
  140. assert len(disable_self_attentions) == len(channel_mult)
  141. if num_attention_blocks is not None:
  142. assert len(num_attention_blocks) == len(self.num_res_blocks)
  143. assert all(
  144. map(
  145. lambda i: self.num_res_blocks[i] >= num_attention_blocks[i],
  146. range(len(num_attention_blocks)),
  147. )
  148. )
  149. print(
  150. f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
  151. f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
  152. f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
  153. f"attention will still not be set."
  154. )
  155. self.attention_resolutions = attention_resolutions
  156. self.dropout = dropout
  157. self.channel_mult = channel_mult
  158. self.conv_resample = conv_resample
  159. self.use_checkpoint = use_checkpoint
  160. self.use_fp16 = use_fp16
  161. self.dtype = th.float16 if use_fp16 else th.float32
  162. self.num_heads = num_heads
  163. self.num_head_channels = num_head_channels
  164. self.num_heads_upsample = num_heads_upsample
  165. self.predict_codebook_ids = n_embed is not None
  166. time_embed_dim = model_channels * 4
  167. self.time_embed = nn.Sequential(
  168. linear(model_channels, time_embed_dim),
  169. nn.SiLU(),
  170. linear(time_embed_dim, time_embed_dim),
  171. )
  172. self.input_blocks = nn.ModuleList(
  173. [
  174. TimestepEmbedSequential(
  175. conv_nd(dims, in_channels, model_channels, 3, padding=1)
  176. )
  177. ]
  178. )
  179. self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels)])
  180. self.glyph_block = TimestepEmbedSequential(
  181. conv_nd(dims, glyph_channels, 8, 3, padding=1),
  182. nn.SiLU(),
  183. conv_nd(dims, 8, 8, 3, padding=1),
  184. nn.SiLU(),
  185. conv_nd(dims, 8, 16, 3, padding=1, stride=2),
  186. nn.SiLU(),
  187. conv_nd(dims, 16, 16, 3, padding=1),
  188. nn.SiLU(),
  189. conv_nd(dims, 16, 32, 3, padding=1, stride=2),
  190. nn.SiLU(),
  191. conv_nd(dims, 32, 32, 3, padding=1),
  192. nn.SiLU(),
  193. conv_nd(dims, 32, 96, 3, padding=1, stride=2),
  194. nn.SiLU(),
  195. conv_nd(dims, 96, 96, 3, padding=1),
  196. nn.SiLU(),
  197. conv_nd(dims, 96, 256, 3, padding=1, stride=2),
  198. nn.SiLU(),
  199. )
  200. self.position_block = TimestepEmbedSequential(
  201. conv_nd(dims, position_channels, 8, 3, padding=1),
  202. nn.SiLU(),
  203. conv_nd(dims, 8, 8, 3, padding=1),
  204. nn.SiLU(),
  205. conv_nd(dims, 8, 16, 3, padding=1, stride=2),
  206. nn.SiLU(),
  207. conv_nd(dims, 16, 16, 3, padding=1),
  208. nn.SiLU(),
  209. conv_nd(dims, 16, 32, 3, padding=1, stride=2),
  210. nn.SiLU(),
  211. conv_nd(dims, 32, 32, 3, padding=1),
  212. nn.SiLU(),
  213. conv_nd(dims, 32, 64, 3, padding=1, stride=2),
  214. nn.SiLU(),
  215. )
  216. self.fuse_block = zero_module(
  217. conv_nd(dims, 256 + 64 + 4, model_channels, 3, padding=1)
  218. )
  219. self._feature_size = model_channels
  220. input_block_chans = [model_channels]
  221. ch = model_channels
  222. ds = 1
  223. for level, mult in enumerate(channel_mult):
  224. for nr in range(self.num_res_blocks[level]):
  225. layers = [
  226. ResBlock(
  227. ch,
  228. time_embed_dim,
  229. dropout,
  230. out_channels=mult * model_channels,
  231. dims=dims,
  232. use_checkpoint=use_checkpoint,
  233. use_scale_shift_norm=use_scale_shift_norm,
  234. )
  235. ]
  236. ch = mult * model_channels
  237. if ds in attention_resolutions:
  238. if num_head_channels == -1:
  239. dim_head = ch // num_heads
  240. else:
  241. num_heads = ch // num_head_channels
  242. dim_head = num_head_channels
  243. if legacy:
  244. # num_heads = 1
  245. dim_head = (
  246. ch // num_heads
  247. if use_spatial_transformer
  248. else num_head_channels
  249. )
  250. if exists(disable_self_attentions):
  251. disabled_sa = disable_self_attentions[level]
  252. else:
  253. disabled_sa = False
  254. if (
  255. not exists(num_attention_blocks)
  256. or nr < num_attention_blocks[level]
  257. ):
  258. layers.append(
  259. AttentionBlock(
  260. ch,
  261. use_checkpoint=use_checkpoint,
  262. num_heads=num_heads,
  263. num_head_channels=dim_head,
  264. use_new_attention_order=use_new_attention_order,
  265. )
  266. if not use_spatial_transformer
  267. else SpatialTransformer(
  268. ch,
  269. num_heads,
  270. dim_head,
  271. depth=transformer_depth,
  272. context_dim=context_dim,
  273. disable_self_attn=disabled_sa,
  274. use_linear=use_linear_in_transformer,
  275. use_checkpoint=use_checkpoint,
  276. )
  277. )
  278. self.input_blocks.append(TimestepEmbedSequential(*layers))
  279. self.zero_convs.append(self.make_zero_conv(ch))
  280. self._feature_size += ch
  281. input_block_chans.append(ch)
  282. if level != len(channel_mult) - 1:
  283. out_ch = ch
  284. self.input_blocks.append(
  285. TimestepEmbedSequential(
  286. ResBlock(
  287. ch,
  288. time_embed_dim,
  289. dropout,
  290. out_channels=out_ch,
  291. dims=dims,
  292. use_checkpoint=use_checkpoint,
  293. use_scale_shift_norm=use_scale_shift_norm,
  294. down=True,
  295. )
  296. if resblock_updown
  297. else Downsample(
  298. ch, conv_resample, dims=dims, out_channels=out_ch
  299. )
  300. )
  301. )
  302. ch = out_ch
  303. input_block_chans.append(ch)
  304. self.zero_convs.append(self.make_zero_conv(ch))
  305. ds *= 2
  306. self._feature_size += ch
  307. if num_head_channels == -1:
  308. dim_head = ch // num_heads
  309. else:
  310. num_heads = ch // num_head_channels
  311. dim_head = num_head_channels
  312. if legacy:
  313. # num_heads = 1
  314. dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
  315. self.middle_block = TimestepEmbedSequential(
  316. ResBlock(
  317. ch,
  318. time_embed_dim,
  319. dropout,
  320. dims=dims,
  321. use_checkpoint=use_checkpoint,
  322. use_scale_shift_norm=use_scale_shift_norm,
  323. ),
  324. AttentionBlock(
  325. ch,
  326. use_checkpoint=use_checkpoint,
  327. num_heads=num_heads,
  328. num_head_channels=dim_head,
  329. use_new_attention_order=use_new_attention_order,
  330. )
  331. if not use_spatial_transformer
  332. else SpatialTransformer( # always uses a self-attn
  333. ch,
  334. num_heads,
  335. dim_head,
  336. depth=transformer_depth,
  337. context_dim=context_dim,
  338. disable_self_attn=disable_middle_self_attn,
  339. use_linear=use_linear_in_transformer,
  340. use_checkpoint=use_checkpoint,
  341. ),
  342. ResBlock(
  343. ch,
  344. time_embed_dim,
  345. dropout,
  346. dims=dims,
  347. use_checkpoint=use_checkpoint,
  348. use_scale_shift_norm=use_scale_shift_norm,
  349. ),
  350. )
  351. self.middle_block_out = self.make_zero_conv(ch)
  352. self._feature_size += ch
  353. def make_zero_conv(self, channels):
  354. return TimestepEmbedSequential(
  355. zero_module(conv_nd(self.dims, channels, channels, 1, padding=0))
  356. )
  357. def forward(self, x, hint, text_info, timesteps, context, **kwargs):
  358. t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
  359. if self.use_fp16:
  360. t_emb = t_emb.half()
  361. emb = self.time_embed(t_emb)
  362. # guided_hint from text_info
  363. B, C, H, W = x.shape
  364. glyphs = torch.cat(text_info["glyphs"], dim=1).sum(dim=1, keepdim=True)
  365. positions = torch.cat(text_info["positions"], dim=1).sum(dim=1, keepdim=True)
  366. enc_glyph = self.glyph_block(glyphs, emb, context)
  367. enc_pos = self.position_block(positions, emb, context)
  368. guided_hint = self.fuse_block(
  369. torch.cat([enc_glyph, enc_pos, text_info["masked_x"]], dim=1)
  370. )
  371. outs = []
  372. h = x.type(self.dtype)
  373. for module, zero_conv in zip(self.input_blocks, self.zero_convs):
  374. if guided_hint is not None:
  375. h = module(h, emb, context)
  376. h += guided_hint
  377. guided_hint = None
  378. else:
  379. h = module(h, emb, context)
  380. outs.append(zero_conv(h, emb, context))
  381. h = self.middle_block(h, emb, context)
  382. outs.append(self.middle_block_out(h, emb, context))
  383. return outs
  384. class ControlLDM(LatentDiffusion):
  385. def __init__(
  386. self,
  387. control_stage_config,
  388. control_key,
  389. glyph_key,
  390. position_key,
  391. only_mid_control,
  392. loss_alpha=0,
  393. loss_beta=0,
  394. with_step_weight=False,
  395. use_vae_upsample=False,
  396. latin_weight=1.0,
  397. embedding_manager_config=None,
  398. *args,
  399. **kwargs,
  400. ):
  401. self.use_fp16 = kwargs.pop("use_fp16", False)
  402. super().__init__(*args, **kwargs)
  403. self.control_model = instantiate_from_config(control_stage_config)
  404. self.control_key = control_key
  405. self.glyph_key = glyph_key
  406. self.position_key = position_key
  407. self.only_mid_control = only_mid_control
  408. self.control_scales = [1.0] * 13
  409. self.loss_alpha = loss_alpha
  410. self.loss_beta = loss_beta
  411. self.with_step_weight = with_step_weight
  412. self.use_vae_upsample = use_vae_upsample
  413. self.latin_weight = latin_weight
  414. if (
  415. embedding_manager_config is not None
  416. and embedding_manager_config.params.valid
  417. ):
  418. self.embedding_manager = self.instantiate_embedding_manager(
  419. embedding_manager_config, self.cond_stage_model
  420. )
  421. for param in self.embedding_manager.embedding_parameters():
  422. param.requires_grad = True
  423. else:
  424. self.embedding_manager = None
  425. if self.loss_alpha > 0 or self.loss_beta > 0 or self.embedding_manager:
  426. if embedding_manager_config.params.emb_type == "ocr":
  427. self.text_predictor = create_predictor().eval()
  428. args = edict()
  429. args.rec_image_shape = "3, 48, 320"
  430. args.rec_batch_num = 6
  431. args.rec_char_dict_path = str(
  432. CURRENT_DIR.parent / "ocr_recog" / "ppocr_keys_v1.txt"
  433. )
  434. args.use_fp16 = self.use_fp16
  435. self.cn_recognizer = TextRecognizer(args, self.text_predictor)
  436. for param in self.text_predictor.parameters():
  437. param.requires_grad = False
  438. if self.embedding_manager:
  439. self.embedding_manager.recog = self.cn_recognizer
  440. @torch.no_grad()
  441. def get_input(self, batch, k, bs=None, *args, **kwargs):
  442. if self.embedding_manager is None: # fill in full caption
  443. self.fill_caption(batch)
  444. x, c, mx = super().get_input(
  445. batch, self.first_stage_key, mask_k="masked_img", *args, **kwargs
  446. )
  447. control = batch[
  448. self.control_key
  449. ] # for log_images and loss_alpha, not real control
  450. if bs is not None:
  451. control = control[:bs]
  452. control = control.to(self.device)
  453. control = einops.rearrange(control, "b h w c -> b c h w")
  454. control = control.to(memory_format=torch.contiguous_format).float()
  455. inv_mask = batch["inv_mask"]
  456. if bs is not None:
  457. inv_mask = inv_mask[:bs]
  458. inv_mask = inv_mask.to(self.device)
  459. inv_mask = einops.rearrange(inv_mask, "b h w c -> b c h w")
  460. inv_mask = inv_mask.to(memory_format=torch.contiguous_format).float()
  461. glyphs = batch[self.glyph_key]
  462. gly_line = batch["gly_line"]
  463. positions = batch[self.position_key]
  464. n_lines = batch["n_lines"]
  465. language = batch["language"]
  466. texts = batch["texts"]
  467. assert len(glyphs) == len(positions)
  468. for i in range(len(glyphs)):
  469. if bs is not None:
  470. glyphs[i] = glyphs[i][:bs]
  471. gly_line[i] = gly_line[i][:bs]
  472. positions[i] = positions[i][:bs]
  473. n_lines = n_lines[:bs]
  474. glyphs[i] = glyphs[i].to(self.device)
  475. gly_line[i] = gly_line[i].to(self.device)
  476. positions[i] = positions[i].to(self.device)
  477. glyphs[i] = einops.rearrange(glyphs[i], "b h w c -> b c h w")
  478. gly_line[i] = einops.rearrange(gly_line[i], "b h w c -> b c h w")
  479. positions[i] = einops.rearrange(positions[i], "b h w c -> b c h w")
  480. glyphs[i] = glyphs[i].to(memory_format=torch.contiguous_format).float()
  481. gly_line[i] = gly_line[i].to(memory_format=torch.contiguous_format).float()
  482. positions[i] = (
  483. positions[i].to(memory_format=torch.contiguous_format).float()
  484. )
  485. info = {}
  486. info["glyphs"] = glyphs
  487. info["positions"] = positions
  488. info["n_lines"] = n_lines
  489. info["language"] = language
  490. info["texts"] = texts
  491. info["img"] = batch["img"] # nhwc, (-1,1)
  492. info["masked_x"] = mx
  493. info["gly_line"] = gly_line
  494. info["inv_mask"] = inv_mask
  495. return x, dict(c_crossattn=[c], c_concat=[control], text_info=info)
  496. def apply_model(self, x_noisy, t, cond, *args, **kwargs):
  497. assert isinstance(cond, dict)
  498. diffusion_model = self.model.diffusion_model
  499. _cond = torch.cat(cond["c_crossattn"], 1)
  500. _hint = torch.cat(cond["c_concat"], 1)
  501. if self.use_fp16:
  502. x_noisy = x_noisy.half()
  503. control = self.control_model(
  504. x=x_noisy,
  505. timesteps=t,
  506. context=_cond,
  507. hint=_hint,
  508. text_info=cond["text_info"],
  509. )
  510. control = [c * scale for c, scale in zip(control, self.control_scales)]
  511. eps = diffusion_model(
  512. x=x_noisy,
  513. timesteps=t,
  514. context=_cond,
  515. control=control,
  516. only_mid_control=self.only_mid_control,
  517. )
  518. return eps
  519. def instantiate_embedding_manager(self, config, embedder):
  520. model = instantiate_from_config(config, embedder=embedder)
  521. return model
  522. @torch.no_grad()
  523. def get_unconditional_conditioning(self, N):
  524. return self.get_learned_conditioning(
  525. dict(c_crossattn=[[""] * N], text_info=None)
  526. )
  527. def get_learned_conditioning(self, c):
  528. if self.cond_stage_forward is None:
  529. if hasattr(self.cond_stage_model, "encode") and callable(
  530. self.cond_stage_model.encode
  531. ):
  532. if self.embedding_manager is not None and c["text_info"] is not None:
  533. self.embedding_manager.encode_text(c["text_info"])
  534. if isinstance(c, dict):
  535. cond_txt = c["c_crossattn"][0]
  536. else:
  537. cond_txt = c
  538. if self.embedding_manager is not None:
  539. cond_txt = self.cond_stage_model.encode(
  540. cond_txt, embedding_manager=self.embedding_manager
  541. )
  542. else:
  543. cond_txt = self.cond_stage_model.encode(cond_txt)
  544. if isinstance(c, dict):
  545. c["c_crossattn"][0] = cond_txt
  546. else:
  547. c = cond_txt
  548. if isinstance(c, DiagonalGaussianDistribution):
  549. c = c.mode()
  550. else:
  551. c = self.cond_stage_model(c)
  552. else:
  553. assert hasattr(self.cond_stage_model, self.cond_stage_forward)
  554. c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
  555. return c
  556. def fill_caption(self, batch, place_holder="*"):
  557. bs = len(batch["n_lines"])
  558. cond_list = copy.deepcopy(batch[self.cond_stage_key])
  559. for i in range(bs):
  560. n_lines = batch["n_lines"][i]
  561. if n_lines == 0:
  562. continue
  563. cur_cap = cond_list[i]
  564. for j in range(n_lines):
  565. r_txt = batch["texts"][j][i]
  566. cur_cap = cur_cap.replace(place_holder, f'"{r_txt}"', 1)
  567. cond_list[i] = cur_cap
  568. batch[self.cond_stage_key] = cond_list
  569. @torch.no_grad()
  570. def log_images(
  571. self,
  572. batch,
  573. N=4,
  574. n_row=2,
  575. sample=False,
  576. ddim_steps=50,
  577. ddim_eta=0.0,
  578. return_keys=None,
  579. quantize_denoised=True,
  580. inpaint=True,
  581. plot_denoise_rows=False,
  582. plot_progressive_rows=True,
  583. plot_diffusion_rows=False,
  584. unconditional_guidance_scale=9.0,
  585. unconditional_guidance_label=None,
  586. use_ema_scope=True,
  587. **kwargs,
  588. ):
  589. use_ddim = ddim_steps is not None
  590. log = dict()
  591. z, c = self.get_input(batch, self.first_stage_key, bs=N)
  592. if self.cond_stage_trainable:
  593. with torch.no_grad():
  594. c = self.get_learned_conditioning(c)
  595. c_crossattn = c["c_crossattn"][0][:N]
  596. c_cat = c["c_concat"][0][:N]
  597. text_info = c["text_info"]
  598. text_info["glyphs"] = [i[:N] for i in text_info["glyphs"]]
  599. text_info["gly_line"] = [i[:N] for i in text_info["gly_line"]]
  600. text_info["positions"] = [i[:N] for i in text_info["positions"]]
  601. text_info["n_lines"] = text_info["n_lines"][:N]
  602. text_info["masked_x"] = text_info["masked_x"][:N]
  603. text_info["img"] = text_info["img"][:N]
  604. N = min(z.shape[0], N)
  605. n_row = min(z.shape[0], n_row)
  606. log["reconstruction"] = self.decode_first_stage(z)
  607. log["masked_image"] = self.decode_first_stage(text_info["masked_x"])
  608. log["control"] = c_cat * 2.0 - 1.0
  609. log["img"] = text_info["img"].permute(0, 3, 1, 2) # log source image if needed
  610. # get glyph
  611. glyph_bs = torch.stack(text_info["glyphs"])
  612. glyph_bs = torch.sum(glyph_bs, dim=0) * 2.0 - 1.0
  613. log["glyph"] = torch.nn.functional.interpolate(
  614. glyph_bs,
  615. size=(512, 512),
  616. mode="bilinear",
  617. align_corners=True,
  618. )
  619. # fill caption
  620. if not self.embedding_manager:
  621. self.fill_caption(batch)
  622. captions = batch[self.cond_stage_key]
  623. log["conditioning"] = log_txt_as_img((512, 512), captions, size=16)
  624. if plot_diffusion_rows:
  625. # get diffusion row
  626. diffusion_row = list()
  627. z_start = z[:n_row]
  628. for t in range(self.num_timesteps):
  629. if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
  630. t = repeat(torch.tensor([t]), "1 -> b", b=n_row)
  631. t = t.to(self.device).long()
  632. noise = torch.randn_like(z_start)
  633. z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
  634. diffusion_row.append(self.decode_first_stage(z_noisy))
  635. diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
  636. diffusion_grid = rearrange(diffusion_row, "n b c h w -> b n c h w")
  637. diffusion_grid = rearrange(diffusion_grid, "b n c h w -> (b n) c h w")
  638. diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
  639. log["diffusion_row"] = diffusion_grid
  640. if sample:
  641. # get denoise row
  642. samples, z_denoise_row = self.sample_log(
  643. cond={"c_concat": [c_cat], "c_crossattn": [c], "text_info": text_info},
  644. batch_size=N,
  645. ddim=use_ddim,
  646. ddim_steps=ddim_steps,
  647. eta=ddim_eta,
  648. )
  649. x_samples = self.decode_first_stage(samples)
  650. log["samples"] = x_samples
  651. if plot_denoise_rows:
  652. denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
  653. log["denoise_row"] = denoise_grid
  654. if unconditional_guidance_scale > 1.0:
  655. uc_cross = self.get_unconditional_conditioning(N)
  656. uc_cat = c_cat # torch.zeros_like(c_cat)
  657. uc_full = {
  658. "c_concat": [uc_cat],
  659. "c_crossattn": [uc_cross["c_crossattn"][0]],
  660. "text_info": text_info,
  661. }
  662. samples_cfg, tmps = self.sample_log(
  663. cond={
  664. "c_concat": [c_cat],
  665. "c_crossattn": [c_crossattn],
  666. "text_info": text_info,
  667. },
  668. batch_size=N,
  669. ddim=use_ddim,
  670. ddim_steps=ddim_steps,
  671. eta=ddim_eta,
  672. unconditional_guidance_scale=unconditional_guidance_scale,
  673. unconditional_conditioning=uc_full,
  674. )
  675. x_samples_cfg = self.decode_first_stage(samples_cfg)
  676. log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
  677. pred_x0 = False # wether log pred_x0
  678. if pred_x0:
  679. for idx in range(len(tmps["pred_x0"])):
  680. pred_x0 = self.decode_first_stage(tmps["pred_x0"][idx])
  681. log[f"pred_x0_{tmps['index'][idx]}"] = pred_x0
  682. return log
  683. @torch.no_grad()
  684. def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):
  685. ddim_sampler = DDIMSampler(self)
  686. b, c, h, w = cond["c_concat"][0].shape
  687. shape = (self.channels, h // 8, w // 8)
  688. samples, intermediates = ddim_sampler.sample(
  689. ddim_steps, batch_size, shape, cond, verbose=False, log_every_t=5, **kwargs
  690. )
  691. return samples, intermediates
  692. def configure_optimizers(self):
  693. lr = self.learning_rate
  694. params = list(self.control_model.parameters())
  695. if self.embedding_manager:
  696. params += list(self.embedding_manager.embedding_parameters())
  697. if not self.sd_locked:
  698. # params += list(self.model.diffusion_model.input_blocks.parameters())
  699. # params += list(self.model.diffusion_model.middle_block.parameters())
  700. params += list(self.model.diffusion_model.output_blocks.parameters())
  701. params += list(self.model.diffusion_model.out.parameters())
  702. if self.unlockKV:
  703. nCount = 0
  704. for name, param in self.model.diffusion_model.named_parameters():
  705. if "attn2.to_k" in name or "attn2.to_v" in name:
  706. params += [param]
  707. nCount += 1
  708. print(
  709. f"Cross attention is unlocked, and {nCount} Wk or Wv are added to potimizers!!!"
  710. )
  711. opt = torch.optim.AdamW(params, lr=lr)
  712. return opt
  713. def low_vram_shift(self, is_diffusing):
  714. if is_diffusing:
  715. self.model = self.model.cuda()
  716. self.control_model = self.control_model.cuda()
  717. self.first_stage_model = self.first_stage_model.cpu()
  718. self.cond_stage_model = self.cond_stage_model.cpu()
  719. else:
  720. self.model = self.model.cpu()
  721. self.control_model = self.control_model.cpu()
  722. self.first_stage_model = self.first_stage_model.cuda()
  723. self.cond_stage_model = self.cond_stage_model.cuda()