| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118 |
- import os
- from dataclasses import dataclass
- from typing import Callable, Dict, List, Optional, Union
- from peft import LoraConfig, PeftModel, get_peft_model
- from transformers import (
- PreTrainedModel,
- TrainingArguments,
- )
- from colpali_engine.collators import VisualRetrieverCollator
- from colpali_engine.data.dataset import ColPaliEngineDataset
- from colpali_engine.loss.late_interaction_losses import (
- ColbertLoss,
- )
- from colpali_engine.trainer.contrastive_trainer import ContrastiveTrainer
- from colpali_engine.utils.gpu_stats import print_gpu_utilization, print_summary
- from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor
- @dataclass
- class ColModelTrainingConfig:
- model: Union[PreTrainedModel, PeftModel]
- processor: BaseVisualRetrieverProcessor
- train_dataset: Union[ColPaliEngineDataset, List[ColPaliEngineDataset]]
- eval_dataset: Optional[Union[ColPaliEngineDataset, Dict[str, ColPaliEngineDataset]]] = None
- tr_args: Optional[TrainingArguments] = None
- output_dir: Optional[str] = None
- max_length: int = 256
- run_eval: bool = True
- run_train: bool = True
- peft_config: Optional[LoraConfig] = None
- loss_func: Optional[Callable] = ColbertLoss()
- pretrained_peft_model_name_or_path: Optional[str] = None
- """
- Config class used for training a ColVision model.
- """
- def __post_init__(self):
- """
- Initialize the model and tokenizer if not provided
- """
- if self.output_dir is None:
- sanitized_name = str(self.model.name_or_path).replace("/", "_")
- self.output_dir = f"./models/{sanitized_name}"
- if self.tr_args is None:
- print("No training arguments provided. Using default.")
- self.tr_args = TrainingArguments(output_dir=self.output_dir)
- elif self.tr_args.output_dir is None or self.tr_args.output_dir == "trainer_output":
- self.tr_args.output_dir = self.output_dir
- if isinstance(self.tr_args.learning_rate, str):
- print("Casting learning rate to float")
- self.tr_args.learning_rate = float(self.tr_args.learning_rate)
- self.tr_args.remove_unused_columns = False
- if self.pretrained_peft_model_name_or_path is not None:
- print("Loading pretrained PEFT model")
- self.model.load_adapter(self.pretrained_peft_model_name_or_path, is_trainable=True)
- if self.peft_config is not None:
- print("Configurating PEFT model")
- if self.pretrained_peft_model_name_or_path is None:
- self.model = get_peft_model(self.model, self.peft_config)
- self.model.print_trainable_parameters()
- else:
- print(f"Adapter already loaded from {self.pretrained_peft_model_name_or_path}. Not overwriting.")
- print_gpu_utilization()
- class ColModelTraining:
- """
- Class that contains the training and evaluation logic for a ColVision model.
- """
- def __init__(self, config: ColModelTrainingConfig) -> None:
- self.config = config
- self.model = self.config.model
- self.current_git_hash = os.popen("git rev-parse HEAD").read().strip()
- self.train_dataset = self.config.train_dataset
- self.eval_dataset = self.config.eval_dataset
- self.collator = VisualRetrieverCollator(
- processor=self.config.processor,
- max_length=self.config.max_length,
- )
- def train(self) -> None:
- trainer = ContrastiveTrainer(
- model=self.model,
- train_dataset=self.train_dataset,
- eval_dataset=self.eval_dataset,
- args=self.config.tr_args,
- data_collator=self.collator,
- loss_func=self.config.loss_func,
- is_vision_model=self.config.processor is not None,
- )
- trainer.args.remove_unused_columns = False
- result = trainer.train(resume_from_checkpoint=self.config.tr_args.resume_from_checkpoint)
- print_summary(result)
- def eval(self) -> None:
- raise NotImplementedError("Evaluation is not implemented yet.")
- def save(self):
- """
- Save the model with its training config, as well as the tokenizer and processor if provided.
- """
- self.model.save_pretrained(self.config.output_dir)
- self.config.processor.save_pretrained(self.config.output_dir)
- # Save git hash of the commit at beginning of training
- with open(f"{self.config.output_dir}/git_hash.txt", "w") as f:
- f.write(self.current_git_hash)
|