| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418 |
- import torch
- import torch.nn.functional as F # noqa: N812
- from torch.nn import CrossEntropyLoss
- class BiEncoderModule(torch.nn.Module):
- """
- Base module for bi-encoder losses, handling buffer indexing and filtering hyperparameters.
- Args:
- max_batch_size (int): Maximum batch size for the pre-allocated index buffer.
- temperature (float): Scaling factor for logits (must be > 0).
- filter_threshold (float): Fraction of positive score above which negatives are down-weighted.
- filter_factor (float): Multiplicative factor applied to filtered negative scores.
- """
- def __init__(
- self,
- max_batch_size: int = 1024,
- temperature: float = 0.02,
- filter_threshold: float = 0.95,
- filter_factor: float = 0.5,
- ):
- super().__init__()
- if temperature <= 0:
- raise ValueError("Temperature must be strictly positive")
- self.register_buffer("idx_buffer", torch.arange(max_batch_size), persistent=False)
- self.temperature = temperature
- self.filter_threshold = filter_threshold
- self.filter_factor = filter_factor
- def _get_idx(self, batch_size: int, offset: int, device: torch.device):
- """
- Generate index tensors for in-batch cross-entropy.
- Args:
- batch_size (int): Number of queries/docs in the batch.
- offset (int): Offset to apply for multi-GPU indexing.
- device (torch.device): Target device of the indices.
- Returns:
- Tuple[Tensor, Tensor]: (idx, pos_idx) both shape [batch_size].
- """
- idx = self.idx_buffer[:batch_size].to(device)
- return idx, idx + offset
- def _filter_high_negatives(self, scores: torch.Tensor, pos_idx: torch.Tensor):
- """
- In-place down-weighting of "too-high" in-batch negative scores.
- Args:
- scores (Tensor[B, B]): In-batch similarity matrix.
- pos_idx (Tensor[B]): Positive index for each query.
- """
- batch_size = scores.size(0)
- idx = self.idx_buffer[:batch_size].to(scores.device)
- pos_scores = scores[idx, pos_idx]
- thresh = self.filter_threshold * pos_scores.unsqueeze(1)
- mask = scores > thresh
- mask[idx, pos_idx] = False
- scores[mask] *= self.filter_factor
- class BiEncoderLoss(BiEncoderModule):
- """
- InfoNCE loss for bi-encoders without explicit negatives.
- Args:
- temperature (float): Scaling factor for logits.
- pos_aware_negative_filtering (bool): Apply in-batch negative filtering if True.
- max_batch_size (int): Max batch size for index buffer caching.
- filter_threshold (float): Threshold ratio for negative filtering.
- filter_factor (float): Factor to down-weight filtered negatives.
- """
- def __init__(
- self,
- temperature: float = 0.02,
- pos_aware_negative_filtering: bool = False,
- max_batch_size: int = 1024,
- filter_threshold: float = 0.95,
- filter_factor: float = 0.5,
- ):
- super().__init__(max_batch_size, temperature, filter_threshold, filter_factor)
- self.pos_aware_negative_filtering = pos_aware_negative_filtering
- self.ce_loss = CrossEntropyLoss()
- def forward(
- self,
- query_embeddings: torch.Tensor,
- doc_embeddings: torch.Tensor,
- offset: int = 0,
- ) -> torch.Tensor:
- """
- Compute the InfoNCE loss over a batch of bi-encoder embeddings.
- Args:
- query_embeddings (Tensor[B, D]): Query vectors.
- doc_embeddings (Tensor[B, D]): Document vectors.
- offset (int): Offset for positive indices (multi-GPU).
- Returns:
- Tensor: Scalar cross-entropy loss.
- """
- # Compute in-batch similarity matrix
- scores = torch.einsum("bd,cd->bc", query_embeddings, doc_embeddings)
- batch_size = scores.size(0)
- idx, pos_idx = self._get_idx(batch_size, offset, scores.device)
- if self.pos_aware_negative_filtering:
- self._filter_high_negatives(scores, pos_idx)
- return self.ce_loss(scores / self.temperature, pos_idx)
- class BiPairedEncoderLoss(BiEncoderModule):
- """
- InfoNCE loss for bi-encoders without explicit negatives.
- Args:
- temperature (float): Scaling factor for logits.
- pos_aware_negative_filtering (bool): Apply in-batch negative filtering if True.
- max_batch_size (int): Max batch size for index buffer caching.
- filter_threshold (float): Threshold ratio for negative filtering.
- filter_factor (float): Factor to down-weight filtered negatives.
- """
- def __init__(
- self,
- temperature: float = 0.02,
- pos_aware_negative_filtering: bool = False,
- max_batch_size: int = 1024,
- filter_threshold: float = 0.95,
- filter_factor: float = 0.5,
- ):
- super().__init__(max_batch_size, temperature, filter_threshold, filter_factor)
- self.pos_aware_negative_filtering = pos_aware_negative_filtering
- self.ce_loss = CrossEntropyLoss()
- def forward(
- self,
- query_embeddings: torch.Tensor,
- doc_embeddings: torch.Tensor,
- offset: int = 0,
- ) -> torch.Tensor:
- """
- Compute the InfoNCE loss over a batch of bi-encoder embeddings.
- Args:
- query_embeddings (Tensor[B, D]): Query vectors.
- doc_embeddings (Tensor[B, D]): Document vectors.
- offset (int): Offset for positive indices (multi-GPU).
- Returns:
- Tensor: Scalar cross-entropy loss.
- """
- # Compute in-batch similarity matrix
- scores = torch.einsum("bd,cd->bc", query_embeddings, doc_embeddings)
- batch_size = scores.size(0)
- idx, pos_idx = self._get_idx(batch_size, offset, scores.device)
- if self.pos_aware_negative_filtering:
- self._filter_high_negatives(scores, pos_idx)
- q2t = self.ce_loss(scores / self.temperature, pos_idx)
- t2q = self.ce_loss(scores.T / self.temperature, ...)
- return (q2t + t2q) / 2.0
- class BiNegativeCELoss(BiEncoderModule):
- """
- InfoNCE loss with explicit negative samples and optional in-batch term.
- Args:
- temperature (float): Scaling factor for logits.
- in_batch_term_weight (float): Weight for in-batch cross-entropy term (0 to 1).
- pos_aware_negative_filtering (bool): Apply in-batch negative filtering.
- max_batch_size (int): Max batch size for index buffer.
- filter_threshold (float): Threshold ratio for filtering.
- filter_factor (float): Factor to down-weight filtered negatives.
- """
- def __init__(
- self,
- temperature: float = 0.02,
- in_batch_term_weight: float = 0.5,
- pos_aware_negative_filtering: bool = False,
- max_batch_size: int = 1024,
- filter_threshold: float = 0.95,
- filter_factor: float = 0.5,
- ):
- super().__init__(max_batch_size, temperature, filter_threshold, filter_factor)
- self.in_batch_term_weight = in_batch_term_weight
- assert 0 <= in_batch_term_weight <= 1, "in_batch_term_weight must be between 0 and 1"
- self.pos_aware_negative_filtering = pos_aware_negative_filtering
- self.ce_loss = CrossEntropyLoss()
- # Inner InfoNCE for in-batch
- self.inner_loss = BiEncoderLoss(
- temperature=temperature,
- pos_aware_negative_filtering=pos_aware_negative_filtering,
- max_batch_size=max_batch_size,
- filter_threshold=filter_threshold,
- filter_factor=filter_factor,
- )
- def forward(
- self,
- query_embeddings: torch.Tensor,
- doc_embeddings: torch.Tensor,
- neg_doc_embeddings: torch.Tensor,
- offset: int = 0,
- ) -> torch.Tensor:
- """
- Compute softplus(neg_score - pos_score) plus optional in-batch CE.
- Args:
- query_embeddings (Tensor[B, D]): Query vectors.
- doc_embeddings (Tensor[B, D]): Positive document vectors.
- neg_doc_embeddings (Tensor[B, N, D]): Negative document vectors.
- offset (int): Offset for in-batch CE positives.
- Returns:
- Tensor: Scalar loss value.
- """
- # Dot-product only for matching pairs
- pos_scores = (query_embeddings * doc_embeddings[offset : offset + neg_doc_embeddings.size(0)]).sum(dim=1)
- pos_scores /= self.temperature
- neg_scores = torch.einsum("bd,bnd->bn", query_embeddings, neg_doc_embeddings) / self.temperature
- loss = F.softplus(neg_scores - pos_scores.unsqueeze(1)).mean()
- if self.in_batch_term_weight > 0:
- loss_ib = self.inner_loss(query_embeddings, doc_embeddings, offset)
- loss = loss * (1 - self.in_batch_term_weight) + loss_ib * self.in_batch_term_weight
- return loss
- class BiPairwiseCELoss(BiEncoderModule):
- """
- Pairwise softplus loss mining the hardest in-batch negative.
- Args:
- temperature (float): Scaling factor for logits.
- pos_aware_negative_filtering (bool): Filter high negatives before mining.
- max_batch_size (int): Maximum batch size for indexing.
- filter_threshold (float): Threshold for pos-aware filtering.
- filter_factor (float): Factor to down-weight filtered negatives.
- """
- def __init__(
- self,
- temperature: float = 0.02,
- pos_aware_negative_filtering: bool = False,
- max_batch_size: int = 1024,
- filter_threshold: float = 0.95,
- filter_factor: float = 0.5,
- ):
- super().__init__(max_batch_size, temperature, filter_threshold, filter_factor)
- self.pos_aware_negative_filtering = pos_aware_negative_filtering
- def forward(
- self,
- query_embeddings: torch.Tensor,
- doc_embeddings: torch.Tensor,
- offset: int = 0,
- ) -> torch.Tensor:
- """
- Compute softplus(hardest_neg - pos) where hardest_neg is the highest off-diagonal score.
- Args:
- query_embeddings (Tensor[B, D]): Query vectors.
- doc_embeddings (Tensor[B, D]): Document vectors.
- Returns:
- Tensor: Scalar loss value.
- """
- scores = torch.einsum("bd,cd->bc", query_embeddings, doc_embeddings)
- batch_size = scores.size(0)
- idx = self.idx_buffer[:batch_size].to(scores.device)
- pos = scores.diagonal()
- if self.pos_aware_negative_filtering:
- self._filter_high_negatives(scores, idx)
- top2 = scores.topk(2, dim=1).values
- neg = torch.where(top2[:, 0] == pos, top2[:, 1], top2[:, 0])
- return torch.nn.functional.softplus((neg - pos) / self.temperature).mean()
- class BiPairwiseNegativeCELoss(BiEncoderModule):
- """
- Pairwise softplus loss with explicit negatives and optional in-batch term.
- Args:
- temperature (float): Scaling factor for logits.
- in_batch_term_weight (float): Weight for in-batch cross-entropy term (0 to 1).
- max_batch_size (int): Maximum batch size for indexing.
- filter_threshold (float): Threshold for pos-aware filtering.
- filter_factor (float): Factor to down-weight filtered negatives.
- """
- def __init__(
- self,
- temperature: float = 0.02,
- in_batch_term_weight: float = 0.5,
- max_batch_size: int = 1024,
- filter_threshold: float = 0.95,
- filter_factor: float = 0.5,
- ):
- super().__init__(max_batch_size, temperature, filter_threshold, filter_factor)
- self.in_batch_term_weight = in_batch_term_weight
- assert 0 <= in_batch_term_weight <= 1, "in_batch_term_weight must be between 0 and 1"
- self.inner_pairwise = BiPairwiseCELoss(
- temperature=temperature,
- pos_aware_negative_filtering=False,
- max_batch_size=max_batch_size,
- filter_threshold=filter_threshold,
- filter_factor=filter_factor,
- )
- def forward(
- self,
- query_embeddings: torch.Tensor,
- doc_embeddings: torch.Tensor,
- neg_doc_embeddings: torch.Tensor,
- offset: int = 0,
- ) -> torch.Tensor:
- """
- Compute softplus(neg-explicit - pos) plus optional pairwise in-batch loss.
- Args:
- query_embeddings (Tensor[B, D]): Query vectors.
- doc_embeddings (Tensor[B, D]): Positive document vectors.
- neg_doc_embeddings (Tensor[B, N, D]): Negative document vectors.
- Returns:
- Tensor: Scalar loss value.
- """
- # dot product for matching pairs only
- pos = (query_embeddings * doc_embeddings[offset : offset + query_embeddings.size(0)]).sum(dim=1) # B
- neg = (query_embeddings.unsqueeze(1) * neg_doc_embeddings).sum(dim=2) # B x N
- loss = torch.nn.functional.softplus((neg - pos.unsqueeze(1)) / self.temperature).mean()
- if self.in_batch_term_weight > 0:
- loss_ib = self.inner_pairwise(query_embeddings, doc_embeddings, offset=offset)
- loss = loss * (1 - self.in_batch_term_weight) + loss_ib * self.in_batch_term_weight
- return loss
- class BiSigmoidLoss(BiEncoderModule):
- """
- Sigmoid loss for ColBERT with in-batch negatives.
- Args:
- temperature (float): Scaling factor for logits.
- pos_aware_negative_filtering (bool): Apply in-batch negative filtering if True.
- max_batch_size (int): Max batch size for index buffer caching.
- filter_threshold (float): Threshold ratio for negative filtering.
- filter_factor (float): Factor to down-weight filtered negatives.
- """
- def __init__(
- self,
- temperature: float = 0.02,
- pos_aware_negative_filtering: bool = False,
- max_batch_size: int = 1024,
- filter_threshold: float = 0.95,
- filter_factor: float = 0.5,
- ):
- super().__init__(max_batch_size, temperature, filter_threshold, filter_factor)
- self.pos_aware_negative_filtering = pos_aware_negative_filtering
- def forward(self, query_embeddings: torch.Tensor, doc_embeddings: torch.Tensor, offset: int = 0) -> torch.Tensor:
- """
- Compute the sigmoid loss for a batch of bi-encoder embeddings.
- Args:
- query_embeddings (Tensor[B, D]): Query vectors.
- doc_embeddings (Tensor[B, D]): Document vectors.
- offset (int): Offset for positive indices (multi-GPU).
- Returns:
- Tensor: Scalar cross-entropy loss.
- """
- # Compute in-batch similarity matrix
- scores = torch.einsum("bd,cd->bc", query_embeddings, doc_embeddings)
- batch_size, num_targets = scores.shape
- device = scores.device
- _, pos_idx = self._get_idx(batch_size, offset, device)
- if self.pos_aware_negative_filtering:
- self._filter_high_negatives(scores, pos_idx)
- all_losses = []
- for k in range(num_targets // batch_size):
- # mask equal to 1 on offset -> offset + batch_size
- curr_idx = torch.arange(offset, offset + batch_size, device=device)
- # keep only the scores for the current batch
- curr_scores = scores[:, curr_idx].view(-1) / self.temperature
- # compute the labels
- labels = -torch.ones(batch_size * batch_size, device=device)
- if k == 0:
- flat_pos = (pos_idx - offset) * (batch_size + 1)
- labels[flat_pos] = 1.0
- # compute the loss
- block_loss = F.softplus(curr_scores * labels)
- all_losses.append(block_loss)
- # shift the offset for the next batch
- offset = (offset + batch_size) % num_targets
- return torch.stack(all_losses, dim=0).mean()
|