dataset.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. import random
  2. from typing import Any, Dict, List, Optional, Union
  3. from datasets import Dataset as HFDataset
  4. from PIL import Image
  5. from torch.utils.data import Dataset
  6. Document = Union[str, Image.Image]
  7. class Corpus:
  8. """
  9. Corpus class for handling retrieving with simple mapping.
  10. This class is meant to be overridden by the user to handle their own corpus.
  11. Args:
  12. corpus_data (List[Dict[str, Any]]): List of dictionaries containing doc data.
  13. docid_to_idx_mapping (Optional[Dict[str, int]]): Optional mapping from doc IDs to indices.
  14. """
  15. def __init__(
  16. self,
  17. corpus_data: List[Dict[str, Any]],
  18. docid_to_idx_mapping: Optional[Dict[str, int]] = None,
  19. doc_column_name: str = "doc",
  20. ):
  21. """
  22. Initialize the corpus with the provided data.
  23. """
  24. self.corpus_data = corpus_data
  25. self.docid_to_idx_mapping = docid_to_idx_mapping
  26. self.doc_column_name = doc_column_name
  27. assert isinstance(
  28. self.corpus_data,
  29. (list, Dataset, HFDataset),
  30. ), "Corpus data must be a map-style dataset"
  31. assert self.doc_column_name in self.corpus_data[0], f"Corpus data must contain a column {self.doc_column_name}."
  32. def __len__(self) -> int:
  33. """
  34. Return the number of docs in the corpus.
  35. Returns:
  36. int: The number of docs in the corpus.
  37. """
  38. return len(self.corpus_data)
  39. def retrieve(self, docid: Any) -> Document:
  40. """
  41. Get the corpus row from the given Doc ID.
  42. Args:
  43. docid (str): The id of the document.
  44. Returns:
  45. Document: The document retrieved from the corpus.
  46. """
  47. if self.docid_to_idx_mapping is not None:
  48. doc_idx = self.docid_to_idx_mapping[docid]
  49. else:
  50. doc_idx = docid
  51. return self.corpus_data[doc_idx][self.doc_column_name]
  52. class ColPaliEngineDataset(Dataset):
  53. # Output keys
  54. QUERY_KEY = "query"
  55. POS_TARGET_KEY = "pos_target"
  56. NEG_TARGET_KEY = "neg_target"
  57. def __init__(
  58. self,
  59. data: List[Dict[str, Any]],
  60. corpus: Optional[Corpus] = None,
  61. query_column_name: str = "query",
  62. pos_target_column_name: str = "pos_target",
  63. neg_target_column_name: str = None,
  64. num_negatives: int = 3,
  65. ):
  66. """
  67. Initialize the dataset with the provided data and external document corpus.
  68. Args:
  69. data (Dict[str, List[Any]]): A dictionary containing the dataset samples.
  70. corpus (Optional[Corpus]): An optional external document corpus to retrieve
  71. documents (images) from.
  72. """
  73. self.data = data
  74. self.corpus = corpus
  75. # Column args
  76. self.query_column_name = query_column_name
  77. self.pos_target_column_name = pos_target_column_name
  78. self.neg_target_column_name = neg_target_column_name
  79. self.num_negatives = num_negatives
  80. assert isinstance(
  81. self.data,
  82. (list, Dataset, HFDataset),
  83. ), "Data must be a map-style dataset"
  84. assert self.query_column_name in self.data[0], f"Data must contain the {self.query_column_name} column"
  85. assert self.pos_target_column_name in self.data[0], f"Data must contain a {self.pos_target_column_name} column"
  86. if self.neg_target_column_name is not None:
  87. assert self.neg_target_column_name in self.data[0], (
  88. f"Data must contain a {self.neg_target_column_name} column"
  89. )
  90. def __len__(self) -> int:
  91. """Return the number of samples in the dataset."""
  92. return len(self.data)
  93. def __getitem__(self, idx: int) -> Dict[str, Any]:
  94. sample = self.data[idx]
  95. query = sample[self.query_column_name]
  96. pos_targets = sample[self.pos_target_column_name]
  97. if not isinstance(pos_targets, list):
  98. pos_targets = [pos_targets]
  99. if self.neg_target_column_name is not None:
  100. neg_targets = sample[self.neg_target_column_name]
  101. if not isinstance(neg_targets, list):
  102. neg_targets = [neg_targets]
  103. else:
  104. neg_targets = None
  105. # If an external document corpus is provided, retrieve the documents from it.
  106. if self.corpus is not None:
  107. pos_targets = [self.corpus.retrieve(doc_id) for doc_id in pos_targets]
  108. if neg_targets is not None:
  109. # to avoid oveflowing CPU memory
  110. if len(neg_targets) > self.num_negatives:
  111. neg_targets = random.sample(neg_targets, self.num_negatives)
  112. neg_targets = [self.corpus.retrieve(doc_id) for doc_id in neg_targets]
  113. return {
  114. self.QUERY_KEY: query,
  115. self.POS_TARGET_KEY: pos_targets,
  116. self.NEG_TARGET_KEY: neg_targets,
  117. }
  118. def take(self, n: int) -> "ColPaliEngineDataset":
  119. """
  120. Take the first n samples from the dataset.
  121. Args:
  122. n (int): The number of samples to take.
  123. Returns:
  124. ColPaliEngineDataset: A new dataset containing the first n samples.
  125. """
  126. return self.__class__(
  127. self.data.take(n),
  128. self.corpus,
  129. self.query_column_name,
  130. self.pos_target_column_name,
  131. self.neg_target_column_name,
  132. )