watermark_detector.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. from pathlib import Path
  2. import numpy as np
  3. from loguru import logger
  4. from ultralytics import YOLO
  5. from sorawm.configs import WATER_MARK_DETECT_YOLO_WEIGHTS
  6. from sorawm.utils.download_utils import download_detector_weights
  7. from sorawm.utils.devices_utils import get_device
  8. from sorawm.utils.video_utils import VideoLoader
  9. # based on the sora tempalte to detect the whole, and then got the icon part area.
  10. class SoraWaterMarkDetector:
  11. def __init__(self):
  12. download_detector_weights()
  13. logger.debug(f"Begin to load yolo water mark detet model.")
  14. self.model = YOLO(WATER_MARK_DETECT_YOLO_WEIGHTS)
  15. self.model.to(str(get_device()))
  16. logger.debug(f"Yolo water mark detet model loaded.")
  17. self.model.eval()
  18. def detect(self, input_image: np.array):
  19. # Run YOLO inference
  20. results = self.model(input_image, verbose=False)
  21. # Extract predictions from the first (and only) result
  22. result = results[0]
  23. # Check if any detections were made
  24. if len(result.boxes) == 0:
  25. return {"detected": False, "bbox": None, "confidence": None, "center": None}
  26. # Get the first detection (highest confidence)
  27. box = result.boxes[0]
  28. # Extract bounding box coordinates (xyxy format)
  29. # Convert tensor to numpy, then to python float, finally to int
  30. xyxy = box.xyxy[0].cpu().numpy()
  31. x1, y1, x2, y2 = float(xyxy[0]), float(xyxy[1]), float(xyxy[2]), float(xyxy[3])
  32. # Extract confidence score
  33. confidence = float(box.conf[0].cpu().numpy())
  34. # Calculate center point
  35. center_x = (x1 + x2) / 2
  36. center_y = (y1 + y2) / 2
  37. return {
  38. "detected": True,
  39. "bbox": (int(x1), int(y1), int(x2), int(y2)),
  40. "confidence": confidence,
  41. "center": (int(center_x), int(center_y)),
  42. }
  43. if __name__ == "__main__":
  44. from pathlib import Path
  45. import cv2
  46. from tqdm import tqdm
  47. # ========= 配置 =========
  48. # video_path = Path("resources/puppies.mp4") # 19700121_1645_68e0a027836c8191a50bea3717ea7485.mp4
  49. video_path = Path("resources/19700121_1645_68e0a027836c8191a50bea3717ea7485.mp4")
  50. save_video = True
  51. out_path = Path("outputs/sora_watermark_yolo_detected.mp4")
  52. window = "Sora Watermark YOLO Detection"
  53. # =======================
  54. # 初始化检测器
  55. detector = SoraWaterMarkDetector()
  56. # 初始化视频加载器
  57. video_loader = VideoLoader(video_path)
  58. # 预取一帧确定尺寸/FPS
  59. first_frame = None
  60. for first_frame in video_loader:
  61. break
  62. assert first_frame is not None, "无法读取视频帧"
  63. H, W = first_frame.shape[:2]
  64. fps = getattr(video_loader, "fps", 30)
  65. # 输出视频设置
  66. writer = None
  67. if save_video:
  68. out_path.parent.mkdir(parents=True, exist_ok=True)
  69. fourcc = cv2.VideoWriter_fourcc(*"avc1")
  70. writer = cv2.VideoWriter(str(out_path), fourcc, fps, (W, H))
  71. if not writer.isOpened():
  72. fourcc = cv2.VideoWriter_fourcc(*"MJPG")
  73. writer = cv2.VideoWriter(str(out_path), fourcc, fps, (W, H))
  74. assert writer.isOpened(), "无法创建输出视频文件"
  75. cv2.namedWindow(window, cv2.WINDOW_NORMAL)
  76. def visualize_detection(frame, detection_result, frame_idx):
  77. """在帧上可视化检测结果"""
  78. vis = frame.copy()
  79. if detection_result["detected"]:
  80. # 绘制边界框
  81. x1, y1, x2, y2 = detection_result["bbox"]
  82. cv2.rectangle(vis, (x1, y1), (x2, y2), (0, 255, 0), 2)
  83. # 绘制中心点
  84. cx, cy = detection_result["center"]
  85. cv2.circle(vis, (cx, cy), 5, (0, 0, 255), -1)
  86. # 显示置信度
  87. conf = detection_result["confidence"]
  88. label = f"Watermark: {conf:.2f}"
  89. # 文本背景
  90. (text_w, text_h), baseline = cv2.getTextSize(
  91. label, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2
  92. )
  93. cv2.rectangle(
  94. vis, (x1, y1 - text_h - 10), (x1 + text_w + 5, y1), (0, 255, 0), -1
  95. )
  96. # 绘制文本
  97. cv2.putText(
  98. vis,
  99. label,
  100. (x1 + 2, y1 - 5),
  101. cv2.FONT_HERSHEY_SIMPLEX,
  102. 0.6,
  103. (0, 0, 0),
  104. 2,
  105. )
  106. status = f"Frame {frame_idx} | DETECTED | Conf: {conf:.3f}"
  107. status_color = (0, 255, 0)
  108. else:
  109. status = f"Frame {frame_idx} | NO WATERMARK"
  110. status_color = (0, 0, 255)
  111. # 显示帧信息
  112. cv2.putText(
  113. vis, status, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.8, status_color, 2
  114. )
  115. return vis
  116. # 处理第一帧
  117. print("开始处理视频...")
  118. detection = detector.detect(first_frame)
  119. vis_frame = visualize_detection(first_frame, detection, 0)
  120. cv2.imshow(window, vis_frame)
  121. if writer is not None:
  122. writer.write(vis_frame)
  123. # 处理剩余帧
  124. total_frames = 0
  125. detected_frames = 0
  126. for idx, frame in enumerate(
  127. tqdm(video_loader, desc="Processing frames", initial=1, unit="f"), start=1
  128. ):
  129. # YOLO 检测
  130. detection = detector.detect(frame)
  131. # 可视化
  132. vis_frame = visualize_detection(frame, detection, idx)
  133. # 统计
  134. total_frames += 1
  135. if detection["detected"]:
  136. detected_frames += 1
  137. # 显示
  138. cv2.imshow(window, vis_frame)
  139. # 保存
  140. if writer is not None:
  141. writer.write(vis_frame)
  142. # 按键控制
  143. key = cv2.waitKey(max(1, int(1000 / max(1, int(fps))))) & 0xFF
  144. if key == ord("q"):
  145. break
  146. elif key == ord(" "): # 空格暂停
  147. while True:
  148. k = cv2.waitKey(50) & 0xFF
  149. if k in (ord(" "), ord("q")):
  150. if k == ord("q"):
  151. idx = 10**9
  152. break
  153. if idx >= 10**9:
  154. break
  155. # 清理
  156. if writer is not None:
  157. writer.release()
  158. print(f"\n[完成] 可视化视频已保存: {out_path}")
  159. # 打印统计信息
  160. total_frames += 1 # 包括第一帧
  161. if detection["detected"]:
  162. detected_frames += 1
  163. print(f"\n=== 检测统计 ===")
  164. print(f"总帧数: {total_frames}")
  165. print(f"检测到水印: {detected_frames} 帧")
  166. print(f"检测率: {detected_frames/total_frames*100:.2f}%")
  167. cv2.destroyAllWindows()