compute_hardnegs.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. from typing import cast
  2. import datasets
  3. import numpy as np
  4. import torch
  5. from torch.utils.data import DataLoader
  6. from tqdm import tqdm
  7. from colpali_engine.models import BiQwen2, BiQwen2Processor
  8. from colpali_engine.utils.dataset_transformation import load_train_set
  9. train_set = load_train_set()
  10. COMPUTE_EMBEDDINGS = False
  11. COMPUTE_HARDNEGS = False
  12. if COMPUTE_HARDNEGS or COMPUTE_EMBEDDINGS:
  13. print("Loading base model")
  14. model = BiQwen2.from_pretrained(
  15. "./models/biqwen2-warmup-256-newpad-0e",
  16. torch_dtype=torch.bfloat16,
  17. device_map="cuda",
  18. attn_implementation="flash_attention_2" if torch.cuda.is_available() else None,
  19. ).eval()
  20. print("Loading processor")
  21. processor = BiQwen2Processor.from_pretrained("./models/biqwen2-warmup-256-newpad-0e")
  22. if COMPUTE_EMBEDDINGS:
  23. print("Loading images")
  24. print("Images loaded")
  25. document_set = train_set["train"]
  26. print("Filtering dataset")
  27. print(document_set)
  28. initial_list = document_set["image_filename"]
  29. _, unique_indices = np.unique(initial_list, return_index=True, axis=0)
  30. filtered_dataset = document_set.select(unique_indices.tolist())
  31. filtered_dataset = filtered_dataset.map(
  32. lambda example: {"image": example["image"], "image_filename": example["image_filename"]}, num_proc=16
  33. )
  34. # keep only column image and image_filename and source if it exists
  35. cols_to_remove = [col for col in filtered_dataset.column_names if col not in ["image", "image_filename"]]
  36. filtered_dataset = filtered_dataset.remove_columns(cols_to_remove)
  37. # save it
  38. print("Saving filtered dataset")
  39. print(filtered_dataset)
  40. filtered_dataset.save_to_disk("data_dir/filtered_dataset", max_shard_size="200MB")
  41. print("Processing images")
  42. # run inference - docs
  43. dataloader = DataLoader(
  44. filtered_dataset,
  45. batch_size=8,
  46. shuffle=False,
  47. collate_fn=lambda x: processor.process_images([a["image"] for a in x]),
  48. )
  49. print("Computing embeddings")
  50. ds = []
  51. for batch_doc in tqdm(dataloader):
  52. with torch.no_grad():
  53. batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
  54. embeddings_doc = model(**batch_doc)
  55. ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
  56. ds = torch.stack(ds)
  57. # save embeddings
  58. torch.save(ds, "data_dir/filtered_dataset_embeddings.pt")
  59. if not COMPUTE_EMBEDDINGS:
  60. ds = torch.load("data_dir/filtered_dataset_embeddings.pt")
  61. if COMPUTE_HARDNEGS:
  62. # compute hard negatives
  63. ds = cast(torch.Tensor, ds).to("cuda")
  64. # iterate on the train set
  65. mined_hardnegs = []
  66. for i in tqdm(range(0, len(train_set["train"]), 8)):
  67. samples = train_set["train"][i : i + 8]
  68. batch_query = processor.process_queries(samples["query"])
  69. with torch.no_grad():
  70. batch_query = {k: v.to(model.device) for k, v in batch_query.items()}
  71. embeddings_query = model(**batch_query)
  72. # compute scores
  73. scores = torch.einsum("bd,cd->bc", embeddings_query, ds)
  74. # get top 100 indexes
  75. top100 = scores.topk(100, dim=1).indices
  76. # indices to list
  77. top100 = top100.tolist()
  78. # append to mined_hardnegs
  79. mined_hardnegs.extend(top100)
  80. # save mined hardnegs as txt
  81. with open("data_dir/mined_hardnegs_filtered.txt", "w") as f:
  82. for item in mined_hardnegs:
  83. f.write("%s\n" % item)
  84. with open("data_dir/mined_hardnegs_filtered.txt") as f:
  85. mined_hardnegs = f.readlines()
  86. filtered_dataset = datasets.load_from_disk("data_dir/filtered_dataset")
  87. filenames = list(filtered_dataset["image_filename"])
  88. def mapper_fn(example, idx):
  89. tmp = {
  90. "negative_passages": [int(x) for x in mined_hardnegs[idx][1:-2].strip().split(",")],
  91. "query": example["query"],
  92. "positive_passages": [filenames.index(example["image_filename"])],
  93. }
  94. tmp["gold_in_top_100"] = tmp["positive_passages"][0] in tmp["negative_passages"]
  95. # remove gold index from negs if it is there
  96. if tmp["gold_in_top_100"]:
  97. tmp["negative_passages"].remove(tmp["positive_passages"][0])
  98. return tmp
  99. final_dataset = train_set["train"].map(mapper_fn, with_indices=True, num_proc=16)
  100. # drop image
  101. final_dataset = final_dataset.remove_columns("image")
  102. final_dataset.save_to_disk("data_dir/final_dataset")