| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973 |
- # pytorch_diffusion + derived encoder decoder
- import math
- import numpy as np
- import torch
- import torch.nn as nn
- def get_timestep_embedding(timesteps, embedding_dim):
- """
- This matches the implementation in Denoising Diffusion Probabilistic Models:
- From Fairseq.
- Build sinusoidal embeddings.
- This matches the implementation in tensor2tensor, but differs slightly
- from the description in Section 3.5 of "Attention Is All You Need".
- """
- assert len(timesteps.shape) == 1
- half_dim = embedding_dim // 2
- emb = math.log(10000) / (half_dim - 1)
- emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
- emb = emb.to(device=timesteps.device)
- emb = timesteps.float()[:, None] * emb[None, :]
- emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
- if embedding_dim % 2 == 1: # zero pad
- emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
- return emb
- def nonlinearity(x):
- # swish
- return x * torch.sigmoid(x)
- def Normalize(in_channels, num_groups=32):
- return torch.nn.GroupNorm(
- num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
- )
- class Upsample(nn.Module):
- def __init__(self, in_channels, with_conv):
- super().__init__()
- self.with_conv = with_conv
- if self.with_conv:
- self.conv = torch.nn.Conv2d(
- in_channels, in_channels, kernel_size=3, stride=1, padding=1
- )
- def forward(self, x):
- x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
- if self.with_conv:
- x = self.conv(x)
- return x
- class Downsample(nn.Module):
- def __init__(self, in_channels, with_conv):
- super().__init__()
- self.with_conv = with_conv
- if self.with_conv:
- # no asymmetric padding in torch conv, must do it ourselves
- self.conv = torch.nn.Conv2d(
- in_channels, in_channels, kernel_size=3, stride=2, padding=0
- )
- def forward(self, x):
- if self.with_conv:
- pad = (0, 1, 0, 1)
- x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
- x = self.conv(x)
- else:
- x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
- return x
- class ResnetBlock(nn.Module):
- def __init__(
- self,
- *,
- in_channels,
- out_channels=None,
- conv_shortcut=False,
- dropout,
- temb_channels=512,
- ):
- super().__init__()
- self.in_channels = in_channels
- out_channels = in_channels if out_channels is None else out_channels
- self.out_channels = out_channels
- self.use_conv_shortcut = conv_shortcut
- self.norm1 = Normalize(in_channels)
- self.conv1 = torch.nn.Conv2d(
- in_channels, out_channels, kernel_size=3, stride=1, padding=1
- )
- if temb_channels > 0:
- self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
- self.norm2 = Normalize(out_channels)
- self.dropout = torch.nn.Dropout(dropout)
- self.conv2 = torch.nn.Conv2d(
- out_channels, out_channels, kernel_size=3, stride=1, padding=1
- )
- if self.in_channels != self.out_channels:
- if self.use_conv_shortcut:
- self.conv_shortcut = torch.nn.Conv2d(
- in_channels, out_channels, kernel_size=3, stride=1, padding=1
- )
- else:
- self.nin_shortcut = torch.nn.Conv2d(
- in_channels, out_channels, kernel_size=1, stride=1, padding=0
- )
- def forward(self, x, temb):
- h = x
- h = self.norm1(h)
- h = nonlinearity(h)
- h = self.conv1(h)
- if temb is not None:
- h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
- h = self.norm2(h)
- h = nonlinearity(h)
- h = self.dropout(h)
- h = self.conv2(h)
- if self.in_channels != self.out_channels:
- if self.use_conv_shortcut:
- x = self.conv_shortcut(x)
- else:
- x = self.nin_shortcut(x)
- return x + h
- class AttnBlock(nn.Module):
- def __init__(self, in_channels):
- super().__init__()
- self.in_channels = in_channels
- self.norm = Normalize(in_channels)
- self.q = torch.nn.Conv2d(
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
- )
- self.k = torch.nn.Conv2d(
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
- )
- self.v = torch.nn.Conv2d(
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
- )
- self.proj_out = torch.nn.Conv2d(
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
- )
- def forward(self, x):
- h_ = x
- h_ = self.norm(h_)
- q = self.q(h_)
- k = self.k(h_)
- v = self.v(h_)
- # compute attention
- b, c, h, w = q.shape
- q = q.reshape(b, c, h * w)
- q = q.permute(0, 2, 1) # b,hw,c
- k = k.reshape(b, c, h * w) # b,c,hw
- w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
- w_ = w_ * (int(c) ** (-0.5))
- w_ = torch.nn.functional.softmax(w_, dim=2)
- # attend to values
- v = v.reshape(b, c, h * w)
- w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
- h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
- h_ = h_.reshape(b, c, h, w)
- h_ = self.proj_out(h_)
- return x + h_
- class AttnBlock2_0(nn.Module):
- def __init__(self, in_channels):
- super().__init__()
- self.in_channels = in_channels
- self.norm = Normalize(in_channels)
- self.q = torch.nn.Conv2d(
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
- )
- self.k = torch.nn.Conv2d(
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
- )
- self.v = torch.nn.Conv2d(
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
- )
- self.proj_out = torch.nn.Conv2d(
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
- )
- def forward(self, x):
- h_ = x
- h_ = self.norm(h_)
- # output: [1, 512, 64, 64]
- q = self.q(h_)
- k = self.k(h_)
- v = self.v(h_)
- # compute attention
- b, c, h, w = q.shape
- # q = q.reshape(b, c, h * w).transpose()
- # q = q.permute(0, 2, 1) # b,hw,c
- # k = k.reshape(b, c, h * w) # b,c,hw
- q = q.transpose(1, 2)
- k = k.transpose(1, 2)
- v = v.transpose(1, 2)
- # (batch, num_heads, seq_len, head_dim)
- hidden_states = torch.nn.functional.scaled_dot_product_attention(
- q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False
- )
- hidden_states = hidden_states.transpose(1, 2)
- hidden_states = hidden_states.to(q.dtype)
- h_ = self.proj_out(hidden_states)
- return x + h_
- def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
- assert attn_type in [
- "vanilla",
- "vanilla-xformers",
- "memory-efficient-cross-attn",
- "linear",
- "none",
- ], f"attn_type {attn_type} unknown"
- assert attn_kwargs is None
- if hasattr(torch.nn.functional, "scaled_dot_product_attention"):
- # print(f"Using torch.nn.functional.scaled_dot_product_attention")
- return AttnBlock2_0(in_channels)
- return AttnBlock(in_channels)
- class Model(nn.Module):
- def __init__(
- self,
- *,
- ch,
- out_ch,
- ch_mult=(1, 2, 4, 8),
- num_res_blocks,
- attn_resolutions,
- dropout=0.0,
- resamp_with_conv=True,
- in_channels,
- resolution,
- use_timestep=True,
- use_linear_attn=False,
- attn_type="vanilla",
- ):
- super().__init__()
- if use_linear_attn:
- attn_type = "linear"
- self.ch = ch
- self.temb_ch = self.ch * 4
- self.num_resolutions = len(ch_mult)
- self.num_res_blocks = num_res_blocks
- self.resolution = resolution
- self.in_channels = in_channels
- self.use_timestep = use_timestep
- if self.use_timestep:
- # timestep embedding
- self.temb = nn.Module()
- self.temb.dense = nn.ModuleList(
- [
- torch.nn.Linear(self.ch, self.temb_ch),
- torch.nn.Linear(self.temb_ch, self.temb_ch),
- ]
- )
- # downsampling
- self.conv_in = torch.nn.Conv2d(
- in_channels, self.ch, kernel_size=3, stride=1, padding=1
- )
- curr_res = resolution
- in_ch_mult = (1,) + tuple(ch_mult)
- self.down = nn.ModuleList()
- for i_level in range(self.num_resolutions):
- block = nn.ModuleList()
- attn = nn.ModuleList()
- block_in = ch * in_ch_mult[i_level]
- block_out = ch * ch_mult[i_level]
- for i_block in range(self.num_res_blocks):
- block.append(
- ResnetBlock(
- in_channels=block_in,
- out_channels=block_out,
- temb_channels=self.temb_ch,
- dropout=dropout,
- )
- )
- block_in = block_out
- if curr_res in attn_resolutions:
- attn.append(make_attn(block_in, attn_type=attn_type))
- down = nn.Module()
- down.block = block
- down.attn = attn
- if i_level != self.num_resolutions - 1:
- down.downsample = Downsample(block_in, resamp_with_conv)
- curr_res = curr_res // 2
- self.down.append(down)
- # middle
- self.mid = nn.Module()
- self.mid.block_1 = ResnetBlock(
- in_channels=block_in,
- out_channels=block_in,
- temb_channels=self.temb_ch,
- dropout=dropout,
- )
- self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
- self.mid.block_2 = ResnetBlock(
- in_channels=block_in,
- out_channels=block_in,
- temb_channels=self.temb_ch,
- dropout=dropout,
- )
- # upsampling
- self.up = nn.ModuleList()
- for i_level in reversed(range(self.num_resolutions)):
- block = nn.ModuleList()
- attn = nn.ModuleList()
- block_out = ch * ch_mult[i_level]
- skip_in = ch * ch_mult[i_level]
- for i_block in range(self.num_res_blocks + 1):
- if i_block == self.num_res_blocks:
- skip_in = ch * in_ch_mult[i_level]
- block.append(
- ResnetBlock(
- in_channels=block_in + skip_in,
- out_channels=block_out,
- temb_channels=self.temb_ch,
- dropout=dropout,
- )
- )
- block_in = block_out
- if curr_res in attn_resolutions:
- attn.append(make_attn(block_in, attn_type=attn_type))
- up = nn.Module()
- up.block = block
- up.attn = attn
- if i_level != 0:
- up.upsample = Upsample(block_in, resamp_with_conv)
- curr_res = curr_res * 2
- self.up.insert(0, up) # prepend to get consistent order
- # end
- self.norm_out = Normalize(block_in)
- self.conv_out = torch.nn.Conv2d(
- block_in, out_ch, kernel_size=3, stride=1, padding=1
- )
- def forward(self, x, t=None, context=None):
- # assert x.shape[2] == x.shape[3] == self.resolution
- if context is not None:
- # assume aligned context, cat along channel axis
- x = torch.cat((x, context), dim=1)
- if self.use_timestep:
- # timestep embedding
- assert t is not None
- temb = get_timestep_embedding(t, self.ch)
- temb = self.temb.dense[0](temb)
- temb = nonlinearity(temb)
- temb = self.temb.dense[1](temb)
- else:
- temb = None
- # downsampling
- hs = [self.conv_in(x)]
- for i_level in range(self.num_resolutions):
- for i_block in range(self.num_res_blocks):
- h = self.down[i_level].block[i_block](hs[-1], temb)
- if len(self.down[i_level].attn) > 0:
- h = self.down[i_level].attn[i_block](h)
- hs.append(h)
- if i_level != self.num_resolutions - 1:
- hs.append(self.down[i_level].downsample(hs[-1]))
- # middle
- h = hs[-1]
- h = self.mid.block_1(h, temb)
- h = self.mid.attn_1(h)
- h = self.mid.block_2(h, temb)
- # upsampling
- for i_level in reversed(range(self.num_resolutions)):
- for i_block in range(self.num_res_blocks + 1):
- h = self.up[i_level].block[i_block](
- torch.cat([h, hs.pop()], dim=1), temb
- )
- if len(self.up[i_level].attn) > 0:
- h = self.up[i_level].attn[i_block](h)
- if i_level != 0:
- h = self.up[i_level].upsample(h)
- # end
- h = self.norm_out(h)
- h = nonlinearity(h)
- h = self.conv_out(h)
- return h
- def get_last_layer(self):
- return self.conv_out.weight
- class Encoder(nn.Module):
- def __init__(
- self,
- *,
- ch,
- out_ch,
- ch_mult=(1, 2, 4, 8),
- num_res_blocks,
- attn_resolutions,
- dropout=0.0,
- resamp_with_conv=True,
- in_channels,
- resolution,
- z_channels,
- double_z=True,
- use_linear_attn=False,
- attn_type="vanilla",
- **ignore_kwargs,
- ):
- super().__init__()
- if use_linear_attn:
- attn_type = "linear"
- self.ch = ch
- self.temb_ch = 0
- self.num_resolutions = len(ch_mult)
- self.num_res_blocks = num_res_blocks
- self.resolution = resolution
- self.in_channels = in_channels
- # downsampling
- self.conv_in = torch.nn.Conv2d(
- in_channels, self.ch, kernel_size=3, stride=1, padding=1
- )
- curr_res = resolution
- in_ch_mult = (1,) + tuple(ch_mult)
- self.in_ch_mult = in_ch_mult
- self.down = nn.ModuleList()
- for i_level in range(self.num_resolutions):
- block = nn.ModuleList()
- attn = nn.ModuleList()
- block_in = ch * in_ch_mult[i_level]
- block_out = ch * ch_mult[i_level]
- for i_block in range(self.num_res_blocks):
- block.append(
- ResnetBlock(
- in_channels=block_in,
- out_channels=block_out,
- temb_channels=self.temb_ch,
- dropout=dropout,
- )
- )
- block_in = block_out
- if curr_res in attn_resolutions:
- attn.append(make_attn(block_in, attn_type=attn_type))
- down = nn.Module()
- down.block = block
- down.attn = attn
- if i_level != self.num_resolutions - 1:
- down.downsample = Downsample(block_in, resamp_with_conv)
- curr_res = curr_res // 2
- self.down.append(down)
- # middle
- self.mid = nn.Module()
- self.mid.block_1 = ResnetBlock(
- in_channels=block_in,
- out_channels=block_in,
- temb_channels=self.temb_ch,
- dropout=dropout,
- )
- self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
- self.mid.block_2 = ResnetBlock(
- in_channels=block_in,
- out_channels=block_in,
- temb_channels=self.temb_ch,
- dropout=dropout,
- )
- # end
- self.norm_out = Normalize(block_in)
- self.conv_out = torch.nn.Conv2d(
- block_in,
- 2 * z_channels if double_z else z_channels,
- kernel_size=3,
- stride=1,
- padding=1,
- )
- def forward(self, x):
- # timestep embedding
- temb = None
- # downsampling
- hs = [self.conv_in(x)]
- for i_level in range(self.num_resolutions):
- for i_block in range(self.num_res_blocks):
- h = self.down[i_level].block[i_block](hs[-1], temb)
- if len(self.down[i_level].attn) > 0:
- h = self.down[i_level].attn[i_block](h)
- hs.append(h)
- if i_level != self.num_resolutions - 1:
- hs.append(self.down[i_level].downsample(hs[-1]))
- # middle
- h = hs[-1]
- h = self.mid.block_1(h, temb)
- h = self.mid.attn_1(h)
- h = self.mid.block_2(h, temb)
- # end
- h = self.norm_out(h)
- h = nonlinearity(h)
- h = self.conv_out(h)
- return h
- class Decoder(nn.Module):
- def __init__(
- self,
- *,
- ch,
- out_ch,
- ch_mult=(1, 2, 4, 8),
- num_res_blocks,
- attn_resolutions,
- dropout=0.0,
- resamp_with_conv=True,
- in_channels,
- resolution,
- z_channels,
- give_pre_end=False,
- tanh_out=False,
- use_linear_attn=False,
- attn_type="vanilla",
- **ignorekwargs,
- ):
- super().__init__()
- if use_linear_attn:
- attn_type = "linear"
- self.ch = ch
- self.temb_ch = 0
- self.num_resolutions = len(ch_mult)
- self.num_res_blocks = num_res_blocks
- self.resolution = resolution
- self.in_channels = in_channels
- self.give_pre_end = give_pre_end
- self.tanh_out = tanh_out
- # compute in_ch_mult, block_in and curr_res at lowest res
- in_ch_mult = (1,) + tuple(ch_mult)
- block_in = ch * ch_mult[self.num_resolutions - 1]
- curr_res = resolution // 2 ** (self.num_resolutions - 1)
- self.z_shape = (1, z_channels, curr_res, curr_res)
- print(
- "Working with z of shape {} = {} dimensions.".format(
- self.z_shape, np.prod(self.z_shape)
- )
- )
- # z to block_in
- self.conv_in = torch.nn.Conv2d(
- z_channels, block_in, kernel_size=3, stride=1, padding=1
- )
- # middle
- self.mid = nn.Module()
- self.mid.block_1 = ResnetBlock(
- in_channels=block_in,
- out_channels=block_in,
- temb_channels=self.temb_ch,
- dropout=dropout,
- )
- self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
- self.mid.block_2 = ResnetBlock(
- in_channels=block_in,
- out_channels=block_in,
- temb_channels=self.temb_ch,
- dropout=dropout,
- )
- # upsampling
- self.up = nn.ModuleList()
- for i_level in reversed(range(self.num_resolutions)):
- block = nn.ModuleList()
- attn = nn.ModuleList()
- block_out = ch * ch_mult[i_level]
- for i_block in range(self.num_res_blocks + 1):
- block.append(
- ResnetBlock(
- in_channels=block_in,
- out_channels=block_out,
- temb_channels=self.temb_ch,
- dropout=dropout,
- )
- )
- block_in = block_out
- if curr_res in attn_resolutions:
- attn.append(make_attn(block_in, attn_type=attn_type))
- up = nn.Module()
- up.block = block
- up.attn = attn
- if i_level != 0:
- up.upsample = Upsample(block_in, resamp_with_conv)
- curr_res = curr_res * 2
- self.up.insert(0, up) # prepend to get consistent order
- # end
- self.norm_out = Normalize(block_in)
- self.conv_out = torch.nn.Conv2d(
- block_in, out_ch, kernel_size=3, stride=1, padding=1
- )
- def forward(self, z):
- # assert z.shape[1:] == self.z_shape[1:]
- self.last_z_shape = z.shape
- # timestep embedding
- temb = None
- # z to block_in
- h = self.conv_in(z)
- # middle
- h = self.mid.block_1(h, temb)
- h = self.mid.attn_1(h)
- h = self.mid.block_2(h, temb)
- # upsampling
- for i_level in reversed(range(self.num_resolutions)):
- for i_block in range(self.num_res_blocks + 1):
- h = self.up[i_level].block[i_block](h, temb)
- if len(self.up[i_level].attn) > 0:
- h = self.up[i_level].attn[i_block](h)
- if i_level != 0:
- h = self.up[i_level].upsample(h)
- # end
- if self.give_pre_end:
- return h
- h = self.norm_out(h)
- h = nonlinearity(h)
- h = self.conv_out(h)
- if self.tanh_out:
- h = torch.tanh(h)
- return h
- class SimpleDecoder(nn.Module):
- def __init__(self, in_channels, out_channels, *args, **kwargs):
- super().__init__()
- self.model = nn.ModuleList(
- [
- nn.Conv2d(in_channels, in_channels, 1),
- ResnetBlock(
- in_channels=in_channels,
- out_channels=2 * in_channels,
- temb_channels=0,
- dropout=0.0,
- ),
- ResnetBlock(
- in_channels=2 * in_channels,
- out_channels=4 * in_channels,
- temb_channels=0,
- dropout=0.0,
- ),
- ResnetBlock(
- in_channels=4 * in_channels,
- out_channels=2 * in_channels,
- temb_channels=0,
- dropout=0.0,
- ),
- nn.Conv2d(2 * in_channels, in_channels, 1),
- Upsample(in_channels, with_conv=True),
- ]
- )
- # end
- self.norm_out = Normalize(in_channels)
- self.conv_out = torch.nn.Conv2d(
- in_channels, out_channels, kernel_size=3, stride=1, padding=1
- )
- def forward(self, x):
- for i, layer in enumerate(self.model):
- if i in [1, 2, 3]:
- x = layer(x, None)
- else:
- x = layer(x)
- h = self.norm_out(x)
- h = nonlinearity(h)
- x = self.conv_out(h)
- return x
- class UpsampleDecoder(nn.Module):
- def __init__(
- self,
- in_channels,
- out_channels,
- ch,
- num_res_blocks,
- resolution,
- ch_mult=(2, 2),
- dropout=0.0,
- ):
- super().__init__()
- # upsampling
- self.temb_ch = 0
- self.num_resolutions = len(ch_mult)
- self.num_res_blocks = num_res_blocks
- block_in = in_channels
- curr_res = resolution // 2 ** (self.num_resolutions - 1)
- self.res_blocks = nn.ModuleList()
- self.upsample_blocks = nn.ModuleList()
- for i_level in range(self.num_resolutions):
- res_block = []
- block_out = ch * ch_mult[i_level]
- for i_block in range(self.num_res_blocks + 1):
- res_block.append(
- ResnetBlock(
- in_channels=block_in,
- out_channels=block_out,
- temb_channels=self.temb_ch,
- dropout=dropout,
- )
- )
- block_in = block_out
- self.res_blocks.append(nn.ModuleList(res_block))
- if i_level != self.num_resolutions - 1:
- self.upsample_blocks.append(Upsample(block_in, True))
- curr_res = curr_res * 2
- # end
- self.norm_out = Normalize(block_in)
- self.conv_out = torch.nn.Conv2d(
- block_in, out_channels, kernel_size=3, stride=1, padding=1
- )
- def forward(self, x):
- # upsampling
- h = x
- for k, i_level in enumerate(range(self.num_resolutions)):
- for i_block in range(self.num_res_blocks + 1):
- h = self.res_blocks[i_level][i_block](h, None)
- if i_level != self.num_resolutions - 1:
- h = self.upsample_blocks[k](h)
- h = self.norm_out(h)
- h = nonlinearity(h)
- h = self.conv_out(h)
- return h
- class LatentRescaler(nn.Module):
- def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):
- super().__init__()
- # residual block, interpolate, residual block
- self.factor = factor
- self.conv_in = nn.Conv2d(
- in_channels, mid_channels, kernel_size=3, stride=1, padding=1
- )
- self.res_block1 = nn.ModuleList(
- [
- ResnetBlock(
- in_channels=mid_channels,
- out_channels=mid_channels,
- temb_channels=0,
- dropout=0.0,
- )
- for _ in range(depth)
- ]
- )
- self.attn = AttnBlock(mid_channels)
- self.res_block2 = nn.ModuleList(
- [
- ResnetBlock(
- in_channels=mid_channels,
- out_channels=mid_channels,
- temb_channels=0,
- dropout=0.0,
- )
- for _ in range(depth)
- ]
- )
- self.conv_out = nn.Conv2d(
- mid_channels,
- out_channels,
- kernel_size=1,
- )
- def forward(self, x):
- x = self.conv_in(x)
- for block in self.res_block1:
- x = block(x, None)
- x = torch.nn.functional.interpolate(
- x,
- size=(
- int(round(x.shape[2] * self.factor)),
- int(round(x.shape[3] * self.factor)),
- ),
- )
- x = self.attn(x)
- for block in self.res_block2:
- x = block(x, None)
- x = self.conv_out(x)
- return x
- class MergedRescaleEncoder(nn.Module):
- def __init__(
- self,
- in_channels,
- ch,
- resolution,
- out_ch,
- num_res_blocks,
- attn_resolutions,
- dropout=0.0,
- resamp_with_conv=True,
- ch_mult=(1, 2, 4, 8),
- rescale_factor=1.0,
- rescale_module_depth=1,
- ):
- super().__init__()
- intermediate_chn = ch * ch_mult[-1]
- self.encoder = Encoder(
- in_channels=in_channels,
- num_res_blocks=num_res_blocks,
- ch=ch,
- ch_mult=ch_mult,
- z_channels=intermediate_chn,
- double_z=False,
- resolution=resolution,
- attn_resolutions=attn_resolutions,
- dropout=dropout,
- resamp_with_conv=resamp_with_conv,
- out_ch=None,
- )
- self.rescaler = LatentRescaler(
- factor=rescale_factor,
- in_channels=intermediate_chn,
- mid_channels=intermediate_chn,
- out_channels=out_ch,
- depth=rescale_module_depth,
- )
- def forward(self, x):
- x = self.encoder(x)
- x = self.rescaler(x)
- return x
- class MergedRescaleDecoder(nn.Module):
- def __init__(
- self,
- z_channels,
- out_ch,
- resolution,
- num_res_blocks,
- attn_resolutions,
- ch,
- ch_mult=(1, 2, 4, 8),
- dropout=0.0,
- resamp_with_conv=True,
- rescale_factor=1.0,
- rescale_module_depth=1,
- ):
- super().__init__()
- tmp_chn = z_channels * ch_mult[-1]
- self.decoder = Decoder(
- out_ch=out_ch,
- z_channels=tmp_chn,
- attn_resolutions=attn_resolutions,
- dropout=dropout,
- resamp_with_conv=resamp_with_conv,
- in_channels=None,
- num_res_blocks=num_res_blocks,
- ch_mult=ch_mult,
- resolution=resolution,
- ch=ch,
- )
- self.rescaler = LatentRescaler(
- factor=rescale_factor,
- in_channels=z_channels,
- mid_channels=tmp_chn,
- out_channels=tmp_chn,
- depth=rescale_module_depth,
- )
- def forward(self, x):
- x = self.rescaler(x)
- x = self.decoder(x)
- return x
- class Upsampler(nn.Module):
- def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):
- super().__init__()
- assert out_size >= in_size
- num_blocks = int(np.log2(out_size // in_size)) + 1
- factor_up = 1.0 + (out_size % in_size)
- print(
- f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}"
- )
- self.rescaler = LatentRescaler(
- factor=factor_up,
- in_channels=in_channels,
- mid_channels=2 * in_channels,
- out_channels=in_channels,
- )
- self.decoder = Decoder(
- out_ch=out_channels,
- resolution=out_size,
- z_channels=in_channels,
- num_res_blocks=2,
- attn_resolutions=[],
- in_channels=None,
- ch=in_channels,
- ch_mult=[ch_mult for _ in range(num_blocks)],
- )
- def forward(self, x):
- x = self.rescaler(x)
- x = self.decoder(x)
- return x
- class Resize(nn.Module):
- def __init__(self, in_channels=None, learned=False, mode="bilinear"):
- super().__init__()
- self.with_conv = learned
- self.mode = mode
- if self.with_conv:
- print(
- f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode"
- )
- raise NotImplementedError()
- assert in_channels is not None
- # no asymmetric padding in torch conv, must do it ourselves
- self.conv = torch.nn.Conv2d(
- in_channels, in_channels, kernel_size=4, stride=2, padding=1
- )
- def forward(self, x, scale_factor=1.0):
- if scale_factor == 1.0:
- return x
- else:
- x = torch.nn.functional.interpolate(
- x, mode=self.mode, align_corners=False, scale_factor=scale_factor
- )
- return x
|