app.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. import shutil
  2. import tempfile
  3. from pathlib import Path
  4. import requests
  5. import time
  6. import streamlit as st
  7. from sorawm.core import SoraWM
  8. def main():
  9. st.set_page_config(
  10. page_title="Sora Watermark Cleaner", page_icon="🎬", layout="centered"
  11. )
  12. st.title("🎬 Sora Watermark Cleaner")
  13. st.markdown("Remove watermarks from Sora-generated videos with ease")
  14. # 后端模式切换
  15. use_backend = st.sidebar.checkbox("使用后端服务 (FastAPI)", value=True)
  16. backend_url = st.sidebar.text_input(
  17. "后端地址",
  18. value="http://192.168.245.14:5344",
  19. help="FastAPI 服务的基础地址",
  20. )
  21. # Initialize SoraWM 仅用于本地模式
  22. if not use_backend:
  23. if "sora_wm" not in st.session_state:
  24. with st.spinner("Loading AI models..."):
  25. st.session_state.sora_wm = SoraWM()
  26. st.markdown("---")
  27. # File uploader
  28. uploaded_file = st.file_uploader(
  29. "Upload your video",
  30. type=["mp4", "avi", "mov", "mkv"],
  31. help="Select a video file to remove watermarks",
  32. )
  33. if uploaded_file is not None:
  34. # Display video info
  35. st.success(f"✅ Uploaded: {uploaded_file.name}")
  36. st.video(uploaded_file)
  37. # Process button
  38. if st.button("🚀 Remove Watermark", type="primary", use_container_width=True):
  39. try:
  40. # Create progress bar and status text
  41. progress_bar = st.progress(0)
  42. status_text = st.empty()
  43. if use_backend:
  44. # 后端模式:提交任务
  45. submit_url = f"{backend_url}/submit_remove_task"
  46. video_bytes = uploaded_file.getvalue()
  47. files = {"video": (uploaded_file.name, video_bytes, "video/mp4")}
  48. status_text.text("📤 正在提交到后端...")
  49. resp = requests.post(submit_url, files=files, timeout=60)
  50. resp.raise_for_status()
  51. data = resp.json()
  52. task_id = data.get("task_id")
  53. if not task_id:
  54. raise RuntimeError("后端未返回 task_id")
  55. # 轮询任务状态
  56. status_url = f"{backend_url}/get_results"
  57. download_url = None
  58. while True:
  59. r = requests.get(status_url, params={"remove_task_id": task_id}, timeout=30)
  60. r.raise_for_status()
  61. info = r.json()
  62. percentage = int(info.get("percentage", 0))
  63. status = info.get("status", "UNKNOWN")
  64. progress_bar.progress(percentage / 100)
  65. if percentage < 50:
  66. status_text.text(f"🔍 后端检测水印... {percentage}% ({status})")
  67. elif percentage < 95:
  68. status_text.text(f"🧹 后端移除水印... {percentage}% ({status})")
  69. else:
  70. status_text.text(f"🎵 后端合并音频... {percentage}% ({status})")
  71. if status == "FINISHED":
  72. download_url = info.get("download_url")
  73. break
  74. elif status == "ERROR":
  75. raise RuntimeError("后端处理失败")
  76. time.sleep(1)
  77. if not download_url:
  78. raise RuntimeError("后端未提供下载链接")
  79. final_url = f"{backend_url}{download_url}"
  80. progress_bar.progress(1.0)
  81. status_text.text("✅ 后端处理完成!")
  82. st.success("✅ 水印移除成功!")
  83. st.markdown("### 处理结果")
  84. st.video(final_url)
  85. # 下载按钮(拉取到本地提供下载)
  86. status_text.text("⬇️ 正在拉取处理后的视频用于下载...")
  87. dl_resp = requests.get(final_url, timeout=300)
  88. dl_resp.raise_for_status()
  89. st.download_button(
  90. label="⬇️ 下载处理后视频",
  91. data=dl_resp.content,
  92. file_name=f"cleaned_{uploaded_file.name}",
  93. mime="video/mp4",
  94. use_container_width=True,
  95. )
  96. else:
  97. # 本地模式处理
  98. with tempfile.TemporaryDirectory() as tmp_dir:
  99. tmp_path = Path(tmp_dir)
  100. # Save uploaded file
  101. input_path = tmp_path / uploaded_file.name
  102. with open(input_path, "wb") as f:
  103. f.write(uploaded_file.read())
  104. # Process video
  105. output_path = tmp_path / f"cleaned_{uploaded_file.name}"
  106. def update_progress(progress: int):
  107. progress_bar.progress(progress / 100)
  108. if progress < 50:
  109. status_text.text(f"🔍 Detecting watermarks... {progress}%")
  110. elif progress < 95:
  111. status_text.text(f"🧹 Removing watermarks... {progress}%")
  112. else:
  113. status_text.text(f"🎵 Merging audio... {progress}%")
  114. # Run the watermark removal with progress callback
  115. st.session_state.sora_wm.run(
  116. input_path, output_path, progress_callback=update_progress
  117. )
  118. # Complete the progress bar
  119. progress_bar.progress(1.0)
  120. status_text.text("✅ Processing complete!")
  121. st.success("✅ Watermark removed successfully!")
  122. # Display result
  123. st.markdown("### Result")
  124. st.video(str(output_path))
  125. # Download button
  126. with open(output_path, "rb") as f:
  127. st.download_button(
  128. label="⬇️ Download Cleaned Video",
  129. data=f,
  130. file_name=f"cleaned_{uploaded_file.name}",
  131. mime="video/mp4",
  132. use_container_width=True,
  133. )
  134. except Exception as e:
  135. st.error(f"❌ 处理失败: {str(e)}")
  136. # Footer
  137. st.markdown("---")
  138. st.markdown(
  139. """
  140. <div style='text-align: center'>
  141. <p>Built with ❤️ using Streamlit and AI</p>
  142. <p><a href='https://github.com/linkedlist771/SoraWatermarkCleaner'>GitHub Repository</a></p>
  143. </div>
  144. """,
  145. unsafe_allow_html=True,
  146. )
  147. if __name__ == "__main__":
  148. main()