test_processing_colmodernvbert.py 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. from typing import Generator, cast
  2. import pytest
  3. import torch
  4. from PIL import Image
  5. from colpali_engine.models import ColModernVBertProcessor
  6. @pytest.fixture(scope="module")
  7. def model_name() -> str:
  8. return "ModernVBERT/colmodernvbert"
  9. @pytest.fixture(scope="module")
  10. def processor_from_pretrained(model_name: str) -> Generator[ColModernVBertProcessor, None, None]:
  11. yield cast(ColModernVBertProcessor, ColModernVBertProcessor.from_pretrained(model_name))
  12. def test_load_processor_from_pretrained(processor_from_pretrained: ColModernVBertProcessor):
  13. assert isinstance(processor_from_pretrained, ColModernVBertProcessor)
  14. def test_process_images(processor_from_pretrained: ColModernVBertProcessor):
  15. # Create a dummy image
  16. image_size = (64, 32)
  17. image = Image.new("RGB", image_size, color="black")
  18. images = [image]
  19. # Process the image
  20. batch_feature = processor_from_pretrained.process_images(images)
  21. # Assertions
  22. assert "pixel_values" in batch_feature
  23. assert isinstance(batch_feature["pixel_values"], torch.Tensor)
  24. assert batch_feature["pixel_values"].shape[0] == 1
  25. def test_process_texts(processor_from_pretrained: ColModernVBertProcessor):
  26. queries = [
  27. "Is attention really all you need?",
  28. "Are Benjamin, Antoine, Merve, and Jo best friends?",
  29. ]
  30. # Process the queries
  31. batch_encoding = processor_from_pretrained.process_texts(queries)
  32. # Assertions
  33. assert "input_ids" in batch_encoding
  34. assert isinstance(batch_encoding["input_ids"], torch.Tensor)
  35. assert cast(torch.Tensor, batch_encoding["input_ids"]).shape[0] == len(queries)
  36. def test_process_queries(processor_from_pretrained: ColModernVBertProcessor):
  37. queries = [
  38. "Is attention really all you need?",
  39. "Are Benjamin, Antoine, Merve, and Jo best friends?",
  40. ]
  41. # Process the queries
  42. batch_encoding = processor_from_pretrained.process_queries(queries)
  43. # Assertions
  44. assert "input_ids" in batch_encoding
  45. assert isinstance(batch_encoding["input_ids"], torch.Tensor)
  46. assert cast(torch.Tensor, batch_encoding["input_ids"]).shape[0] == len(queries)