ldm.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336
  1. import os
  2. import numpy as np
  3. import torch
  4. from loguru import logger
  5. from sorawm.iopaint.schema import InpaintRequest, LDMSampler
  6. from .base import InpaintModel
  7. from .ddim_sampler import DDIMSampler
  8. from .plms_sampler import PLMSSampler
  9. torch.manual_seed(42)
  10. import torch.nn as nn
  11. from sorawm.iopaint.helper import (
  12. download_model,
  13. get_cache_path_by_url,
  14. load_jit_model,
  15. norm_img,
  16. )
  17. from .utils import make_beta_schedule, timestep_embedding
  18. LDM_ENCODE_MODEL_URL = os.environ.get(
  19. "LDM_ENCODE_MODEL_URL",
  20. "https://github.com/Sanster/models/releases/download/add_ldm/cond_stage_model_encode.pt",
  21. )
  22. LDM_ENCODE_MODEL_MD5 = os.environ.get(
  23. "LDM_ENCODE_MODEL_MD5", "23239fc9081956a3e70de56472b3f296"
  24. )
  25. LDM_DECODE_MODEL_URL = os.environ.get(
  26. "LDM_DECODE_MODEL_URL",
  27. "https://github.com/Sanster/models/releases/download/add_ldm/cond_stage_model_decode.pt",
  28. )
  29. LDM_DECODE_MODEL_MD5 = os.environ.get(
  30. "LDM_DECODE_MODEL_MD5", "fe419cd15a750d37a4733589d0d3585c"
  31. )
  32. LDM_DIFFUSION_MODEL_URL = os.environ.get(
  33. "LDM_DIFFUSION_MODEL_URL",
  34. "https://github.com/Sanster/models/releases/download/add_ldm/diffusion.pt",
  35. )
  36. LDM_DIFFUSION_MODEL_MD5 = os.environ.get(
  37. "LDM_DIFFUSION_MODEL_MD5", "b0afda12bf790c03aba2a7431f11d22d"
  38. )
  39. class DDPM(nn.Module):
  40. # classic DDPM with Gaussian diffusion, in image space
  41. def __init__(
  42. self,
  43. device,
  44. timesteps=1000,
  45. beta_schedule="linear",
  46. linear_start=0.0015,
  47. linear_end=0.0205,
  48. cosine_s=0.008,
  49. original_elbo_weight=0.0,
  50. v_posterior=0.0, # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
  51. l_simple_weight=1.0,
  52. parameterization="eps", # all assuming fixed variance schedules
  53. use_positional_encodings=False,
  54. ):
  55. super().__init__()
  56. self.device = device
  57. self.parameterization = parameterization
  58. self.use_positional_encodings = use_positional_encodings
  59. self.v_posterior = v_posterior
  60. self.original_elbo_weight = original_elbo_weight
  61. self.l_simple_weight = l_simple_weight
  62. self.register_schedule(
  63. beta_schedule=beta_schedule,
  64. timesteps=timesteps,
  65. linear_start=linear_start,
  66. linear_end=linear_end,
  67. cosine_s=cosine_s,
  68. )
  69. def register_schedule(
  70. self,
  71. given_betas=None,
  72. beta_schedule="linear",
  73. timesteps=1000,
  74. linear_start=1e-4,
  75. linear_end=2e-2,
  76. cosine_s=8e-3,
  77. ):
  78. betas = make_beta_schedule(
  79. self.device,
  80. beta_schedule,
  81. timesteps,
  82. linear_start=linear_start,
  83. linear_end=linear_end,
  84. cosine_s=cosine_s,
  85. )
  86. alphas = 1.0 - betas
  87. alphas_cumprod = np.cumprod(alphas, axis=0)
  88. alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
  89. (timesteps,) = betas.shape
  90. self.num_timesteps = int(timesteps)
  91. self.linear_start = linear_start
  92. self.linear_end = linear_end
  93. assert (
  94. alphas_cumprod.shape[0] == self.num_timesteps
  95. ), "alphas have to be defined for each timestep"
  96. to_torch = lambda x: torch.tensor(x, dtype=torch.float32).to(self.device)
  97. self.register_buffer("betas", to_torch(betas))
  98. self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
  99. self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev))
  100. # calculations for diffusion q(x_t | x_{t-1}) and others
  101. self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod)))
  102. self.register_buffer(
  103. "sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod))
  104. )
  105. self.register_buffer(
  106. "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod))
  107. )
  108. self.register_buffer(
  109. "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod))
  110. )
  111. self.register_buffer(
  112. "sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1))
  113. )
  114. # calculations for posterior q(x_{t-1} | x_t, x_0)
  115. posterior_variance = (1 - self.v_posterior) * betas * (
  116. 1.0 - alphas_cumprod_prev
  117. ) / (1.0 - alphas_cumprod) + self.v_posterior * betas
  118. # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
  119. self.register_buffer("posterior_variance", to_torch(posterior_variance))
  120. # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
  121. self.register_buffer(
  122. "posterior_log_variance_clipped",
  123. to_torch(np.log(np.maximum(posterior_variance, 1e-20))),
  124. )
  125. self.register_buffer(
  126. "posterior_mean_coef1",
  127. to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)),
  128. )
  129. self.register_buffer(
  130. "posterior_mean_coef2",
  131. to_torch(
  132. (1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod)
  133. ),
  134. )
  135. if self.parameterization == "eps":
  136. lvlb_weights = self.betas**2 / (
  137. 2
  138. * self.posterior_variance
  139. * to_torch(alphas)
  140. * (1 - self.alphas_cumprod)
  141. )
  142. elif self.parameterization == "x0":
  143. lvlb_weights = (
  144. 0.5
  145. * np.sqrt(torch.Tensor(alphas_cumprod))
  146. / (2.0 * 1 - torch.Tensor(alphas_cumprod))
  147. )
  148. else:
  149. raise NotImplementedError("mu not supported")
  150. # TODO how to choose this term
  151. lvlb_weights[0] = lvlb_weights[1]
  152. self.register_buffer("lvlb_weights", lvlb_weights, persistent=False)
  153. assert not torch.isnan(self.lvlb_weights).all()
  154. class LatentDiffusion(DDPM):
  155. def __init__(
  156. self,
  157. diffusion_model,
  158. device,
  159. cond_stage_key="image",
  160. cond_stage_trainable=False,
  161. concat_mode=True,
  162. scale_factor=1.0,
  163. scale_by_std=False,
  164. *args,
  165. **kwargs,
  166. ):
  167. self.num_timesteps_cond = 1
  168. self.scale_by_std = scale_by_std
  169. super().__init__(device, *args, **kwargs)
  170. self.diffusion_model = diffusion_model
  171. self.concat_mode = concat_mode
  172. self.cond_stage_trainable = cond_stage_trainable
  173. self.cond_stage_key = cond_stage_key
  174. self.num_downs = 2
  175. self.scale_factor = scale_factor
  176. def make_cond_schedule(
  177. self,
  178. ):
  179. self.cond_ids = torch.full(
  180. size=(self.num_timesteps,),
  181. fill_value=self.num_timesteps - 1,
  182. dtype=torch.long,
  183. )
  184. ids = torch.round(
  185. torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)
  186. ).long()
  187. self.cond_ids[: self.num_timesteps_cond] = ids
  188. def register_schedule(
  189. self,
  190. given_betas=None,
  191. beta_schedule="linear",
  192. timesteps=1000,
  193. linear_start=1e-4,
  194. linear_end=2e-2,
  195. cosine_s=8e-3,
  196. ):
  197. super().register_schedule(
  198. given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s
  199. )
  200. self.shorten_cond_schedule = self.num_timesteps_cond > 1
  201. if self.shorten_cond_schedule:
  202. self.make_cond_schedule()
  203. def apply_model(self, x_noisy, t, cond):
  204. # x_recon = self.model(x_noisy, t, cond['c_concat'][0]) # cond['c_concat'][0].shape 1,4,128,128
  205. t_emb = timestep_embedding(x_noisy.device, t, 256, repeat_only=False)
  206. x_recon = self.diffusion_model(x_noisy, t_emb, cond)
  207. return x_recon
  208. class LDM(InpaintModel):
  209. name = "ldm"
  210. pad_mod = 32
  211. is_erase_model = True
  212. def __init__(self, device, fp16: bool = True, **kwargs):
  213. self.fp16 = fp16
  214. super().__init__(device)
  215. self.device = device
  216. def init_model(self, device, **kwargs):
  217. self.diffusion_model = load_jit_model(
  218. LDM_DIFFUSION_MODEL_URL, device, LDM_DIFFUSION_MODEL_MD5
  219. )
  220. self.cond_stage_model_decode = load_jit_model(
  221. LDM_DECODE_MODEL_URL, device, LDM_DECODE_MODEL_MD5
  222. )
  223. self.cond_stage_model_encode = load_jit_model(
  224. LDM_ENCODE_MODEL_URL, device, LDM_ENCODE_MODEL_MD5
  225. )
  226. if self.fp16 and "cuda" in str(device):
  227. self.diffusion_model = self.diffusion_model.half()
  228. self.cond_stage_model_decode = self.cond_stage_model_decode.half()
  229. self.cond_stage_model_encode = self.cond_stage_model_encode.half()
  230. self.model = LatentDiffusion(self.diffusion_model, device)
  231. @staticmethod
  232. def download():
  233. download_model(LDM_DIFFUSION_MODEL_URL, LDM_DIFFUSION_MODEL_MD5)
  234. download_model(LDM_DECODE_MODEL_URL, LDM_DECODE_MODEL_MD5)
  235. download_model(LDM_ENCODE_MODEL_URL, LDM_ENCODE_MODEL_MD5)
  236. @staticmethod
  237. def is_downloaded() -> bool:
  238. model_paths = [
  239. get_cache_path_by_url(LDM_DIFFUSION_MODEL_URL),
  240. get_cache_path_by_url(LDM_DECODE_MODEL_URL),
  241. get_cache_path_by_url(LDM_ENCODE_MODEL_URL),
  242. ]
  243. return all([os.path.exists(it) for it in model_paths])
  244. @torch.cuda.amp.autocast()
  245. def forward(self, image, mask, config: InpaintRequest):
  246. """
  247. image: [H, W, C] RGB
  248. mask: [H, W, 1]
  249. return: BGR IMAGE
  250. """
  251. # image [1,3,512,512] float32
  252. # mask: [1,1,512,512] float32
  253. # masked_image: [1,3,512,512] float32
  254. if config.ldm_sampler == LDMSampler.ddim:
  255. sampler = DDIMSampler(self.model)
  256. elif config.ldm_sampler == LDMSampler.plms:
  257. sampler = PLMSSampler(self.model)
  258. else:
  259. raise ValueError()
  260. steps = config.ldm_steps
  261. image = norm_img(image)
  262. mask = norm_img(mask)
  263. mask[mask < 0.5] = 0
  264. mask[mask >= 0.5] = 1
  265. image = torch.from_numpy(image).unsqueeze(0).to(self.device)
  266. mask = torch.from_numpy(mask).unsqueeze(0).to(self.device)
  267. masked_image = (1 - mask) * image
  268. mask = self._norm(mask)
  269. masked_image = self._norm(masked_image)
  270. c = self.cond_stage_model_encode(masked_image)
  271. torch.cuda.empty_cache()
  272. cc = torch.nn.functional.interpolate(mask, size=c.shape[-2:]) # 1,1,128,128
  273. c = torch.cat((c, cc), dim=1) # 1,4,128,128
  274. shape = (c.shape[1] - 1,) + c.shape[2:]
  275. samples_ddim = sampler.sample(
  276. steps=steps, conditioning=c, batch_size=c.shape[0], shape=shape
  277. )
  278. torch.cuda.empty_cache()
  279. x_samples_ddim = self.cond_stage_model_decode(
  280. samples_ddim
  281. ) # samples_ddim: 1, 3, 128, 128 float32
  282. torch.cuda.empty_cache()
  283. # image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
  284. # mask = torch.clamp((mask + 1.0) / 2.0, min=0.0, max=1.0)
  285. inpainted_image = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
  286. # inpainted = (1 - mask) * image + mask * predicted_image
  287. inpainted_image = inpainted_image.cpu().numpy().transpose(0, 2, 3, 1)[0] * 255
  288. inpainted_image = inpainted_image.astype(np.uint8)[:, :, ::-1]
  289. return inpainted_image
  290. def _norm(self, tensor):
  291. return tensor * 2.0 - 1.0