train_colqwen25_model.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. import argparse
  2. import shutil
  3. from pathlib import Path
  4. import torch
  5. from datasets import load_dataset
  6. from peft import LoraConfig
  7. from transformers import TrainingArguments
  8. from colpali_engine.data.dataset import ColPaliEngineDataset
  9. from colpali_engine.loss.late_interaction_losses import ColbertLoss, ColbertPairwiseCELoss
  10. from colpali_engine.models import ColQwen2_5, ColQwen2_5_Processor
  11. from colpali_engine.trainer.colmodel_torch_training import ColModelTorchTraining
  12. from colpali_engine.trainer.colmodel_training import ColModelTraining, ColModelTrainingConfig
  13. from colpali_engine.utils.dataset_transformation import load_train_set
  14. def parse_args():
  15. p = argparse.ArgumentParser()
  16. p.add_argument("--output-dir", type=str, required=True, help="where to write model + script copy")
  17. p.add_argument("--lr", type=float, default=2e-4, help="learning rate")
  18. p.add_argument("--tau", type=float, default=0.02, help="temperature for loss function")
  19. p.add_argument("--trainer", type=str, default="hf", choices=["torch", "hf"], help="trainer to use")
  20. p.add_argument("--loss", type=str, default="ce", choices=["ce", "pairwise"], help="loss function to use")
  21. p.add_argument("--peft", action="store_true", help="use PEFT for training")
  22. return p.parse_args()
  23. if __name__ == "__main__":
  24. args = parse_args()
  25. if args.loss == "ce":
  26. loss_func = ColbertLoss(
  27. temperature=args.tau,
  28. normalize_scores=True,
  29. use_smooth_max=False,
  30. pos_aware_negative_filtering=False,
  31. )
  32. elif args.loss == "pairwise":
  33. loss_func = ColbertPairwiseCELoss(
  34. normalize_scores=False,
  35. )
  36. else:
  37. raise ValueError(f"Unknown loss function: {args.loss}")
  38. config = ColModelTrainingConfig(
  39. output_dir=args.output_dir,
  40. processor=ColQwen2_5_Processor.from_pretrained(
  41. pretrained_model_name_or_path="./models/base_models/colqwen2.5-base",
  42. max_num_visual_tokens=768,
  43. ),
  44. model=ColQwen2_5.from_pretrained(
  45. pretrained_model_name_or_path="./models/base_models/colqwen2.5-base",
  46. torch_dtype=torch.bfloat16,
  47. use_cache=False,
  48. attn_implementation="flash_attention_2",
  49. ),
  50. train_dataset=load_train_set(),
  51. eval_dataset=ColPaliEngineDataset(
  52. load_dataset("./data_dir/colpali_train_set", split="test"), pos_target_column_name="image"
  53. ),
  54. run_eval=True,
  55. loss_func=loss_func,
  56. tr_args=TrainingArguments(
  57. output_dir=None,
  58. overwrite_output_dir=True,
  59. num_train_epochs=5,
  60. per_device_train_batch_size=64,
  61. gradient_checkpointing=True,
  62. gradient_checkpointing_kwargs={"use_reentrant": False},
  63. per_device_eval_batch_size=16,
  64. eval_strategy="steps",
  65. dataloader_num_workers=8,
  66. save_steps=500,
  67. logging_steps=10,
  68. eval_steps=100,
  69. warmup_steps=100,
  70. learning_rate=args.lr,
  71. save_total_limit=1,
  72. ),
  73. peft_config=LoraConfig(
  74. r=32,
  75. lora_alpha=32,
  76. lora_dropout=0.1,
  77. init_lora_weights="gaussian",
  78. bias="none",
  79. task_type="FEATURE_EXTRACTION",
  80. target_modules="(.*(model)(?!.*visual).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)",
  81. )
  82. if args.peft
  83. else None,
  84. )
  85. # make sure output_dir exists and copy script for provenance
  86. Path(config.output_dir).mkdir(parents=True, exist_ok=True)
  87. shutil.copy(Path(__file__), Path(config.output_dir) / Path(__file__).name)
  88. trainer = ColModelTraining(config) if args.trainer == "hf" else ColModelTorchTraining(config)
  89. trainer.train()
  90. trainer.save()