| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899 |
- import gc
- import logging
- from typing import List, TypeVar
- import torch
- from torch.utils.data import Dataset
- logger = logging.getLogger(__name__)
- T = TypeVar("T")
- def get_torch_device(device: str = "auto") -> str:
- """
- Returns the device (string) to be used by PyTorch.
- `device` arg defaults to "auto" which will use:
- - "cuda:0" if available
- - else "mps" if available
- - else "cpu".
- """
- if device == "auto":
- if torch.cuda.is_available():
- device = "cuda:0"
- elif torch.backends.mps.is_available(): # for Apple Silicon
- device = "mps"
- else:
- device = "cpu"
- logger.info(f"Using device: {device}")
- return device
- def tear_down_torch():
- """
- Teardown for PyTorch.
- Clears GPU cache for both CUDA and MPS.
- """
- gc.collect()
- if torch.cuda.is_available():
- torch.cuda.empty_cache()
- if torch.backends.mps.is_available():
- torch.mps.empty_cache()
- class ListDataset(Dataset[T]):
- def __init__(self, elements: List[T]):
- self.elements = elements
- def __len__(self) -> int:
- return len(self.elements)
- def __getitem__(self, idx: int) -> T:
- return self.elements[idx]
- def unbind_padded_multivector_embeddings(
- embeddings: torch.Tensor,
- padding_value: float = 0.0,
- padding_side: str = "left",
- ) -> List[torch.Tensor]:
- """
- Removes padding elements from a batch of multivector embeddings.
- Args:
- embeddings (torch.Tensor): A tensor of shape (batch_size, seq_length, dim) with padding.
- padding_value (float): The value used for padding. Each padded token is assumed
- to be a vector where every element equals this value.
- padding_side (str): Either "left" or "right". This indicates whether the padded
- elements appear at the beginning (left) or end (right) of the sequence.
- Returns:
- List[torch.Tensor]: A list of tensors, one per sequence in the batch, where
- each tensor has shape (new_seq_length, dim) and contains only the non-padding elements.
- """
- results: List[torch.Tensor] = []
- for seq in embeddings:
- is_padding = torch.all(seq.eq(padding_value), dim=-1)
- if padding_side == "left":
- non_padding_indices = (~is_padding).nonzero(as_tuple=False)
- if non_padding_indices.numel() == 0:
- valid_seq = seq[:0]
- else:
- first_valid_idx = non_padding_indices[0].item()
- valid_seq = seq[first_valid_idx:]
- elif padding_side == "right":
- non_padding_indices = (~is_padding).nonzero(as_tuple=False)
- if non_padding_indices.numel() == 0:
- valid_seq = seq[:0]
- else:
- last_valid_idx = non_padding_indices[-1].item()
- valid_seq = seq[: last_valid_idx + 1]
- else:
- raise ValueError("padding_side must be either 'left' or 'right'.")
- results.append(valid_seq)
- return results
|