| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336 |
- import os
- import numpy as np
- import torch
- from loguru import logger
- from sorawm.iopaint.schema import InpaintRequest, LDMSampler
- from .base import InpaintModel
- from .ddim_sampler import DDIMSampler
- from .plms_sampler import PLMSSampler
- torch.manual_seed(42)
- import torch.nn as nn
- from sorawm.iopaint.helper import (
- download_model,
- get_cache_path_by_url,
- load_jit_model,
- norm_img,
- )
- from .utils import make_beta_schedule, timestep_embedding
- LDM_ENCODE_MODEL_URL = os.environ.get(
- "LDM_ENCODE_MODEL_URL",
- "https://github.com/Sanster/models/releases/download/add_ldm/cond_stage_model_encode.pt",
- )
- LDM_ENCODE_MODEL_MD5 = os.environ.get(
- "LDM_ENCODE_MODEL_MD5", "23239fc9081956a3e70de56472b3f296"
- )
- LDM_DECODE_MODEL_URL = os.environ.get(
- "LDM_DECODE_MODEL_URL",
- "https://github.com/Sanster/models/releases/download/add_ldm/cond_stage_model_decode.pt",
- )
- LDM_DECODE_MODEL_MD5 = os.environ.get(
- "LDM_DECODE_MODEL_MD5", "fe419cd15a750d37a4733589d0d3585c"
- )
- LDM_DIFFUSION_MODEL_URL = os.environ.get(
- "LDM_DIFFUSION_MODEL_URL",
- "https://github.com/Sanster/models/releases/download/add_ldm/diffusion.pt",
- )
- LDM_DIFFUSION_MODEL_MD5 = os.environ.get(
- "LDM_DIFFUSION_MODEL_MD5", "b0afda12bf790c03aba2a7431f11d22d"
- )
- class DDPM(nn.Module):
- # classic DDPM with Gaussian diffusion, in image space
- def __init__(
- self,
- device,
- timesteps=1000,
- beta_schedule="linear",
- linear_start=0.0015,
- linear_end=0.0205,
- cosine_s=0.008,
- original_elbo_weight=0.0,
- v_posterior=0.0, # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
- l_simple_weight=1.0,
- parameterization="eps", # all assuming fixed variance schedules
- use_positional_encodings=False,
- ):
- super().__init__()
- self.device = device
- self.parameterization = parameterization
- self.use_positional_encodings = use_positional_encodings
- self.v_posterior = v_posterior
- self.original_elbo_weight = original_elbo_weight
- self.l_simple_weight = l_simple_weight
- self.register_schedule(
- beta_schedule=beta_schedule,
- timesteps=timesteps,
- linear_start=linear_start,
- linear_end=linear_end,
- cosine_s=cosine_s,
- )
- def register_schedule(
- self,
- given_betas=None,
- beta_schedule="linear",
- timesteps=1000,
- linear_start=1e-4,
- linear_end=2e-2,
- cosine_s=8e-3,
- ):
- betas = make_beta_schedule(
- self.device,
- beta_schedule,
- timesteps,
- linear_start=linear_start,
- linear_end=linear_end,
- cosine_s=cosine_s,
- )
- alphas = 1.0 - betas
- alphas_cumprod = np.cumprod(alphas, axis=0)
- alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
- (timesteps,) = betas.shape
- self.num_timesteps = int(timesteps)
- self.linear_start = linear_start
- self.linear_end = linear_end
- assert (
- alphas_cumprod.shape[0] == self.num_timesteps
- ), "alphas have to be defined for each timestep"
- to_torch = lambda x: torch.tensor(x, dtype=torch.float32).to(self.device)
- self.register_buffer("betas", to_torch(betas))
- self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
- self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev))
- # calculations for diffusion q(x_t | x_{t-1}) and others
- self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod)))
- self.register_buffer(
- "sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod))
- )
- self.register_buffer(
- "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod))
- )
- self.register_buffer(
- "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod))
- )
- self.register_buffer(
- "sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1))
- )
- # calculations for posterior q(x_{t-1} | x_t, x_0)
- posterior_variance = (1 - self.v_posterior) * betas * (
- 1.0 - alphas_cumprod_prev
- ) / (1.0 - alphas_cumprod) + self.v_posterior * betas
- # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
- self.register_buffer("posterior_variance", to_torch(posterior_variance))
- # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
- self.register_buffer(
- "posterior_log_variance_clipped",
- to_torch(np.log(np.maximum(posterior_variance, 1e-20))),
- )
- self.register_buffer(
- "posterior_mean_coef1",
- to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)),
- )
- self.register_buffer(
- "posterior_mean_coef2",
- to_torch(
- (1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod)
- ),
- )
- if self.parameterization == "eps":
- lvlb_weights = self.betas**2 / (
- 2
- * self.posterior_variance
- * to_torch(alphas)
- * (1 - self.alphas_cumprod)
- )
- elif self.parameterization == "x0":
- lvlb_weights = (
- 0.5
- * np.sqrt(torch.Tensor(alphas_cumprod))
- / (2.0 * 1 - torch.Tensor(alphas_cumprod))
- )
- else:
- raise NotImplementedError("mu not supported")
- # TODO how to choose this term
- lvlb_weights[0] = lvlb_weights[1]
- self.register_buffer("lvlb_weights", lvlb_weights, persistent=False)
- assert not torch.isnan(self.lvlb_weights).all()
- class LatentDiffusion(DDPM):
- def __init__(
- self,
- diffusion_model,
- device,
- cond_stage_key="image",
- cond_stage_trainable=False,
- concat_mode=True,
- scale_factor=1.0,
- scale_by_std=False,
- *args,
- **kwargs,
- ):
- self.num_timesteps_cond = 1
- self.scale_by_std = scale_by_std
- super().__init__(device, *args, **kwargs)
- self.diffusion_model = diffusion_model
- self.concat_mode = concat_mode
- self.cond_stage_trainable = cond_stage_trainable
- self.cond_stage_key = cond_stage_key
- self.num_downs = 2
- self.scale_factor = scale_factor
- def make_cond_schedule(
- self,
- ):
- self.cond_ids = torch.full(
- size=(self.num_timesteps,),
- fill_value=self.num_timesteps - 1,
- dtype=torch.long,
- )
- ids = torch.round(
- torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)
- ).long()
- self.cond_ids[: self.num_timesteps_cond] = ids
- def register_schedule(
- self,
- given_betas=None,
- beta_schedule="linear",
- timesteps=1000,
- linear_start=1e-4,
- linear_end=2e-2,
- cosine_s=8e-3,
- ):
- super().register_schedule(
- given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s
- )
- self.shorten_cond_schedule = self.num_timesteps_cond > 1
- if self.shorten_cond_schedule:
- self.make_cond_schedule()
- def apply_model(self, x_noisy, t, cond):
- # x_recon = self.model(x_noisy, t, cond['c_concat'][0]) # cond['c_concat'][0].shape 1,4,128,128
- t_emb = timestep_embedding(x_noisy.device, t, 256, repeat_only=False)
- x_recon = self.diffusion_model(x_noisy, t_emb, cond)
- return x_recon
- class LDM(InpaintModel):
- name = "ldm"
- pad_mod = 32
- is_erase_model = True
- def __init__(self, device, fp16: bool = True, **kwargs):
- self.fp16 = fp16
- super().__init__(device)
- self.device = device
- def init_model(self, device, **kwargs):
- self.diffusion_model = load_jit_model(
- LDM_DIFFUSION_MODEL_URL, device, LDM_DIFFUSION_MODEL_MD5
- )
- self.cond_stage_model_decode = load_jit_model(
- LDM_DECODE_MODEL_URL, device, LDM_DECODE_MODEL_MD5
- )
- self.cond_stage_model_encode = load_jit_model(
- LDM_ENCODE_MODEL_URL, device, LDM_ENCODE_MODEL_MD5
- )
- if self.fp16 and "cuda" in str(device):
- self.diffusion_model = self.diffusion_model.half()
- self.cond_stage_model_decode = self.cond_stage_model_decode.half()
- self.cond_stage_model_encode = self.cond_stage_model_encode.half()
- self.model = LatentDiffusion(self.diffusion_model, device)
- @staticmethod
- def download():
- download_model(LDM_DIFFUSION_MODEL_URL, LDM_DIFFUSION_MODEL_MD5)
- download_model(LDM_DECODE_MODEL_URL, LDM_DECODE_MODEL_MD5)
- download_model(LDM_ENCODE_MODEL_URL, LDM_ENCODE_MODEL_MD5)
- @staticmethod
- def is_downloaded() -> bool:
- model_paths = [
- get_cache_path_by_url(LDM_DIFFUSION_MODEL_URL),
- get_cache_path_by_url(LDM_DECODE_MODEL_URL),
- get_cache_path_by_url(LDM_ENCODE_MODEL_URL),
- ]
- return all([os.path.exists(it) for it in model_paths])
- @torch.cuda.amp.autocast()
- def forward(self, image, mask, config: InpaintRequest):
- """
- image: [H, W, C] RGB
- mask: [H, W, 1]
- return: BGR IMAGE
- """
- # image [1,3,512,512] float32
- # mask: [1,1,512,512] float32
- # masked_image: [1,3,512,512] float32
- if config.ldm_sampler == LDMSampler.ddim:
- sampler = DDIMSampler(self.model)
- elif config.ldm_sampler == LDMSampler.plms:
- sampler = PLMSSampler(self.model)
- else:
- raise ValueError()
- steps = config.ldm_steps
- image = norm_img(image)
- mask = norm_img(mask)
- mask[mask < 0.5] = 0
- mask[mask >= 0.5] = 1
- image = torch.from_numpy(image).unsqueeze(0).to(self.device)
- mask = torch.from_numpy(mask).unsqueeze(0).to(self.device)
- masked_image = (1 - mask) * image
- mask = self._norm(mask)
- masked_image = self._norm(masked_image)
- c = self.cond_stage_model_encode(masked_image)
- torch.cuda.empty_cache()
- cc = torch.nn.functional.interpolate(mask, size=c.shape[-2:]) # 1,1,128,128
- c = torch.cat((c, cc), dim=1) # 1,4,128,128
- shape = (c.shape[1] - 1,) + c.shape[2:]
- samples_ddim = sampler.sample(
- steps=steps, conditioning=c, batch_size=c.shape[0], shape=shape
- )
- torch.cuda.empty_cache()
- x_samples_ddim = self.cond_stage_model_decode(
- samples_ddim
- ) # samples_ddim: 1, 3, 128, 128 float32
- torch.cuda.empty_cache()
- # image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
- # mask = torch.clamp((mask + 1.0) / 2.0, min=0.0, max=1.0)
- inpainted_image = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
- # inpainted = (1 - mask) * image + mask * predicted_image
- inpainted_image = inpainted_image.cpu().numpy().transpose(0, 2, 3, 1)[0] * 255
- inpainted_image = inpainted_image.astype(np.uint8)[:, :, ::-1]
- return inpainted_image
- def _norm(self, tensor):
- return tensor * 2.0 - 1.0
|