app.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. import shutil
  2. import tempfile
  3. from pathlib import Path
  4. import requests
  5. import time
  6. import streamlit as st
  7. from sorawm.core import SoraWM
  8. def main():
  9. st.set_page_config(
  10. page_title="Sora Watermark Cleaner", page_icon="🎬", layout="centered"
  11. )
  12. st.title("🎬 Sora Watermark Cleaner")
  13. st.markdown("Remove watermarks from Sora-generated videos with ease")
  14. # 后端模式切换
  15. use_backend = st.sidebar.checkbox("使用后端服务 (FastAPI)", value=True)
  16. backend_url = st.sidebar.text_input(
  17. "后端地址",
  18. # value="http://192.168.245.14:5344",
  19. value="http://127.0.0.1:5344",
  20. help="FastAPI 服务的基础地址",
  21. )
  22. # Initialize SoraWM 仅用于本地模式
  23. if not use_backend:
  24. if "sora_wm" not in st.session_state:
  25. with st.spinner("Loading AI models..."):
  26. st.session_state.sora_wm = SoraWM()
  27. st.markdown("---")
  28. # File uploader
  29. uploaded_file = st.file_uploader(
  30. "Upload your video",
  31. type=["mp4", "avi", "mov", "mkv"],
  32. help="Select a video file to remove watermarks",
  33. )
  34. if uploaded_file is not None:
  35. # Display video info
  36. st.success(f"✅ Uploaded: {uploaded_file.name}")
  37. st.video(uploaded_file)
  38. # Process button
  39. if st.button("🚀 Remove Watermark", type="primary", use_container_width=True):
  40. try:
  41. # Create progress bar and status text
  42. progress_bar = st.progress(0)
  43. status_text = st.empty()
  44. if use_backend:
  45. # 后端模式:提交任务
  46. submit_url = f"{backend_url}/submit_remove_task"
  47. video_bytes = uploaded_file.getvalue()
  48. files = {"video": (uploaded_file.name, video_bytes, "video/mp4")}
  49. status_text.text("📤 正在提交到后端...")
  50. resp = requests.post(submit_url, files=files, timeout=60)
  51. resp.raise_for_status()
  52. data = resp.json()
  53. task_id = data.get("task_id")
  54. if not task_id:
  55. raise RuntimeError("后端未返回 task_id")
  56. # 轮询任务状态
  57. status_url = f"{backend_url}/get_results"
  58. download_url = None
  59. while True:
  60. r = requests.get(status_url, params={"remove_task_id": task_id}, timeout=30)
  61. r.raise_for_status()
  62. info = r.json()
  63. percentage = int(info.get("percentage", 0))
  64. status = info.get("status", "UNKNOWN")
  65. progress_bar.progress(percentage / 100)
  66. if percentage < 50:
  67. status_text.text(f"🔍 后端检测水印... {percentage}% ({status})")
  68. elif percentage < 95:
  69. status_text.text(f"🧹 后端移除水印... {percentage}% ({status})")
  70. else:
  71. status_text.text(f"🎵 后端合并音频... {percentage}% ({status})")
  72. if status == "FINISHED":
  73. download_url = info.get("download_url")
  74. break
  75. elif status == "ERROR":
  76. raise RuntimeError("后端处理失败")
  77. time.sleep(1)
  78. if not download_url:
  79. raise RuntimeError("后端未提供下载链接")
  80. final_url = f"{backend_url}{download_url}"
  81. progress_bar.progress(1.0)
  82. status_text.text("✅ 后端处理完成!")
  83. st.success("✅ 水印移除成功!")
  84. st.markdown("### 处理结果")
  85. st.video(final_url)
  86. # 下载按钮(拉取到本地提供下载)
  87. status_text.text("⬇️ 正在拉取处理后的视频用于下载...")
  88. dl_resp = requests.get(final_url, timeout=300)
  89. dl_resp.raise_for_status()
  90. st.download_button(
  91. label="⬇️ 下载处理后视频",
  92. data=dl_resp.content,
  93. file_name=f"cleaned_{uploaded_file.name}",
  94. mime="video/mp4",
  95. use_container_width=True,
  96. )
  97. else:
  98. # 本地模式处理
  99. with tempfile.TemporaryDirectory() as tmp_dir:
  100. tmp_path = Path(tmp_dir)
  101. # Save uploaded file
  102. input_path = tmp_path / uploaded_file.name
  103. with open(input_path, "wb") as f:
  104. f.write(uploaded_file.read())
  105. # Process video
  106. output_path = tmp_path / f"cleaned_{uploaded_file.name}"
  107. def update_progress(progress: int):
  108. progress_bar.progress(progress / 100)
  109. if progress < 50:
  110. status_text.text(f"🔍 Detecting watermarks... {progress}%")
  111. elif progress < 95:
  112. status_text.text(f"🧹 Removing watermarks... {progress}%")
  113. else:
  114. status_text.text(f"🎵 Merging audio... {progress}%")
  115. # Run the watermark removal with progress callback
  116. st.session_state.sora_wm.run(
  117. input_path, output_path, progress_callback=update_progress
  118. )
  119. # Complete the progress bar
  120. progress_bar.progress(1.0)
  121. status_text.text("✅ Processing complete!")
  122. st.success("✅ Watermark removed successfully!")
  123. # Display result
  124. st.markdown("### Result")
  125. st.video(str(output_path))
  126. # Download button
  127. with open(output_path, "rb") as f:
  128. st.download_button(
  129. label="⬇️ Download Cleaned Video",
  130. data=f,
  131. file_name=f"cleaned_{uploaded_file.name}",
  132. mime="video/mp4",
  133. use_container_width=True,
  134. )
  135. except Exception as e:
  136. st.error(f"❌ 处理失败: {str(e)}")
  137. # Footer
  138. st.markdown("---")
  139. st.markdown(
  140. """
  141. <div style='text-align: center'>
  142. <p>Built with ❤️ using Streamlit and AI</p>
  143. <p><a href='https://github.com/linkedlist771/SoraWatermarkCleaner'>GitHub Repository</a></p>
  144. </div>
  145. """,
  146. unsafe_allow_html=True,
  147. )
  148. if __name__ == "__main__":
  149. main()