test_li_losses.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. # ruff: noqa: N806, N812
  2. import pytest
  3. import torch
  4. import torch.nn.functional as F
  5. from colpali_engine.loss import (
  6. ColbertLoss,
  7. ColbertModule,
  8. ColbertNegativeCELoss,
  9. ColbertPairwiseCELoss,
  10. ColbertPairwiseNegativeCELoss,
  11. )
  12. class TestColbertModule:
  13. def test_get_idx(self):
  14. module = ColbertModule(max_batch_size=5)
  15. idx, pos_idx = module._get_idx(batch_size=3, offset=2, device=torch.device("cpu"))
  16. assert torch.equal(idx, torch.tensor([0, 1, 2]))
  17. assert torch.equal(pos_idx, torch.tensor([2, 3, 4]))
  18. def test_smooth_max(self):
  19. module = ColbertModule(tau=2.0)
  20. scores = torch.tensor([[0.0, 2.0]])
  21. out = module._smooth_max(scores, dim=1)
  22. expected = 2.0 * torch.log(torch.tensor(1.0 + torch.exp(torch.tensor(1.0))))
  23. assert torch.allclose(out, expected)
  24. def test_apply_normalization_within_bounds(self):
  25. module = ColbertModule(norm_tol=1e-3)
  26. scores = torch.tensor([[0.5, 1.0], [0.2, 0.8]])
  27. lengths = torch.tensor([2.0, 4.0])
  28. normalized = module._apply_normalization(scores, lengths)
  29. expected = scores / lengths.unsqueeze(1)
  30. assert torch.allclose(normalized, expected)
  31. # def test_apply_normalization_out_of_bounds(self):
  32. # module = ColbertModule(norm_tol=1e-3)
  33. # scores = torch.tensor([[2.0, 0.0], [0.0, 0.0]])
  34. # lengths = torch.tensor([1.0, 1.0])
  35. # with pytest.raises(ValueError) as excinfo:
  36. # module._apply_normalization(scores, lengths)
  37. # assert "Scores out of bounds after normalization" in str(excinfo.value)
  38. def test_aggregate_max(self):
  39. module = ColbertModule()
  40. raw = torch.tensor(
  41. [
  42. [[1.0, 2.0], [3.0, 4.0]],
  43. [[5.0, 6.0], [7.0, 8.0]],
  44. ]
  45. )
  46. out = module._aggregate(raw, use_smooth_max=False, dim_max=2, dim_sum=1)
  47. assert torch.allclose(out, torch.tensor([6.0, 14.0]))
  48. def test_aggregate_smooth_max(self):
  49. module = ColbertModule(tau=1.0)
  50. raw = torch.zeros(1, 2, 2)
  51. out = module._aggregate(raw, use_smooth_max=True, dim_max=2, dim_sum=1)
  52. assert torch.allclose(out, 2 * torch.log(torch.tensor(2.0)))
  53. def test_filter_high_negatives(self):
  54. module = ColbertModule(filter_threshold=0.95, filter_factor=0.5)
  55. scores = torch.tensor([[1.0, 0.96], [0.5, 1.0]])
  56. original = scores.clone()
  57. pos_idx = torch.tensor([0, 1])
  58. module._filter_high_negatives(scores, pos_idx)
  59. assert scores[0, 1] == pytest.approx(0.48)
  60. # other entries unchanged
  61. assert scores[0, 0] == original[0, 0]
  62. assert scores[1, 0] == original[1, 0]
  63. assert scores[1, 1] == original[1, 1]
  64. class TestColbertLoss:
  65. def test_zero_embeddings(self):
  66. loss_fn = ColbertLoss(
  67. temperature=1.0,
  68. normalize_scores=False,
  69. use_smooth_max=False,
  70. pos_aware_negative_filtering=False,
  71. )
  72. B, Nq, D = 3, 1, 4
  73. query = torch.zeros(B, Nq, D)
  74. doc = torch.zeros(B, Nq, D)
  75. loss = loss_fn(query, doc)
  76. expected = torch.log(torch.tensor(float(B)))
  77. assert torch.allclose(loss, expected)
  78. def test_with_and_without_filtering(self):
  79. base = ColbertLoss(
  80. temperature=1.0, normalize_scores=False, use_smooth_max=False, pos_aware_negative_filtering=False
  81. )
  82. filt = ColbertLoss(
  83. temperature=1.0, normalize_scores=False, use_smooth_max=False, pos_aware_negative_filtering=True
  84. )
  85. B, Nq, D = 2, 1, 3
  86. query = torch.zeros(B, Nq, D)
  87. doc = torch.zeros(B, Nq, D)
  88. assert torch.allclose(base(query, doc), filt(query, doc))
  89. class TestColbertNegativeCELoss:
  90. def test_no_inbatch(self):
  91. loss_fn = ColbertNegativeCELoss(
  92. temperature=1.0,
  93. normalize_scores=False,
  94. use_smooth_max=False,
  95. pos_aware_negative_filtering=False,
  96. in_batch_term_weight=0,
  97. )
  98. B, Lq, D, Lneg, Nneg = 2, 1, 3, 1, 1
  99. query = torch.zeros(B, Lq, D)
  100. doc = torch.zeros(B, Lq, D)
  101. neg = torch.zeros(B, Nneg, Lneg, D)
  102. loss = loss_fn(query, doc, neg)
  103. expected = F.softplus(torch.tensor(0.0))
  104. assert torch.allclose(loss, expected)
  105. def test_with_inbatch(self):
  106. loss_fn = ColbertNegativeCELoss(
  107. temperature=1.0,
  108. normalize_scores=False,
  109. use_smooth_max=False,
  110. pos_aware_negative_filtering=False,
  111. in_batch_term_weight=0.5,
  112. )
  113. B, Lq, D, Lneg, Nneg = 2, 1, 3, 1, 1
  114. query = torch.zeros(B, Lq, D)
  115. doc = torch.zeros(B, Lq, D)
  116. neg = torch.zeros(B, Nneg, Lneg, D)
  117. loss = loss_fn(query, doc, neg)
  118. expected = F.softplus(torch.tensor(0.0))
  119. assert torch.allclose(loss, expected)
  120. class TestColbertPairwiseCELoss:
  121. def test_zero_embeddings(self):
  122. loss_fn = ColbertPairwiseCELoss(
  123. temperature=1.0, normalize_scores=False, use_smooth_max=False, pos_aware_negative_filtering=False
  124. )
  125. B, Nq, D = 2, 1, 3
  126. query = torch.zeros(B, Nq, D)
  127. doc = torch.zeros(B, Nq, D)
  128. loss = loss_fn(query, doc)
  129. expected = F.softplus(torch.tensor(0.0))
  130. assert torch.allclose(loss, expected)
  131. class TestColbertPairwiseNegativeCELoss:
  132. def test_no_inbatch(self):
  133. loss_fn = ColbertPairwiseNegativeCELoss(
  134. temperature=1.0,
  135. normalize_scores=False,
  136. use_smooth_max=False,
  137. pos_aware_negative_filtering=False,
  138. in_batch_term_weight=0,
  139. )
  140. B, Lq, D, Lneg, Nneg = 2, 1, 3, 1, 1
  141. query = torch.zeros(B, Lq, D)
  142. doc = torch.zeros(B, Lq, D)
  143. neg = torch.zeros(B, Nneg, Lneg, D)
  144. loss = loss_fn(query, doc, neg)
  145. expected = F.softplus(torch.tensor(0.0))
  146. assert torch.allclose(loss, expected)
  147. def test_with_inbatch(self):
  148. loss_fn = ColbertPairwiseNegativeCELoss(
  149. temperature=1.0,
  150. normalize_scores=False,
  151. use_smooth_max=False,
  152. pos_aware_negative_filtering=False,
  153. in_batch_term_weight=0.5,
  154. )
  155. B, Lq, D, Lneg, Nneg = 2, 1, 3, 1, 1
  156. query = torch.zeros(B, Lq, D)
  157. doc = torch.zeros(B, Lq, D)
  158. neg = torch.zeros(B, Nneg, Lneg, D)
  159. loss = loss_fn(query, doc, neg)
  160. expected = F.softplus(torch.tensor(0.0))
  161. assert torch.allclose(loss, expected)