torch_utils.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. import gc
  2. import logging
  3. from typing import List, TypeVar
  4. import torch
  5. from torch.utils.data import Dataset
  6. logger = logging.getLogger(__name__)
  7. T = TypeVar("T")
  8. def get_torch_device(device: str = "auto") -> str:
  9. """
  10. Returns the device (string) to be used by PyTorch.
  11. `device` arg defaults to "auto" which will use:
  12. - "cuda:0" if available
  13. - else "mps" if available
  14. - else "cpu".
  15. """
  16. if device == "auto":
  17. if torch.cuda.is_available():
  18. device = "cuda:0"
  19. elif torch.backends.mps.is_available(): # for Apple Silicon
  20. device = "mps"
  21. else:
  22. device = "cpu"
  23. logger.info(f"Using device: {device}")
  24. return device
  25. def tear_down_torch():
  26. """
  27. Teardown for PyTorch.
  28. Clears GPU cache for both CUDA and MPS.
  29. """
  30. gc.collect()
  31. if torch.cuda.is_available():
  32. torch.cuda.empty_cache()
  33. if torch.backends.mps.is_available():
  34. torch.mps.empty_cache()
  35. class ListDataset(Dataset[T]):
  36. def __init__(self, elements: List[T]):
  37. self.elements = elements
  38. def __len__(self) -> int:
  39. return len(self.elements)
  40. def __getitem__(self, idx: int) -> T:
  41. return self.elements[idx]
  42. def unbind_padded_multivector_embeddings(
  43. embeddings: torch.Tensor,
  44. padding_value: float = 0.0,
  45. padding_side: str = "left",
  46. ) -> List[torch.Tensor]:
  47. """
  48. Removes padding elements from a batch of multivector embeddings.
  49. Args:
  50. embeddings (torch.Tensor): A tensor of shape (batch_size, seq_length, dim) with padding.
  51. padding_value (float): The value used for padding. Each padded token is assumed
  52. to be a vector where every element equals this value.
  53. padding_side (str): Either "left" or "right". This indicates whether the padded
  54. elements appear at the beginning (left) or end (right) of the sequence.
  55. Returns:
  56. List[torch.Tensor]: A list of tensors, one per sequence in the batch, where
  57. each tensor has shape (new_seq_length, dim) and contains only the non-padding elements.
  58. """
  59. results: List[torch.Tensor] = []
  60. for seq in embeddings:
  61. is_padding = torch.all(seq.eq(padding_value), dim=-1)
  62. if padding_side == "left":
  63. non_padding_indices = (~is_padding).nonzero(as_tuple=False)
  64. if non_padding_indices.numel() == 0:
  65. valid_seq = seq[:0]
  66. else:
  67. first_valid_idx = non_padding_indices[0].item()
  68. valid_seq = seq[first_valid_idx:]
  69. elif padding_side == "right":
  70. non_padding_indices = (~is_padding).nonzero(as_tuple=False)
  71. if non_padding_indices.numel() == 0:
  72. valid_seq = seq[:0]
  73. else:
  74. last_valid_idx = non_padding_indices[-1].item()
  75. valid_seq = seq[: last_valid_idx + 1]
  76. else:
  77. raise ValueError("padding_side must be either 'left' or 'right'.")
  78. results.append(valid_seq)
  79. return results