test_similarity_map_utils.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. import pytest
  2. import torch
  3. from colpali_engine.interpretability.similarity_map_utils import get_similarity_maps_from_embeddings
  4. from colpali_engine.interpretability.similarity_maps import normalize_similarity_map
  5. class TestNormalizeSimilarityMap:
  6. def test_normalize_similarity_map_2d_ones(self):
  7. similarity_map = torch.tensor(
  8. [
  9. [1.0, 1.0],
  10. [1.0, 1.0],
  11. ]
  12. )
  13. normalized_map = normalize_similarity_map(similarity_map)
  14. expected_map = torch.zeros_like(similarity_map)
  15. assert torch.allclose(normalized_map, expected_map, atol=1e-6)
  16. def test_normalize_similarity_map_2d(self):
  17. similarity_map = torch.tensor(
  18. [
  19. [1.0, 1.0],
  20. [0.0, -1.0],
  21. ]
  22. )
  23. normalized_map = normalize_similarity_map(similarity_map)
  24. expected_map = torch.tensor(
  25. [
  26. [1.0, 1.0],
  27. [0.5, 0.0],
  28. ]
  29. )
  30. assert torch.allclose(normalized_map, expected_map, atol=1e-6)
  31. def test_normalize_similarity_map_3d_ones(self):
  32. similarity_map = torch.tensor(
  33. [
  34. [
  35. [1.0, 1.0],
  36. [1.0, 1.0],
  37. ],
  38. [
  39. [2.0, 2.0],
  40. [2.0, 2.0],
  41. ],
  42. ]
  43. )
  44. normalized_map = normalize_similarity_map(similarity_map)
  45. expected_map = torch.zeros_like(similarity_map)
  46. assert torch.allclose(normalized_map, expected_map, atol=1e-6)
  47. class TestGetSimilarityMapsFromEmbeddings:
  48. def test_get_similarity_maps_from_embeddings(self):
  49. # Define test parameters
  50. batch_size = 2
  51. image_tokens = 6 # Total number of image tokens
  52. query_tokens = 3
  53. dim = 4 # Embedding dimension
  54. # Create dummy image embeddings and query embeddings
  55. image_embeddings = torch.randn(batch_size, image_tokens, dim)
  56. query_embeddings = torch.randn(batch_size, query_tokens, dim)
  57. # Define n_patches as a tuple (h, w), ensuring h * w equals image_tokens
  58. n_patches = (2, 3) # For instance, 2 rows and 3 columns
  59. # Create an optional image attention mask (all ones, no padding)
  60. image_mask = torch.ones(batch_size, image_tokens, dtype=torch.bool)
  61. # Call the function under test
  62. similarity_maps = get_similarity_maps_from_embeddings(
  63. image_embeddings=image_embeddings,
  64. query_embeddings=query_embeddings,
  65. n_patches=n_patches,
  66. image_mask=image_mask,
  67. )
  68. # Assertions to validate the output
  69. assert isinstance(similarity_maps, list), "Output should be a list of tensors."
  70. assert len(similarity_maps) == batch_size, "Output list length should match batch size."
  71. for idx, similarity_map in enumerate(similarity_maps):
  72. expected_shape = (query_tokens, n_patches[0], n_patches[1])
  73. assert similarity_map.shape == expected_shape, (
  74. f"Similarity map at index {idx} has shape {similarity_map.shape}, expected {expected_shape}."
  75. )
  76. def test_get_similarity_maps_with_varied_n_patches(self):
  77. # Define test parameters
  78. batch_size = 2
  79. image_tokens_list = [6, 8] # Different number of tokens for each image
  80. query_tokens = 3
  81. dim = 4 # Embedding dimension
  82. # Create dummy image embeddings with padding to match the maximum tokens
  83. max_image_tokens = max(image_tokens_list)
  84. image_embeddings = torch.randn(batch_size, max_image_tokens, dim)
  85. query_embeddings = torch.randn(batch_size, query_tokens, dim)
  86. # Define n_patches as a list of tuples
  87. n_patches = [(2, 3), (2, 4)] # Different for each image
  88. # Create image attention masks for variable image tokens
  89. image_mask = torch.zeros(batch_size, max_image_tokens, dtype=torch.bool)
  90. for idx, tokens in enumerate(image_tokens_list):
  91. image_mask[idx, :tokens] = 1
  92. # Call the function under test
  93. similarity_maps = get_similarity_maps_from_embeddings(
  94. image_embeddings=image_embeddings,
  95. query_embeddings=query_embeddings,
  96. n_patches=n_patches,
  97. image_mask=image_mask,
  98. )
  99. # Assertions to validate the output
  100. assert isinstance(similarity_maps, list), "Output should be a list of tensors."
  101. assert len(similarity_maps) == batch_size, "Output list length should match batch size."
  102. for idx, similarity_map in enumerate(similarity_maps):
  103. expected_shape = (query_tokens, n_patches[idx][0], n_patches[idx][1])
  104. assert similarity_map.shape == expected_shape, (
  105. f"Similarity map at index {idx} has shape {similarity_map.shape}, expected {expected_shape}."
  106. )
  107. def test_get_similarity_maps_with_incorrect_n_patches(self):
  108. # Define test parameters
  109. batch_size = 1
  110. image_tokens = 6 # Total number of image tokens
  111. query_tokens = 2
  112. dim = 5 # Embedding dimension
  113. # Create dummy image embeddings and query embeddings
  114. image_embeddings = torch.randn(batch_size, image_tokens, dim)
  115. query_embeddings = torch.randn(batch_size, query_tokens, dim)
  116. # Define incorrect n_patches that do not match image_tokens
  117. n_patches = (2, 2) # 2*2 != 6
  118. # Create image attention masks for variable image tokens
  119. image_mask = torch.ones(batch_size, image_tokens, dtype=torch.bool)
  120. # Expect an error due to shape mismatch
  121. with pytest.raises(ValueError):
  122. get_similarity_maps_from_embeddings(
  123. image_embeddings=image_embeddings,
  124. query_embeddings=query_embeddings,
  125. n_patches=n_patches,
  126. image_mask=image_mask,
  127. )
  128. def test_get_similarity_maps_with_padding(self):
  129. # Define test parameters
  130. batch_size = 1
  131. image_tokens = 9 # Total number of image tokens
  132. query_tokens = 2
  133. dim = 5 # Embedding dimension
  134. # Create dummy image embeddings and query embeddings
  135. image_embeddings = torch.randn(batch_size, image_tokens, dim)
  136. query_embeddings = torch.randn(batch_size, query_tokens, dim)
  137. # Define n_patches as a tuple
  138. n_patches = (3, 2)
  139. # Create an image attention mask with padding
  140. image_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 0, 0, 0]], dtype=torch.bool)
  141. # Call the function under test
  142. similarity_maps = get_similarity_maps_from_embeddings(
  143. image_embeddings=image_embeddings,
  144. query_embeddings=query_embeddings,
  145. n_patches=n_patches,
  146. image_mask=image_mask,
  147. )
  148. # Assertions to validate the output
  149. assert isinstance(similarity_maps, list), "Output should be a list of tensors."
  150. assert len(similarity_maps) == batch_size, "Output list length should match batch size."
  151. similarity_map = similarity_maps[0]
  152. expected_shape = (query_tokens, n_patches[0], n_patches[1])
  153. assert similarity_map.shape == expected_shape, (
  154. f"Similarity map has shape {similarity_map.shape}, expected {expected_shape}."
  155. )