| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364 |
- config:
- (): colpali_engine.trainer.colmodel_training.ColModelTrainingConfig
- output_dir: !path ../../../models/colqwen2-cesmoothmax-5e-2604
- processor:
- (): colpali_engine.utils.transformers_wrappers.AllPurposeWrapper
- class_to_instanciate: !ext colpali_engine.models.ColQwen2Processor
- pretrained_model_name_or_path: "./models/base_models/colqwen2-base"
- max_num_visual_tokens: 1024
- model:
- (): colpali_engine.utils.transformers_wrappers.AllPurposeWrapper
- class_to_instanciate: !ext colpali_engine.models.ColQwen2
- pretrained_model_name_or_path: "./models/base_models/colqwen2-base"
- torch_dtype: !ext torch.bfloat16
- use_cache: false
- attn_implementation: "flash_attention_2"
- train_dataset:
- (): colpali_engine.utils.dataset_transformation.load_train_set
- eval_dataset: !import ../data/test_data.yaml
- # max_length: 50
- run_eval: true
- loss_func:
- (): colpali_engine.loss.late_interaction_losses.ColbertPairwiseCELoss
- tr_args:
- (): transformers.training_args.TrainingArguments
- output_dir: null
- overwrite_output_dir: true
- num_train_epochs: 5
- per_device_train_batch_size: 64
- gradient_checkpointing: true
- gradient_checkpointing_kwargs: { "use_reentrant": false }
- # 6 x 8 gpus = 48 batch size
- # gradient_accumulation_steps: 4
- per_device_eval_batch_size: 8
- eval_strategy: "steps"
- dataloader_num_workers: 16
- # bf16: true
- save_steps: 500
- logging_steps: 10
- eval_steps: 100
- warmup_steps: 100
- learning_rate: 2e-4
- save_total_limit: 1
- # resume_from_checkpoint: true
- # optim: "paged_adamw_8bit"
- # wandb logging
- # wandb_project: "colqwen2"
- # run_name: "colqwen2-ba32-nolora"
- report_to: "wandb"
- peft_config:
- (): peft.LoraConfig
- r: 32
- lora_alpha: 32
- lora_dropout: 0.1
- init_lora_weights: "gaussian"
- bias: "none"
- task_type: "FEATURE_EXTRACTION"
- target_modules: '(.*(model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)'
- # target_modules: '(.*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)'
|