colmodel_training.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. import os
  2. from dataclasses import dataclass
  3. from typing import Callable, Dict, List, Optional, Union
  4. from peft import LoraConfig, PeftModel, get_peft_model
  5. from transformers import (
  6. PreTrainedModel,
  7. TrainingArguments,
  8. )
  9. from colpali_engine.collators import VisualRetrieverCollator
  10. from colpali_engine.data.dataset import ColPaliEngineDataset
  11. from colpali_engine.loss.late_interaction_losses import (
  12. ColbertLoss,
  13. )
  14. from colpali_engine.trainer.contrastive_trainer import ContrastiveTrainer
  15. from colpali_engine.utils.gpu_stats import print_gpu_utilization, print_summary
  16. from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor
  17. @dataclass
  18. class ColModelTrainingConfig:
  19. model: Union[PreTrainedModel, PeftModel]
  20. processor: BaseVisualRetrieverProcessor
  21. train_dataset: Union[ColPaliEngineDataset, List[ColPaliEngineDataset]]
  22. eval_dataset: Optional[Union[ColPaliEngineDataset, Dict[str, ColPaliEngineDataset]]] = None
  23. tr_args: Optional[TrainingArguments] = None
  24. output_dir: Optional[str] = None
  25. max_length: int = 256
  26. run_eval: bool = True
  27. run_train: bool = True
  28. peft_config: Optional[LoraConfig] = None
  29. loss_func: Optional[Callable] = ColbertLoss()
  30. pretrained_peft_model_name_or_path: Optional[str] = None
  31. """
  32. Config class used for training a ColVision model.
  33. """
  34. def __post_init__(self):
  35. """
  36. Initialize the model and tokenizer if not provided
  37. """
  38. if self.output_dir is None:
  39. sanitized_name = str(self.model.name_or_path).replace("/", "_")
  40. self.output_dir = f"./models/{sanitized_name}"
  41. if self.tr_args is None:
  42. print("No training arguments provided. Using default.")
  43. self.tr_args = TrainingArguments(output_dir=self.output_dir)
  44. elif self.tr_args.output_dir is None or self.tr_args.output_dir == "trainer_output":
  45. self.tr_args.output_dir = self.output_dir
  46. if isinstance(self.tr_args.learning_rate, str):
  47. print("Casting learning rate to float")
  48. self.tr_args.learning_rate = float(self.tr_args.learning_rate)
  49. self.tr_args.remove_unused_columns = False
  50. if self.pretrained_peft_model_name_or_path is not None:
  51. print("Loading pretrained PEFT model")
  52. self.model.load_adapter(self.pretrained_peft_model_name_or_path, is_trainable=True)
  53. if self.peft_config is not None:
  54. print("Configurating PEFT model")
  55. if self.pretrained_peft_model_name_or_path is None:
  56. self.model = get_peft_model(self.model, self.peft_config)
  57. self.model.print_trainable_parameters()
  58. else:
  59. print(f"Adapter already loaded from {self.pretrained_peft_model_name_or_path}. Not overwriting.")
  60. print_gpu_utilization()
  61. class ColModelTraining:
  62. """
  63. Class that contains the training and evaluation logic for a ColVision model.
  64. """
  65. def __init__(self, config: ColModelTrainingConfig) -> None:
  66. self.config = config
  67. self.model = self.config.model
  68. self.current_git_hash = os.popen("git rev-parse HEAD").read().strip()
  69. self.train_dataset = self.config.train_dataset
  70. self.eval_dataset = self.config.eval_dataset
  71. self.collator = VisualRetrieverCollator(
  72. processor=self.config.processor,
  73. max_length=self.config.max_length,
  74. )
  75. def train(self) -> None:
  76. trainer = ContrastiveTrainer(
  77. model=self.model,
  78. train_dataset=self.train_dataset,
  79. eval_dataset=self.eval_dataset,
  80. args=self.config.tr_args,
  81. data_collator=self.collator,
  82. loss_func=self.config.loss_func,
  83. is_vision_model=self.config.processor is not None,
  84. )
  85. trainer.args.remove_unused_columns = False
  86. result = trainer.train(resume_from_checkpoint=self.config.tr_args.resume_from_checkpoint)
  87. print_summary(result)
  88. def eval(self) -> None:
  89. raise NotImplementedError("Evaluation is not implemented yet.")
  90. def save(self):
  91. """
  92. Save the model with its training config, as well as the tokenizer and processor if provided.
  93. """
  94. self.model.save_pretrained(self.config.output_dir)
  95. self.config.processor.save_pretrained(self.config.output_dir)
  96. # Save git hash of the commit at beginning of training
  97. with open(f"{self.config.output_dir}/git_hash.txt", "w") as f:
  98. f.write(self.current_git_hash)