test_sampler.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. import pytest
  2. import torch
  3. from torch.utils.data import Dataset
  4. from colpali_engine.data.sampler import SingleDatasetBatchSampler
  5. class DummyDataset(Dataset):
  6. """
  7. Minimal PyTorch dataset that also supports `.take()`.
  8. The values it returns are irrelevant to the sampler; we only care about length.
  9. """
  10. def __init__(self, size: int, start: int = 0):
  11. self._data = list(range(start, start + size))
  12. def __len__(self):
  13. return len(self._data)
  14. def __getitem__(self, idx):
  15. return self._data[idx]
  16. # Simulate Arrow / HF dataset API used by the sampler
  17. def take(self, total_samples: int):
  18. # Keep the same starting offset so global indices stay monotonic
  19. return DummyDataset(total_samples, start=self._data[0])
  20. # --------------------------------------------------------------------------- #
  21. # Test helpers #
  22. # --------------------------------------------------------------------------- #
  23. def dataset_boundaries(sampler):
  24. """Return a list of (lo, hi) index ranges, one per dataset, in global space."""
  25. cs = sampler.cumsum_sizes # cumsum has an extra leading 0
  26. return [(cs[i], cs[i + 1]) for i in range(len(cs) - 1)]
  27. def which_dataset(idx, boundaries):
  28. """Given a global idx, tell which dataset it belongs to (0‑based)."""
  29. for d, (lo, hi) in enumerate(boundaries):
  30. if lo <= idx < hi:
  31. return d
  32. raise ValueError(f"idx {idx} out of bounds")
  33. # --------------------------------------------------------------------------- #
  34. # Tests #
  35. # --------------------------------------------------------------------------- #
  36. def test_basic_iteration_and_len():
  37. """
  38. Two datasets, lengths 10 and 6, global batch size 4.
  39. Both datasets should be truncated (10→8, 6→4). Expect 3 batches.
  40. """
  41. ds = [DummyDataset(10), DummyDataset(6)]
  42. gen = torch.Generator().manual_seed(123)
  43. sampler = SingleDatasetBatchSampler(ds, global_batch_size=4, generator=gen)
  44. batches = list(iter(sampler))
  45. # 1) __len__ matches actual number of batches
  46. assert len(batches) == len(sampler) == 3
  47. # 2) All samples are unique and count equals truncated total
  48. flat = [i for b in batches for i in b]
  49. assert len(flat) == len(set(flat)) == 12 # 8 + 4
  50. # 3) Every batch is exactly global_batch_size long
  51. assert all(len(b) == 4 for b in batches)
  52. def test_single_dataset_per_batch():
  53. """
  54. Ensure that every yielded batch contains indices drawn from
  55. *one—and only one—dataset*.
  56. """
  57. ds = [DummyDataset(8), DummyDataset(8), DummyDataset(16)]
  58. sampler = SingleDatasetBatchSampler(ds, global_batch_size=4, generator=torch.Generator())
  59. boundaries = dataset_boundaries(sampler)
  60. for batch in sampler:
  61. d0 = which_dataset(batch[0], boundaries)
  62. # All indices in the batch must map to the same dataset ID
  63. assert all(which_dataset(i, boundaries) == d0 for i in batch)
  64. def test_epoch_based_reshuffle_changes_order():
  65. """
  66. Calling set_epoch should reshuffle the internal order so that
  67. consecutive epochs produce different batch orderings.
  68. """
  69. ds = [DummyDataset(8), DummyDataset(8)]
  70. gen = torch.Generator().manual_seed(999)
  71. sampler = SingleDatasetBatchSampler(ds, global_batch_size=4, generator=gen)
  72. first_epoch = list(iter(sampler))
  73. sampler.set_epoch(1)
  74. second_epoch = list(iter(sampler))
  75. # Pure order comparison; contents are the same but order should differ
  76. assert first_epoch != second_epoch
  77. # Same epoch again → deterministic repeat
  78. sampler.set_epoch(1)
  79. repeat_epoch = list(iter(sampler))
  80. assert second_epoch == repeat_epoch
  81. @pytest.mark.parametrize(
  82. "lengths,batch_size,expected_batches",
  83. [
  84. ([12], 4, 3), # single dataset, perfect fit
  85. ([13], 4, 3), # single dataset, truncated down
  86. ([7, 9], 4, 3), # truncates both
  87. ([4, 4, 4], 4, 3), # multiple, exact fit
  88. ],
  89. )
  90. def test_len_property_various_lengths(lengths, batch_size, expected_batches):
  91. datasets = [DummyDataset(n) for n in lengths]
  92. sampler = SingleDatasetBatchSampler(datasets, global_batch_size=batch_size)
  93. assert len(sampler) == expected_batches