late_interaction_losses.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465
  1. import torch
  2. import torch.nn.functional as F # noqa: N812
  3. from torch.nn import CrossEntropyLoss
  4. class ColbertModule(torch.nn.Module):
  5. """
  6. Base module for ColBERT losses, handling shared utilities and hyperparameters.
  7. Args:
  8. max_batch_size (int): Maximum batch size for pre-allocating index buffer.
  9. tau (float): Temperature for smooth-max approximation.
  10. norm_tol (float): Tolerance for score normalization bounds.
  11. filter_threshold (float): Ratio threshold for pos-aware negative filtering.
  12. filter_factor (float): Multiplicative factor to down-weight high negatives.
  13. """
  14. def __init__(
  15. self,
  16. max_batch_size: int = 1024,
  17. tau: float = 0.1,
  18. norm_tol: float = 1e-3,
  19. filter_threshold: float = 0.95,
  20. filter_factor: float = 0.5,
  21. ):
  22. super().__init__()
  23. self.register_buffer("idx_buffer", torch.arange(max_batch_size), persistent=False)
  24. self.tau = tau
  25. self.norm_tol = norm_tol
  26. self.filter_threshold = filter_threshold
  27. self.filter_factor = filter_factor
  28. def _get_idx(self, batch_size: int, offset: int, device: torch.device):
  29. """
  30. Retrieve index and positive index tensors for in-batch losses.
  31. """
  32. idx = self.idx_buffer[:batch_size].to(device)
  33. return idx, idx + offset
  34. def _smooth_max(self, scores: torch.Tensor, dim: int) -> torch.Tensor:
  35. """
  36. Compute smooth max via log-sum-exp along a given dimension.
  37. """
  38. return self.tau * torch.logsumexp(scores / self.tau, dim=dim)
  39. def _apply_normalization(self, scores: torch.Tensor, lengths: torch.Tensor) -> torch.Tensor:
  40. """
  41. Normalize scores by query lengths and enforce bounds.
  42. Args:
  43. scores (Tensor): Unnormalized score matrix [B, C].
  44. lengths (Tensor): Query lengths [B].
  45. Returns:
  46. Tensor: Normalized scores.
  47. Raises:
  48. ValueError: If normalized scores exceed tolerance.
  49. """
  50. if scores.ndim == 2:
  51. normalized = scores / lengths.unsqueeze(1)
  52. else:
  53. normalized = scores / lengths
  54. mn, mx = torch.aminmax(normalized)
  55. if mn < -self.norm_tol or mx > 1 + self.norm_tol:
  56. print(
  57. f"Scores out of bounds after normalization: "
  58. f"min={mn.item():.4f}, max={mx.item():.4f}, tol={self.norm_tol}"
  59. )
  60. return normalized
  61. def _aggregate(
  62. self,
  63. scores_raw: torch.Tensor,
  64. use_smooth_max: bool,
  65. dim_max: int,
  66. dim_sum: int,
  67. ) -> torch.Tensor:
  68. """
  69. Aggregate token-level scores into document-level.
  70. Args:
  71. scores_raw (Tensor): Raw scores tensor.
  72. use_smooth_max (bool): Use smooth-max if True.
  73. dim_max (int): Dimension to perform max/logsumexp.
  74. dim_sum (int): Dimension to sum over after max.
  75. """
  76. if use_smooth_max:
  77. return self._smooth_max(scores_raw, dim=dim_max).sum(dim=dim_sum)
  78. return scores_raw.amax(dim=dim_max).sum(dim=dim_sum)
  79. def _filter_high_negatives(self, scores: torch.Tensor, pos_idx: torch.Tensor) -> None:
  80. """
  81. Down-weight negatives whose score exceeds a fraction of the positive score.
  82. Args:
  83. scores (Tensor): In-batch score matrix [B, B].
  84. pos_idx (Tensor): Positive indices for each query in batch.
  85. """
  86. batch_size = scores.size(0)
  87. idx = self.idx_buffer[:batch_size].to(scores.device)
  88. pos_scores = scores[idx, pos_idx]
  89. thresh = self.filter_threshold * pos_scores.unsqueeze(1)
  90. mask = scores > thresh
  91. mask[idx, pos_idx] = False
  92. scores[mask] *= self.filter_factor
  93. class ColbertLoss(ColbertModule):
  94. """
  95. InfoNCE loss for late interaction (ColBERT) without explicit negatives.
  96. Args:
  97. temperature (float): Scaling factor for logits.
  98. normalize_scores (bool): Normalize scores by query lengths.
  99. use_smooth_max (bool): Use log-sum-exp instead of amax.
  100. pos_aware_negative_filtering (bool): Apply pos-aware negative filtering.
  101. """
  102. def __init__(
  103. self,
  104. temperature: float = 0.02,
  105. normalize_scores: bool = True,
  106. use_smooth_max: bool = False,
  107. pos_aware_negative_filtering: bool = False,
  108. max_batch_size: int = 1024,
  109. tau: float = 0.1,
  110. norm_tol: float = 1e-3,
  111. filter_threshold: float = 0.95,
  112. filter_factor: float = 0.5,
  113. ):
  114. super().__init__(max_batch_size, tau, norm_tol, filter_threshold, filter_factor)
  115. self.temperature = temperature
  116. self.normalize_scores = normalize_scores
  117. self.use_smooth_max = use_smooth_max
  118. self.pos_aware_negative_filtering = pos_aware_negative_filtering
  119. self.ce_loss = CrossEntropyLoss()
  120. def forward(self, query_embeddings: torch.Tensor, doc_embeddings: torch.Tensor, offset: int = 0) -> torch.Tensor:
  121. """
  122. Compute ColBERT InfoNCE loss over a batch of queries and documents.
  123. Args:
  124. query_embeddings (Tensor): (batch_size, query_length, dim)
  125. doc_embeddings (Tensor): positive docs (batch_size, pos_doc_length, dim)
  126. offset (int): Offset for positive doc indices (multi-GPU).
  127. Returns:
  128. Tensor: Scalar loss value.
  129. """
  130. lengths = (query_embeddings[:, :, 0] != 0).sum(dim=1)
  131. raw = torch.einsum("bnd,csd->bcns", query_embeddings, doc_embeddings)
  132. scores = self._aggregate(raw, self.use_smooth_max, dim_max=3, dim_sum=2)
  133. if self.normalize_scores:
  134. scores = self._apply_normalization(scores, lengths)
  135. batch_size = scores.size(0)
  136. idx, pos_idx = self._get_idx(batch_size, offset, scores.device)
  137. if self.pos_aware_negative_filtering:
  138. self._filter_high_negatives(scores, pos_idx)
  139. return self.ce_loss(scores / self.temperature, pos_idx)
  140. class ColbertNegativeCELoss(ColbertModule):
  141. """
  142. InfoNCE loss with explicit negative documents.
  143. Args:
  144. temperature (float): Scaling for logits.
  145. normalize_scores (bool): Normalize scores by query lengths.
  146. use_smooth_max (bool): Use log-sum-exp instead of amax.
  147. pos_aware_negative_filtering (bool): Apply pos-aware negative filtering.
  148. in_batch_term_weight (float): Add in-batch CE term (between 0 and 1).
  149. """
  150. def __init__(
  151. self,
  152. temperature: float = 0.02,
  153. normalize_scores: bool = True,
  154. use_smooth_max: bool = False,
  155. pos_aware_negative_filtering: bool = False,
  156. in_batch_term_weight: float = 0.5,
  157. max_batch_size: int = 1024,
  158. tau: float = 0.1,
  159. norm_tol: float = 1e-3,
  160. filter_threshold: float = 0.95,
  161. filter_factor: float = 0.5,
  162. ):
  163. super().__init__(max_batch_size, tau, norm_tol, filter_threshold, filter_factor)
  164. self.temperature = temperature
  165. self.normalize_scores = normalize_scores
  166. self.use_smooth_max = use_smooth_max
  167. self.pos_aware_negative_filtering = pos_aware_negative_filtering
  168. self.in_batch_term_weight = in_batch_term_weight
  169. self.ce_loss = CrossEntropyLoss()
  170. assert in_batch_term_weight >= 0, "in_batch_term_weight must be non-negative"
  171. assert in_batch_term_weight <= 1, "in_batch_term_weight must be less than 1"
  172. self.inner_loss = ColbertLoss(
  173. temperature=temperature,
  174. normalize_scores=normalize_scores,
  175. use_smooth_max=use_smooth_max,
  176. pos_aware_negative_filtering=pos_aware_negative_filtering,
  177. max_batch_size=max_batch_size,
  178. tau=tau,
  179. norm_tol=norm_tol,
  180. filter_threshold=filter_threshold,
  181. filter_factor=filter_factor,
  182. )
  183. def forward(
  184. self,
  185. query_embeddings: torch.Tensor,
  186. doc_embeddings: torch.Tensor,
  187. neg_doc_embeddings: torch.Tensor,
  188. offset: int = 0,
  189. ) -> torch.Tensor:
  190. """
  191. Compute InfoNCE loss with explicit negatives and optional in-batch term.
  192. Args:
  193. query_embeddings (Tensor): (batch_size, query_length, dim)
  194. doc_embeddings (Tensor): positive docs (batch_size, pos_doc_length, dim)
  195. neg_doc_embeddings (Tensor): negative docs (batch_size, num_negs, neg_doc_length, dim)
  196. offset (int): Positional offset for in-batch CE.
  197. Returns:
  198. Tensor: Scalar loss.
  199. """
  200. lengths = (query_embeddings[:, :, 0] != 0).sum(dim=1)
  201. pos_raw = torch.einsum(
  202. "bnd,bsd->bns", query_embeddings, doc_embeddings[offset : offset + neg_doc_embeddings.size(0)]
  203. )
  204. neg_raw = torch.einsum("bnd,blsd->blns", query_embeddings, neg_doc_embeddings)
  205. pos_scores = self._aggregate(pos_raw, self.use_smooth_max, dim_max=2, dim_sum=1)
  206. neg_scores = self._aggregate(neg_raw, self.use_smooth_max, dim_max=3, dim_sum=2)
  207. if self.normalize_scores:
  208. pos_scores = self._apply_normalization(pos_scores, lengths)
  209. neg_scores = self._apply_normalization(neg_scores, lengths)
  210. loss = F.softplus((neg_scores - pos_scores.unsqueeze(1)) / self.temperature).mean()
  211. if self.in_batch_term_weight > 0:
  212. loss_ib = self.inner_loss(query_embeddings, doc_embeddings, offset)
  213. loss = loss * (1 - self.in_batch_term_weight) + loss_ib * self.in_batch_term_weight
  214. return loss
  215. class ColbertPairwiseCELoss(ColbertModule):
  216. """
  217. Pairwise loss for ColBERT (no explicit negatives).
  218. Args:
  219. temperature (float): Scaling for logits.
  220. normalize_scores (bool): Normalize scores by query lengths.
  221. use_smooth_max (bool): Use log-sum-exp instead of amax.
  222. pos_aware_negative_filtering (bool): Apply pos-aware negative filtering.
  223. """
  224. def __init__(
  225. self,
  226. temperature: float = 1.0,
  227. normalize_scores: bool = True,
  228. use_smooth_max: bool = False,
  229. pos_aware_negative_filtering: bool = False,
  230. max_batch_size: int = 1024,
  231. tau: float = 0.1,
  232. norm_tol: float = 1e-3,
  233. filter_threshold: float = 0.95,
  234. filter_factor: float = 0.5,
  235. ):
  236. super().__init__(max_batch_size, tau, norm_tol, filter_threshold, filter_factor)
  237. self.temperature = temperature
  238. self.normalize_scores = normalize_scores
  239. self.use_smooth_max = use_smooth_max
  240. self.pos_aware_negative_filtering = pos_aware_negative_filtering
  241. def forward(self, query_embeddings: torch.Tensor, doc_embeddings: torch.Tensor, offset: int = 0) -> torch.Tensor:
  242. """
  243. Compute pairwise softplus loss over in-batch document pairs.
  244. Args:
  245. query_embeddings (Tensor): (batch_size, query_length, dim)
  246. doc_embeddings (Tensor): positive docs (batch_size, pos_doc_length, dim)
  247. offset (int): Positional offset for positives.
  248. Returns:
  249. Tensor: Scalar loss value.
  250. """
  251. lengths = (query_embeddings[:, :, 0] != 0).sum(dim=1)
  252. raw = torch.einsum("bnd,csd->bcns", query_embeddings, doc_embeddings)
  253. scores = self._aggregate(raw, self.use_smooth_max, dim_max=3, dim_sum=2)
  254. if self.normalize_scores:
  255. scores = self._apply_normalization(scores, lengths)
  256. batch_size = scores.size(0)
  257. idx, pos_idx = self._get_idx(batch_size, offset, scores.device)
  258. if self.pos_aware_negative_filtering:
  259. self._filter_high_negatives(scores, pos_idx)
  260. pos_scores = scores.diagonal(offset=offset)
  261. top2 = scores.topk(2, dim=1).values
  262. neg_scores = torch.where(top2[:, 0] == pos_scores, top2[:, 1], top2[:, 0])
  263. return F.softplus((neg_scores - pos_scores) / self.temperature).mean()
  264. class ColbertPairwiseNegativeCELoss(ColbertModule):
  265. """
  266. Pairwise loss with explicit negatives and optional in-batch term.
  267. Args:
  268. temperature (float): Scaling for logits.
  269. normalize_scores (bool): Normalize scores by query lengths.
  270. use_smooth_max (bool): Use log-sum-exp instead of amax.
  271. pos_aware_negative_filtering (bool): Apply pos-aware negative filtering.
  272. in_batch_term_weight (float): Add in-batch CE term (between 0 and 1).
  273. """
  274. def __init__(
  275. self,
  276. temperature: float = 0.02,
  277. normalize_scores: bool = True,
  278. use_smooth_max: bool = False,
  279. pos_aware_negative_filtering: bool = False,
  280. in_batch_term_weight: float = 0.5,
  281. max_batch_size: int = 1024,
  282. tau: float = 0.1,
  283. norm_tol: float = 1e-3,
  284. filter_threshold: float = 0.95,
  285. filter_factor: float = 0.5,
  286. ):
  287. super().__init__(max_batch_size, tau, norm_tol, filter_threshold, filter_factor)
  288. self.temperature = temperature
  289. self.normalize_scores = normalize_scores
  290. self.use_smooth_max = use_smooth_max
  291. self.pos_aware_negative_filtering = pos_aware_negative_filtering
  292. self.in_batch_term_weight = in_batch_term_weight
  293. assert in_batch_term_weight >= 0, "in_batch_term_weight must be non-negative"
  294. assert in_batch_term_weight <= 1, "in_batch_term_weight must be less than 1"
  295. self.inner_pairwise = ColbertPairwiseCELoss(
  296. temperature=temperature,
  297. normalize_scores=normalize_scores,
  298. use_smooth_max=use_smooth_max,
  299. pos_aware_negative_filtering=pos_aware_negative_filtering,
  300. max_batch_size=max_batch_size,
  301. tau=tau,
  302. norm_tol=norm_tol,
  303. filter_threshold=filter_threshold,
  304. filter_factor=filter_factor,
  305. )
  306. def forward(
  307. self,
  308. query_embeddings: torch.Tensor,
  309. doc_embeddings: torch.Tensor,
  310. neg_doc_embeddings: torch.Tensor,
  311. offset: int = 0,
  312. ) -> torch.Tensor:
  313. """
  314. Compute pairwise softplus loss with explicit negatives and optional in-batch term.
  315. Args:
  316. query_embeddings (Tensor): (batch_size, query_length, dim)
  317. doc_embeddings (Tensor): positive docs (batch_size, pos_doc_length, dim)
  318. neg_doc_embeddings (Tensor): negative docs (batch_size, num_negs, neg_doc_length, dim)
  319. offset (int): Positional offset for positives.
  320. Returns:
  321. Tensor: Scalar loss value.
  322. """
  323. lengths = (query_embeddings[:, :, 0] != 0).sum(dim=1)
  324. pos_raw = torch.einsum(
  325. "bnd,bld->bnl", query_embeddings, doc_embeddings[offset : offset + query_embeddings.size(0)]
  326. )
  327. neg_raw = torch.einsum("bnd,bsld->bsnl", query_embeddings, neg_doc_embeddings) # B x Nneg x Nq x Lneg
  328. pos_scores = self._aggregate(pos_raw, self.use_smooth_max, dim_max=2, dim_sum=1)
  329. neg_scores = self._aggregate(neg_raw, self.use_smooth_max, dim_max=3, dim_sum=2)
  330. if self.normalize_scores:
  331. pos_scores = self._apply_normalization(pos_scores, lengths)
  332. neg_scores = self._apply_normalization(neg_scores, lengths)
  333. loss = F.softplus((neg_scores - pos_scores.unsqueeze(1)) / self.temperature).mean()
  334. if self.in_batch_term_weight > 0:
  335. loss_ib = self.inner_pairwise(query_embeddings, doc_embeddings, offset)
  336. loss = loss * (1 - self.in_batch_term_weight) + loss_ib * self.in_batch_term_weight
  337. return loss
  338. class ColbertSigmoidLoss(ColbertModule):
  339. """
  340. Sigmoid loss for ColBERT with explicit negatives.
  341. Args:
  342. temperature (float): Scaling for logits.
  343. normalize_scores (bool): Normalize scores by query lengths.
  344. use_smooth_max (bool): Use log-sum-exp instead of amax.
  345. pos_aware_negative_filtering (bool): Apply pos-aware negative filtering.
  346. """
  347. def __init__(
  348. self,
  349. temperature: float = 0.02,
  350. normalize_scores: bool = True,
  351. use_smooth_max: bool = False,
  352. pos_aware_negative_filtering: bool = False,
  353. max_batch_size: int = 1024,
  354. tau: float = 0.1,
  355. norm_tol: float = 1e-3,
  356. filter_threshold: float = 0.95,
  357. filter_factor: float = 0.5,
  358. ):
  359. super().__init__(max_batch_size, tau, norm_tol, filter_threshold, filter_factor)
  360. self.temperature = temperature
  361. self.normalize_scores = normalize_scores
  362. self.use_smooth_max = use_smooth_max
  363. self.pos_aware_negative_filtering = pos_aware_negative_filtering
  364. self.ce_loss = CrossEntropyLoss()
  365. def forward(self, query_embeddings: torch.Tensor, doc_embeddings: torch.Tensor, offset: int = 0) -> torch.Tensor:
  366. """
  367. Compute sigmoid loss over positive and negative document pairs.
  368. Args:
  369. query_embeddings (Tensor): (batch_size, query_length, dim)
  370. doc_embeddings (Tensor): positive docs (batch_size, pos_doc_length, dim)
  371. Returns:
  372. Tensor: Scalar loss value.
  373. """
  374. lengths = (query_embeddings[:, :, 0] != 0).sum(dim=1)
  375. raw = torch.einsum("bnd,csd->bcns", query_embeddings, doc_embeddings)
  376. scores = self._aggregate(raw, self.use_smooth_max, dim_max=3, dim_sum=2)
  377. if self.normalize_scores:
  378. scores = self._apply_normalization(scores, lengths)
  379. batch_size = scores.size(0)
  380. idx, pos_idx = self._get_idx(batch_size, offset, scores.device)
  381. if self.pos_aware_negative_filtering:
  382. self._filter_high_negatives(scores, pos_idx)
  383. # for each idx in pos_idx, the 2D index (idx, idx) → flat index = idx * B + idx
  384. # build a 1-D mask of length B*B with ones at those positions
  385. flat_pos = pos_idx * (batch_size + 1)
  386. pos_mask = -torch.ones(batch_size * batch_size, device=scores.device)
  387. pos_mask[flat_pos] = 1.0
  388. # flatten the scores to [B * B]
  389. scores = scores.view(-1) / self.temperature
  390. return F.softplus(scores * pos_mask).mean()