misc.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. import os
  2. import os.path as osp
  3. from urllib.parse import urlparse
  4. import cv2
  5. import torch
  6. from torch.hub import download_url_to_file, get_dir
  7. ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
  8. def imwrite(img, file_path, params=None, auto_mkdir=True):
  9. """Write image to file.
  10. Args:
  11. img (ndarray): Image array to be written.
  12. file_path (str): Image file path.
  13. params (None or list): Same as opencv's :func:`imwrite` interface.
  14. auto_mkdir (bool): If the parent folder of `file_path` does not exist,
  15. whether to create it automatically.
  16. Returns:
  17. bool: Successful or not.
  18. """
  19. if auto_mkdir:
  20. dir_name = os.path.abspath(os.path.dirname(file_path))
  21. os.makedirs(dir_name, exist_ok=True)
  22. return cv2.imwrite(file_path, img, params)
  23. def img2tensor(imgs, bgr2rgb=True, float32=True):
  24. """Numpy array to tensor.
  25. Args:
  26. imgs (list[ndarray] | ndarray): Input images.
  27. bgr2rgb (bool): Whether to change bgr to rgb.
  28. float32 (bool): Whether to change to float32.
  29. Returns:
  30. list[tensor] | tensor: Tensor images. If returned results only have
  31. one element, just return tensor.
  32. """
  33. def _totensor(img, bgr2rgb, float32):
  34. if img.shape[2] == 3 and bgr2rgb:
  35. if img.dtype == "float64":
  36. img = img.astype("float32")
  37. img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  38. img = torch.from_numpy(img.transpose(2, 0, 1))
  39. if float32:
  40. img = img.float()
  41. return img
  42. if isinstance(imgs, list):
  43. return [_totensor(img, bgr2rgb, float32) for img in imgs]
  44. else:
  45. return _totensor(imgs, bgr2rgb, float32)
  46. def load_file_from_url(
  47. url, model_dir=None, progress=True, file_name=None, save_dir=None
  48. ):
  49. """Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py"""
  50. if model_dir is None:
  51. hub_dir = get_dir()
  52. model_dir = os.path.join(hub_dir, "checkpoints")
  53. if save_dir is None:
  54. save_dir = os.path.join(ROOT_DIR, model_dir)
  55. os.makedirs(save_dir, exist_ok=True)
  56. parts = urlparse(url)
  57. filename = os.path.basename(parts.path)
  58. if file_name is not None:
  59. filename = file_name
  60. cached_file = os.path.abspath(os.path.join(save_dir, filename))
  61. if not os.path.exists(cached_file):
  62. print(f'Downloading: "{url}" to {cached_file}\n')
  63. download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
  64. return cached_file
  65. def scandir(dir_path, suffix=None, recursive=False, full_path=False):
  66. """Scan a directory to find the interested files.
  67. Args:
  68. dir_path (str): Path of the directory.
  69. suffix (str | tuple(str), optional): File suffix that we are
  70. interested in. Default: None.
  71. recursive (bool, optional): If set to True, recursively scan the
  72. directory. Default: False.
  73. full_path (bool, optional): If set to True, include the dir_path.
  74. Default: False.
  75. Returns:
  76. A generator for all the interested files with relative paths.
  77. """
  78. if (suffix is not None) and not isinstance(suffix, (str, tuple)):
  79. raise TypeError('"suffix" must be a string or tuple of strings')
  80. root = dir_path
  81. def _scandir(dir_path, suffix, recursive):
  82. for entry in os.scandir(dir_path):
  83. if not entry.name.startswith(".") and entry.is_file():
  84. if full_path:
  85. return_path = entry.path
  86. else:
  87. return_path = osp.relpath(entry.path, root)
  88. if suffix is None:
  89. yield return_path
  90. elif return_path.endswith(suffix):
  91. yield return_path
  92. else:
  93. if recursive:
  94. yield from _scandir(entry.path, suffix=suffix, recursive=recursive)
  95. else:
  96. continue
  97. return _scandir(dir_path, suffix=suffix, recursive=recursive)