ddim_sampler.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. import numpy as np
  2. import torch
  3. from loguru import logger
  4. from tqdm import tqdm
  5. from .utils import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
  6. class DDIMSampler(object):
  7. def __init__(self, model, schedule="linear"):
  8. super().__init__()
  9. self.model = model
  10. self.ddpm_num_timesteps = model.num_timesteps
  11. self.schedule = schedule
  12. def register_buffer(self, name, attr):
  13. setattr(self, name, attr)
  14. def make_schedule(
  15. self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True
  16. ):
  17. self.ddim_timesteps = make_ddim_timesteps(
  18. ddim_discr_method=ddim_discretize,
  19. num_ddim_timesteps=ddim_num_steps,
  20. # array([1])
  21. num_ddpm_timesteps=self.ddpm_num_timesteps,
  22. verbose=verbose,
  23. )
  24. alphas_cumprod = self.model.alphas_cumprod # torch.Size([1000])
  25. assert (
  26. alphas_cumprod.shape[0] == self.ddpm_num_timesteps
  27. ), "alphas have to be defined for each timestep"
  28. to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
  29. self.register_buffer("betas", to_torch(self.model.betas))
  30. self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
  31. self.register_buffer(
  32. "alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev)
  33. )
  34. # calculations for diffusion q(x_t | x_{t-1}) and others
  35. self.register_buffer(
  36. "sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu()))
  37. )
  38. self.register_buffer(
  39. "sqrt_one_minus_alphas_cumprod",
  40. to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
  41. )
  42. self.register_buffer(
  43. "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu()))
  44. )
  45. self.register_buffer(
  46. "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))
  47. )
  48. self.register_buffer(
  49. "sqrt_recipm1_alphas_cumprod",
  50. to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
  51. )
  52. # ddim sampling parameters
  53. ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
  54. alphacums=alphas_cumprod.cpu(),
  55. ddim_timesteps=self.ddim_timesteps,
  56. eta=ddim_eta,
  57. verbose=verbose,
  58. )
  59. self.register_buffer("ddim_sigmas", ddim_sigmas)
  60. self.register_buffer("ddim_alphas", ddim_alphas)
  61. self.register_buffer("ddim_alphas_prev", ddim_alphas_prev)
  62. self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas))
  63. sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
  64. (1 - self.alphas_cumprod_prev)
  65. / (1 - self.alphas_cumprod)
  66. * (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
  67. )
  68. self.register_buffer(
  69. "ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps
  70. )
  71. @torch.no_grad()
  72. def sample(self, steps, conditioning, batch_size, shape):
  73. self.make_schedule(ddim_num_steps=steps, ddim_eta=0, verbose=False)
  74. # sampling
  75. C, H, W = shape
  76. size = (batch_size, C, H, W)
  77. # samples: 1,3,128,128
  78. return self.ddim_sampling(
  79. conditioning,
  80. size,
  81. quantize_denoised=False,
  82. ddim_use_original_steps=False,
  83. noise_dropout=0,
  84. temperature=1.0,
  85. )
  86. @torch.no_grad()
  87. def ddim_sampling(
  88. self,
  89. cond,
  90. shape,
  91. ddim_use_original_steps=False,
  92. quantize_denoised=False,
  93. temperature=1.0,
  94. noise_dropout=0.0,
  95. ):
  96. device = self.model.betas.device
  97. b = shape[0]
  98. img = torch.randn(shape, device=device, dtype=cond.dtype)
  99. timesteps = (
  100. self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
  101. )
  102. time_range = (
  103. reversed(range(0, timesteps))
  104. if ddim_use_original_steps
  105. else np.flip(timesteps)
  106. )
  107. total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
  108. logger.info(f"Running DDIM Sampling with {total_steps} timesteps")
  109. iterator = tqdm(time_range, desc="DDIM Sampler", total=total_steps)
  110. for i, step in enumerate(iterator):
  111. index = total_steps - i - 1
  112. ts = torch.full((b,), step, device=device, dtype=torch.long)
  113. outs = self.p_sample_ddim(
  114. img,
  115. cond,
  116. ts,
  117. index=index,
  118. use_original_steps=ddim_use_original_steps,
  119. quantize_denoised=quantize_denoised,
  120. temperature=temperature,
  121. noise_dropout=noise_dropout,
  122. )
  123. img, _ = outs
  124. return img
  125. @torch.no_grad()
  126. def p_sample_ddim(
  127. self,
  128. x,
  129. c,
  130. t,
  131. index,
  132. repeat_noise=False,
  133. use_original_steps=False,
  134. quantize_denoised=False,
  135. temperature=1.0,
  136. noise_dropout=0.0,
  137. ):
  138. b, *_, device = *x.shape, x.device
  139. e_t = self.model.apply_model(x, t, c)
  140. alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
  141. alphas_prev = (
  142. self.model.alphas_cumprod_prev
  143. if use_original_steps
  144. else self.ddim_alphas_prev
  145. )
  146. sqrt_one_minus_alphas = (
  147. self.model.sqrt_one_minus_alphas_cumprod
  148. if use_original_steps
  149. else self.ddim_sqrt_one_minus_alphas
  150. )
  151. sigmas = (
  152. self.model.ddim_sigmas_for_original_num_steps
  153. if use_original_steps
  154. else self.ddim_sigmas
  155. )
  156. # select parameters corresponding to the currently considered timestep
  157. a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
  158. a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
  159. sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
  160. sqrt_one_minus_at = torch.full(
  161. (b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device
  162. )
  163. # current prediction for x_0
  164. pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
  165. if quantize_denoised: # 没用
  166. pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
  167. # direction pointing to x_t
  168. dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t
  169. noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
  170. if noise_dropout > 0.0: # 没用
  171. noise = torch.nn.functional.dropout(noise, p=noise_dropout)
  172. x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
  173. return x_prev, pred_x0