openaimodel.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854
  1. import math
  2. from abc import abstractmethod
  3. import numpy as np
  4. import torch as th
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. from sorawm.iopaint.model.anytext.ldm.modules.attention import SpatialTransformer
  8. from sorawm.iopaint.model.anytext.ldm.modules.diffusionmodules.util import (
  9. avg_pool_nd,
  10. checkpoint,
  11. conv_nd,
  12. linear,
  13. normalization,
  14. timestep_embedding,
  15. zero_module,
  16. )
  17. from sorawm.iopaint.model.anytext.ldm.util import exists
  18. # dummy replace
  19. def convert_module_to_f16(x):
  20. pass
  21. def convert_module_to_f32(x):
  22. pass
  23. ## go
  24. class AttentionPool2d(nn.Module):
  25. """
  26. Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
  27. """
  28. def __init__(
  29. self,
  30. spacial_dim: int,
  31. embed_dim: int,
  32. num_heads_channels: int,
  33. output_dim: int = None,
  34. ):
  35. super().__init__()
  36. self.positional_embedding = nn.Parameter(
  37. th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5
  38. )
  39. self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
  40. self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
  41. self.num_heads = embed_dim // num_heads_channels
  42. self.attention = QKVAttention(self.num_heads)
  43. def forward(self, x):
  44. b, c, *_spatial = x.shape
  45. x = x.reshape(b, c, -1) # NC(HW)
  46. x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
  47. x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
  48. x = self.qkv_proj(x)
  49. x = self.attention(x)
  50. x = self.c_proj(x)
  51. return x[:, :, 0]
  52. class TimestepBlock(nn.Module):
  53. """
  54. Any module where forward() takes timestep embeddings as a second argument.
  55. """
  56. @abstractmethod
  57. def forward(self, x, emb):
  58. """
  59. Apply the module to `x` given `emb` timestep embeddings.
  60. """
  61. class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
  62. """
  63. A sequential module that passes timestep embeddings to the children that
  64. support it as an extra input.
  65. """
  66. def forward(self, x, emb, context=None):
  67. for layer in self:
  68. if isinstance(layer, TimestepBlock):
  69. x = layer(x, emb)
  70. elif isinstance(layer, SpatialTransformer):
  71. x = layer(x, context)
  72. else:
  73. x = layer(x)
  74. return x
  75. class Upsample(nn.Module):
  76. """
  77. An upsampling layer with an optional convolution.
  78. :param channels: channels in the inputs and outputs.
  79. :param use_conv: a bool determining if a convolution is applied.
  80. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
  81. upsampling occurs in the inner-two dimensions.
  82. """
  83. def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
  84. super().__init__()
  85. self.channels = channels
  86. self.out_channels = out_channels or channels
  87. self.use_conv = use_conv
  88. self.dims = dims
  89. if use_conv:
  90. self.conv = conv_nd(
  91. dims, self.channels, self.out_channels, 3, padding=padding
  92. )
  93. def forward(self, x):
  94. assert x.shape[1] == self.channels
  95. if self.dims == 3:
  96. x = F.interpolate(
  97. x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
  98. )
  99. else:
  100. x = F.interpolate(x, scale_factor=2, mode="nearest")
  101. if self.use_conv:
  102. x = self.conv(x)
  103. return x
  104. class TransposedUpsample(nn.Module):
  105. "Learned 2x upsampling without padding"
  106. def __init__(self, channels, out_channels=None, ks=5):
  107. super().__init__()
  108. self.channels = channels
  109. self.out_channels = out_channels or channels
  110. self.up = nn.ConvTranspose2d(
  111. self.channels, self.out_channels, kernel_size=ks, stride=2
  112. )
  113. def forward(self, x):
  114. return self.up(x)
  115. class Downsample(nn.Module):
  116. """
  117. A downsampling layer with an optional convolution.
  118. :param channels: channels in the inputs and outputs.
  119. :param use_conv: a bool determining if a convolution is applied.
  120. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
  121. downsampling occurs in the inner-two dimensions.
  122. """
  123. def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
  124. super().__init__()
  125. self.channels = channels
  126. self.out_channels = out_channels or channels
  127. self.use_conv = use_conv
  128. self.dims = dims
  129. stride = 2 if dims != 3 else (1, 2, 2)
  130. if use_conv:
  131. self.op = conv_nd(
  132. dims,
  133. self.channels,
  134. self.out_channels,
  135. 3,
  136. stride=stride,
  137. padding=padding,
  138. )
  139. else:
  140. assert self.channels == self.out_channels
  141. self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
  142. def forward(self, x):
  143. assert x.shape[1] == self.channels
  144. return self.op(x)
  145. class ResBlock(TimestepBlock):
  146. """
  147. A residual block that can optionally change the number of channels.
  148. :param channels: the number of input channels.
  149. :param emb_channels: the number of timestep embedding channels.
  150. :param dropout: the rate of dropout.
  151. :param out_channels: if specified, the number of out channels.
  152. :param use_conv: if True and out_channels is specified, use a spatial
  153. convolution instead of a smaller 1x1 convolution to change the
  154. channels in the skip connection.
  155. :param dims: determines if the signal is 1D, 2D, or 3D.
  156. :param use_checkpoint: if True, use gradient checkpointing on this module.
  157. :param up: if True, use this block for upsampling.
  158. :param down: if True, use this block for downsampling.
  159. """
  160. def __init__(
  161. self,
  162. channels,
  163. emb_channels,
  164. dropout,
  165. out_channels=None,
  166. use_conv=False,
  167. use_scale_shift_norm=False,
  168. dims=2,
  169. use_checkpoint=False,
  170. up=False,
  171. down=False,
  172. ):
  173. super().__init__()
  174. self.channels = channels
  175. self.emb_channels = emb_channels
  176. self.dropout = dropout
  177. self.out_channels = out_channels or channels
  178. self.use_conv = use_conv
  179. self.use_checkpoint = use_checkpoint
  180. self.use_scale_shift_norm = use_scale_shift_norm
  181. self.in_layers = nn.Sequential(
  182. normalization(channels),
  183. nn.SiLU(),
  184. conv_nd(dims, channels, self.out_channels, 3, padding=1),
  185. )
  186. self.updown = up or down
  187. if up:
  188. self.h_upd = Upsample(channels, False, dims)
  189. self.x_upd = Upsample(channels, False, dims)
  190. elif down:
  191. self.h_upd = Downsample(channels, False, dims)
  192. self.x_upd = Downsample(channels, False, dims)
  193. else:
  194. self.h_upd = self.x_upd = nn.Identity()
  195. self.emb_layers = nn.Sequential(
  196. nn.SiLU(),
  197. linear(
  198. emb_channels,
  199. 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
  200. ),
  201. )
  202. self.out_layers = nn.Sequential(
  203. normalization(self.out_channels),
  204. nn.SiLU(),
  205. nn.Dropout(p=dropout),
  206. zero_module(
  207. conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
  208. ),
  209. )
  210. if self.out_channels == channels:
  211. self.skip_connection = nn.Identity()
  212. elif use_conv:
  213. self.skip_connection = conv_nd(
  214. dims, channels, self.out_channels, 3, padding=1
  215. )
  216. else:
  217. self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
  218. def forward(self, x, emb):
  219. """
  220. Apply the block to a Tensor, conditioned on a timestep embedding.
  221. :param x: an [N x C x ...] Tensor of features.
  222. :param emb: an [N x emb_channels] Tensor of timestep embeddings.
  223. :return: an [N x C x ...] Tensor of outputs.
  224. """
  225. return checkpoint(
  226. self._forward, (x, emb), self.parameters(), self.use_checkpoint
  227. )
  228. def _forward(self, x, emb):
  229. if self.updown:
  230. in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
  231. h = in_rest(x)
  232. h = self.h_upd(h)
  233. x = self.x_upd(x)
  234. h = in_conv(h)
  235. else:
  236. h = self.in_layers(x)
  237. emb_out = self.emb_layers(emb).type(h.dtype)
  238. while len(emb_out.shape) < len(h.shape):
  239. emb_out = emb_out[..., None]
  240. if self.use_scale_shift_norm:
  241. out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
  242. scale, shift = th.chunk(emb_out, 2, dim=1)
  243. h = out_norm(h) * (1 + scale) + shift
  244. h = out_rest(h)
  245. else:
  246. h = h + emb_out
  247. h = self.out_layers(h)
  248. return self.skip_connection(x) + h
  249. class AttentionBlock(nn.Module):
  250. """
  251. An attention block that allows spatial positions to attend to each other.
  252. Originally ported from here, but adapted to the N-d case.
  253. https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
  254. """
  255. def __init__(
  256. self,
  257. channels,
  258. num_heads=1,
  259. num_head_channels=-1,
  260. use_checkpoint=False,
  261. use_new_attention_order=False,
  262. ):
  263. super().__init__()
  264. self.channels = channels
  265. if num_head_channels == -1:
  266. self.num_heads = num_heads
  267. else:
  268. assert (
  269. channels % num_head_channels == 0
  270. ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
  271. self.num_heads = channels // num_head_channels
  272. self.use_checkpoint = use_checkpoint
  273. self.norm = normalization(channels)
  274. self.qkv = conv_nd(1, channels, channels * 3, 1)
  275. if use_new_attention_order:
  276. # split qkv before split heads
  277. self.attention = QKVAttention(self.num_heads)
  278. else:
  279. # split heads before split qkv
  280. self.attention = QKVAttentionLegacy(self.num_heads)
  281. self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
  282. def forward(self, x):
  283. return checkpoint(
  284. self._forward, (x,), self.parameters(), True
  285. ) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
  286. # return pt_checkpoint(self._forward, x) # pytorch
  287. def _forward(self, x):
  288. b, c, *spatial = x.shape
  289. x = x.reshape(b, c, -1)
  290. qkv = self.qkv(self.norm(x))
  291. h = self.attention(qkv)
  292. h = self.proj_out(h)
  293. return (x + h).reshape(b, c, *spatial)
  294. def count_flops_attn(model, _x, y):
  295. """
  296. A counter for the `thop` package to count the operations in an
  297. attention operation.
  298. Meant to be used like:
  299. macs, params = thop.profile(
  300. model,
  301. inputs=(inputs, timestamps),
  302. custom_ops={QKVAttention: QKVAttention.count_flops},
  303. )
  304. """
  305. b, c, *spatial = y[0].shape
  306. num_spatial = int(np.prod(spatial))
  307. # We perform two matmuls with the same number of ops.
  308. # The first computes the weight matrix, the second computes
  309. # the combination of the value vectors.
  310. matmul_ops = 2 * b * (num_spatial**2) * c
  311. model.total_ops += th.DoubleTensor([matmul_ops])
  312. class QKVAttentionLegacy(nn.Module):
  313. """
  314. A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
  315. """
  316. def __init__(self, n_heads):
  317. super().__init__()
  318. self.n_heads = n_heads
  319. def forward(self, qkv):
  320. """
  321. Apply QKV attention.
  322. :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
  323. :return: an [N x (H * C) x T] tensor after attention.
  324. """
  325. bs, width, length = qkv.shape
  326. assert width % (3 * self.n_heads) == 0
  327. ch = width // (3 * self.n_heads)
  328. q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
  329. scale = 1 / math.sqrt(math.sqrt(ch))
  330. weight = th.einsum(
  331. "bct,bcs->bts", q * scale, k * scale
  332. ) # More stable with f16 than dividing afterwards
  333. weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
  334. a = th.einsum("bts,bcs->bct", weight, v)
  335. return a.reshape(bs, -1, length)
  336. @staticmethod
  337. def count_flops(model, _x, y):
  338. return count_flops_attn(model, _x, y)
  339. class QKVAttention(nn.Module):
  340. """
  341. A module which performs QKV attention and splits in a different order.
  342. """
  343. def __init__(self, n_heads):
  344. super().__init__()
  345. self.n_heads = n_heads
  346. def forward(self, qkv):
  347. """
  348. Apply QKV attention.
  349. :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
  350. :return: an [N x (H * C) x T] tensor after attention.
  351. """
  352. bs, width, length = qkv.shape
  353. assert width % (3 * self.n_heads) == 0
  354. ch = width // (3 * self.n_heads)
  355. q, k, v = qkv.chunk(3, dim=1)
  356. scale = 1 / math.sqrt(math.sqrt(ch))
  357. weight = th.einsum(
  358. "bct,bcs->bts",
  359. (q * scale).view(bs * self.n_heads, ch, length),
  360. (k * scale).view(bs * self.n_heads, ch, length),
  361. ) # More stable with f16 than dividing afterwards
  362. weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
  363. a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
  364. return a.reshape(bs, -1, length)
  365. @staticmethod
  366. def count_flops(model, _x, y):
  367. return count_flops_attn(model, _x, y)
  368. class UNetModel(nn.Module):
  369. """
  370. The full UNet model with attention and timestep embedding.
  371. :param in_channels: channels in the input Tensor.
  372. :param model_channels: base channel count for the model.
  373. :param out_channels: channels in the output Tensor.
  374. :param num_res_blocks: number of residual blocks per downsample.
  375. :param attention_resolutions: a collection of downsample rates at which
  376. attention will take place. May be a set, list, or tuple.
  377. For example, if this contains 4, then at 4x downsampling, attention
  378. will be used.
  379. :param dropout: the dropout probability.
  380. :param channel_mult: channel multiplier for each level of the UNet.
  381. :param conv_resample: if True, use learned convolutions for upsampling and
  382. downsampling.
  383. :param dims: determines if the signal is 1D, 2D, or 3D.
  384. :param num_classes: if specified (as an int), then this model will be
  385. class-conditional with `num_classes` classes.
  386. :param use_checkpoint: use gradient checkpointing to reduce memory usage.
  387. :param num_heads: the number of attention heads in each attention layer.
  388. :param num_heads_channels: if specified, ignore num_heads and instead use
  389. a fixed channel width per attention head.
  390. :param num_heads_upsample: works with num_heads to set a different number
  391. of heads for upsampling. Deprecated.
  392. :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
  393. :param resblock_updown: use residual blocks for up/downsampling.
  394. :param use_new_attention_order: use a different attention pattern for potentially
  395. increased efficiency.
  396. """
  397. def __init__(
  398. self,
  399. image_size,
  400. in_channels,
  401. model_channels,
  402. out_channels,
  403. num_res_blocks,
  404. attention_resolutions,
  405. dropout=0,
  406. channel_mult=(1, 2, 4, 8),
  407. conv_resample=True,
  408. dims=2,
  409. num_classes=None,
  410. use_checkpoint=False,
  411. use_fp16=False,
  412. num_heads=-1,
  413. num_head_channels=-1,
  414. num_heads_upsample=-1,
  415. use_scale_shift_norm=False,
  416. resblock_updown=False,
  417. use_new_attention_order=False,
  418. use_spatial_transformer=False, # custom transformer support
  419. transformer_depth=1, # custom transformer support
  420. context_dim=None, # custom transformer support
  421. n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
  422. legacy=True,
  423. disable_self_attentions=None,
  424. num_attention_blocks=None,
  425. disable_middle_self_attn=False,
  426. use_linear_in_transformer=False,
  427. ):
  428. super().__init__()
  429. if use_spatial_transformer:
  430. assert (
  431. context_dim is not None
  432. ), "Fool!! You forgot to include the dimension of your cross-attention conditioning..."
  433. if context_dim is not None:
  434. assert (
  435. use_spatial_transformer
  436. ), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..."
  437. from omegaconf.listconfig import ListConfig
  438. if type(context_dim) == ListConfig:
  439. context_dim = list(context_dim)
  440. if num_heads_upsample == -1:
  441. num_heads_upsample = num_heads
  442. if num_heads == -1:
  443. assert (
  444. num_head_channels != -1
  445. ), "Either num_heads or num_head_channels has to be set"
  446. if num_head_channels == -1:
  447. assert (
  448. num_heads != -1
  449. ), "Either num_heads or num_head_channels has to be set"
  450. self.image_size = image_size
  451. self.in_channels = in_channels
  452. self.model_channels = model_channels
  453. self.out_channels = out_channels
  454. if isinstance(num_res_blocks, int):
  455. self.num_res_blocks = len(channel_mult) * [num_res_blocks]
  456. else:
  457. if len(num_res_blocks) != len(channel_mult):
  458. raise ValueError(
  459. "provide num_res_blocks either as an int (globally constant) or "
  460. "as a list/tuple (per-level) with the same length as channel_mult"
  461. )
  462. self.num_res_blocks = num_res_blocks
  463. if disable_self_attentions is not None:
  464. # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
  465. assert len(disable_self_attentions) == len(channel_mult)
  466. if num_attention_blocks is not None:
  467. assert len(num_attention_blocks) == len(self.num_res_blocks)
  468. assert all(
  469. map(
  470. lambda i: self.num_res_blocks[i] >= num_attention_blocks[i],
  471. range(len(num_attention_blocks)),
  472. )
  473. )
  474. print(
  475. f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
  476. f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
  477. f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
  478. f"attention will still not be set."
  479. )
  480. self.use_fp16 = use_fp16
  481. self.attention_resolutions = attention_resolutions
  482. self.dropout = dropout
  483. self.channel_mult = channel_mult
  484. self.conv_resample = conv_resample
  485. self.num_classes = num_classes
  486. self.use_checkpoint = use_checkpoint
  487. self.dtype = th.float16 if use_fp16 else th.float32
  488. self.num_heads = num_heads
  489. self.num_head_channels = num_head_channels
  490. self.num_heads_upsample = num_heads_upsample
  491. self.predict_codebook_ids = n_embed is not None
  492. time_embed_dim = model_channels * 4
  493. self.time_embed = nn.Sequential(
  494. linear(model_channels, time_embed_dim),
  495. nn.SiLU(),
  496. linear(time_embed_dim, time_embed_dim),
  497. )
  498. if self.num_classes is not None:
  499. if isinstance(self.num_classes, int):
  500. self.label_emb = nn.Embedding(num_classes, time_embed_dim)
  501. elif self.num_classes == "continuous":
  502. print("setting up linear c_adm embedding layer")
  503. self.label_emb = nn.Linear(1, time_embed_dim)
  504. else:
  505. raise ValueError()
  506. self.input_blocks = nn.ModuleList(
  507. [
  508. TimestepEmbedSequential(
  509. conv_nd(dims, in_channels, model_channels, 3, padding=1)
  510. )
  511. ]
  512. )
  513. self._feature_size = model_channels
  514. input_block_chans = [model_channels]
  515. ch = model_channels
  516. ds = 1
  517. for level, mult in enumerate(channel_mult):
  518. for nr in range(self.num_res_blocks[level]):
  519. layers = [
  520. ResBlock(
  521. ch,
  522. time_embed_dim,
  523. dropout,
  524. out_channels=mult * model_channels,
  525. dims=dims,
  526. use_checkpoint=use_checkpoint,
  527. use_scale_shift_norm=use_scale_shift_norm,
  528. )
  529. ]
  530. ch = mult * model_channels
  531. if ds in attention_resolutions:
  532. if num_head_channels == -1:
  533. dim_head = ch // num_heads
  534. else:
  535. num_heads = ch // num_head_channels
  536. dim_head = num_head_channels
  537. if legacy:
  538. # num_heads = 1
  539. dim_head = (
  540. ch // num_heads
  541. if use_spatial_transformer
  542. else num_head_channels
  543. )
  544. if exists(disable_self_attentions):
  545. disabled_sa = disable_self_attentions[level]
  546. else:
  547. disabled_sa = False
  548. if (
  549. not exists(num_attention_blocks)
  550. or nr < num_attention_blocks[level]
  551. ):
  552. layers.append(
  553. AttentionBlock(
  554. ch,
  555. use_checkpoint=use_checkpoint,
  556. num_heads=num_heads,
  557. num_head_channels=dim_head,
  558. use_new_attention_order=use_new_attention_order,
  559. )
  560. if not use_spatial_transformer
  561. else SpatialTransformer(
  562. ch,
  563. num_heads,
  564. dim_head,
  565. depth=transformer_depth,
  566. context_dim=context_dim,
  567. disable_self_attn=disabled_sa,
  568. use_linear=use_linear_in_transformer,
  569. use_checkpoint=use_checkpoint,
  570. )
  571. )
  572. self.input_blocks.append(TimestepEmbedSequential(*layers))
  573. self._feature_size += ch
  574. input_block_chans.append(ch)
  575. if level != len(channel_mult) - 1:
  576. out_ch = ch
  577. self.input_blocks.append(
  578. TimestepEmbedSequential(
  579. ResBlock(
  580. ch,
  581. time_embed_dim,
  582. dropout,
  583. out_channels=out_ch,
  584. dims=dims,
  585. use_checkpoint=use_checkpoint,
  586. use_scale_shift_norm=use_scale_shift_norm,
  587. down=True,
  588. )
  589. if resblock_updown
  590. else Downsample(
  591. ch, conv_resample, dims=dims, out_channels=out_ch
  592. )
  593. )
  594. )
  595. ch = out_ch
  596. input_block_chans.append(ch)
  597. ds *= 2
  598. self._feature_size += ch
  599. if num_head_channels == -1:
  600. dim_head = ch // num_heads
  601. else:
  602. num_heads = ch // num_head_channels
  603. dim_head = num_head_channels
  604. if legacy:
  605. # num_heads = 1
  606. dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
  607. self.middle_block = TimestepEmbedSequential(
  608. ResBlock(
  609. ch,
  610. time_embed_dim,
  611. dropout,
  612. dims=dims,
  613. use_checkpoint=use_checkpoint,
  614. use_scale_shift_norm=use_scale_shift_norm,
  615. ),
  616. AttentionBlock(
  617. ch,
  618. use_checkpoint=use_checkpoint,
  619. num_heads=num_heads,
  620. num_head_channels=dim_head,
  621. use_new_attention_order=use_new_attention_order,
  622. )
  623. if not use_spatial_transformer
  624. else SpatialTransformer( # always uses a self-attn
  625. ch,
  626. num_heads,
  627. dim_head,
  628. depth=transformer_depth,
  629. context_dim=context_dim,
  630. disable_self_attn=disable_middle_self_attn,
  631. use_linear=use_linear_in_transformer,
  632. use_checkpoint=use_checkpoint,
  633. ),
  634. ResBlock(
  635. ch,
  636. time_embed_dim,
  637. dropout,
  638. dims=dims,
  639. use_checkpoint=use_checkpoint,
  640. use_scale_shift_norm=use_scale_shift_norm,
  641. ),
  642. )
  643. self._feature_size += ch
  644. self.output_blocks = nn.ModuleList([])
  645. for level, mult in list(enumerate(channel_mult))[::-1]:
  646. for i in range(self.num_res_blocks[level] + 1):
  647. ich = input_block_chans.pop()
  648. layers = [
  649. ResBlock(
  650. ch + ich,
  651. time_embed_dim,
  652. dropout,
  653. out_channels=model_channels * mult,
  654. dims=dims,
  655. use_checkpoint=use_checkpoint,
  656. use_scale_shift_norm=use_scale_shift_norm,
  657. )
  658. ]
  659. ch = model_channels * mult
  660. if ds in attention_resolutions:
  661. if num_head_channels == -1:
  662. dim_head = ch // num_heads
  663. else:
  664. num_heads = ch // num_head_channels
  665. dim_head = num_head_channels
  666. if legacy:
  667. # num_heads = 1
  668. dim_head = (
  669. ch // num_heads
  670. if use_spatial_transformer
  671. else num_head_channels
  672. )
  673. if exists(disable_self_attentions):
  674. disabled_sa = disable_self_attentions[level]
  675. else:
  676. disabled_sa = False
  677. if (
  678. not exists(num_attention_blocks)
  679. or i < num_attention_blocks[level]
  680. ):
  681. layers.append(
  682. AttentionBlock(
  683. ch,
  684. use_checkpoint=use_checkpoint,
  685. num_heads=num_heads_upsample,
  686. num_head_channels=dim_head,
  687. use_new_attention_order=use_new_attention_order,
  688. )
  689. if not use_spatial_transformer
  690. else SpatialTransformer(
  691. ch,
  692. num_heads,
  693. dim_head,
  694. depth=transformer_depth,
  695. context_dim=context_dim,
  696. disable_self_attn=disabled_sa,
  697. use_linear=use_linear_in_transformer,
  698. use_checkpoint=use_checkpoint,
  699. )
  700. )
  701. if level and i == self.num_res_blocks[level]:
  702. out_ch = ch
  703. layers.append(
  704. ResBlock(
  705. ch,
  706. time_embed_dim,
  707. dropout,
  708. out_channels=out_ch,
  709. dims=dims,
  710. use_checkpoint=use_checkpoint,
  711. use_scale_shift_norm=use_scale_shift_norm,
  712. up=True,
  713. )
  714. if resblock_updown
  715. else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
  716. )
  717. ds //= 2
  718. self.output_blocks.append(TimestepEmbedSequential(*layers))
  719. self._feature_size += ch
  720. self.out = nn.Sequential(
  721. normalization(ch),
  722. nn.SiLU(),
  723. zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
  724. )
  725. if self.predict_codebook_ids:
  726. self.id_predictor = nn.Sequential(
  727. normalization(ch),
  728. conv_nd(dims, model_channels, n_embed, 1),
  729. # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
  730. )
  731. def convert_to_fp16(self):
  732. """
  733. Convert the torso of the model to float16.
  734. """
  735. self.input_blocks.apply(convert_module_to_f16)
  736. self.middle_block.apply(convert_module_to_f16)
  737. self.output_blocks.apply(convert_module_to_f16)
  738. def convert_to_fp32(self):
  739. """
  740. Convert the torso of the model to float32.
  741. """
  742. self.input_blocks.apply(convert_module_to_f32)
  743. self.middle_block.apply(convert_module_to_f32)
  744. self.output_blocks.apply(convert_module_to_f32)
  745. def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
  746. """
  747. Apply the model to an input batch.
  748. :param x: an [N x C x ...] Tensor of inputs.
  749. :param timesteps: a 1-D batch of timesteps.
  750. :param context: conditioning plugged in via crossattn
  751. :param y: an [N] Tensor of labels, if class-conditional.
  752. :return: an [N x C x ...] Tensor of outputs.
  753. """
  754. assert (y is not None) == (
  755. self.num_classes is not None
  756. ), "must specify y if and only if the model is class-conditional"
  757. hs = []
  758. t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
  759. emb = self.time_embed(t_emb)
  760. if self.num_classes is not None:
  761. assert y.shape[0] == x.shape[0]
  762. emb = emb + self.label_emb(y)
  763. h = x.type(self.dtype)
  764. for module in self.input_blocks:
  765. h = module(h, emb, context)
  766. hs.append(h)
  767. h = self.middle_block(h, emb, context)
  768. for module in self.output_blocks:
  769. h = th.cat([h, hs.pop()], dim=1)
  770. h = module(h, emb, context)
  771. h = h.type(x.dtype)
  772. if self.predict_codebook_ids:
  773. return self.id_predictor(h)
  774. else:
  775. return self.out(h)