batch_processing.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. import json
  2. from pathlib import Path
  3. from typing import Dict, Optional
  4. import cv2
  5. import numpy as np
  6. from loguru import logger
  7. from PIL import Image
  8. from rich.console import Console
  9. from rich.progress import (
  10. BarColumn,
  11. MofNCompleteColumn,
  12. Progress,
  13. SpinnerColumn,
  14. TaskProgressColumn,
  15. TextColumn,
  16. TimeElapsedColumn,
  17. )
  18. from sorawm.iopaint.helper import pil_to_bytes
  19. from sorawm.iopaint.model.utils import torch_gc
  20. from sorawm.iopaint.model_manager import ModelManager
  21. from sorawm.iopaint.schema import InpaintRequest
  22. def glob_images(path: Path) -> Dict[str, Path]:
  23. # png/jpg/jpeg
  24. if path.is_file():
  25. return {path.stem: path}
  26. elif path.is_dir():
  27. res = {}
  28. for it in path.glob("*.*"):
  29. if it.suffix.lower() in [".png", ".jpg", ".jpeg"]:
  30. res[it.stem] = it
  31. return res
  32. def batch_inpaint(
  33. model: str,
  34. device,
  35. image: Path,
  36. mask: Path,
  37. output: Path,
  38. config: Optional[Path] = None,
  39. concat: bool = False,
  40. ):
  41. if image.is_dir() and output.is_file():
  42. logger.error(
  43. "invalid --output: when image is a directory, output should be a directory"
  44. )
  45. exit(-1)
  46. output.mkdir(parents=True, exist_ok=True)
  47. image_paths = glob_images(image)
  48. mask_paths = glob_images(mask)
  49. if len(image_paths) == 0:
  50. logger.error("invalid --image: empty image folder")
  51. exit(-1)
  52. if len(mask_paths) == 0:
  53. logger.error("invalid --mask: empty mask folder")
  54. exit(-1)
  55. if config is None:
  56. inpaint_request = InpaintRequest()
  57. logger.info(f"Using default config: {inpaint_request}")
  58. else:
  59. with open(config, "r", encoding="utf-8") as f:
  60. inpaint_request = InpaintRequest(**json.load(f))
  61. logger.info(f"Using config: {inpaint_request}")
  62. model_manager = ModelManager(name=model, device=device)
  63. first_mask = list(mask_paths.values())[0]
  64. console = Console()
  65. with Progress(
  66. SpinnerColumn(),
  67. TextColumn("[progress.description]{task.description}"),
  68. BarColumn(),
  69. TaskProgressColumn(),
  70. MofNCompleteColumn(),
  71. TimeElapsedColumn(),
  72. console=console,
  73. transient=False,
  74. ) as progress:
  75. task = progress.add_task("Batch processing...", total=len(image_paths))
  76. for stem, image_p in image_paths.items():
  77. if stem not in mask_paths and mask.is_dir():
  78. progress.log(f"mask for {image_p} not found")
  79. progress.update(task, advance=1)
  80. continue
  81. mask_p = mask_paths.get(stem, first_mask)
  82. infos = Image.open(image_p).info
  83. img = np.array(Image.open(image_p).convert("RGB"))
  84. mask_img = np.array(Image.open(mask_p).convert("L"))
  85. if mask_img.shape[:2] != img.shape[:2]:
  86. progress.log(
  87. f"resize mask {mask_p.name} to image {image_p.name} size: {img.shape[:2]}"
  88. )
  89. mask_img = cv2.resize(
  90. mask_img,
  91. (img.shape[1], img.shape[0]),
  92. interpolation=cv2.INTER_NEAREST,
  93. )
  94. mask_img[mask_img >= 127] = 255
  95. mask_img[mask_img < 127] = 0
  96. # bgr
  97. inpaint_result = model_manager(img, mask_img, inpaint_request)
  98. inpaint_result = cv2.cvtColor(inpaint_result, cv2.COLOR_BGR2RGB)
  99. if concat:
  100. mask_img = cv2.cvtColor(mask_img, cv2.COLOR_GRAY2RGB)
  101. inpaint_result = cv2.hconcat([img, mask_img, inpaint_result])
  102. img_bytes = pil_to_bytes(Image.fromarray(inpaint_result), "png", 100, infos)
  103. save_p = output / f"{stem}.png"
  104. with open(save_p, "wb") as fw:
  105. fw.write(img_bytes)
  106. progress.update(task, advance=1)
  107. torch_gc()
  108. # pid = psutil.Process().pid
  109. # memory_info = psutil.Process(pid).memory_info()
  110. # memory_in_mb = memory_info.rss / (1024 * 1024)
  111. # print(f"原图大小:{img.shape},当前进程的内存占用:{memory_in_mb}MB")