train_colqwen2_model.yaml 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. config:
  2. (): colpali_engine.trainer.colmodel_training.ColModelTrainingConfig
  3. output_dir: !path ../../../models/colqwen2-cesmoothmax-5e-2604
  4. processor:
  5. (): colpali_engine.utils.transformers_wrappers.AllPurposeWrapper
  6. class_to_instanciate: !ext colpali_engine.models.ColQwen2Processor
  7. pretrained_model_name_or_path: "./models/base_models/colqwen2-base"
  8. max_num_visual_tokens: 1024
  9. model:
  10. (): colpali_engine.utils.transformers_wrappers.AllPurposeWrapper
  11. class_to_instanciate: !ext colpali_engine.models.ColQwen2
  12. pretrained_model_name_or_path: "./models/base_models/colqwen2-base"
  13. torch_dtype: !ext torch.bfloat16
  14. use_cache: false
  15. attn_implementation: "flash_attention_2"
  16. train_dataset:
  17. (): colpali_engine.utils.dataset_transformation.load_train_set
  18. eval_dataset: !import ../data/test_data.yaml
  19. # max_length: 50
  20. run_eval: true
  21. loss_func:
  22. (): colpali_engine.loss.late_interaction_losses.ColbertPairwiseCELoss
  23. tr_args:
  24. (): transformers.training_args.TrainingArguments
  25. output_dir: null
  26. overwrite_output_dir: true
  27. num_train_epochs: 5
  28. per_device_train_batch_size: 64
  29. gradient_checkpointing: true
  30. gradient_checkpointing_kwargs: { "use_reentrant": false }
  31. # 6 x 8 gpus = 48 batch size
  32. # gradient_accumulation_steps: 4
  33. per_device_eval_batch_size: 8
  34. eval_strategy: "steps"
  35. dataloader_num_workers: 16
  36. # bf16: true
  37. save_steps: 500
  38. logging_steps: 10
  39. eval_steps: 100
  40. warmup_steps: 100
  41. learning_rate: 2e-4
  42. save_total_limit: 1
  43. # resume_from_checkpoint: true
  44. # optim: "paged_adamw_8bit"
  45. # wandb logging
  46. # wandb_project: "colqwen2"
  47. # run_name: "colqwen2-ba32-nolora"
  48. report_to: "wandb"
  49. peft_config:
  50. (): peft.LoraConfig
  51. r: 32
  52. lora_alpha: 32
  53. lora_dropout: 0.1
  54. init_lora_weights: "gaussian"
  55. bias: "none"
  56. task_type: "FEATURE_EXTRACTION"
  57. target_modules: '(.*(model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)'
  58. # target_modules: '(.*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)'