| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144 |
- import logging
- from typing import Generator, cast
- import pytest
- import torch
- from datasets import load_dataset
- from PIL import Image
- from colpali_engine.models import ColModernVBert, ColModernVBertProcessor
- from colpali_engine.utils.torch_utils import get_torch_device, tear_down_torch
- logger = logging.getLogger(__name__)
- @pytest.fixture(scope="module")
- def model_name() -> str:
- return "ModernVBERT/colmodernvbert"
- @pytest.fixture(scope="module")
- def model_without_mask(model_name: str) -> Generator[ColModernVBert, None, None]:
- device = get_torch_device("auto")
- logger.info(f"Device used: {device}")
- yield cast(
- ColModernVBert,
- ColModernVBert.from_pretrained(
- model_name,
- torch_dtype=torch.float32,
- device_map=device,
- attn_implementation="eager",
- mask_non_image_embeddings=False,
- ).eval(),
- )
- tear_down_torch()
- @pytest.fixture(scope="module")
- def model_with_mask(model_name: str) -> Generator[ColModernVBert, None, None]:
- device = get_torch_device("auto")
- logger.info(f"Device used: {device}")
- yield cast(
- ColModernVBert,
- ColModernVBert.from_pretrained(
- model_name,
- torch_dtype=torch.float32,
- device_map=device,
- attn_implementation="eager",
- mask_non_image_embeddings=True,
- ).eval(),
- )
- tear_down_torch()
- @pytest.fixture(scope="module")
- def processor(model_name: str) -> Generator[ColModernVBertProcessor, None, None]:
- yield cast(ColModernVBertProcessor, ColModernVBertProcessor.from_pretrained(model_name))
- class TestColModernVBert_Model: # noqa N801
- @pytest.mark.slow
- def test_load_model_from_pretrained(self, model_without_mask: ColModernVBert):
- assert isinstance(model_without_mask, ColModernVBert)
- class TestColModernVBert_ModelIntegration: # noqa N801
- @pytest.mark.slow
- def test_forward_images_integration(
- self,
- model_without_mask: ColModernVBert,
- processor: ColModernVBertProcessor,
- ):
- # Create a batch of dummy images
- images = [
- Image.new("RGB", (64, 64), color="white"),
- Image.new("RGB", (32, 32), color="black"),
- ]
- # Process the image
- batch_images = processor.process_images(images).to(model_without_mask.device)
- # Forward pass
- with torch.no_grad():
- outputs = model_without_mask(**batch_images)
- # Assertions
- assert isinstance(outputs, torch.Tensor)
- assert outputs.dim() == 3
- batch_size, n_visual_tokens, emb_dim = outputs.shape
- assert batch_size == len(images)
- assert emb_dim == model_without_mask.dim
- @pytest.mark.slow
- def test_forward_queries_integration(
- self,
- model_without_mask: ColModernVBert,
- processor: ColModernVBertProcessor,
- ):
- queries = [
- "Is attention really all you need?",
- "Are Benjamin, Antoine, Merve, and Jo best friends?",
- ]
- # Process the queries
- batch_queries = processor.process_queries(queries).to(model_without_mask.device).to(torch.float32)
- # Forward pass
- with torch.no_grad():
- outputs = model_without_mask(**batch_queries)
- # Assertions
- assert isinstance(outputs, torch.Tensor)
- assert outputs.dim() == 3
- batch_size, n_query_tokens, emb_dim = outputs.shape
- assert batch_size == len(queries)
- assert emb_dim == model_without_mask.dim
- @pytest.mark.slow
- def test_retrieval_integration(
- self,
- model_without_mask: ColModernVBert,
- processor: ColModernVBertProcessor,
- ):
- # Load the test dataset
- ds = load_dataset("hf-internal-testing/document-visual-retrieval-test", split="test")
- # Preprocess the examples
- batch_images = processor.process_images(images=ds["image"]).to(model_without_mask.device).to(torch.float32)
- batch_queries = processor.process_queries(queries=ds["query"]).to(model_without_mask.device).to(torch.float32)
- # Run inference
- with torch.inference_mode():
- image_embeddings = model_without_mask(**batch_images)
- query_embeddings = model_without_mask(**batch_queries)
- # Compute retrieval scores
- scores = processor.score_multi_vector(
- qs=query_embeddings,
- ps=image_embeddings,
- ) # (len(qs), len(ps))
- assert scores.ndim == 2, f"Expected 2D tensor, got {scores.ndim}"
- assert scores.shape == (len(ds), len(ds)), f"Expected shape {(len(ds), len(ds))}, got {scores.shape}"
|