| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566 |
- from typing import Generator, cast
- import pytest
- import torch
- from PIL import Image
- from colpali_engine.models import ColModernVBertProcessor
- @pytest.fixture(scope="module")
- def model_name() -> str:
- return "ModernVBERT/colmodernvbert"
- @pytest.fixture(scope="module")
- def processor_from_pretrained(model_name: str) -> Generator[ColModernVBertProcessor, None, None]:
- yield cast(ColModernVBertProcessor, ColModernVBertProcessor.from_pretrained(model_name))
- def test_load_processor_from_pretrained(processor_from_pretrained: ColModernVBertProcessor):
- assert isinstance(processor_from_pretrained, ColModernVBertProcessor)
- def test_process_images(processor_from_pretrained: ColModernVBertProcessor):
- # Create a dummy image
- image_size = (64, 32)
- image = Image.new("RGB", image_size, color="black")
- images = [image]
- # Process the image
- batch_feature = processor_from_pretrained.process_images(images)
- # Assertions
- assert "pixel_values" in batch_feature
- assert isinstance(batch_feature["pixel_values"], torch.Tensor)
- assert batch_feature["pixel_values"].shape[0] == 1
- def test_process_texts(processor_from_pretrained: ColModernVBertProcessor):
- queries = [
- "Is attention really all you need?",
- "Are Benjamin, Antoine, Merve, and Jo best friends?",
- ]
- # Process the queries
- batch_encoding = processor_from_pretrained.process_texts(queries)
- # Assertions
- assert "input_ids" in batch_encoding
- assert isinstance(batch_encoding["input_ids"], torch.Tensor)
- assert cast(torch.Tensor, batch_encoding["input_ids"]).shape[0] == len(queries)
- def test_process_queries(processor_from_pretrained: ColModernVBertProcessor):
- queries = [
- "Is attention really all you need?",
- "Are Benjamin, Antoine, Merve, and Jo best friends?",
- ]
- # Process the queries
- batch_encoding = processor_from_pretrained.process_queries(queries)
- # Assertions
- assert "input_ids" in batch_encoding
- assert isinstance(batch_encoding["input_ids"], torch.Tensor)
- assert cast(torch.Tensor, batch_encoding["input_ids"]).shape[0] == len(queries)
|