contrastive_trainer.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. from functools import partial
  2. from typing import Optional
  3. import datasets
  4. import torch
  5. from torch.distributed.nn.functional import all_gather # PyTorch ≥ 2.1
  6. from torch.utils.data import ConcatDataset, DataLoader, Dataset
  7. from transformers import Trainer, is_datasets_available
  8. from transformers.trainer_utils import seed_worker
  9. from colpali_engine.data.sampler import SingleDatasetBatchSampler
  10. def concat_all_gather(t: torch.Tensor) -> torch.Tensor:
  11. if torch.distributed.is_available() and torch.distributed.is_initialized():
  12. return torch.cat(all_gather(t), dim=0) # keeps grad graph
  13. return t
  14. def concat_datasets(datasets: list[Dataset], batch_size: int) -> Dataset:
  15. """
  16. Concatenates a list of datasets into a single dataset.
  17. This is a utility function to handle the case where multiple datasets are provided.
  18. """
  19. # round down each dataset if not divible by global batch size
  20. for i in range(len(datasets)):
  21. if len(datasets[i]) % batch_size != 0:
  22. total_samples = (len(datasets[i]) // batch_size) * batch_size
  23. datasets[i] = datasets[i].take(total_samples)
  24. return ConcatDataset(datasets)
  25. class ContrastiveTrainer(Trainer):
  26. def __init__(self, loss_func, is_vision_model, compute_symetric_loss=False, *args, **kwargs):
  27. if isinstance(kwargs["train_dataset"], list):
  28. train_dataset_list = kwargs["train_dataset"]
  29. kwargs["train_dataset"] = concat_datasets(train_dataset_list, batch_size=kwargs["args"].train_batch_size)
  30. else:
  31. train_dataset_list = None
  32. if isinstance(kwargs["eval_dataset"], list):
  33. eval_dataset_list = kwargs["eval_dataset"]
  34. kwargs["eval_dataset"] = concat_datasets(eval_dataset_list)
  35. else:
  36. eval_dataset_list = None
  37. super().__init__(*args, **kwargs)
  38. self.loss_func = loss_func
  39. self.is_vision_model = is_vision_model # Unused argument, will be removed in 0.4.0
  40. self.args.remove_unused_columns = False # Safety, don't remove dataset columns from dataloader
  41. self.train_dataset_list = train_dataset_list
  42. self.eval_dataset_list = eval_dataset_list
  43. self.compute_symetric_loss = compute_symetric_loss
  44. def get_train_dataloader(self) -> DataLoader:
  45. """
  46. Returns the training [`~torch.utils.data.DataLoader`].
  47. Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed
  48. training if necessary) otherwise.
  49. Subclass and override this method if you want to inject some custom behavior.
  50. """
  51. if self.train_dataset is None:
  52. raise ValueError("Trainer: training requires a train_dataset.")
  53. if self.train_dataset_list is None:
  54. # If no dataset list, use the default behavior
  55. return super().get_train_dataloader()
  56. dataset = self.train_dataset
  57. description = "Training"
  58. sampler_fn = self._get_train_sampler
  59. is_training = True
  60. dataloader_key = None
  61. data_collator = self.data_collator
  62. if is_datasets_available() and isinstance(dataset, datasets.Dataset):
  63. dataset = self._remove_unused_columns(dataset, description=description)
  64. else:
  65. data_collator = self._get_collator_with_removed_columns(self.data_collator, description=description)
  66. self.query_prefix = data_collator.query_prefix
  67. self.pos_prefix = data_collator.pos_doc_prefix
  68. self.neg_prefix = data_collator.neg_doc_prefix
  69. dataloader_params = {
  70. ######### don't set batch size, mutually exclusive from batch sampler ######
  71. "collate_fn": data_collator,
  72. "num_workers": self.args.dataloader_num_workers,
  73. "pin_memory": self.args.dataloader_pin_memory,
  74. "persistent_workers": self.args.dataloader_persistent_workers,
  75. }
  76. if not isinstance(dataset, torch.utils.data.IterableDataset):
  77. if sampler_fn is not None:
  78. ###### batch_sampler set instead of sampler in trainer code #######
  79. dataloader_params["batch_sampler"] = sampler_fn()
  80. dataloader_params["drop_last"] = self.args.dataloader_drop_last
  81. dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
  82. if is_training:
  83. dataloader_params["worker_init_fn"] = partial(
  84. seed_worker, num_workers=self.args.dataloader_num_workers, rank=self.args.process_index
  85. )
  86. dataloader = DataLoader(dataset, **dataloader_params)
  87. # Accelerator.free_memory() will destroy the references, so
  88. # we need to store the non-prepared version for eval dataloaders.
  89. if dataloader_key is not None and self.args.dataloader_persistent_workers:
  90. if hasattr(self, "_eval_dataloaders"):
  91. self._eval_dataloaders[dataloader_key] = dataloader
  92. else:
  93. self._eval_dataloaders = {dataloader_key: dataloader}
  94. return self.accelerator.prepare(dataloader)
  95. def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
  96. if self.train_dataset_list is None:
  97. return super()._get_train_sampler()
  98. # Use SingleDatasetBatchSampler to ensure that each dataset in the list is sampled independently
  99. # Note: Surely breaks in distributed training
  100. # TODO: fix this
  101. generator = torch.Generator()
  102. generator.manual_seed(self.args.seed)
  103. return SingleDatasetBatchSampler(
  104. self.train_dataset_list,
  105. self.args.train_batch_size,
  106. drop_last=self.args.dataloader_drop_last,
  107. generator=generator,
  108. )
  109. def _compute_loss_from_outputs(
  110. self,
  111. query_outputs,
  112. pos_target_outputs,
  113. neg_target_outputs=None,
  114. ):
  115. offset = 0
  116. batch_size = query_outputs.size(0)
  117. if self.accelerator.num_processes > 1 and self.accelerator.sync_gradients:
  118. # gather docs across all processes
  119. pos_target_outputs = self.accelerator.pad_across_processes(
  120. pos_target_outputs, dim=1, pad_index=0, pad_first=True
  121. )
  122. pos_target_outputs = concat_all_gather(pos_target_outputs)
  123. rank = self.accelerator.process_index
  124. offset = rank * batch_size
  125. if neg_target_outputs is not None:
  126. loss = self.loss_func(
  127. query_embeddings=query_outputs,
  128. doc_embeddings=pos_target_outputs,
  129. neg_doc_embeddings=neg_target_outputs,
  130. offset=offset,
  131. )
  132. else:
  133. loss = self.loss_func(query_embeddings=query_outputs, doc_embeddings=pos_target_outputs, offset=offset)
  134. return loss
  135. def _reshape_neg_doc_inputs(self, inputs):
  136. """
  137. Helper function to reshape negative doc inputs to (batch_size * num_neg_docs, ...)
  138. """
  139. neg_doc_inputs = {k[len(self.neg_prefix) :]: v for k, v in inputs.items() if k.startswith(self.neg_prefix)}
  140. for k in neg_doc_inputs:
  141. # go from (batch_size, num_neg_docs, ...) to (batch_size * num_neg_docs, ...)
  142. neg_doc_inputs[k] = neg_doc_inputs[k].view(-1, *neg_doc_inputs[k].shape[2:])
  143. return neg_doc_inputs
  144. def _reshape_neg_doc_outputs(self, neg_doc_outputs, num_neg_docs):
  145. """
  146. Helper function to reshape negative doc outputs to (batch_size, num_neg_docs, ...)
  147. """
  148. neg_doc_outputs = neg_doc_outputs.view(-1, num_neg_docs, *neg_doc_outputs.shape[1:])
  149. return neg_doc_outputs
  150. def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
  151. query_inputs = {k[len(self.query_prefix) :]: v for k, v in inputs.items() if k.startswith(self.query_prefix)}
  152. query_outputs = model(**query_inputs)
  153. # feed only kwargs with 'doc_' prefix
  154. doc_inputs = {k[len(self.pos_prefix) :]: v for k, v in inputs.items() if k.startswith(self.pos_prefix)}
  155. doc_outputs = model(**doc_inputs)
  156. if "neg_doc_input_ids" in inputs:
  157. # Negative docs are not gathered across processes, so we can use them without offset
  158. num_negs = inputs["neg_doc_input_ids"].size(1)
  159. neg_doc_inputs = self._reshape_neg_doc_inputs(inputs)
  160. neg_doc_outputs = model(**neg_doc_inputs)
  161. neg_doc_outputs = self._reshape_neg_doc_outputs(neg_doc_outputs, num_negs)
  162. else:
  163. neg_doc_outputs = None
  164. # query -> doc loss
  165. loss = self._compute_loss_from_outputs(query_outputs, doc_outputs, neg_doc_outputs)
  166. if self.compute_symetric_loss:
  167. assert neg_doc_outputs is None, "Symmetric loss is not compatible with negative documents."
  168. # doc -> query loss
  169. sym_loss = self._compute_loss_from_outputs(doc_outputs, query_outputs)
  170. loss = (loss + sym_loss) / 2
  171. return (loss, (query_outputs, doc_outputs)) if return_outputs else loss
  172. def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=True):
  173. """This function is used to generate predictions and return the loss for the given inputs."""
  174. if not prediction_loss_only:
  175. raise ValueError("prediction_step is only called with prediction_loss_only=True")
  176. with torch.no_grad():
  177. # feed only kwargs with 'doc_' prefix
  178. doc_outputs = model(**{k[4:]: v for k, v in inputs.items() if k.startswith("doc")})
  179. query_outputs = model(input_ids=inputs["query_input_ids"], attention_mask=inputs["query_attention_mask"])
  180. if "neg_doc_input_ids" in inputs:
  181. neg_doc_outputs = model(**{k[8:]: v for k, v in inputs.items() if k.startswith("neg_doc")})
  182. loss = self.loss_func(query_outputs, doc_outputs, neg_doc_outputs)
  183. return loss, None, None
  184. loss = self.loss_func(query_outputs, doc_outputs)
  185. return loss, None, None