from pathlib import Path from typing import Callable import ffmpeg import numpy as np from loguru import logger from tqdm import tqdm from sorawm.utils.video_utils import VideoLoader from sorawm.watermark_cleaner import WaterMarkCleaner from sorawm.watermark_detector import SoraWaterMarkDetector from sorawm.utils.imputation_utils import ( find_2d_data_bkps, get_interval_average_bbox, find_idxs_interval, ) class SoraWM: def __init__(self): self.detector = SoraWaterMarkDetector() self.cleaner = WaterMarkCleaner() def run( self, input_video_path: Path, output_video_path: Path, progress_callback: Callable[[int], None] | None = None, ): input_video_loader = VideoLoader(input_video_path) output_video_path.parent.mkdir(parents=True, exist_ok=True) width = input_video_loader.width height = input_video_loader.height fps = input_video_loader.fps total_frames = input_video_loader.total_frames temp_output_path = output_video_path.parent / f"temp_{output_video_path.name}" output_options = { "pix_fmt": "yuv420p", "vcodec": "libx264", "preset": "slow", } if input_video_loader.original_bitrate: output_options["video_bitrate"] = str( int(int(input_video_loader.original_bitrate) * 1.2) ) else: output_options["crf"] = "18" process_out = ( ffmpeg.input( "pipe:", format="rawvideo", pix_fmt="bgr24", s=f"{width}x{height}", r=fps, ) .output(str(temp_output_path), **output_options) .overwrite_output() .global_args("-loglevel", "error") .run_async(pipe_stdin=True) ) frame_bboxes = {} detect_missed = [] bbox_centers = [] bboxes = [] logger.debug( f"total frames: {total_frames}, fps: {fps}, width: {width}, height: {height}" ) for idx, frame in enumerate( tqdm(input_video_loader, total=total_frames, desc="Detect watermarks") ): detection_result = self.detector.detect(frame) if detection_result["detected"]: frame_bboxes[idx] = { "bbox": detection_result["bbox"]} x1, y1, x2, y2 = detection_result["bbox"] bbox_centers.append((int((x1 + x2) / 2), int((y1 + y2) / 2))) bboxes.append((x1, y1, x2, y2)) else: frame_bboxes[idx] = {"bbox": None} detect_missed.append(idx) bbox_centers.append(None) bboxes.append(None) # 10% - 50% if progress_callback and idx % 10 == 0: progress = 10 + int((idx / total_frames) * 40) progress_callback(progress) logger.debug(f"detect missed frames: {detect_missed}") # logger.debug(f"bbox centers: \n{bbox_centers}") if detect_missed: # 1. find the bkps of the bbox centers bkps = find_2d_data_bkps(bbox_centers) # add the start and end position, to form the complete interval boundaries bkps_full = [0] + bkps + [total_frames] # logger.debug(f"bkps intervals: {bkps_full}") # 2. calculate the average bbox of each interval interval_bboxes = get_interval_average_bbox(bboxes, bkps_full) # logger.debug(f"interval average bboxes: {interval_bboxes}") # 3. find the interval index of each missed frame missed_intervals = find_idxs_interval(detect_missed, bkps_full) # logger.debug( # f"missed frame intervals: {list(zip(detect_missed, missed_intervals))}" # ) # 4. fill the missed frames with the average bbox of the corresponding interval for missed_idx, interval_idx in zip(detect_missed, missed_intervals): if ( interval_idx < len(interval_bboxes) and interval_bboxes[interval_idx] is not None ): frame_bboxes[missed_idx]["bbox"] = interval_bboxes[interval_idx] logger.debug(f"Filled missed frame {missed_idx} with bbox:\n" f" {interval_bboxes[interval_idx]}") else: # if the interval has no valid bbox, use the previous and next frame to complete (fallback strategy) before = max(missed_idx - 1, 0) after = min(missed_idx + 1, total_frames - 1) before_box = frame_bboxes[before]["bbox"] after_box = frame_bboxes[after]["bbox"] if before_box: frame_bboxes[missed_idx]["bbox"] = before_box elif after_box: frame_bboxes[missed_idx]["bbox"] = after_box else: del bboxes del bbox_centers del detect_missed input_video_loader = VideoLoader(input_video_path) for idx, frame in enumerate(tqdm(input_video_loader, total=total_frames, desc="Remove watermarks")): # for idx in tqdm(range(total_frames), desc="Remove watermarks"): # frame_info = bbox = frame_bboxes[idx]["bbox"] if bbox is not None: x1, y1, x2, y2 = bbox mask = np.zeros((height, width), dtype=np.uint8) mask[y1:y2, x1:x2] = 255 cleaned_frame = self.cleaner.clean(frame, mask) else: cleaned_frame = frame process_out.stdin.write(cleaned_frame.tobytes()) # 50% - 95% if progress_callback and idx % 10 == 0: progress = 50 + int((idx / total_frames) * 45) progress_callback(progress) process_out.stdin.close() process_out.wait() # 95% - 99% if progress_callback: progress_callback(95) self.merge_audio_track(input_video_path, temp_output_path, output_video_path) if progress_callback: progress_callback(99) def merge_audio_track( self, input_video_path: Path, temp_output_path: Path, output_video_path: Path ): logger.info("Merging audio track...") video_stream = ffmpeg.input(str(temp_output_path)) audio_stream = ffmpeg.input(str(input_video_path)).audio ( ffmpeg.output( video_stream, audio_stream, str(output_video_path), vcodec="copy", acodec="aac", ) .overwrite_output() .run(quiet=True) ) # Clean up temporary file temp_output_path.unlink() logger.info(f"Saved no watermark video with audio at: {output_video_path}") if __name__ == "__main__": from pathlib import Path input_video_path = Path( "resources/19700121_1645_68e0a027836c8191a50bea3717ea7485.mp4" ) output_video_path = Path("outputs/sora_watermark_removed.mp4") sora_wm = SoraWM() sora_wm.run(input_video_path, output_video_path)