| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360 |
- import math
- # CrossAttn precision handling
- import os
- from inspect import isfunction
- from typing import Any, Optional
- import torch
- import torch.nn.functional as F
- from einops import rearrange, repeat
- from torch import einsum, nn
- from sorawm.iopaint.model.anytext.ldm.modules.diffusionmodules.util import checkpoint
- _ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
- def exists(val):
- return val is not None
- def uniq(arr):
- return {el: True for el in arr}.keys()
- def default(val, d):
- if exists(val):
- return val
- return d() if isfunction(d) else d
- def max_neg_value(t):
- return -torch.finfo(t.dtype).max
- def init_(tensor):
- dim = tensor.shape[-1]
- std = 1 / math.sqrt(dim)
- tensor.uniform_(-std, std)
- return tensor
- # feedforward
- class GEGLU(nn.Module):
- def __init__(self, dim_in, dim_out):
- super().__init__()
- self.proj = nn.Linear(dim_in, dim_out * 2)
- def forward(self, x):
- x, gate = self.proj(x).chunk(2, dim=-1)
- return x * F.gelu(gate)
- class FeedForward(nn.Module):
- def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
- super().__init__()
- inner_dim = int(dim * mult)
- dim_out = default(dim_out, dim)
- project_in = (
- nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
- if not glu
- else GEGLU(dim, inner_dim)
- )
- self.net = nn.Sequential(
- project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
- )
- def forward(self, x):
- return self.net(x)
- def zero_module(module):
- """
- Zero out the parameters of a module and return it.
- """
- for p in module.parameters():
- p.detach().zero_()
- return module
- def Normalize(in_channels):
- return torch.nn.GroupNorm(
- num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
- )
- class SpatialSelfAttention(nn.Module):
- def __init__(self, in_channels):
- super().__init__()
- self.in_channels = in_channels
- self.norm = Normalize(in_channels)
- self.q = torch.nn.Conv2d(
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
- )
- self.k = torch.nn.Conv2d(
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
- )
- self.v = torch.nn.Conv2d(
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
- )
- self.proj_out = torch.nn.Conv2d(
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
- )
- def forward(self, x):
- h_ = x
- h_ = self.norm(h_)
- q = self.q(h_)
- k = self.k(h_)
- v = self.v(h_)
- # compute attention
- b, c, h, w = q.shape
- q = rearrange(q, "b c h w -> b (h w) c")
- k = rearrange(k, "b c h w -> b c (h w)")
- w_ = torch.einsum("bij,bjk->bik", q, k)
- w_ = w_ * (int(c) ** (-0.5))
- w_ = torch.nn.functional.softmax(w_, dim=2)
- # attend to values
- v = rearrange(v, "b c h w -> b c (h w)")
- w_ = rearrange(w_, "b i j -> b j i")
- h_ = torch.einsum("bij,bjk->bik", v, w_)
- h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
- h_ = self.proj_out(h_)
- return x + h_
- class CrossAttention(nn.Module):
- def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
- super().__init__()
- inner_dim = dim_head * heads
- context_dim = default(context_dim, query_dim)
- self.scale = dim_head**-0.5
- self.heads = heads
- self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
- self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
- self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
- self.to_out = nn.Sequential(
- nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
- )
- def forward(self, x, context=None, mask=None):
- h = self.heads
- q = self.to_q(x)
- context = default(context, x)
- k = self.to_k(context)
- v = self.to_v(context)
- q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
- # force cast to fp32 to avoid overflowing
- if _ATTN_PRECISION == "fp32":
- with torch.autocast(enabled=False, device_type="cuda"):
- q, k = q.float(), k.float()
- sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
- else:
- sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
- del q, k
- if exists(mask):
- mask = rearrange(mask, "b ... -> b (...)")
- max_neg_value = -torch.finfo(sim.dtype).max
- mask = repeat(mask, "b j -> (b h) () j", h=h)
- sim.masked_fill_(~mask, max_neg_value)
- # attention, what we cannot get enough of
- sim = sim.softmax(dim=-1)
- out = einsum("b i j, b j d -> b i d", sim, v)
- out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
- return self.to_out(out)
- class SDPACrossAttention(CrossAttention):
- def forward(self, x, context=None, mask=None):
- batch_size, sequence_length, inner_dim = x.shape
- if mask is not None:
- mask = self.prepare_attention_mask(mask, sequence_length, batch_size)
- mask = mask.view(batch_size, self.heads, -1, mask.shape[-1])
- h = self.heads
- q_in = self.to_q(x)
- context = default(context, x)
- k_in = self.to_k(context)
- v_in = self.to_v(context)
- head_dim = inner_dim // h
- q = q_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
- k = k_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
- v = v_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
- del q_in, k_in, v_in
- dtype = q.dtype
- if _ATTN_PRECISION == "fp32":
- q, k, v = q.float(), k.float(), v.float()
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
- hidden_states = torch.nn.functional.scaled_dot_product_attention(
- q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False
- )
- hidden_states = hidden_states.transpose(1, 2).reshape(
- batch_size, -1, h * head_dim
- )
- hidden_states = hidden_states.to(dtype)
- # linear proj
- hidden_states = self.to_out[0](hidden_states)
- # dropout
- hidden_states = self.to_out[1](hidden_states)
- return hidden_states
- class BasicTransformerBlock(nn.Module):
- def __init__(
- self,
- dim,
- n_heads,
- d_head,
- dropout=0.0,
- context_dim=None,
- gated_ff=True,
- checkpoint=True,
- disable_self_attn=False,
- ):
- super().__init__()
- if hasattr(torch.nn.functional, "scaled_dot_product_attention"):
- attn_cls = SDPACrossAttention
- else:
- attn_cls = CrossAttention
- self.disable_self_attn = disable_self_attn
- self.attn1 = attn_cls(
- query_dim=dim,
- heads=n_heads,
- dim_head=d_head,
- dropout=dropout,
- context_dim=context_dim if self.disable_self_attn else None,
- ) # is a self-attention if not self.disable_self_attn
- self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
- self.attn2 = attn_cls(
- query_dim=dim,
- context_dim=context_dim,
- heads=n_heads,
- dim_head=d_head,
- dropout=dropout,
- ) # is self-attn if context is none
- self.norm1 = nn.LayerNorm(dim)
- self.norm2 = nn.LayerNorm(dim)
- self.norm3 = nn.LayerNorm(dim)
- self.checkpoint = checkpoint
- def forward(self, x, context=None):
- return checkpoint(
- self._forward, (x, context), self.parameters(), self.checkpoint
- )
- def _forward(self, x, context=None):
- x = (
- self.attn1(
- self.norm1(x), context=context if self.disable_self_attn else None
- )
- + x
- )
- x = self.attn2(self.norm2(x), context=context) + x
- x = self.ff(self.norm3(x)) + x
- return x
- class SpatialTransformer(nn.Module):
- """
- Transformer block for image-like data.
- First, project the input (aka embedding)
- and reshape to b, t, d.
- Then apply standard transformer action.
- Finally, reshape to image
- NEW: use_linear for more efficiency instead of the 1x1 convs
- """
- def __init__(
- self,
- in_channels,
- n_heads,
- d_head,
- depth=1,
- dropout=0.0,
- context_dim=None,
- disable_self_attn=False,
- use_linear=False,
- use_checkpoint=True,
- ):
- super().__init__()
- if exists(context_dim) and not isinstance(context_dim, list):
- context_dim = [context_dim]
- self.in_channels = in_channels
- inner_dim = n_heads * d_head
- self.norm = Normalize(in_channels)
- if not use_linear:
- self.proj_in = nn.Conv2d(
- in_channels, inner_dim, kernel_size=1, stride=1, padding=0
- )
- else:
- self.proj_in = nn.Linear(in_channels, inner_dim)
- self.transformer_blocks = nn.ModuleList(
- [
- BasicTransformerBlock(
- inner_dim,
- n_heads,
- d_head,
- dropout=dropout,
- context_dim=context_dim[d],
- disable_self_attn=disable_self_attn,
- checkpoint=use_checkpoint,
- )
- for d in range(depth)
- ]
- )
- if not use_linear:
- self.proj_out = zero_module(
- nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
- )
- else:
- self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
- self.use_linear = use_linear
- def forward(self, x, context=None):
- # note: if no context is given, cross-attention defaults to self-attention
- if not isinstance(context, list):
- context = [context]
- b, c, h, w = x.shape
- x_in = x
- x = self.norm(x)
- if not self.use_linear:
- x = self.proj_in(x)
- x = rearrange(x, "b c h w -> b (h w) c").contiguous()
- if self.use_linear:
- x = self.proj_in(x)
- for i, block in enumerate(self.transformer_blocks):
- x = block(x, context=context[i])
- if self.use_linear:
- x = self.proj_out(x)
- x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
- if not self.use_linear:
- x = self.proj_out(x)
- return x + x_in
|