similarity_maps.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. from typing import List, Tuple
  2. import matplotlib.pyplot as plt
  3. import numpy as np
  4. import seaborn as sns
  5. import torch
  6. from einops import rearrange
  7. from PIL import Image
  8. from colpali_engine.interpretability.similarity_map_utils import normalize_similarity_map
  9. def plot_similarity_map(
  10. image: Image.Image,
  11. similarity_map: torch.Tensor,
  12. figsize: Tuple[int, int] = (8, 8),
  13. show_colorbar: bool = False,
  14. ) -> Tuple[plt.Figure, plt.Axes]:
  15. """
  16. Plot and overlay a similarity map over the input image.
  17. A similarity map is a 2D tensor where each element (i, j) represents the similarity score between a chosen query
  18. token and the associated image patch at position (i, j). Thus, the higher the similarity score, the brighter the
  19. color of the patch.
  20. To show the returned similarity map, use:
  21. ```python
  22. >>> fig, ax = plot_similarity_map(image, similarity_map)
  23. >>> fig.show()
  24. ```
  25. Args:
  26. image: PIL image
  27. similarity_map: tensor of shape (n_patches_x, n_patches_y)
  28. figsize: size of the figure
  29. show_colorbar: whether to show a colorbar
  30. """
  31. # Convert the image to an array
  32. img_array = np.array(image.convert("RGBA")) # (height, width, channels)
  33. # Normalize the similarity map and convert it to Pillow image
  34. similarity_map_array = (
  35. normalize_similarity_map(similarity_map).to(torch.float32).cpu().numpy()
  36. ) # (n_patches_x, n_patches_y)
  37. # Reshape the similarity map to match the PIL shape convention
  38. similarity_map_array = rearrange(similarity_map_array, "h w -> w h") # (n_patches_y, n_patches_x)
  39. similarity_map_image = Image.fromarray((similarity_map_array * 255).astype("uint8")).resize(
  40. image.size, Image.Resampling.BICUBIC
  41. )
  42. # Create the figure
  43. with plt.style.context("dark_background"):
  44. fig, ax = plt.subplots(figsize=figsize)
  45. ax.imshow(img_array)
  46. im = ax.imshow(
  47. similarity_map_image,
  48. cmap=sns.color_palette("mako", as_cmap=True),
  49. alpha=0.5,
  50. )
  51. if show_colorbar:
  52. fig.colorbar(im)
  53. ax.set_axis_off()
  54. fig.tight_layout()
  55. return fig, ax
  56. def plot_all_similarity_maps(
  57. image: Image.Image,
  58. query_tokens: List[str],
  59. similarity_maps: torch.Tensor,
  60. figsize: Tuple[int, int] = (8, 8),
  61. show_colorbar: bool = False,
  62. add_title: bool = True,
  63. ) -> List[Tuple[plt.Figure, plt.Axes]]:
  64. """
  65. For each token in the query, plot and overlay a similarity map over the input image.
  66. A similarity map is a 2D tensor where each element (i, j) represents the similarity score between a chosen query
  67. token and the associated image patch at position (i, j). Thus, the higher the similarity score, the brighter the
  68. color of the patch.
  69. Args:
  70. image: PIL image
  71. query_tokens: list of query tokens
  72. similarity_maps: tensor of shape (query_tokens, n_patches_x, n_patches_y)
  73. figsize: size of the figure
  74. show_colorbar: whether to show a colorbar
  75. add_title: whether to add a title with the token and the max similarity score
  76. Example usage for one query-image pair:
  77. ```python
  78. >>> from colpali_engine.interpretability.similarity_map_utils import get_similarity_maps_from_embeddings
  79. >>> batch_images = processor.process_images([image]).to(device)
  80. >>> batch_queries = processor.process_queries([query]).to(device)
  81. >>> with torch.no_grad():
  82. image_embeddings = model.forward(**batch_images)
  83. query_embeddings = model.forward(**batch_queries)
  84. >>> n_patches = processor.get_n_patches(
  85. image_size=image.size,
  86. patch_size=model.patch_size
  87. )
  88. >>> image_mask = processor.get_image_mask(batch_images)
  89. >>> batched_similarity_maps = get_similarity_maps_from_embeddings(
  90. image_embeddings=image_embeddings,
  91. query_embeddings=query_embeddings,
  92. n_patches=n_patches,
  93. image_mask=image_mask,
  94. )
  95. >>> similarity_maps = batched_similarity_maps[0] # (query_length, n_patches_x, n_patches_y)
  96. >>> plots = plot_all_similarity_maps(
  97. image=image,
  98. query_tokens=query_tokens,
  99. similarity_maps=similarity_maps,
  100. )
  101. >>> for fig, ax in plots:
  102. fig.show()
  103. ```
  104. """
  105. plots: List[Tuple[plt.Figure, plt.Axes]] = []
  106. for idx, token in enumerate(query_tokens):
  107. fig, ax = plot_similarity_map(
  108. image=image,
  109. similarity_map=similarity_maps[idx],
  110. figsize=figsize,
  111. show_colorbar=show_colorbar,
  112. )
  113. if add_title:
  114. max_sim_score = similarity_maps[idx].max().item()
  115. ax.set_title(f"Token #{idx}: `{token}`. MaxSim score: {max_sim_score:.2f}", fontsize=14)
  116. plots.append((fig, ax))
  117. return plots