test_dataset.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. import pytest
  2. from torch.utils.data import Dataset
  3. from colpali_engine.data import ColPaliEngineDataset, Corpus
  4. # --------------------------------------------------------------------------- #
  5. # Helper utilities #
  6. # --------------------------------------------------------------------------- #
  7. class DummyMapDataset(Dataset):
  8. """
  9. Minimal map‑style dataset that includes a `.take()` method so we can
  10. exercise ColPaliEngineDataset.take() without depending on HF datasets.
  11. """
  12. def __init__(self, samples):
  13. self._samples = list(samples)
  14. def __len__(self):
  15. return len(self._samples)
  16. def __getitem__(self, idx):
  17. return self._samples[idx]
  18. def take(self, n):
  19. return DummyMapDataset(self._samples[:n])
  20. # --------------------------------------------------------------------------- #
  21. # Fixtures & samples #
  22. # --------------------------------------------------------------------------- #
  23. @pytest.fixture
  24. def corpus():
  25. data = [{"doc": f"doc_{i}"} for i in range(3)]
  26. return Corpus(corpus_data=data)
  27. @pytest.fixture
  28. def data_no_neg():
  29. """3 samples – *no* neg_target column at all."""
  30. return [
  31. {"query": "q0", "pos_target": 0},
  32. {"query": "q1", "pos_target": [1]},
  33. {"query": "q2", "pos_target": [2]},
  34. ]
  35. @pytest.fixture
  36. def data_with_neg():
  37. """2 samples – every sample has a neg_target column."""
  38. return [
  39. {"query": "q0", "pos_target": 1, "neg_target": 0},
  40. {"query": "q1", "pos_target": [2], "neg_target": [0, 1]},
  41. ]
  42. # --------------------------------------------------------------------------- #
  43. # Tests – NO negatives case #
  44. # --------------------------------------------------------------------------- #
  45. def test_no_negatives_basic(data_no_neg):
  46. ds = ColPaliEngineDataset(data_no_neg) # neg_target_column_name defaults to None
  47. assert len(ds) == 3
  48. sample = ds[0]
  49. assert sample[ColPaliEngineDataset.QUERY_KEY] == "q0"
  50. assert sample[ColPaliEngineDataset.POS_TARGET_KEY] == [0]
  51. # NEG_TARGET_KEY should be None
  52. assert sample[ColPaliEngineDataset.NEG_TARGET_KEY] is None
  53. def test_no_negatives_with_corpus_resolution(data_no_neg, corpus):
  54. ds = ColPaliEngineDataset(data_no_neg, corpus=corpus)
  55. s1 = ds[1]
  56. # pos_target indices 1 should be resolved to the actual doc string
  57. assert s1[ColPaliEngineDataset.POS_TARGET_KEY] == ["doc_1"]
  58. # still no negatives
  59. assert s1[ColPaliEngineDataset.NEG_TARGET_KEY] is None
  60. # --------------------------------------------------------------------------- #
  61. # Tests – WITH negatives case #
  62. # --------------------------------------------------------------------------- #
  63. def test_with_negatives_basic(data_with_neg):
  64. ds = ColPaliEngineDataset(
  65. data_with_neg,
  66. neg_target_column_name="neg_target",
  67. )
  68. assert len(ds) == 2
  69. s0 = ds[0]
  70. assert s0[ColPaliEngineDataset.POS_TARGET_KEY] == [1]
  71. assert s0[ColPaliEngineDataset.NEG_TARGET_KEY] == [0]
  72. def test_with_negatives_and_corpus(data_with_neg, corpus):
  73. ds = ColPaliEngineDataset(
  74. data_with_neg,
  75. corpus=corpus,
  76. neg_target_column_name="neg_target",
  77. )
  78. s1 = ds[1]
  79. # pos 2 -> "doc_2", negs 0,1 -> "doc_0", "doc_1"
  80. assert s1[ColPaliEngineDataset.POS_TARGET_KEY] == ["doc_2"]
  81. assert s1[ColPaliEngineDataset.NEG_TARGET_KEY] == ["doc_0", "doc_1"]
  82. # --------------------------------------------------------------------------- #
  83. # Tests for mixed / inconsistent scenarios #
  84. # --------------------------------------------------------------------------- #
  85. def test_error_if_neg_column_specified_but_missing(data_no_neg):
  86. """All samples must include the column when neg_target_column_name is given."""
  87. with pytest.raises(AssertionError):
  88. ds = ColPaliEngineDataset( # noqa: F841
  89. data_no_neg,
  90. neg_target_column_name="neg_target",
  91. )
  92. _ = ds[0] # force __getitem__
  93. def test_error_if_data_mix_neg_and_non_neg(data_with_neg, data_no_neg):
  94. """A mixed dataset (some samples without neg_target) should fail."""
  95. mixed = data_with_neg + data_no_neg
  96. # The first sample *does* have neg_target, so __init__ succeeds.
  97. ds = ColPaliEngineDataset(
  98. mixed,
  99. neg_target_column_name="neg_target",
  100. )
  101. # Accessing a sample lacking the column should raise.
  102. with pytest.raises(KeyError):
  103. _ = ds[len(data_with_neg)] # first sample from the 'no_neg' part
  104. # --------------------------------------------------------------------------- #
  105. # .take() works in both modes #
  106. # --------------------------------------------------------------------------- #
  107. def test_take_returns_subset(data_no_neg):
  108. wrapped = DummyMapDataset(data_no_neg)
  109. ds = ColPaliEngineDataset(wrapped)
  110. sub_ds = ds.take(1)
  111. assert isinstance(sub_ds, ColPaliEngineDataset)
  112. assert len(sub_ds) == 1
  113. # Make sure we can still index
  114. _ = sub_ds[0]