helper.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411
  1. import base64
  2. import hashlib
  3. import imghdr
  4. import io
  5. import os
  6. import sys
  7. from typing import Dict, List, Optional, Tuple
  8. from urllib.parse import urlparse
  9. import cv2
  10. import numpy as np
  11. import torch
  12. from loguru import logger
  13. from PIL import Image, ImageOps, PngImagePlugin
  14. from torch.hub import download_url_to_file, get_dir
  15. from sorawm.iopaint.const import MPS_UNSUPPORT_MODELS
  16. def md5sum(filename):
  17. md5 = hashlib.md5()
  18. with open(filename, "rb") as f:
  19. for chunk in iter(lambda: f.read(128 * md5.block_size), b""):
  20. md5.update(chunk)
  21. return md5.hexdigest()
  22. def switch_mps_device(model_name, device):
  23. if model_name in MPS_UNSUPPORT_MODELS and str(device) == "mps":
  24. logger.info(f"{model_name} not support mps, switch to cpu")
  25. return torch.device("cpu")
  26. return device
  27. def get_cache_path_by_url(url):
  28. parts = urlparse(url)
  29. hub_dir = get_dir()
  30. model_dir = os.path.join(hub_dir, "checkpoints")
  31. if not os.path.isdir(model_dir):
  32. os.makedirs(model_dir)
  33. filename = os.path.basename(parts.path)
  34. cached_file = os.path.join(model_dir, filename)
  35. return cached_file
  36. def download_model(url, model_md5: str = None):
  37. if os.path.exists(url):
  38. cached_file = url
  39. else:
  40. cached_file = get_cache_path_by_url(url)
  41. if not os.path.exists(cached_file):
  42. sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
  43. hash_prefix = None
  44. download_url_to_file(url, cached_file, hash_prefix, progress=True)
  45. if model_md5:
  46. _md5 = md5sum(cached_file)
  47. if model_md5 == _md5:
  48. logger.info(f"Download model success, md5: {_md5}")
  49. else:
  50. try:
  51. os.remove(cached_file)
  52. logger.error(
  53. f"Model md5: {_md5}, expected md5: {model_md5}, wrong model deleted. Please restart sorawm.iopaint."
  54. f"If you still have errors, please try download model manually first https://lama-cleaner-docs.vercel.app/install/download_model_manually.\n"
  55. )
  56. except:
  57. logger.error(
  58. f"Model md5: {_md5}, expected md5: {model_md5}, please delete {cached_file} and restart sorawm.iopaint."
  59. )
  60. exit(-1)
  61. return cached_file
  62. def ceil_modulo(x, mod):
  63. if x % mod == 0:
  64. return x
  65. return (x // mod + 1) * mod
  66. def handle_error(model_path, model_md5, e):
  67. _md5 = md5sum(model_path)
  68. if _md5 != model_md5:
  69. try:
  70. os.remove(model_path)
  71. logger.error(
  72. f"Model md5: {_md5}, expected md5: {model_md5}, wrong model deleted. Please restart sorawm.iopaint."
  73. f"If you still have errors, please try download model manually first https://lama-cleaner-docs.vercel.app/install/download_model_manually.\n"
  74. )
  75. except:
  76. logger.error(
  77. f"Model md5: {_md5}, expected md5: {model_md5}, please delete {model_path} and restart sorawm.iopaint."
  78. )
  79. else:
  80. logger.error(
  81. f"Failed to load model {model_path},"
  82. f"please submit an issue at https://github.com/Sanster/lama-cleaner/issues and include a screenshot of the error:\n{e}"
  83. )
  84. exit(-1)
  85. def load_jit_model(url_or_path, device, model_md5: str):
  86. if os.path.exists(url_or_path):
  87. model_path = url_or_path
  88. else:
  89. model_path = download_model(url_or_path, model_md5)
  90. logger.info(f"Loading model from: {model_path}")
  91. try:
  92. model = torch.jit.load(model_path, map_location="cpu").to(device)
  93. except Exception as e:
  94. handle_error(model_path, model_md5, e)
  95. model.eval()
  96. return model
  97. def load_model(model: torch.nn.Module, url_or_path, device, model_md5):
  98. if os.path.exists(url_or_path):
  99. model_path = url_or_path
  100. else:
  101. model_path = download_model(url_or_path, model_md5)
  102. try:
  103. logger.info(f"Loading model from: {model_path}")
  104. state_dict = torch.load(model_path, map_location="cpu")
  105. model.load_state_dict(state_dict, strict=True)
  106. model.to(device)
  107. except Exception as e:
  108. handle_error(model_path, model_md5, e)
  109. model.eval()
  110. return model
  111. def numpy_to_bytes(image_numpy: np.ndarray, ext: str) -> bytes:
  112. data = cv2.imencode(
  113. f".{ext}",
  114. image_numpy,
  115. [int(cv2.IMWRITE_JPEG_QUALITY), 100, int(cv2.IMWRITE_PNG_COMPRESSION), 0],
  116. )[1]
  117. image_bytes = data.tobytes()
  118. return image_bytes
  119. def pil_to_bytes(pil_img, ext: str, quality: int = 95, infos={}) -> bytes:
  120. with io.BytesIO() as output:
  121. kwargs = {k: v for k, v in infos.items() if v is not None}
  122. if ext == "jpg":
  123. ext = "jpeg"
  124. if "png" == ext.lower() and "parameters" in kwargs:
  125. pnginfo_data = PngImagePlugin.PngInfo()
  126. pnginfo_data.add_text("parameters", kwargs["parameters"])
  127. kwargs["pnginfo"] = pnginfo_data
  128. pil_img.save(output, format=ext, quality=quality, **kwargs)
  129. image_bytes = output.getvalue()
  130. return image_bytes
  131. def load_img(img_bytes, gray: bool = False, return_info: bool = False):
  132. alpha_channel = None
  133. image = Image.open(io.BytesIO(img_bytes))
  134. if return_info:
  135. infos = image.info
  136. try:
  137. image = ImageOps.exif_transpose(image)
  138. except:
  139. pass
  140. if gray:
  141. image = image.convert("L")
  142. np_img = np.array(image)
  143. else:
  144. if image.mode == "RGBA":
  145. np_img = np.array(image)
  146. alpha_channel = np_img[:, :, -1]
  147. np_img = cv2.cvtColor(np_img, cv2.COLOR_RGBA2RGB)
  148. else:
  149. image = image.convert("RGB")
  150. np_img = np.array(image)
  151. if return_info:
  152. return np_img, alpha_channel, infos
  153. return np_img, alpha_channel
  154. def norm_img(np_img):
  155. if len(np_img.shape) == 2:
  156. np_img = np_img[:, :, np.newaxis]
  157. np_img = np.transpose(np_img, (2, 0, 1))
  158. np_img = np_img.astype("float32") / 255
  159. return np_img
  160. def resize_max_size(
  161. np_img, size_limit: int, interpolation=cv2.INTER_CUBIC
  162. ) -> np.ndarray:
  163. # Resize image's longer size to size_limit if longer size larger than size_limit
  164. h, w = np_img.shape[:2]
  165. if max(h, w) > size_limit:
  166. ratio = size_limit / max(h, w)
  167. new_w = int(w * ratio + 0.5)
  168. new_h = int(h * ratio + 0.5)
  169. return cv2.resize(np_img, dsize=(new_w, new_h), interpolation=interpolation)
  170. else:
  171. return np_img
  172. def pad_img_to_modulo(
  173. img: np.ndarray, mod: int, square: bool = False, min_size: Optional[int] = None
  174. ):
  175. """
  176. Args:
  177. img: [H, W, C]
  178. mod:
  179. square: 是否为正方形
  180. min_size:
  181. Returns:
  182. """
  183. if len(img.shape) == 2:
  184. img = img[:, :, np.newaxis]
  185. height, width = img.shape[:2]
  186. out_height = ceil_modulo(height, mod)
  187. out_width = ceil_modulo(width, mod)
  188. if min_size is not None:
  189. assert min_size % mod == 0
  190. out_width = max(min_size, out_width)
  191. out_height = max(min_size, out_height)
  192. if square:
  193. max_size = max(out_height, out_width)
  194. out_height = max_size
  195. out_width = max_size
  196. return np.pad(
  197. img,
  198. ((0, out_height - height), (0, out_width - width), (0, 0)),
  199. mode="symmetric",
  200. )
  201. def boxes_from_mask(mask: np.ndarray) -> List[np.ndarray]:
  202. """
  203. Args:
  204. mask: (h, w, 1) 0~255
  205. Returns:
  206. """
  207. height, width = mask.shape[:2]
  208. _, thresh = cv2.threshold(mask, 127, 255, 0)
  209. contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
  210. boxes = []
  211. for cnt in contours:
  212. x, y, w, h = cv2.boundingRect(cnt)
  213. box = np.array([x, y, x + w, y + h]).astype(int)
  214. box[::2] = np.clip(box[::2], 0, width)
  215. box[1::2] = np.clip(box[1::2], 0, height)
  216. boxes.append(box)
  217. return boxes
  218. def only_keep_largest_contour(mask: np.ndarray) -> List[np.ndarray]:
  219. """
  220. Args:
  221. mask: (h, w) 0~255
  222. Returns:
  223. """
  224. _, thresh = cv2.threshold(mask, 127, 255, 0)
  225. contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
  226. max_area = 0
  227. max_index = -1
  228. for i, cnt in enumerate(contours):
  229. area = cv2.contourArea(cnt)
  230. if area > max_area:
  231. max_area = area
  232. max_index = i
  233. if max_index != -1:
  234. new_mask = np.zeros_like(mask)
  235. return cv2.drawContours(new_mask, contours, max_index, 255, -1)
  236. else:
  237. return mask
  238. def is_mac():
  239. return sys.platform == "darwin"
  240. def get_image_ext(img_bytes):
  241. w = imghdr.what("", img_bytes)
  242. if w is None:
  243. w = "jpeg"
  244. return w
  245. def decode_base64_to_image(
  246. encoding: str, gray=False
  247. ) -> Tuple[np.array, Optional[np.array], Dict, str]:
  248. if encoding.startswith("data:image/") or encoding.startswith(
  249. "data:application/octet-stream;base64,"
  250. ):
  251. encoding = encoding.split(";")[1].split(",")[1]
  252. image_bytes = base64.b64decode(encoding)
  253. ext = get_image_ext(image_bytes)
  254. image = Image.open(io.BytesIO(image_bytes))
  255. alpha_channel = None
  256. try:
  257. image = ImageOps.exif_transpose(image)
  258. except:
  259. pass
  260. # exif_transpose will remove exif rotate info,we must call image.info after exif_transpose
  261. infos = image.info
  262. if gray:
  263. image = image.convert("L")
  264. np_img = np.array(image)
  265. else:
  266. if image.mode == "RGBA":
  267. np_img = np.array(image)
  268. alpha_channel = np_img[:, :, -1]
  269. np_img = cv2.cvtColor(np_img, cv2.COLOR_RGBA2RGB)
  270. else:
  271. image = image.convert("RGB")
  272. np_img = np.array(image)
  273. return np_img, alpha_channel, infos, ext
  274. def encode_pil_to_base64(image: Image, quality: int, infos: Dict) -> bytes:
  275. img_bytes = pil_to_bytes(
  276. image,
  277. "png",
  278. quality=quality,
  279. infos=infos,
  280. )
  281. return base64.b64encode(img_bytes)
  282. def concat_alpha_channel(rgb_np_img, alpha_channel) -> np.ndarray:
  283. if alpha_channel is not None:
  284. if alpha_channel.shape[:2] != rgb_np_img.shape[:2]:
  285. alpha_channel = cv2.resize(
  286. alpha_channel, dsize=(rgb_np_img.shape[1], rgb_np_img.shape[0])
  287. )
  288. rgb_np_img = np.concatenate(
  289. (rgb_np_img, alpha_channel[:, :, np.newaxis]), axis=-1
  290. )
  291. return rgb_np_img
  292. def adjust_mask(mask: np.ndarray, kernel_size: int, operate):
  293. # fronted brush color "ffcc00bb"
  294. # kernel_size = kernel_size*2+1
  295. mask[mask >= 127] = 255
  296. mask[mask < 127] = 0
  297. if operate == "reverse":
  298. mask = 255 - mask
  299. else:
  300. kernel = cv2.getStructuringElement(
  301. cv2.MORPH_ELLIPSE, (2 * kernel_size + 1, 2 * kernel_size + 1)
  302. )
  303. if operate == "expand":
  304. mask = cv2.dilate(
  305. mask,
  306. kernel,
  307. iterations=1,
  308. )
  309. else:
  310. mask = cv2.erode(
  311. mask,
  312. kernel,
  313. iterations=1,
  314. )
  315. res_mask = np.zeros((mask.shape[0], mask.shape[1], 4), dtype=np.uint8)
  316. res_mask[mask > 128] = [255, 203, 0, int(255 * 0.73)]
  317. res_mask = cv2.cvtColor(res_mask, cv2.COLOR_BGRA2RGBA)
  318. return res_mask
  319. def gen_frontend_mask(bgr_or_gray_mask):
  320. if len(bgr_or_gray_mask.shape) == 3 and bgr_or_gray_mask.shape[2] != 1:
  321. bgr_or_gray_mask = cv2.cvtColor(bgr_or_gray_mask, cv2.COLOR_BGR2GRAY)
  322. # fronted brush color "ffcc00bb"
  323. # TODO: how to set kernel size?
  324. kernel_size = 9
  325. bgr_or_gray_mask = cv2.dilate(
  326. bgr_or_gray_mask,
  327. np.ones((kernel_size, kernel_size), np.uint8),
  328. iterations=1,
  329. )
  330. res_mask = np.zeros(
  331. (bgr_or_gray_mask.shape[0], bgr_or_gray_mask.shape[1], 4), dtype=np.uint8
  332. )
  333. res_mask[bgr_or_gray_mask > 128] = [255, 203, 0, int(255 * 0.73)]
  334. res_mask = cv2.cvtColor(res_mask, cv2.COLOR_BGRA2RGBA)
  335. return res_mask