| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411 |
- import base64
- import hashlib
- import imghdr
- import io
- import os
- import sys
- from typing import Dict, List, Optional, Tuple
- from urllib.parse import urlparse
- import cv2
- import numpy as np
- import torch
- from loguru import logger
- from PIL import Image, ImageOps, PngImagePlugin
- from torch.hub import download_url_to_file, get_dir
- from sorawm.iopaint.const import MPS_UNSUPPORT_MODELS
- def md5sum(filename):
- md5 = hashlib.md5()
- with open(filename, "rb") as f:
- for chunk in iter(lambda: f.read(128 * md5.block_size), b""):
- md5.update(chunk)
- return md5.hexdigest()
- def switch_mps_device(model_name, device):
- if model_name in MPS_UNSUPPORT_MODELS and str(device) == "mps":
- logger.info(f"{model_name} not support mps, switch to cpu")
- return torch.device("cpu")
- return device
- def get_cache_path_by_url(url):
- parts = urlparse(url)
- hub_dir = get_dir()
- model_dir = os.path.join(hub_dir, "checkpoints")
- if not os.path.isdir(model_dir):
- os.makedirs(model_dir)
- filename = os.path.basename(parts.path)
- cached_file = os.path.join(model_dir, filename)
- return cached_file
- def download_model(url, model_md5: str = None):
- if os.path.exists(url):
- cached_file = url
- else:
- cached_file = get_cache_path_by_url(url)
- if not os.path.exists(cached_file):
- sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
- hash_prefix = None
- download_url_to_file(url, cached_file, hash_prefix, progress=True)
- if model_md5:
- _md5 = md5sum(cached_file)
- if model_md5 == _md5:
- logger.info(f"Download model success, md5: {_md5}")
- else:
- try:
- os.remove(cached_file)
- logger.error(
- f"Model md5: {_md5}, expected md5: {model_md5}, wrong model deleted. Please restart sorawm.iopaint."
- f"If you still have errors, please try download model manually first https://lama-cleaner-docs.vercel.app/install/download_model_manually.\n"
- )
- except:
- logger.error(
- f"Model md5: {_md5}, expected md5: {model_md5}, please delete {cached_file} and restart sorawm.iopaint."
- )
- exit(-1)
- return cached_file
- def ceil_modulo(x, mod):
- if x % mod == 0:
- return x
- return (x // mod + 1) * mod
- def handle_error(model_path, model_md5, e):
- _md5 = md5sum(model_path)
- if _md5 != model_md5:
- try:
- os.remove(model_path)
- logger.error(
- f"Model md5: {_md5}, expected md5: {model_md5}, wrong model deleted. Please restart sorawm.iopaint."
- f"If you still have errors, please try download model manually first https://lama-cleaner-docs.vercel.app/install/download_model_manually.\n"
- )
- except:
- logger.error(
- f"Model md5: {_md5}, expected md5: {model_md5}, please delete {model_path} and restart sorawm.iopaint."
- )
- else:
- logger.error(
- f"Failed to load model {model_path},"
- f"please submit an issue at https://github.com/Sanster/lama-cleaner/issues and include a screenshot of the error:\n{e}"
- )
- exit(-1)
- def load_jit_model(url_or_path, device, model_md5: str):
- if os.path.exists(url_or_path):
- model_path = url_or_path
- else:
- model_path = download_model(url_or_path, model_md5)
- logger.info(f"Loading model from: {model_path}")
- try:
- model = torch.jit.load(model_path, map_location="cpu").to(device)
- except Exception as e:
- handle_error(model_path, model_md5, e)
- model.eval()
- return model
- def load_model(model: torch.nn.Module, url_or_path, device, model_md5):
- if os.path.exists(url_or_path):
- model_path = url_or_path
- else:
- model_path = download_model(url_or_path, model_md5)
- try:
- logger.info(f"Loading model from: {model_path}")
- state_dict = torch.load(model_path, map_location="cpu")
- model.load_state_dict(state_dict, strict=True)
- model.to(device)
- except Exception as e:
- handle_error(model_path, model_md5, e)
- model.eval()
- return model
- def numpy_to_bytes(image_numpy: np.ndarray, ext: str) -> bytes:
- data = cv2.imencode(
- f".{ext}",
- image_numpy,
- [int(cv2.IMWRITE_JPEG_QUALITY), 100, int(cv2.IMWRITE_PNG_COMPRESSION), 0],
- )[1]
- image_bytes = data.tobytes()
- return image_bytes
- def pil_to_bytes(pil_img, ext: str, quality: int = 95, infos={}) -> bytes:
- with io.BytesIO() as output:
- kwargs = {k: v for k, v in infos.items() if v is not None}
- if ext == "jpg":
- ext = "jpeg"
- if "png" == ext.lower() and "parameters" in kwargs:
- pnginfo_data = PngImagePlugin.PngInfo()
- pnginfo_data.add_text("parameters", kwargs["parameters"])
- kwargs["pnginfo"] = pnginfo_data
- pil_img.save(output, format=ext, quality=quality, **kwargs)
- image_bytes = output.getvalue()
- return image_bytes
- def load_img(img_bytes, gray: bool = False, return_info: bool = False):
- alpha_channel = None
- image = Image.open(io.BytesIO(img_bytes))
- if return_info:
- infos = image.info
- try:
- image = ImageOps.exif_transpose(image)
- except:
- pass
- if gray:
- image = image.convert("L")
- np_img = np.array(image)
- else:
- if image.mode == "RGBA":
- np_img = np.array(image)
- alpha_channel = np_img[:, :, -1]
- np_img = cv2.cvtColor(np_img, cv2.COLOR_RGBA2RGB)
- else:
- image = image.convert("RGB")
- np_img = np.array(image)
- if return_info:
- return np_img, alpha_channel, infos
- return np_img, alpha_channel
- def norm_img(np_img):
- if len(np_img.shape) == 2:
- np_img = np_img[:, :, np.newaxis]
- np_img = np.transpose(np_img, (2, 0, 1))
- np_img = np_img.astype("float32") / 255
- return np_img
- def resize_max_size(
- np_img, size_limit: int, interpolation=cv2.INTER_CUBIC
- ) -> np.ndarray:
- # Resize image's longer size to size_limit if longer size larger than size_limit
- h, w = np_img.shape[:2]
- if max(h, w) > size_limit:
- ratio = size_limit / max(h, w)
- new_w = int(w * ratio + 0.5)
- new_h = int(h * ratio + 0.5)
- return cv2.resize(np_img, dsize=(new_w, new_h), interpolation=interpolation)
- else:
- return np_img
- def pad_img_to_modulo(
- img: np.ndarray, mod: int, square: bool = False, min_size: Optional[int] = None
- ):
- """
- Args:
- img: [H, W, C]
- mod:
- square: 是否为正方形
- min_size:
- Returns:
- """
- if len(img.shape) == 2:
- img = img[:, :, np.newaxis]
- height, width = img.shape[:2]
- out_height = ceil_modulo(height, mod)
- out_width = ceil_modulo(width, mod)
- if min_size is not None:
- assert min_size % mod == 0
- out_width = max(min_size, out_width)
- out_height = max(min_size, out_height)
- if square:
- max_size = max(out_height, out_width)
- out_height = max_size
- out_width = max_size
- return np.pad(
- img,
- ((0, out_height - height), (0, out_width - width), (0, 0)),
- mode="symmetric",
- )
- def boxes_from_mask(mask: np.ndarray) -> List[np.ndarray]:
- """
- Args:
- mask: (h, w, 1) 0~255
- Returns:
- """
- height, width = mask.shape[:2]
- _, thresh = cv2.threshold(mask, 127, 255, 0)
- contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
- boxes = []
- for cnt in contours:
- x, y, w, h = cv2.boundingRect(cnt)
- box = np.array([x, y, x + w, y + h]).astype(int)
- box[::2] = np.clip(box[::2], 0, width)
- box[1::2] = np.clip(box[1::2], 0, height)
- boxes.append(box)
- return boxes
- def only_keep_largest_contour(mask: np.ndarray) -> List[np.ndarray]:
- """
- Args:
- mask: (h, w) 0~255
- Returns:
- """
- _, thresh = cv2.threshold(mask, 127, 255, 0)
- contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
- max_area = 0
- max_index = -1
- for i, cnt in enumerate(contours):
- area = cv2.contourArea(cnt)
- if area > max_area:
- max_area = area
- max_index = i
- if max_index != -1:
- new_mask = np.zeros_like(mask)
- return cv2.drawContours(new_mask, contours, max_index, 255, -1)
- else:
- return mask
- def is_mac():
- return sys.platform == "darwin"
- def get_image_ext(img_bytes):
- w = imghdr.what("", img_bytes)
- if w is None:
- w = "jpeg"
- return w
- def decode_base64_to_image(
- encoding: str, gray=False
- ) -> Tuple[np.array, Optional[np.array], Dict, str]:
- if encoding.startswith("data:image/") or encoding.startswith(
- "data:application/octet-stream;base64,"
- ):
- encoding = encoding.split(";")[1].split(",")[1]
- image_bytes = base64.b64decode(encoding)
- ext = get_image_ext(image_bytes)
- image = Image.open(io.BytesIO(image_bytes))
- alpha_channel = None
- try:
- image = ImageOps.exif_transpose(image)
- except:
- pass
- # exif_transpose will remove exif rotate info,we must call image.info after exif_transpose
- infos = image.info
- if gray:
- image = image.convert("L")
- np_img = np.array(image)
- else:
- if image.mode == "RGBA":
- np_img = np.array(image)
- alpha_channel = np_img[:, :, -1]
- np_img = cv2.cvtColor(np_img, cv2.COLOR_RGBA2RGB)
- else:
- image = image.convert("RGB")
- np_img = np.array(image)
- return np_img, alpha_channel, infos, ext
- def encode_pil_to_base64(image: Image, quality: int, infos: Dict) -> bytes:
- img_bytes = pil_to_bytes(
- image,
- "png",
- quality=quality,
- infos=infos,
- )
- return base64.b64encode(img_bytes)
- def concat_alpha_channel(rgb_np_img, alpha_channel) -> np.ndarray:
- if alpha_channel is not None:
- if alpha_channel.shape[:2] != rgb_np_img.shape[:2]:
- alpha_channel = cv2.resize(
- alpha_channel, dsize=(rgb_np_img.shape[1], rgb_np_img.shape[0])
- )
- rgb_np_img = np.concatenate(
- (rgb_np_img, alpha_channel[:, :, np.newaxis]), axis=-1
- )
- return rgb_np_img
- def adjust_mask(mask: np.ndarray, kernel_size: int, operate):
- # fronted brush color "ffcc00bb"
- # kernel_size = kernel_size*2+1
- mask[mask >= 127] = 255
- mask[mask < 127] = 0
- if operate == "reverse":
- mask = 255 - mask
- else:
- kernel = cv2.getStructuringElement(
- cv2.MORPH_ELLIPSE, (2 * kernel_size + 1, 2 * kernel_size + 1)
- )
- if operate == "expand":
- mask = cv2.dilate(
- mask,
- kernel,
- iterations=1,
- )
- else:
- mask = cv2.erode(
- mask,
- kernel,
- iterations=1,
- )
- res_mask = np.zeros((mask.shape[0], mask.shape[1], 4), dtype=np.uint8)
- res_mask[mask > 128] = [255, 203, 0, int(255 * 0.73)]
- res_mask = cv2.cvtColor(res_mask, cv2.COLOR_BGRA2RGBA)
- return res_mask
- def gen_frontend_mask(bgr_or_gray_mask):
- if len(bgr_or_gray_mask.shape) == 3 and bgr_or_gray_mask.shape[2] != 1:
- bgr_or_gray_mask = cv2.cvtColor(bgr_or_gray_mask, cv2.COLOR_BGR2GRAY)
- # fronted brush color "ffcc00bb"
- # TODO: how to set kernel size?
- kernel_size = 9
- bgr_or_gray_mask = cv2.dilate(
- bgr_or_gray_mask,
- np.ones((kernel_size, kernel_size), np.uint8),
- iterations=1,
- )
- res_mask = np.zeros(
- (bgr_or_gray_mask.shape[0], bgr_or_gray_mask.shape[1], 4), dtype=np.uint8
- )
- res_mask[bgr_or_gray_mask > 128] = [255, 203, 0, int(255 * 0.73)]
- res_mask = cv2.cvtColor(res_mask, cv2.COLOR_BGRA2RGBA)
- return res_mask
|