| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199 |
- from pathlib import Path
- from typing import Callable
- import ffmpeg
- import numpy as np
- from loguru import logger
- from tqdm import tqdm
- from sorawm.utils.video_utils import VideoLoader
- from sorawm.watermark_cleaner import WaterMarkCleaner
- from sorawm.watermark_detector import SoraWaterMarkDetector
- from sorawm.utils.imputation_utils import (
- find_2d_data_bkps,
- get_interval_average_bbox,
- find_idxs_interval,
- )
- class SoraWM:
- def __init__(self):
- self.detector = SoraWaterMarkDetector()
- self.cleaner = WaterMarkCleaner()
- def run(
- self,
- input_video_path: Path,
- output_video_path: Path,
- progress_callback: Callable[[int], None] | None = None,
- ):
- input_video_loader = VideoLoader(input_video_path)
- output_video_path.parent.mkdir(parents=True, exist_ok=True)
- width = input_video_loader.width
- height = input_video_loader.height
- fps = input_video_loader.fps
- total_frames = input_video_loader.total_frames
- temp_output_path = output_video_path.parent / f"temp_{output_video_path.name}"
- output_options = {
- "pix_fmt": "yuv420p",
- "vcodec": "libx264",
- "preset": "slow",
- }
- if input_video_loader.original_bitrate:
- output_options["video_bitrate"] = str(
- int(int(input_video_loader.original_bitrate) * 1.2)
- )
- else:
- output_options["crf"] = "18"
- process_out = (
- ffmpeg.input(
- "pipe:",
- format="rawvideo",
- pix_fmt="bgr24",
- s=f"{width}x{height}",
- r=fps,
- )
- .output(str(temp_output_path), **output_options)
- .overwrite_output()
- .global_args("-loglevel", "error")
- .run_async(pipe_stdin=True)
- )
- frame_bboxes = {}
- detect_missed = []
- bbox_centers = []
- bboxes = []
- logger.debug(
- f"total frames: {total_frames}, fps: {fps}, width: {width}, height: {height}"
- )
- for idx, frame in enumerate(
- tqdm(input_video_loader, total=total_frames, desc="Detect watermarks")
- ):
- detection_result = self.detector.detect(frame)
- if detection_result["detected"]:
- frame_bboxes[idx] = { "bbox": detection_result["bbox"]}
- x1, y1, x2, y2 = detection_result["bbox"]
- bbox_centers.append((int((x1 + x2) / 2), int((y1 + y2) / 2)))
- bboxes.append((x1, y1, x2, y2))
- else:
- frame_bboxes[idx] = {"bbox": None}
- detect_missed.append(idx)
- bbox_centers.append(None)
- bboxes.append(None)
- # 10% - 50%
- if progress_callback and idx % 10 == 0:
- progress = 10 + int((idx / total_frames) * 40)
- progress_callback(progress)
- logger.debug(f"detect missed frames: {detect_missed}")
- # logger.debug(f"bbox centers: \n{bbox_centers}")
- if detect_missed:
- # 1. find the bkps of the bbox centers
- bkps = find_2d_data_bkps(bbox_centers)
- # add the start and end position, to form the complete interval boundaries
- bkps_full = [0] + bkps + [total_frames]
- # logger.debug(f"bkps intervals: {bkps_full}")
- # 2. calculate the average bbox of each interval
- interval_bboxes = get_interval_average_bbox(bboxes, bkps_full)
- # logger.debug(f"interval average bboxes: {interval_bboxes}")
- # 3. find the interval index of each missed frame
- missed_intervals = find_idxs_interval(detect_missed, bkps_full)
- # logger.debug(
- # f"missed frame intervals: {list(zip(detect_missed, missed_intervals))}"
- # )
- # 4. fill the missed frames with the average bbox of the corresponding interval
- for missed_idx, interval_idx in zip(detect_missed, missed_intervals):
- if (
- interval_idx < len(interval_bboxes)
- and interval_bboxes[interval_idx] is not None
- ):
- frame_bboxes[missed_idx]["bbox"] = interval_bboxes[interval_idx]
- logger.debug(f"Filled missed frame {missed_idx} with bbox:\n"
- f" {interval_bboxes[interval_idx]}")
- else:
- # if the interval has no valid bbox, use the previous and next frame to complete (fallback strategy)
- before = max(missed_idx - 1, 0)
- after = min(missed_idx + 1, total_frames - 1)
- before_box = frame_bboxes[before]["bbox"]
- after_box = frame_bboxes[after]["bbox"]
- if before_box:
- frame_bboxes[missed_idx]["bbox"] = before_box
- elif after_box:
- frame_bboxes[missed_idx]["bbox"] = after_box
- else:
- del bboxes
- del bbox_centers
- del detect_missed
-
- input_video_loader = VideoLoader(input_video_path)
- for idx, frame in enumerate(tqdm(input_video_loader, total=total_frames, desc="Remove watermarks")):
- # for idx in tqdm(range(total_frames), desc="Remove watermarks"):
- # frame_info =
- bbox = frame_bboxes[idx]["bbox"]
- if bbox is not None:
- x1, y1, x2, y2 = bbox
- mask = np.zeros((height, width), dtype=np.uint8)
- mask[y1:y2, x1:x2] = 255
- cleaned_frame = self.cleaner.clean(frame, mask)
- else:
- cleaned_frame = frame
- process_out.stdin.write(cleaned_frame.tobytes())
- # 50% - 95%
- if progress_callback and idx % 10 == 0:
- progress = 50 + int((idx / total_frames) * 45)
- progress_callback(progress)
- process_out.stdin.close()
- process_out.wait()
- # 95% - 99%
- if progress_callback:
- progress_callback(95)
- self.merge_audio_track(input_video_path, temp_output_path, output_video_path)
- if progress_callback:
- progress_callback(99)
- def merge_audio_track(
- self, input_video_path: Path, temp_output_path: Path, output_video_path: Path
- ):
- logger.info("Merging audio track...")
- video_stream = ffmpeg.input(str(temp_output_path))
- audio_stream = ffmpeg.input(str(input_video_path)).audio
- (
- ffmpeg.output(
- video_stream,
- audio_stream,
- str(output_video_path),
- vcodec="copy",
- acodec="aac",
- )
- .overwrite_output()
- .run(quiet=True)
- )
- # Clean up temporary file
- temp_output_path.unlink()
- logger.info(f"Saved no watermark video with audio at: {output_video_path}")
- if __name__ == "__main__":
- from pathlib import Path
- input_video_path = Path(
- "resources/19700121_1645_68e0a027836c8191a50bea3717ea7485.mp4"
- )
- output_video_path = Path("outputs/sora_watermark_removed.mp4")
- sora_wm = SoraWM()
- sora_wm.run(input_video_path, output_video_path)
|