import gc import logging from typing import List, TypeVar import torch from torch.utils.data import Dataset logger = logging.getLogger(__name__) T = TypeVar("T") def get_torch_device(device: str = "auto") -> str: """ Returns the device (string) to be used by PyTorch. `device` arg defaults to "auto" which will use: - "cuda:0" if available - else "mps" if available - else "cpu". """ if device == "auto": if torch.cuda.is_available(): device = "cuda:0" elif torch.backends.mps.is_available(): # for Apple Silicon device = "mps" else: device = "cpu" logger.info(f"Using device: {device}") return device def tear_down_torch(): """ Teardown for PyTorch. Clears GPU cache for both CUDA and MPS. """ gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() if torch.backends.mps.is_available(): torch.mps.empty_cache() class ListDataset(Dataset[T]): def __init__(self, elements: List[T]): self.elements = elements def __len__(self) -> int: return len(self.elements) def __getitem__(self, idx: int) -> T: return self.elements[idx] def unbind_padded_multivector_embeddings( embeddings: torch.Tensor, padding_value: float = 0.0, padding_side: str = "left", ) -> List[torch.Tensor]: """ Removes padding elements from a batch of multivector embeddings. Args: embeddings (torch.Tensor): A tensor of shape (batch_size, seq_length, dim) with padding. padding_value (float): The value used for padding. Each padded token is assumed to be a vector where every element equals this value. padding_side (str): Either "left" or "right". This indicates whether the padded elements appear at the beginning (left) or end (right) of the sequence. Returns: List[torch.Tensor]: A list of tensors, one per sequence in the batch, where each tensor has shape (new_seq_length, dim) and contains only the non-padding elements. """ results: List[torch.Tensor] = [] for seq in embeddings: is_padding = torch.all(seq.eq(padding_value), dim=-1) if padding_side == "left": non_padding_indices = (~is_padding).nonzero(as_tuple=False) if non_padding_indices.numel() == 0: valid_seq = seq[:0] else: first_valid_idx = non_padding_indices[0].item() valid_seq = seq[first_valid_idx:] elif padding_side == "right": non_padding_indices = (~is_padding).nonzero(as_tuple=False) if non_padding_indices.numel() == 0: valid_seq = seq[:0] else: last_valid_idx = non_padding_indices[-1].item() valid_seq = seq[: last_valid_idx + 1] else: raise ValueError("padding_side must be either 'left' or 'right'.") results.append(valid_seq) return results