upscaling.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. from functools import partial
  2. import numpy as np
  3. import torch
  4. import torch.nn as nn
  5. from sorawm.iopaint.model.anytext.ldm.modules.diffusionmodules.util import (
  6. extract_into_tensor,
  7. make_beta_schedule,
  8. )
  9. from sorawm.iopaint.model.anytext.ldm.util import default
  10. class AbstractLowScaleModel(nn.Module):
  11. # for concatenating a downsampled image to the latent representation
  12. def __init__(self, noise_schedule_config=None):
  13. super(AbstractLowScaleModel, self).__init__()
  14. if noise_schedule_config is not None:
  15. self.register_schedule(**noise_schedule_config)
  16. def register_schedule(
  17. self,
  18. beta_schedule="linear",
  19. timesteps=1000,
  20. linear_start=1e-4,
  21. linear_end=2e-2,
  22. cosine_s=8e-3,
  23. ):
  24. betas = make_beta_schedule(
  25. beta_schedule,
  26. timesteps,
  27. linear_start=linear_start,
  28. linear_end=linear_end,
  29. cosine_s=cosine_s,
  30. )
  31. alphas = 1.0 - betas
  32. alphas_cumprod = np.cumprod(alphas, axis=0)
  33. alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
  34. (timesteps,) = betas.shape
  35. self.num_timesteps = int(timesteps)
  36. self.linear_start = linear_start
  37. self.linear_end = linear_end
  38. assert (
  39. alphas_cumprod.shape[0] == self.num_timesteps
  40. ), "alphas have to be defined for each timestep"
  41. to_torch = partial(torch.tensor, dtype=torch.float32)
  42. self.register_buffer("betas", to_torch(betas))
  43. self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
  44. self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev))
  45. # calculations for diffusion q(x_t | x_{t-1}) and others
  46. self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod)))
  47. self.register_buffer(
  48. "sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod))
  49. )
  50. self.register_buffer(
  51. "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod))
  52. )
  53. self.register_buffer(
  54. "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod))
  55. )
  56. self.register_buffer(
  57. "sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1))
  58. )
  59. def q_sample(self, x_start, t, noise=None):
  60. noise = default(noise, lambda: torch.randn_like(x_start))
  61. return (
  62. extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
  63. + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
  64. * noise
  65. )
  66. def forward(self, x):
  67. return x, None
  68. def decode(self, x):
  69. return x
  70. class SimpleImageConcat(AbstractLowScaleModel):
  71. # no noise level conditioning
  72. def __init__(self):
  73. super(SimpleImageConcat, self).__init__(noise_schedule_config=None)
  74. self.max_noise_level = 0
  75. def forward(self, x):
  76. # fix to constant noise level
  77. return x, torch.zeros(x.shape[0], device=x.device).long()
  78. class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel):
  79. def __init__(self, noise_schedule_config, max_noise_level=1000, to_cuda=False):
  80. super().__init__(noise_schedule_config=noise_schedule_config)
  81. self.max_noise_level = max_noise_level
  82. def forward(self, x, noise_level=None):
  83. if noise_level is None:
  84. noise_level = torch.randint(
  85. 0, self.max_noise_level, (x.shape[0],), device=x.device
  86. ).long()
  87. else:
  88. assert isinstance(noise_level, torch.Tensor)
  89. z = self.q_sample(x, noise_level)
  90. return z, noise_level