img_util.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. import math
  2. import os
  3. import cv2
  4. import numpy as np
  5. import torch
  6. from torchvision.utils import make_grid
  7. def img2tensor(imgs, bgr2rgb=True, float32=True):
  8. """Numpy array to tensor.
  9. Args:
  10. imgs (list[ndarray] | ndarray): Input images.
  11. bgr2rgb (bool): Whether to change bgr to rgb.
  12. float32 (bool): Whether to change to float32.
  13. Returns:
  14. list[tensor] | tensor: Tensor images. If returned results only have
  15. one element, just return tensor.
  16. """
  17. def _totensor(img, bgr2rgb, float32):
  18. if img.shape[2] == 3 and bgr2rgb:
  19. if img.dtype == "float64":
  20. img = img.astype("float32")
  21. img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  22. img = torch.from_numpy(img.transpose(2, 0, 1))
  23. if float32:
  24. img = img.float()
  25. return img
  26. if isinstance(imgs, list):
  27. return [_totensor(img, bgr2rgb, float32) for img in imgs]
  28. else:
  29. return _totensor(imgs, bgr2rgb, float32)
  30. def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)):
  31. """Convert torch Tensors into image numpy arrays.
  32. After clamping to [min, max], values will be normalized to [0, 1].
  33. Args:
  34. tensor (Tensor or list[Tensor]): Accept shapes:
  35. 1) 4D mini-batch Tensor of shape (B x 3/1 x H x W);
  36. 2) 3D Tensor of shape (3/1 x H x W);
  37. 3) 2D Tensor of shape (H x W).
  38. Tensor channel should be in RGB order.
  39. rgb2bgr (bool): Whether to change rgb to bgr.
  40. out_type (numpy type): output types. If ``np.uint8``, transform outputs
  41. to uint8 type with range [0, 255]; otherwise, float type with
  42. range [0, 1]. Default: ``np.uint8``.
  43. min_max (tuple[int]): min and max values for clamp.
  44. Returns:
  45. (Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of
  46. shape (H x W). The channel order is BGR.
  47. """
  48. if not (
  49. torch.is_tensor(tensor)
  50. or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))
  51. ):
  52. raise TypeError(f"tensor or list of tensors expected, got {type(tensor)}")
  53. if torch.is_tensor(tensor):
  54. tensor = [tensor]
  55. result = []
  56. for _tensor in tensor:
  57. _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max)
  58. _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0])
  59. n_dim = _tensor.dim()
  60. if n_dim == 4:
  61. img_np = make_grid(
  62. _tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False
  63. ).numpy()
  64. img_np = img_np.transpose(1, 2, 0)
  65. if rgb2bgr:
  66. img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
  67. elif n_dim == 3:
  68. img_np = _tensor.numpy()
  69. img_np = img_np.transpose(1, 2, 0)
  70. if img_np.shape[2] == 1: # gray image
  71. img_np = np.squeeze(img_np, axis=2)
  72. else:
  73. if rgb2bgr:
  74. img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
  75. elif n_dim == 2:
  76. img_np = _tensor.numpy()
  77. else:
  78. raise TypeError(
  79. f"Only support 4D, 3D or 2D tensor. But received with dimension: {n_dim}"
  80. )
  81. if out_type == np.uint8:
  82. # Unlike MATLAB, numpy.unit8() WILL NOT round by default.
  83. img_np = (img_np * 255.0).round()
  84. img_np = img_np.astype(out_type)
  85. result.append(img_np)
  86. if len(result) == 1:
  87. result = result[0]
  88. return result
  89. def tensor2img_fast(tensor, rgb2bgr=True, min_max=(0, 1)):
  90. """This implementation is slightly faster than tensor2img.
  91. It now only supports torch tensor with shape (1, c, h, w).
  92. Args:
  93. tensor (Tensor): Now only support torch tensor with (1, c, h, w).
  94. rgb2bgr (bool): Whether to change rgb to bgr. Default: True.
  95. min_max (tuple[int]): min and max values for clamp.
  96. """
  97. output = tensor.squeeze(0).detach().clamp_(*min_max).permute(1, 2, 0)
  98. output = (output - min_max[0]) / (min_max[1] - min_max[0]) * 255
  99. output = output.type(torch.uint8).cpu().numpy()
  100. if rgb2bgr:
  101. output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
  102. return output
  103. def imfrombytes(content, flag="color", float32=False):
  104. """Read an image from bytes.
  105. Args:
  106. content (bytes): Image bytes got from files or other streams.
  107. flag (str): Flags specifying the color type of a loaded image,
  108. candidates are `color`, `grayscale` and `unchanged`.
  109. float32 (bool): Whether to change to float32., If True, will also norm
  110. to [0, 1]. Default: False.
  111. Returns:
  112. ndarray: Loaded image array.
  113. """
  114. img_np = np.frombuffer(content, np.uint8)
  115. imread_flags = {
  116. "color": cv2.IMREAD_COLOR,
  117. "grayscale": cv2.IMREAD_GRAYSCALE,
  118. "unchanged": cv2.IMREAD_UNCHANGED,
  119. }
  120. img = cv2.imdecode(img_np, imread_flags[flag])
  121. if float32:
  122. img = img.astype(np.float32) / 255.0
  123. return img
  124. def imwrite(img, file_path, params=None, auto_mkdir=True):
  125. """Write image to file.
  126. Args:
  127. img (ndarray): Image array to be written.
  128. file_path (str): Image file path.
  129. params (None or list): Same as opencv's :func:`imwrite` interface.
  130. auto_mkdir (bool): If the parent folder of `file_path` does not exist,
  131. whether to create it automatically.
  132. Returns:
  133. bool: Successful or not.
  134. """
  135. if auto_mkdir:
  136. dir_name = os.path.abspath(os.path.dirname(file_path))
  137. os.makedirs(dir_name, exist_ok=True)
  138. ok = cv2.imwrite(file_path, img, params)
  139. if not ok:
  140. raise IOError("Failed in writing images.")
  141. def crop_border(imgs, crop_border):
  142. """Crop borders of images.
  143. Args:
  144. imgs (list[ndarray] | ndarray): Images with shape (h, w, c).
  145. crop_border (int): Crop border for each end of height and weight.
  146. Returns:
  147. list[ndarray]: Cropped images.
  148. """
  149. if crop_border == 0:
  150. return imgs
  151. else:
  152. if isinstance(imgs, list):
  153. return [
  154. v[crop_border:-crop_border, crop_border:-crop_border, ...] for v in imgs
  155. ]
  156. else:
  157. return imgs[crop_border:-crop_border, crop_border:-crop_border, ...]