test_bi_losses.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. # ruff: noqa: N806, N812
  2. import pytest
  3. import torch
  4. import torch.nn.functional as F
  5. from colpali_engine.loss import (
  6. BiEncoderLoss,
  7. BiEncoderModule,
  8. BiNegativeCELoss,
  9. BiPairwiseCELoss,
  10. BiPairwiseNegativeCELoss,
  11. )
  12. class TestBiEncoderModule:
  13. def test_init_invalid_temperature(self):
  14. with pytest.raises(ValueError) as excinfo:
  15. BiEncoderModule(temperature=0.0)
  16. assert "Temperature must be strictly positive" in str(excinfo.value)
  17. def test_get_idx(self):
  18. module = BiEncoderModule(max_batch_size=5, temperature=0.1)
  19. idx, pos_idx = module._get_idx(batch_size=3, offset=2, device=torch.device("cpu"))
  20. assert torch.equal(idx, torch.tensor([0, 1, 2]))
  21. assert torch.equal(pos_idx, torch.tensor([2, 3, 4]))
  22. def test_filter_high_negatives(self):
  23. module = BiEncoderModule(filter_threshold=0.95, filter_factor=0.5, temperature=0.1)
  24. # Create a 2×2 score matrix where scores[0,1] > 0.95 * pos_score[0]
  25. scores = torch.tensor([[1.0, 0.96], [0.5, 1.0]])
  26. original = scores.clone()
  27. pos_idx = torch.tensor([0, 1])
  28. module._filter_high_negatives(scores, pos_idx)
  29. # Only scores[0,1] should be down-weighted
  30. assert scores[0, 1] == pytest.approx(0.48)
  31. # Other entries unchanged
  32. assert scores[0, 0] == original[0, 0]
  33. assert scores[1, 0] == original[1, 0]
  34. assert scores[1, 1] == original[1, 1]
  35. class TestBiEncoderLoss:
  36. def test_forward_zero_embeddings(self):
  37. loss_fn = BiEncoderLoss(temperature=1.0, pos_aware_negative_filtering=False)
  38. B, D = 4, 5
  39. query = torch.zeros(B, D)
  40. doc = torch.zeros(B, D)
  41. loss = loss_fn(query, doc)
  42. # scores are all zeros => uniform softmax => loss = log(B)
  43. expected = torch.log(torch.tensor(float(B)))
  44. assert torch.allclose(loss, expected)
  45. def test_forward_with_filtering(self):
  46. loss_fn = BiEncoderLoss(temperature=1.0, pos_aware_negative_filtering=True)
  47. B, D = 3, 2
  48. query = torch.zeros(B, D)
  49. doc = torch.zeros(B, D)
  50. # Filtering on zero scores should have no effect
  51. loss1 = loss_fn(query, doc)
  52. loss2 = BiEncoderLoss(temperature=1.0, pos_aware_negative_filtering=False)(query, doc)
  53. assert torch.allclose(loss1, loss2)
  54. class TestBiNegativeCELoss:
  55. def test_forward_no_inbatch(self):
  56. loss_fn = BiNegativeCELoss(temperature=1.0, in_batch_term_weight=0, pos_aware_negative_filtering=False)
  57. B, D, Nneg = 3, 4, 1
  58. query = torch.zeros(B, D)
  59. pos = torch.zeros(B, D)
  60. neg = torch.zeros(B, Nneg, D)
  61. loss = loss_fn(query, pos, neg)
  62. # softplus(0 - 0) = ln(2)
  63. expected = F.softplus(torch.tensor(0.0))
  64. assert torch.allclose(loss, expected)
  65. def test_forward_with_inbatch(self):
  66. loss_fn = BiNegativeCELoss(temperature=1.0, in_batch_term_weight=0.5, pos_aware_negative_filtering=False)
  67. B, D, Nneg = 2, 3, 1
  68. query = torch.zeros(B, D)
  69. pos = torch.zeros(B, D)
  70. neg = torch.zeros(B, Nneg, D)
  71. loss = loss_fn(query, pos, neg)
  72. # in-batch CE on zeros: log(B)
  73. ce = torch.log(torch.tensor(float(B)))
  74. sp = F.softplus(torch.tensor(0.0))
  75. expected = (sp + ce) / 2
  76. assert torch.allclose(loss, expected)
  77. class TestBiPairwiseCELoss:
  78. def test_forward_zero_embeddings(self):
  79. loss_fn = BiPairwiseCELoss(temperature=1.0, pos_aware_negative_filtering=False)
  80. B, D = 4, 6
  81. query = torch.zeros(B, D)
  82. doc = torch.zeros(B, D)
  83. loss = loss_fn(query, doc)
  84. # hardest neg = 0, pos = 0 => softplus(0) = ln(2)
  85. expected = F.softplus(torch.tensor(0.0))
  86. assert torch.allclose(loss, expected)
  87. def test_forward_with_filtering(self):
  88. loss_fn = BiPairwiseCELoss(temperature=1.0, pos_aware_negative_filtering=True)
  89. B, D = 3, 5
  90. query = torch.zeros(B, D)
  91. doc = torch.zeros(B, D)
  92. # Filtering on zero scores should not change result
  93. assert torch.allclose(loss_fn(query, doc), BiPairwiseCELoss(temperature=1.0)(query, doc))
  94. class TestBiPairwiseNegativeCELoss:
  95. def test_forward_no_inbatch(self):
  96. loss_fn = BiPairwiseNegativeCELoss(temperature=1.0, in_batch_term_weight=0)
  97. B, Nneg, D = 5, 2, 4
  98. query = torch.zeros(B, D)
  99. pos = torch.zeros(B, D)
  100. neg = torch.zeros(B, Nneg, D)
  101. loss = loss_fn(query, pos, neg)
  102. expected = F.softplus(torch.tensor(0.0))
  103. assert torch.allclose(loss, expected)
  104. def test_forward_with_inbatch(self):
  105. loss_fn = BiPairwiseNegativeCELoss(temperature=1.0, in_batch_term_weight=0.5)
  106. B, Nneg, D = 2, 3, 4
  107. query = torch.zeros(B, D)
  108. pos = torch.zeros(B, D)
  109. neg = torch.zeros(B, Nneg, D)
  110. loss = loss_fn(query, pos, neg)
  111. # both explicit and in-batch pairwise yield ln(2), average remains ln(2)
  112. expected = F.softplus(torch.tensor(0.0))
  113. assert torch.allclose(loss, expected)