controlnet_preprocess.py 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. import cv2
  2. import numpy as np
  3. import PIL
  4. import torch
  5. from PIL import Image
  6. from sorawm.iopaint.helper import pad_img_to_modulo
  7. def make_canny_control_image(image: np.ndarray) -> Image:
  8. canny_image = cv2.Canny(image, 100, 200)
  9. canny_image = canny_image[:, :, None]
  10. canny_image = np.concatenate([canny_image, canny_image, canny_image], axis=2)
  11. canny_image = PIL.Image.fromarray(canny_image)
  12. control_image = canny_image
  13. return control_image
  14. def make_openpose_control_image(image: np.ndarray) -> Image:
  15. from controlnet_aux import OpenposeDetector
  16. processor = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")
  17. control_image = processor(image, hand_and_face=True)
  18. return control_image
  19. def resize_image(input_image, resolution):
  20. H, W, C = input_image.shape
  21. H = float(H)
  22. W = float(W)
  23. k = float(resolution) / min(H, W)
  24. H *= k
  25. W *= k
  26. H = int(np.round(H / 64.0)) * 64
  27. W = int(np.round(W / 64.0)) * 64
  28. img = cv2.resize(
  29. input_image,
  30. (W, H),
  31. interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA,
  32. )
  33. return img
  34. def make_depth_control_image(image: np.ndarray) -> Image:
  35. from controlnet_aux import MidasDetector
  36. midas = MidasDetector.from_pretrained("lllyasviel/Annotators")
  37. origin_height, origin_width = image.shape[:2]
  38. pad_image = pad_img_to_modulo(image, mod=64, square=False, min_size=512)
  39. depth_image = midas(pad_image)
  40. depth_image = depth_image[0:origin_height, 0:origin_width]
  41. depth_image = depth_image[:, :, None]
  42. depth_image = np.concatenate([depth_image, depth_image, depth_image], axis=2)
  43. control_image = PIL.Image.fromarray(depth_image)
  44. return control_image
  45. def make_inpaint_control_image(image: np.ndarray, mask: np.ndarray) -> torch.Tensor:
  46. """
  47. image: [H, W, C] RGB
  48. mask: [H, W, 1] 255 means area to repaint
  49. """
  50. image = image.astype(np.float32) / 255.0
  51. image[mask[:, :, -1] > 128] = -1.0 # set as masked pixel
  52. image = np.expand_dims(image, 0).transpose(0, 3, 1, 2)
  53. image = torch.from_numpy(image)
  54. return image