train_colqwenomni_model.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. import argparse
  2. import shutil
  3. from pathlib import Path
  4. import torch
  5. from datasets import load_dataset
  6. from peft import LoraConfig
  7. from transformers import TrainingArguments
  8. from colpali_engine.data.dataset import ColPaliEngineDataset
  9. from colpali_engine.loss.late_interaction_losses import ColbertLoss, ColbertPairwiseCELoss
  10. from colpali_engine.models import ColQwen2_5Omni, ColQwen2_5OmniProcessor
  11. from colpali_engine.trainer.colmodel_torch_training import ColModelTorchTraining
  12. from colpali_engine.trainer.colmodel_training import ColModelTraining, ColModelTrainingConfig
  13. def parse_args():
  14. p = argparse.ArgumentParser()
  15. p.add_argument("--output-dir", type=str, required=True, help="where to write model + script copy")
  16. p.add_argument("--lr", type=float, default=2e-4, help="learning rate")
  17. p.add_argument("--tau", type=float, default=0.02, help="temperature for loss function")
  18. p.add_argument("--trainer", type=str, default="hf", choices=["torch", "hf"], help="trainer to use")
  19. p.add_argument("--loss", type=str, default="ce", choices=["ce", "pairwise"], help="loss function to use")
  20. p.add_argument("--peft", action="store_true", help="use PEFT for training")
  21. return p.parse_args()
  22. if __name__ == "__main__":
  23. args = parse_args()
  24. if args.loss == "ce":
  25. loss_func = ColbertLoss(
  26. temperature=args.tau,
  27. normalize_scores=True,
  28. use_smooth_max=False,
  29. pos_aware_negative_filtering=False,
  30. )
  31. elif args.loss == "pairwise":
  32. loss_func = ColbertPairwiseCELoss(
  33. normalize_scores=False,
  34. )
  35. else:
  36. raise ValueError(f"Unknown loss function: {args.loss}")
  37. config = ColModelTrainingConfig(
  38. output_dir=args.output_dir,
  39. processor=ColQwen2_5OmniProcessor.from_pretrained(
  40. pretrained_model_name_or_path="./models/base_models/colqwen2.5omni-base",
  41. ),
  42. model=ColQwen2_5Omni.from_pretrained(
  43. pretrained_model_name_or_path="./models/base_models/colqwen2.5omni-base",
  44. torch_dtype=torch.bfloat16,
  45. attn_implementation="flash_attention_2",
  46. ),
  47. train_dataset=ColPaliEngineDataset(
  48. load_dataset("./data_dir/colpali_train_set", split="train"), pos_target_column_name="image"
  49. ),
  50. eval_dataset=ColPaliEngineDataset(
  51. load_dataset("./data_dir/colpali_train_set", split="test"), pos_target_column_name="image"
  52. ),
  53. run_eval=True,
  54. loss_func=loss_func,
  55. tr_args=TrainingArguments(
  56. output_dir=None,
  57. overwrite_output_dir=True,
  58. num_train_epochs=5,
  59. per_device_train_batch_size=32,
  60. gradient_checkpointing=True,
  61. gradient_checkpointing_kwargs={"use_reentrant": False},
  62. per_device_eval_batch_size=16,
  63. eval_strategy="steps",
  64. dataloader_num_workers=8,
  65. save_steps=500,
  66. logging_steps=10,
  67. eval_steps=100,
  68. warmup_steps=100,
  69. learning_rate=args.lr,
  70. save_total_limit=1,
  71. dataloader_prefetch_factor=2,
  72. dataloader_pin_memory=True,
  73. dataloader_persistent_workers=True,
  74. ),
  75. peft_config=LoraConfig(
  76. r=32,
  77. lora_alpha=32,
  78. lora_dropout=0.1,
  79. init_lora_weights="gaussian",
  80. bias="none",
  81. task_type="FEATURE_EXTRACTION",
  82. target_modules="(.*(model)(?!.*visual).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)",
  83. )
  84. if args.peft
  85. else None,
  86. )
  87. # make sure output_dir exists and copy script for provenance
  88. Path(config.output_dir).mkdir(parents=True, exist_ok=True)
  89. shutil.copy(Path(__file__), Path(config.output_dir) / Path(__file__).name)
  90. trainer = ColModelTraining(config) if args.trainer == "hf" else ColModelTorchTraining(config)
  91. trainer.train()
  92. trainer.save()