| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419 |
- import math
- from typing import Callable, Optional, Union
- import numpy as np
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- def modulate(x, shift, scale):
- return x * (1 + scale) + shift
- def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
- xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
- freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
- x_out2 = torch.stack(
- [
- xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
- xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
- ],
- -1,
- )
- x_out2 = x_out2.flatten(3)
- return x_out2.type_as(x)
- class TimestepEmbedder(nn.Module):
- """
- Embeds scalar timesteps into vector representations.
- """
- def __init__(self, hidden_size, frequency_embedding_size=256):
- super().__init__()
- self.mlp = FeedForward(
- frequency_embedding_size, hidden_size, out_dim=hidden_size
- )
- self.frequency_embedding_size = frequency_embedding_size
- @staticmethod
- def timestep_embedding(t, dim, max_period=10000):
- """
- Create sinusoidal timestep embeddings.
- :param t: a 1-D Tensor of N indices, one per batch element.
- These may be fractional.
- :param dim: the dimension of the output.
- :param max_period: controls the minimum frequency of the embeddings.
- :return: an (N, D) Tensor of positional embeddings.
- """
- # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
- half = dim // 2
- freqs = torch.exp(
- -math.log(max_period)
- * torch.arange(start=0, end=half, dtype=torch.float32)
- / half
- ).to(device=t.device)
- args = t[:, None].float() * freqs[None]
- embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
- if dim % 2:
- embedding = torch.cat(
- [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
- )
- return embedding
- def forward(self, t):
- t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
- t_emb = self.mlp(t_freq)
- return t_emb
- def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000) -> torch.Tensor:
- freqs = 1.0 / (
- base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)
- )
- t = torch.arange(seq_len, device=freqs.device)
- freqs = torch.outer(t, freqs)
- freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
- cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
- return cache.to(dtype=torch.bfloat16)
- class Attention(nn.Module):
- def __init__(
- self,
- dim,
- n_head,
- ):
- super().__init__()
- assert dim % n_head == 0
- self.dim = dim
- self.n_head = n_head
- self.head_dim = dim // n_head
- self.wq = nn.Linear(dim, dim)
- self.wk = nn.Linear(dim, dim)
- self.wv = nn.Linear(dim, dim)
- self.wo = nn.Linear(dim, dim)
- def forward(self, q, freqs_cis, kv=None, mask=None):
- bsz, seqlen, _ = q.shape
- if kv is None:
- kv = q
- kv_seqlen = kv.shape[1]
- q = self.wq(q).view(bsz, seqlen, self.n_head, self.head_dim)
- k = self.wk(kv).view(bsz, kv_seqlen, self.n_head, self.head_dim)
- v = self.wv(kv).view(bsz, kv_seqlen, self.n_head, self.head_dim)
- q = apply_rotary_emb(q, freqs_cis[:seqlen])
- k = apply_rotary_emb(k, freqs_cis[:kv_seqlen])
- q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
- y = F.scaled_dot_product_attention(
- q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False
- )
- y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
- y = self.wo(y)
- return y
- class FeedForward(nn.Module):
- def __init__(self, in_dim, intermediate_size, out_dim=None):
- super().__init__()
- self.w1 = nn.Linear(in_dim, intermediate_size)
- self.w3 = nn.Linear(in_dim, intermediate_size)
- self.w2 = nn.Linear(intermediate_size, out_dim or in_dim)
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- return self.w2(F.silu(self.w1(x)) * self.w3(x))
- class DiTBlock(nn.Module):
- def __init__(
- self,
- hidden_size,
- num_heads,
- mlp_ratio=4.0,
- use_self_attention=True,
- use_cross_attention=False,
- ):
- super().__init__()
- self.use_self_attention = use_self_attention
- self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
- if use_self_attention:
- self.mix = Attention(hidden_size, num_heads)
- else:
- self.mix = nn.Conv1d(
- hidden_size,
- hidden_size,
- kernel_size=7,
- padding=3,
- bias=True,
- groups=hidden_size,
- )
- self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
- self.mlp = FeedForward(hidden_size, int(hidden_size * mlp_ratio))
- self.adaLN_modulation = nn.Sequential(
- nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True)
- )
- self.use_cross_attention = use_cross_attention
- if self.use_cross_attention:
- self.norm3 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
- self.norm4 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
- self.cross_attn = Attention(hidden_size, num_heads)
- self.adaLN_modulation_cross = nn.Sequential(
- nn.SiLU(), nn.Linear(hidden_size, 3 * hidden_size, bias=True)
- )
- self.adaLN_modulation_cross_condition = nn.Sequential(
- nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)
- )
- def forward(
- self,
- x,
- condition,
- freqs_cis,
- self_mask=None,
- cross_condition=None,
- cross_mask=None,
- ):
- (
- shift_msa,
- scale_msa,
- gate_msa,
- shift_mlp,
- scale_mlp,
- gate_mlp,
- ) = self.adaLN_modulation(condition).chunk(6, dim=-1)
- # Self-attention
- inp = modulate(self.norm1(x), shift_msa, scale_msa)
- if self.use_self_attention:
- inp = self.mix(inp, freqs_cis=freqs_cis, mask=self_mask)
- else:
- inp = self.mix(inp.mT).mT
- x = x + gate_msa * inp
- # Cross-attention
- if self.use_cross_attention:
- (
- shift_cross,
- scale_cross,
- gate_cross,
- ) = self.adaLN_modulation_cross(
- condition
- ).chunk(3, dim=-1)
- (
- shift_cross_condition,
- scale_cross_condition,
- ) = self.adaLN_modulation_cross_condition(cross_condition).chunk(2, dim=-1)
- inp = modulate(self.norm3(x), shift_cross, scale_cross)
- inp = self.cross_attn(
- inp,
- freqs_cis=freqs_cis,
- kv=modulate(
- self.norm4(cross_condition),
- shift_cross_condition,
- scale_cross_condition,
- ),
- mask=cross_mask,
- )
- x = x + gate_cross * inp
- # MLP
- x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
- return x
- class FinalLayer(nn.Module):
- """
- The final layer of DiT.
- """
- def __init__(self, hidden_size, out_channels):
- super().__init__()
- self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
- self.linear = nn.Linear(hidden_size, out_channels, bias=True)
- self.adaLN_modulation = nn.Sequential(
- nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)
- )
- def forward(self, x, c):
- shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
- x = modulate(self.norm_final(x), shift, scale)
- return self.linear(x)
- class DiT(nn.Module):
- def __init__(
- self,
- hidden_size,
- num_heads,
- diffusion_num_layers,
- channels=160,
- mlp_ratio=4.0,
- max_seq_len=16384,
- condition_dim=512,
- style_dim=None,
- cross_condition_dim=None,
- ):
- super().__init__()
- self.max_seq_len = max_seq_len
- self.time_embedder = TimestepEmbedder(hidden_size)
- self.condition_embedder = FeedForward(
- condition_dim, int(hidden_size * mlp_ratio), out_dim=hidden_size
- )
- if cross_condition_dim is not None:
- self.cross_condition_embedder = FeedForward(
- cross_condition_dim, int(hidden_size * mlp_ratio), out_dim=hidden_size
- )
- self.use_style = style_dim is not None
- if self.use_style:
- self.style_embedder = FeedForward(
- style_dim, int(hidden_size * mlp_ratio), out_dim=hidden_size
- )
- self.diffusion_blocks = nn.ModuleList(
- [
- DiTBlock(
- hidden_size,
- num_heads,
- mlp_ratio,
- use_self_attention=i % 4 == 0,
- use_cross_attention=cross_condition_dim is not None,
- )
- for i in range(diffusion_num_layers)
- ]
- )
- # Downsample & upsample blocks
- self.input_embedder = FeedForward(
- channels, int(hidden_size * mlp_ratio), out_dim=hidden_size
- )
- self.final_layer = FinalLayer(hidden_size, channels)
- self.register_buffer(
- "freqs_cis", precompute_freqs_cis(max_seq_len, hidden_size // num_heads)
- )
- self.initialize_weights()
- def initialize_weights(self):
- # Initialize input embedding:
- self.input_embedder.apply(self.init_weight)
- self.time_embedder.mlp.apply(self.init_weight)
- self.condition_embedder.apply(self.init_weight)
- if self.use_style:
- self.style_embedder.apply(self.init_weight)
- if hasattr(self, "cross_condition_embedder"):
- self.cross_condition_embedder.apply(self.init_weight)
- for block in self.diffusion_blocks:
- nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
- nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
- block.mix.apply(self.init_weight)
- # Zero-out output layers:
- nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
- nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
- self.final_layer.linear.apply(self.init_weight)
- def init_weight(self, m):
- if isinstance(m, (nn.Conv1d, nn.ConvTranspose1d, nn.Linear)):
- nn.init.normal_(m.weight, 0, 0.02)
- if m.bias is not None:
- nn.init.constant_(m.bias, 0)
- def forward(
- self,
- x,
- time,
- condition,
- style=None,
- self_mask=None,
- cross_condition=None,
- cross_mask=None,
- ):
- # Embed inputs
- x = self.input_embedder(x)
- t = self.time_embedder(time)
- condition = self.condition_embedder(condition)
- if self.use_style:
- style = self.style_embedder(style)
- if cross_condition is not None:
- cross_condition = self.cross_condition_embedder(cross_condition)
- cross_condition = t[:, None, :] + cross_condition
- # Merge t, condition, and style
- condition = t[:, None, :] + condition
- if self.use_style:
- condition = condition + style[:, None, :]
- if self_mask is not None:
- self_mask = self_mask[:, None, None, :]
- if cross_mask is not None:
- cross_mask = cross_mask[:, None, None, :]
- # DiT
- for block in self.diffusion_blocks:
- x = block(
- x,
- condition,
- self.freqs_cis,
- self_mask=self_mask,
- cross_condition=cross_condition,
- cross_mask=cross_mask,
- )
- x = self.final_layer(x, condition)
- return x
- if __name__ == "__main__":
- model = DiT(
- hidden_size=384,
- num_heads=6,
- diffusion_num_layers=12,
- channels=160,
- condition_dim=512,
- style_dim=256,
- )
- bs, seq_len = 8, 1024
- x = torch.randn(bs, seq_len, 160)
- condition = torch.randn(bs, seq_len, 512)
- style = torch.randn(bs, 256)
- mask = torch.ones(bs, seq_len, dtype=torch.bool)
- mask[0, 5:] = False
- time = torch.arange(bs)
- print(time)
- out = model(x, time, condition, style, self_mask=mask)
- print(out.shape) # torch.Size([2, 100, 160])
- # Print model size
- num_params = sum(p.numel() for p in model.parameters())
- print(f"Number of parameters: {num_params / 1e6:.1f}M")
|