similarity_map_utils.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. from typing import List, Tuple, Union
  2. import torch
  3. from einops import rearrange
  4. EPSILON = 1e-10
  5. def get_similarity_maps_from_embeddings(
  6. image_embeddings: torch.Tensor,
  7. query_embeddings: torch.Tensor,
  8. n_patches: Union[Tuple[int, int], List[Tuple[int, int]]],
  9. image_mask: torch.Tensor,
  10. ) -> List[torch.Tensor]:
  11. """
  12. Get the batched similarity maps between the query embeddings and the image embeddings.
  13. Each element in the returned list is a tensor of shape (query_tokens, n_patches_x, n_patches_y).
  14. Args:
  15. image_embeddings: tensor of shape (batch_size, image_tokens, dim)
  16. query_embeddings: tensor of shape (batch_size, query_tokens, dim)
  17. n_patches: number of patches per dimension for each image in the batch. If a single tuple is provided,
  18. the same number of patches is used for all images in the batch (broadcasted).
  19. image_mask: tensor of shape (batch_size, image_tokens). Used to filter out the embeddings
  20. that are not related to the image
  21. """
  22. if isinstance(n_patches, tuple):
  23. n_patches = [n_patches] * image_embeddings.size(0)
  24. similarity_maps: List[torch.Tensor] = []
  25. for idx in range(image_embeddings.size(0)):
  26. # Sanity check
  27. if image_mask[idx].sum() != n_patches[idx][0] * n_patches[idx][1]:
  28. raise ValueError(
  29. f"The number of patches ({n_patches[idx][0]} x {n_patches[idx][1]} = "
  30. f"{n_patches[idx][0] * n_patches[idx][1]}) "
  31. f"does not match the number of non-padded image tokens ({image_mask[idx].sum()})."
  32. )
  33. # Rearrange the output image tensor to explicitly represent the 2D grid of patches
  34. image_embedding_grid = rearrange(
  35. image_embeddings[idx][image_mask[idx]], # (n_patches_x * n_patches_y, dim)
  36. "(h w) c -> w h c",
  37. w=n_patches[idx][0],
  38. h=n_patches[idx][1],
  39. ) # (n_patches_x, n_patches_y, dim)
  40. similarity_map = torch.einsum(
  41. "nk,ijk->nij", query_embeddings[idx], image_embedding_grid
  42. ) # (batch_size, query_tokens, n_patches_x, n_patches_y)
  43. similarity_maps.append(similarity_map)
  44. return similarity_maps
  45. def normalize_similarity_map(similarity_map: torch.Tensor) -> torch.Tensor:
  46. """
  47. Normalize the similarity map to have values in the range [0, 1].
  48. Args:
  49. similarity_map: tensor of shape (n_patch_x, n_patch_y) or (batch_size, n_patch_x, n_patch_y)
  50. """
  51. if similarity_map.ndim not in [2, 3]:
  52. raise ValueError(
  53. "The input tensor must have 2 dimensions (n_patch_x, n_patch_y) or "
  54. "3 dimensions (batch_size, n_patch_x, n_patch_y)."
  55. )
  56. # Compute the minimum values along the last two dimensions (n_patch_x, n_patch_y)
  57. min_vals = similarity_map.min(dim=-1, keepdim=True)[0].min(dim=-2, keepdim=True)[0] # (1, 1) or (batch_size, 1, 1)
  58. # Compute the maximum values along the last two dimensions (n_patch_x, n_patch_y)
  59. max_vals = similarity_map.max(dim=-1, keepdim=True)[0].max(dim=-2, keepdim=True)[0] # (1, 1) or (batch_size, 1, 1)
  60. # Normalize the tensor
  61. # NOTE: Add a small epsilon to avoid division by zero.
  62. similarity_map_normalized = (similarity_map - min_vals) / (
  63. max_vals - min_vals + EPSILON
  64. ) # (n_patch_x, n_patch_y) or (batch_size, n_patch_x, n_patch_y)
  65. return similarity_map_normalized