bi_encoder_losses.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418
  1. import torch
  2. import torch.nn.functional as F # noqa: N812
  3. from torch.nn import CrossEntropyLoss
  4. class BiEncoderModule(torch.nn.Module):
  5. """
  6. Base module for bi-encoder losses, handling buffer indexing and filtering hyperparameters.
  7. Args:
  8. max_batch_size (int): Maximum batch size for the pre-allocated index buffer.
  9. temperature (float): Scaling factor for logits (must be > 0).
  10. filter_threshold (float): Fraction of positive score above which negatives are down-weighted.
  11. filter_factor (float): Multiplicative factor applied to filtered negative scores.
  12. """
  13. def __init__(
  14. self,
  15. max_batch_size: int = 1024,
  16. temperature: float = 0.02,
  17. filter_threshold: float = 0.95,
  18. filter_factor: float = 0.5,
  19. ):
  20. super().__init__()
  21. if temperature <= 0:
  22. raise ValueError("Temperature must be strictly positive")
  23. self.register_buffer("idx_buffer", torch.arange(max_batch_size), persistent=False)
  24. self.temperature = temperature
  25. self.filter_threshold = filter_threshold
  26. self.filter_factor = filter_factor
  27. def _get_idx(self, batch_size: int, offset: int, device: torch.device):
  28. """
  29. Generate index tensors for in-batch cross-entropy.
  30. Args:
  31. batch_size (int): Number of queries/docs in the batch.
  32. offset (int): Offset to apply for multi-GPU indexing.
  33. device (torch.device): Target device of the indices.
  34. Returns:
  35. Tuple[Tensor, Tensor]: (idx, pos_idx) both shape [batch_size].
  36. """
  37. idx = self.idx_buffer[:batch_size].to(device)
  38. return idx, idx + offset
  39. def _filter_high_negatives(self, scores: torch.Tensor, pos_idx: torch.Tensor):
  40. """
  41. In-place down-weighting of "too-high" in-batch negative scores.
  42. Args:
  43. scores (Tensor[B, B]): In-batch similarity matrix.
  44. pos_idx (Tensor[B]): Positive index for each query.
  45. """
  46. batch_size = scores.size(0)
  47. idx = self.idx_buffer[:batch_size].to(scores.device)
  48. pos_scores = scores[idx, pos_idx]
  49. thresh = self.filter_threshold * pos_scores.unsqueeze(1)
  50. mask = scores > thresh
  51. mask[idx, pos_idx] = False
  52. scores[mask] *= self.filter_factor
  53. class BiEncoderLoss(BiEncoderModule):
  54. """
  55. InfoNCE loss for bi-encoders without explicit negatives.
  56. Args:
  57. temperature (float): Scaling factor for logits.
  58. pos_aware_negative_filtering (bool): Apply in-batch negative filtering if True.
  59. max_batch_size (int): Max batch size for index buffer caching.
  60. filter_threshold (float): Threshold ratio for negative filtering.
  61. filter_factor (float): Factor to down-weight filtered negatives.
  62. """
  63. def __init__(
  64. self,
  65. temperature: float = 0.02,
  66. pos_aware_negative_filtering: bool = False,
  67. max_batch_size: int = 1024,
  68. filter_threshold: float = 0.95,
  69. filter_factor: float = 0.5,
  70. ):
  71. super().__init__(max_batch_size, temperature, filter_threshold, filter_factor)
  72. self.pos_aware_negative_filtering = pos_aware_negative_filtering
  73. self.ce_loss = CrossEntropyLoss()
  74. def forward(
  75. self,
  76. query_embeddings: torch.Tensor,
  77. doc_embeddings: torch.Tensor,
  78. offset: int = 0,
  79. ) -> torch.Tensor:
  80. """
  81. Compute the InfoNCE loss over a batch of bi-encoder embeddings.
  82. Args:
  83. query_embeddings (Tensor[B, D]): Query vectors.
  84. doc_embeddings (Tensor[B, D]): Document vectors.
  85. offset (int): Offset for positive indices (multi-GPU).
  86. Returns:
  87. Tensor: Scalar cross-entropy loss.
  88. """
  89. # Compute in-batch similarity matrix
  90. scores = torch.einsum("bd,cd->bc", query_embeddings, doc_embeddings)
  91. batch_size = scores.size(0)
  92. idx, pos_idx = self._get_idx(batch_size, offset, scores.device)
  93. if self.pos_aware_negative_filtering:
  94. self._filter_high_negatives(scores, pos_idx)
  95. return self.ce_loss(scores / self.temperature, pos_idx)
  96. class BiPairedEncoderLoss(BiEncoderModule):
  97. """
  98. InfoNCE loss for bi-encoders without explicit negatives.
  99. Args:
  100. temperature (float): Scaling factor for logits.
  101. pos_aware_negative_filtering (bool): Apply in-batch negative filtering if True.
  102. max_batch_size (int): Max batch size for index buffer caching.
  103. filter_threshold (float): Threshold ratio for negative filtering.
  104. filter_factor (float): Factor to down-weight filtered negatives.
  105. """
  106. def __init__(
  107. self,
  108. temperature: float = 0.02,
  109. pos_aware_negative_filtering: bool = False,
  110. max_batch_size: int = 1024,
  111. filter_threshold: float = 0.95,
  112. filter_factor: float = 0.5,
  113. ):
  114. super().__init__(max_batch_size, temperature, filter_threshold, filter_factor)
  115. self.pos_aware_negative_filtering = pos_aware_negative_filtering
  116. self.ce_loss = CrossEntropyLoss()
  117. def forward(
  118. self,
  119. query_embeddings: torch.Tensor,
  120. doc_embeddings: torch.Tensor,
  121. offset: int = 0,
  122. ) -> torch.Tensor:
  123. """
  124. Compute the InfoNCE loss over a batch of bi-encoder embeddings.
  125. Args:
  126. query_embeddings (Tensor[B, D]): Query vectors.
  127. doc_embeddings (Tensor[B, D]): Document vectors.
  128. offset (int): Offset for positive indices (multi-GPU).
  129. Returns:
  130. Tensor: Scalar cross-entropy loss.
  131. """
  132. # Compute in-batch similarity matrix
  133. scores = torch.einsum("bd,cd->bc", query_embeddings, doc_embeddings)
  134. batch_size = scores.size(0)
  135. idx, pos_idx = self._get_idx(batch_size, offset, scores.device)
  136. if self.pos_aware_negative_filtering:
  137. self._filter_high_negatives(scores, pos_idx)
  138. q2t = self.ce_loss(scores / self.temperature, pos_idx)
  139. t2q = self.ce_loss(scores.T / self.temperature, ...)
  140. return (q2t + t2q) / 2.0
  141. class BiNegativeCELoss(BiEncoderModule):
  142. """
  143. InfoNCE loss with explicit negative samples and optional in-batch term.
  144. Args:
  145. temperature (float): Scaling factor for logits.
  146. in_batch_term_weight (float): Weight for in-batch cross-entropy term (0 to 1).
  147. pos_aware_negative_filtering (bool): Apply in-batch negative filtering.
  148. max_batch_size (int): Max batch size for index buffer.
  149. filter_threshold (float): Threshold ratio for filtering.
  150. filter_factor (float): Factor to down-weight filtered negatives.
  151. """
  152. def __init__(
  153. self,
  154. temperature: float = 0.02,
  155. in_batch_term_weight: float = 0.5,
  156. pos_aware_negative_filtering: bool = False,
  157. max_batch_size: int = 1024,
  158. filter_threshold: float = 0.95,
  159. filter_factor: float = 0.5,
  160. ):
  161. super().__init__(max_batch_size, temperature, filter_threshold, filter_factor)
  162. self.in_batch_term_weight = in_batch_term_weight
  163. assert 0 <= in_batch_term_weight <= 1, "in_batch_term_weight must be between 0 and 1"
  164. self.pos_aware_negative_filtering = pos_aware_negative_filtering
  165. self.ce_loss = CrossEntropyLoss()
  166. # Inner InfoNCE for in-batch
  167. self.inner_loss = BiEncoderLoss(
  168. temperature=temperature,
  169. pos_aware_negative_filtering=pos_aware_negative_filtering,
  170. max_batch_size=max_batch_size,
  171. filter_threshold=filter_threshold,
  172. filter_factor=filter_factor,
  173. )
  174. def forward(
  175. self,
  176. query_embeddings: torch.Tensor,
  177. doc_embeddings: torch.Tensor,
  178. neg_doc_embeddings: torch.Tensor,
  179. offset: int = 0,
  180. ) -> torch.Tensor:
  181. """
  182. Compute softplus(neg_score - pos_score) plus optional in-batch CE.
  183. Args:
  184. query_embeddings (Tensor[B, D]): Query vectors.
  185. doc_embeddings (Tensor[B, D]): Positive document vectors.
  186. neg_doc_embeddings (Tensor[B, N, D]): Negative document vectors.
  187. offset (int): Offset for in-batch CE positives.
  188. Returns:
  189. Tensor: Scalar loss value.
  190. """
  191. # Dot-product only for matching pairs
  192. pos_scores = (query_embeddings * doc_embeddings[offset : offset + neg_doc_embeddings.size(0)]).sum(dim=1)
  193. pos_scores /= self.temperature
  194. neg_scores = torch.einsum("bd,bnd->bn", query_embeddings, neg_doc_embeddings) / self.temperature
  195. loss = F.softplus(neg_scores - pos_scores.unsqueeze(1)).mean()
  196. if self.in_batch_term_weight > 0:
  197. loss_ib = self.inner_loss(query_embeddings, doc_embeddings, offset)
  198. loss = loss * (1 - self.in_batch_term_weight) + loss_ib * self.in_batch_term_weight
  199. return loss
  200. class BiPairwiseCELoss(BiEncoderModule):
  201. """
  202. Pairwise softplus loss mining the hardest in-batch negative.
  203. Args:
  204. temperature (float): Scaling factor for logits.
  205. pos_aware_negative_filtering (bool): Filter high negatives before mining.
  206. max_batch_size (int): Maximum batch size for indexing.
  207. filter_threshold (float): Threshold for pos-aware filtering.
  208. filter_factor (float): Factor to down-weight filtered negatives.
  209. """
  210. def __init__(
  211. self,
  212. temperature: float = 0.02,
  213. pos_aware_negative_filtering: bool = False,
  214. max_batch_size: int = 1024,
  215. filter_threshold: float = 0.95,
  216. filter_factor: float = 0.5,
  217. ):
  218. super().__init__(max_batch_size, temperature, filter_threshold, filter_factor)
  219. self.pos_aware_negative_filtering = pos_aware_negative_filtering
  220. def forward(
  221. self,
  222. query_embeddings: torch.Tensor,
  223. doc_embeddings: torch.Tensor,
  224. offset: int = 0,
  225. ) -> torch.Tensor:
  226. """
  227. Compute softplus(hardest_neg - pos) where hardest_neg is the highest off-diagonal score.
  228. Args:
  229. query_embeddings (Tensor[B, D]): Query vectors.
  230. doc_embeddings (Tensor[B, D]): Document vectors.
  231. Returns:
  232. Tensor: Scalar loss value.
  233. """
  234. scores = torch.einsum("bd,cd->bc", query_embeddings, doc_embeddings)
  235. batch_size = scores.size(0)
  236. idx = self.idx_buffer[:batch_size].to(scores.device)
  237. pos = scores.diagonal()
  238. if self.pos_aware_negative_filtering:
  239. self._filter_high_negatives(scores, idx)
  240. top2 = scores.topk(2, dim=1).values
  241. neg = torch.where(top2[:, 0] == pos, top2[:, 1], top2[:, 0])
  242. return torch.nn.functional.softplus((neg - pos) / self.temperature).mean()
  243. class BiPairwiseNegativeCELoss(BiEncoderModule):
  244. """
  245. Pairwise softplus loss with explicit negatives and optional in-batch term.
  246. Args:
  247. temperature (float): Scaling factor for logits.
  248. in_batch_term_weight (float): Weight for in-batch cross-entropy term (0 to 1).
  249. max_batch_size (int): Maximum batch size for indexing.
  250. filter_threshold (float): Threshold for pos-aware filtering.
  251. filter_factor (float): Factor to down-weight filtered negatives.
  252. """
  253. def __init__(
  254. self,
  255. temperature: float = 0.02,
  256. in_batch_term_weight: float = 0.5,
  257. max_batch_size: int = 1024,
  258. filter_threshold: float = 0.95,
  259. filter_factor: float = 0.5,
  260. ):
  261. super().__init__(max_batch_size, temperature, filter_threshold, filter_factor)
  262. self.in_batch_term_weight = in_batch_term_weight
  263. assert 0 <= in_batch_term_weight <= 1, "in_batch_term_weight must be between 0 and 1"
  264. self.inner_pairwise = BiPairwiseCELoss(
  265. temperature=temperature,
  266. pos_aware_negative_filtering=False,
  267. max_batch_size=max_batch_size,
  268. filter_threshold=filter_threshold,
  269. filter_factor=filter_factor,
  270. )
  271. def forward(
  272. self,
  273. query_embeddings: torch.Tensor,
  274. doc_embeddings: torch.Tensor,
  275. neg_doc_embeddings: torch.Tensor,
  276. offset: int = 0,
  277. ) -> torch.Tensor:
  278. """
  279. Compute softplus(neg-explicit - pos) plus optional pairwise in-batch loss.
  280. Args:
  281. query_embeddings (Tensor[B, D]): Query vectors.
  282. doc_embeddings (Tensor[B, D]): Positive document vectors.
  283. neg_doc_embeddings (Tensor[B, N, D]): Negative document vectors.
  284. Returns:
  285. Tensor: Scalar loss value.
  286. """
  287. # dot product for matching pairs only
  288. pos = (query_embeddings * doc_embeddings[offset : offset + query_embeddings.size(0)]).sum(dim=1) # B
  289. neg = (query_embeddings.unsqueeze(1) * neg_doc_embeddings).sum(dim=2) # B x N
  290. loss = torch.nn.functional.softplus((neg - pos.unsqueeze(1)) / self.temperature).mean()
  291. if self.in_batch_term_weight > 0:
  292. loss_ib = self.inner_pairwise(query_embeddings, doc_embeddings, offset=offset)
  293. loss = loss * (1 - self.in_batch_term_weight) + loss_ib * self.in_batch_term_weight
  294. return loss
  295. class BiSigmoidLoss(BiEncoderModule):
  296. """
  297. Sigmoid loss for ColBERT with in-batch negatives.
  298. Args:
  299. temperature (float): Scaling factor for logits.
  300. pos_aware_negative_filtering (bool): Apply in-batch negative filtering if True.
  301. max_batch_size (int): Max batch size for index buffer caching.
  302. filter_threshold (float): Threshold ratio for negative filtering.
  303. filter_factor (float): Factor to down-weight filtered negatives.
  304. """
  305. def __init__(
  306. self,
  307. temperature: float = 0.02,
  308. pos_aware_negative_filtering: bool = False,
  309. max_batch_size: int = 1024,
  310. filter_threshold: float = 0.95,
  311. filter_factor: float = 0.5,
  312. ):
  313. super().__init__(max_batch_size, temperature, filter_threshold, filter_factor)
  314. self.pos_aware_negative_filtering = pos_aware_negative_filtering
  315. def forward(self, query_embeddings: torch.Tensor, doc_embeddings: torch.Tensor, offset: int = 0) -> torch.Tensor:
  316. """
  317. Compute the sigmoid loss for a batch of bi-encoder embeddings.
  318. Args:
  319. query_embeddings (Tensor[B, D]): Query vectors.
  320. doc_embeddings (Tensor[B, D]): Document vectors.
  321. offset (int): Offset for positive indices (multi-GPU).
  322. Returns:
  323. Tensor: Scalar cross-entropy loss.
  324. """
  325. # Compute in-batch similarity matrix
  326. scores = torch.einsum("bd,cd->bc", query_embeddings, doc_embeddings)
  327. batch_size, num_targets = scores.shape
  328. device = scores.device
  329. _, pos_idx = self._get_idx(batch_size, offset, device)
  330. if self.pos_aware_negative_filtering:
  331. self._filter_high_negatives(scores, pos_idx)
  332. all_losses = []
  333. for k in range(num_targets // batch_size):
  334. # mask equal to 1 on offset -> offset + batch_size
  335. curr_idx = torch.arange(offset, offset + batch_size, device=device)
  336. # keep only the scores for the current batch
  337. curr_scores = scores[:, curr_idx].view(-1) / self.temperature
  338. # compute the labels
  339. labels = -torch.ones(batch_size * batch_size, device=device)
  340. if k == 0:
  341. flat_pos = (pos_idx - offset) * (batch_size + 1)
  342. labels[flat_pos] = 1.0
  343. # compute the loss
  344. block_loss = F.softplus(curr_scores * labels)
  345. all_losses.append(block_loss)
  346. # shift the offset for the next batch
  347. offset = (offset + batch_size) % num_targets
  348. return torch.stack(all_losses, dim=0).mean()