watermark_cleaner.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329
  1. from pathlib import Path
  2. import cv2
  3. import numpy as np
  4. import torch
  5. from loguru import logger
  6. from sorawm.configs import DEFAULT_WATERMARK_REMOVE_MODEL
  7. from sorawm.iopaint.const import DEFAULT_MODEL_DIR
  8. from sorawm.iopaint.download import cli_download_model, scan_models
  9. from sorawm.iopaint.model_manager import ModelManager
  10. from sorawm.iopaint.schema import InpaintRequest
  11. from sorawm.utils.devices_utils import get_device
  12. # This codebase is from https://github.com/Sanster/IOPaint#, thanks for their amazing work!
  13. class WaterMarkCleaner:
  14. def __init__(self):
  15. self.model = DEFAULT_WATERMARK_REMOVE_MODEL
  16. self.device = get_device()
  17. scanned_models = scan_models()
  18. if self.model not in [it.name for it in scanned_models]:
  19. logger.info(
  20. f"{self.model} not found in {DEFAULT_MODEL_DIR}, try to downloading"
  21. )
  22. cli_download_model(self.model)
  23. self.model_manager = ModelManager(name=self.model, device=self.device)
  24. self.inpaint_request = InpaintRequest()
  25. def clean(self, input_image: np.array, watermark_mask: np.array) -> np.array:
  26. inpaint_result = self.model_manager(
  27. input_image, watermark_mask, self.inpaint_request
  28. )
  29. inpaint_result = cv2.cvtColor(inpaint_result, cv2.COLOR_BGR2RGB)
  30. return inpaint_result
  31. if __name__ == "__main__":
  32. from pathlib import Path
  33. import cv2
  34. import numpy as np
  35. from tqdm import tqdm
  36. # ========= 配置 =========
  37. video_path = Path("resources/puppies.mp4")
  38. save_video = True
  39. out_path = Path("outputs/dog_vs_sam_detected.mp4")
  40. window = "Sora watermark (threshold+morph+shape + tracking)"
  41. # 追踪/回退策略参数
  42. PREV_ROI_EXPAND = 2.2 # 上一框宽高的膨胀倍数(>1)
  43. AREA1 = (1000, 2000) # 主检测面积范围
  44. AREA2 = (600, 4000) # 回退阶段面积范围
  45. # =======================
  46. cleaner = SoraWaterMarkCleaner(video_path, video_path)
  47. # 预取一帧确定尺寸/FPS
  48. first_frame = None
  49. for first_frame in cleaner.input_video_loader:
  50. break
  51. assert first_frame is not None, "无法读取视频帧"
  52. H, W = first_frame.shape[:2]
  53. fps = getattr(cleaner.input_video_loader, "fps", 30)
  54. # 输出视频(原 | bw | all-contours | vis 四联画)
  55. writer = None
  56. if save_video:
  57. out_path.parent.mkdir(parents=True, exist_ok=True)
  58. fourcc = cv2.VideoWriter_fourcc(*"avc1")
  59. writer = cv2.VideoWriter(str(out_path), fourcc, fps, (W * 4, H))
  60. if not writer.isOpened():
  61. fourcc = cv2.VideoWriter_fourcc(*"MJPG")
  62. writer = cv2.VideoWriter(str(out_path), fourcc, fps, (W * 4, H))
  63. assert writer.isOpened(), "无法创建输出视频文件"
  64. cv2.namedWindow(window, cv2.WINDOW_NORMAL)
  65. # ---- 工具函数 ----
  66. def _clip_rect(x0, y0, x1, y1, w_img, h_img):
  67. x0 = max(0, min(x0, w_img - 1))
  68. x1 = max(0, min(x1, w_img))
  69. y0 = max(0, min(y0, h_img - 1))
  70. y1 = max(0, min(y1, h_img))
  71. if x1 <= x0:
  72. x1 = x0 + 1
  73. if y1 <= y0:
  74. y1 = y0 + 1
  75. return x0, y0, x1, y1
  76. def _cnt_bbox(cnt):
  77. x, y, w, h = cv2.boundingRect(cnt)
  78. return (x, y, x + w, y + h)
  79. def _bbox_center(b):
  80. x0, y0, x1, y1 = b
  81. return ((x0 + x1) // 2, (y0 + y1) // 2)
  82. def detect_flower_like(image, prev_bbox=None):
  83. """
  84. 识别流程:
  85. 灰度范围 → 自适应阈值 → 仅在 3 个区域 + (可选)上一帧膨胀ROI 内找轮廓
  86. 三个区域:1) 左上20% 2) 左下20% 3) 中间水平带 y∈[0.4H, 0.6H], x∈[0,W]
  87. 返回: bw_region, best_cnt, contours_region, region_boxes, prev_roi_box
  88. """
  89. gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
  90. # 208 ± 20% 亮度范围
  91. low, high = int(round(208 * 0.9)), int(round(208 * 1.1))
  92. mask = ((gray >= low) & (gray <= high)).astype(np.uint8) * 255
  93. # 自适应阈值并限制到亮度范围
  94. bw = cv2.adaptiveThreshold(
  95. gray, 255, cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY, 31, -5
  96. )
  97. bw = cv2.bitwise_and(bw, mask)
  98. # -------- 三个区域:左上/左下/中间带 --------
  99. h_img, w_img = gray.shape[:2]
  100. r_top_left = (0, 0, int(0.2 * w_img), int(0.2 * h_img))
  101. r_bot_left = (0, int(0.8 * h_img), int(0.2 * w_img), h_img)
  102. y0, y1 = int(0.40 * h_img), int(0.60 * h_img) # 中间带
  103. r_mid_band = (0, y0, w_img, y1)
  104. region_mask = np.zeros_like(bw, dtype=np.uint8)
  105. for x0, ys, x1, ye in (r_top_left, r_bot_left):
  106. region_mask[ys:ye, x0:x1] = 255
  107. region_mask[y0:y1, :] = 255
  108. # -------- 追加:上一帧膨胀ROI --------
  109. prev_roi_box = None
  110. if prev_bbox is not None:
  111. px0, py0, px1, py1 = prev_bbox
  112. pw, ph = (px1 - px0), (py1 - py0)
  113. cx, cy = _bbox_center(prev_bbox)
  114. rw = int(pw * PREV_ROI_EXPAND)
  115. rh = int(ph * PREV_ROI_EXPAND)
  116. rx0, ry0 = cx - rw // 2, cy - rh // 2
  117. rx1, ry1 = cx + rw // 2, cy + rh // 2
  118. rx0, ry0, rx1, ry1 = _clip_rect(rx0, ry0, rx1, ry1, w_img, h_img)
  119. region_mask[ry0:ry1, rx0:rx1] = 255
  120. prev_roi_box = (rx0, ry0, rx1, ry1)
  121. bw_region = cv2.bitwise_and(bw, region_mask)
  122. # -------- 轮廓 + 形状筛选 --------
  123. def select_candidates(bw_bin, area_rng):
  124. contours, _ = cv2.findContours(
  125. bw_bin, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
  126. )
  127. cand = []
  128. for cnt in contours:
  129. area = cv2.contourArea(cnt)
  130. if area < area_rng[0] or area > area_rng[1]:
  131. continue
  132. peri = cv2.arcLength(cnt, True)
  133. if peri == 0:
  134. continue
  135. circularity = 4.0 * np.pi * area / (peri * peri)
  136. if 0.55 <= circularity <= 0.95:
  137. cand.append(cnt)
  138. return contours, cand
  139. contours_region, cand1 = select_candidates(bw_region, AREA1)
  140. best_cnt = None
  141. if cand1:
  142. # 若有上一帧,用“离上一帧中心最近”优先;否则取面积最大
  143. if prev_bbox is None:
  144. best_cnt = max(cand1, key=lambda c: cv2.contourArea(c))
  145. else:
  146. pcx, pcy = _bbox_center(prev_bbox)
  147. best_cnt = max(
  148. cand1,
  149. key=lambda c: -(
  150. (((_cnt_bbox(c)[0] + _cnt_bbox(c)[2]) // 2 - pcx) ** 2)
  151. + (((_cnt_bbox(c)[1] + _cnt_bbox(c)[3]) // 2 - pcy) ** 2)
  152. ),
  153. )
  154. else:
  155. # 回退1:仅在上一帧 ROI 内放宽面积
  156. if prev_roi_box is not None:
  157. rx0, ry0, rx1, ry1 = prev_roi_box
  158. roi = np.zeros_like(bw_region)
  159. roi[ry0:ry1, rx0:rx1] = bw_region[ry0:ry1, rx0:rx1]
  160. _, cand2 = select_candidates(roi, AREA2)
  161. if cand2:
  162. if prev_bbox is None:
  163. best_cnt = max(cand2, key=lambda c: cv2.contourArea(c))
  164. else:
  165. pcx, pcy = _bbox_center(prev_bbox)
  166. best_cnt = max(
  167. cand2,
  168. key=lambda c: -(
  169. (((_cnt_bbox(c)[0] + _cnt_bbox(c)[2]) // 2 - pcx) ** 2)
  170. + (
  171. ((_cnt_bbox(c)[1] + _cnt_bbox(c)[3]) // 2 - pcy)
  172. ** 2
  173. )
  174. ),
  175. )
  176. else:
  177. # 回退2:全区域 cand,选最近中心
  178. if prev_bbox is not None:
  179. _, cand3 = select_candidates(bw_region, AREA2)
  180. if cand3:
  181. pcx, pcy = _bbox_center(prev_bbox)
  182. best_cnt = max(
  183. cand3,
  184. key=lambda c: -(
  185. (
  186. ((_cnt_bbox(c)[0] + _cnt_bbox(c)[2]) // 2 - pcx)
  187. ** 2
  188. )
  189. + (
  190. ((_cnt_bbox(c)[1] + _cnt_bbox(c)[3]) // 2 - pcy)
  191. ** 2
  192. )
  193. ),
  194. )
  195. region_boxes = (r_top_left, r_bot_left, r_mid_band, (y0, y1))
  196. return bw_region, best_cnt, contours_region, region_boxes, prev_roi_box
  197. # ---- 时序追踪状态(用字典避免 nonlocal/global) ----
  198. state = {"bbox": None} # 保存上一帧外接框 (x0,y0,x1,y1)
  199. def process_and_show(frame, idx):
  200. img = frame.copy()
  201. bw, best, contours, region_boxes, prev_roi_box = detect_flower_like(
  202. img, state["bbox"]
  203. )
  204. r_top_left, r_bot_left, r_mid_band, (y0, y1) = region_boxes
  205. # 所有轮廓(黄)
  206. allc = img.copy()
  207. if contours:
  208. cv2.drawContours(allc, contours, -1, (0, 255, 255), 1)
  209. # 画三个区域:红框 + 中间带上下红线
  210. def draw_rect(im, rect, color=(0, 0, 255), th=2):
  211. x0, y0r, x1, y1r = rect
  212. cv2.rectangle(im, (x0, y0r), (x1, y1r), color, th)
  213. draw_rect(allc, r_top_left)
  214. draw_rect(allc, r_bot_left)
  215. draw_rect(allc, (r_mid_band[0], r_mid_band[1], r_mid_band[2], r_mid_band[3]))
  216. cv2.line(allc, (0, y0), (img.shape[1], y0), (0, 0, 255), 2)
  217. cv2.line(allc, (0, y1), (img.shape[1], y1), (0, 0, 255), 2)
  218. # 画上一帧的膨胀 ROI(青色)
  219. if prev_roi_box is not None:
  220. x0, y0r, x1, y1r = prev_roi_box
  221. cv2.rectangle(allc, (x0, y0r), (x1, y1r), (255, 255, 0), 2)
  222. # 最终检测
  223. vis = img.copy()
  224. title = "no-detect"
  225. if best is not None:
  226. cv2.drawContours(vis, [best], -1, (0, 255, 0), 2)
  227. x0, y0r, x1, y1r = _cnt_bbox(best)
  228. state["bbox"] = (x0, y0r, x1, y1r) # 更新追踪状态
  229. M = cv2.moments(best)
  230. if M["m00"] > 0:
  231. cx, cy = int(M["m10"] / M["m00"]), int(M["m01"] / M["m00"])
  232. cv2.circle(vis, (cx, cy), 4, (0, 0, 255), -1)
  233. title = "detected"
  234. else:
  235. # 若仍未检测,维持上一状态
  236. cv2.putText(
  237. vis,
  238. "No detection (kept last state)",
  239. (12, 28),
  240. cv2.FONT_HERSHEY_SIMPLEX,
  241. 0.8,
  242. (0, 0, 255),
  243. 2,
  244. )
  245. if state["bbox"] is not None:
  246. x0, y0r, x1, y1r = state["bbox"]
  247. cv2.rectangle(vis, (x0, y0r), (x1, y1r), (255, 255, 0), 2)
  248. # 四联画:原图 | 区域内bw | 所有轮廓 | 最终检测
  249. panel = np.hstack([img, cv2.cvtColor(bw, cv2.COLOR_GRAY2BGR), allc, vis])
  250. cv2.putText(
  251. panel,
  252. f"Frame {idx} | {title}",
  253. (12, 28),
  254. cv2.FONT_HERSHEY_SIMPLEX,
  255. 0.9,
  256. (255, 255, 255),
  257. 2,
  258. )
  259. cv2.imshow(window, panel)
  260. if writer is not None:
  261. if panel.shape[:2] != (H, W * 4):
  262. panel = cv2.resize(panel, (W * 4, H), interpolation=cv2.INTER_AREA)
  263. writer.write(panel)
  264. # 先处理已取出的第一帧
  265. process_and_show(first_frame, 0)
  266. # 按你的遍历方式继续
  267. for idx, frame in enumerate(
  268. tqdm(cleaner.input_video_loader, desc="Processing frames", initial=1, unit="f")
  269. ):
  270. process_and_show(frame, idx)
  271. key = cv2.waitKey(max(1, int(1000 / max(1, int(fps))))) & 0xFF
  272. if key == ord("q"):
  273. break
  274. elif key == ord(" "):
  275. while True:
  276. k = cv2.waitKey(50) & 0xFF
  277. if k in (ord(" "), ord("q")):
  278. if k == ord("q"):
  279. idx = 10**9
  280. break
  281. if idx >= 10**9:
  282. break
  283. if writer is not None:
  284. writer.release()
  285. print(f"[OK] 可视化视频已保存: {out_path}")
  286. cv2.destroyAllWindows()