model.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973
  1. # pytorch_diffusion + derived encoder decoder
  2. import math
  3. import numpy as np
  4. import torch
  5. import torch.nn as nn
  6. def get_timestep_embedding(timesteps, embedding_dim):
  7. """
  8. This matches the implementation in Denoising Diffusion Probabilistic Models:
  9. From Fairseq.
  10. Build sinusoidal embeddings.
  11. This matches the implementation in tensor2tensor, but differs slightly
  12. from the description in Section 3.5 of "Attention Is All You Need".
  13. """
  14. assert len(timesteps.shape) == 1
  15. half_dim = embedding_dim // 2
  16. emb = math.log(10000) / (half_dim - 1)
  17. emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
  18. emb = emb.to(device=timesteps.device)
  19. emb = timesteps.float()[:, None] * emb[None, :]
  20. emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
  21. if embedding_dim % 2 == 1: # zero pad
  22. emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
  23. return emb
  24. def nonlinearity(x):
  25. # swish
  26. return x * torch.sigmoid(x)
  27. def Normalize(in_channels, num_groups=32):
  28. return torch.nn.GroupNorm(
  29. num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
  30. )
  31. class Upsample(nn.Module):
  32. def __init__(self, in_channels, with_conv):
  33. super().__init__()
  34. self.with_conv = with_conv
  35. if self.with_conv:
  36. self.conv = torch.nn.Conv2d(
  37. in_channels, in_channels, kernel_size=3, stride=1, padding=1
  38. )
  39. def forward(self, x):
  40. x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
  41. if self.with_conv:
  42. x = self.conv(x)
  43. return x
  44. class Downsample(nn.Module):
  45. def __init__(self, in_channels, with_conv):
  46. super().__init__()
  47. self.with_conv = with_conv
  48. if self.with_conv:
  49. # no asymmetric padding in torch conv, must do it ourselves
  50. self.conv = torch.nn.Conv2d(
  51. in_channels, in_channels, kernel_size=3, stride=2, padding=0
  52. )
  53. def forward(self, x):
  54. if self.with_conv:
  55. pad = (0, 1, 0, 1)
  56. x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
  57. x = self.conv(x)
  58. else:
  59. x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
  60. return x
  61. class ResnetBlock(nn.Module):
  62. def __init__(
  63. self,
  64. *,
  65. in_channels,
  66. out_channels=None,
  67. conv_shortcut=False,
  68. dropout,
  69. temb_channels=512,
  70. ):
  71. super().__init__()
  72. self.in_channels = in_channels
  73. out_channels = in_channels if out_channels is None else out_channels
  74. self.out_channels = out_channels
  75. self.use_conv_shortcut = conv_shortcut
  76. self.norm1 = Normalize(in_channels)
  77. self.conv1 = torch.nn.Conv2d(
  78. in_channels, out_channels, kernel_size=3, stride=1, padding=1
  79. )
  80. if temb_channels > 0:
  81. self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
  82. self.norm2 = Normalize(out_channels)
  83. self.dropout = torch.nn.Dropout(dropout)
  84. self.conv2 = torch.nn.Conv2d(
  85. out_channels, out_channels, kernel_size=3, stride=1, padding=1
  86. )
  87. if self.in_channels != self.out_channels:
  88. if self.use_conv_shortcut:
  89. self.conv_shortcut = torch.nn.Conv2d(
  90. in_channels, out_channels, kernel_size=3, stride=1, padding=1
  91. )
  92. else:
  93. self.nin_shortcut = torch.nn.Conv2d(
  94. in_channels, out_channels, kernel_size=1, stride=1, padding=0
  95. )
  96. def forward(self, x, temb):
  97. h = x
  98. h = self.norm1(h)
  99. h = nonlinearity(h)
  100. h = self.conv1(h)
  101. if temb is not None:
  102. h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
  103. h = self.norm2(h)
  104. h = nonlinearity(h)
  105. h = self.dropout(h)
  106. h = self.conv2(h)
  107. if self.in_channels != self.out_channels:
  108. if self.use_conv_shortcut:
  109. x = self.conv_shortcut(x)
  110. else:
  111. x = self.nin_shortcut(x)
  112. return x + h
  113. class AttnBlock(nn.Module):
  114. def __init__(self, in_channels):
  115. super().__init__()
  116. self.in_channels = in_channels
  117. self.norm = Normalize(in_channels)
  118. self.q = torch.nn.Conv2d(
  119. in_channels, in_channels, kernel_size=1, stride=1, padding=0
  120. )
  121. self.k = torch.nn.Conv2d(
  122. in_channels, in_channels, kernel_size=1, stride=1, padding=0
  123. )
  124. self.v = torch.nn.Conv2d(
  125. in_channels, in_channels, kernel_size=1, stride=1, padding=0
  126. )
  127. self.proj_out = torch.nn.Conv2d(
  128. in_channels, in_channels, kernel_size=1, stride=1, padding=0
  129. )
  130. def forward(self, x):
  131. h_ = x
  132. h_ = self.norm(h_)
  133. q = self.q(h_)
  134. k = self.k(h_)
  135. v = self.v(h_)
  136. # compute attention
  137. b, c, h, w = q.shape
  138. q = q.reshape(b, c, h * w)
  139. q = q.permute(0, 2, 1) # b,hw,c
  140. k = k.reshape(b, c, h * w) # b,c,hw
  141. w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
  142. w_ = w_ * (int(c) ** (-0.5))
  143. w_ = torch.nn.functional.softmax(w_, dim=2)
  144. # attend to values
  145. v = v.reshape(b, c, h * w)
  146. w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
  147. h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
  148. h_ = h_.reshape(b, c, h, w)
  149. h_ = self.proj_out(h_)
  150. return x + h_
  151. class AttnBlock2_0(nn.Module):
  152. def __init__(self, in_channels):
  153. super().__init__()
  154. self.in_channels = in_channels
  155. self.norm = Normalize(in_channels)
  156. self.q = torch.nn.Conv2d(
  157. in_channels, in_channels, kernel_size=1, stride=1, padding=0
  158. )
  159. self.k = torch.nn.Conv2d(
  160. in_channels, in_channels, kernel_size=1, stride=1, padding=0
  161. )
  162. self.v = torch.nn.Conv2d(
  163. in_channels, in_channels, kernel_size=1, stride=1, padding=0
  164. )
  165. self.proj_out = torch.nn.Conv2d(
  166. in_channels, in_channels, kernel_size=1, stride=1, padding=0
  167. )
  168. def forward(self, x):
  169. h_ = x
  170. h_ = self.norm(h_)
  171. # output: [1, 512, 64, 64]
  172. q = self.q(h_)
  173. k = self.k(h_)
  174. v = self.v(h_)
  175. # compute attention
  176. b, c, h, w = q.shape
  177. # q = q.reshape(b, c, h * w).transpose()
  178. # q = q.permute(0, 2, 1) # b,hw,c
  179. # k = k.reshape(b, c, h * w) # b,c,hw
  180. q = q.transpose(1, 2)
  181. k = k.transpose(1, 2)
  182. v = v.transpose(1, 2)
  183. # (batch, num_heads, seq_len, head_dim)
  184. hidden_states = torch.nn.functional.scaled_dot_product_attention(
  185. q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False
  186. )
  187. hidden_states = hidden_states.transpose(1, 2)
  188. hidden_states = hidden_states.to(q.dtype)
  189. h_ = self.proj_out(hidden_states)
  190. return x + h_
  191. def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
  192. assert attn_type in [
  193. "vanilla",
  194. "vanilla-xformers",
  195. "memory-efficient-cross-attn",
  196. "linear",
  197. "none",
  198. ], f"attn_type {attn_type} unknown"
  199. assert attn_kwargs is None
  200. if hasattr(torch.nn.functional, "scaled_dot_product_attention"):
  201. # print(f"Using torch.nn.functional.scaled_dot_product_attention")
  202. return AttnBlock2_0(in_channels)
  203. return AttnBlock(in_channels)
  204. class Model(nn.Module):
  205. def __init__(
  206. self,
  207. *,
  208. ch,
  209. out_ch,
  210. ch_mult=(1, 2, 4, 8),
  211. num_res_blocks,
  212. attn_resolutions,
  213. dropout=0.0,
  214. resamp_with_conv=True,
  215. in_channels,
  216. resolution,
  217. use_timestep=True,
  218. use_linear_attn=False,
  219. attn_type="vanilla",
  220. ):
  221. super().__init__()
  222. if use_linear_attn:
  223. attn_type = "linear"
  224. self.ch = ch
  225. self.temb_ch = self.ch * 4
  226. self.num_resolutions = len(ch_mult)
  227. self.num_res_blocks = num_res_blocks
  228. self.resolution = resolution
  229. self.in_channels = in_channels
  230. self.use_timestep = use_timestep
  231. if self.use_timestep:
  232. # timestep embedding
  233. self.temb = nn.Module()
  234. self.temb.dense = nn.ModuleList(
  235. [
  236. torch.nn.Linear(self.ch, self.temb_ch),
  237. torch.nn.Linear(self.temb_ch, self.temb_ch),
  238. ]
  239. )
  240. # downsampling
  241. self.conv_in = torch.nn.Conv2d(
  242. in_channels, self.ch, kernel_size=3, stride=1, padding=1
  243. )
  244. curr_res = resolution
  245. in_ch_mult = (1,) + tuple(ch_mult)
  246. self.down = nn.ModuleList()
  247. for i_level in range(self.num_resolutions):
  248. block = nn.ModuleList()
  249. attn = nn.ModuleList()
  250. block_in = ch * in_ch_mult[i_level]
  251. block_out = ch * ch_mult[i_level]
  252. for i_block in range(self.num_res_blocks):
  253. block.append(
  254. ResnetBlock(
  255. in_channels=block_in,
  256. out_channels=block_out,
  257. temb_channels=self.temb_ch,
  258. dropout=dropout,
  259. )
  260. )
  261. block_in = block_out
  262. if curr_res in attn_resolutions:
  263. attn.append(make_attn(block_in, attn_type=attn_type))
  264. down = nn.Module()
  265. down.block = block
  266. down.attn = attn
  267. if i_level != self.num_resolutions - 1:
  268. down.downsample = Downsample(block_in, resamp_with_conv)
  269. curr_res = curr_res // 2
  270. self.down.append(down)
  271. # middle
  272. self.mid = nn.Module()
  273. self.mid.block_1 = ResnetBlock(
  274. in_channels=block_in,
  275. out_channels=block_in,
  276. temb_channels=self.temb_ch,
  277. dropout=dropout,
  278. )
  279. self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
  280. self.mid.block_2 = ResnetBlock(
  281. in_channels=block_in,
  282. out_channels=block_in,
  283. temb_channels=self.temb_ch,
  284. dropout=dropout,
  285. )
  286. # upsampling
  287. self.up = nn.ModuleList()
  288. for i_level in reversed(range(self.num_resolutions)):
  289. block = nn.ModuleList()
  290. attn = nn.ModuleList()
  291. block_out = ch * ch_mult[i_level]
  292. skip_in = ch * ch_mult[i_level]
  293. for i_block in range(self.num_res_blocks + 1):
  294. if i_block == self.num_res_blocks:
  295. skip_in = ch * in_ch_mult[i_level]
  296. block.append(
  297. ResnetBlock(
  298. in_channels=block_in + skip_in,
  299. out_channels=block_out,
  300. temb_channels=self.temb_ch,
  301. dropout=dropout,
  302. )
  303. )
  304. block_in = block_out
  305. if curr_res in attn_resolutions:
  306. attn.append(make_attn(block_in, attn_type=attn_type))
  307. up = nn.Module()
  308. up.block = block
  309. up.attn = attn
  310. if i_level != 0:
  311. up.upsample = Upsample(block_in, resamp_with_conv)
  312. curr_res = curr_res * 2
  313. self.up.insert(0, up) # prepend to get consistent order
  314. # end
  315. self.norm_out = Normalize(block_in)
  316. self.conv_out = torch.nn.Conv2d(
  317. block_in, out_ch, kernel_size=3, stride=1, padding=1
  318. )
  319. def forward(self, x, t=None, context=None):
  320. # assert x.shape[2] == x.shape[3] == self.resolution
  321. if context is not None:
  322. # assume aligned context, cat along channel axis
  323. x = torch.cat((x, context), dim=1)
  324. if self.use_timestep:
  325. # timestep embedding
  326. assert t is not None
  327. temb = get_timestep_embedding(t, self.ch)
  328. temb = self.temb.dense[0](temb)
  329. temb = nonlinearity(temb)
  330. temb = self.temb.dense[1](temb)
  331. else:
  332. temb = None
  333. # downsampling
  334. hs = [self.conv_in(x)]
  335. for i_level in range(self.num_resolutions):
  336. for i_block in range(self.num_res_blocks):
  337. h = self.down[i_level].block[i_block](hs[-1], temb)
  338. if len(self.down[i_level].attn) > 0:
  339. h = self.down[i_level].attn[i_block](h)
  340. hs.append(h)
  341. if i_level != self.num_resolutions - 1:
  342. hs.append(self.down[i_level].downsample(hs[-1]))
  343. # middle
  344. h = hs[-1]
  345. h = self.mid.block_1(h, temb)
  346. h = self.mid.attn_1(h)
  347. h = self.mid.block_2(h, temb)
  348. # upsampling
  349. for i_level in reversed(range(self.num_resolutions)):
  350. for i_block in range(self.num_res_blocks + 1):
  351. h = self.up[i_level].block[i_block](
  352. torch.cat([h, hs.pop()], dim=1), temb
  353. )
  354. if len(self.up[i_level].attn) > 0:
  355. h = self.up[i_level].attn[i_block](h)
  356. if i_level != 0:
  357. h = self.up[i_level].upsample(h)
  358. # end
  359. h = self.norm_out(h)
  360. h = nonlinearity(h)
  361. h = self.conv_out(h)
  362. return h
  363. def get_last_layer(self):
  364. return self.conv_out.weight
  365. class Encoder(nn.Module):
  366. def __init__(
  367. self,
  368. *,
  369. ch,
  370. out_ch,
  371. ch_mult=(1, 2, 4, 8),
  372. num_res_blocks,
  373. attn_resolutions,
  374. dropout=0.0,
  375. resamp_with_conv=True,
  376. in_channels,
  377. resolution,
  378. z_channels,
  379. double_z=True,
  380. use_linear_attn=False,
  381. attn_type="vanilla",
  382. **ignore_kwargs,
  383. ):
  384. super().__init__()
  385. if use_linear_attn:
  386. attn_type = "linear"
  387. self.ch = ch
  388. self.temb_ch = 0
  389. self.num_resolutions = len(ch_mult)
  390. self.num_res_blocks = num_res_blocks
  391. self.resolution = resolution
  392. self.in_channels = in_channels
  393. # downsampling
  394. self.conv_in = torch.nn.Conv2d(
  395. in_channels, self.ch, kernel_size=3, stride=1, padding=1
  396. )
  397. curr_res = resolution
  398. in_ch_mult = (1,) + tuple(ch_mult)
  399. self.in_ch_mult = in_ch_mult
  400. self.down = nn.ModuleList()
  401. for i_level in range(self.num_resolutions):
  402. block = nn.ModuleList()
  403. attn = nn.ModuleList()
  404. block_in = ch * in_ch_mult[i_level]
  405. block_out = ch * ch_mult[i_level]
  406. for i_block in range(self.num_res_blocks):
  407. block.append(
  408. ResnetBlock(
  409. in_channels=block_in,
  410. out_channels=block_out,
  411. temb_channels=self.temb_ch,
  412. dropout=dropout,
  413. )
  414. )
  415. block_in = block_out
  416. if curr_res in attn_resolutions:
  417. attn.append(make_attn(block_in, attn_type=attn_type))
  418. down = nn.Module()
  419. down.block = block
  420. down.attn = attn
  421. if i_level != self.num_resolutions - 1:
  422. down.downsample = Downsample(block_in, resamp_with_conv)
  423. curr_res = curr_res // 2
  424. self.down.append(down)
  425. # middle
  426. self.mid = nn.Module()
  427. self.mid.block_1 = ResnetBlock(
  428. in_channels=block_in,
  429. out_channels=block_in,
  430. temb_channels=self.temb_ch,
  431. dropout=dropout,
  432. )
  433. self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
  434. self.mid.block_2 = ResnetBlock(
  435. in_channels=block_in,
  436. out_channels=block_in,
  437. temb_channels=self.temb_ch,
  438. dropout=dropout,
  439. )
  440. # end
  441. self.norm_out = Normalize(block_in)
  442. self.conv_out = torch.nn.Conv2d(
  443. block_in,
  444. 2 * z_channels if double_z else z_channels,
  445. kernel_size=3,
  446. stride=1,
  447. padding=1,
  448. )
  449. def forward(self, x):
  450. # timestep embedding
  451. temb = None
  452. # downsampling
  453. hs = [self.conv_in(x)]
  454. for i_level in range(self.num_resolutions):
  455. for i_block in range(self.num_res_blocks):
  456. h = self.down[i_level].block[i_block](hs[-1], temb)
  457. if len(self.down[i_level].attn) > 0:
  458. h = self.down[i_level].attn[i_block](h)
  459. hs.append(h)
  460. if i_level != self.num_resolutions - 1:
  461. hs.append(self.down[i_level].downsample(hs[-1]))
  462. # middle
  463. h = hs[-1]
  464. h = self.mid.block_1(h, temb)
  465. h = self.mid.attn_1(h)
  466. h = self.mid.block_2(h, temb)
  467. # end
  468. h = self.norm_out(h)
  469. h = nonlinearity(h)
  470. h = self.conv_out(h)
  471. return h
  472. class Decoder(nn.Module):
  473. def __init__(
  474. self,
  475. *,
  476. ch,
  477. out_ch,
  478. ch_mult=(1, 2, 4, 8),
  479. num_res_blocks,
  480. attn_resolutions,
  481. dropout=0.0,
  482. resamp_with_conv=True,
  483. in_channels,
  484. resolution,
  485. z_channels,
  486. give_pre_end=False,
  487. tanh_out=False,
  488. use_linear_attn=False,
  489. attn_type="vanilla",
  490. **ignorekwargs,
  491. ):
  492. super().__init__()
  493. if use_linear_attn:
  494. attn_type = "linear"
  495. self.ch = ch
  496. self.temb_ch = 0
  497. self.num_resolutions = len(ch_mult)
  498. self.num_res_blocks = num_res_blocks
  499. self.resolution = resolution
  500. self.in_channels = in_channels
  501. self.give_pre_end = give_pre_end
  502. self.tanh_out = tanh_out
  503. # compute in_ch_mult, block_in and curr_res at lowest res
  504. in_ch_mult = (1,) + tuple(ch_mult)
  505. block_in = ch * ch_mult[self.num_resolutions - 1]
  506. curr_res = resolution // 2 ** (self.num_resolutions - 1)
  507. self.z_shape = (1, z_channels, curr_res, curr_res)
  508. print(
  509. "Working with z of shape {} = {} dimensions.".format(
  510. self.z_shape, np.prod(self.z_shape)
  511. )
  512. )
  513. # z to block_in
  514. self.conv_in = torch.nn.Conv2d(
  515. z_channels, block_in, kernel_size=3, stride=1, padding=1
  516. )
  517. # middle
  518. self.mid = nn.Module()
  519. self.mid.block_1 = ResnetBlock(
  520. in_channels=block_in,
  521. out_channels=block_in,
  522. temb_channels=self.temb_ch,
  523. dropout=dropout,
  524. )
  525. self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
  526. self.mid.block_2 = ResnetBlock(
  527. in_channels=block_in,
  528. out_channels=block_in,
  529. temb_channels=self.temb_ch,
  530. dropout=dropout,
  531. )
  532. # upsampling
  533. self.up = nn.ModuleList()
  534. for i_level in reversed(range(self.num_resolutions)):
  535. block = nn.ModuleList()
  536. attn = nn.ModuleList()
  537. block_out = ch * ch_mult[i_level]
  538. for i_block in range(self.num_res_blocks + 1):
  539. block.append(
  540. ResnetBlock(
  541. in_channels=block_in,
  542. out_channels=block_out,
  543. temb_channels=self.temb_ch,
  544. dropout=dropout,
  545. )
  546. )
  547. block_in = block_out
  548. if curr_res in attn_resolutions:
  549. attn.append(make_attn(block_in, attn_type=attn_type))
  550. up = nn.Module()
  551. up.block = block
  552. up.attn = attn
  553. if i_level != 0:
  554. up.upsample = Upsample(block_in, resamp_with_conv)
  555. curr_res = curr_res * 2
  556. self.up.insert(0, up) # prepend to get consistent order
  557. # end
  558. self.norm_out = Normalize(block_in)
  559. self.conv_out = torch.nn.Conv2d(
  560. block_in, out_ch, kernel_size=3, stride=1, padding=1
  561. )
  562. def forward(self, z):
  563. # assert z.shape[1:] == self.z_shape[1:]
  564. self.last_z_shape = z.shape
  565. # timestep embedding
  566. temb = None
  567. # z to block_in
  568. h = self.conv_in(z)
  569. # middle
  570. h = self.mid.block_1(h, temb)
  571. h = self.mid.attn_1(h)
  572. h = self.mid.block_2(h, temb)
  573. # upsampling
  574. for i_level in reversed(range(self.num_resolutions)):
  575. for i_block in range(self.num_res_blocks + 1):
  576. h = self.up[i_level].block[i_block](h, temb)
  577. if len(self.up[i_level].attn) > 0:
  578. h = self.up[i_level].attn[i_block](h)
  579. if i_level != 0:
  580. h = self.up[i_level].upsample(h)
  581. # end
  582. if self.give_pre_end:
  583. return h
  584. h = self.norm_out(h)
  585. h = nonlinearity(h)
  586. h = self.conv_out(h)
  587. if self.tanh_out:
  588. h = torch.tanh(h)
  589. return h
  590. class SimpleDecoder(nn.Module):
  591. def __init__(self, in_channels, out_channels, *args, **kwargs):
  592. super().__init__()
  593. self.model = nn.ModuleList(
  594. [
  595. nn.Conv2d(in_channels, in_channels, 1),
  596. ResnetBlock(
  597. in_channels=in_channels,
  598. out_channels=2 * in_channels,
  599. temb_channels=0,
  600. dropout=0.0,
  601. ),
  602. ResnetBlock(
  603. in_channels=2 * in_channels,
  604. out_channels=4 * in_channels,
  605. temb_channels=0,
  606. dropout=0.0,
  607. ),
  608. ResnetBlock(
  609. in_channels=4 * in_channels,
  610. out_channels=2 * in_channels,
  611. temb_channels=0,
  612. dropout=0.0,
  613. ),
  614. nn.Conv2d(2 * in_channels, in_channels, 1),
  615. Upsample(in_channels, with_conv=True),
  616. ]
  617. )
  618. # end
  619. self.norm_out = Normalize(in_channels)
  620. self.conv_out = torch.nn.Conv2d(
  621. in_channels, out_channels, kernel_size=3, stride=1, padding=1
  622. )
  623. def forward(self, x):
  624. for i, layer in enumerate(self.model):
  625. if i in [1, 2, 3]:
  626. x = layer(x, None)
  627. else:
  628. x = layer(x)
  629. h = self.norm_out(x)
  630. h = nonlinearity(h)
  631. x = self.conv_out(h)
  632. return x
  633. class UpsampleDecoder(nn.Module):
  634. def __init__(
  635. self,
  636. in_channels,
  637. out_channels,
  638. ch,
  639. num_res_blocks,
  640. resolution,
  641. ch_mult=(2, 2),
  642. dropout=0.0,
  643. ):
  644. super().__init__()
  645. # upsampling
  646. self.temb_ch = 0
  647. self.num_resolutions = len(ch_mult)
  648. self.num_res_blocks = num_res_blocks
  649. block_in = in_channels
  650. curr_res = resolution // 2 ** (self.num_resolutions - 1)
  651. self.res_blocks = nn.ModuleList()
  652. self.upsample_blocks = nn.ModuleList()
  653. for i_level in range(self.num_resolutions):
  654. res_block = []
  655. block_out = ch * ch_mult[i_level]
  656. for i_block in range(self.num_res_blocks + 1):
  657. res_block.append(
  658. ResnetBlock(
  659. in_channels=block_in,
  660. out_channels=block_out,
  661. temb_channels=self.temb_ch,
  662. dropout=dropout,
  663. )
  664. )
  665. block_in = block_out
  666. self.res_blocks.append(nn.ModuleList(res_block))
  667. if i_level != self.num_resolutions - 1:
  668. self.upsample_blocks.append(Upsample(block_in, True))
  669. curr_res = curr_res * 2
  670. # end
  671. self.norm_out = Normalize(block_in)
  672. self.conv_out = torch.nn.Conv2d(
  673. block_in, out_channels, kernel_size=3, stride=1, padding=1
  674. )
  675. def forward(self, x):
  676. # upsampling
  677. h = x
  678. for k, i_level in enumerate(range(self.num_resolutions)):
  679. for i_block in range(self.num_res_blocks + 1):
  680. h = self.res_blocks[i_level][i_block](h, None)
  681. if i_level != self.num_resolutions - 1:
  682. h = self.upsample_blocks[k](h)
  683. h = self.norm_out(h)
  684. h = nonlinearity(h)
  685. h = self.conv_out(h)
  686. return h
  687. class LatentRescaler(nn.Module):
  688. def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):
  689. super().__init__()
  690. # residual block, interpolate, residual block
  691. self.factor = factor
  692. self.conv_in = nn.Conv2d(
  693. in_channels, mid_channels, kernel_size=3, stride=1, padding=1
  694. )
  695. self.res_block1 = nn.ModuleList(
  696. [
  697. ResnetBlock(
  698. in_channels=mid_channels,
  699. out_channels=mid_channels,
  700. temb_channels=0,
  701. dropout=0.0,
  702. )
  703. for _ in range(depth)
  704. ]
  705. )
  706. self.attn = AttnBlock(mid_channels)
  707. self.res_block2 = nn.ModuleList(
  708. [
  709. ResnetBlock(
  710. in_channels=mid_channels,
  711. out_channels=mid_channels,
  712. temb_channels=0,
  713. dropout=0.0,
  714. )
  715. for _ in range(depth)
  716. ]
  717. )
  718. self.conv_out = nn.Conv2d(
  719. mid_channels,
  720. out_channels,
  721. kernel_size=1,
  722. )
  723. def forward(self, x):
  724. x = self.conv_in(x)
  725. for block in self.res_block1:
  726. x = block(x, None)
  727. x = torch.nn.functional.interpolate(
  728. x,
  729. size=(
  730. int(round(x.shape[2] * self.factor)),
  731. int(round(x.shape[3] * self.factor)),
  732. ),
  733. )
  734. x = self.attn(x)
  735. for block in self.res_block2:
  736. x = block(x, None)
  737. x = self.conv_out(x)
  738. return x
  739. class MergedRescaleEncoder(nn.Module):
  740. def __init__(
  741. self,
  742. in_channels,
  743. ch,
  744. resolution,
  745. out_ch,
  746. num_res_blocks,
  747. attn_resolutions,
  748. dropout=0.0,
  749. resamp_with_conv=True,
  750. ch_mult=(1, 2, 4, 8),
  751. rescale_factor=1.0,
  752. rescale_module_depth=1,
  753. ):
  754. super().__init__()
  755. intermediate_chn = ch * ch_mult[-1]
  756. self.encoder = Encoder(
  757. in_channels=in_channels,
  758. num_res_blocks=num_res_blocks,
  759. ch=ch,
  760. ch_mult=ch_mult,
  761. z_channels=intermediate_chn,
  762. double_z=False,
  763. resolution=resolution,
  764. attn_resolutions=attn_resolutions,
  765. dropout=dropout,
  766. resamp_with_conv=resamp_with_conv,
  767. out_ch=None,
  768. )
  769. self.rescaler = LatentRescaler(
  770. factor=rescale_factor,
  771. in_channels=intermediate_chn,
  772. mid_channels=intermediate_chn,
  773. out_channels=out_ch,
  774. depth=rescale_module_depth,
  775. )
  776. def forward(self, x):
  777. x = self.encoder(x)
  778. x = self.rescaler(x)
  779. return x
  780. class MergedRescaleDecoder(nn.Module):
  781. def __init__(
  782. self,
  783. z_channels,
  784. out_ch,
  785. resolution,
  786. num_res_blocks,
  787. attn_resolutions,
  788. ch,
  789. ch_mult=(1, 2, 4, 8),
  790. dropout=0.0,
  791. resamp_with_conv=True,
  792. rescale_factor=1.0,
  793. rescale_module_depth=1,
  794. ):
  795. super().__init__()
  796. tmp_chn = z_channels * ch_mult[-1]
  797. self.decoder = Decoder(
  798. out_ch=out_ch,
  799. z_channels=tmp_chn,
  800. attn_resolutions=attn_resolutions,
  801. dropout=dropout,
  802. resamp_with_conv=resamp_with_conv,
  803. in_channels=None,
  804. num_res_blocks=num_res_blocks,
  805. ch_mult=ch_mult,
  806. resolution=resolution,
  807. ch=ch,
  808. )
  809. self.rescaler = LatentRescaler(
  810. factor=rescale_factor,
  811. in_channels=z_channels,
  812. mid_channels=tmp_chn,
  813. out_channels=tmp_chn,
  814. depth=rescale_module_depth,
  815. )
  816. def forward(self, x):
  817. x = self.rescaler(x)
  818. x = self.decoder(x)
  819. return x
  820. class Upsampler(nn.Module):
  821. def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):
  822. super().__init__()
  823. assert out_size >= in_size
  824. num_blocks = int(np.log2(out_size // in_size)) + 1
  825. factor_up = 1.0 + (out_size % in_size)
  826. print(
  827. f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}"
  828. )
  829. self.rescaler = LatentRescaler(
  830. factor=factor_up,
  831. in_channels=in_channels,
  832. mid_channels=2 * in_channels,
  833. out_channels=in_channels,
  834. )
  835. self.decoder = Decoder(
  836. out_ch=out_channels,
  837. resolution=out_size,
  838. z_channels=in_channels,
  839. num_res_blocks=2,
  840. attn_resolutions=[],
  841. in_channels=None,
  842. ch=in_channels,
  843. ch_mult=[ch_mult for _ in range(num_blocks)],
  844. )
  845. def forward(self, x):
  846. x = self.rescaler(x)
  847. x = self.decoder(x)
  848. return x
  849. class Resize(nn.Module):
  850. def __init__(self, in_channels=None, learned=False, mode="bilinear"):
  851. super().__init__()
  852. self.with_conv = learned
  853. self.mode = mode
  854. if self.with_conv:
  855. print(
  856. f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode"
  857. )
  858. raise NotImplementedError()
  859. assert in_channels is not None
  860. # no asymmetric padding in torch conv, must do it ourselves
  861. self.conv = torch.nn.Conv2d(
  862. in_channels, in_channels, kernel_size=4, stride=2, padding=1
  863. )
  864. def forward(self, x, scale_factor=1.0):
  865. if scale_factor == 1.0:
  866. return x
  867. else:
  868. x = torch.nn.functional.interpolate(
  869. x, mode=self.mode, align_corners=False, scale_factor=scale_factor
  870. )
  871. return x