processing_utils.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256
  1. import importlib
  2. import logging
  3. from abc import ABC, abstractmethod
  4. from typing import ClassVar, List, Optional, Tuple, Union
  5. import torch
  6. from PIL import Image
  7. from transformers import BatchEncoding, BatchFeature
  8. try:
  9. from fast_plaid import search
  10. except ImportError:
  11. logging.info(
  12. "FastPlaid is not installed.If you want to use it:Instal with `pip install --no-deps fast-plaid fastkmeans`"
  13. )
  14. from colpali_engine.utils.torch_utils import get_torch_device
  15. class BaseVisualRetrieverProcessor(ABC):
  16. """
  17. Base class for visual retriever processors.
  18. """
  19. query_prefix: ClassVar[str] = "" # Default prefix for queries. Override in subclasses if needed.
  20. @abstractmethod
  21. def process_images(
  22. self,
  23. images: List[Image.Image],
  24. ) -> Union[BatchFeature, BatchEncoding]:
  25. """
  26. Process a list of images into a format suitable for the model.
  27. Args:
  28. images (List[Image.Image]): List of images to process.
  29. Returns:
  30. Union[BatchFeature, BatchEncoding]: Processed images.
  31. """
  32. pass
  33. @abstractmethod
  34. def process_texts(self, texts: List[str]) -> Union[BatchFeature, BatchEncoding]:
  35. """
  36. Process a list of texts into a format suitable for the model.
  37. Args:
  38. texts: List of input texts.
  39. Returns:
  40. Union[BatchFeature, BatchEncoding]: Processed texts.
  41. """
  42. pass
  43. def process_queries(
  44. self,
  45. texts: Optional[List[str]] = None,
  46. queries: Optional[List[str]] = None,
  47. max_length: int = 50,
  48. contexts: Optional[List[str]] = None,
  49. suffix: Optional[str] = None,
  50. ) -> Union[BatchFeature, BatchEncoding]:
  51. """
  52. Process a list of queries into a format suitable for the model.
  53. Args:
  54. texts: List of input texts.
  55. [DEPRECATED] max_length: Maximum length of the text.
  56. suffix: Suffix to append to each text. If None, the default query augmentation token is used.
  57. Returns:
  58. Union[BatchFeature, BatchEncoding]: Processed texts.
  59. NOTE: This function will be deprecated. Use `process_texts` instead.
  60. It is kept to maintain back-compatibility with vidore evaluator.
  61. """
  62. if texts and queries:
  63. raise ValueError("Only one of 'texts' or 'queries' should be provided.")
  64. if queries is not None:
  65. texts = queries
  66. elif texts is None:
  67. raise ValueError("No texts or queries provided.")
  68. if suffix is None:
  69. suffix = self.query_augmentation_token * 10
  70. # Add the query prefix and suffix to each text
  71. texts = [self.query_prefix + text + suffix for text in texts]
  72. return self.process_texts(texts=texts)
  73. @abstractmethod
  74. def score(
  75. self,
  76. qs: Union[torch.Tensor, List[torch.Tensor]],
  77. ps: Union[torch.Tensor, List[torch.Tensor]],
  78. device: Optional[Union[str, torch.device]] = None,
  79. **kwargs,
  80. ) -> torch.Tensor:
  81. pass
  82. @staticmethod
  83. def score_single_vector(
  84. qs: Union[torch.Tensor, List[torch.Tensor]],
  85. ps: Union[torch.Tensor, List[torch.Tensor]],
  86. device: Optional[Union[str, torch.device]] = None,
  87. ) -> torch.Tensor:
  88. """
  89. Compute the dot product score for the given single-vector query and passage embeddings.
  90. """
  91. device = device or get_torch_device("auto")
  92. if isinstance(qs, list) and isinstance(ps, list):
  93. if len(qs) == 0:
  94. raise ValueError("No queries provided")
  95. if len(ps) == 0:
  96. raise ValueError("No passages provided")
  97. qs = torch.stack(qs).to(device)
  98. ps = torch.stack(ps).to(device)
  99. else:
  100. qs = qs.to(device)
  101. ps = ps.to(device)
  102. scores = torch.einsum("bd,cd->bc", qs, ps)
  103. assert scores.shape[0] == len(qs), f"Expected {len(qs)} scores, got {scores.shape[0]}"
  104. scores = scores.to(torch.float32)
  105. return scores
  106. @staticmethod
  107. def score_multi_vector(
  108. qs: Union[torch.Tensor, List[torch.Tensor]],
  109. ps: Union[torch.Tensor, List[torch.Tensor]],
  110. batch_size: int = 128,
  111. device: Optional[Union[str, torch.device]] = None,
  112. ) -> torch.Tensor:
  113. """
  114. Compute the late-interaction/MaxSim score (ColBERT-like) for the given multi-vector
  115. query embeddings (`qs`) and passage embeddings (`ps`). For ColPali, a passage is the
  116. image of a document page.
  117. Because the embedding tensors are multi-vector and can thus have different shapes, they
  118. should be fed as:
  119. (1) a list of tensors, where the i-th tensor is of shape (sequence_length_i, embedding_dim)
  120. (2) a single tensor of shape (n_passages, max_sequence_length, embedding_dim) -> usually
  121. obtained by padding the list of tensors.
  122. Args:
  123. qs (`Union[torch.Tensor, List[torch.Tensor]`): Query embeddings.
  124. ps (`Union[torch.Tensor, List[torch.Tensor]`): Passage embeddings.
  125. batch_size (`int`, *optional*, defaults to 128): Batch size for computing scores.
  126. device (`Union[str, torch.device]`, *optional*): Device to use for computation. If not
  127. provided, uses `get_torch_device("auto")`.
  128. Returns:
  129. `torch.Tensor`: A tensor of shape `(n_queries, n_passages)` containing the scores. The score
  130. tensor is saved on the "cpu" device.
  131. """
  132. device = device or get_torch_device("auto")
  133. if len(qs) == 0:
  134. raise ValueError("No queries provided")
  135. if len(ps) == 0:
  136. raise ValueError("No passages provided")
  137. scores_list: List[torch.Tensor] = []
  138. for i in range(0, len(qs), batch_size):
  139. scores_batch = []
  140. qs_batch = torch.nn.utils.rnn.pad_sequence(qs[i : i + batch_size], batch_first=True, padding_value=0).to(
  141. device
  142. )
  143. for j in range(0, len(ps), batch_size):
  144. ps_batch = torch.nn.utils.rnn.pad_sequence(
  145. ps[j : j + batch_size], batch_first=True, padding_value=0
  146. ).to(device)
  147. scores_batch.append(torch.einsum("bnd,csd->bcns", qs_batch, ps_batch).max(dim=3)[0].sum(dim=2))
  148. scores_batch = torch.cat(scores_batch, dim=1).cpu()
  149. scores_list.append(scores_batch)
  150. scores = torch.cat(scores_list, dim=0)
  151. assert scores.shape[0] == len(qs), f"Expected {len(qs)} scores, got {scores.shape[0]}"
  152. scores = scores.to(torch.float32)
  153. return scores
  154. @staticmethod
  155. def get_topk_plaid(
  156. qs: Union[torch.Tensor, List[torch.Tensor]],
  157. plaid_index: "search.FastPlaid",
  158. k: int = 10,
  159. batch_size: int = 128,
  160. device: Optional[Union[str, torch.device]] = None,
  161. ) -> torch.Tensor:
  162. """
  163. Experimental: Compute the late-interaction/MaxSim score (ColBERT-like) for the given multi-vector
  164. query embeddings (`qs`) and passage embeddings endoded in a plaid index. For ColPali, a passage is the
  165. image of a document page.
  166. """
  167. device = device or get_torch_device("auto")
  168. if len(qs) == 0:
  169. raise ValueError("No queries provided")
  170. scores_list: List[torch.Tensor] = []
  171. for i in range(0, len(qs), batch_size):
  172. scores_batch = []
  173. qs_batch = torch.nn.utils.rnn.pad_sequence(qs[i : i + batch_size], batch_first=True, padding_value=0).to(
  174. device
  175. )
  176. # Use the plaid index to get the top-k scores
  177. scores_batch = plaid_index.search(
  178. queries_embeddings=qs_batch.to(torch.float32),
  179. top_k=k,
  180. )
  181. scores_list.append(scores_batch)
  182. return scores_list
  183. @staticmethod
  184. def create_plaid_index(
  185. ps: Union[torch.Tensor, List[torch.Tensor]],
  186. device: Optional[Union[str, torch.device]] = None,
  187. ) -> torch.Tensor:
  188. """
  189. Experimental: Create a FastPlaid index from the given passage embeddings.
  190. Args:
  191. ps (`Union[torch.Tensor, List[torch.Tensor]]`): Passage embeddings. Should be a list of tensors,
  192. where each tensor is of shape (sequence_length_i, embedding_dim).
  193. device (`Optional[Union[str, torch.device]]`, *optional*): Device to use for computation. If not
  194. provided, uses `get_torch_device("auto")`.
  195. """
  196. # assert fast_plaid is installed
  197. if not importlib.util.find_spec("fast_plaid"):
  198. raise ImportError("FastPlaid is not installed. Please install it with `pip install fast-plaid`.")
  199. fast_plaid_index = search.FastPlaid(index="index")
  200. # torch.nn.utils.rnn.pad_sequence(ds, batch_first=True, padding_value=0).to(device)
  201. device = device or get_torch_device("auto")
  202. fast_plaid_index.create(documents_embeddings=[d.to(device).to(torch.float32) for d in ps])
  203. return fast_plaid_index
  204. @abstractmethod
  205. def get_n_patches(
  206. self,
  207. image_size: Tuple[int, int],
  208. *args,
  209. **kwargs,
  210. ) -> Tuple[int, int]:
  211. """
  212. Get the number of patches (n_patches_x, n_patches_y) that will be used to process an
  213. image of size (height, width) with the given patch size.
  214. """
  215. pass