| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384 |
- from typing import List, Tuple, Union
- import torch
- from einops import rearrange
- EPSILON = 1e-10
- def get_similarity_maps_from_embeddings(
- image_embeddings: torch.Tensor,
- query_embeddings: torch.Tensor,
- n_patches: Union[Tuple[int, int], List[Tuple[int, int]]],
- image_mask: torch.Tensor,
- ) -> List[torch.Tensor]:
- """
- Get the batched similarity maps between the query embeddings and the image embeddings.
- Each element in the returned list is a tensor of shape (query_tokens, n_patches_x, n_patches_y).
- Args:
- image_embeddings: tensor of shape (batch_size, image_tokens, dim)
- query_embeddings: tensor of shape (batch_size, query_tokens, dim)
- n_patches: number of patches per dimension for each image in the batch. If a single tuple is provided,
- the same number of patches is used for all images in the batch (broadcasted).
- image_mask: tensor of shape (batch_size, image_tokens). Used to filter out the embeddings
- that are not related to the image
- """
- if isinstance(n_patches, tuple):
- n_patches = [n_patches] * image_embeddings.size(0)
- similarity_maps: List[torch.Tensor] = []
- for idx in range(image_embeddings.size(0)):
- # Sanity check
- if image_mask[idx].sum() != n_patches[idx][0] * n_patches[idx][1]:
- raise ValueError(
- f"The number of patches ({n_patches[idx][0]} x {n_patches[idx][1]} = "
- f"{n_patches[idx][0] * n_patches[idx][1]}) "
- f"does not match the number of non-padded image tokens ({image_mask[idx].sum()})."
- )
- # Rearrange the output image tensor to explicitly represent the 2D grid of patches
- image_embedding_grid = rearrange(
- image_embeddings[idx][image_mask[idx]], # (n_patches_x * n_patches_y, dim)
- "(h w) c -> w h c",
- w=n_patches[idx][0],
- h=n_patches[idx][1],
- ) # (n_patches_x, n_patches_y, dim)
- similarity_map = torch.einsum(
- "nk,ijk->nij", query_embeddings[idx], image_embedding_grid
- ) # (batch_size, query_tokens, n_patches_x, n_patches_y)
- similarity_maps.append(similarity_map)
- return similarity_maps
- def normalize_similarity_map(similarity_map: torch.Tensor) -> torch.Tensor:
- """
- Normalize the similarity map to have values in the range [0, 1].
- Args:
- similarity_map: tensor of shape (n_patch_x, n_patch_y) or (batch_size, n_patch_x, n_patch_y)
- """
- if similarity_map.ndim not in [2, 3]:
- raise ValueError(
- "The input tensor must have 2 dimensions (n_patch_x, n_patch_y) or "
- "3 dimensions (batch_size, n_patch_x, n_patch_y)."
- )
- # Compute the minimum values along the last two dimensions (n_patch_x, n_patch_y)
- min_vals = similarity_map.min(dim=-1, keepdim=True)[0].min(dim=-2, keepdim=True)[0] # (1, 1) or (batch_size, 1, 1)
- # Compute the maximum values along the last two dimensions (n_patch_x, n_patch_y)
- max_vals = similarity_map.max(dim=-1, keepdim=True)[0].max(dim=-2, keepdim=True)[0] # (1, 1) or (batch_size, 1, 1)
- # Normalize the tensor
- # NOTE: Add a small epsilon to avoid division by zero.
- similarity_map_normalized = (similarity_map - min_vals) / (
- max_vals - min_vals + EPSILON
- ) # (n_patch_x, n_patch_y) or (batch_size, n_patch_x, n_patch_y)
- return similarity_map_normalized
|