| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185 |
- # Copyright (c) Meta Platforms, Inc. and affiliates.
- # All rights reserved.
- # This source code is licensed under the license found in the
- # LICENSE file in the root directory of this source tree.
- from typing import Any, Dict, List, Tuple
- import torch
- from torch import nn
- from torch.nn import functional as F
- from .image_encoder import ImageEncoderViT
- from .mask_decoder import MaskDecoder
- from .prompt_encoder import PromptEncoder
- class SamHQ(nn.Module):
- mask_threshold: float = 0.0
- image_format: str = "RGB"
- def __init__(
- self,
- image_encoder: ImageEncoderViT,
- prompt_encoder: PromptEncoder,
- mask_decoder: MaskDecoder,
- pixel_mean: List[float] = [123.675, 116.28, 103.53],
- pixel_std: List[float] = [58.395, 57.12, 57.375],
- ) -> None:
- """
- SAM predicts object masks from an image and input prompts.
- Arguments:
- image_encoder (ImageEncoderViT): The backbone used to encode the
- image into image embeddings that allow for efficient mask prediction.
- prompt_encoder (PromptEncoder): Encodes various types of input prompts.
- mask_decoder (MaskDecoder): Predicts masks from the image embeddings
- and encoded prompts.
- pixel_mean (list(float)): Mean values for normalizing pixels in the input image.
- pixel_std (list(float)): Std values for normalizing pixels in the input image.
- """
- super().__init__()
- self.image_encoder = image_encoder
- self.prompt_encoder = prompt_encoder
- self.mask_decoder = mask_decoder
- self.register_buffer(
- "pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False
- )
- self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
- @property
- def device(self) -> Any:
- return self.pixel_mean.device
- def forward(
- self,
- batched_input: List[Dict[str, Any]],
- multimask_output: bool,
- hq_token_only: bool = False,
- ) -> List[Dict[str, torch.Tensor]]:
- """
- Predicts masks end-to-end from provided images and prompts.
- If prompts are not known in advance, using SamPredictor is
- recommended over calling the model directly.
- Arguments:
- batched_input (list(dict)): A list over input images, each a
- dictionary with the following keys. A prompt key can be
- excluded if it is not present.
- 'image': The image as a torch tensor in 3xHxW format,
- already transformed for input to the model.
- 'original_size': (tuple(int, int)) The original size of
- the image before transformation, as (H, W).
- 'point_coords': (torch.Tensor) Batched point prompts for
- this image, with shape BxNx2. Already transformed to the
- input frame of the model.
- 'point_labels': (torch.Tensor) Batched labels for point prompts,
- with shape BxN.
- 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4.
- Already transformed to the input frame of the model.
- 'mask_inputs': (torch.Tensor) Batched mask inputs to the model,
- in the form Bx1xHxW.
- multimask_output (bool): Whether the model should predict multiple
- disambiguating masks, or return a single mask.
- Returns:
- (list(dict)): A list over input images, where each element is
- as dictionary with the following keys.
- 'masks': (torch.Tensor) Batched binary mask predictions,
- with shape BxCxHxW, where B is the number of input prompts,
- C is determined by multimask_output, and (H, W) is the
- original size of the image.
- 'iou_predictions': (torch.Tensor) The model's predictions
- of mask quality, in shape BxC.
- 'low_res_logits': (torch.Tensor) Low resolution logits with
- shape BxCxHxW, where H=W=256. Can be passed as mask input
- to subsequent iterations of prediction.
- """
- input_images = torch.stack(
- [self.preprocess(x["image"]) for x in batched_input], dim=0
- )
- image_embeddings, interm_embeddings = self.image_encoder(input_images)
- interm_embeddings = interm_embeddings[0] # early layer
- outputs = []
- for image_record, curr_embedding, curr_interm in zip(
- batched_input, image_embeddings, interm_embeddings
- ):
- if "point_coords" in image_record:
- points = (image_record["point_coords"], image_record["point_labels"])
- else:
- points = None
- sparse_embeddings, dense_embeddings = self.prompt_encoder(
- points=points,
- boxes=image_record.get("boxes", None),
- masks=image_record.get("mask_inputs", None),
- )
- low_res_masks, iou_predictions = self.mask_decoder(
- image_embeddings=curr_embedding.unsqueeze(0),
- image_pe=self.prompt_encoder.get_dense_pe(),
- sparse_prompt_embeddings=sparse_embeddings,
- dense_prompt_embeddings=dense_embeddings,
- multimask_output=multimask_output,
- hq_token_only=hq_token_only,
- interm_embeddings=curr_interm.unsqueeze(0).unsqueeze(0),
- )
- masks = self.postprocess_masks(
- low_res_masks,
- input_size=image_record["image"].shape[-2:],
- original_size=image_record["original_size"],
- )
- masks = masks > self.mask_threshold
- outputs.append(
- {
- "masks": masks,
- "iou_predictions": iou_predictions,
- "low_res_logits": low_res_masks,
- }
- )
- return outputs
- def postprocess_masks(
- self,
- masks: torch.Tensor,
- input_size: Tuple[int, ...],
- original_size: Tuple[int, ...],
- ) -> torch.Tensor:
- """
- Remove padding and upscale masks to the original image size.
- Arguments:
- masks (torch.Tensor): Batched masks from the mask_decoder,
- in BxCxHxW format.
- input_size (tuple(int, int)): The size of the image input to the
- model, in (H, W) format. Used to remove padding.
- original_size (tuple(int, int)): The original size of the image
- before resizing for input to the model, in (H, W) format.
- Returns:
- (torch.Tensor): Batched masks in BxCxHxW format, where (H, W)
- is given by original_size.
- """
- masks = F.interpolate(
- masks,
- (self.image_encoder.img_size, self.image_encoder.img_size),
- mode="bilinear",
- align_corners=False,
- )
- masks = masks[..., : input_size[0], : input_size[1]]
- masks = F.interpolate(
- masks, original_size, mode="bilinear", align_corners=False
- )
- return masks
- def preprocess(self, x: torch.Tensor) -> torch.Tensor:
- """Normalize pixel values and pad to a square input."""
- # Normalize colors
- x = (x - self.pixel_mean) / self.pixel_std
- # Pad
- h, w = x.shape[-2:]
- padh = self.image_encoder.img_size - h
- padw = self.image_encoder.img_size - w
- x = F.pad(x, (0, padw, 0, padh))
- return x
|