test_modeling_colmodernvbert.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. import logging
  2. from typing import Generator, cast
  3. import pytest
  4. import torch
  5. from datasets import load_dataset
  6. from PIL import Image
  7. from colpali_engine.models import ColModernVBert, ColModernVBertProcessor
  8. from colpali_engine.utils.torch_utils import get_torch_device, tear_down_torch
  9. logger = logging.getLogger(__name__)
  10. @pytest.fixture(scope="module")
  11. def model_name() -> str:
  12. return "ModernVBERT/colmodernvbert"
  13. @pytest.fixture(scope="module")
  14. def model_without_mask(model_name: str) -> Generator[ColModernVBert, None, None]:
  15. device = get_torch_device("auto")
  16. logger.info(f"Device used: {device}")
  17. yield cast(
  18. ColModernVBert,
  19. ColModernVBert.from_pretrained(
  20. model_name,
  21. torch_dtype=torch.float32,
  22. device_map=device,
  23. attn_implementation="eager",
  24. mask_non_image_embeddings=False,
  25. ).eval(),
  26. )
  27. tear_down_torch()
  28. @pytest.fixture(scope="module")
  29. def model_with_mask(model_name: str) -> Generator[ColModernVBert, None, None]:
  30. device = get_torch_device("auto")
  31. logger.info(f"Device used: {device}")
  32. yield cast(
  33. ColModernVBert,
  34. ColModernVBert.from_pretrained(
  35. model_name,
  36. torch_dtype=torch.float32,
  37. device_map=device,
  38. attn_implementation="eager",
  39. mask_non_image_embeddings=True,
  40. ).eval(),
  41. )
  42. tear_down_torch()
  43. @pytest.fixture(scope="module")
  44. def processor(model_name: str) -> Generator[ColModernVBertProcessor, None, None]:
  45. yield cast(ColModernVBertProcessor, ColModernVBertProcessor.from_pretrained(model_name))
  46. class TestColModernVBert_Model: # noqa N801
  47. @pytest.mark.slow
  48. def test_load_model_from_pretrained(self, model_without_mask: ColModernVBert):
  49. assert isinstance(model_without_mask, ColModernVBert)
  50. class TestColModernVBert_ModelIntegration: # noqa N801
  51. @pytest.mark.slow
  52. def test_forward_images_integration(
  53. self,
  54. model_without_mask: ColModernVBert,
  55. processor: ColModernVBertProcessor,
  56. ):
  57. # Create a batch of dummy images
  58. images = [
  59. Image.new("RGB", (64, 64), color="white"),
  60. Image.new("RGB", (32, 32), color="black"),
  61. ]
  62. # Process the image
  63. batch_images = processor.process_images(images).to(model_without_mask.device)
  64. # Forward pass
  65. with torch.no_grad():
  66. outputs = model_without_mask(**batch_images)
  67. # Assertions
  68. assert isinstance(outputs, torch.Tensor)
  69. assert outputs.dim() == 3
  70. batch_size, n_visual_tokens, emb_dim = outputs.shape
  71. assert batch_size == len(images)
  72. assert emb_dim == model_without_mask.dim
  73. @pytest.mark.slow
  74. def test_forward_queries_integration(
  75. self,
  76. model_without_mask: ColModernVBert,
  77. processor: ColModernVBertProcessor,
  78. ):
  79. queries = [
  80. "Is attention really all you need?",
  81. "Are Benjamin, Antoine, Merve, and Jo best friends?",
  82. ]
  83. # Process the queries
  84. batch_queries = processor.process_queries(queries).to(model_without_mask.device).to(torch.float32)
  85. # Forward pass
  86. with torch.no_grad():
  87. outputs = model_without_mask(**batch_queries)
  88. # Assertions
  89. assert isinstance(outputs, torch.Tensor)
  90. assert outputs.dim() == 3
  91. batch_size, n_query_tokens, emb_dim = outputs.shape
  92. assert batch_size == len(queries)
  93. assert emb_dim == model_without_mask.dim
  94. @pytest.mark.slow
  95. def test_retrieval_integration(
  96. self,
  97. model_without_mask: ColModernVBert,
  98. processor: ColModernVBertProcessor,
  99. ):
  100. # Load the test dataset
  101. ds = load_dataset("hf-internal-testing/document-visual-retrieval-test", split="test")
  102. # Preprocess the examples
  103. batch_images = processor.process_images(images=ds["image"]).to(model_without_mask.device).to(torch.float32)
  104. batch_queries = processor.process_queries(queries=ds["query"]).to(model_without_mask.device).to(torch.float32)
  105. # Run inference
  106. with torch.inference_mode():
  107. image_embeddings = model_without_mask(**batch_images)
  108. query_embeddings = model_without_mask(**batch_queries)
  109. # Compute retrieval scores
  110. scores = processor.score_multi_vector(
  111. qs=query_embeddings,
  112. ps=image_embeddings,
  113. ) # (len(qs), len(ps))
  114. assert scores.ndim == 2, f"Expected 2D tensor, got {scores.ndim}"
  115. assert scores.shape == (len(ds), len(ds)), f"Expected shape {(len(ds), len(ds))}, got {scores.shape}"