import shutil import tempfile from pathlib import Path import requests import time import streamlit as st from sorawm.core import SoraWM def main(): st.set_page_config( page_title="Sora Watermark Cleaner", page_icon="🎬", layout="centered" ) st.title("🎬 Sora Watermark Cleaner") st.markdown("Remove watermarks from Sora-generated videos with ease") # 后端模式切换 use_backend = st.sidebar.checkbox("使用后端服务 (FastAPI)", value=True) backend_url = st.sidebar.text_input( "后端地址", value="http://192.168.245.14:5344", help="FastAPI 服务的基础地址", ) # Initialize SoraWM 仅用于本地模式 if not use_backend: if "sora_wm" not in st.session_state: with st.spinner("Loading AI models..."): st.session_state.sora_wm = SoraWM() st.markdown("---") # File uploader uploaded_file = st.file_uploader( "Upload your video", type=["mp4", "avi", "mov", "mkv"], help="Select a video file to remove watermarks", ) if uploaded_file is not None: # Display video info st.success(f"✅ Uploaded: {uploaded_file.name}") st.video(uploaded_file) # Process button if st.button("🚀 Remove Watermark", type="primary", use_container_width=True): try: # Create progress bar and status text progress_bar = st.progress(0) status_text = st.empty() if use_backend: # 后端模式:提交任务 submit_url = f"{backend_url}/submit_remove_task" video_bytes = uploaded_file.getvalue() files = {"video": (uploaded_file.name, video_bytes, "video/mp4")} status_text.text("📤 正在提交到后端...") resp = requests.post(submit_url, files=files, timeout=60) resp.raise_for_status() data = resp.json() task_id = data.get("task_id") if not task_id: raise RuntimeError("后端未返回 task_id") # 轮询任务状态 status_url = f"{backend_url}/get_results" download_url = None while True: r = requests.get(status_url, params={"remove_task_id": task_id}, timeout=30) r.raise_for_status() info = r.json() percentage = int(info.get("percentage", 0)) status = info.get("status", "UNKNOWN") progress_bar.progress(percentage / 100) if percentage < 50: status_text.text(f"🔍 后端检测水印... {percentage}% ({status})") elif percentage < 95: status_text.text(f"🧹 后端移除水印... {percentage}% ({status})") else: status_text.text(f"🎵 后端合并音频... {percentage}% ({status})") if status == "FINISHED": download_url = info.get("download_url") break elif status == "ERROR": raise RuntimeError("后端处理失败") time.sleep(1) if not download_url: raise RuntimeError("后端未提供下载链接") final_url = f"{backend_url}{download_url}" progress_bar.progress(1.0) status_text.text("✅ 后端处理完成!") st.success("✅ 水印移除成功!") st.markdown("### 处理结果") st.video(final_url) # 下载按钮(拉取到本地提供下载) status_text.text("⬇️ 正在拉取处理后的视频用于下载...") dl_resp = requests.get(final_url, timeout=300) dl_resp.raise_for_status() st.download_button( label="⬇️ 下载处理后视频", data=dl_resp.content, file_name=f"cleaned_{uploaded_file.name}", mime="video/mp4", use_container_width=True, ) else: # 本地模式处理 with tempfile.TemporaryDirectory() as tmp_dir: tmp_path = Path(tmp_dir) # Save uploaded file input_path = tmp_path / uploaded_file.name with open(input_path, "wb") as f: f.write(uploaded_file.read()) # Process video output_path = tmp_path / f"cleaned_{uploaded_file.name}" def update_progress(progress: int): progress_bar.progress(progress / 100) if progress < 50: status_text.text(f"🔍 Detecting watermarks... {progress}%") elif progress < 95: status_text.text(f"🧹 Removing watermarks... {progress}%") else: status_text.text(f"🎵 Merging audio... {progress}%") # Run the watermark removal with progress callback st.session_state.sora_wm.run( input_path, output_path, progress_callback=update_progress ) # Complete the progress bar progress_bar.progress(1.0) status_text.text("✅ Processing complete!") st.success("✅ Watermark removed successfully!") # Display result st.markdown("### Result") st.video(str(output_path)) # Download button with open(output_path, "rb") as f: st.download_button( label="⬇️ Download Cleaned Video", data=f, file_name=f"cleaned_{uploaded_file.name}", mime="video/mp4", use_container_width=True, ) except Exception as e: st.error(f"❌ 处理失败: {str(e)}") # Footer st.markdown("---") st.markdown( """

Built with ❤️ using Streamlit and AI

GitHub Repository

""", unsafe_allow_html=True, ) if __name__ == "__main__": main()