| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131 |
- from typing import cast
- import datasets
- import numpy as np
- import torch
- from torch.utils.data import DataLoader
- from tqdm import tqdm
- from colpali_engine.models import BiQwen2, BiQwen2Processor
- from colpali_engine.utils.dataset_transformation import load_train_set
- train_set = load_train_set()
- COMPUTE_EMBEDDINGS = False
- COMPUTE_HARDNEGS = False
- if COMPUTE_HARDNEGS or COMPUTE_EMBEDDINGS:
- print("Loading base model")
- model = BiQwen2.from_pretrained(
- "./models/biqwen2-warmup-256-newpad-0e",
- torch_dtype=torch.bfloat16,
- device_map="cuda",
- attn_implementation="flash_attention_2" if torch.cuda.is_available() else None,
- ).eval()
- print("Loading processor")
- processor = BiQwen2Processor.from_pretrained("./models/biqwen2-warmup-256-newpad-0e")
- if COMPUTE_EMBEDDINGS:
- print("Loading images")
- print("Images loaded")
- document_set = train_set["train"]
- print("Filtering dataset")
- print(document_set)
- initial_list = document_set["image_filename"]
- _, unique_indices = np.unique(initial_list, return_index=True, axis=0)
- filtered_dataset = document_set.select(unique_indices.tolist())
- filtered_dataset = filtered_dataset.map(
- lambda example: {"image": example["image"], "image_filename": example["image_filename"]}, num_proc=16
- )
- # keep only column image and image_filename and source if it exists
- cols_to_remove = [col for col in filtered_dataset.column_names if col not in ["image", "image_filename"]]
- filtered_dataset = filtered_dataset.remove_columns(cols_to_remove)
- # save it
- print("Saving filtered dataset")
- print(filtered_dataset)
- filtered_dataset.save_to_disk("data_dir/filtered_dataset", max_shard_size="200MB")
- print("Processing images")
- # run inference - docs
- dataloader = DataLoader(
- filtered_dataset,
- batch_size=8,
- shuffle=False,
- collate_fn=lambda x: processor.process_images([a["image"] for a in x]),
- )
- print("Computing embeddings")
- ds = []
- for batch_doc in tqdm(dataloader):
- with torch.no_grad():
- batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
- embeddings_doc = model(**batch_doc)
- ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
- ds = torch.stack(ds)
- # save embeddings
- torch.save(ds, "data_dir/filtered_dataset_embeddings.pt")
- if not COMPUTE_EMBEDDINGS:
- ds = torch.load("data_dir/filtered_dataset_embeddings.pt")
- if COMPUTE_HARDNEGS:
- # compute hard negatives
- ds = cast(torch.Tensor, ds).to("cuda")
- # iterate on the train set
- mined_hardnegs = []
- for i in tqdm(range(0, len(train_set["train"]), 8)):
- samples = train_set["train"][i : i + 8]
- batch_query = processor.process_queries(samples["query"])
- with torch.no_grad():
- batch_query = {k: v.to(model.device) for k, v in batch_query.items()}
- embeddings_query = model(**batch_query)
- # compute scores
- scores = torch.einsum("bd,cd->bc", embeddings_query, ds)
- # get top 100 indexes
- top100 = scores.topk(100, dim=1).indices
- # indices to list
- top100 = top100.tolist()
- # append to mined_hardnegs
- mined_hardnegs.extend(top100)
- # save mined hardnegs as txt
- with open("data_dir/mined_hardnegs_filtered.txt", "w") as f:
- for item in mined_hardnegs:
- f.write("%s\n" % item)
- with open("data_dir/mined_hardnegs_filtered.txt") as f:
- mined_hardnegs = f.readlines()
- filtered_dataset = datasets.load_from_disk("data_dir/filtered_dataset")
- filenames = list(filtered_dataset["image_filename"])
- def mapper_fn(example, idx):
- tmp = {
- "negative_passages": [int(x) for x in mined_hardnegs[idx][1:-2].strip().split(",")],
- "query": example["query"],
- "positive_passages": [filenames.index(example["image_filename"])],
- }
- tmp["gold_in_top_100"] = tmp["positive_passages"][0] in tmp["negative_passages"]
- # remove gold index from negs if it is there
- if tmp["gold_in_top_100"]:
- tmp["negative_passages"].remove(tmp["positive_passages"][0])
- return tmp
- final_dataset = train_set["train"].map(mapper_fn, with_indices=True, num_proc=16)
- # drop image
- final_dataset = final_dataset.remove_columns("image")
- final_dataset.save_to_disk("data_dir/final_dataset")
|