| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108 |
- from functools import partial
- import numpy as np
- import torch
- import torch.nn as nn
- from sorawm.iopaint.model.anytext.ldm.modules.diffusionmodules.util import (
- extract_into_tensor,
- make_beta_schedule,
- )
- from sorawm.iopaint.model.anytext.ldm.util import default
- class AbstractLowScaleModel(nn.Module):
- # for concatenating a downsampled image to the latent representation
- def __init__(self, noise_schedule_config=None):
- super(AbstractLowScaleModel, self).__init__()
- if noise_schedule_config is not None:
- self.register_schedule(**noise_schedule_config)
- def register_schedule(
- self,
- beta_schedule="linear",
- timesteps=1000,
- linear_start=1e-4,
- linear_end=2e-2,
- cosine_s=8e-3,
- ):
- betas = make_beta_schedule(
- beta_schedule,
- timesteps,
- linear_start=linear_start,
- linear_end=linear_end,
- cosine_s=cosine_s,
- )
- alphas = 1.0 - betas
- alphas_cumprod = np.cumprod(alphas, axis=0)
- alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
- (timesteps,) = betas.shape
- self.num_timesteps = int(timesteps)
- self.linear_start = linear_start
- self.linear_end = linear_end
- assert (
- alphas_cumprod.shape[0] == self.num_timesteps
- ), "alphas have to be defined for each timestep"
- to_torch = partial(torch.tensor, dtype=torch.float32)
- self.register_buffer("betas", to_torch(betas))
- self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
- self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev))
- # calculations for diffusion q(x_t | x_{t-1}) and others
- self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod)))
- self.register_buffer(
- "sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod))
- )
- self.register_buffer(
- "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod))
- )
- self.register_buffer(
- "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod))
- )
- self.register_buffer(
- "sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1))
- )
- def q_sample(self, x_start, t, noise=None):
- noise = default(noise, lambda: torch.randn_like(x_start))
- return (
- extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
- + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
- * noise
- )
- def forward(self, x):
- return x, None
- def decode(self, x):
- return x
- class SimpleImageConcat(AbstractLowScaleModel):
- # no noise level conditioning
- def __init__(self):
- super(SimpleImageConcat, self).__init__(noise_schedule_config=None)
- self.max_noise_level = 0
- def forward(self, x):
- # fix to constant noise level
- return x, torch.zeros(x.shape[0], device=x.device).long()
- class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel):
- def __init__(self, noise_schedule_config, max_noise_level=1000, to_cuda=False):
- super().__init__(noise_schedule_config=noise_schedule_config)
- self.max_noise_level = max_noise_level
- def forward(self, x, noise_level=None):
- if noise_level is None:
- noise_level = torch.randint(
- 0, self.max_noise_level, (x.shape[0],), device=x.device
- ).long()
- else:
- assert isinstance(noise_level, torch.Tensor)
- z = self.q_sample(x, noise_level)
- return z, noise_level
|