file_manager.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. import os
  2. from io import BytesIO
  3. from pathlib import Path
  4. from typing import List
  5. from fastapi import FastAPI, HTTPException
  6. from PIL import Image, ImageOps, PngImagePlugin
  7. from starlette.responses import FileResponse
  8. from ..schema import MediasResponse, MediaTab
  9. LARGE_ENOUGH_NUMBER = 100
  10. PngImagePlugin.MAX_TEXT_CHUNK = LARGE_ENOUGH_NUMBER * (1024**2)
  11. from .storage_backends import FilesystemStorageBackend
  12. from .utils import aspect_to_string, generate_filename, glob_img
  13. class FileManager:
  14. def __init__(self, app: FastAPI, input_dir: Path, mask_dir: Path, output_dir: Path):
  15. self.app = app
  16. self.input_dir: Path = input_dir
  17. self.mask_dir: Path = mask_dir
  18. self.output_dir: Path = output_dir
  19. self.image_dir_filenames = []
  20. self.output_dir_filenames = []
  21. if not self.thumbnail_directory.exists():
  22. self.thumbnail_directory.mkdir(parents=True)
  23. # fmt: off
  24. self.app.add_api_route("/api/v1/medias", self.api_medias, methods=["GET"], response_model=List[MediasResponse])
  25. self.app.add_api_route("/api/v1/media_file", self.api_media_file, methods=["GET"])
  26. self.app.add_api_route("/api/v1/media_thumbnail_file", self.api_media_thumbnail_file, methods=["GET"])
  27. # fmt: on
  28. def api_medias(self, tab: MediaTab) -> List[MediasResponse]:
  29. img_dir = self._get_dir(tab)
  30. return self._media_names(img_dir)
  31. def api_media_file(self, tab: MediaTab, filename: str) -> FileResponse:
  32. file_path = self._get_file(tab, filename)
  33. return FileResponse(file_path, media_type="image/png")
  34. # tab=${tab}?filename=${filename.name}?width=${width}&height=${height}
  35. def api_media_thumbnail_file(
  36. self, tab: MediaTab, filename: str, width: int, height: int
  37. ) -> FileResponse:
  38. img_dir = self._get_dir(tab)
  39. thumb_filename, (width, height) = self.get_thumbnail(
  40. img_dir, filename, width=width, height=height
  41. )
  42. thumbnail_filepath = self.thumbnail_directory / thumb_filename
  43. return FileResponse(
  44. thumbnail_filepath,
  45. headers={
  46. "X-Width": str(width),
  47. "X-Height": str(height),
  48. },
  49. media_type="image/jpeg",
  50. )
  51. def _get_dir(self, tab: MediaTab) -> Path:
  52. if tab == "input":
  53. return self.input_dir
  54. elif tab == "output":
  55. return self.output_dir
  56. elif tab == "mask":
  57. return self.mask_dir
  58. else:
  59. raise HTTPException(status_code=422, detail=f"tab not found: {tab}")
  60. def _get_file(self, tab: MediaTab, filename: str) -> Path:
  61. file_path = self._get_dir(tab) / filename
  62. if not file_path.exists():
  63. raise HTTPException(status_code=422, detail=f"file not found: {file_path}")
  64. return file_path
  65. @property
  66. def thumbnail_directory(self) -> Path:
  67. return self.output_dir / "thumbnails"
  68. @staticmethod
  69. def _media_names(directory: Path) -> List[MediasResponse]:
  70. if directory is None:
  71. return []
  72. names = sorted([it.name for it in glob_img(directory)])
  73. res = []
  74. for name in names:
  75. path = os.path.join(directory, name)
  76. img = Image.open(path)
  77. res.append(
  78. MediasResponse(
  79. name=name,
  80. height=img.height,
  81. width=img.width,
  82. ctime=os.path.getctime(path),
  83. mtime=os.path.getmtime(path),
  84. )
  85. )
  86. return res
  87. def get_thumbnail(
  88. self, directory: Path, original_filename: str, width, height, **options
  89. ):
  90. directory = Path(directory)
  91. storage = FilesystemStorageBackend(self.app)
  92. crop = options.get("crop", "fit")
  93. background = options.get("background")
  94. quality = options.get("quality", 90)
  95. original_path, original_filename = os.path.split(original_filename)
  96. original_filepath = os.path.join(directory, original_path, original_filename)
  97. image = Image.open(BytesIO(storage.read(original_filepath)))
  98. # keep ratio resize
  99. if not width and not height:
  100. width = 256
  101. if width != 0:
  102. height = int(image.height * width / image.width)
  103. else:
  104. width = int(image.width * height / image.height)
  105. thumbnail_size = (width, height)
  106. thumbnail_filename = generate_filename(
  107. directory,
  108. original_filename,
  109. aspect_to_string(thumbnail_size),
  110. crop,
  111. background,
  112. quality,
  113. )
  114. thumbnail_filepath = os.path.join(
  115. self.thumbnail_directory, original_path, thumbnail_filename
  116. )
  117. if storage.exists(thumbnail_filepath):
  118. return thumbnail_filepath, (width, height)
  119. try:
  120. image.load()
  121. except (IOError, OSError):
  122. self.app.logger.warning("Thumbnail not load image: %s", original_filepath)
  123. return thumbnail_filepath, (width, height)
  124. # get original image format
  125. options["format"] = options.get("format", image.format)
  126. image = self._create_thumbnail(
  127. image, thumbnail_size, crop, background=background
  128. )
  129. raw_data = self.get_raw_data(image, **options)
  130. storage.save(thumbnail_filepath, raw_data)
  131. return thumbnail_filepath, (width, height)
  132. def get_raw_data(self, image, **options):
  133. data = {
  134. "format": self._get_format(image, **options),
  135. "quality": options.get("quality", 90),
  136. }
  137. _file = BytesIO()
  138. image.save(_file, **data)
  139. return _file.getvalue()
  140. @staticmethod
  141. def colormode(image, colormode="RGB"):
  142. if colormode == "RGB" or colormode == "RGBA":
  143. if image.mode == "RGBA":
  144. return image
  145. if image.mode == "LA":
  146. return image.convert("RGBA")
  147. return image.convert(colormode)
  148. if colormode == "GRAY":
  149. return image.convert("L")
  150. return image.convert(colormode)
  151. @staticmethod
  152. def background(original_image, color=0xFF):
  153. size = (max(original_image.size),) * 2
  154. image = Image.new("L", size, color)
  155. image.paste(
  156. original_image,
  157. tuple(map(lambda x: (x[0] - x[1]) / 2, zip(size, original_image.size))),
  158. )
  159. return image
  160. def _get_format(self, image, **options):
  161. if options.get("format"):
  162. return options.get("format")
  163. if image.format:
  164. return image.format
  165. return "JPEG"
  166. def _create_thumbnail(self, image, size, crop="fit", background=None):
  167. try:
  168. resample = Image.Resampling.LANCZOS
  169. except AttributeError: # pylint: disable=raise-missing-from
  170. resample = Image.ANTIALIAS
  171. if crop == "fit":
  172. image = ImageOps.fit(image, size, resample)
  173. else:
  174. image = image.copy()
  175. image.thumbnail(size, resample=resample)
  176. if background is not None:
  177. image = self.background(image)
  178. image = self.colormode(image)
  179. return image