download_utils.py 1.4 KB

1234567891011121314151617181920212223242526272829303132333435
  1. from pathlib import Path
  2. import requests
  3. from loguru import logger
  4. from tqdm import tqdm
  5. from sorawm.configs import WATER_MARK_DETECT_YOLO_WEIGHTS
  6. DETECTOR_URL = "https://github.com/linkedlist771/SoraWatermarkCleaner/releases/download/V0.0.1/best.pt"
  7. def download_detector_weights():
  8. if not WATER_MARK_DETECT_YOLO_WEIGHTS.exists():
  9. logger.debug(f"llama weights not found, downloading from {DETECTOR_URL}")
  10. WATER_MARK_DETECT_YOLO_WEIGHTS.parent.mkdir(parents=True, exist_ok=True)
  11. try:
  12. response = requests.get(DETECTOR_URL, stream=True, timeout=300)
  13. response.raise_for_status()
  14. total_size = int(response.headers.get("content-length", 0))
  15. with open(WATER_MARK_DETECT_YOLO_WEIGHTS, "wb") as f:
  16. with tqdm(
  17. total=total_size, unit="B", unit_scale=True, desc="Downloading"
  18. ) as pbar:
  19. for chunk in response.iter_content(chunk_size=8192):
  20. if chunk:
  21. f.write(chunk)
  22. pbar.update(len(chunk))
  23. logger.success(f"✓ Weights downloaded: {WATER_MARK_DETECT_YOLO_WEIGHTS}")
  24. except requests.exceptions.RequestException as e:
  25. if WATER_MARK_DETECT_YOLO_WEIGHTS.exists():
  26. WATER_MARK_DETECT_YOLO_WEIGHTS.unlink()
  27. raise RuntimeError(f"Downing failed: {e}")