train_bipali_model.yaml 1.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041
  1. config:
  2. (): colpali_engine.trainer.colmodel_training.ColModelTrainingConfig
  3. output_dir: !path ../../../models/right_pad/train_bipali_reproduction
  4. processor:
  5. (): colpali_engine.utils.transformers_wrappers.AllPurposeWrapper
  6. class_to_instanciate: !ext colpali_engine.models.BiPaliProcessor
  7. pretrained_model_name_or_path: "./models/paligemma-3b-mix-448"
  8. model:
  9. (): colpali_engine.utils.transformers_wrappers.AllPurposeWrapper
  10. class_to_instanciate: !ext colpali_engine.models.BiPali
  11. pretrained_model_name_or_path: "./models/paligemma-3b-mix-448"
  12. torch_dtype: !ext torch.bfloat16
  13. # device_map: "auto"
  14. # quantization_config:
  15. # (): transformers.BitsAndBytesConfig
  16. # load_in_4bit: true
  17. # bnb_4bit_quant_type: "nf4"
  18. # bnb_4bit_compute_dtype: "bfloat16"
  19. # bnb_4bit_use_double_quant: true
  20. train_dataset:
  21. (): colpali_engine.utils.dataset_transformation.load_train_set_detailed
  22. eval_dataset: !import ../data/test_data.yaml
  23. max_length: 50
  24. run_eval: true
  25. loss_func:
  26. (): colpali_engine.loss.bi_encoder_losses.BiEncoderLoss
  27. tr_args: !import ../tr_args/default_tr_args.yaml
  28. peft_config:
  29. (): peft.LoraConfig
  30. r: 32
  31. lora_alpha: 32
  32. lora_dropout: 0.1
  33. init_lora_weights: "gaussian"
  34. bias: "none"
  35. task_type: "FEATURE_EXTRACTION"
  36. target_modules: '(.*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$)'
  37. # target_modules: '(.*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$'