runtime.py 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. # https://github.com/huggingface/huggingface_hub/blob/5a12851f54bf614be39614034ed3a9031922d297/src/huggingface_hub/utils/_runtime.py
  2. import os
  3. import platform
  4. import sys
  5. from pathlib import Path
  6. from typing import Any, Dict
  7. import packaging.version
  8. from loguru import logger
  9. from rich import print
  10. from sorawm.iopaint.schema import Device
  11. _PY_VERSION: str = sys.version.split()[0].rstrip("+")
  12. if packaging.version.Version(_PY_VERSION) < packaging.version.Version("3.8.0"):
  13. import importlib_metadata # type: ignore
  14. else:
  15. import importlib.metadata as importlib_metadata # type: ignore
  16. _package_versions = {}
  17. _CANDIDATES = [
  18. "torch",
  19. "torchvision",
  20. "Pillow",
  21. "diffusers",
  22. "transformers",
  23. "opencv-python",
  24. "accelerate",
  25. "iopaint",
  26. "rembg",
  27. "onnxruntime",
  28. ]
  29. # Check once at runtime
  30. for name in _CANDIDATES:
  31. _package_versions[name] = "N/A"
  32. try:
  33. _package_versions[name] = importlib_metadata.version(name)
  34. except importlib_metadata.PackageNotFoundError:
  35. pass
  36. def dump_environment_info() -> Dict[str, str]:
  37. """Dump information about the machine to help debugging issues."""
  38. # Generic machine info
  39. info: Dict[str, Any] = {
  40. "Platform": platform.platform(),
  41. "Python version": platform.python_version(),
  42. }
  43. info.update(_package_versions)
  44. print("\n".join([f"- {prop}: {val}" for prop, val in info.items()]) + "\n")
  45. return info
  46. def check_device(device: Device) -> Device:
  47. if device == Device.cuda:
  48. import platform
  49. if platform.system() == "Darwin":
  50. logger.warning("MacOS does not support cuda, use cpu instead")
  51. return Device.cpu
  52. else:
  53. import torch
  54. if not torch.cuda.is_available():
  55. logger.warning("CUDA is not available, use cpu instead")
  56. return Device.cpu
  57. elif device == Device.mps:
  58. import torch
  59. if not torch.backends.mps.is_available():
  60. logger.warning("mps is not available, use cpu instead")
  61. return Device.cpu
  62. return device
  63. def setup_model_dir(model_dir: Path):
  64. model_dir = model_dir.expanduser().absolute()
  65. logger.info(f"Model directory: {model_dir}")
  66. os.environ["U2NET_HOME"] = str(model_dir)
  67. os.environ["XDG_CACHE_HOME"] = str(model_dir)
  68. if not model_dir.exists():
  69. logger.info(f"Create model directory: {model_dir}")
  70. model_dir.mkdir(exist_ok=True, parents=True)
  71. return model_dir