train_colpali_model.yaml 1.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041
  1. config:
  2. (): colpali_engine.trainer.colmodel_training.ColModelTrainingConfig
  3. output_dir: !path ../../../models/right_pad/train_colpali-3b-mix-448
  4. processor:
  5. (): colpali_engine.utils.transformers_wrappers.AllPurposeWrapper
  6. class_to_instanciate: !ext colpali_engine.models.ColPaliProcessor
  7. pretrained_model_name_or_path: "./models/colpaligemma-3b-mix-448-base" # "./models/paligemma-3b-mix-448"
  8. max_length: 50
  9. model:
  10. (): colpali_engine.utils.transformers_wrappers.AllPurposeWrapper
  11. class_to_instanciate: !ext colpali_engine.models.ColPali
  12. pretrained_model_name_or_path: "./models/colpaligemma-3b-mix-448-base"
  13. torch_dtype: !ext torch.bfloat16
  14. # device_map: "auto"
  15. # quantization_config:
  16. # (): transformers.BitsAndBytesConfig
  17. # load_in_4bit: true
  18. # bnb_4bit_quant_type: "nf4"
  19. # bnb_4bit_compute_dtype: "bfloat16"
  20. # bnb_4bit_use_double_quant: true
  21. train_dataset:
  22. (): colpali_engine.utils.dataset_transformation.load_train_set
  23. eval_dataset: !import ../data/test_data.yaml
  24. max_length: 50
  25. run_eval: true
  26. loss_func:
  27. (): colpali_engine.loss.late_interaction_losses.ColbertPairwiseCELoss
  28. tr_args: !import ../tr_args/default_tr_args.yaml
  29. peft_config:
  30. (): peft.LoraConfig
  31. r: 32
  32. lora_alpha: 32
  33. lora_dropout: 0.1
  34. init_lora_weights: "gaussian"
  35. bias: "none"
  36. task_type: "FEATURE_EXTRACTION"
  37. target_modules: '(.*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)'
  38. # target_modules: '(.*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)'