core.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199
  1. from pathlib import Path
  2. from typing import Callable
  3. import ffmpeg
  4. import numpy as np
  5. from loguru import logger
  6. from tqdm import tqdm
  7. from sorawm.utils.video_utils import VideoLoader
  8. from sorawm.watermark_cleaner import WaterMarkCleaner
  9. from sorawm.watermark_detector import SoraWaterMarkDetector
  10. from sorawm.utils.imputation_utils import (
  11. find_2d_data_bkps,
  12. get_interval_average_bbox,
  13. find_idxs_interval,
  14. )
  15. class SoraWM:
  16. def __init__(self):
  17. self.detector = SoraWaterMarkDetector()
  18. self.cleaner = WaterMarkCleaner()
  19. def run(
  20. self,
  21. input_video_path: Path,
  22. output_video_path: Path,
  23. progress_callback: Callable[[int], None] | None = None,
  24. ):
  25. input_video_loader = VideoLoader(input_video_path)
  26. output_video_path.parent.mkdir(parents=True, exist_ok=True)
  27. width = input_video_loader.width
  28. height = input_video_loader.height
  29. fps = input_video_loader.fps
  30. total_frames = input_video_loader.total_frames
  31. temp_output_path = output_video_path.parent / f"temp_{output_video_path.name}"
  32. output_options = {
  33. "pix_fmt": "yuv420p",
  34. "vcodec": "libx264",
  35. "preset": "slow",
  36. }
  37. if input_video_loader.original_bitrate:
  38. output_options["video_bitrate"] = str(
  39. int(int(input_video_loader.original_bitrate) * 1.2)
  40. )
  41. else:
  42. output_options["crf"] = "18"
  43. process_out = (
  44. ffmpeg.input(
  45. "pipe:",
  46. format="rawvideo",
  47. pix_fmt="bgr24",
  48. s=f"{width}x{height}",
  49. r=fps,
  50. )
  51. .output(str(temp_output_path), **output_options)
  52. .overwrite_output()
  53. .global_args("-loglevel", "error")
  54. .run_async(pipe_stdin=True)
  55. )
  56. frame_bboxes = {}
  57. detect_missed = []
  58. bbox_centers = []
  59. bboxes = []
  60. logger.debug(
  61. f"total frames: {total_frames}, fps: {fps}, width: {width}, height: {height}"
  62. )
  63. for idx, frame in enumerate(
  64. tqdm(input_video_loader, total=total_frames, desc="Detect watermarks")
  65. ):
  66. detection_result = self.detector.detect(frame)
  67. if detection_result["detected"]:
  68. frame_bboxes[idx] = { "bbox": detection_result["bbox"]}
  69. x1, y1, x2, y2 = detection_result["bbox"]
  70. bbox_centers.append((int((x1 + x2) / 2), int((y1 + y2) / 2)))
  71. bboxes.append((x1, y1, x2, y2))
  72. else:
  73. frame_bboxes[idx] = {"bbox": None}
  74. detect_missed.append(idx)
  75. bbox_centers.append(None)
  76. bboxes.append(None)
  77. # 10% - 50%
  78. if progress_callback and idx % 10 == 0:
  79. progress = 10 + int((idx / total_frames) * 40)
  80. progress_callback(progress)
  81. logger.debug(f"detect missed frames: {detect_missed}")
  82. # logger.debug(f"bbox centers: \n{bbox_centers}")
  83. if detect_missed:
  84. # 1. find the bkps of the bbox centers
  85. bkps = find_2d_data_bkps(bbox_centers)
  86. # add the start and end position, to form the complete interval boundaries
  87. bkps_full = [0] + bkps + [total_frames]
  88. # logger.debug(f"bkps intervals: {bkps_full}")
  89. # 2. calculate the average bbox of each interval
  90. interval_bboxes = get_interval_average_bbox(bboxes, bkps_full)
  91. # logger.debug(f"interval average bboxes: {interval_bboxes}")
  92. # 3. find the interval index of each missed frame
  93. missed_intervals = find_idxs_interval(detect_missed, bkps_full)
  94. # logger.debug(
  95. # f"missed frame intervals: {list(zip(detect_missed, missed_intervals))}"
  96. # )
  97. # 4. fill the missed frames with the average bbox of the corresponding interval
  98. for missed_idx, interval_idx in zip(detect_missed, missed_intervals):
  99. if (
  100. interval_idx < len(interval_bboxes)
  101. and interval_bboxes[interval_idx] is not None
  102. ):
  103. frame_bboxes[missed_idx]["bbox"] = interval_bboxes[interval_idx]
  104. logger.debug(f"Filled missed frame {missed_idx} with bbox:\n"
  105. f" {interval_bboxes[interval_idx]}")
  106. else:
  107. # if the interval has no valid bbox, use the previous and next frame to complete (fallback strategy)
  108. before = max(missed_idx - 1, 0)
  109. after = min(missed_idx + 1, total_frames - 1)
  110. before_box = frame_bboxes[before]["bbox"]
  111. after_box = frame_bboxes[after]["bbox"]
  112. if before_box:
  113. frame_bboxes[missed_idx]["bbox"] = before_box
  114. elif after_box:
  115. frame_bboxes[missed_idx]["bbox"] = after_box
  116. else:
  117. del bboxes
  118. del bbox_centers
  119. del detect_missed
  120. input_video_loader = VideoLoader(input_video_path)
  121. for idx, frame in enumerate(tqdm(input_video_loader, total=total_frames, desc="Remove watermarks")):
  122. # for idx in tqdm(range(total_frames), desc="Remove watermarks"):
  123. # frame_info =
  124. bbox = frame_bboxes[idx]["bbox"]
  125. if bbox is not None:
  126. x1, y1, x2, y2 = bbox
  127. mask = np.zeros((height, width), dtype=np.uint8)
  128. mask[y1:y2, x1:x2] = 255
  129. cleaned_frame = self.cleaner.clean(frame, mask)
  130. else:
  131. cleaned_frame = frame
  132. process_out.stdin.write(cleaned_frame.tobytes())
  133. # 50% - 95%
  134. if progress_callback and idx % 10 == 0:
  135. progress = 50 + int((idx / total_frames) * 45)
  136. progress_callback(progress)
  137. process_out.stdin.close()
  138. process_out.wait()
  139. # 95% - 99%
  140. if progress_callback:
  141. progress_callback(95)
  142. self.merge_audio_track(input_video_path, temp_output_path, output_video_path)
  143. if progress_callback:
  144. progress_callback(99)
  145. def merge_audio_track(
  146. self, input_video_path: Path, temp_output_path: Path, output_video_path: Path
  147. ):
  148. logger.info("Merging audio track...")
  149. video_stream = ffmpeg.input(str(temp_output_path))
  150. audio_stream = ffmpeg.input(str(input_video_path)).audio
  151. (
  152. ffmpeg.output(
  153. video_stream,
  154. audio_stream,
  155. str(output_video_path),
  156. vcodec="copy",
  157. acodec="aac",
  158. )
  159. .overwrite_output()
  160. .run(quiet=True)
  161. )
  162. # Clean up temporary file
  163. temp_output_path.unlink()
  164. logger.info(f"Saved no watermark video with audio at: {output_video_path}")
  165. if __name__ == "__main__":
  166. from pathlib import Path
  167. input_video_path = Path(
  168. "resources/19700121_1645_68e0a027836c8191a50bea3717ea7485.mp4"
  169. )
  170. output_video_path = Path("outputs/sora_watermark_removed.mp4")
  171. sora_wm = SoraWM()
  172. sora_wm.run(input_video_path, output_video_path)