dataset_transformation.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268
  1. import os
  2. from typing import List, Tuple, cast
  3. from datasets import Dataset, DatasetDict, concatenate_datasets, load_dataset
  4. from PIL import Image
  5. from colpali_engine.data.dataset import ColPaliEngineDataset, Corpus
  6. USE_LOCAL_DATASET = os.environ.get("USE_LOCAL_DATASET", "1") == "1"
  7. def load_train_set() -> ColPaliEngineDataset:
  8. base_path = "./data_dir/" if USE_LOCAL_DATASET else "vidore/"
  9. dataset = load_dataset(base_path + "colpali_train_set", split="train")
  10. train_dataset = ColPaliEngineDataset(dataset, pos_target_column_name="image")
  11. return train_dataset
  12. def load_eval_set(dataset_path) -> ColPaliEngineDataset:
  13. dataset = load_dataset(dataset_path, split="test")
  14. return dataset
  15. def load_train_set_ir(num_negs=0) -> ColPaliEngineDataset:
  16. """Returns the query dataset, then the anchor dataset with the documents, then the dataset type"""
  17. base_path = "./data_dir/" if USE_LOCAL_DATASET else "manu/"
  18. corpus_data = load_dataset(base_path + "colpali-corpus", split="train")
  19. corpus = Corpus(corpus_data=corpus_data, doc_column_name="image")
  20. dataset = load_dataset(base_path + "colpali-queries", split="train")
  21. print("Dataset size:", len(dataset))
  22. # filter out queries with "gold_in_top_100" == False
  23. dataset = dataset.filter(lambda x: x["gold_in_top_100"], num_proc=16)
  24. if num_negs > 0:
  25. # keep only top 5 negative passages
  26. dataset = dataset.map(lambda x: {"negative_passages": x["negative_passages"][:num_negs]})
  27. print("Dataset size after filtering:", len(dataset))
  28. train_dataset = ColPaliEngineDataset(
  29. data=dataset,
  30. corpus=corpus,
  31. pos_target_column_name="positive_passages",
  32. neg_target_column_name="negative_passages" if num_negs else None,
  33. )
  34. return train_dataset
  35. def load_train_set_detailed() -> DatasetDict:
  36. ds_paths = [
  37. "infovqa_train",
  38. "docvqa_train",
  39. "arxivqa_train",
  40. "tatdqa_train",
  41. "syntheticDocQA_government_reports_train",
  42. "syntheticDocQA_healthcare_industry_train",
  43. "syntheticDocQA_artificial_intelligence_train",
  44. "syntheticDocQA_energy_train",
  45. ]
  46. base_path = "./data_dir/" if USE_LOCAL_DATASET else "vidore/"
  47. ds_tot = []
  48. for path in ds_paths:
  49. cpath = base_path + path
  50. ds = cast(Dataset, load_dataset(cpath, split="train"))
  51. if "arxivqa" in path:
  52. # subsample 10k
  53. ds = ds.shuffle(42).select(range(10000))
  54. ds_tot.append(ds)
  55. dataset = cast(Dataset, concatenate_datasets(ds_tot))
  56. dataset = dataset.shuffle(seed=42)
  57. # split into train and test
  58. dataset_eval = dataset.select(range(500))
  59. dataset = dataset.select(range(500, len(dataset)))
  60. ds_dict = DatasetDict({"train": dataset, "test": dataset_eval})
  61. return ds_dict
  62. def load_train_set_with_tabfquad() -> DatasetDict:
  63. ds_paths = [
  64. "infovqa_train",
  65. "docvqa_train",
  66. "arxivqa_train",
  67. "tatdqa_train",
  68. "tabfquad_train_subsampled",
  69. "syntheticDocQA_government_reports_train",
  70. "syntheticDocQA_healthcare_industry_train",
  71. "syntheticDocQA_artificial_intelligence_train",
  72. "syntheticDocQA_energy_train",
  73. ]
  74. base_path = "./data_dir/" if USE_LOCAL_DATASET else "vidore/"
  75. ds_tot = []
  76. for path in ds_paths:
  77. cpath = base_path + path
  78. ds = cast(Dataset, load_dataset(cpath, split="train"))
  79. if "arxivqa" in path:
  80. # subsample 10k
  81. ds = ds.shuffle(42).select(range(10000))
  82. ds_tot.append(ds)
  83. dataset = cast(Dataset, concatenate_datasets(ds_tot))
  84. dataset = dataset.shuffle(seed=42)
  85. # split into train and test
  86. dataset_eval = dataset.select(range(500))
  87. dataset = dataset.select(range(500, len(dataset)))
  88. ds_dict = DatasetDict({"train": dataset, "test": dataset_eval})
  89. return ds_dict
  90. def load_docmatix_ir_negs() -> Tuple[DatasetDict, Dataset, str]:
  91. """Returns the query dataset, then the anchor dataset with the documents, then the dataset type"""
  92. base_path = "./data_dir/" if USE_LOCAL_DATASET else "Tevatron/"
  93. dataset = cast(Dataset, load_dataset(base_path + "docmatix-ir", split="train"))
  94. # dataset = dataset.select(range(100500))
  95. dataset_eval = dataset.select(range(500))
  96. dataset = dataset.select(range(500, len(dataset)))
  97. ds_dict = DatasetDict({"train": dataset, "test": dataset_eval})
  98. base_path = "./data_dir/" if USE_LOCAL_DATASET else "HuggingFaceM4/"
  99. anchor_ds = cast(Dataset, load_dataset(base_path + "Docmatix", "images", split="train"))
  100. return ds_dict, anchor_ds, "docmatix"
  101. def load_wikiss() -> Tuple[DatasetDict, Dataset, str]:
  102. """Returns the query dataset, then the anchor dataset with the documents, then the dataset type"""
  103. base_path = "./data_dir/" if USE_LOCAL_DATASET else "Tevatron/"
  104. dataset = cast(Dataset, load_dataset(base_path + "wiki-ss-nq", data_files="train.jsonl", split="train"))
  105. # dataset = dataset.select(range(400500))
  106. dataset_eval = dataset.select(range(500))
  107. dataset = dataset.select(range(500, len(dataset)))
  108. ds_dict = DatasetDict({"train": dataset, "test": dataset_eval})
  109. base_path = "./data_dir/" if USE_LOCAL_DATASET else "HuggingFaceM4/"
  110. anchor_ds = cast(Dataset, load_dataset(base_path + "wiki-ss-corpus", split="train"))
  111. return ds_dict, anchor_ds, "wikiss"
  112. def load_train_set_with_docmatix() -> DatasetDict:
  113. ds_paths = [
  114. "infovqa_train",
  115. "docvqa_train",
  116. "arxivqa_train",
  117. "tatdqa_train",
  118. "tabfquad_train_subsampled",
  119. "syntheticDocQA_government_reports_train",
  120. "syntheticDocQA_healthcare_industry_train",
  121. "syntheticDocQA_artificial_intelligence_train",
  122. "syntheticDocQA_energy_train",
  123. "Docmatix_filtered_train",
  124. ]
  125. base_path = "./data_dir/" if USE_LOCAL_DATASET else "vidore/"
  126. ds_tot: List[Dataset] = []
  127. for path in ds_paths:
  128. cpath = base_path + path
  129. ds = cast(Dataset, load_dataset(cpath, split="train"))
  130. if "arxivqa" in path:
  131. # subsample 10k
  132. ds = ds.shuffle(42).select(range(10000))
  133. ds_tot.append(ds)
  134. dataset = concatenate_datasets(ds_tot)
  135. dataset = dataset.shuffle(seed=42)
  136. # split into train and test
  137. dataset_eval = dataset.select(range(500))
  138. dataset = dataset.select(range(500, len(dataset)))
  139. ds_dict = DatasetDict({"train": dataset, "test": dataset_eval})
  140. return ds_dict
  141. def load_docvqa_dataset() -> DatasetDict:
  142. if USE_LOCAL_DATASET:
  143. dataset_doc = cast(Dataset, load_dataset("./data_dir/DocVQA", "DocVQA", split="validation"))
  144. dataset_doc_eval = cast(Dataset, load_dataset("./data_dir/DocVQA", "DocVQA", split="test"))
  145. dataset_info = cast(Dataset, load_dataset("./data_dir/DocVQA", "InfographicVQA", split="validation"))
  146. dataset_info_eval = cast(Dataset, load_dataset("./data_dir/DocVQA", "InfographicVQA", split="test"))
  147. else:
  148. dataset_doc = cast(Dataset, load_dataset("lmms-lab/DocVQA", "DocVQA", split="validation"))
  149. dataset_doc_eval = cast(Dataset, load_dataset("lmms-lab/DocVQA", "DocVQA", split="test"))
  150. dataset_info = cast(Dataset, load_dataset("lmms-lab/DocVQA", "InfographicVQA", split="validation"))
  151. dataset_info_eval = cast(Dataset, load_dataset("lmms-lab/DocVQA", "InfographicVQA", split="test"))
  152. # concatenate the two datasets
  153. dataset = concatenate_datasets([dataset_doc, dataset_info])
  154. dataset_eval = concatenate_datasets([dataset_doc_eval, dataset_info_eval])
  155. # sample 100 from eval dataset
  156. dataset_eval = dataset_eval.shuffle(seed=42).select(range(200))
  157. # rename question as query
  158. dataset = dataset.rename_column("question", "query")
  159. dataset_eval = dataset_eval.rename_column("question", "query")
  160. # create new column image_filename that corresponds to ucsf_document_id if not None, else image_url
  161. dataset = dataset.map(
  162. lambda x: {"image_filename": x["ucsf_document_id"] if x["ucsf_document_id"] is not None else x["image_url"]}
  163. )
  164. dataset_eval = dataset_eval.map(
  165. lambda x: {"image_filename": x["ucsf_document_id"] if x["ucsf_document_id"] is not None else x["image_url"]}
  166. )
  167. ds_dict = DatasetDict({"train": dataset, "test": dataset_eval})
  168. return ds_dict
  169. def load_dummy_dataset() -> List[DatasetDict]:
  170. # create a dataset from the queries and images
  171. queries_1 = ["What is the capital of France?", "What is the capital of Germany?"]
  172. queries_2 = ["What is the capital of Italy?", "What is the capital of Spain?"]
  173. images_1 = [Image.new("RGB", (100, 100)) for _ in range(2)]
  174. images_2 = [Image.new("RGB", (120, 120)) for _ in range(2)]
  175. dataset_1 = Dataset.from_list([{"query": q, "image": i} for q, i in zip(queries_1, images_1)])
  176. dataset_2 = Dataset.from_list([{"query": q, "image": i} for q, i in zip(queries_2, images_2)])
  177. return DatasetDict(
  178. {
  179. "train": DatasetDict({"dataset_1": dataset_1, "dataset_2": dataset_2}),
  180. "test": DatasetDict({"dataset_1": dataset_2, "dataset_2": dataset_1}),
  181. }
  182. )
  183. def load_multi_qa_datasets() -> List[DatasetDict]:
  184. dataset_args = [
  185. ("vidore/colpali_train_set"),
  186. ("llamaindex/vdr-multilingual-train", "de"),
  187. ("llamaindex/vdr-multilingual-train", "en"),
  188. ("llamaindex/vdr-multilingual-train", "es"),
  189. ("llamaindex/vdr-multilingual-train", "fr"),
  190. ("llamaindex/vdr-multilingual-train", "it"),
  191. ]
  192. train_datasets = {}
  193. test_datasets = {}
  194. for args in dataset_args:
  195. dataset_name = args[0] + "_" + args[1]
  196. dataset = load_dataset(*args)
  197. if "test" in dataset:
  198. train_datasets[dataset_name] = dataset["train"]
  199. test_datasets[dataset_name] = dataset["test"]
  200. else:
  201. train_dataset, test_dataset = dataset.split_by_ratio(test_size=200)
  202. train_datasets[dataset_name] = train_dataset
  203. test_datasets[dataset_name] = test_dataset
  204. return DatasetDict({"train": DatasetDict(train_datasets), "test": DatasetDict(test_datasets)})
  205. class TestSetFactory:
  206. def __init__(self, dataset_path):
  207. self.dataset_path = dataset_path
  208. def __call__(self, *args, **kwargs):
  209. dataset = load_dataset(self.dataset_path, split="test")
  210. return dataset
  211. if __name__ == "__main__":
  212. ds = TestSetFactory("vidore/tabfquad_test_subsampled")()
  213. print(ds)