Procházet zdrojové kódy

新增一个多模态embedding 项目

lookathis@163.com před 1 měsícem
rodič
revize
0888b7c23d
100 změnil soubory, kde provedl 8296 přidání a 0 odebrání
  1. 289 0
      deconstruct_SQI/colpali/CHANGELOG.md
  2. 37 0
      deconstruct_SQI/colpali/CITATION.cff
  3. 21 0
      deconstruct_SQI/colpali/LICENSE
  4. 452 0
      deconstruct_SQI/colpali/README.md
  5. binární
      deconstruct_SQI/colpali/assets/colpali_architecture.webp
  6. 22 0
      deconstruct_SQI/colpali/colpali_engine/__init__.py
  7. 1 0
      deconstruct_SQI/colpali/colpali_engine/collators/__init__.py
  8. 128 0
      deconstruct_SQI/colpali/colpali_engine/collators/visual_retriever_collator.py
  9. 6 0
      deconstruct_SQI/colpali/colpali_engine/compression/__init__.py
  10. 3 0
      deconstruct_SQI/colpali/colpali_engine/compression/token_pooling/__init__.py
  11. 164 0
      deconstruct_SQI/colpali/colpali_engine/compression/token_pooling/base_token_pooling.py
  12. 146 0
      deconstruct_SQI/colpali/colpali_engine/compression/token_pooling/hierarchical_token_pooling.py
  13. 89 0
      deconstruct_SQI/colpali/colpali_engine/compression/token_pooling/lambda_token_pooling.py
  14. 2 0
      deconstruct_SQI/colpali/colpali_engine/data/__init__.py
  15. 162 0
      deconstruct_SQI/colpali/colpali_engine/data/dataset.py
  16. 107 0
      deconstruct_SQI/colpali/colpali_engine/data/sampler.py
  17. 8 0
      deconstruct_SQI/colpali/colpali_engine/interpretability/__init__.py
  18. 84 0
      deconstruct_SQI/colpali/colpali_engine/interpretability/similarity_map_utils.py
  19. 150 0
      deconstruct_SQI/colpali/colpali_engine/interpretability/similarity_maps.py
  20. 16 0
      deconstruct_SQI/colpali/colpali_engine/loss/__init__.py
  21. 418 0
      deconstruct_SQI/colpali/colpali_engine/loss/bi_encoder_losses.py
  22. 465 0
      deconstruct_SQI/colpali/colpali_engine/loss/late_interaction_losses.py
  23. 5 0
      deconstruct_SQI/colpali/colpali_engine/models/__init__.py
  24. 2 0
      deconstruct_SQI/colpali/colpali_engine/models/idefics3/__init__.py
  25. 2 0
      deconstruct_SQI/colpali/colpali_engine/models/idefics3/biidefics3/__init__.py
  26. 57 0
      deconstruct_SQI/colpali/colpali_engine/models/idefics3/biidefics3/modeling_biidefics3.py
  27. 40 0
      deconstruct_SQI/colpali/colpali_engine/models/idefics3/biidefics3/processing_biidefics3.py
  28. 2 0
      deconstruct_SQI/colpali/colpali_engine/models/idefics3/colidefics3/__init__.py
  29. 46 0
      deconstruct_SQI/colpali/colpali_engine/models/idefics3/colidefics3/modeling_colidefics3.py
  30. 76 0
      deconstruct_SQI/colpali/colpali_engine/models/idefics3/colidefics3/processing_colidefics3.py
  31. 2 0
      deconstruct_SQI/colpali/colpali_engine/models/modernvbert/__init__.py
  32. 2 0
      deconstruct_SQI/colpali/colpali_engine/models/modernvbert/bivbert/__init__.py
  33. 66 0
      deconstruct_SQI/colpali/colpali_engine/models/modernvbert/bivbert/modeling_bimodernvbert.py
  34. 36 0
      deconstruct_SQI/colpali/colpali_engine/models/modernvbert/bivbert/processing_bimodernvbert.py
  35. 2 0
      deconstruct_SQI/colpali/colpali_engine/models/modernvbert/colvbert/__init__.py
  36. 52 0
      deconstruct_SQI/colpali/colpali_engine/models/modernvbert/colvbert/modeling_colmodernvbert.py
  37. 78 0
      deconstruct_SQI/colpali/colpali_engine/models/modernvbert/colvbert/processing_colmodernvbert.py
  38. 279 0
      deconstruct_SQI/colpali/colpali_engine/models/modernvbert/configuration_modernvbert.py
  39. 476 0
      deconstruct_SQI/colpali/colpali_engine/models/modernvbert/modeling_modernvbert.py
  40. 2 0
      deconstruct_SQI/colpali/colpali_engine/models/paligemma/__init__.py
  41. 2 0
      deconstruct_SQI/colpali/colpali_engine/models/paligemma/bipali/__init__.py
  42. 144 0
      deconstruct_SQI/colpali/colpali_engine/models/paligemma/bipali/modeling_bipali.py
  43. 26 0
      deconstruct_SQI/colpali/colpali_engine/models/paligemma/bipali/processing_bipali.py
  44. 2 0
      deconstruct_SQI/colpali/colpali_engine/models/paligemma/colpali/__init__.py
  45. 114 0
      deconstruct_SQI/colpali/colpali_engine/models/paligemma/colpali/modeling_colpali.py
  46. 89 0
      deconstruct_SQI/colpali/colpali_engine/models/paligemma/colpali/processing_colpali.py
  47. 2 0
      deconstruct_SQI/colpali/colpali_engine/models/qwen2/__init__.py
  48. 2 0
      deconstruct_SQI/colpali/colpali_engine/models/qwen2/biqwen2/__init__.py
  49. 76 0
      deconstruct_SQI/colpali/colpali_engine/models/qwen2/biqwen2/modeling_biqwen2.py
  50. 43 0
      deconstruct_SQI/colpali/colpali_engine/models/qwen2/biqwen2/processing_biqwen2.py
  51. 2 0
      deconstruct_SQI/colpali/colpali_engine/models/qwen2/colqwen2/__init__.py
  52. 71 0
      deconstruct_SQI/colpali/colpali_engine/models/qwen2/colqwen2/modeling_colqwen2.py
  53. 149 0
      deconstruct_SQI/colpali/colpali_engine/models/qwen2/colqwen2/processing_colqwen2.py
  54. 2 0
      deconstruct_SQI/colpali/colpali_engine/models/qwen2_5/__init__.py
  55. 2 0
      deconstruct_SQI/colpali/colpali_engine/models/qwen2_5/biqwen2_5/__init__.py
  56. 86 0
      deconstruct_SQI/colpali/colpali_engine/models/qwen2_5/biqwen2_5/modeling_biqwen2_5.py
  57. 40 0
      deconstruct_SQI/colpali/colpali_engine/models/qwen2_5/biqwen2_5/processing_biqwen2_5.py
  58. 2 0
      deconstruct_SQI/colpali/colpali_engine/models/qwen2_5/colqwen2_5/__init__.py
  59. 73 0
      deconstruct_SQI/colpali/colpali_engine/models/qwen2_5/colqwen2_5/modeling_colqwen2_5.py
  60. 146 0
      deconstruct_SQI/colpali/colpali_engine/models/qwen2_5/colqwen2_5/processing_colqwen2_5.py
  61. 65 0
      deconstruct_SQI/colpali/colpali_engine/models/qwen_omni/colqwen_omni/modeling_colqwen_omni.py
  62. 229 0
      deconstruct_SQI/colpali/colpali_engine/models/qwen_omni/colqwen_omni/processing_colqwen_omni.py
  63. 3 0
      deconstruct_SQI/colpali/colpali_engine/trainer/__init__.py
  64. 261 0
      deconstruct_SQI/colpali/colpali_engine/trainer/colmodel_torch_training.py
  65. 118 0
      deconstruct_SQI/colpali/colpali_engine/trainer/colmodel_training.py
  66. 225 0
      deconstruct_SQI/colpali/colpali_engine/trainer/contrastive_trainer.py
  67. 0 0
      deconstruct_SQI/colpali/colpali_engine/utils/__init__.py
  68. 268 0
      deconstruct_SQI/colpali/colpali_engine/utils/dataset_transformation.py
  69. 24 0
      deconstruct_SQI/colpali/colpali_engine/utils/gpu_stats.py
  70. 256 0
      deconstruct_SQI/colpali/colpali_engine/utils/processing_utils.py
  71. 99 0
      deconstruct_SQI/colpali/colpali_engine/utils/torch_utils.py
  72. 20 0
      deconstruct_SQI/colpali/colpali_engine/utils/transformers_wrappers.py
  73. 86 0
      deconstruct_SQI/colpali/pyproject.toml
  74. 109 0
      deconstruct_SQI/colpali/scripts/api_call.py
  75. 131 0
      deconstruct_SQI/colpali/scripts/compute_hardnegs.py
  76. 3 0
      deconstruct_SQI/colpali/scripts/configs/data/debug_data.yaml
  77. 31 0
      deconstruct_SQI/colpali/scripts/configs/data/test_data.yaml
  78. 72 0
      deconstruct_SQI/colpali/scripts/configs/idefics/train_colsmolvlm_model.yaml
  79. 41 0
      deconstruct_SQI/colpali/scripts/configs/pali/train_bipali_all_model.yaml
  80. 41 0
      deconstruct_SQI/colpali/scripts/configs/pali/train_bipali_model.yaml
  81. 65 0
      deconstruct_SQI/colpali/scripts/configs/pali/train_bipali_pairwise_256_model.yaml
  82. 42 0
      deconstruct_SQI/colpali/scripts/configs/pali/train_bipali_pairwise_hardneg_model.yaml
  83. 42 0
      deconstruct_SQI/colpali/scripts/configs/pali/train_bipali_pairwise_model.yaml
  84. 65 0
      deconstruct_SQI/colpali/scripts/configs/pali/train_colpali2_pt_model.yaml
  85. 40 0
      deconstruct_SQI/colpali/scripts/configs/pali/train_colpali_all_model.yaml
  86. 42 0
      deconstruct_SQI/colpali/scripts/configs/pali/train_colpali_docmatix_hardneg_model.yaml
  87. 39 0
      deconstruct_SQI/colpali/scripts/configs/pali/train_colpali_docmatix_model.yaml
  88. 42 0
      deconstruct_SQI/colpali/scripts/configs/pali/train_colpali_hardneg_debug_model.yaml
  89. 62 0
      deconstruct_SQI/colpali/scripts/configs/pali/train_colpali_hardneg_model.yaml
  90. 41 0
      deconstruct_SQI/colpali/scripts/configs/pali/train_colpali_model.yaml
  91. 41 0
      deconstruct_SQI/colpali/scripts/configs/pali/train_colpali_pt_model.yaml
  92. 65 0
      deconstruct_SQI/colpali/scripts/configs/qwen2/deprecated/train_biqwen2_docmatix_model.yaml
  93. 66 0
      deconstruct_SQI/colpali/scripts/configs/qwen2/deprecated/train_biqwen2_warmup_model.yaml
  94. 67 0
      deconstruct_SQI/colpali/scripts/configs/qwen2/deprecated/train_colqwen2_docmatix_model.yaml
  95. 66 0
      deconstruct_SQI/colpali/scripts/configs/qwen2/deprecated/train_colqwen2_hardneg_model.yaml
  96. 65 0
      deconstruct_SQI/colpali/scripts/configs/qwen2/deprecated/train_colqwen2_wikiss_model.yaml
  97. 68 0
      deconstruct_SQI/colpali/scripts/configs/qwen2/train_biqwen2_hardneg_model.py
  98. 66 0
      deconstruct_SQI/colpali/scripts/configs/qwen2/train_biqwen2_hardneg_model.yaml
  99. 62 0
      deconstruct_SQI/colpali/scripts/configs/qwen2/train_biqwen2_model.yaml
  100. 100 0
      deconstruct_SQI/colpali/scripts/configs/qwen2/train_colqwen25_model.py

+ 289 - 0
deconstruct_SQI/colpali/CHANGELOG.md

@@ -0,0 +1,289 @@
+# Changelog
+
+All notable changes to this project will be documented in this file.
+
+The format is based on [Keep a Changelog](http://keepachangelog.com/)
+and this project adheres to [Semantic Versioning](http://semver.org/).
+
+## Unreleased
+
+- Add ModernVBERT to the list of supported models
+- Fix multi hard negatives training
+- Bump transformer, torch and peft support
+- Fix multi dataset sampling in order to weight probability of being picked by the size of the dataset
+
+## [0.3.12] - 2025-07-16
+
+### Added
+- Video processing for ColQwen-Omni
+
+### Fixed
+- Fixed loading of PaliGemma and ColPali checkpoints (bug introduced in transformers 4.52)
+- Fixed loading of SmolVLM (Idefics3) processors that didn't transmit image_seq_len (bug introduced in transformers 4.52)
+
+## [0.3.11] - 2025-07-04
+
+### Added
+
+- Added BiIdefics3 modeling and processor.
+- [Breaking] (minor) Remove support for context-augmented queries and images
+- Uniform processor docstring
+- Update the collator to align with the new function signatures
+- Add a `process_text` method to replace the `process_query` one. We keep support of the last one for the moment, but we'll deprecate it later
+- Introduce the ColPaliEngineDataset and Corpus class. This is to delegate all data loading to a standard format before training. The concept is for users to override the dataset class if needed for their specific usecases.
+- Added smooth_max option to loss functions
+- Added weighted in_batch terms for losses with hard negatives
+- Added an option to filter out (presumably) false negatives during online training
+- Added a training script in pure torch without the HF trainer
+- Added a sampler to train with multiple datasets at once, with each batch coming from the same source. (experimental, might still need testing on multi-GPU)
+- Adds score normalization to LI models (diving by token length) for betetr performance with CE loss
+- Add experimental PLAID support
+
+### Changed
+
+- Stops pooling queries between GPUs and instead pools only documents, enabling training with way bigger batch sizes. We recomment training with accelerate launch now.
+- Updated loss functions for better abstractions and coherence between the various loss functions. Small speedups and less memory requirements.
+
+
+## [0.3.10] - 2025-04-18
+
+### Added
+
+- Add `LambdaTokenPooler` to allow for custom token pooling functions.
+- Added training losses with negatives to InfoNCE type losses
+
+### Changed
+
+- Fix similarity map helpers for ColQwen2 and ColQwen2.5.
+- [Breaking] (minor) Remove support for Idefics2-based models.
+- Disable multithreading in `HierarchicalTokenPooler` if `num_workers` is not provided or is 1.
+- [Breaking] (minor) Make `pool_factor` an argument of `pool_embeddings` instead of a `HierarchicalTokenPooler` class attribute
+- Bump dependencies for transformers, torch, peft, pillow, accelerate, etc...
+
+## [0.3.9] - 2025-04-03
+
+### Added
+
+- Allow user to pass custom textual context for passage inference
+- Add ColQwen2.5 support and BiQwen2.5 support
+- Add support for token pooling with `HierarchicalTokenPooler`.
+- Allow user to specify the maximum number of image tokens in the resized images in `ColQwen2Processor` and `ColQwen2_5_Processor`.
+
+### Changed
+
+- Warn about evaluation being different from Vidore, and do not store results to prevent confusion.
+- Remove duplicate resize code in `ColQwen2Processor` and `ColQwen2_5_Processor`.
+- Simplify sequence padding for pixel values in `ColQwen2Processor` and `ColQwen2_5_Processor`.
+- Remove deprecated evaluation (`CustomRetrievalEvaluator`) from trainer
+- Refactor the collator classes
+- Make `processor` input compulsory in `ColModelTrainingConfig`
+- Make `BaseVisualRetrieverProcessor` inherit from `ProcessorMixin`
+- Remove unused `tokenizer` field from `ColModelTrainingConfig`
+- Bump transformers to `4.50.0` and torch to `2.6.0` to keep up with the latest versions. Note that this leads to errors on mps until transformers 4.50.4 is released.
+
+## [0.3.8] - 2025-01-29
+
+### Fixed
+
+- Fix peft version in `colpali-engine[train]`
+- Loosen upper bound for `accelerate`
+
+### Tests
+
+- Reorganize modeling tests
+- Add test for ColIdefics3 (and ColSmol)
+
+## [0.3.7] - 2025-01-28
+
+### Changed
+
+- Bump transformers to `4.47` to support `colSmol-256M` and `colSmol-500M`
+
+### Fixed
+
+- Fix checkpoints used for ColQwen2 tests
+
+## [0.3.6] - 2025-01-10
+
+### Added
+
+- Add expected scores in ColPali E2E test
+
+### Changed
+
+- Loosen package dependencies
+
+## [0.3.5] - 2024-12-13
+
+### Added
+
+- Added support for Idefics3 (and SmolVLM)
+
+### Fixed
+
+- Fix typing for `processor.score_multi_vector` (allow for both list and tensor inputs). This does not change how the scores are computed.
+- Fix `tear_down_torch` when used on a non-MPS machine
+
+## [0.3.4] - 2024-11-07
+
+### Added
+
+- General `CorpusQueryCollator` for BEIR style dataset training or hard negative training. This deprecates `HardNegCollator` but all changes to the training loop are made for a seemless update.
+
+### Changed
+
+- Updates BiPali config files
+- Removed query augmentation tokens from BiQwen2Processor
+- Modified XQwen2Processor to place `<|endoftext|>` token at the end of the document prompt (non-breaking for ColQwen but helps BiQwen).
+- Removed `add_suffix` in the VisualRetrieverCollator and let the `suffix` be added in the individual processors.
+- Changed the incorrect `<pad>` token to `<|endoftext|>` fo query augmentation `ColQwen2Processor`. Note that previous models were trained with `<|endoftext|>` so this is simply a non-breaking inference upgrade patch.
+
+## [0.3.3] - 2024-10-29
+
+### Added
+
+- Add BiQwen2 model
+
+### Changed
+
+- Modified ColQwen and BiQwen to prevent the useless forward pass in the last layer of the original model (classification head)
+- Bumped "breaking" dependencies on MTEB and Transformers version and made the corresponding changes in the code
+- Casted Image dtype in ColPali due to breaking 4.46 transformers update
+- Added a "num_image_tokens" kwarg to the `ColQwen2Processor` to allow for different image resolutions
+
+### Fixed
+
+- Fix wrong variable name for `ColPaliProcessor`'s prefixes
+
+## [0.3.2] - 2024-10-17
+
+### Added
+
+- Restore, refactor, and improve `interpretability` module for generating similarity maps
+
+### Changed
+
+- Remove dummy image from `ColPaliProcessor.process_queries`
+
+### Fixed
+
+- Fix the `compute_hardnegs.py` script
+
+### Tests
+
+- Add missing `model.eval()` in tests
+- Add tests for ColQwen2
+
+## [0.3.1] - 2024-09-27
+
+### Added
+
+- Add module-level imports for collators
+- Add sanity check in the run inference example script
+- Add E2E test for ColPali
+- Add Qwen2-VL support
+
+### Changed
+
+- Improve code clarity the run inference example script
+- Subset the example dataset in the run inference example script
+- Rename scorer test to `test_processing_utils`
+- Greatly simplify routing logic in Trainer selection and when feeding arguments to the model forward pass (refacto)
+- Removed class `ContrastiveNegativeTrainer` which is now just integrated in ContrastiveTrainer. This should not affect the user-facing API.
+- Bumped transformers version to 4.45.0 to get Qwen2-VL support
+
+### Fixed
+
+- Import HardNegCollator at module-level if and only if datasets is available
+- Remove the need for `typer` in the run inference example script
+- Fix edge case when empty suffix `""` given to processor
+- Fix bug in HardNegCollator since 0.3.0
+
+## [0.3.0] - 2024-09-10
+
+✨ This release is an exhaustive package refacto, making ColPali more modular and easier to use.
+
+🚨 It is **NOT** backward-compatible with previous versions.
+
+### Added
+
+- Restructure the `utils` module
+- Restructure the model training code
+- Add custom `Processor` classes to easily process images and/or queries
+- Enable module-level imports
+- Add scoring to processor
+- Add `CustomRetrievalEvaluator`
+- Add missing typing
+- Add tests for model, processor, scorer, and collator
+- Lint `Changelog`
+- Add missing docstrings
+- Add "Ruff" and "Test" CI pipelines
+
+### Changed
+
+- Restructure all modules to closely follow the [`transformers`](https://github.com/huggingface/transformers) architecture
+- Hugely simplify the collator implementation to make it model-agnostic
+- `ColPaliProcessor`'s `process_queries` doesn't need a mock image input anymore
+- Clean `pyproject.toml`
+- Loosen the required dependencies
+- Replace `black` with the `ruff` linter
+
+### Removed
+
+- Remove `interpretability` and `eval_manager` modules
+- Remove unused utils
+- Remove `TextRetrieverCollator`
+- Remove `HardNegDocmatixCollator`
+
+### Fixed
+
+- Fix wrong PIL import
+- Fix dependency issues
+
+## [0.2.2] - 2024-09-06
+
+### Fixed
+
+- Remove forced "cuda" usage in Retrieval Evaluator
+
+## [0.2.1] - 2024-09-02
+
+Patch query preprocessing helper function disalignement with training scheme.
+
+### Fixed
+
+- Add 10 extra pad token by default to the query to act as reasoning buffers. This was added in the collator but not the external helper function for inference purposes.
+
+## [0.2.0] - 2024-08-29
+
+Large refactoring to adress several issues and add features. This release is not backward compatible with previous versions.
+The models trained under this version will exhibit degraded performance if used with the previous version of the code and vice versa.
+
+[Branch](https://github.com/illuin-tech/colpali/pull/23)
+
+### Added
+
+- Added multiple training options for training with hard negatives. This leads to better model performance !
+- Added options for restarting training from a checkpoint.
+
+### Changed
+
+- Optionally load ColPali models from pre-initialized backbones of the same shape to remove any stochastic initialization when loading adapters. This fixes [11](https://github.com/illuin-tech/colpali/issues/11) and [17](https://github.com/illuin-tech/colpali/issues/17).
+
+### Fixed
+
+- Set padding side to right in the tokenizer to fix misalignement issue between different query lengths in the same batch. Fixes [12](https://github.com/illuin-tech/colpali/issues/12)
+- Add 10 extra pad token by default to the query to act as reasoning buffers. This enables the above fix to be made without degrading performance and cleans up the old technique of using `<unused>` tokens.
+
+## [0.1.1] - 2024-08-28
+  
+Minor patch release to fix packaging issues.
+
+### Fixed
+
+- [Branch](https://github.com/illuin-tech/colpali/commit/bd55e88c7af7069dde943f00665181fb94631cdd)
+  Fix .gitignore to include all necessary files in the package.
+
+## [0.1.0] - 2024-08-28
+
+Initial code release corresponding to the paper.

+ 37 - 0
deconstruct_SQI/colpali/CITATION.cff

@@ -0,0 +1,37 @@
+cff-version: 1.2.0
+message: "If you use this software, please cite it as below."
+authors:
+- family-names: "Faysse"
+  given-names: "Manuel"
+  email: "manuel.faysse@illuin.tech"
+- family-names: "Sibille"
+  given-names: "Hugues"
+  email: "hugues.sibille@illuin.tech"
+- family-names: "Wu"
+  given-names: "Tony"
+  email: "tony.wu@illuin.tech"
+title: "Vision Document Retrieval (ViDoRe): Benchmark"
+date-released: 2024-06-26
+url: "https://github.com/illuin-tech/vidore-benchmark"
+preferred-citation:
+  type: article
+  authors:
+  - family-names: "Faysse"
+    given-names: "Manuel"
+  - family-names: "Sibille"
+    given-names: "Hugues"
+  - family-names: "Wu"
+    given-names: "Tony"
+  - family-names: "Omrani"
+    given-names: "Bilel"
+  - family-names: "Viaud"
+    given-names: "Gautier"
+  - family-names: "Hudelot"
+    given-names: "Céline"
+  - family-names: "Colombo"
+    given-names: "Pierre"
+  doi: "arXiv.2407.01449"
+  month: 6
+  title: "ColPali: Efficient Document Retrieval with Vision Language Models"
+  year: 2024
+  url: "https://arxiv.org/abs/2407.01449"

+ 21 - 0
deconstruct_SQI/colpali/LICENSE

@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2024 Manuel Faysse, Hugues Sibille, Tony Wu
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.

+ 452 - 0
deconstruct_SQI/colpali/README.md

@@ -0,0 +1,452 @@
+# ColPali: Efficient Document Retrieval with Vision Language Models 👀
+
+[![arXiv](https://img.shields.io/badge/arXiv-2407.01449-b31b1b.svg?style=for-the-badge)](https://arxiv.org/abs/2407.01449)
+[![GitHub](https://img.shields.io/badge/ViDoRe_Benchmark-100000?style=for-the-badge&logo=github&logoColor=white)](https://github.com/illuin-tech/vidore-benchmark)
+[![Hugging Face](https://img.shields.io/badge/Vidore_Hf_Space-FFD21E?style=for-the-badge&logo=huggingface&logoColor=000)](https://huggingface.co/vidore)
+[![GitHub](https://img.shields.io/badge/Cookbooks-100000?style=for-the-badge&logo=github&logoColor=white)](https://github.com/tonywu71/colpali-cookbooks)
+
+[![Test](https://github.com/illuin-tech/colpali/actions/workflows/test.yml/badge.svg?branch=main)](https://github.com/illuin-tech/colpali/actions/workflows/test.yml)
+[![Version](https://img.shields.io/pypi/v/colpali-engine?color=%2334D058&label=pypi%20package)](https://pypi.org/project/colpali-engine/)
+[![Downloads](https://static.pepy.tech/badge/colpali-engine)](https://pepy.tech/project/colpali-engine)
+
+---
+
+[[Model card]](https://huggingface.co/vidore/colpali)
+[[ViDoRe Leaderboard]](https://huggingface.co/spaces/vidore/vidore-leaderboard)
+[[Demo]](https://huggingface.co/spaces/manu/ColPali-demo)
+[[Blog Post]](https://huggingface.co/blog/manu/colpali)
+
+## Associated Paper
+
+This repository contains the code used for training the vision retrievers in the [*ColPali: Efficient Document Retrieval with Vision Language Models*](https://arxiv.org/abs/2407.01449) paper. In particular, it contains the code for training the ColPali model, which is a vision retriever based on the ColBERT architecture and the PaliGemma model.
+
+## Introduction
+
+With our new model *ColPali*, we propose to leverage VLMs to construct efficient multi-vector embeddings in the visual space for document retrieval. By feeding the ViT output patches from PaliGemma-3B to a linear projection, we create a multi-vector representation of documents. We train the model to maximize the similarity between these document embeddings and the query embeddings, following the ColBERT method.
+
+Using ColPali removes the need for potentially complex and brittle layout recognition and OCR pipelines with a single model that can take into account both the textual and visual content (layout, charts, ...) of a document.
+
+![ColPali Architecture](assets/colpali_architecture.webp)
+
+## List of ColVision models
+
+| Model                                                               | Score on [ViDoRe](https://huggingface.co/spaces/vidore/vidore-leaderboard) 🏆 | License    | Comments                                                                                                                                                       | Currently supported |
+|---------------------------------------------------------------------|-------------------------------------------------------------------------------|------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------|---------------------|
+| [vidore/colpali](https://huggingface.co/vidore/colpali)             | 81.3                                                                          | Gemma      | • Based on `google/paligemma-3b-mix-448`.<br />• Checkpoint used in the ColPali paper.                                                                         | ❌                   |
+| [vidore/colpali-v1.1](https://huggingface.co/vidore/colpali-v1.1)   | 81.5                                                                          | Gemma      | • Based on `google/paligemma-3b-mix-448`.<br />• Fix right padding for queries.                                                                                | ✅                   |
+| [vidore/colpali-v1.2](https://huggingface.co/vidore/colpali-v1.2)   | 83.9                                                                          | Gemma      | • Similar to `vidore/colpali-v1.1`.                                                                                                                            | ✅                   |
+| [vidore/colpali-v1.3](https://huggingface.co/vidore/colpali-v1.3)   | 84.8                                                                          | Gemma      | • Similar to `vidore/colpali-v1.2`.<br />• Trained with a larger effective batch size of 256 batch size for 3 epochs.                                          | ✅                   |
+| [vidore/colqwen2-v0.1](https://huggingface.co/vidore/colqwen2-v0.1) | 87.3                                                                          | Apache 2.0 | • Based on `Qwen/Qwen2-VL-2B-Instruct`.<br />• Supports dynamic resolution.<br />• Trained using 768 image patches per page and an effective batch size of 32. | ✅                   |
+| [vidore/colqwen2-v1.0](https://huggingface.co/vidore/colqwen2-v1.0) | 89.3                                                                          | Apache 2.0 | • Similar to `vidore/colqwen2-v0.1`, but trained with more powerful GPUs and with a larger effective batch size (256).                                         | ✅                   |
+| [vidore/colqwen2.5-v0.1](https://huggingface.co/vidore/colqwen2.5-v0.1) | 88.8                                                                          | Apache 2.0 | • Based on `Qwen/Qwen2 5-VL-3B-Instruct`<br />• Supports dynamic resolution.<br />• Trained using 768 image patches per page and an effective batch size of 32.                                         | ✅                   |
+| [vidore/colqwen2.5-v0.2](https://huggingface.co/vidore/colqwen2.5-v0.2) | 89.4                                                                          | Apache 2.0 | • Similar to `vidore/colqwen2.5-v0.1`, but trained with slightly different hyper parameters                                        | ✅                   |
+| [vidore/colSmol-256M](https://huggingface.co/vidore/colSmol-256M)   | 80.1                                                                          | Apache 2.0 | • Based on `HuggingFaceTB/SmolVLM-256M-Instruct`.                                                                                                              | ✅                   |
+| [vidore/colSmol-500M](https://huggingface.co/vidore/colSmol-500M)   | 82.3                                                                          | Apache 2.0 | • Based on `HuggingFaceTB/SmolVLM-500M-Instruct`.                                                                                                              | ✅                   |
+
+## Setup
+
+We used Python 3.11.6 and PyTorch 2.4 to train and test our models, but the codebase is compatible with Python >=3.9 and recent PyTorch versions. To install the package, run:
+
+```bash
+pip install colpali-engine # from PyPi
+pip install git+https://github.com/illuin-tech/colpali # from source
+```
+
+Mac users using MPS with the ColQwen models have reported errors with torch 2.6.0. These errors are fixed by downgrading to torch 2.5.1.
+
+> [!WARNING]
+> For ColPali versions above v1.0, make sure to install the `colpali-engine` package from source or with a version above v0.2.0.
+
+## Usage
+
+### Quick start
+
+```python
+import torch
+from PIL import Image
+from transformers.utils.import_utils import is_flash_attn_2_available
+
+from colpali_engine.models import ColQwen2, ColQwen2Processor
+
+model_name = "vidore/colqwen2-v1.0"
+
+model = ColQwen2.from_pretrained(
+    model_name,
+    torch_dtype=torch.bfloat16,
+    device_map="cuda:0",  # or "mps" if on Apple Silicon
+    attn_implementation="flash_attention_2" if is_flash_attn_2_available() else None,
+).eval()
+
+processor = ColQwen2Processor.from_pretrained(model_name)
+
+# Your inputs
+images = [
+    Image.new("RGB", (128, 128), color="white"),
+    Image.new("RGB", (64, 32), color="black"),
+]
+queries = [
+    "What is the organizational structure for our R&D department?",
+    "Can you provide a breakdown of last year’s financial performance?",
+]
+
+# Process the inputs
+batch_images = processor.process_images(images).to(model.device)
+batch_queries = processor.process_queries(queries).to(model.device)
+
+# Forward pass
+with torch.no_grad():
+    image_embeddings = model(**batch_images)
+    query_embeddings = model(**batch_queries)
+
+scores = processor.score_multi_vector(query_embeddings, image_embeddings)
+```
+
+We now support `fast-plaid` experimentally to make matching quicker for larger corpus sizes:
+
+```python
+# !pip install --no-deps fast-plaid fastkmeans
+
+# Process the inputs by batches of 4
+dataloader = DataLoader(
+    dataset=images,
+    batch_size=4,
+    shuffle=False,
+    collate_fn=lambda x: processor.process_images(x),
+)
+
+ds  = []
+for batch_doc in tqdm(dataloader):
+    with torch.no_grad():
+        batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
+        embeddings_doc = model(**batch_doc)
+    ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
+
+plaid_index = processor.create_plaid_index(ds)
+
+scores = processor.get_topk_plaid(query_embeddings, plaid_index, k=10)
+```
+
+### Benchmarking
+
+To benchmark ColPali on the [ViDoRe leaderboard](https://huggingface.co/spaces/vidore/vidore-leaderboard), use the [`vidore-benchmark`](https://github.com/illuin-tech/vidore-benchmark) package.
+
+### Interpretability with similarity maps
+
+By superimposing the late interaction similarity maps on top of the original image, we can visualize the most salient image patches with respect to each term of the query, yielding interpretable insights into model focus zones.
+
+To use the `interpretability` module, you need to install the `colpali-engine[interpretability]` package:
+
+```bash
+pip install colpali-engine[interpretability]
+```
+
+Then, after generating your embeddings with ColPali, use the following code to plot the similarity maps for each query token:
+
+<details>
+<summary><strong>🔽 Click to expand code snippet</strong></summary>
+
+```python
+import torch
+from PIL import Image
+
+from colpali_engine.interpretability import (
+    get_similarity_maps_from_embeddings,
+    plot_all_similarity_maps,
+)
+from colpali_engine.models import ColPali, ColPaliProcessor
+from colpali_engine.utils.torch_utils import get_torch_device
+
+model_name = "vidore/colpali-v1.3"
+device = get_torch_device("auto")
+
+# Load the model
+model = ColPali.from_pretrained(
+    model_name,
+    torch_dtype=torch.bfloat16,
+    device_map=device,
+).eval()
+
+# Load the processor
+processor = ColPaliProcessor.from_pretrained(model_name)
+
+# Load the image and query
+image = Image.open("shift_kazakhstan.jpg")
+query = "Quelle partie de la production pétrolière du Kazakhstan provient de champs en mer ?"
+
+# Preprocess inputs
+batch_images = processor.process_images([image]).to(device)
+batch_queries = processor.process_queries([query]).to(device)
+
+# Forward passes
+with torch.no_grad():
+    image_embeddings = model.forward(**batch_images)
+    query_embeddings = model.forward(**batch_queries)
+
+# Get the number of image patches
+n_patches = processor.get_n_patches(image_size=image.size, patch_size=model.patch_size)
+
+# Get the tensor mask to filter out the embeddings that are not related to the image
+image_mask = processor.get_image_mask(batch_images)
+
+# Generate the similarity maps
+batched_similarity_maps = get_similarity_maps_from_embeddings(
+    image_embeddings=image_embeddings,
+    query_embeddings=query_embeddings,
+    n_patches=n_patches,
+    image_mask=image_mask,
+)
+
+# Get the similarity map for our (only) input image
+similarity_maps = batched_similarity_maps[0]  # (query_length, n_patches_x, n_patches_y)
+
+# Tokenize the query
+query_tokens = processor.tokenizer.tokenize(query)
+
+# Plot and save the similarity maps for each query token
+plots = plot_all_similarity_maps(
+    image=image,
+    query_tokens=query_tokens,
+    similarity_maps=similarity_maps,
+)
+for idx, (fig, ax) in enumerate(plots):
+    fig.savefig(f"similarity_map_{idx}.png")
+```
+
+</details>
+
+For a more detailed example, you can refer to the interpretability notebooks from the [ColPali Cookbooks 👨🏻‍🍳](https://github.com/tonywu71/colpali-cookbooks) repository.
+
+### Token pooling
+
+[Token pooling](https://doi.org/10.48550/arXiv.2409.14683) is a CRUDE-compliant method (document addition/deletion-friendly) that aims at reducing the sequence length of multi-vector embeddings. For ColPali, many image patches share redundant information, e.g. white background patches. By pooling these patches together, we can reduce the amount of embeddings while retaining most of the page's signal. Retrieval performance with hierarchical mean token pooling on image embeddings can be found in the [ColPali paper](https://doi.org/10.48550/arXiv.2407.01449). In our experiments, we found that a pool factor of 3 offered the optimal trade-off: the total number of vectors is reduced by $66.7\%$ while $97.8\%$ of the original performance is maintained.
+
+To use token pooling, you can use the `HierarchicalEmbeddingPooler` class from the `colpali-engine` package:
+
+<details>
+<summary><strong>🔽 Click to expand code snippet</strong></summary>
+
+```python
+import torch
+
+from colpali_engine.compression.token_pooling import HierarchicalTokenPooler
+
+# Dummy multivector embeddings
+list_embeddings = [
+    torch.rand(10, 768),
+    torch.rand(20, 768),
+]
+
+# Define the pooler with the desired level of compression
+pooler = HierarchicalTokenPooler()
+
+# Pool the embeddings
+outputs = pooler.pool_embeddings(list_embeddings, pool_factor=2)
+```
+
+If your inputs are padded 3D tensor embeddings instead of lists of 2D tensors, use `padding=True` and specify the padding used by your tokenizer to make sure the `HierarchicalTokenPooler` correctly removes the padding values before pooling:
+
+```python
+import torch
+from PIL import Image
+from transformers.utils.import_utils import is_flash_attn_2_available
+
+from colpali_engine.compression.token_pooling import HierarchicalTokenPooler
+from colpali_engine.models import ColQwen2, ColQwen2Processor
+
+model_name = "vidore/colqwen2-v1.0"
+model = ColQwen2.from_pretrained(
+    model_name,
+    torch_dtype=torch.bfloat16,
+    device_map="cuda:0",  # or "mps" if on Apple Silicon
+    attn_implementation="flash_attention_2" if is_flash_attn_2_available() else None,
+).eval()
+processor = ColQwen2Processor.from_pretrained(model_name)
+
+token_pooler = HierarchicalTokenPooler()
+
+# Your page images
+images = [
+    Image.new("RGB", (128, 128), color="white"),
+    Image.new("RGB", (32, 32), color="black"),
+]
+
+# Process the inputs
+batch_images = processor.process_images(images).to(model.device)
+
+# Forward pass
+with torch.no_grad():
+    image_embeddings = model(**batch_images)
+
+# Apply token pooling (reduces the sequence length of the multi-vector embeddings)
+image_embeddings = token_pooler.pool_embeddings(
+    image_embeddings,
+    pool_factor=2,
+    padding=True,
+    padding_side=processor.tokenizer.padding_side,
+)
+```
+
+</details>
+
+### Training
+
+To keep a lightweight repository, only the essential packages were installed. In particular, you must specify the dependencies to use the training script for ColPali. You can do this using the following command:
+
+```bash
+pip install "colpali-engine[train]"
+```
+
+All the model configs used can be found in `scripts/configs/` and rely on the [configue](https://github.com/illuin-tech/configue) package for straightforward configuration. They should be used with the `train_colbert.py` script.
+
+<details>
+<summary><strong>🔽 Example 1: Local training</strong></summary>
+
+
+```bash
+accelerate launch --multi-gpu scripts/configs/qwen2/train_colqwen25_model.py
+```
+
+</details>
+
+<details>
+<summary><strong>🔽 Example 2: Training on a SLURM cluster</strong></summary>
+
+```bash
+sbatch --nodes=1 --cpus-per-task=16 --mem-per-cpu=32GB --time=20:00:00 --gres=gpu:1  -p gpua100 --job-name=colidefics --output=colidefics.out --error=colidefics.err --wrap="accelerate launch scripts/train/train_colbert.py scripts/configs/pali/train_colpali_docmatix_hardneg_model.yaml"
+
+sbatch --nodes=1  --time=5:00:00 -A cad15443 --gres=gpu:8  --constraint=MI250 --job-name=colpali --wrap="accelerate launch --multi-gpu scripts/configs/qwen2/train_colqwen25_model.py"
+```
+
+</details>
+
+## Contributing
+
+We welcome contributions to ColPali! 🤗
+
+To contribute to ColPali, first install the development dependencies for proper testing/linting:
+
+```bash
+pip install "colpali-engine[dev]"
+```
+
+To run all the tests, you will have to install all optional dependencies (or you'll get an error in test discovery):
+
+```bash
+pip install "colpali-engine[all]"
+```
+
+When your PR is ready, ping one of the repository maintainers. We will do our best to review it as soon as possible!
+
+## Community Projects
+
+Several community projects and ressources have been developed around ColPali to facilitate its usage. Feel free to reach out if you want to add your project to this list!
+
+<details>
+<summary><strong>🔽 Libraries 📚</strong></summary>
+
+| Library Name  | Description                                                                                                                                                                                                                                          |
+|---------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------  |
+| Byaldi        | [`Byaldi`](https://github.com/AnswerDotAI/byaldi) is [RAGatouille](https://github.com/AnswerDotAI/RAGatouille)'s equivalent for ColPali, leveraging the `colpali-engine` package to facilitate indexing and storing embeddings.                      |
+| PyVespa       | [`PyVespa`](https://pyvespa.readthedocs.io/en/latest/examples/colpali-document-retrieval-vision-language-models-cloud.html) allows interaction with [Vespa](https://vespa.ai/), a production-grade vector database, with detailed ColPali support.   |
+| Qdrant | Tutorial about using ColQwen2 with the [Qdrant](https://qdrant.tech/documentation/advanced-tutorials/pdf-retrieval-at-scale/) vector database. |
+| Elastic Search     | Tutorial about using ColPali with the [Elastic Search](https://www.elastic.co/search-labs/blog/elastiacsearch-colpali-document-search) vector database. |
+| Weaviate | Tutorial about using multi-vector embeddings with the [Weaviate](https://weaviate.io/developers/weaviate/tutorials/multi-vector-embeddings) vector database. |
+| Candle        | [Candle](https://github.com/huggingface/candle/tree/main/candle-examples/examples/colpali) enables ColPali inference with an efficient ML framework for Rust.                                                                                        |
+| EmbedAnything | [`EmbedAnything`](https://github.com/StarlightSearch/EmbedAnything) Allows end-to-end ColPali inference with both Candle and ONNX backend.                                                                                                           |
+| DocAI         | [DocAI](https://github.com/PragmaticMachineLearning/docai) uses ColPali with GPT-4o and Langchain to extract structured information from documents.                                                                                                  |
+| VARAG         | [VARAG](https://github.com/adithya-s-k/VARAG) uses ColPali in a vision-only and a hybrid RAG pipeline.                                                                                                                                               |
+| ColBERT Live! | [`ColBERT Live!`](https://github.com/jbellis/colbert-live/) enables ColPali usage with vector databases supporting large datasets, compression, and non-vector predicates.                                                                           |
+| ColiVara      | [`ColiVara`](https://github.com/tjmlabs/ColiVara/) is retrieval API that allows you to store, search, and retrieve documents based on their visual embedding. It is a web-first implementation of the ColPali paper using ColQwen2 as the LLM model. |
+| BentoML       | Deploy ColPali easily with BentoML using [this example repository](https://github.com/bentoml/BentoColPali). BentoML features adaptive batching and zero-copy I/O to minimize overhead.                                                              |
+| NoOCR       | NoOCR is end-to-end, [open source](https://github.com/kyryl-opens-ml/no-ocr) solution for complex PDFs, powered by ColPali embeddings. |
+| Astra Multi-vector     | [`Astra-multivector`](https://github.com/brian-ogrady/astradb-multivector) provides enterprise-grade integration with AstraDB for late-interaction models like ColPali, ColQwen2, and ColBERT. It implements efficient token pooling and embedding caching strategies to dramatically reduce latency and index size while maintaining retrieval quality. The library leverages Cassandra's distributed architecture for high-throughput vector search at scale. |
+| Mixpeek       | [Mixpeek](https://docs.mixpeek.com/processing/feature-extractors) is a production platform for multimodal late-interaction retrieval. It supports models like ColBERT, ColPaLI, and ColQwen2 with built-in indexing, versioning, A/B testing, and explainability across image, text, video, and PDF pipelines. |
+
+
+</details>
+
+<details>
+<summary><strong>🔽 Notebooks 📙</strong></summary>
+
+| Notebook Title                                               | Author & Link                                                |
+| ------------------------------------------------------------ | ------------------------------------------------------------ |
+| ColPali Cookbooks                                            | [Tony's Cookbooks (ILLUIN)](https://github.com/tonywu71/colpali-cookbooks) 🙋🏻 |
+| Vision RAG Tutorial                                          | [Manu's Vision Rag Tutorial (ILLUIN)](https://github.com/ManuelFay/Tutorials/blob/main/Tuesday_Practical_2_Vision_RAG.ipynb) 🙋🏻 |
+| ColPali (Byaldi) + Qwen2-VL for RAG                          | [Merve's Notebook (HuggingFace 🤗)](https://github.com/merveenoyan/smol-vision/blob/main/ColPali_%2B_Qwen2_VL.ipynb) |
+| Indexing ColPali with Qdrant                                 | [Daniel's Notebook (HuggingFace 🤗)](https://danielvanstrien.xyz/posts/post-with-code/colpali-qdrant/2024-10-02_using_colpali_with_qdrant.html) |
+| Weaviate Tutorial                                            | [Connor's ColPali POC (Weaviate)](https://github.com/weaviate/recipes/blob/main/weaviate-features/named-vectors/NamedVectors-ColPali-POC.ipynb) |
+| Use ColPali for Multi-Modal Retrieval with Milvus            | [Milvus Documentation](https://milvus.io/docs/use_ColPali_with_milvus.md) |
+| Data Generation                                              | [Daniel's Notebook (HuggingFace 🤗)](https://danielvanstrien.xyz/posts/post-with-code/colpali/2024-09-23-generate_colpali_dataset.html) |
+| Finance Report Analysis with ColPali and Gemini              | [Jaykumaran (LearnOpenCV)](https://github.com/spmallick/learnopencv/tree/master/Multimodal-RAG-with-ColPali-Gemini) |
+| Multimodal Retrieval-Augmented Generation (RAG) with Document Retrieval (ColPali) and Vision Language Models (VLMs) | [Sergio Paniego](https://huggingface.co/learn/cookbook/multimodal_rag_using_document_retrieval_and_vlms) |
+| Document Similarity Search with ColPali                      | [Frank Sommers](https://colab.research.google.com/github/fsommers/documentai/blob/main/Document_Similarity_with_ColPali_0_2_2_version.ipynb) |
+| End-to-end ColPali inference with EmbedAnything              | [Akshay Ballal (EmbedAnything)](https://colab.research.google.com/drive/1-Eiaw8wMm8I1n69N1uKOHkmpw3yV22w8?usp=sharing) |
+| ColiVara: A ColPali Retrieval API                            | [A simple RAG Example](https://github.com/tjmlabs/ColiVara-docs/blob/main/cookbook/RAG.ipynb) |
+| Multimodal RAG with Document Retrieval (ColPali), Vision Language Model (ColQwen2) and Amazon Nova | [Suman's Notebook (AWS)](https://github.com/debnsuma/fcc-ai-engineering-aws/blob/main/05-multimodal-rag-with-colpali/01-multimodal-retrival-with-colpali-retreve-gen.ipynb) |
+| Multi-vector RAG: Using Weaviate to search a collection of PDF documents | [Weaviate's Notebook](https://github.com/weaviate/recipes/blob/main/weaviate-features/multi-vector/multi-vector-colipali-rag.ipynb) |
+
+</details>
+
+<details>
+<summary><strong>🔽 Other resources</strong></summary>
+
+- 📝 = blog post
+- 📋 = PDF / slides
+- 📹 = video
+
+| Title                                                                                    | Author & Link                                                                                                                                                 |
+|------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------|
+| State of AI report 2024                                                                  | [Nathan's report](https://www.stateof.ai/) 📋                                                                                                                 |
+| Technology Radar Volume 31 (October 2024)                                                | [thoughtworks's report](https://www.thoughtworks.com/radar) 📋                                                                                                |
+| LlamaIndex Webinar: ColPali - Efficient Document Retrieval with Vision Language Models   | [LlamaIndex's Youtube video](https://youtu.be/nzcBvba7mzI?si=WL9MsyiAFJMyEolz) 📹                                                                             |
+| PDF Retrieval with Vision Language Models                                                | [Jo's blog post #1 (Vespa)](https://blog.vespa.ai/retrieval-with-vision-language-models-colpali/) 📝                                                          |
+| Scaling ColPali to billions of PDFs with Vespa                                           | [Jo's blog post #2 (Vespa)](https://blog.vespa.ai/scaling-colpali-to-billions/) 📝                                                                            |
+| Neural Search Talks: ColPali (with Manuel Faysse)                                        | [Zeta Alpha's Podcast](https://open.spotify.com/episode/2s6ljhd6VQTL2mIU9cFzCb) 📹                                                                            |
+| Multimodal Document RAG with Llama 3.2 Vision and ColQwen2                               | [Zain's blog post (Together AI)](https://www.together.ai/blog/multimodal-document-rag-with-llama-3-2-vision-and-colqwen2) 📝                                  |
+| ColPali: Document Retrieval with Vision Language Models                                  | [Antaripa Saha](https://antaripasaha.notion.site/ColPali-Efficient-Document-Retrieval-with-Vision-Language-Models-10f5314a5639803d94d0d7ac191bb5b1) 📝        |
+| Minimalist diagrams explaining ColPali                                                   | [Leonie's ColPali diagrams on X](https://twitter.com/helloiamleonie/status/1839321865195851859)📝                                                            |
+| Multimodal RAG with ColPali and Gemini : Financial Report Analysis Application           | [Jaykumaran's blog post (LearnOpenCV)](https://learnopencv.com/multimodal-rag-with-colpali/) 📝                                                               |
+| Implement Multimodal RAG with ColPali and Vision Language Model Groq(Llava) and Qwen2-VL | [Plaban's blog post](https://medium.com/the-ai-forum/implement-multimodal-rag-with-colpali-and-vision-language-model-groq-llava-and-qwen2-vl-5c113b8c08fd) 📝 |
+| multimodal AI. open-source. in a nutshell.                                               | [Merve's Youtube video](https://youtu.be/IoGaGfU1CIg?si=yEhxMqJYxvMzGyUm) 📹                                                                                  |
+| Remove Complexity from Your RAG Applications                                             | [Kyryl's blog post (KOML)](https://kyrylai.com/2024/09/09/remove-complexity-from-your-rag-applications/) 📝                                                   |
+| Late interaction & efficient Multi-modal retrievers need more than a vector index        | [Ayush Chaurasia (LanceDB)](https://blog.lancedb.com/late-interaction-efficient-multi-modal-retrievers-need-more-than-just-a-vector-index/) 📝                |
+| Optimizing Document Retrieval with ColPali and Qdrant's Binary Quantization              | [Sabrina Aquino (Qdrant)]( https://youtu.be/_A90A-grwIc?si=MS5RV17D6sgirCRm)  📹                                                                              |
+| Hands-On Multimodal Retrieval and Interpretability (ColQwen + Vespa)                     | [Antaripa Saha](https://www.analyticsvidhya.com/blog/2024/10/multimodal-retrieval-with-colqwen-vespa/) 📝                                                     |
+
+</details>
+
+## Paper result reproduction
+
+To reproduce the results from the paper, you should checkout to the `v0.1.1` tag or install the corresponding `colpali-engine` package release using:
+
+```bash
+pip install colpali-engine==0.1.1
+```
+
+## Citation
+
+**ColPali: Efficient Document Retrieval with Vision Language Models**  
+
+Authors: **Manuel Faysse**\*, **Hugues Sibille**\*, **Tony Wu**\*, Bilel Omrani, Gautier Viaud, Céline Hudelot, Pierre Colombo (\* denotes equal contribution)
+
+```latex
+@misc{faysse2024colpaliefficientdocumentretrieval,
+      title={ColPali: Efficient Document Retrieval with Vision Language Models}, 
+      author={Manuel Faysse and Hugues Sibille and Tony Wu and Bilel Omrani and Gautier Viaud and Céline Hudelot and Pierre Colombo},
+      year={2024},
+      eprint={2407.01449},
+      archivePrefix={arXiv},
+      primaryClass={cs.IR},
+      url={https://arxiv.org/abs/2407.01449}, 
+}
+
+@misc{macé2025vidorebenchmarkv2raising,
+      title={ViDoRe Benchmark V2: Raising the Bar for Visual Retrieval}, 
+      author={Quentin Macé and António Loison and Manuel Faysse},
+      year={2025},
+      eprint={2505.17166},
+      archivePrefix={arXiv},
+      primaryClass={cs.IR},
+      url={https://arxiv.org/abs/2505.17166}, 
+}
+```

binární
deconstruct_SQI/colpali/assets/colpali_architecture.webp


+ 22 - 0
deconstruct_SQI/colpali/colpali_engine/__init__.py

@@ -0,0 +1,22 @@
+from .models import (
+    BiModernVBert,
+    BiModernVBertProcessor,
+    BiPali,
+    BiPaliProj,
+    BiQwen2,
+    BiQwen2_5,
+    BiQwen2_5_Processor,
+    BiQwen2Processor,
+    ColIdefics3,
+    ColIdefics3Processor,
+    ColModernVBert,
+    ColModernVBertProcessor,
+    ColPali,
+    ColPaliProcessor,
+    ColQwen2,
+    ColQwen2_5,
+    ColQwen2_5_Processor,
+    # ColQwen2_5Omni,
+    # ColQwen2_5OmniProcessor,
+    ColQwen2Processor,
+)

+ 1 - 0
deconstruct_SQI/colpali/colpali_engine/collators/__init__.py

@@ -0,0 +1 @@
+from .visual_retriever_collator import VisualRetrieverCollator

+ 128 - 0
deconstruct_SQI/colpali/colpali_engine/collators/visual_retriever_collator.py

@@ -0,0 +1,128 @@
+import random
+from typing import Any, Dict, List, Union
+
+import torch
+from PIL.Image import Image
+
+from colpali_engine.data.dataset import ColPaliEngineDataset
+from colpali_engine.models.paligemma import ColPaliProcessor
+from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor
+
+N_AUGMENTATION_TOKENS = 10
+
+
+def prefix_keys(data: Dict[str, Any], prefix: str) -> Dict[str, Any]:
+    """
+    Prefix all keys in a dictionary with the given prefix.
+    """
+    return {f"{prefix}{k}": v for k, v in data.items()}
+
+
+class VisualRetrieverCollator:
+    """
+    Collator for training vision retrieval models.
+    """
+
+    # Prefixes
+    query_prefix = "query_"
+    pos_doc_prefix = "doc_"
+    neg_doc_prefix = "neg_doc_"
+
+    def __init__(
+        self,
+        processor: BaseVisualRetrieverProcessor,
+        max_length: int = 2048,
+    ):
+        self.processor = processor
+        self.max_length = max_length
+        self.image_token_id = None
+
+        # If processor is one of the supported types, extract the <image> token id.
+        if isinstance(self.processor, (ColPaliProcessor,)):
+            image_token = "<image>"
+            try:
+                idx = self.processor.tokenizer.additional_special_tokens.index(image_token)
+                self.image_token_id = self.processor.tokenizer.additional_special_tokens_ids[idx]
+            except ValueError:
+                self.image_token_id = None
+
+        # Force padding to be on the right for ColPaliProcessor.
+        if isinstance(self.processor, ColPaliProcessor) and self.processor.tokenizer.padding_side != "right":
+            print("Setting padding side to right")
+            self.processor.tokenizer.padding_side = "right"
+
+    def __call__(self, examples: List[Dict[str, Any]]) -> Dict[str, Any]:
+        queries: List[Union[None, str, Image]] = []
+        pos_targets: List[Union[str, Image]] = []
+        neg_targets: List[Union[str, Image]] = []
+
+        # Parse the examples.
+        for example in examples:
+            assert ColPaliEngineDataset.QUERY_KEY in example, f"Missing {ColPaliEngineDataset.QUERY_KEY} in example."
+            query = example[ColPaliEngineDataset.QUERY_KEY]
+            sampled_query = random.choice(query) if isinstance(query, list) else query
+            queries.append(sampled_query)
+
+            assert ColPaliEngineDataset.POS_TARGET_KEY in example, (
+                f"Missing {ColPaliEngineDataset.POS_TARGET_KEY} in example."
+            )
+            pos_tgt = example[ColPaliEngineDataset.POS_TARGET_KEY]
+            sample_pos = random.choice(pos_tgt) if isinstance(pos_tgt, list) else pos_tgt
+            pos_targets.append(sample_pos)
+
+            neg_tgt = example.get(ColPaliEngineDataset.NEG_TARGET_KEY, None)
+            if neg_tgt is not None:
+                neg_targets.append(neg_tgt)
+
+        # Ensure all queries are strings or images.
+        assert all(isinstance(q, str) for q in queries), (
+            "All queries must be strings, this collator does not support images in queries."
+        )
+
+        # Process queries.
+        queries = [
+            self.processor.query_prefix + q + self.processor.query_augmentation_token * N_AUGMENTATION_TOKENS
+            for q in queries
+        ]
+        batch_query = self.auto_collate(queries, key_prefix=self.query_prefix)
+
+        # Process targets.
+        batch_pos_target = self.auto_collate(pos_targets, key_prefix=self.pos_doc_prefix)
+        batch_neg_target = self.auto_collate(neg_targets, key_prefix=self.neg_doc_prefix) if neg_targets else {}
+
+        return {
+            **batch_query,
+            **batch_pos_target,
+            **batch_neg_target,
+        }
+
+    def auto_collate(self, batch: List[Union[str, Image]], key_prefix: str = "") -> Dict[str, Any]:
+        """Automatically collate a batch of documents."""
+        # Convert Document objects to their underlying data.
+        # if type is mixed across the batch, raise an error.
+        all_types = set(type(item) for item in batch)
+        if str in all_types and Image in all_types:
+            raise ValueError(f"Batch contains mixed types: {all_types}. Expected all items to be of the same type.")
+        if isinstance(batch[0], str):
+            proc_batch = self.processor.process_texts(texts=batch)
+        elif isinstance(batch[0], Image):
+            proc_batch = self.processor.process_images(images=batch)
+        elif isinstance(batch[0], list):
+            if isinstance(batch[0][0], str):
+                batch_size = len(batch)
+                all_texts = [text for texts in batch for text in texts]
+                num_negatives = len(all_texts) // batch_size
+                proc_batch = self.processor.process_texts(texts=all_texts)
+            elif isinstance(batch[0][0], Image):
+                batch_size = len(batch)
+                all_imgs = [img for imgs in batch for img in imgs]
+                num_negatives = len(all_imgs) // batch_size
+                proc_batch = self.processor.process_images(images=all_imgs)
+            else:
+                raise ValueError(f"Unsupported batch type: {type(batch[0][0])}. Expected str or Image.")
+            for k, v in proc_batch.items():
+                if isinstance(v, torch.Tensor):
+                    proc_batch[k] = v.view(batch_size, num_negatives, *v.shape[1:])
+        else:
+            raise ValueError(f"Unsupported batch type: {type(batch[0])}. Expected str or Image.")
+        return prefix_keys(proc_batch, key_prefix)

+ 6 - 0
deconstruct_SQI/colpali/colpali_engine/compression/__init__.py

@@ -0,0 +1,6 @@
+from .token_pooling import (
+    BaseTokenPooler,
+    HierarchicalTokenPooler,
+    LambdaTokenPooler,
+    TokenPoolingOutput,
+)

+ 3 - 0
deconstruct_SQI/colpali/colpali_engine/compression/token_pooling/__init__.py

@@ -0,0 +1,3 @@
+from .base_token_pooling import BaseTokenPooler, TokenPoolingOutput
+from .hierarchical_token_pooling import HierarchicalTokenPooler
+from .lambda_token_pooling import LambdaTokenPooler

+ 164 - 0
deconstruct_SQI/colpali/colpali_engine/compression/token_pooling/base_token_pooling.py

@@ -0,0 +1,164 @@
+from abc import ABC, abstractmethod
+from dataclasses import dataclass
+from typing import Dict, List, Optional, Tuple, Union, cast
+
+import torch
+
+from colpali_engine.utils.torch_utils import unbind_padded_multivector_embeddings
+
+
+@dataclass
+class TokenPoolingOutput:
+    """
+    Token pooling outputs:
+    - pooled_embeddings: A list of 2D tensors (token_length, embedding_dim) where each tensor can have its own
+                         token_length, or a 3D tensor of shape (batch_size, token_length, embedding_dim) with
+                         optional padding.
+    - cluster_id_to_indices (optional): A list of dictionaries. The i-th dictionary maps the cluster id to token indices
+                                        for the i-th embedding in `pooled_embeddings`.
+    """
+
+    pooled_embeddings: Union[List[torch.Tensor], torch.Tensor]
+    cluster_id_to_indices: Optional[Dict[int, Tuple[torch.Tensor]]] = None
+
+
+class BaseTokenPooler(ABC):
+    """
+    Abstract class for token pooling multi-vector embeddings.
+    """
+
+    @abstractmethod
+    def _pool_embeddings_impl(
+        self,
+        embeddings: List[torch.Tensor],
+        num_workers: Optional[int] = None,
+        *args,
+        **kwargs,
+    ) -> Tuple[
+        List[torch.Tensor],
+        Optional[List[Dict[int, Tuple[torch.Tensor]]]],
+    ]:
+        """
+        Implementation of pooling logic for a list of 2D embeddings.
+
+        Args:
+            embeddings: A list of 2D tensors (token_length, embedding_dim)
+            num_workers: Number of workers for parallel processing
+
+        Returns:
+            Tuple containing:
+            - List of pooled embeddings
+            - (Optional) List of dictionaries mapping cluster IDs to token indices
+        """
+        pass
+
+    def _validate_embeddings(self, embeddings: Union[List[torch.Tensor], torch.Tensor]) -> None:
+        """
+        Validate input embeddings and determine their type.
+
+        Args:
+            embeddings: Input embeddings to validate
+
+        Raises:
+            ValueError: If embeddings are empty or have invalid dimensions
+        """
+        if isinstance(embeddings, list) and not embeddings:
+            raise ValueError("Empty embeddings list provided")
+
+        is_list_of_2d_tensors = isinstance(embeddings, list) and embeddings[0].dim() == 2
+        is_3d_tensor = isinstance(embeddings, torch.Tensor) and embeddings.dim() == 3
+
+        if not is_list_of_2d_tensors and not is_3d_tensor:
+            raise ValueError("The input tensor must be a list of 2D tensors or a 3D tensor.")
+
+    def _prepare_embeddings(
+        self,
+        embeddings: Union[List[torch.Tensor], torch.Tensor],
+        padding: bool = False,
+        padding_side: str = "left",
+    ) -> List[torch.Tensor]:
+        """
+        Prepare embeddings for pooling by converting to a list of 2D tensors.
+
+        Args:
+            embeddings: Input embeddings
+            padding: Whether to unbind padded 3D tensor
+            padding_side: Side where padding was applied
+
+        Returns:
+            List of 2D tensors ready for pooling
+        """
+        is_3d_tensor = isinstance(embeddings, torch.Tensor) and embeddings.dim() == 3
+        if is_3d_tensor:
+            if padding:
+                return unbind_padded_multivector_embeddings(
+                    embeddings=cast(torch.Tensor, embeddings),
+                    padding_value=0.0,
+                    padding_side=padding_side,
+                )
+            else:
+                return list(cast(torch.Tensor, embeddings).unbind(dim=0))
+
+        return cast(List[torch.Tensor], embeddings)
+
+    def pool_embeddings(
+        self,
+        embeddings: Union[torch.Tensor, List[torch.Tensor]],
+        return_dict: bool = False,
+        padding: bool = False,
+        padding_side: str = "left",
+        num_workers: Optional[int] = None,
+        **pool_kwargs,
+    ) -> Union[Union[torch.Tensor, List[torch.Tensor]], TokenPoolingOutput]:
+        """
+        Return the pooled multi-vector embeddings and the mapping from cluster id to token indices.
+
+        Args:
+            embeddings: A list of 2D tensors (token_length, embedding_dim) where each tensor can have its own token
+                        length, or a 3D tensor of shape (batch_size, token_length, embedding_dim) with 0-padding.
+            return_dict: Whether or not to return a `TokenPoolingOutput` object (with the cluster id to token indices
+                         mapping) instead of just the pooled embeddings.
+            padding: Whether or not to unbind the padded 3D tensor into a list of 2D tensors. Does nothing if the input
+                     is a list of 2D tensors.
+            padding_side: The side where the padding was applied in the 3D tensor.
+            num_workers: Number of workers for parallel processing. If None, processing is done sequentially.
+
+        Returns:
+            If the `embeddings` input is:
+            - A list of 2D tensors: Returns a list of 2D tensors (token_length, embedding_dim) where each tensor can
+                                    have its own token_length.
+            - A 3D tensor: A 3D tensor of shape (batch_size, token_length, embedding_dim) with 0-padding.
+
+            If `return_dict` is True, the pooled embeddings are returned within a `TokenPoolingOutput` object, along
+            with the cluster id to token indices mapping.
+        """
+        if isinstance(embeddings, list) and not embeddings:
+            return TokenPoolingOutput(pooled_embeddings=[], cluster_id_to_indices=[])
+
+        self._validate_embeddings(embeddings)
+        prepared_embeddings = self._prepare_embeddings(embeddings, padding, padding_side)
+
+        # Apply pooling implementation
+        pooled_embeddings, cluster_id_to_indices = self._pool_embeddings_impl(
+            prepared_embeddings,
+            num_workers=num_workers,
+            **pool_kwargs,
+        )
+
+        # If the input was a 3D tensor, we need to repad the pooled embeddings for the output to be a 3D
+        # tensor as well.
+        if isinstance(embeddings, torch.Tensor) and embeddings.dim() == 3:
+            pooled_embeddings = torch.nn.utils.rnn.pad_sequence(
+                pooled_embeddings,
+                batch_first=True,
+                padding_value=0.0,
+                padding_side=padding_side,
+            )
+
+        if not return_dict:
+            return pooled_embeddings
+
+        return TokenPoolingOutput(
+            pooled_embeddings=pooled_embeddings,
+            cluster_id_to_indices=cluster_id_to_indices,
+        )

+ 146 - 0
deconstruct_SQI/colpali/colpali_engine/compression/token_pooling/hierarchical_token_pooling.py

@@ -0,0 +1,146 @@
+from concurrent.futures import ThreadPoolExecutor
+from typing import Dict, List, Optional, Tuple, cast
+
+import numpy as np
+import torch
+from numpy.typing import NDArray
+from scipy.cluster.hierarchy import fcluster, linkage
+
+from colpali_engine.compression.token_pooling.base_token_pooling import BaseTokenPooler
+
+
+class HierarchicalTokenPooler(BaseTokenPooler):
+    """
+    Hierarchical token pooling of multi-vector embeddings based on the similarity between token embeddings.
+
+    Example with a list of 2D tensors:
+
+    ```python
+    list_embeddings = [torch.rand(10, 768), torch.rand(20, 768)]
+    pooler = HierarchicalTokenPooler()
+    outputs = pooler.pool_embeddings(list_embeddings, pool_factor=2)
+    ```
+
+    Example with a 0-padded 3D tensor:
+
+    ```python
+    list_embeddings = [torch.rand(10, 768), torch.rand(20, 768)]
+    padded_embeddings = torch.nn.utils.rnn.pad_sequence(
+            list_embeddings,
+            batch_first=True,
+            padding_value=0.0,
+            padding_side="left",
+        )
+    pooler = HierarchicalTokenPooler()
+    outputs = pooler.pool_embeddings(list_embeddings, pool_factor=2, padding=True, padding_side="left")
+    ```
+    """
+
+    def _pool_embeddings_impl(
+        self,
+        embeddings: List[torch.Tensor],
+        pool_factor: int,
+        num_workers: Optional[int] = None,
+    ) -> Tuple[
+        List[torch.Tensor],
+        List[Dict[int, Tuple[torch.Tensor]]],
+    ]:
+        """
+        Apply hierarchical pooling to each embedding in the list.
+
+        Args:
+            embeddings: A list of 2D tensors (token_length, embedding_dim) where each tensor can have its own token
+                        length, or a 3D tensor of shape (batch_size, token_length, embedding_dim) with 0-padding.
+            pool_factor: An integer factor that determines the maximum number of clusters defined as
+                         `max_clusters = max(token_length // pool_factor, 1)`.
+            num_workers: The number of workers to use for parallel processing. If not provided, the pooler will use
+                         the number of available CPU cores.
+
+        Returns:
+            Tuple containing:
+            - List of pooled embeddings
+            - List of dictionaries mapping cluster IDs to token indices
+        """
+        if num_workers and num_workers > 1:
+            with ThreadPoolExecutor(num_workers) as executor:
+                # NOTE: We opted for a thread-based pool because most of the heavy lifting is done in C-level libraries
+                # (NumPy, Torch, and SciPy) which usually release the GIL.
+                results = list(
+                    executor.map(lambda x: self._pool_single_embedding(x, pool_factor=pool_factor), embeddings)
+                )
+        elif num_workers is None or num_workers == 1:
+            # Process embeddings sequentially
+            results = [self._pool_single_embedding(embedding, pool_factor=pool_factor) for embedding in embeddings]
+        else:
+            raise ValueError(f"Invalid number of workers: {num_workers}")
+
+        # Unpack the results
+        pooled_embeddings = [result[0] for result in results]
+        cluster_id_to_indices = [result[1] for result in results]
+
+        return pooled_embeddings, cluster_id_to_indices
+
+    def _pool_single_embedding(
+        self,
+        embedding: torch.Tensor,
+        pool_factor: int,
+    ) -> Tuple[torch.Tensor, Dict[int, Tuple[torch.Tensor]]]:
+        """
+        Return the pooled embedding and the mapping from cluster id to token indices.
+
+        Args:
+            embedding: A tensor of shape (token_length, embedding_dim).
+            pool_factor: An integer factor that determines the maximum number of clusters defined as
+                         `max_clusters = max(token_length // pool_factor, 1)`.
+
+        Returns:
+            pooled_embedding: A tensor of shape (num_clusters, embedding_dim).
+            cluster_id_to_indices: A dictionary mapping the cluster id to token indices.
+        """
+        if embedding.dim() != 2:
+            raise ValueError("The input tensor must be a 2D tensor.")
+
+        token_length = embedding.size(0)
+        if token_length == 1:
+            raise ValueError("The input tensor must have more than one token.")
+
+        if pool_factor == 1:
+            cluster_id_to_indices = {0: (torch.arange(token_length),)}
+            return embedding, cluster_id_to_indices
+
+        # Move the embedding to CPU for better multi-threading performance
+        dtype = embedding.dtype
+        device = embedding.device
+        embedding = embedding.to(torch.float32).cpu()
+
+        list_pooled_embeddings: List[torch.Tensor] = []
+
+        similarities = torch.mm(embedding, embedding.t())
+        distances = 1 - similarities.numpy()
+
+        Z = linkage(distances, metric="euclidean", method="ward")  # noqa: N806
+        max_clusters = max(token_length // pool_factor, 1)
+        cluster_labels: NDArray[np.int32] = fcluster(Z, t=max_clusters, criterion="maxclust") - 1
+        # NOTE: The scipy cluster labels start from 1, so we subtract 1 to start from 0.
+
+        cluster_id_to_indices: Dict[int, Tuple[torch.Tensor]] = {}
+
+        with torch.no_grad():
+            for cluster_id in range(max_clusters):
+                cluster_indices = cast(
+                    Tuple[torch.Tensor],  # we know it is a 1-tuple
+                    torch.where(torch.tensor(cluster_labels == cluster_id)),
+                )
+                cluster_id_to_indices[cluster_id] = cluster_indices
+
+                if cluster_indices[0].numel() > 0:
+                    pooled_embedding = embedding[cluster_indices].mean(dim=0)  # (embedding_dim,)
+                    pooled_embedding = torch.nn.functional.normalize(pooled_embedding, p=2, dim=-1)
+                    list_pooled_embeddings.append(pooled_embedding)
+
+            pooled_embeddings = torch.stack(list_pooled_embeddings, dim=0)  # (num_clusters, embedding_dim)
+
+        # Restore the original device and dtype
+        pooled_embeddings = pooled_embeddings.to(device).to(dtype)
+
+        return pooled_embeddings, cluster_id_to_indices

+ 89 - 0
deconstruct_SQI/colpali/colpali_engine/compression/token_pooling/lambda_token_pooling.py

@@ -0,0 +1,89 @@
+from concurrent.futures import ThreadPoolExecutor
+from typing import Callable, Dict, List, Optional, Tuple
+
+import torch
+
+from colpali_engine.compression.token_pooling.base_token_pooling import BaseTokenPooler
+
+
+class LambdaTokenPooler(BaseTokenPooler):
+    """
+    Token pooler that applies a user-defined pooling function to multi-vector embeddings.
+
+    This pooler allows users to define custom pooling methods rather than relying on pre-defined pooling strategies.
+
+    Example:
+
+    ```python
+    # Define a custom pooling function that reduces sequence length by half
+    def custom_pooling(embedding: torch.Tensor) -> torch.Tensor:
+        token_length = embedding.size(0)
+        # Resize to half the original length by averaging pairs of tokens
+        half_length = token_length // 2 + (token_length % 2)
+        pooled_embeddings = torch.zeros(
+            (half_length, embedding.size(1)),
+            dtype=embedding.dtype,
+            device=embedding.device,
+        )
+
+        for i in range(half_length):
+            start_idx = i * 2
+            end_idx = min(start_idx + 2, token_length)
+            cluster_indices = torch.arange(start_idx, end_idx)
+            pooled_embeddings[i] = embedding[cluster_indices].mean(dim=0)
+            pooled_embeddings[i] = torch.nn.functional.normalize(pooled_embeddings[i], p=2, dim=-1)
+
+        return pooled_embeddings
+
+
+    # Create a LambdaTokenPooler with the custom function
+    pooler = LambdaTokenPooler(pool_func=custom_pooling)
+    outputs = pooler.pool_embeddings(embeddings)
+    ```
+    """
+
+    def __init__(
+        self,
+        pool_func: Callable[[torch.Tensor], torch.Tensor],
+    ):
+        """
+        Initialize the LambdaTokenPooler with a custom pooling function.
+
+        Args:
+            pool_func: A function that takes a 2D tensor (token_length, embedding_dim) and returns pooled embeddings,
+                       i.e. a tensor of shape (num_clusters, embedding_dim)).
+        """
+        self.pool_func = pool_func
+
+    def _pool_embeddings_impl(
+        self,
+        embeddings: List[torch.Tensor],
+        num_workers: Optional[int] = None,
+    ) -> Tuple[
+        List[torch.Tensor],
+        Optional[List[Dict[int, Tuple[torch.Tensor]]]],
+    ]:
+        """
+        Apply the custom pooling function to each embedding in the list.
+
+        Args:
+            embeddings: List of 2D tensors to pool
+            num_workers: Number of workers for parallel processing
+
+        Returns:
+            Tuple containing:
+            - List of pooled embeddings
+            - None (no cluster ID mapping in this implementation)
+        """
+        if num_workers and num_workers > 1:
+            with ThreadPoolExecutor(num_workers) as executor:
+                # NOTE: We opted for a thread-based pool because most of the heavy lifting is done in C-level libraries
+                # (NumPy, Torch, and SciPy) which usually release the GIL.
+                pooled_embeddings = list(executor.map(self.pool_func, embeddings))
+        elif num_workers is None or num_workers == 1:
+            # Process embeddings sequentially
+            pooled_embeddings = [self.pool_func(emb) for emb in embeddings]
+        else:
+            raise ValueError(f"Invalid number of workers: {num_workers}")
+
+        return pooled_embeddings, None

+ 2 - 0
deconstruct_SQI/colpali/colpali_engine/data/__init__.py

@@ -0,0 +1,2 @@
+from .dataset import ColPaliEngineDataset, Corpus
+from .sampler import SingleDatasetBatchSampler

+ 162 - 0
deconstruct_SQI/colpali/colpali_engine/data/dataset.py

@@ -0,0 +1,162 @@
+import random
+from typing import Any, Dict, List, Optional, Union
+
+from datasets import Dataset as HFDataset
+from PIL import Image
+from torch.utils.data import Dataset
+
+Document = Union[str, Image.Image]
+
+
+class Corpus:
+    """
+    Corpus class for handling retrieving with simple mapping.
+    This class is meant to be overridden by the user to handle their own corpus.
+
+    Args:
+        corpus_data (List[Dict[str, Any]]): List of dictionaries containing doc data.
+        docid_to_idx_mapping (Optional[Dict[str, int]]): Optional mapping from doc IDs to indices.
+    """
+
+    def __init__(
+        self,
+        corpus_data: List[Dict[str, Any]],
+        docid_to_idx_mapping: Optional[Dict[str, int]] = None,
+        doc_column_name: str = "doc",
+    ):
+        """
+        Initialize the corpus with the provided data.
+        """
+        self.corpus_data = corpus_data
+        self.docid_to_idx_mapping = docid_to_idx_mapping
+        self.doc_column_name = doc_column_name
+
+        assert isinstance(
+            self.corpus_data,
+            (list, Dataset, HFDataset),
+        ), "Corpus data must be a map-style dataset"
+
+        assert self.doc_column_name in self.corpus_data[0], f"Corpus data must contain a column {self.doc_column_name}."
+
+    def __len__(self) -> int:
+        """
+        Return the number of docs in the corpus.
+
+        Returns:
+            int: The number of docs in the corpus.
+        """
+        return len(self.corpus_data)
+
+    def retrieve(self, docid: Any) -> Document:
+        """
+        Get the corpus row from the given Doc ID.
+
+        Args:
+            docid (str): The id of the document.
+
+        Returns:
+            Document: The document retrieved from the corpus.
+        """
+        if self.docid_to_idx_mapping is not None:
+            doc_idx = self.docid_to_idx_mapping[docid]
+        else:
+            doc_idx = docid
+        return self.corpus_data[doc_idx][self.doc_column_name]
+
+
+class ColPaliEngineDataset(Dataset):
+    # Output keys
+    QUERY_KEY = "query"
+    POS_TARGET_KEY = "pos_target"
+    NEG_TARGET_KEY = "neg_target"
+
+    def __init__(
+        self,
+        data: List[Dict[str, Any]],
+        corpus: Optional[Corpus] = None,
+        query_column_name: str = "query",
+        pos_target_column_name: str = "pos_target",
+        neg_target_column_name: str = None,
+        num_negatives: int = 3,
+    ):
+        """
+        Initialize the dataset with the provided data and external document corpus.
+
+        Args:
+            data (Dict[str, List[Any]]): A dictionary containing the dataset samples.
+            corpus (Optional[Corpus]): An optional external document corpus to retrieve
+            documents (images) from.
+        """
+        self.data = data
+        self.corpus = corpus
+
+        # Column args
+        self.query_column_name = query_column_name
+        self.pos_target_column_name = pos_target_column_name
+        self.neg_target_column_name = neg_target_column_name
+
+        self.num_negatives = num_negatives
+        assert isinstance(
+            self.data,
+            (list, Dataset, HFDataset),
+        ), "Data must be a map-style dataset"
+
+        assert self.query_column_name in self.data[0], f"Data must contain the {self.query_column_name} column"
+        assert self.pos_target_column_name in self.data[0], f"Data must contain a {self.pos_target_column_name} column"
+        if self.neg_target_column_name is not None:
+            assert self.neg_target_column_name in self.data[0], (
+                f"Data must contain a {self.neg_target_column_name} column"
+            )
+
+    def __len__(self) -> int:
+        """Return the number of samples in the dataset."""
+        return len(self.data)
+
+    def __getitem__(self, idx: int) -> Dict[str, Any]:
+        sample = self.data[idx]
+
+        query = sample[self.query_column_name]
+
+        pos_targets = sample[self.pos_target_column_name]
+        if not isinstance(pos_targets, list):
+            pos_targets = [pos_targets]
+
+        if self.neg_target_column_name is not None:
+            neg_targets = sample[self.neg_target_column_name]
+            if not isinstance(neg_targets, list):
+                neg_targets = [neg_targets]
+        else:
+            neg_targets = None
+
+        # If an external document corpus is provided, retrieve the documents from it.
+        if self.corpus is not None:
+            pos_targets = [self.corpus.retrieve(doc_id) for doc_id in pos_targets]
+            if neg_targets is not None:
+                # to avoid oveflowing CPU memory
+                if len(neg_targets) > self.num_negatives:
+                    neg_targets = random.sample(neg_targets, self.num_negatives)
+                neg_targets = [self.corpus.retrieve(doc_id) for doc_id in neg_targets]
+
+        return {
+            self.QUERY_KEY: query,
+            self.POS_TARGET_KEY: pos_targets,
+            self.NEG_TARGET_KEY: neg_targets,
+        }
+
+    def take(self, n: int) -> "ColPaliEngineDataset":
+        """
+        Take the first n samples from the dataset.
+
+        Args:
+            n (int): The number of samples to take.
+
+        Returns:
+            ColPaliEngineDataset: A new dataset containing the first n samples.
+        """
+        return self.__class__(
+            self.data.take(n),
+            self.corpus,
+            self.query_column_name,
+            self.pos_target_column_name,
+            self.neg_target_column_name,
+        )

+ 107 - 0
deconstruct_SQI/colpali/colpali_engine/data/sampler.py

@@ -0,0 +1,107 @@
+from typing import Iterator, List, Optional
+
+import numpy as np
+import torch
+from torch.utils.data import BatchSampler, Dataset
+
+
+class SingleDatasetBatchSampler(BatchSampler):
+    """
+    A batch sampler that samples from a single dataset per batch and handles distribution across GPUs.
+
+    Args:
+        datasets (List[Dataset]): List of datasets to sample from
+        batch_size (int): Global batch size (will be divided across GPUs)
+        drop_last (bool): Whether to drop the last incomplete batch
+        generator (Optional[torch.Generator]): Random number generator
+    """
+
+    def __init__(
+        self,
+        datasets: List[Dataset],
+        global_batch_size: int,
+        drop_last: bool = True,
+        generator: Optional[torch.Generator] = None,
+    ):
+        self.datasets = datasets
+        self.global_batch_size = global_batch_size
+        self.drop_last = drop_last
+        self.generator = generator or torch.Generator()
+        self.initial_seed = self.generator.initial_seed()
+
+        # Calculate dataset sizes and create index mappings
+        self.dataset_sizes = [len(dataset) for dataset in datasets]
+        #### get start of each dataset #####
+        self.cumsum_sizes = np.cumsum([0] + self.dataset_sizes).tolist()
+        self.total_size = sum(self.dataset_sizes)
+
+        # Create shuffled indices for each dataset
+        self.indices_per_dataset = [
+            torch.randperm(size, generator=self.generator).tolist() for size in self.dataset_sizes
+        ]
+        self.current_positions = [0] * len(datasets)
+
+        self.available_datasets = list(range(len(datasets)))
+        self.max_positions = [(size // self.global_batch_size) * self.global_batch_size for size in self.dataset_sizes]
+
+    def __iter__(self) -> Iterator[List[int]]:
+        # Reset state
+        self.current_positions = [0] * len(self.datasets)
+        self.available_datasets = list(range(len(self.datasets)))
+        self.current_data_lengths = [size for size in self.dataset_sizes]  # full length, never shrinks
+
+        while self.available_datasets:
+            # Build probabilities for available datasets only
+            lengths = [self.current_data_lengths[i] for i in self.available_datasets]
+            total_length = sum(lengths)
+            if total_length <= 0:
+                break  # nothing left to sample
+
+            probs = torch.tensor(lengths, dtype=torch.float) / total_length
+
+            # Pick dataset
+            dataset_idx_in_available = torch.multinomial(probs, num_samples=1, generator=self.generator).item()
+            dataset_idx = self.available_datasets[dataset_idx_in_available]
+
+            # Fetch batch
+            dataset_indices = self.indices_per_dataset[dataset_idx]
+            current_pos = self.current_positions[dataset_idx]
+            end_pos = current_pos + self.global_batch_size
+
+            if end_pos <= self.max_positions[dataset_idx]:
+                batch_indices = [idx + self.cumsum_sizes[dataset_idx] for idx in dataset_indices[current_pos:end_pos]]
+                self.current_positions[dataset_idx] = end_pos
+                self.current_data_lengths[dataset_idx] = self.dataset_sizes[dataset_idx] - end_pos
+
+                # Remove if exhausted
+                if end_pos >= self.max_positions[dataset_idx]:
+                    self.available_datasets.remove(dataset_idx)
+
+                yield batch_indices
+            else:
+                # Not enough for a full batch
+                self.available_datasets.remove(dataset_idx)
+
+    def set_epoch(self, epoch):
+        """
+        Sets the epoch for this sampler.
+
+        Args:
+            epoch (int): Epoch number
+        """
+        torch_gen = torch.Generator()
+
+        # Set seed based on epoch to ensure different shuffling each epoch
+        new_seed = self.initial_seed + epoch
+        torch_gen.manual_seed(new_seed)
+        self.generator.manual_seed(new_seed)
+
+        # Reshuffle indices for each dataset
+        self.indices_per_dataset = [torch.randperm(size, generator=torch_gen).tolist() for size in self.dataset_sizes]
+
+    @property
+    def batch_size(self) -> int:
+        return self.global_batch_size
+
+    def __len__(self) -> int:
+        return sum(size // self.global_batch_size for size in self.dataset_sizes)

+ 8 - 0
deconstruct_SQI/colpali/colpali_engine/interpretability/__init__.py

@@ -0,0 +1,8 @@
+from .similarity_map_utils import (
+    get_similarity_maps_from_embeddings,
+    normalize_similarity_map,
+)
+from .similarity_maps import (
+    plot_all_similarity_maps,
+    plot_similarity_map,
+)

+ 84 - 0
deconstruct_SQI/colpali/colpali_engine/interpretability/similarity_map_utils.py

@@ -0,0 +1,84 @@
+from typing import List, Tuple, Union
+
+import torch
+from einops import rearrange
+
+EPSILON = 1e-10
+
+
+def get_similarity_maps_from_embeddings(
+    image_embeddings: torch.Tensor,
+    query_embeddings: torch.Tensor,
+    n_patches: Union[Tuple[int, int], List[Tuple[int, int]]],
+    image_mask: torch.Tensor,
+) -> List[torch.Tensor]:
+    """
+    Get the batched similarity maps between the query embeddings and the image embeddings.
+    Each element in the returned list is a tensor of shape (query_tokens, n_patches_x, n_patches_y).
+
+    Args:
+        image_embeddings: tensor of shape (batch_size, image_tokens, dim)
+        query_embeddings: tensor of shape (batch_size, query_tokens, dim)
+        n_patches: number of patches per dimension for each image in the batch. If a single tuple is provided,
+            the same number of patches is used for all images in the batch (broadcasted).
+        image_mask: tensor of shape (batch_size, image_tokens). Used to filter out the embeddings
+            that are not related to the image
+    """
+
+    if isinstance(n_patches, tuple):
+        n_patches = [n_patches] * image_embeddings.size(0)
+
+    similarity_maps: List[torch.Tensor] = []
+
+    for idx in range(image_embeddings.size(0)):
+        # Sanity check
+        if image_mask[idx].sum() != n_patches[idx][0] * n_patches[idx][1]:
+            raise ValueError(
+                f"The number of patches ({n_patches[idx][0]} x {n_patches[idx][1]} = "
+                f"{n_patches[idx][0] * n_patches[idx][1]}) "
+                f"does not match the number of non-padded image tokens ({image_mask[idx].sum()})."
+            )
+
+        # Rearrange the output image tensor to explicitly represent the 2D grid of patches
+        image_embedding_grid = rearrange(
+            image_embeddings[idx][image_mask[idx]],  # (n_patches_x * n_patches_y, dim)
+            "(h w) c -> w h c",
+            w=n_patches[idx][0],
+            h=n_patches[idx][1],
+        )  # (n_patches_x, n_patches_y, dim)
+
+        similarity_map = torch.einsum(
+            "nk,ijk->nij", query_embeddings[idx], image_embedding_grid
+        )  # (batch_size, query_tokens, n_patches_x, n_patches_y)
+
+        similarity_maps.append(similarity_map)
+
+    return similarity_maps
+
+
+def normalize_similarity_map(similarity_map: torch.Tensor) -> torch.Tensor:
+    """
+    Normalize the similarity map to have values in the range [0, 1].
+
+    Args:
+        similarity_map: tensor of shape (n_patch_x, n_patch_y) or (batch_size, n_patch_x, n_patch_y)
+    """
+    if similarity_map.ndim not in [2, 3]:
+        raise ValueError(
+            "The input tensor must have 2 dimensions (n_patch_x, n_patch_y) or "
+            "3 dimensions (batch_size, n_patch_x, n_patch_y)."
+        )
+
+    # Compute the minimum values along the last two dimensions (n_patch_x, n_patch_y)
+    min_vals = similarity_map.min(dim=-1, keepdim=True)[0].min(dim=-2, keepdim=True)[0]  # (1, 1) or (batch_size, 1, 1)
+
+    # Compute the maximum values along the last two dimensions (n_patch_x, n_patch_y)
+    max_vals = similarity_map.max(dim=-1, keepdim=True)[0].max(dim=-2, keepdim=True)[0]  # (1, 1) or (batch_size, 1, 1)
+
+    # Normalize the tensor
+    # NOTE: Add a small epsilon to avoid division by zero.
+    similarity_map_normalized = (similarity_map - min_vals) / (
+        max_vals - min_vals + EPSILON
+    )  # (n_patch_x, n_patch_y) or (batch_size, n_patch_x, n_patch_y)
+
+    return similarity_map_normalized

+ 150 - 0
deconstruct_SQI/colpali/colpali_engine/interpretability/similarity_maps.py

@@ -0,0 +1,150 @@
+from typing import List, Tuple
+
+import matplotlib.pyplot as plt
+import numpy as np
+import seaborn as sns
+import torch
+from einops import rearrange
+from PIL import Image
+
+from colpali_engine.interpretability.similarity_map_utils import normalize_similarity_map
+
+
+def plot_similarity_map(
+    image: Image.Image,
+    similarity_map: torch.Tensor,
+    figsize: Tuple[int, int] = (8, 8),
+    show_colorbar: bool = False,
+) -> Tuple[plt.Figure, plt.Axes]:
+    """
+    Plot and overlay a similarity map over the input image.
+
+    A similarity map is a 2D tensor where each element (i, j) represents the similarity score between a chosen query
+    token and the associated image patch at position (i, j). Thus, the higher the similarity score, the brighter the
+    color of the patch.
+
+    To show the returned similarity map, use:
+
+    ```python
+    >>> fig, ax = plot_similarity_map(image, similarity_map)
+    >>> fig.show()
+    ```
+
+    Args:
+        image: PIL image
+        similarity_map: tensor of shape (n_patches_x, n_patches_y)
+        figsize: size of the figure
+        show_colorbar: whether to show a colorbar
+    """
+
+    # Convert the image to an array
+    img_array = np.array(image.convert("RGBA"))  # (height, width, channels)
+
+    # Normalize the similarity map and convert it to Pillow image
+    similarity_map_array = (
+        normalize_similarity_map(similarity_map).to(torch.float32).cpu().numpy()
+    )  # (n_patches_x, n_patches_y)
+
+    # Reshape the similarity map to match the PIL shape convention
+    similarity_map_array = rearrange(similarity_map_array, "h w -> w h")  # (n_patches_y, n_patches_x)
+
+    similarity_map_image = Image.fromarray((similarity_map_array * 255).astype("uint8")).resize(
+        image.size, Image.Resampling.BICUBIC
+    )
+
+    # Create the figure
+    with plt.style.context("dark_background"):
+        fig, ax = plt.subplots(figsize=figsize)
+
+        ax.imshow(img_array)
+        im = ax.imshow(
+            similarity_map_image,
+            cmap=sns.color_palette("mako", as_cmap=True),
+            alpha=0.5,
+        )
+
+        if show_colorbar:
+            fig.colorbar(im)
+        ax.set_axis_off()
+        fig.tight_layout()
+
+    return fig, ax
+
+
+def plot_all_similarity_maps(
+    image: Image.Image,
+    query_tokens: List[str],
+    similarity_maps: torch.Tensor,
+    figsize: Tuple[int, int] = (8, 8),
+    show_colorbar: bool = False,
+    add_title: bool = True,
+) -> List[Tuple[plt.Figure, plt.Axes]]:
+    """
+    For each token in the query, plot and overlay a similarity map over the input image.
+
+    A similarity map is a 2D tensor where each element (i, j) represents the similarity score between a chosen query
+    token and the associated image patch at position (i, j). Thus, the higher the similarity score, the brighter the
+    color of the patch.
+
+    Args:
+        image: PIL image
+        query_tokens: list of query tokens
+        similarity_maps: tensor of shape (query_tokens, n_patches_x, n_patches_y)
+        figsize: size of the figure
+        show_colorbar: whether to show a colorbar
+        add_title: whether to add a title with the token and the max similarity score
+
+    Example usage for one query-image pair:
+
+    ```python
+    >>> from colpali_engine.interpretability.similarity_map_utils import get_similarity_maps_from_embeddings
+
+    >>> batch_images = processor.process_images([image]).to(device)
+    >>> batch_queries = processor.process_queries([query]).to(device)
+
+    >>> with torch.no_grad():
+            image_embeddings = model.forward(**batch_images)
+            query_embeddings = model.forward(**batch_queries)
+
+    >>> n_patches = processor.get_n_patches(
+            image_size=image.size,
+            patch_size=model.patch_size
+        )
+    >>> image_mask = processor.get_image_mask(batch_images)
+
+    >>> batched_similarity_maps = get_similarity_maps_from_embeddings(
+            image_embeddings=image_embeddings,
+            query_embeddings=query_embeddings,
+            n_patches=n_patches,
+            image_mask=image_mask,
+        )
+    >>> similarity_maps = batched_similarity_maps[0]  # (query_length, n_patches_x, n_patches_y)
+
+    >>> plots = plot_all_similarity_maps(
+            image=image,
+            query_tokens=query_tokens,
+            similarity_maps=similarity_maps,
+        )
+
+    >>> for fig, ax in plots:
+            fig.show()
+    ```
+    """
+
+    plots: List[Tuple[plt.Figure, plt.Axes]] = []
+
+    for idx, token in enumerate(query_tokens):
+        fig, ax = plot_similarity_map(
+            image=image,
+            similarity_map=similarity_maps[idx],
+            figsize=figsize,
+            show_colorbar=show_colorbar,
+        )
+
+        if add_title:
+            max_sim_score = similarity_maps[idx].max().item()
+            ax.set_title(f"Token #{idx}: `{token}`. MaxSim score: {max_sim_score:.2f}", fontsize=14)
+
+        plots.append((fig, ax))
+
+    return plots

+ 16 - 0
deconstruct_SQI/colpali/colpali_engine/loss/__init__.py

@@ -0,0 +1,16 @@
+from .bi_encoder_losses import (
+    BiEncoderLoss,
+    BiEncoderModule,
+    BiNegativeCELoss,
+    BiPairwiseCELoss,
+    BiPairwiseNegativeCELoss,
+    BiSigmoidLoss,
+)
+from .late_interaction_losses import (
+    ColbertLoss,
+    ColbertModule,
+    ColbertNegativeCELoss,
+    ColbertPairwiseCELoss,
+    ColbertPairwiseNegativeCELoss,
+    ColbertSigmoidLoss,
+)

+ 418 - 0
deconstruct_SQI/colpali/colpali_engine/loss/bi_encoder_losses.py

@@ -0,0 +1,418 @@
+import torch
+import torch.nn.functional as F  # noqa: N812
+from torch.nn import CrossEntropyLoss
+
+
+class BiEncoderModule(torch.nn.Module):
+    """
+    Base module for bi-encoder losses, handling buffer indexing and filtering hyperparameters.
+
+    Args:
+        max_batch_size (int): Maximum batch size for the pre-allocated index buffer.
+        temperature (float): Scaling factor for logits (must be > 0).
+        filter_threshold (float): Fraction of positive score above which negatives are down-weighted.
+        filter_factor (float): Multiplicative factor applied to filtered negative scores.
+    """
+
+    def __init__(
+        self,
+        max_batch_size: int = 1024,
+        temperature: float = 0.02,
+        filter_threshold: float = 0.95,
+        filter_factor: float = 0.5,
+    ):
+        super().__init__()
+        if temperature <= 0:
+            raise ValueError("Temperature must be strictly positive")
+        self.register_buffer("idx_buffer", torch.arange(max_batch_size), persistent=False)
+        self.temperature = temperature
+        self.filter_threshold = filter_threshold
+        self.filter_factor = filter_factor
+
+    def _get_idx(self, batch_size: int, offset: int, device: torch.device):
+        """
+        Generate index tensors for in-batch cross-entropy.
+
+        Args:
+            batch_size (int): Number of queries/docs in the batch.
+            offset (int): Offset to apply for multi-GPU indexing.
+            device (torch.device): Target device of the indices.
+
+        Returns:
+            Tuple[Tensor, Tensor]: (idx, pos_idx) both shape [batch_size].
+        """
+        idx = self.idx_buffer[:batch_size].to(device)
+        return idx, idx + offset
+
+    def _filter_high_negatives(self, scores: torch.Tensor, pos_idx: torch.Tensor):
+        """
+        In-place down-weighting of "too-high" in-batch negative scores.
+
+        Args:
+            scores (Tensor[B, B]): In-batch similarity matrix.
+            pos_idx (Tensor[B]): Positive index for each query.
+        """
+        batch_size = scores.size(0)
+        idx = self.idx_buffer[:batch_size].to(scores.device)
+        pos_scores = scores[idx, pos_idx]
+        thresh = self.filter_threshold * pos_scores.unsqueeze(1)
+        mask = scores > thresh
+        mask[idx, pos_idx] = False
+        scores[mask] *= self.filter_factor
+
+
+class BiEncoderLoss(BiEncoderModule):
+    """
+    InfoNCE loss for bi-encoders without explicit negatives.
+
+    Args:
+        temperature (float): Scaling factor for logits.
+        pos_aware_negative_filtering (bool): Apply in-batch negative filtering if True.
+        max_batch_size (int): Max batch size for index buffer caching.
+        filter_threshold (float): Threshold ratio for negative filtering.
+        filter_factor (float): Factor to down-weight filtered negatives.
+    """
+
+    def __init__(
+        self,
+        temperature: float = 0.02,
+        pos_aware_negative_filtering: bool = False,
+        max_batch_size: int = 1024,
+        filter_threshold: float = 0.95,
+        filter_factor: float = 0.5,
+    ):
+        super().__init__(max_batch_size, temperature, filter_threshold, filter_factor)
+        self.pos_aware_negative_filtering = pos_aware_negative_filtering
+        self.ce_loss = CrossEntropyLoss()
+
+    def forward(
+        self,
+        query_embeddings: torch.Tensor,
+        doc_embeddings: torch.Tensor,
+        offset: int = 0,
+    ) -> torch.Tensor:
+        """
+        Compute the InfoNCE loss over a batch of bi-encoder embeddings.
+
+        Args:
+            query_embeddings (Tensor[B, D]): Query vectors.
+            doc_embeddings (Tensor[B, D]): Document vectors.
+            offset (int): Offset for positive indices (multi-GPU).
+
+        Returns:
+            Tensor: Scalar cross-entropy loss.
+        """
+        # Compute in-batch similarity matrix
+        scores = torch.einsum("bd,cd->bc", query_embeddings, doc_embeddings)
+        batch_size = scores.size(0)
+        idx, pos_idx = self._get_idx(batch_size, offset, scores.device)
+
+        if self.pos_aware_negative_filtering:
+            self._filter_high_negatives(scores, pos_idx)
+
+        return self.ce_loss(scores / self.temperature, pos_idx)
+
+
+class BiPairedEncoderLoss(BiEncoderModule):
+    """
+    InfoNCE loss for bi-encoders without explicit negatives.
+
+    Args:
+        temperature (float): Scaling factor for logits.
+        pos_aware_negative_filtering (bool): Apply in-batch negative filtering if True.
+        max_batch_size (int): Max batch size for index buffer caching.
+        filter_threshold (float): Threshold ratio for negative filtering.
+        filter_factor (float): Factor to down-weight filtered negatives.
+    """
+
+    def __init__(
+        self,
+        temperature: float = 0.02,
+        pos_aware_negative_filtering: bool = False,
+        max_batch_size: int = 1024,
+        filter_threshold: float = 0.95,
+        filter_factor: float = 0.5,
+    ):
+        super().__init__(max_batch_size, temperature, filter_threshold, filter_factor)
+        self.pos_aware_negative_filtering = pos_aware_negative_filtering
+        self.ce_loss = CrossEntropyLoss()
+
+    def forward(
+        self,
+        query_embeddings: torch.Tensor,
+        doc_embeddings: torch.Tensor,
+        offset: int = 0,
+    ) -> torch.Tensor:
+        """
+        Compute the InfoNCE loss over a batch of bi-encoder embeddings.
+
+        Args:
+            query_embeddings (Tensor[B, D]): Query vectors.
+            doc_embeddings (Tensor[B, D]): Document vectors.
+            offset (int): Offset for positive indices (multi-GPU).
+
+        Returns:
+            Tensor: Scalar cross-entropy loss.
+        """
+        # Compute in-batch similarity matrix
+        scores = torch.einsum("bd,cd->bc", query_embeddings, doc_embeddings)
+        batch_size = scores.size(0)
+        idx, pos_idx = self._get_idx(batch_size, offset, scores.device)
+
+        if self.pos_aware_negative_filtering:
+            self._filter_high_negatives(scores, pos_idx)
+
+        q2t = self.ce_loss(scores / self.temperature, pos_idx)
+        t2q = self.ce_loss(scores.T / self.temperature, ...)
+
+        return (q2t + t2q) / 2.0
+
+
+class BiNegativeCELoss(BiEncoderModule):
+    """
+    InfoNCE loss with explicit negative samples and optional in-batch term.
+
+    Args:
+        temperature (float): Scaling factor for logits.
+        in_batch_term_weight (float): Weight for in-batch cross-entropy term (0 to 1).
+        pos_aware_negative_filtering (bool): Apply in-batch negative filtering.
+        max_batch_size (int): Max batch size for index buffer.
+        filter_threshold (float): Threshold ratio for filtering.
+        filter_factor (float): Factor to down-weight filtered negatives.
+    """
+
+    def __init__(
+        self,
+        temperature: float = 0.02,
+        in_batch_term_weight: float = 0.5,
+        pos_aware_negative_filtering: bool = False,
+        max_batch_size: int = 1024,
+        filter_threshold: float = 0.95,
+        filter_factor: float = 0.5,
+    ):
+        super().__init__(max_batch_size, temperature, filter_threshold, filter_factor)
+        self.in_batch_term_weight = in_batch_term_weight
+        assert 0 <= in_batch_term_weight <= 1, "in_batch_term_weight must be between 0 and 1"
+        self.pos_aware_negative_filtering = pos_aware_negative_filtering
+        self.ce_loss = CrossEntropyLoss()
+        # Inner InfoNCE for in-batch
+        self.inner_loss = BiEncoderLoss(
+            temperature=temperature,
+            pos_aware_negative_filtering=pos_aware_negative_filtering,
+            max_batch_size=max_batch_size,
+            filter_threshold=filter_threshold,
+            filter_factor=filter_factor,
+        )
+
+    def forward(
+        self,
+        query_embeddings: torch.Tensor,
+        doc_embeddings: torch.Tensor,
+        neg_doc_embeddings: torch.Tensor,
+        offset: int = 0,
+    ) -> torch.Tensor:
+        """
+        Compute softplus(neg_score - pos_score) plus optional in-batch CE.
+
+        Args:
+            query_embeddings (Tensor[B, D]): Query vectors.
+            doc_embeddings (Tensor[B, D]): Positive document vectors.
+            neg_doc_embeddings (Tensor[B, N, D]): Negative document vectors.
+            offset (int): Offset for in-batch CE positives.
+
+        Returns:
+            Tensor: Scalar loss value.
+        """
+        # Dot-product only for matching pairs
+        pos_scores = (query_embeddings * doc_embeddings[offset : offset + neg_doc_embeddings.size(0)]).sum(dim=1)
+        pos_scores /= self.temperature
+        neg_scores = torch.einsum("bd,bnd->bn", query_embeddings, neg_doc_embeddings) / self.temperature
+
+        loss = F.softplus(neg_scores - pos_scores.unsqueeze(1)).mean()
+
+        if self.in_batch_term_weight > 0:
+            loss_ib = self.inner_loss(query_embeddings, doc_embeddings, offset)
+            loss = loss * (1 - self.in_batch_term_weight) + loss_ib * self.in_batch_term_weight
+        return loss
+
+
+class BiPairwiseCELoss(BiEncoderModule):
+    """
+    Pairwise softplus loss mining the hardest in-batch negative.
+
+    Args:
+        temperature (float): Scaling factor for logits.
+        pos_aware_negative_filtering (bool): Filter high negatives before mining.
+        max_batch_size (int): Maximum batch size for indexing.
+        filter_threshold (float): Threshold for pos-aware filtering.
+        filter_factor (float): Factor to down-weight filtered negatives.
+    """
+
+    def __init__(
+        self,
+        temperature: float = 0.02,
+        pos_aware_negative_filtering: bool = False,
+        max_batch_size: int = 1024,
+        filter_threshold: float = 0.95,
+        filter_factor: float = 0.5,
+    ):
+        super().__init__(max_batch_size, temperature, filter_threshold, filter_factor)
+        self.pos_aware_negative_filtering = pos_aware_negative_filtering
+
+    def forward(
+        self,
+        query_embeddings: torch.Tensor,
+        doc_embeddings: torch.Tensor,
+        offset: int = 0,
+    ) -> torch.Tensor:
+        """
+        Compute softplus(hardest_neg - pos) where hardest_neg is the highest off-diagonal score.
+
+        Args:
+            query_embeddings (Tensor[B, D]): Query vectors.
+            doc_embeddings (Tensor[B, D]): Document vectors.
+
+        Returns:
+            Tensor: Scalar loss value.
+        """
+        scores = torch.einsum("bd,cd->bc", query_embeddings, doc_embeddings)
+        batch_size = scores.size(0)
+        idx = self.idx_buffer[:batch_size].to(scores.device)
+        pos = scores.diagonal()
+
+        if self.pos_aware_negative_filtering:
+            self._filter_high_negatives(scores, idx)
+
+        top2 = scores.topk(2, dim=1).values
+        neg = torch.where(top2[:, 0] == pos, top2[:, 1], top2[:, 0])
+
+        return torch.nn.functional.softplus((neg - pos) / self.temperature).mean()
+
+
+class BiPairwiseNegativeCELoss(BiEncoderModule):
+    """
+    Pairwise softplus loss with explicit negatives and optional in-batch term.
+
+    Args:
+        temperature (float): Scaling factor for logits.
+        in_batch_term_weight (float): Weight for in-batch cross-entropy term (0 to 1).
+        max_batch_size (int): Maximum batch size for indexing.
+        filter_threshold (float): Threshold for pos-aware filtering.
+        filter_factor (float): Factor to down-weight filtered negatives.
+    """
+
+    def __init__(
+        self,
+        temperature: float = 0.02,
+        in_batch_term_weight: float = 0.5,
+        max_batch_size: int = 1024,
+        filter_threshold: float = 0.95,
+        filter_factor: float = 0.5,
+    ):
+        super().__init__(max_batch_size, temperature, filter_threshold, filter_factor)
+        self.in_batch_term_weight = in_batch_term_weight
+        assert 0 <= in_batch_term_weight <= 1, "in_batch_term_weight must be between 0 and 1"
+        self.inner_pairwise = BiPairwiseCELoss(
+            temperature=temperature,
+            pos_aware_negative_filtering=False,
+            max_batch_size=max_batch_size,
+            filter_threshold=filter_threshold,
+            filter_factor=filter_factor,
+        )
+
+    def forward(
+        self,
+        query_embeddings: torch.Tensor,
+        doc_embeddings: torch.Tensor,
+        neg_doc_embeddings: torch.Tensor,
+        offset: int = 0,
+    ) -> torch.Tensor:
+        """
+        Compute softplus(neg-explicit - pos) plus optional pairwise in-batch loss.
+
+        Args:
+            query_embeddings (Tensor[B, D]): Query vectors.
+            doc_embeddings (Tensor[B, D]): Positive document vectors.
+            neg_doc_embeddings (Tensor[B, N, D]): Negative document vectors.
+
+        Returns:
+            Tensor: Scalar loss value.
+        """
+        # dot product for matching pairs only
+        pos = (query_embeddings * doc_embeddings[offset : offset + query_embeddings.size(0)]).sum(dim=1)  # B
+        neg = (query_embeddings.unsqueeze(1) * neg_doc_embeddings).sum(dim=2)  # B x N
+
+        loss = torch.nn.functional.softplus((neg - pos.unsqueeze(1)) / self.temperature).mean()
+
+        if self.in_batch_term_weight > 0:
+            loss_ib = self.inner_pairwise(query_embeddings, doc_embeddings, offset=offset)
+            loss = loss * (1 - self.in_batch_term_weight) + loss_ib * self.in_batch_term_weight
+
+        return loss
+
+
+class BiSigmoidLoss(BiEncoderModule):
+    """
+    Sigmoid loss for ColBERT with in-batch negatives.
+
+    Args:
+        temperature (float): Scaling factor for logits.
+        pos_aware_negative_filtering (bool): Apply in-batch negative filtering if True.
+        max_batch_size (int): Max batch size for index buffer caching.
+        filter_threshold (float): Threshold ratio for negative filtering.
+        filter_factor (float): Factor to down-weight filtered negatives.
+    """
+
+    def __init__(
+        self,
+        temperature: float = 0.02,
+        pos_aware_negative_filtering: bool = False,
+        max_batch_size: int = 1024,
+        filter_threshold: float = 0.95,
+        filter_factor: float = 0.5,
+    ):
+        super().__init__(max_batch_size, temperature, filter_threshold, filter_factor)
+        self.pos_aware_negative_filtering = pos_aware_negative_filtering
+
+    def forward(self, query_embeddings: torch.Tensor, doc_embeddings: torch.Tensor, offset: int = 0) -> torch.Tensor:
+        """
+        Compute the sigmoid loss for a batch of bi-encoder embeddings.
+
+        Args:
+            query_embeddings (Tensor[B, D]): Query vectors.
+            doc_embeddings (Tensor[B, D]): Document vectors.
+            offset (int): Offset for positive indices (multi-GPU).
+
+        Returns:
+            Tensor: Scalar cross-entropy loss.
+        """
+
+        # Compute in-batch similarity matrix
+        scores = torch.einsum("bd,cd->bc", query_embeddings, doc_embeddings)
+
+        batch_size, num_targets = scores.shape
+        device = scores.device
+
+        _, pos_idx = self._get_idx(batch_size, offset, device)
+
+        if self.pos_aware_negative_filtering:
+            self._filter_high_negatives(scores, pos_idx)
+
+        all_losses = []
+        for k in range(num_targets // batch_size):
+            # mask equal to 1 on offset -> offset + batch_size
+            curr_idx = torch.arange(offset, offset + batch_size, device=device)
+            # keep only the scores for the current batch
+            curr_scores = scores[:, curr_idx].view(-1) / self.temperature
+            # compute the labels
+            labels = -torch.ones(batch_size * batch_size, device=device)
+            if k == 0:
+                flat_pos = (pos_idx - offset) * (batch_size + 1)
+                labels[flat_pos] = 1.0
+            # compute the loss
+            block_loss = F.softplus(curr_scores * labels)
+            all_losses.append(block_loss)
+            # shift the offset for the next batch
+            offset = (offset + batch_size) % num_targets
+
+        return torch.stack(all_losses, dim=0).mean()

+ 465 - 0
deconstruct_SQI/colpali/colpali_engine/loss/late_interaction_losses.py

@@ -0,0 +1,465 @@
+import torch
+import torch.nn.functional as F  # noqa: N812
+from torch.nn import CrossEntropyLoss
+
+
+class ColbertModule(torch.nn.Module):
+    """
+    Base module for ColBERT losses, handling shared utilities and hyperparameters.
+
+    Args:
+        max_batch_size (int): Maximum batch size for pre-allocating index buffer.
+        tau (float): Temperature for smooth-max approximation.
+        norm_tol (float): Tolerance for score normalization bounds.
+        filter_threshold (float): Ratio threshold for pos-aware negative filtering.
+        filter_factor (float): Multiplicative factor to down-weight high negatives.
+    """
+
+    def __init__(
+        self,
+        max_batch_size: int = 1024,
+        tau: float = 0.1,
+        norm_tol: float = 1e-3,
+        filter_threshold: float = 0.95,
+        filter_factor: float = 0.5,
+    ):
+        super().__init__()
+        self.register_buffer("idx_buffer", torch.arange(max_batch_size), persistent=False)
+        self.tau = tau
+        self.norm_tol = norm_tol
+        self.filter_threshold = filter_threshold
+        self.filter_factor = filter_factor
+
+    def _get_idx(self, batch_size: int, offset: int, device: torch.device):
+        """
+        Retrieve index and positive index tensors for in-batch losses.
+        """
+        idx = self.idx_buffer[:batch_size].to(device)
+        return idx, idx + offset
+
+    def _smooth_max(self, scores: torch.Tensor, dim: int) -> torch.Tensor:
+        """
+        Compute smooth max via log-sum-exp along a given dimension.
+        """
+        return self.tau * torch.logsumexp(scores / self.tau, dim=dim)
+
+    def _apply_normalization(self, scores: torch.Tensor, lengths: torch.Tensor) -> torch.Tensor:
+        """
+        Normalize scores by query lengths and enforce bounds.
+
+        Args:
+            scores (Tensor): Unnormalized score matrix [B, C].
+            lengths (Tensor): Query lengths [B].
+
+        Returns:
+            Tensor: Normalized scores.
+
+        Raises:
+            ValueError: If normalized scores exceed tolerance.
+        """
+        if scores.ndim == 2:
+            normalized = scores / lengths.unsqueeze(1)
+        else:
+            normalized = scores / lengths
+
+        mn, mx = torch.aminmax(normalized)
+        if mn < -self.norm_tol or mx > 1 + self.norm_tol:
+            print(
+                f"Scores out of bounds after normalization: "
+                f"min={mn.item():.4f}, max={mx.item():.4f}, tol={self.norm_tol}"
+            )
+        return normalized
+
+    def _aggregate(
+        self,
+        scores_raw: torch.Tensor,
+        use_smooth_max: bool,
+        dim_max: int,
+        dim_sum: int,
+    ) -> torch.Tensor:
+        """
+        Aggregate token-level scores into document-level.
+
+        Args:
+            scores_raw (Tensor): Raw scores tensor.
+            use_smooth_max (bool): Use smooth-max if True.
+            dim_max (int): Dimension to perform max/logsumexp.
+            dim_sum (int): Dimension to sum over after max.
+        """
+        if use_smooth_max:
+            return self._smooth_max(scores_raw, dim=dim_max).sum(dim=dim_sum)
+        return scores_raw.amax(dim=dim_max).sum(dim=dim_sum)
+
+    def _filter_high_negatives(self, scores: torch.Tensor, pos_idx: torch.Tensor) -> None:
+        """
+        Down-weight negatives whose score exceeds a fraction of the positive score.
+
+        Args:
+            scores (Tensor): In-batch score matrix [B, B].
+            pos_idx (Tensor): Positive indices for each query in batch.
+        """
+        batch_size = scores.size(0)
+        idx = self.idx_buffer[:batch_size].to(scores.device)
+        pos_scores = scores[idx, pos_idx]
+        thresh = self.filter_threshold * pos_scores.unsqueeze(1)
+        mask = scores > thresh
+        mask[idx, pos_idx] = False
+        scores[mask] *= self.filter_factor
+
+
+class ColbertLoss(ColbertModule):
+    """
+    InfoNCE loss for late interaction (ColBERT) without explicit negatives.
+
+    Args:
+        temperature (float): Scaling factor for logits.
+        normalize_scores (bool): Normalize scores by query lengths.
+        use_smooth_max (bool): Use log-sum-exp instead of amax.
+        pos_aware_negative_filtering (bool): Apply pos-aware negative filtering.
+    """
+
+    def __init__(
+        self,
+        temperature: float = 0.02,
+        normalize_scores: bool = True,
+        use_smooth_max: bool = False,
+        pos_aware_negative_filtering: bool = False,
+        max_batch_size: int = 1024,
+        tau: float = 0.1,
+        norm_tol: float = 1e-3,
+        filter_threshold: float = 0.95,
+        filter_factor: float = 0.5,
+    ):
+        super().__init__(max_batch_size, tau, norm_tol, filter_threshold, filter_factor)
+        self.temperature = temperature
+        self.normalize_scores = normalize_scores
+        self.use_smooth_max = use_smooth_max
+        self.pos_aware_negative_filtering = pos_aware_negative_filtering
+        self.ce_loss = CrossEntropyLoss()
+
+    def forward(self, query_embeddings: torch.Tensor, doc_embeddings: torch.Tensor, offset: int = 0) -> torch.Tensor:
+        """
+        Compute ColBERT InfoNCE loss over a batch of queries and documents.
+
+        Args:
+            query_embeddings (Tensor): (batch_size, query_length, dim)
+            doc_embeddings (Tensor): positive docs (batch_size, pos_doc_length, dim)
+            offset (int): Offset for positive doc indices (multi-GPU).
+
+        Returns:
+            Tensor: Scalar loss value.
+        """
+        lengths = (query_embeddings[:, :, 0] != 0).sum(dim=1)
+        raw = torch.einsum("bnd,csd->bcns", query_embeddings, doc_embeddings)
+        scores = self._aggregate(raw, self.use_smooth_max, dim_max=3, dim_sum=2)
+        if self.normalize_scores:
+            scores = self._apply_normalization(scores, lengths)
+
+        batch_size = scores.size(0)
+        idx, pos_idx = self._get_idx(batch_size, offset, scores.device)
+
+        if self.pos_aware_negative_filtering:
+            self._filter_high_negatives(scores, pos_idx)
+
+        return self.ce_loss(scores / self.temperature, pos_idx)
+
+
+class ColbertNegativeCELoss(ColbertModule):
+    """
+    InfoNCE loss with explicit negative documents.
+
+    Args:
+        temperature (float): Scaling for logits.
+        normalize_scores (bool): Normalize scores by query lengths.
+        use_smooth_max (bool): Use log-sum-exp instead of amax.
+        pos_aware_negative_filtering (bool): Apply pos-aware negative filtering.
+        in_batch_term_weight (float): Add in-batch CE term (between 0 and 1).
+    """
+
+    def __init__(
+        self,
+        temperature: float = 0.02,
+        normalize_scores: bool = True,
+        use_smooth_max: bool = False,
+        pos_aware_negative_filtering: bool = False,
+        in_batch_term_weight: float = 0.5,
+        max_batch_size: int = 1024,
+        tau: float = 0.1,
+        norm_tol: float = 1e-3,
+        filter_threshold: float = 0.95,
+        filter_factor: float = 0.5,
+    ):
+        super().__init__(max_batch_size, tau, norm_tol, filter_threshold, filter_factor)
+        self.temperature = temperature
+        self.normalize_scores = normalize_scores
+        self.use_smooth_max = use_smooth_max
+        self.pos_aware_negative_filtering = pos_aware_negative_filtering
+        self.in_batch_term_weight = in_batch_term_weight
+        self.ce_loss = CrossEntropyLoss()
+
+        assert in_batch_term_weight >= 0, "in_batch_term_weight must be non-negative"
+        assert in_batch_term_weight <= 1, "in_batch_term_weight must be less than 1"
+
+        self.inner_loss = ColbertLoss(
+            temperature=temperature,
+            normalize_scores=normalize_scores,
+            use_smooth_max=use_smooth_max,
+            pos_aware_negative_filtering=pos_aware_negative_filtering,
+            max_batch_size=max_batch_size,
+            tau=tau,
+            norm_tol=norm_tol,
+            filter_threshold=filter_threshold,
+            filter_factor=filter_factor,
+        )
+
+    def forward(
+        self,
+        query_embeddings: torch.Tensor,
+        doc_embeddings: torch.Tensor,
+        neg_doc_embeddings: torch.Tensor,
+        offset: int = 0,
+    ) -> torch.Tensor:
+        """
+        Compute InfoNCE loss with explicit negatives and optional in-batch term.
+
+        Args:
+            query_embeddings (Tensor): (batch_size, query_length, dim)
+            doc_embeddings (Tensor): positive docs (batch_size, pos_doc_length, dim)
+            neg_doc_embeddings (Tensor): negative docs (batch_size, num_negs, neg_doc_length, dim)
+            offset (int): Positional offset for in-batch CE.
+
+        Returns:
+            Tensor: Scalar loss.
+        """
+        lengths = (query_embeddings[:, :, 0] != 0).sum(dim=1)
+        pos_raw = torch.einsum(
+            "bnd,bsd->bns", query_embeddings, doc_embeddings[offset : offset + neg_doc_embeddings.size(0)]
+        )
+        neg_raw = torch.einsum("bnd,blsd->blns", query_embeddings, neg_doc_embeddings)
+        pos_scores = self._aggregate(pos_raw, self.use_smooth_max, dim_max=2, dim_sum=1)
+        neg_scores = self._aggregate(neg_raw, self.use_smooth_max, dim_max=3, dim_sum=2)
+
+        if self.normalize_scores:
+            pos_scores = self._apply_normalization(pos_scores, lengths)
+            neg_scores = self._apply_normalization(neg_scores, lengths)
+
+        loss = F.softplus((neg_scores - pos_scores.unsqueeze(1)) / self.temperature).mean()
+
+        if self.in_batch_term_weight > 0:
+            loss_ib = self.inner_loss(query_embeddings, doc_embeddings, offset)
+            loss = loss * (1 - self.in_batch_term_weight) + loss_ib * self.in_batch_term_weight
+
+        return loss
+
+
+class ColbertPairwiseCELoss(ColbertModule):
+    """
+    Pairwise loss for ColBERT (no explicit negatives).
+
+    Args:
+        temperature (float): Scaling for logits.
+        normalize_scores (bool): Normalize scores by query lengths.
+        use_smooth_max (bool): Use log-sum-exp instead of amax.
+        pos_aware_negative_filtering (bool): Apply pos-aware negative filtering.
+    """
+
+    def __init__(
+        self,
+        temperature: float = 1.0,
+        normalize_scores: bool = True,
+        use_smooth_max: bool = False,
+        pos_aware_negative_filtering: bool = False,
+        max_batch_size: int = 1024,
+        tau: float = 0.1,
+        norm_tol: float = 1e-3,
+        filter_threshold: float = 0.95,
+        filter_factor: float = 0.5,
+    ):
+        super().__init__(max_batch_size, tau, norm_tol, filter_threshold, filter_factor)
+        self.temperature = temperature
+        self.normalize_scores = normalize_scores
+        self.use_smooth_max = use_smooth_max
+        self.pos_aware_negative_filtering = pos_aware_negative_filtering
+
+    def forward(self, query_embeddings: torch.Tensor, doc_embeddings: torch.Tensor, offset: int = 0) -> torch.Tensor:
+        """
+        Compute pairwise softplus loss over in-batch document pairs.
+
+        Args:
+            query_embeddings (Tensor): (batch_size, query_length, dim)
+            doc_embeddings (Tensor): positive docs (batch_size, pos_doc_length, dim)
+            offset (int): Positional offset for positives.
+
+        Returns:
+            Tensor: Scalar loss value.
+        """
+        lengths = (query_embeddings[:, :, 0] != 0).sum(dim=1)
+        raw = torch.einsum("bnd,csd->bcns", query_embeddings, doc_embeddings)
+        scores = self._aggregate(raw, self.use_smooth_max, dim_max=3, dim_sum=2)
+
+        if self.normalize_scores:
+            scores = self._apply_normalization(scores, lengths)
+
+        batch_size = scores.size(0)
+        idx, pos_idx = self._get_idx(batch_size, offset, scores.device)
+
+        if self.pos_aware_negative_filtering:
+            self._filter_high_negatives(scores, pos_idx)
+
+        pos_scores = scores.diagonal(offset=offset)
+        top2 = scores.topk(2, dim=1).values
+        neg_scores = torch.where(top2[:, 0] == pos_scores, top2[:, 1], top2[:, 0])
+
+        return F.softplus((neg_scores - pos_scores) / self.temperature).mean()
+
+
+class ColbertPairwiseNegativeCELoss(ColbertModule):
+    """
+    Pairwise loss with explicit negatives and optional in-batch term.
+
+    Args:
+        temperature (float): Scaling for logits.
+        normalize_scores (bool): Normalize scores by query lengths.
+        use_smooth_max (bool): Use log-sum-exp instead of amax.
+        pos_aware_negative_filtering (bool): Apply pos-aware negative filtering.
+        in_batch_term_weight (float): Add in-batch CE term (between 0 and 1).
+    """
+
+    def __init__(
+        self,
+        temperature: float = 0.02,
+        normalize_scores: bool = True,
+        use_smooth_max: bool = False,
+        pos_aware_negative_filtering: bool = False,
+        in_batch_term_weight: float = 0.5,
+        max_batch_size: int = 1024,
+        tau: float = 0.1,
+        norm_tol: float = 1e-3,
+        filter_threshold: float = 0.95,
+        filter_factor: float = 0.5,
+    ):
+        super().__init__(max_batch_size, tau, norm_tol, filter_threshold, filter_factor)
+        self.temperature = temperature
+        self.normalize_scores = normalize_scores
+        self.use_smooth_max = use_smooth_max
+        self.pos_aware_negative_filtering = pos_aware_negative_filtering
+        self.in_batch_term_weight = in_batch_term_weight
+        assert in_batch_term_weight >= 0, "in_batch_term_weight must be non-negative"
+        assert in_batch_term_weight <= 1, "in_batch_term_weight must be less than 1"
+        self.inner_pairwise = ColbertPairwiseCELoss(
+            temperature=temperature,
+            normalize_scores=normalize_scores,
+            use_smooth_max=use_smooth_max,
+            pos_aware_negative_filtering=pos_aware_negative_filtering,
+            max_batch_size=max_batch_size,
+            tau=tau,
+            norm_tol=norm_tol,
+            filter_threshold=filter_threshold,
+            filter_factor=filter_factor,
+        )
+
+    def forward(
+        self,
+        query_embeddings: torch.Tensor,
+        doc_embeddings: torch.Tensor,
+        neg_doc_embeddings: torch.Tensor,
+        offset: int = 0,
+    ) -> torch.Tensor:
+        """
+        Compute pairwise softplus loss with explicit negatives and optional in-batch term.
+
+        Args:
+            query_embeddings (Tensor): (batch_size, query_length, dim)
+            doc_embeddings (Tensor): positive docs (batch_size, pos_doc_length, dim)
+            neg_doc_embeddings (Tensor): negative docs (batch_size, num_negs, neg_doc_length, dim)
+            offset (int): Positional offset for positives.
+
+        Returns:
+            Tensor: Scalar loss value.
+        """
+        lengths = (query_embeddings[:, :, 0] != 0).sum(dim=1)
+        pos_raw = torch.einsum(
+            "bnd,bld->bnl", query_embeddings, doc_embeddings[offset : offset + query_embeddings.size(0)]
+        )
+        neg_raw = torch.einsum("bnd,bsld->bsnl", query_embeddings, neg_doc_embeddings)  # B x Nneg x Nq x Lneg
+        pos_scores = self._aggregate(pos_raw, self.use_smooth_max, dim_max=2, dim_sum=1)
+        neg_scores = self._aggregate(neg_raw, self.use_smooth_max, dim_max=3, dim_sum=2)
+
+        if self.normalize_scores:
+            pos_scores = self._apply_normalization(pos_scores, lengths)
+            neg_scores = self._apply_normalization(neg_scores, lengths)
+
+        loss = F.softplus((neg_scores - pos_scores.unsqueeze(1)) / self.temperature).mean()
+
+        if self.in_batch_term_weight > 0:
+            loss_ib = self.inner_pairwise(query_embeddings, doc_embeddings, offset)
+            loss = loss * (1 - self.in_batch_term_weight) + loss_ib * self.in_batch_term_weight
+
+        return loss
+
+
+class ColbertSigmoidLoss(ColbertModule):
+    """
+    Sigmoid loss for ColBERT with explicit negatives.
+
+    Args:
+        temperature (float): Scaling for logits.
+        normalize_scores (bool): Normalize scores by query lengths.
+        use_smooth_max (bool): Use log-sum-exp instead of amax.
+        pos_aware_negative_filtering (bool): Apply pos-aware negative filtering.
+    """
+
+    def __init__(
+        self,
+        temperature: float = 0.02,
+        normalize_scores: bool = True,
+        use_smooth_max: bool = False,
+        pos_aware_negative_filtering: bool = False,
+        max_batch_size: int = 1024,
+        tau: float = 0.1,
+        norm_tol: float = 1e-3,
+        filter_threshold: float = 0.95,
+        filter_factor: float = 0.5,
+    ):
+        super().__init__(max_batch_size, tau, norm_tol, filter_threshold, filter_factor)
+        self.temperature = temperature
+        self.normalize_scores = normalize_scores
+        self.use_smooth_max = use_smooth_max
+        self.pos_aware_negative_filtering = pos_aware_negative_filtering
+        self.ce_loss = CrossEntropyLoss()
+
+    def forward(self, query_embeddings: torch.Tensor, doc_embeddings: torch.Tensor, offset: int = 0) -> torch.Tensor:
+        """
+        Compute sigmoid loss over positive and negative document pairs.
+
+        Args:
+            query_embeddings (Tensor): (batch_size, query_length, dim)
+            doc_embeddings (Tensor): positive docs (batch_size, pos_doc_length, dim)
+
+        Returns:
+            Tensor: Scalar loss value.
+        """
+
+        lengths = (query_embeddings[:, :, 0] != 0).sum(dim=1)
+        raw = torch.einsum("bnd,csd->bcns", query_embeddings, doc_embeddings)
+        scores = self._aggregate(raw, self.use_smooth_max, dim_max=3, dim_sum=2)
+
+        if self.normalize_scores:
+            scores = self._apply_normalization(scores, lengths)
+
+        batch_size = scores.size(0)
+        idx, pos_idx = self._get_idx(batch_size, offset, scores.device)
+
+        if self.pos_aware_negative_filtering:
+            self._filter_high_negatives(scores, pos_idx)
+
+        # for each idx in pos_idx, the 2D index (idx, idx) → flat index = idx * B + idx
+        # build a 1-D mask of length B*B with ones at those positions
+        flat_pos = pos_idx * (batch_size + 1)
+        pos_mask = -torch.ones(batch_size * batch_size, device=scores.device)
+        pos_mask[flat_pos] = 1.0
+
+        # flatten the scores to [B * B]
+        scores = scores.view(-1) / self.temperature
+
+        return F.softplus(scores * pos_mask).mean()

+ 5 - 0
deconstruct_SQI/colpali/colpali_engine/models/__init__.py

@@ -0,0 +1,5 @@
+from .idefics3 import BiIdefics3, BiIdefics3Processor, ColIdefics3, ColIdefics3Processor
+from .modernvbert import BiModernVBert, BiModernVBertProcessor, ColModernVBert, ColModernVBertProcessor
+from .paligemma import BiPali, BiPaliProcessor, BiPaliProj, ColPali, ColPaliProcessor
+from .qwen2 import BiQwen2, BiQwen2Processor, ColQwen2, ColQwen2Processor
+from .qwen2_5 import BiQwen2_5, BiQwen2_5_Processor, ColQwen2_5, ColQwen2_5_Processor

+ 2 - 0
deconstruct_SQI/colpali/colpali_engine/models/idefics3/__init__.py

@@ -0,0 +1,2 @@
+from .biidefics3 import BiIdefics3, BiIdefics3Processor
+from .colidefics3 import ColIdefics3, ColIdefics3Processor

+ 2 - 0
deconstruct_SQI/colpali/colpali_engine/models/idefics3/biidefics3/__init__.py

@@ -0,0 +1,2 @@
+from .modeling_biidefics3 import BiIdefics3
+from .processing_biidefics3 import BiIdefics3Processor

+ 57 - 0
deconstruct_SQI/colpali/colpali_engine/models/idefics3/biidefics3/modeling_biidefics3.py

@@ -0,0 +1,57 @@
+from typing import Literal
+
+import torch
+from transformers import Idefics3Config, Idefics3Model, Idefics3PreTrainedModel
+
+
+class BiIdefics3(Idefics3PreTrainedModel):
+    """
+    Initializes the BiIdefics3 model.
+
+    Args:
+        config : The model configuration.
+    """
+
+    def __init__(self, config: Idefics3Config):
+        super(BiIdefics3, self).__init__(config=config)
+        self.model: Idefics3Model = Idefics3Model(config)
+        self.padding_side = "left"
+        self.post_init()
+
+    def forward(
+        self,
+        pooling_strategy: Literal["cls", "last", "mean"] = "last",
+        *args,
+        **kwargs,
+    ) -> torch.Tensor:
+        """
+        Forward pass through model and pooling.
+
+        Args:
+        - pooling_strategy (str): The pooling strategy to use. Options are "cls", "last", or "mean".
+        - input_ids (torch.LongTensor): The input tokens tensor.
+        - attention_mask (torch.LongTensor): The attention mask tensor.
+
+        Returns:
+        - torch.Tensor: Embeddings of shape (batch_size, dim)
+        """
+        outputs = self.model(*args, **kwargs)
+        last_hidden_states = outputs[0]  # (batch_size, sequence_length, hidden_size)
+
+        # Get CLS token embedding, last token, or mean pool over sequence
+        if pooling_strategy == "cls":
+            # Use CLS token (first token) embedding
+            pooled_output = last_hidden_states[:, 0]  # (batch_size, hidden_size)
+        elif pooling_strategy == "last":
+            # use last token since we are left padding
+            pooled_output = last_hidden_states[:, -1]  # (batch_size, hidden_size)
+        elif pooling_strategy == "mean":
+            # Mean pooling over sequence length
+            mask = kwargs["attention_mask"].unsqueeze(-1)  # (batch_size, sequence_length, 1)
+            pooled_output = (last_hidden_states * mask).sum(dim=1) / mask.sum(dim=1)  # (batch_size, hidden_size)
+        else:
+            raise ValueError(f"Invalid pooling strategy: {pooling_strategy}")
+
+        # L2 normalization
+        pooled_output = pooled_output / pooled_output.norm(dim=-1, keepdim=True)
+        return pooled_output

+ 40 - 0
deconstruct_SQI/colpali/colpali_engine/models/idefics3/biidefics3/processing_biidefics3.py

@@ -0,0 +1,40 @@
+from typing import List, Optional, Union
+
+import torch
+from transformers import BatchEncoding, BatchFeature
+
+from colpali_engine.models.idefics3.colidefics3 import ColIdefics3Processor
+
+
+class BiIdefics3Processor(ColIdefics3Processor):
+    """
+    Processor for BiIdefics3.
+    """
+
+    def process_texts(self, texts: List[str]) -> Union[BatchFeature, BatchEncoding]:
+        """
+        Process texts for BiIdefics3.
+
+        Args:
+            texts: List of input texts.
+
+        Returns:
+            Union[BatchFeature, BatchEncoding]: Processed texts.
+        """
+        return self(
+            text=texts,
+            return_tensors="pt",
+            padding="longest",
+        )
+
+    def score(
+        self,
+        qs: List[torch.Tensor],
+        ps: List[torch.Tensor],
+        device: Optional[Union[str, torch.device]] = None,
+        **kwargs,
+    ) -> torch.Tensor:
+        """
+        Compute the cosine similarity for the given query and passage embeddings.
+        """
+        return self.score_single_vector(qs, ps, device=device)

+ 2 - 0
deconstruct_SQI/colpali/colpali_engine/models/idefics3/colidefics3/__init__.py

@@ -0,0 +1,2 @@
+from .modeling_colidefics3 import ColIdefics3
+from .processing_colidefics3 import ColIdefics3Processor

+ 46 - 0
deconstruct_SQI/colpali/colpali_engine/models/idefics3/colidefics3/modeling_colidefics3.py

@@ -0,0 +1,46 @@
+from torch import nn
+from transformers import Idefics3Model, Idefics3PreTrainedModel
+
+
+class ColIdefics3(Idefics3PreTrainedModel):
+    """
+    Initializes the ColIdefics3 model.
+
+    Args:
+        config : The model configuration.
+        mask_non_image_embeddings (Optional[bool]): Whether to ignore all tokens embeddings
+        except those of the image at inference.
+        Defaults to False --> Do not mask any embeddings during forward pass.
+    """
+
+    def __init__(self, config, mask_non_image_embeddings: bool = False):
+        super(ColIdefics3, self).__init__(config=config)
+        self.model: Idefics3Model = Idefics3Model(config)
+        self.dim = 128
+        self.linear = nn.Linear(self.model.config.text_config.hidden_size, self.dim)
+        self.mask_non_image_embeddings = mask_non_image_embeddings
+        self.main_input_name = "doc_input_ids"
+
+    def forward(self, *args, **kwargs):
+        """
+        Forward pass through Llama and the linear layer for dimensionality reduction
+
+        Args:
+        - input_ids (torch.LongTensor): The input tokens tensor.
+        - attention_mask (torch.LongTensor): The attention mask tensor.
+
+        Returns:
+        - torch.Tensor: Embeddings of shape (batch_size, num_tokens, dim)
+        """
+        outputs = self.model(*args, **kwargs)
+        last_hidden_states = outputs[0]  # (batch_size, sequence_length, hidden_size)
+        proj = self.linear(last_hidden_states)
+        # normalize l2 norm
+        proj = proj / proj.norm(dim=-1, keepdim=True)
+        proj = proj * kwargs["attention_mask"].unsqueeze(-1)
+
+        if "pixel_values" in kwargs and self.mask_non_image_embeddings:
+            # Pools only the image embeddings
+            image_mask = (kwargs["input_ids"] == self.config.image_token_id).unsqueeze(-1)
+            proj = proj * image_mask
+        return proj

+ 76 - 0
deconstruct_SQI/colpali/colpali_engine/models/idefics3/colidefics3/processing_colidefics3.py

@@ -0,0 +1,76 @@
+from typing import ClassVar, List, Optional, Tuple, Union
+
+import torch
+from PIL import Image
+from transformers import BatchEncoding, BatchFeature, Idefics3Processor
+
+from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor
+
+
+class ColIdefics3Processor(BaseVisualRetrieverProcessor, Idefics3Processor):
+    """
+    Processor for ColIdefics3.
+    """
+
+    query_augmentation_token: ClassVar[str] = "<end_of_utterance>"
+    image_token: ClassVar[str] = "<image>"
+    visual_prompt_prefix: ClassVar[str] = "<|im_start|>User:<image>Describe the image.<end_of_utterance>\nAssistant:"
+
+    def __init__(self, *args, image_seq_len=64, **kwargs):
+        super().__init__(*args, image_seq_len=image_seq_len, **kwargs)
+        self.tokenizer.padding_side = "left"
+
+    def process_images(
+        self,
+        images: List[Image.Image],
+    ) -> Union[BatchFeature, BatchEncoding]:
+        """
+        Process images for ColIdefics3.
+
+        Args:
+            images: List of PIL images.
+        """
+        images = [image.convert("RGB") for image in images]
+
+        batch_doc = self(
+            text=[self.visual_prompt_prefix] * len(images),
+            images=images,
+            padding="longest",
+            return_tensors="pt",
+        )
+        return batch_doc
+
+    def process_texts(self, texts: List[str]) -> Union[BatchFeature, BatchEncoding]:
+        """
+        Process texts for ColIdefics3.
+
+        Args:
+            texts: List of input texts.
+
+        Returns:
+            Union[BatchFeature, BatchEncoding]: Processed texts.
+        """
+        return self(
+            text=texts,
+            return_tensors="pt",
+            padding="longest",
+        )
+
+    def score(
+        self,
+        qs: List[torch.Tensor],
+        ps: List[torch.Tensor],
+        device: Optional[Union[str, torch.device]] = None,
+        **kwargs,
+    ) -> torch.Tensor:
+        """
+        Compute the MaxSim score (ColBERT-like) for the given multi-vector query and passage embeddings.
+        """
+        return self.score_multi_vector(qs, ps, device=device, **kwargs)
+
+    def get_n_patches(
+        self,
+        image_size: Tuple[int, int],
+        patch_size: int,
+    ) -> Tuple[int, int]:
+        raise NotImplementedError("This method is not implemented for ColIdefics3.")

+ 2 - 0
deconstruct_SQI/colpali/colpali_engine/models/modernvbert/__init__.py

@@ -0,0 +1,2 @@
+from .bivbert import BiModernVBert, BiModernVBertProcessor
+from .colvbert import ColModernVBert, ColModernVBertProcessor

+ 2 - 0
deconstruct_SQI/colpali/colpali_engine/models/modernvbert/bivbert/__init__.py

@@ -0,0 +1,2 @@
+from .modeling_bimodernvbert import BiModernVBert
+from .processing_bimodernvbert import BiModernVBertProcessor

+ 66 - 0
deconstruct_SQI/colpali/colpali_engine/models/modernvbert/bivbert/modeling_bimodernvbert.py

@@ -0,0 +1,66 @@
+from typing import Literal
+
+import torch
+
+from colpali_engine.models.modernvbert.modeling_modernvbert import ModernVBertModel, ModernVBertPreTrainedModel
+
+
+class BiModernVBert(ModernVBertPreTrainedModel):
+    """
+    Initializes the BiModernVBert model.
+
+    Args:
+        config : The model configuration.
+    """
+
+    supports_gradient_checkpointing = True
+    _supports_flash_attn_2 = True
+    _supports_sdpa = True
+    _supports_cache_class = True
+
+    def __init__(self, config, pooling_strategy="mean", **kwargs):
+        super().__init__(config=config)
+        self.model = ModernVBertModel(config, **kwargs)
+        self.pooling_strategy = pooling_strategy
+        self.eps = 1e-12
+        self.post_init()
+
+    def forward(
+        self,
+        pooling_strategy: Literal["cls", "last", "mean"] = None,
+        *args,
+        **kwargs,
+    ) -> torch.Tensor:
+        """
+        Forward pass through model and pooling.
+
+        Args:
+        - pooling_strategy (str): The pooling strategy to use. Options are "cls", "last", or "mean".
+        - input_ids (torch.LongTensor): The input tokens tensor.
+        - attention_mask (torch.LongTensor): The attention mask tensor.
+
+        Returns:
+        - torch.Tensor: Embeddings of shape (batch_size, dim)
+        """
+        outputs = self.model(*args, **kwargs)
+        last_hidden_states = outputs[0]  # (batch_size, sequence_length, hidden_size)
+
+        pooling_strategy = pooling_strategy or self.pooling_strategy
+
+        # Get CLS token embedding, last token, or mean pool over sequence
+        if pooling_strategy == "cls":
+            # Use CLS token (first token) embedding
+            pooled_output = last_hidden_states[:, 0]  # (batch_size, hidden_size)
+        elif pooling_strategy == "last":
+            # Use last token
+            pooled_output = last_hidden_states[:, -1]  # (batch_size, hidden_size)
+        elif pooling_strategy == "mean":
+            # Mean pooling over sequence length
+            mask = kwargs["attention_mask"].unsqueeze(-1)  # (batch_size, sequence_length, 1)
+            pooled_output = (last_hidden_states * mask).sum(dim=1) / mask.sum(dim=1)  # (batch_size, hidden_size)
+        else:
+            raise ValueError(f"Invalid pooling strategy: {pooling_strategy}")
+
+        # L2 normalization
+        pooled_output = pooled_output / pooled_output.norm(dim=-1, keepdim=True).clamp_min(self.eps)
+        return pooled_output

+ 36 - 0
deconstruct_SQI/colpali/colpali_engine/models/modernvbert/bivbert/processing_bimodernvbert.py

@@ -0,0 +1,36 @@
+from typing import List, Optional, Union
+
+import torch
+from transformers import BatchEncoding, BatchFeature
+
+from colpali_engine.models.modernvbert.colvbert import ColModernVBertProcessor  # noqa: N801
+
+
+class BiModernVBertProcessor(ColModernVBertProcessor):  # noqa: N801
+    """
+    Processor for BiVBert.
+    """
+
+    def process_texts(self, texts: List[str]) -> Union[BatchFeature, BatchEncoding]:
+        """
+        Process texts for BiModernVBert.
+
+        Args:
+            texts: List of input texts.
+
+        Returns:
+            Union[BatchFeature, BatchEncoding]: Processed texts.
+        """
+        return self(text=texts, return_tensors="pt", padding="longest")
+
+    def score(
+        self,
+        qs: List[torch.Tensor],
+        ps: List[torch.Tensor],
+        device: Optional[Union[str, torch.device]] = None,
+        **kwargs,
+    ) -> torch.Tensor:
+        """
+        Compute the cosine similarity for the given query and passage embeddings.
+        """
+        return self.score_single_vector(qs, ps, device=device)

+ 2 - 0
deconstruct_SQI/colpali/colpali_engine/models/modernvbert/colvbert/__init__.py

@@ -0,0 +1,2 @@
+from .modeling_colmodernvbert import ColModernVBert
+from .processing_colmodernvbert import ColModernVBertProcessor

+ 52 - 0
deconstruct_SQI/colpali/colpali_engine/models/modernvbert/colvbert/modeling_colmodernvbert.py

@@ -0,0 +1,52 @@
+from torch import nn
+
+from colpali_engine.models.modernvbert.modeling_modernvbert import ModernVBertModel, ModernVBertPreTrainedModel
+
+
+class ColModernVBert(ModernVBertPreTrainedModel):
+    """
+    Initializes the ColModernVBert model.
+
+    Args:
+        config : The model configuration.
+        mask_non_image_embeddings (Optional[bool]): Whether to ignore all tokens embeddings
+        except those of the image at inference.
+        Defaults to False --> Do not mask any embeddings during forward pass.
+    """
+
+    supports_gradient_checkpointing = True
+    _supports_flash_attn_2 = True
+    _supports_sdpa = True
+    _supports_cache_class = True
+
+    def __init__(self, config, mask_non_image_embeddings: bool = False, **kwargs):
+        super().__init__(config=config)
+        self.model = ModernVBertModel(config, **kwargs)
+        self.dim = 128
+        self.custom_text_proj = nn.Linear(self.model.config.text_config.hidden_size, self.dim)
+        self.mask_non_image_embeddings = mask_non_image_embeddings
+        self.main_input_name = "doc_input_ids"
+
+    def forward(self, *args, **kwargs):
+        """
+        Forward pass through the model and the linear layer for dimensionality reduction
+
+        Args:
+        - input_ids (torch.LongTensor): The input tokens tensor.
+        - attention_mask (torch.LongTensor): The attention mask tensor.
+
+        Returns:
+        - torch.Tensor: Embeddings of shape (batch_size, num_tokens, dim)
+        """
+        outputs = self.model(*args, **kwargs)
+        last_hidden_states = outputs[0]  # (batch_size, sequence_length, hidden_size)
+        proj = self.custom_text_proj(last_hidden_states)
+        # normalize l2 norm
+        proj = proj / proj.norm(dim=-1, keepdim=True).clamp_min(1e-12)
+        proj = proj * kwargs["attention_mask"].unsqueeze(-1)
+
+        if "pixel_values" in kwargs and self.mask_non_image_embeddings:
+            # Pools only the image embeddings
+            image_mask = (kwargs["input_ids"] == self.config.image_token_id).unsqueeze(-1)
+            proj = proj * image_mask
+        return proj

+ 78 - 0
deconstruct_SQI/colpali/colpali_engine/models/modernvbert/colvbert/processing_colmodernvbert.py

@@ -0,0 +1,78 @@
+from typing import ClassVar, List, Optional, Tuple, Union
+
+import torch
+from PIL import Image
+from transformers import BatchEncoding, BatchFeature, Idefics3Processor
+
+from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor
+
+
+class ColModernVBertProcessor(BaseVisualRetrieverProcessor, Idefics3Processor):
+    """
+    Processor for ColIdefics3.
+    """
+
+    query_augmentation_token: ClassVar[str] = "<end_of_utterance>"
+    image_token: ClassVar[str] = "<image>"
+    visual_prompt_prefix: ClassVar[str] = (
+        "<|begin_of_text|>User:<image>Describe the image.<end_of_utterance>\nAssistant:"
+    )
+
+    def __init__(self, *args, image_seq_len=64, **kwargs):
+        super().__init__(*args, image_seq_len=image_seq_len, **kwargs)
+        self.tokenizer.padding_side = "left"
+
+    def process_images(
+        self,
+        images: List[Image.Image],
+    ) -> Union[BatchFeature, BatchEncoding]:
+        """
+        Process images for ColModernVBert.
+
+        Args:
+            images: List of PIL images.
+        """
+        images = [image.convert("RGB") for image in images]
+
+        batch_doc = self(
+            text=[self.visual_prompt_prefix] * len(images),
+            images=images,
+            padding="longest",
+            return_tensors="pt",
+        )
+        return batch_doc
+
+    def process_texts(self, texts: List[str]) -> Union[BatchFeature, BatchEncoding]:
+        """
+        Process texts for ColModernVBert.
+
+        Args:
+            texts: List of input texts.
+
+        Returns:
+            Union[BatchFeature, BatchEncoding]: Processed texts.
+        """
+        return self(
+            text=texts,
+            return_tensors="pt",
+            padding="longest",
+        )
+
+    def score(
+        self,
+        qs: List[torch.Tensor],
+        ps: List[torch.Tensor],
+        device: Optional[Union[str, torch.device]] = None,
+        **kwargs,
+    ) -> torch.Tensor:
+        """
+        Compute the MaxSim score (ColBERT-like) for the given multi-vector query and passage embeddings.
+        """
+        return self.score_multi_vector(qs, ps, device=device, **kwargs)
+
+    def get_n_patches(
+        self,
+        image_size: Tuple[int, int],
+        patch_size: int,
+    ) -> Tuple[int, int]:
+        raise NotImplementedError("This method is not implemented for ColIdefics3.")

+ 279 - 0
deconstruct_SQI/colpali/colpali_engine/models/modernvbert/configuration_modernvbert.py

@@ -0,0 +1,279 @@
+import copy
+import os
+from typing import Any, Dict, Union
+
+from transformers import AutoConfig
+from transformers.configuration_utils import PretrainedConfig
+from transformers.utils import logging
+
+logger = logging.get_logger(__name__)
+
+DEFAULT_TEXT_MODEL_NAME = "jhu-clsp/ettin-encoder-150m"
+DEFAULT_VISION_MODEL_NAME = "google/siglip2-base-patch16-512"
+
+
+def collect_arg_in_candidates(config, candidates, default=None) -> Any:
+    """Gets the first available argument in a config given a list of candidate names."""
+    for c in candidates:
+        if hasattr(config, c):
+            return getattr(config, c)
+        elif c in config:
+            return config[c]
+    if default is not None:
+        return default
+    raise ValueError(f"No matching arguments found in candidates. Candidates: {candidates}, Config: {config}")
+
+
+class ModernVBertTextConfig(PretrainedConfig):
+    r"""
+    This is the configuration class to store the configuration of a [`ModernBERT`].
+        It is used to instantiate an ModernBERT
+    model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
+    defaults will yield a similar configuration to that of the
+        [jhu-clsp/ettin-encoder-150m](https://huggingface.co/jhu-clsp/ettin-encoder-150m) architecture.
+
+    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+    documentation from [`PretrainedConfig`] for more information.
+    """
+
+    model_type = "modernvbert_text"
+
+    def __init__(
+        self,
+        text_model_name=DEFAULT_TEXT_MODEL_NAME,
+        hidden_size=768,
+        num_hidden_layers=22,
+        intermediate_size=1152,
+        mlp_bias=False,
+        vocab_size=50368,
+        **kwargs,
+    ):
+        super().__init__(
+            text_model_name=text_model_name,
+            hidden_size=hidden_size,
+            num_hidden_layers=num_hidden_layers,
+            intermediate_size=intermediate_size,
+            mlp_bias=mlp_bias,
+            vocab_size=vocab_size,
+            **kwargs,
+        )
+
+    @classmethod
+    def from_base_model(
+        cls,
+        text_model_name=DEFAULT_TEXT_MODEL_NAME,
+        **kwargs,
+    ):
+        text_config = AutoConfig.from_pretrained(text_model_name, trust_remote_code=True)
+        if hasattr(text_config, "text_config"):
+            text_config = text_config.text_config
+
+        hidden_size = collect_arg_in_candidates(text_config, ["hidden_size", "embed_dim"])
+        num_hidden_layers = collect_arg_in_candidates(text_config, ["num_hidden_layers", "num_hidden_blocks"])
+        intermediate_size = collect_arg_in_candidates(text_config, ["intermediate_size", "mlp_dim"])
+        mlp_bias = collect_arg_in_candidates(text_config, ["mlp_bias", "mlp_hidden_bias"], default=False)
+        vocab_size = collect_arg_in_candidates(text_config, ["vocab_size"])
+
+        return cls(
+            text_model_name=text_model_name,
+            hidden_size=hidden_size,
+            num_hidden_layers=num_hidden_layers,
+            intermediate_size=intermediate_size,
+            mlp_bias=mlp_bias,
+            vocab_size=vocab_size,
+            **kwargs,
+        )
+
+
+class ModernVBertVisionConfig(PretrainedConfig):
+    r"""
+    This is the configuration class to store the configuration of a [`SigLIP`]. It is used to instantiate
+        the vision encoder part of the ModernVBERT.
+    model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
+    defaults will yield a similar configuration to that of the SigLIP.
+
+    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+    documentation from [`PretrainedConfig`] for more information.
+    """
+
+    model_type = "modernvbert_vision"
+
+    attribute_map = {
+        "hidden_size": "embed_dim",
+    }
+
+    def __init__(
+        self,
+        vision_model_name=DEFAULT_VISION_MODEL_NAME,
+        embed_dim=768,
+        image_size=512,
+        patch_size=16,
+        num_hidden_layers=12,
+        intermediate_size=3072,
+        **kwargs,
+    ):
+        super().__init__(
+            vision_model_name=vision_model_name,
+            embed_dim=embed_dim,
+            image_size=image_size,
+            patch_size=patch_size,
+            num_hidden_layers=num_hidden_layers,
+            intermediate_size=intermediate_size,
+            **kwargs,
+        )
+
+    @classmethod
+    def from_base_model(
+        cls,
+        vision_model_name=DEFAULT_VISION_MODEL_NAME,
+        **kwargs,
+    ):
+        vision_config = AutoConfig.from_pretrained(vision_model_name, trust_remote_code=True)
+        if hasattr(vision_config, "vision_config"):
+            vision_config = vision_config.vision_config
+
+        embed_dim = collect_arg_in_candidates(vision_config, ["embed_dim", "hidden_size"])
+        image_size = collect_arg_in_candidates(vision_config, ["image_size", "img_size"])
+        patch_size = collect_arg_in_candidates(vision_config, ["patch_size"])
+        num_hidden_layers = collect_arg_in_candidates(vision_config, ["num_hidden_layers", "num_hidden_blocks"])
+        intermediate_size = collect_arg_in_candidates(vision_config, ["intermediate_size", "mlp_dim"])
+
+        return cls(
+            vision_model_name=vision_model_name,
+            embed_dim=embed_dim,
+            image_size=image_size,
+            patch_size=patch_size,
+            num_hidden_layers=num_hidden_layers,
+            intermediate_size=intermediate_size,
+            **kwargs,
+        )
+
+
+class ModernVBertConfig(PretrainedConfig):
+    r"""
+    This is the configuration class to store the configuration of a `ModernVBert` model. It is used to
+    instantiate a ModernVBert model according to the specified arguments and defines the model architecture.
+
+    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs.
+    See the documentation for [`PretrainedConfig`] for more details.
+
+    Args:
+        text_config (`PretrainedConfig` or `dict`, optional):
+            Custom text config or a dict with a `text_model_name` key for the text encoder. If `None`, the
+            default text backbone defined by `DEFAULT_TEXT_MODEL_NAME` is used.
+        vision_config (`PretrainedConfig` or `dict`, optional):
+            Custom vision config or a dict with a `vision_model_name` key for the vision encoder. If `None`, the
+            default vision backbone defined by `DEFAULT_VISION_MODEL_NAME` is used.
+        image_token_id (`int`, optional, defaults to 128257):
+            Token id reserved for image tokens inserted into the text stream.
+        vocab_size (`int`, optional, defaults to 128256):
+            Vocabulary size used by the text embeddings.
+        use_cache (`bool`, optional, defaults to `True`):
+            Whether to cache key/value tensors for attention (relevant for decoder architectures).
+        tie_word_embeddings (`bool`, optional, defaults to `False`):
+            Whether to tie input token embeddings and output token embeddings.
+        pixel_shuffle_factor (`int`, optional, defaults to 4):
+            Scale factor used by any pixel-shuffle / upsampling operations in the vision head.
+        additional_vocab_size (`int`, optional, defaults to 0):
+            Number of extra tokens appended to the base vocabulary (useful for adapters / special tokens).
+        pad_token_id (`int`, optional):
+            Padding token id.
+        initializer_range (`float`, optional, defaults to 0.02):
+            Stddev used for weight initialization.
+        freeze_config (`Any`, optional):
+            Optional config describing which submodules to freeze during training.
+        use_resampler (`bool`, optional, defaults to `False`):
+            Whether to enable an additional resampler on visual features.
+        neftune_noise_alpha (`float`, optional, defaults to 0.0):
+            Alpha parameter for neftune noise injection.
+
+    Example:
+    ```python
+    >>> from modernvbert import ModernVBertConfig
+    >>> # Initializing configuration
+    >>> configuration = ModernVBertConfig()
+    >>> # Initializing a model from the configuration (model class is implemented in
+    >>> # `modernvbert.modeling_modernvbert`)
+    >>> # from modernvbert import ModernVBertModel
+    >>> # model = ModernVBertModel(configuration)
+    >>> # Accessing the model configuration
+    >>> # cfg = model.config
+    ```"""
+
+    model_type = "modernvbert"
+    is_composition = True
+
+    def __init__(
+        self,
+        text_config: Union[PretrainedConfig, Dict[str, Any]] = None,
+        vision_config: Union[PretrainedConfig, Dict[str, Any]] = None,
+        image_token_id: int = 50407,
+        vocab_size=50368,
+        use_cache=True,
+        tie_word_embeddings=False,
+        freeze_config=None,
+        pad_token_id=None,
+        initializer_range=0.02,
+        pixel_shuffle_factor=4,
+        use_resampler=False,
+        additional_vocab_size=0,
+        neftune_noise_alpha=0.0,
+        **kwargs,
+    ):
+        self.image_token_id = image_token_id
+        self.use_cache = use_cache
+        self.tie_word_embeddings = tie_word_embeddings
+        self.scale_factor = pixel_shuffle_factor
+        self.additional_vocab_size = additional_vocab_size
+
+        if text_config is None:
+            base_text_config = AutoConfig.from_pretrained(DEFAULT_TEXT_MODEL_NAME, trust_remote_code=True)
+            text_config = ModernVBertTextConfig(base_text_config)
+        elif isinstance(text_config, dict):
+            text_config = ModernVBertTextConfig.from_dict(text_config)
+        self.text_config = text_config
+
+        if vision_config is None:
+            base_vision_config = AutoConfig.from_pretrained(DEFAULT_VISION_MODEL_NAME, trust_remote_code=True)
+            vision_config = ModernVBertVisionConfig(base_vision_config)
+        elif isinstance(vision_config, dict):
+            vision_config = ModernVBertVisionConfig.from_dict(vision_config)
+        self.vision_config = vision_config
+
+        self.freeze_config = freeze_config
+        self.pixel_shuffle_factor = pixel_shuffle_factor
+        self.use_resampler = use_resampler
+        self.neftune_noise_alpha = neftune_noise_alpha
+        self.initializer_range = initializer_range
+
+        hidden_size = kwargs.pop("hidden_size", self.text_config.hidden_size)
+
+        super().__init__(
+            **kwargs,
+            pad_token_id=pad_token_id,
+            tie_word_embeddings=tie_word_embeddings,
+            vocab_size=vocab_size,
+            hidden_size=hidden_size,
+        )
+
+    def to_dict(self):
+        output = copy.deepcopy(self.__dict__)
+        output["model_type"] = self.__class__.model_type
+        output["vision_config"] = self.vision_config.to_dict()
+        output["text_config"] = self.text_config.to_dict()
+        return output
+
+    @classmethod
+    def from_pretrained_models(
+        cls,
+        text_model_name: Union[str, os.PathLike],
+        vision_model_name: Union[str, os.PathLike],
+        **kwargs,
+    ) -> "PretrainedConfig":
+        text_model_config = ModernVBertTextConfig.from_base_model(text_model_name)
+        vision_model_config = ModernVBertVisionConfig.from_base_model(vision_model_name)
+        return cls(
+            text_config=text_model_config,
+            vision_config=vision_model_config,
+            **kwargs,
+        )

+ 476 - 0
deconstruct_SQI/colpali/colpali_engine/models/modernvbert/modeling_modernvbert.py

@@ -0,0 +1,476 @@
+from dataclasses import dataclass
+from typing import Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F  # noqa: N812
+from torch.nn import CrossEntropyLoss
+from transformers import AutoConfig, AutoModel, AutoModelForMaskedLM, PreTrainedModel, logging
+from transformers.modeling_outputs import BaseModelOutput
+from transformers.models.bert.modeling_bert import BaseModelOutputWithPoolingAndCrossAttentions, MaskedLMOutput
+
+from .configuration_modernvbert import ModernVBertConfig
+
+logger = logging.get_logger(__name__)
+
+
+class DecoupledEmbedding(nn.Embedding):
+    # Derived from https://pytorch.org/docs/stable/_modules/torch/nn/modules/sparse.html#Embedding
+    """
+    Implements a decoupling of parameters to allow freezing (or not) a subset of the embeddings.
+    In practise, the regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and
+        if `num_additional_embeddings` > 0, then it will create `num_additional_embeddings`
+        additional parameters that are always trained.
+    If `num_additional_embeddings=0`, then the module defaults back to the regular behavior of `nn.Embedding`.
+    """
+
+    def __init__(
+        self,
+        num_embeddings,
+        num_additional_embeddings,
+        embedding_dim,
+        partially_freeze=False,
+        device=None,
+        dtype=None,
+        padding_idx=None,
+        **kwargs,
+    ) -> None:
+        """
+        num_additional_embeddings: int. Number of additional embeddings. Only useful when you `partially_freeze=True`.
+        partially_freeze: bool. If True, the regular `weight` will be frozen. `additional_weight` is never frozen.
+
+        Note: there are a lot of other parameters to initialize a standard `nn.Embedding` such as `padding_idx`,
+            `max_norm` or `norm_type`. We are not supporting these.
+        """
+        if padding_idx is not None and padding_idx > num_embeddings:
+            raise ValueError(f"padding_idx must be within num_embeddings. Got {padding_idx} and {num_embeddings}")
+
+        super().__init__(
+            num_embeddings=num_embeddings,
+            embedding_dim=embedding_dim,
+            device=device,
+            dtype=dtype,
+            padding_idx=padding_idx,
+            **kwargs,
+        )
+        self.num_embeddings = num_embeddings
+        self.num_additional_embeddings = num_additional_embeddings
+        self.partially_freeze = partially_freeze
+
+        if partially_freeze:
+            self.weight.requires_grad_(False)
+
+        if self.num_additional_embeddings > 0:
+            self.additional_embedding = nn.Embedding(
+                num_embeddings=num_additional_embeddings,
+                embedding_dim=embedding_dim,
+                device=device,
+                dtype=dtype,
+            )
+
+    def forward(self, input_ids):
+        """
+        we have 2 embeddings, with different indices - one pretrained self.weight and another
+        self.additional_embedding.weight that is being trained.
+
+        in order to make a lookup of the input ids, we:
+        1. find out the indices of the entries belonging to the 2nd embedding
+        2. extract those values while subtracting the size of the first embedding (num_embeddings),
+           since the 2nd embedding starts from 0 and not num_embeddings
+        3. perform the 2nd embedding lookup
+        4. now we handle the 1st embedding, we overwrite indices belonging to the 2nd embedding with a padding index
+        5. perform the 1st embedding lookup
+        6. now we overwrite the values in the 1st embedding lookup with the values of the 2nd embedding lookup
+
+        note: for the 1st embedding lookup we could have looked up only the low indices and not do
+        the padding, but then we have to create a new tensor and populate it with 2 tensors that are
+        spread out across various indices - i.e. not a simple concat - I haven't benchmarked the
+        complex case if it's any faster, given that seqlens are usually relatively short it's
+        probably not faster or if faster not by much - but might be a good idea to measure.
+
+        """
+        if self.num_additional_embeddings == 0:
+            return super().forward(input_ids)
+
+        input_ids = input_ids.clone()
+        additional_vocab_indices = torch.where(input_ids >= self.num_embeddings)
+        input_ids_additional_vocab = input_ids[additional_vocab_indices]
+        additional_embeddings = self.additional_embedding(input_ids_additional_vocab - self.num_embeddings)
+
+        # for successful lookup replace input_ids with 0, the results of these will be discarded anyway
+        input_ids[additional_vocab_indices] = 0
+        full_vector = F.embedding(input_ids, self.weight)
+        full_vector[additional_vocab_indices] = additional_embeddings  # overwrite the records with high indices
+        return full_vector
+
+
+@dataclass
+class ModernVBertBaseModelOutput(BaseModelOutput):
+    """
+    Base class for ModernVBERT model's outputs that may also contain a past key/values (to speed up sequential decoding)
+    Args:
+        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+            Sequence of hidden-states at the output of the last layer of the model.
+            If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
+            hidden_size)` is output.
+        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed
+            or when `config.output_hidden_states=True`):
+            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed
+            or when `config.output_attentions=True`):
+            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+            heads.
+        image_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
+            Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images,
+            sequence_length, hidden_size)`.
+            image_hidden_states of the model produced by the vision encoder
+    """
+
+    last_hidden_state: torch.FloatTensor = None
+    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+    attentions: Optional[Tuple[torch.FloatTensor]] = None
+    image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+
+
+@dataclass
+class ModernVBertMaskedLMOutput(MaskedLMOutput):
+    """
+    Base class for ModernVBERT model's outputs that may also contain a past key/values (to speed up sequential decoding)
+    Args:
+        loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided):
+            Masked language modeling (MLM) loss.
+        logits (`torch.FloatTensor`):
+            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed
+        or when `config.output_hidden_states=True`):
+            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed
+        or when `config.output_attentions=True`):
+            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+            heads.
+        image_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
+            Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images,
+            sequence_length, hidden_size)`.
+            image_hidden_states of the model produced by the vision encoder
+    """
+
+    loss: Optional[torch.FloatTensor] = None
+    logits: torch.FloatTensor = None
+    hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+    attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+    image_hidden_states: Optional[torch.FloatTensor] = None
+
+
+class ModernVBertSimpleMLP(nn.Module):
+    """A simple linear projection layer to project the vision hidden states to the text hidden states."""
+
+    def __init__(self, input_size, output_size):
+        super().__init__()
+        self.proj = nn.Linear(input_size, output_size, bias=False)
+
+    def forward(self, x):
+        return self.proj(x)
+
+
+class ModernVBertConnector(nn.Module):
+    """
+    Connector module for ModernVBERT. It performs a pixel shuffle operation
+    followed by a linear projection to match the text model's hidden size.
+    Based on https://pytorch.org/docs/stable/generated/torch.nn.PixelShuffle.html
+    """
+
+    def __init__(self, config):
+        super().__init__()
+        self.scale_factor = config.pixel_shuffle_factor
+        self.modality_projection = ModernVBertSimpleMLP(
+            input_size=config.vision_config.hidden_size * (config.scale_factor**2),
+            output_size=config.text_config.hidden_size,
+        )
+
+    def pixel_shuffle(self, x, scale_factor):
+        bsz, seq, embed_dim = x.size()
+        height = width = int(seq**0.5)
+        x = x.view(bsz, height, width, embed_dim)
+        x = x.view(bsz, height, int(width / scale_factor), embed_dim * scale_factor)
+        x = x.permute(0, 2, 1, 3)
+        x = x.reshape(bsz, int(width / scale_factor), int(height / scale_factor), embed_dim * (scale_factor**2))
+        x = x.permute(0, 2, 1, 3)
+        return x.reshape(bsz, int(seq / (scale_factor**2)), embed_dim * (scale_factor**2))
+
+    def forward(self, image_hidden_states):
+        image_hidden_states = self.pixel_shuffle(image_hidden_states, self.scale_factor)
+        return self.modality_projection(image_hidden_states)
+
+
+class ModernVBertPreTrainedModel(PreTrainedModel):
+    config_class = ModernVBertConfig
+    base_model_prefix = "model"
+    supports_gradient_checkpointing = True
+    _supports_flash_attn_2 = True
+    _supports_sdpa = True
+
+    def _init_weights(self, module):
+        std = getattr(self.config, "initializer_range", 0.02)
+        if isinstance(module, (nn.Linear, nn.Conv2d)):
+            module.weight.data.normal_(mean=0.0, std=std)
+            if module.bias is not None:
+                module.bias.data.zero_()
+        elif isinstance(module, nn.Embedding):
+            module.weight.data.normal_(mean=0.0, std=std)
+            if module.padding_idx is not None:
+                module.weight.data[module.padding_idx].zero_()
+
+
+class ModernVBertModel(ModernVBertPreTrainedModel):
+    def __init__(self, config: ModernVBertConfig):
+        super().__init__(config)
+        self.vision_model = ModernVBertModel.init_vision_model(config)
+        self.connector = ModernVBertConnector(config)
+        self.text_model = ModernVBertModel.init_language_model(config)
+        self.image_seq_len = int(
+            ((config.vision_config.image_size // config.vision_config.patch_size) ** 2) / (config.scale_factor**2)
+        )
+        self.image_token_id = config.image_token_id
+        self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
+        # set the correct dtype for vision and text models
+        self.vision_model.to(self.dtype)
+        self.text_model.to(self.dtype)
+        self.post_init()
+
+    @staticmethod
+    def init_vision_model(config: ModernVBertConfig):
+        vision_model_config = AutoConfig.from_pretrained(
+            config.vision_config.vision_model_name,
+            _attn_implementation=config._attn_implementation,
+        )
+        vision_model = AutoModel.from_config(
+            vision_model_config,
+            trust_remote_code=True,
+        )
+        return getattr(vision_model, "vision_model", vision_model)
+
+    @staticmethod
+    def init_language_model(config: ModernVBertConfig):
+        text_model_config = AutoConfig.from_pretrained(
+            config.text_config.text_model_name,
+            _attn_implementation=config._attn_implementation,
+            trust_remote_code=True,
+        )
+        text_model = AutoModel.from_config(text_model_config, trust_remote_code=True)
+        embed_layer = DecoupledEmbedding(
+            num_embeddings=text_model_config.vocab_size,
+            num_additional_embeddings=config.additional_vocab_size,
+            embedding_dim=config.hidden_size,
+            partially_freeze=config.freeze_config["freeze_text_layers"],
+            padding_idx=config.pad_token_id,
+        )
+        text_model.set_input_embeddings(embed_layer)
+        return text_model
+
+    def enable_input_require_grads(self):
+        """
+        Enables the gradients for the input embeddings.
+
+        This is useful for lora when using gradient checkpointing.
+        c.f. https://github.com/huggingface/peft/issues/1402#issuecomment-1913675032
+
+        Override to set output.requires_grad = True for both the decoder's and vision model's embeddings.
+        """
+
+        def get_lowest_module(module):
+            if len(list(module.children())) == 0:
+                # If the module has no children, it is a leaf module (e.g., Linear, Conv2d, etc.)
+                return module
+            else:
+                # Recursively call the function on each child module
+                return get_lowest_module(list(module.children())[0])
+
+        def make_inputs_require_grads(module, input, output):
+            output.requires_grad_(True)
+
+        self._text_require_grads_hook = self.get_input_embeddings().register_forward_hook(make_inputs_require_grads)
+        self._vision_require_grads_hook = get_lowest_module(self.vision_model).register_forward_hook(
+            make_inputs_require_grads
+        )
+
+    def get_input_embeddings(self):
+        return self.text_model.get_input_embeddings()
+
+    def set_input_embeddings(self, value):
+        self.text_model.set_input_embeddings(value)
+
+    def inputs_merger(self, input_ids, inputs_embeds, image_hidden_states):
+        """Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/smolvlm/modeling_smolvlm.py
+
+        This method aims at merging the token embeddings with the image hidden states into one single
+        sequence of vectors that are fed to the transformer LM.
+        The merging happens as follows:
+        - The text token sequence is:
+            `tok_1 tok_2 tok_3 <fake_token_around_image> <image> <image> ... <image> <fake_token_around_image> tok_4`.
+        - We get the image hidden states for the image through the vision encoder and that hidden state,
+            after a pixel shuffle operation, is then projected into the text embedding space.
+            We thus have a sequence of image hidden states of size (1, image_seq_len, hidden_dim),
+            where 1 is for batch_size of 1 image and hidden_dim is the hidden_dim of the LM transformer.
+        - The merging happens so that we obtain the following sequence:
+            `vector_tok_1 vector_tok_2 vector_tok_3 vector_fake_tok_around_image
+            {sequence of image_seq_len image hidden states}
+            vector_fake_tok_around_image vector_tok_4`.
+            That sequence is fed to the LM.
+        - To fit the format of that sequence, `input_ids`, `input_embeds`, `attention_mask` are all 3 adapted to insert
+            the image hidden states.
+        """
+
+        _, patch_size, _ = image_hidden_states.shape
+        image_mask = input_ids == self.image_token_id
+        num_image_tokens = image_mask.sum(dim=1)
+        if not torch.all(num_image_tokens % patch_size == 0):
+            raise ValueError("Number of <image> tokens not divisible by patch_size.")
+        blocks_per_sample = num_image_tokens // patch_size
+        offsets = torch.nn.functional.pad(blocks_per_sample.cumsum(dim=0), (1, 0), value=0)
+        block_offset = offsets[:-1]
+        row_cum = image_mask.cumsum(dim=-1)
+        chunk_idx = (row_cum - 1) // patch_size
+        local_idx = (row_cum - 1) % patch_size
+        block_idx = block_offset.unsqueeze(1) + chunk_idx
+        image_embeds = torch.zeros_like(inputs_embeds)
+        image_embeds[image_mask] = image_hidden_states[block_idx[image_mask], local_idx[image_mask], :]
+        return torch.where(image_mask.unsqueeze(-1), image_embeds, inputs_embeds)
+
+    def forward(
+        self,
+        input_ids: torch.LongTensor = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        position_ids: Optional[torch.LongTensor] = None,
+        inputs_embeds: Optional[torch.FloatTensor] = None,
+        pixel_values: Optional[torch.FloatTensor] = None,
+        pixel_attention_mask: Optional[torch.BoolTensor] = None,
+        image_hidden_states: Optional[torch.FloatTensor] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+    ) -> Union[Tuple, BaseModelOutputWithPoolingAndCrossAttentions]:
+        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+        if inputs_embeds is None:
+            inputs_embeds = self.text_model.get_input_embeddings()(input_ids).to(input_ids.device)
+        if pixel_values is not None:
+            batch_size, num_images, _, _, _ = pixel_values.shape
+            pixel_values = pixel_values.view(batch_size * num_images, *pixel_values.shape[2:])
+            nb_values_per_image = pixel_values.shape[1:].numel()
+            real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image
+            if not any(real_images_inds):
+                real_images_inds[0] = True
+            pixel_values = pixel_values[real_images_inds].contiguous()
+            image_hidden_states = self.vision_model(pixel_values=pixel_values).last_hidden_state
+            image_hidden_states = self.connector(image_hidden_states)
+        elif image_hidden_states is not None:
+            image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device)
+        if inputs_embeds is not None and image_hidden_states is not None:
+            inputs_embeds = self.inputs_merger(input_ids, inputs_embeds, image_hidden_states)
+        outputs = self.text_model(
+            inputs_embeds=inputs_embeds,
+            attention_mask=attention_mask,
+            position_ids=position_ids,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+        )
+        if not return_dict:
+            return tuple(v for v in [*outputs, image_hidden_states] if v is not None)
+        return ModernVBertBaseModelOutput(
+            last_hidden_state=outputs.last_hidden_state,
+            hidden_states=outputs.hidden_states,
+            attentions=outputs.attentions,
+            image_hidden_states=image_hidden_states,
+        )
+
+
+class ModernVBertLMHead(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        pretrained_config = AutoConfig.from_pretrained(config.text_config.text_model_name, trust_remote_code=True)
+        pretrained_model = AutoModelForMaskedLM.from_config(pretrained_config, trust_remote_code=True)
+        self.head = pretrained_model.head
+        self.decoder = pretrained_model.decoder
+
+    def forward(self, hidden_states):
+        return self.decoder(self.head(hidden_states))
+
+
+class ModernVBertForMaskedLM(ModernVBertPreTrainedModel):
+    def __init__(self, config):
+        super().__init__(config)
+        self.image_token_id = config.image_token_id
+        self.in_features = config.hidden_size
+        self.out_additional_features = config.additional_vocab_size
+        self.vocab_size = config.vocab_size
+        self.model = ModernVBertModel(config)
+        self.lm_head = ModernVBertLMHead(config)
+        if self.out_additional_features > 0:
+            self.additional_fc = nn.Linear(self.in_features, self.out_additional_features, bias=False)
+        self.lm_head.to(self.dtype)
+        self.loss_function = CrossEntropyLoss()
+        self.post_init()
+
+    def forward(
+        self,
+        input_ids: torch.LongTensor = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        position_ids: Optional[torch.LongTensor] = None,
+        inputs_embeds: Optional[torch.FloatTensor] = None,
+        pixel_values: Optional[torch.FloatTensor] = None,
+        pixel_attention_mask: Optional[torch.BoolTensor] = None,
+        image_hidden_states: Optional[torch.FloatTensor] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+        labels: Optional[torch.LongTensor] = None,
+    ) -> Union[Tuple, ModernVBertMaskedLMOutput]:
+        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        outputs = self.model(
+            input_ids=input_ids,
+            attention_mask=attention_mask,
+            position_ids=position_ids,
+            inputs_embeds=inputs_embeds,
+            pixel_values=pixel_values,
+            pixel_attention_mask=pixel_attention_mask,
+            image_hidden_states=image_hidden_states,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+        )
+        hidden_states = outputs[0]
+        logits = self.lm_head(hidden_states)
+        if self.out_additional_features > 0:
+            proj_states = self.lm_head.head(hidden_states)
+            additional_features = self.additional_fc(proj_states)
+            logits = torch.cat((logits, additional_features), -1)
+        loss = None
+        if labels is not None:
+            loss = self.loss_function(logits.view(-1, self.vocab_size + self.out_additional_features), labels.view(-1))
+        if not return_dict:
+            output = (logits,) + outputs[2:]
+            return ((loss,) + output) if loss is not None else output
+        return ModernVBertMaskedLMOutput(
+            loss=loss,
+            logits=logits.float(),
+            hidden_states=outputs.hidden_states,
+            attentions=outputs.attentions,
+            image_hidden_states=outputs.image_hidden_states,
+        )

+ 2 - 0
deconstruct_SQI/colpali/colpali_engine/models/paligemma/__init__.py

@@ -0,0 +1,2 @@
+from .bipali import BiPali, BiPaliProcessor, BiPaliProj
+from .colpali import ColPali, ColPaliProcessor

+ 2 - 0
deconstruct_SQI/colpali/colpali_engine/models/paligemma/bipali/__init__.py

@@ -0,0 +1,2 @@
+from .modeling_bipali import BiPali, BiPaliProj
+from .processing_bipali import BiPaliProcessor

+ 144 - 0
deconstruct_SQI/colpali/colpali_engine/models/paligemma/bipali/modeling_bipali.py

@@ -0,0 +1,144 @@
+from typing import Optional
+
+import torch
+from torch import nn
+from transformers.models.paligemma.configuration_paligemma import PaliGemmaConfig
+from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration, PaliGemmaPreTrainedModel
+
+
+class BiPali(PaliGemmaPreTrainedModel):
+    """
+    BiPali is an implementation from the "ColPali: Efficient Document Retrieval with Vision Language Models" paper.
+    Representations are average pooled to obtain a single vector representation.
+    """
+
+    _checkpoint_conversion_mapping = {
+        "^model.language_model.model": "model.model.language_model",
+        "^model.vision_tower": "model.model.vision_tower",
+        "^model.multi_modal_projector": "model.model.multi_modal_projector",
+        "^model.language_model.lm_head": "model.lm_head",
+    }
+
+    @classmethod
+    def from_pretrained(cls, *args, **kwargs):
+        key_mapping = kwargs.pop("key_mapping", None)
+        if key_mapping is None:
+            key_mapping = cls._checkpoint_conversion_mapping
+        return super().from_pretrained(*args, **kwargs, key_mapping=key_mapping)
+
+    def __init__(self, config: PaliGemmaConfig):
+        super(BiPali, self).__init__(config=config)
+        model: PaliGemmaForConditionalGeneration = PaliGemmaForConditionalGeneration(config)
+        if model.language_model._tied_weights_keys is not None:
+            self._tied_weights_keys = [f"model.language_model.{k}" for k in model.language_model._tied_weights_keys]
+        self.model: PaliGemmaForConditionalGeneration = model
+        self.model.lm_head = torch.nn.Identity()
+        self.main_input_name = "doc_input_ids"
+        self.post_init()
+
+    def get_input_embeddings(self):
+        return self.model.language_model.get_input_embeddings()
+
+    def set_input_embeddings(self, value):
+        self.model.language_model.set_input_embeddings(value)
+
+    def get_output_embeddings(self):
+        return self.model.language_model.get_output_embeddings()
+
+    def set_output_embeddings(self, new_embeddings):
+        self.model.language_model.set_output_embeddings(new_embeddings)
+
+    def set_decoder(self, decoder):
+        self.model.language_model.set_decoder(decoder)
+
+    def get_decoder(self):
+        return self.model.language_model.get_decoder()
+
+    def tie_weights(self):
+        return self.model.language_model.tie_weights()
+
+    def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
+        model_embeds = self.model.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
+        # update vocab size
+        self.config.text_config.vocab_size = model_embeds.num_embeddings
+        self.config.vocab_size = model_embeds.num_embeddings
+        self.model.vocab_size = model_embeds.num_embeddings
+        return model_embeds
+
+    def forward(self, *args, **kwargs):
+        # delete output_hidden_states from kwargs
+        kwargs.pop("output_hidden_states", None)
+        if "pixel_values" in kwargs:
+            kwargs["pixel_values"] = kwargs["pixel_values"].to(dtype=self.dtype)
+
+        outputs = self.model(*args, output_hidden_states=True, **kwargs)
+        last_hidden_states = outputs.hidden_states[-1]  # (batch_size, sequence_length, hidden_size)
+        # pooling -mean on attention mask==1
+        proj = torch.sum(last_hidden_states * kwargs["attention_mask"].unsqueeze(-1), dim=1) / torch.sum(
+            kwargs["attention_mask"], dim=1, keepdim=True
+        )
+        proj = proj / proj.norm(dim=-1, keepdim=True)
+        return proj
+
+
+class BiPaliProj(PaliGemmaPreTrainedModel):
+    """
+    BiPaliProj is a BiPali model with a projection layer for dimensionality reduction.
+    """
+
+    def __init__(self, config: PaliGemmaConfig):
+        super(BiPaliProj, self).__init__(config=config)
+        model: PaliGemmaForConditionalGeneration = PaliGemmaForConditionalGeneration(config)
+        if model.language_model._tied_weights_keys is not None:
+            self._tied_weights_keys = [f"model.language_model.{k}" for k in model.language_model._tied_weights_keys]
+        self.model: PaliGemmaForConditionalGeneration = model
+        self.main_input_name = "doc_input_ids"
+        self.dim = 1024
+        self.custom_text_proj = nn.Linear(self.model.config.text_config.hidden_size, self.dim)
+        self.post_init()
+
+    def get_input_embeddings(self):
+        return self.model.language_model.get_input_embeddings()
+
+    def set_input_embeddings(self, value):
+        self.model.language_model.set_input_embeddings(value)
+
+    def get_output_embeddings(self):
+        return self.model.language_model.get_output_embeddings()
+
+    def set_output_embeddings(self, new_embeddings):
+        self.model.language_model.set_output_embeddings(new_embeddings)
+
+    def set_decoder(self, decoder):
+        self.model.language_model.set_decoder(decoder)
+
+    def get_decoder(self):
+        return self.model.language_model.get_decoder()
+
+    def tie_weights(self):
+        return self.model.language_model.tie_weights()
+
+    def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
+        model_embeds = self.model.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
+        # update vocab size
+        self.config.text_config.vocab_size = model_embeds.num_embeddings
+        self.config.vocab_size = model_embeds.num_embeddings
+        self.model.vocab_size = model_embeds.num_embeddings
+        return model_embeds
+
+    def forward(self, *args, **kwargs):
+        # delete output_hidden_states from kwargs
+        kwargs.pop("output_hidden_states", None)
+        if "pixel_values" in kwargs:
+            kwargs["pixel_values"] = kwargs["pixel_values"].to(dtype=self.dtype)
+
+        outputs = self.model(*args, output_hidden_states=True, **kwargs)
+        last_hidden_states = outputs.hidden_states[-1]  # (batch_size, sequence_length, hidden_size)
+
+        # pooling -mean on attention mask==1
+        proj = torch.sum(last_hidden_states * kwargs["attention_mask"].unsqueeze(-1), dim=1) / torch.sum(
+            kwargs["attention_mask"], dim=1, keepdim=True
+        )
+        proj = self.custom_text_proj(proj)
+        proj = proj / proj.norm(dim=-1, keepdim=True)
+        return proj

+ 26 - 0
deconstruct_SQI/colpali/colpali_engine/models/paligemma/bipali/processing_bipali.py

@@ -0,0 +1,26 @@
+from typing import List, Optional, Union
+
+import torch
+
+from colpali_engine.models.paligemma.colpali.processing_colpali import ColPaliProcessor
+
+
+class BiPaliProcessor(ColPaliProcessor):
+    """
+    Processor for BiPali. Mirrors the `ColPaliProcessor` class.
+    """
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
+    def score(
+        self,
+        qs: List[torch.Tensor],
+        ps: List[torch.Tensor],
+        device: Optional[Union[str, torch.device]] = None,
+        **kwargs,
+    ) -> torch.Tensor:
+        """
+        Compute the dot product score for the given single-vector query and passage embeddings.
+        """
+        return self.score_single_vector(qs, ps, device=device)

+ 2 - 0
deconstruct_SQI/colpali/colpali_engine/models/paligemma/colpali/__init__.py

@@ -0,0 +1,2 @@
+from .modeling_colpali import ColPali
+from .processing_colpali import ColPaliProcessor

+ 114 - 0
deconstruct_SQI/colpali/colpali_engine/models/paligemma/colpali/modeling_colpali.py

@@ -0,0 +1,114 @@
+from typing import ClassVar, Optional
+
+import torch
+from torch import nn
+from transformers.models.paligemma.modeling_paligemma import (
+    PaliGemmaConfig,
+    PaliGemmaForConditionalGeneration,
+    PaliGemmaPreTrainedModel,
+)
+
+
+class ColPali(PaliGemmaPreTrainedModel):
+    """
+    ColPali model implementation from the "ColPali: Efficient Document Retrieval with Vision Language Models" paper.
+
+    Args:
+        config (PaliGemmaConfig): The model configuration.
+        mask_non_image_embeddings (Optional[bool]): Whether to ignore all tokens embeddings
+            except those of the image at inference.
+            Defaults to False --> Do not mask any embeddings during forward pass.
+    """
+
+    main_input_name: ClassVar[str] = "doc_input_ids"  # transformers-related
+    _checkpoint_conversion_mapping = {
+        "^model.language_model.model": "model.model.language_model",
+        "^model.vision_tower": "model.model.vision_tower",
+        "^model.multi_modal_projector": "model.model.multi_modal_projector",
+        "^model.language_model.lm_head": "model.lm_head",
+    }
+
+    @classmethod
+    def from_pretrained(cls, *args, **kwargs):
+        key_mapping = kwargs.pop("key_mapping", None)
+        if key_mapping is None:
+            key_mapping = cls._checkpoint_conversion_mapping
+        return super().from_pretrained(*args, **kwargs, key_mapping=key_mapping)
+
+    def __init__(self, config: PaliGemmaConfig, mask_non_image_embeddings: bool = False):
+        super().__init__(config=config)
+
+        model = PaliGemmaForConditionalGeneration(config=config)
+        if model.language_model._tied_weights_keys is not None:
+            self._tied_weights_keys = [f"model.language_model.{k}" for k in model.language_model._tied_weights_keys]
+        self.model = model
+        self.model.lm_head = torch.nn.Identity()
+
+        # TODO: Wait for ColPali2 to create a ColPaliConfig to allow specifying the embedding dimension.
+        # We could do it now but it would break all the models trying to load the model from the checkpoint.
+        self.dim = 128
+        self.custom_text_proj = nn.Linear(self.model.config.text_config.hidden_size, self.dim)
+
+        self.mask_non_image_embeddings = mask_non_image_embeddings
+
+        self.post_init()
+
+    def forward(self, *args, **kwargs) -> torch.Tensor:
+        # Delete output_hidden_states from kwargs
+        kwargs.pop("output_hidden_states", None)
+        if "pixel_values" in kwargs:
+            kwargs["pixel_values"] = kwargs["pixel_values"].to(dtype=self.dtype)
+
+        outputs = self.model(*args, output_hidden_states=True, **kwargs)  # (batch_size, sequence_length, hidden_size)
+        last_hidden_states = outputs.hidden_states[-1]  # (batch_size, sequence_length, hidden_size)
+        proj = self.custom_text_proj(last_hidden_states)  # (batch_size, sequence_length, dim)
+
+        # L2 normalization
+        proj = proj / proj.norm(dim=-1, keepdim=True)  # (batch_size, sequence_length, dim)
+
+        proj = proj * kwargs["attention_mask"].unsqueeze(-1)  # (batch_size, sequence_length, dim)
+
+        if "pixel_values" in kwargs and self.mask_non_image_embeddings:
+            # Pools only the image embeddings
+            image_mask = (kwargs["input_ids"] == self.config.image_token_index).unsqueeze(-1)
+            proj = proj * image_mask
+        return proj
+
+    def get_input_embeddings(self):
+        return self.model.language_model.get_input_embeddings()
+
+    def set_input_embeddings(self, value):
+        self.model.language_model.set_input_embeddings(value)
+
+    def get_output_embeddings(self):
+        return self.model.language_model.get_output_embeddings()
+
+    def set_output_embeddings(self, new_embeddings):
+        self.model.language_model.set_output_embeddings(new_embeddings)
+
+    def set_decoder(self, decoder):
+        self.model.language_model.set_decoder(decoder)
+
+    def get_decoder(self):
+        return self.model.language_model.get_decoder()
+
+    def tie_weights(self):
+        return self.model.language_model.tie_weights()
+
+    def resize_token_embeddings(
+        self,
+        new_num_tokens: Optional[int] = None,
+        pad_to_multiple_of=None,
+    ) -> nn.Embedding:
+        model_embeds = self.model.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
+
+        # Update vocab size
+        self.config.text_config.vocab_size = model_embeds.num_embeddings
+        self.config.vocab_size = model_embeds.num_embeddings
+        self.model.vocab_size = model_embeds.num_embeddings
+
+        return model_embeds
+
+    @property
+    def patch_size(self) -> int:
+        return self.model.vision_tower.config.patch_size

+ 89 - 0
deconstruct_SQI/colpali/colpali_engine/models/paligemma/colpali/processing_colpali.py

@@ -0,0 +1,89 @@
+from typing import ClassVar, List, Optional, Tuple, Union
+
+import torch
+from PIL import Image
+from transformers import BatchEncoding, BatchFeature, PaliGemmaProcessor
+
+from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor
+
+
+class ColPaliProcessor(BaseVisualRetrieverProcessor, PaliGemmaProcessor):
+    """
+    Processor for ColPali.
+    """
+
+    visual_prompt_prefix: ClassVar[str] = "<image><bos>Describe the image."
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
+    @property
+    def query_augmentation_token(self) -> str:
+        """
+        Return the query augmentation token.
+        Query augmentation buffers are used as reasoning buffers during inference.
+        """
+        return self.tokenizer.pad_token
+
+    def process_images(
+        self,
+        images: List[Image.Image],
+    ) -> Union[BatchFeature, BatchEncoding]:
+        """
+        Process images for ColPali.
+
+        Args:
+            images: List of PIL images.
+        """
+        images = [image.convert("RGB") for image in images]
+
+        batch_doc = self(
+            text=[self.visual_prompt_prefix] * len(images),
+            images=images,
+            return_tensors="pt",
+            padding="longest",
+        )
+        return batch_doc
+
+    def process_texts(self, texts: List[str]) -> Union[BatchFeature, BatchEncoding]:
+        """
+        Process texts for ColPali.
+
+        Args:
+            texts: List of input texts.
+
+        Returns:
+            Union[BatchFeature, BatchEncoding]: Processed texts.
+        """
+        return self.tokenizer(
+            [self.tokenizer.bos_token + text for text in texts],
+            text_pair=None,
+            return_token_type_ids=False,
+            return_tensors="pt",
+            padding="longest",
+        )
+
+    def score(
+        self,
+        qs: List[torch.Tensor],
+        ps: List[torch.Tensor],
+        device: Optional[Union[str, torch.device]] = None,
+        **kwargs,
+    ) -> torch.Tensor:
+        """
+        Compute the MaxSim score (ColBERT-like) for the given multi-vector query and passage embeddings.
+        """
+        return self.score_multi_vector(qs, ps, device=device, **kwargs)
+
+    def get_n_patches(
+        self,
+        image_size: Tuple[int, int],
+        patch_size: int,
+    ) -> Tuple[int, int]:
+        n_patches_x = self.image_processor.size["width"] // patch_size
+        n_patches_y = self.image_processor.size["height"] // patch_size
+
+        return n_patches_x, n_patches_y
+
+    def get_image_mask(self, batch_images: BatchFeature) -> torch.Tensor:
+        return batch_images.input_ids == self.image_token_id

+ 2 - 0
deconstruct_SQI/colpali/colpali_engine/models/qwen2/__init__.py

@@ -0,0 +1,2 @@
+from .biqwen2 import BiQwen2, BiQwen2Processor
+from .colqwen2 import ColQwen2, ColQwen2Processor

+ 2 - 0
deconstruct_SQI/colpali/colpali_engine/models/qwen2/biqwen2/__init__.py

@@ -0,0 +1,2 @@
+from .modeling_biqwen2 import BiQwen2
+from .processing_biqwen2 import BiQwen2Processor

+ 76 - 0
deconstruct_SQI/colpali/colpali_engine/models/qwen2/biqwen2/modeling_biqwen2.py

@@ -0,0 +1,76 @@
+from typing import ClassVar, Literal
+
+import torch
+from transformers.models.qwen2_vl import Qwen2VLConfig, Qwen2VLModel
+
+
+class BiQwen2(Qwen2VLModel):
+    """
+    BiQwen2 is an implementation from the "ColPali: Efficient Document Retrieval with Vision Language Models" paper.
+    Representations are pooled to obtain a single vector representation. Based on the Qwen2.5-VL backbone.
+    """
+
+    main_input_name: ClassVar[str] = "doc_input_ids"  # transformers-related
+
+    def __init__(self, config: Qwen2VLConfig):
+        super().__init__(config=config)
+        self.padding_side = "left"
+        self.post_init()
+
+    @classmethod
+    def from_pretrained(cls, *args, **kwargs):
+        key_mapping = kwargs.pop("key_mapping", None)
+        if key_mapping is None:
+            key_mapping = super()._checkpoint_conversion_mapping
+        return super().from_pretrained(*args, **kwargs, key_mapping=key_mapping)
+
+    def forward(
+        self,
+        pooling_strategy: Literal["cls", "last", "mean"] = "last",
+        *args,
+        **kwargs,
+    ) -> torch.Tensor:
+        """
+        Forward pass for BiQwen2.5 model.
+
+        Args:
+            pooling_strategy: The strategy to use for pooling the hidden states.
+            *args: Variable length argument list.
+            **kwargs: Additional keyword arguments.
+
+        Returns:
+            torch.Tensor: Dense embeddings (batch_size, hidden_size).
+        """
+        # Handle the custom "pixel_values" input obtained with `ColQwen2Processor` through unpadding
+        if "pixel_values" in kwargs:
+            offsets = kwargs["image_grid_thw"][:, 1] * kwargs["image_grid_thw"][:, 2]  # (batch_size,)
+            kwargs["pixel_values"] = torch.cat(
+                [pixel_sequence[:offset] for pixel_sequence, offset in zip(kwargs["pixel_values"], offsets)],
+                dim=0,
+            )
+        kwargs.pop("return_dict", True)
+        kwargs.pop("output_hidden_states", None)
+        kwargs.pop("use_cache", None)
+        last_hidden_states = (
+            super()
+            .forward(*args, **kwargs, use_cache=False, output_hidden_states=True, return_dict=True)
+            .last_hidden_state
+        )  # (batch_size, sequence_length, hidden_size)
+
+        # Get CLS token embedding, last token, or mean pool over sequence
+        if pooling_strategy == "cls":
+            # Use CLS token (first token) embedding
+            pooled_output = last_hidden_states[:, 0]  # (batch_size, hidden_size)
+        elif pooling_strategy == "last":
+            # use last token since we are left padding
+            pooled_output = last_hidden_states[:, -1]  # (batch_size, hidden_size)
+        elif pooling_strategy == "mean":
+            # Mean pooling over sequence length
+            mask = kwargs["attention_mask"].unsqueeze(-1)  # (batch_size, sequence_length, 1)
+            pooled_output = (last_hidden_states * mask).sum(dim=1) / mask.sum(dim=1)  # (batch_size, hidden_size)
+        else:
+            raise ValueError(f"Invalid pooling strategy: {pooling_strategy}")
+
+        # L2 normalization
+        pooled_output = pooled_output / pooled_output.norm(dim=-1, keepdim=True)
+        return pooled_output

+ 43 - 0
deconstruct_SQI/colpali/colpali_engine/models/qwen2/biqwen2/processing_biqwen2.py

@@ -0,0 +1,43 @@
+from typing import List, Optional, Union
+
+import torch
+from transformers import BatchEncoding, BatchFeature
+
+from colpali_engine.models.qwen2.colqwen2 import ColQwen2Processor
+
+
+class BiQwen2Processor(ColQwen2Processor):
+    """
+    Processor for ColQwen2.
+    """
+
+    def process_texts(
+        self,
+        texts: List[str],
+    ) -> Union[BatchFeature, BatchEncoding]:
+        """
+        Process texts for ColQwen2.
+
+        Args:
+            texts: List of input texts.
+
+        Returns:
+            Union[BatchFeature, BatchEncoding]: Processed texts.
+        """
+        return self(
+            text=texts,
+            return_tensors="pt",
+            padding="longest",
+        )
+
+    def score(
+        self,
+        qs: List[torch.Tensor],
+        ps: List[torch.Tensor],
+        device: Optional[Union[str, torch.device]] = None,
+        **kwargs,
+    ) -> torch.Tensor:
+        """
+        Compute the MaxSim score (ColBERT-like) for the given multi-vector query and passage embeddings.
+        """
+        return self.score_single_vector(qs, ps, device=device)

+ 2 - 0
deconstruct_SQI/colpali/colpali_engine/models/qwen2/colqwen2/__init__.py

@@ -0,0 +1,2 @@
+from .modeling_colqwen2 import ColQwen2
+from .processing_colqwen2 import ColQwen2Processor

+ 71 - 0
deconstruct_SQI/colpali/colpali_engine/models/qwen2/colqwen2/modeling_colqwen2.py

@@ -0,0 +1,71 @@
+from typing import ClassVar
+
+import torch
+from torch import nn
+from transformers.models.qwen2_vl import Qwen2VLConfig, Qwen2VLModel
+
+
+class ColQwen2(Qwen2VLModel):
+    """
+    ColQwen2 model implementation from the "ColPali: Efficient Document Retrieval with Vision Language Models" paper.
+
+    Args:
+        config (Qwen2VLConfig): The model configuration.
+        mask_non_image_embeddings (Optional[bool]): Whether to ignore all tokens embeddings
+            except those of the image at inference.
+            Defaults to False --> Do not mask any embeddings during forward pass.
+    """
+
+    main_input_name: ClassVar[str] = "doc_input_ids"  # transformers-related
+
+    def __init__(self, config: Qwen2VLConfig, mask_non_image_embeddings: bool = False):
+        super().__init__(config=config)
+        self.dim = 128
+        self.custom_text_proj = nn.Linear(self.config.hidden_size, self.dim)
+        self.padding_side = "left"
+        self.mask_non_image_embeddings = mask_non_image_embeddings
+        self.post_init()
+
+    @classmethod
+    def from_pretrained(cls, *args, **kwargs):
+        key_mapping = kwargs.pop("key_mapping", None)
+        if key_mapping is None:
+            key_mapping = super()._checkpoint_conversion_mapping
+        return super().from_pretrained(*args, **kwargs, key_mapping=key_mapping)
+
+    def forward(self, *args, **kwargs) -> torch.Tensor:
+        # Handle the custom "pixel_values" input obtained with `ColQwen2Processor` through unpadding
+        if "pixel_values" in kwargs:
+            offsets = kwargs["image_grid_thw"][:, 1] * kwargs["image_grid_thw"][:, 2]  # (batch_size,)
+            kwargs["pixel_values"] = torch.cat(
+                [pixel_sequence[:offset] for pixel_sequence, offset in zip(kwargs["pixel_values"], offsets)],
+                dim=0,
+            )
+        kwargs.pop("return_dict", True)
+        kwargs.pop("output_hidden_states", None)
+        kwargs.pop("use_cache", None)
+        hidden_states = (
+            super()
+            .forward(*args, **kwargs, use_cache=False, output_hidden_states=True, return_dict=True)
+            .last_hidden_state
+        )  # (batch_size, sequence_length, hidden_size)
+
+        proj = self.custom_text_proj(hidden_states)  # (batch_size, sequence_length, dim)
+
+        # L2 normalization
+        proj = proj / proj.norm(dim=-1, keepdim=True)  # (batch_size, sequence_length, dim)
+        proj = proj * kwargs["attention_mask"].unsqueeze(-1)  # (batch_size, sequence_length, dim)
+
+        if "pixel_values" in kwargs and self.mask_non_image_embeddings:
+            # Pools only the image embeddings
+            image_mask = (kwargs["input_ids"] == self.config.image_token_id).unsqueeze(-1)
+            proj = proj * image_mask
+        return proj
+
+    @property
+    def patch_size(self) -> int:
+        return self.visual.config.patch_size
+
+    @property
+    def spatial_merge_size(self) -> int:
+        return self.visual.config.spatial_merge_size

+ 149 - 0
deconstruct_SQI/colpali/colpali_engine/models/qwen2/colqwen2/processing_colqwen2.py

@@ -0,0 +1,149 @@
+from typing import ClassVar, List, Optional, Tuple, Union
+
+import torch
+from PIL import Image
+from transformers import BatchEncoding, BatchFeature
+from transformers.models.qwen2_vl import Qwen2VLProcessor
+from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
+
+from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor
+
+
+class ColQwen2Processor(BaseVisualRetrieverProcessor, Qwen2VLProcessor):
+    """
+    Processor for ColQwen2.
+
+    Args:
+        *args: Variable length argument list to be passed to the parent `Qwen2VLProcessor` class.
+        max_num_visual_tokens: The maximum number of visual tokens that can be processed by the model.
+        **kwargs: Arbitrary keyword arguments to be passed to the parent `Qwen2VLProcessor` class.
+    """
+
+    visual_prompt_prefix: ClassVar[str] = (
+        "<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Describe the image.<|im_end|><|endoftext|>"
+    )
+    query_augmentation_token: ClassVar[str] = "<|endoftext|>"
+    image_token: ClassVar[str] = "<|image_pad|>"
+
+    def __init__(
+        self,
+        *args,
+        **kwargs,
+    ):
+        super().__init__(*args, **kwargs)
+        self.tokenizer.padding_side = "left"
+
+    @classmethod
+    def from_pretrained(
+        cls,
+        *args,
+        device_map: Optional[str] = None,
+        **kwargs,
+    ):
+        instance = super().from_pretrained(
+            *args,
+            device_map=device_map,
+            **kwargs,
+        )
+
+        if "max_num_visual_tokens" in kwargs:
+            instance.image_processor.max_pixels = kwargs["max_num_visual_tokens"] * 28 * 28
+            instance.image_processor.size["longest_edge"] = instance.image_processor.max_pixels
+
+        return instance
+
+    def process_images(
+        self,
+        images: List[Image.Image],
+    ) -> Union[BatchFeature, BatchEncoding]:
+        """
+        Process images for ColQwen2.
+
+        Args:
+            images: List of PIL images.
+        """
+
+        images = [image.convert("RGB") for image in images]
+
+        batch_doc = self(
+            text=[self.visual_prompt_prefix] * len(images),
+            images=images,
+            padding="longest",
+            return_tensors="pt",
+        )
+
+        # # NOTE: The following adjustment ensures correct behavior with DDP on multiple GPUs.
+        offsets = batch_doc["image_grid_thw"][:, 1] * batch_doc["image_grid_thw"][:, 2]  # (batch_size,)
+
+        # Split the pixel_values tensor into a list of tensors, one per image
+        pixel_values = list(
+            torch.split(batch_doc["pixel_values"], offsets.tolist())
+        )  # [(num_patches_image_0, pixel_values), ..., (num_patches_image_n, pixel_values)]
+
+        # Pad the list of pixel_value tensors to the same length along the sequence dimension
+        batch_doc["pixel_values"] = torch.nn.utils.rnn.pad_sequence(
+            pixel_values, batch_first=True
+        )  # (batch_size, max_num_patches, pixel_values)
+
+        return batch_doc
+
+    def process_texts(self, texts: List[str]) -> Union[BatchFeature, BatchEncoding]:
+        """
+        Process texts for ColQwen2.
+
+        Args:
+            texts: List of input texts.
+
+        Returns:
+            Union[BatchFeature, BatchEncoding]: Processed texts.
+        """
+        return self(
+            text=texts,
+            return_tensors="pt",
+            padding="longest",
+        )
+
+    def score(
+        self,
+        qs: List[torch.Tensor],
+        ps: List[torch.Tensor],
+        device: Optional[Union[str, torch.device]] = None,
+        **kwargs,
+    ) -> torch.Tensor:
+        """
+        Compute the MaxSim score (ColBERT-like) for the given multi-vector query and passage embeddings.
+        """
+        return self.score_multi_vector(qs, ps, device=device, **kwargs)
+
+    def get_n_patches(
+        self,
+        image_size: Tuple[int, int],
+        spatial_merge_size: int,
+    ) -> Tuple[int, int]:
+        """
+        Get the number of patches (n_patches_x, n_patches_y) that will be used to process an image of
+        size (height, width) with the given patch size.
+
+        The `spatial_merge_size` is the number of patches that will be merged spatially. It is stored in
+        as a `Qwen2VLForConditionalGeneration` attribute under `model.spatial_merge_size`.
+        """
+        patch_size = self.image_processor.patch_size
+
+        height_new, width_new = smart_resize(
+            width=image_size[0],
+            height=image_size[1],
+            factor=patch_size * self.image_processor.merge_size,
+            min_pixels=self.image_processor.size["shortest_edge"],
+            max_pixels=self.image_processor.size["longest_edge"],
+        )
+
+        n_patches_x = width_new // patch_size // spatial_merge_size
+        n_patches_y = height_new // patch_size // spatial_merge_size
+
+        return n_patches_x, n_patches_y
+
+    def get_image_mask(self, batch_images: BatchFeature) -> torch.Tensor:
+        """
+        Get a tensor mask that identifies the image tokens in the batch.
+        """
+        return batch_images.input_ids == self.image_token_id

+ 2 - 0
deconstruct_SQI/colpali/colpali_engine/models/qwen2_5/__init__.py

@@ -0,0 +1,2 @@
+from .biqwen2_5 import BiQwen2_5, BiQwen2_5_Processor
+from .colqwen2_5 import ColQwen2_5, ColQwen2_5_Processor

+ 2 - 0
deconstruct_SQI/colpali/colpali_engine/models/qwen2_5/biqwen2_5/__init__.py

@@ -0,0 +1,2 @@
+from .modeling_biqwen2_5 import BiQwen2_5
+from .processing_biqwen2_5 import BiQwen2_5_Processor

+ 86 - 0
deconstruct_SQI/colpali/colpali_engine/models/qwen2_5/biqwen2_5/modeling_biqwen2_5.py

@@ -0,0 +1,86 @@
+from typing import ClassVar, Literal
+
+import torch
+from transformers.models.qwen2_5_vl import Qwen2_5_VLConfig, Qwen2_5_VLModel
+
+
+class BiQwen2_5(Qwen2_5_VLModel):  # noqa: N801
+    """
+    BiQwen2.5 is an implementation from the "ColPali: Efficient Document Retrieval with Vision Language Models" paper.
+    Representations are pooled to obtain a single vector representation. Based on the Qwen2.5-VL backbone.
+    """
+
+    main_input_name: ClassVar[str] = "doc_input_ids"  # transformers-related
+
+    def __init__(self, config: Qwen2_5_VLConfig):
+        super().__init__(config=config)
+        # self.dim = 128
+        # self.custom_text_proj = nn.Linear(self.model.config.hidden_size, self.dim)
+        self.padding_side = "left"
+        self.post_init()
+
+    @classmethod
+    def from_pretrained(cls, *args, **kwargs):
+        key_mapping = kwargs.pop("key_mapping", None)
+        if key_mapping is None:
+            key_mapping = super()._checkpoint_conversion_mapping
+        return super().from_pretrained(*args, **kwargs, key_mapping=key_mapping)
+
+    def forward(
+        self,
+        pooling_strategy: Literal["cls", "last", "mean"] = "last",
+        *args,
+        **kwargs,
+    ) -> torch.Tensor:
+        """
+        Forward pass for BiQwen2.5 model.
+
+        Args:
+            pooling_strategy: The strategy to use for pooling the hidden states.
+            *args: Variable length argument list.
+            **kwargs: Additional keyword arguments.
+
+        Returns:
+            torch.Tensor: Dense embeddings (batch_size, hidden_size).
+        """
+        # Handle the custom "pixel_values" input obtained with `ColQwen2Processor` through unpadding
+        if "pixel_values" in kwargs:
+            offsets = kwargs["image_grid_thw"][:, 1] * kwargs["image_grid_thw"][:, 2]  # (batch_size,)
+            kwargs["pixel_values"] = torch.cat(
+                [pixel_sequence[:offset] for pixel_sequence, offset in zip(kwargs["pixel_values"], offsets)],
+                dim=0,
+            )
+        kwargs.pop("return_dict", True)
+        kwargs.pop("output_hidden_states", None)
+        kwargs.pop("use_cache", None)
+        last_hidden_states = (
+            super()
+            .forward(*args, **kwargs, use_cache=False, output_hidden_states=True, return_dict=True)
+            .last_hidden_state
+        )  # (batch_size, sequence_length, hidden_size)# (batch_size, sequence_length, hidden_size)
+
+        # Get CLS token embedding, last token, or mean pool over sequence
+        if pooling_strategy == "cls":
+            # Use CLS token (first token) embedding
+            pooled_output = last_hidden_states[:, 0]  # (batch_size, hidden_size)
+        elif pooling_strategy == "last":
+            # use last token since we are left padding
+            pooled_output = last_hidden_states[:, -1]  # (batch_size, hidden_size)
+        elif pooling_strategy == "mean":
+            # Mean pooling over sequence length
+            mask = kwargs["attention_mask"].unsqueeze(-1)  # (batch_size, sequence_length, 1)
+            pooled_output = (last_hidden_states * mask).sum(dim=1) / mask.sum(dim=1)  # (batch_size, hidden_size)
+        else:
+            raise ValueError(f"Invalid pooling strategy: {pooling_strategy}")
+
+        # L2 normalization
+        pooled_output = pooled_output / pooled_output.norm(dim=-1, keepdim=True)
+        return pooled_output
+
+    @property
+    def patch_size(self) -> int:
+        return self.visual.config.patch_size
+
+    @property
+    def spatial_merge_size(self) -> int:
+        return self.visual.config.spatial_merge_size

+ 40 - 0
deconstruct_SQI/colpali/colpali_engine/models/qwen2_5/biqwen2_5/processing_biqwen2_5.py

@@ -0,0 +1,40 @@
+from typing import List, Optional, Union
+
+import torch
+from transformers import BatchEncoding, BatchFeature
+
+from colpali_engine.models.qwen2_5.colqwen2_5 import ColQwen2_5_Processor
+
+
+class BiQwen2_5_Processor(ColQwen2_5_Processor):  # noqa: N801
+    """
+    Processor for BiQwen2.5.
+    """
+
+    def process_texts(self, texts: List[str]) -> Union[BatchFeature, BatchEncoding]:
+        """
+        Process texts for BiQwen2.5.
+
+        Args:
+            texts: List of input texts.
+
+        Returns:
+            Union[BatchFeature, BatchEncoding]: Processed texts.
+        """
+        return self(
+            text=texts,
+            return_tensors="pt",
+            padding="longest",
+        )
+
+    def score(
+        self,
+        qs: List[torch.Tensor],
+        ps: List[torch.Tensor],
+        device: Optional[Union[str, torch.device]] = None,
+        **kwargs,
+    ) -> torch.Tensor:
+        """
+        Compute the cosine similarity for the given query and passage embeddings.
+        """
+        return self.score_single_vector(qs, ps, device=device)

+ 2 - 0
deconstruct_SQI/colpali/colpali_engine/models/qwen2_5/colqwen2_5/__init__.py

@@ -0,0 +1,2 @@
+from .modeling_colqwen2_5 import ColQwen2_5
+from .processing_colqwen2_5 import ColQwen2_5_Processor

+ 73 - 0
deconstruct_SQI/colpali/colpali_engine/models/qwen2_5/colqwen2_5/modeling_colqwen2_5.py

@@ -0,0 +1,73 @@
+from typing import ClassVar
+
+import torch
+from torch import nn
+from transformers.models.qwen2_5_vl import Qwen2_5_VLConfig, Qwen2_5_VLModel
+
+
+class ColQwen2_5(Qwen2_5_VLModel):  # noqa: N801
+    """
+    ColQwen2.5 model implementation, following the achitecture from the article "ColPali: Efficient Document Retrieval
+    with Vision Language Models" paper. Based on the Qwen2.5-VL backbone.
+
+    Args:
+        config (Qwen2.5VLConfig): The model configuration.
+        mask_non_image_embeddings (Optional[bool]): Whether to ignore all tokens embeddings
+            except those of the image at inference.
+            Defaults to False --> Do not mask any embeddings during forward pass.
+    """
+
+    main_input_name: ClassVar[str] = "doc_input_ids"  # transformers-related
+
+    def __init__(self, config: Qwen2_5_VLConfig, mask_non_image_embeddings: bool = False):
+        super().__init__(config=config)
+        self.dim = 128
+        self.custom_text_proj = nn.Linear(self.config.hidden_size, self.dim)
+        self.padding_side = "left"
+        self.mask_non_image_embeddings = mask_non_image_embeddings
+        self.post_init()
+
+    @classmethod
+    def from_pretrained(cls, *args, **kwargs):
+        key_mapping = kwargs.pop("key_mapping", None)
+        if key_mapping is None:
+            key_mapping = super()._checkpoint_conversion_mapping
+        return super().from_pretrained(*args, **kwargs, key_mapping=key_mapping)
+
+    def forward(self, *args, **kwargs) -> torch.Tensor:
+        # Handle the custom "pixel_values" input obtained with `ColQwen2Processor` through unpadding
+        if "pixel_values" in kwargs:
+            offsets = kwargs["image_grid_thw"][:, 1] * kwargs["image_grid_thw"][:, 2]  # (batch_size,)
+            kwargs["pixel_values"] = torch.cat(
+                [pixel_sequence[:offset] for pixel_sequence, offset in zip(kwargs["pixel_values"], offsets)],
+                dim=0,
+            )
+
+        kwargs.pop("return_dict", True)
+        kwargs.pop("output_hidden_states", None)
+        kwargs.pop("use_cache", None)
+        last_hidden_states = (
+            super()
+            .forward(*args, **kwargs, use_cache=False, output_hidden_states=True, return_dict=True)
+            .last_hidden_state
+        )  # (batch_size, sequence_length, hidden_size)# (batch_size, sequence_length, hidden_size)
+
+        proj = self.custom_text_proj(last_hidden_states)  # (batch_size, sequence_length, dim)
+
+        # L2 normalization
+        proj = proj / proj.norm(dim=-1, keepdim=True)  # (batch_size, sequence_length, dim)
+        proj = proj * kwargs["attention_mask"].unsqueeze(-1)  # (batch_size, sequence_length, dim)
+
+        if "pixel_values" in kwargs and self.mask_non_image_embeddings:
+            # Pools only the image embeddings
+            image_mask = (kwargs["input_ids"] == self.config.image_token_id).unsqueeze(-1)
+            proj = proj * image_mask
+        return proj
+
+    @property
+    def patch_size(self) -> int:
+        return self.visual.config.patch_size
+
+    @property
+    def spatial_merge_size(self) -> int:
+        return self.visual.config.spatial_merge_size

+ 146 - 0
deconstruct_SQI/colpali/colpali_engine/models/qwen2_5/colqwen2_5/processing_colqwen2_5.py

@@ -0,0 +1,146 @@
+from typing import ClassVar, List, Optional, Tuple, Union
+
+import torch
+from PIL import Image
+from transformers import BatchEncoding, BatchFeature
+from transformers.models.qwen2_vl import Qwen2VLProcessor
+from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
+
+from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor
+
+
+class ColQwen2_5_Processor(BaseVisualRetrieverProcessor, Qwen2VLProcessor):  # noqa: N801
+    """
+    Processor for ColQwen2.5.
+
+    Args:
+        *args: Variable length argument list to be passed to the parent `Qwen2VLProcessor` class.
+        max_num_visual_tokens: The maximum number of visual tokens that can be processed by the model.
+        **kwargs: Arbitrary keyword arguments to be passed to the parent `Qwen2VLProcessor` class.
+    """
+
+    visual_prompt_prefix: ClassVar[str] = (
+        "<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Describe the image.<|im_end|><|endoftext|>"
+    )
+    query_augmentation_token: ClassVar[str] = "<|endoftext|>"
+    image_token: ClassVar[str] = "<|image_pad|>"
+
+    def __init__(
+        self,
+        *args,
+        **kwargs,
+    ):
+        super().__init__(*args, **kwargs)
+        self.tokenizer.padding_side = "left"
+
+    @classmethod
+    def from_pretrained(
+        cls,
+        *args,
+        device_map: Optional[str] = None,
+        **kwargs,
+    ):
+        instance = super().from_pretrained(
+            *args,
+            device_map=device_map,
+            **kwargs,
+        )
+
+        if "max_num_visual_tokens" in kwargs:
+            instance.image_processor.max_pixels = kwargs["max_num_visual_tokens"] * 28 * 28
+            instance.image_processor.size["longest_edge"] = instance.image_processor.max_pixels
+
+        return instance
+
+    def process_images(
+        self,
+        images: List[Image.Image],
+    ) -> Union[BatchFeature, BatchEncoding]:
+        """
+        Process images for ColQwen2.5.
+
+        Args:
+            images: List of PIL images.
+        """
+
+        images = [image.convert("RGB") for image in images]
+
+        batch_doc = self(
+            text=[self.visual_prompt_prefix] * len(images),
+            images=images,
+            padding="longest",
+            return_tensors="pt",
+        )
+
+        # NOTE: The following adjustment ensures correct behavior with DDP on multiple GPUs.
+        offsets = batch_doc["image_grid_thw"][:, 1] * batch_doc["image_grid_thw"][:, 2]  # (batch_size,)
+
+        # Split the pixel_values tensor into a list of tensors, one per image
+        pixel_values = list(
+            torch.split(batch_doc["pixel_values"], offsets.tolist())
+        )  # [(num_patches_image_0, pixel_values), ..., (num_patches_image_n, pixel_values)]
+
+        # Pad the list of pixel_value tensors to the same length along the sequence dimension
+        batch_doc["pixel_values"] = torch.nn.utils.rnn.pad_sequence(
+            pixel_values, batch_first=True
+        )  # (batch_size, max_num_patches, pixel_values)
+
+        return batch_doc
+
+    def process_texts(self, texts: List[str]) -> Union[BatchFeature, BatchEncoding]:
+        """
+        Process texts for ColQwen2.5.
+
+        Args:
+            texts: List of input texts.
+
+        Returns:
+            Union[BatchFeature, BatchEncoding]: Processed texts.
+        """
+        return self(
+            text=texts,
+            return_tensors="pt",
+            padding="longest",
+        )
+
+    def score(
+        self,
+        qs: List[torch.Tensor],
+        ps: List[torch.Tensor],
+        device: Optional[Union[str, torch.device]] = None,
+        **kwargs,
+    ) -> torch.Tensor:
+        """
+        Compute the MaxSim score (ColBERT-like) for the given multi-vector query and passage embeddings.
+        """
+        return self.score_multi_vector(qs, ps, device=device, **kwargs)
+
+    def get_n_patches(
+        self,
+        image_size: Tuple[int, int],
+        spatial_merge_size: int,
+    ) -> Tuple[int, int]:
+        """
+        Get the number of patches (n_patches_x, n_patches_y) that will be used to process an image of
+        size (height, width) with the given patch size.
+
+        The `spatial_merge_size` is the number of patches that will be merged spatially. It is stored in
+        as a `Qwen2VLForConditionalGeneration` attribute under `model.spatial_merge_size`.
+        """
+        patch_size = self.image_processor.patch_size
+
+        height_new, width_new = smart_resize(
+            width=image_size[0],
+            height=image_size[1],
+            factor=patch_size * self.image_processor.merge_size,
+            min_pixels=self.image_processor.size["shortest_edge"],
+            max_pixels=self.image_processor.size["longest_edge"],
+        )
+
+        n_patches_x = width_new // patch_size // spatial_merge_size
+        n_patches_y = height_new // patch_size // spatial_merge_size
+
+        return n_patches_x, n_patches_y
+
+    def get_image_mask(self, batch_images: BatchFeature) -> torch.Tensor:
+        return batch_images.input_ids == self.image_token_id

+ 65 - 0
deconstruct_SQI/colpali/colpali_engine/models/qwen_omni/colqwen_omni/modeling_colqwen_omni.py

@@ -0,0 +1,65 @@
+from typing import ClassVar
+
+import torch
+from torch import nn
+from transformers.models.qwen2_5_omni import Qwen2_5OmniThinkerConfig, Qwen2_5OmniThinkerForConditionalGeneration
+
+
+class ColQwen2_5Omni(Qwen2_5OmniThinkerForConditionalGeneration):  # noqa: N801
+    """
+    ColQwen2.5 Omni model with custom text projection layer.
+    This model is a modified version of the Qwen2.5 Omni model, which includes a custom text projection layer
+    for better performance in visual-textual tasks.
+    """
+
+    main_input_name: ClassVar[str] = "doc_input_ids"  # transformers-related
+
+    def __init__(self, config: Qwen2_5OmniThinkerConfig, mask_non_image_embeddings: bool = False):
+        super().__init__(config=config)
+        self.dim = 128
+        self.custom_text_proj = nn.Linear(self.model.config.hidden_size, self.dim)
+        self.lm_head = nn.Identity()  # Disable the original lm_head
+        self.padding_side = "left"
+        self.mask_non_image_embeddings = mask_non_image_embeddings
+        self.lm_head = nn.Identity()  # Disable the original lm_head
+        self.post_init()
+
+    def forward(self, *args, **kwargs) -> torch.Tensor:
+        # # Handle the custom "pixel_values" input obtained with `ColQwen2Processor` through unpadding
+        # if "pixel_values" in kwargs:
+        #     offsets = kwargs["image_grid_thw"][:, 1] * kwargs["image_grid_thw"][:, 2]  # (batch_size,)
+        #     kwargs["pixel_values"] = torch.cat(
+        #         [pixel_sequence[:offset] for pixel_sequence, offset in zip(kwargs["pixel_values"], offsets)],
+        #         dim=0,
+        #     )
+        # pop return dict and output hidden states
+        kwargs.pop("return_dict", True)
+        kwargs.pop("output_hidden_states", None)
+        kwargs.pop("use_cache", None)
+        last_hidden_states = (
+            super().forward(*args, **kwargs, use_cache=False, output_hidden_states=True, return_dict=True).logits
+        )  # (batch_size, sequence_length, hidden_size)# (batch_size, sequence_length, hidden_size)
+        proj = self.custom_text_proj(last_hidden_states)  # (batch_size, sequence_length, dim)
+
+        # L2 normalization
+        proj = proj / proj.norm(dim=-1, keepdim=True)  # (batch_size, sequence_length, dim)
+        proj = proj * kwargs["attention_mask"].unsqueeze(-1)  # (batch_size, sequence_length, dim)
+
+        if "pixel_values" in kwargs and self.mask_non_image_embeddings:
+            # Pools only the image embeddings
+            image_mask = (kwargs["input_ids"] == self.config.image_token_id).unsqueeze(-1)
+            proj = proj * image_mask
+        return proj
+
+    @property
+    def patch_size(self) -> int:
+        return self.visual.config.patch_size
+
+    @property
+    def spatial_merge_size(self) -> int:
+        return self.visual.config.spatial_merge_size
+
+    @spatial_merge_size.setter
+    def spatial_merge_size(self, value):
+        # allow assignment
+        self.visual.config.spatial_merge_size = value

+ 229 - 0
deconstruct_SQI/colpali/colpali_engine/models/qwen_omni/colqwen_omni/processing_colqwen_omni.py

@@ -0,0 +1,229 @@
+from typing import ClassVar, List, Optional, Tuple, Union
+
+import torch
+from PIL import Image
+from transformers import BatchFeature
+from transformers.models.qwen2_5_omni import Qwen2_5OmniProcessor
+
+from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor
+
+
+class ColQwen2_5OmniProcessor(BaseVisualRetrieverProcessor, Qwen2_5OmniProcessor):  # noqa: N801
+    """
+    Processor for ColQwen2.5 Omni.
+
+    Args:
+        *args: Variable length argument list to be passed to the parent `Qwen2VLProcessor` class.
+        max_num_visual_tokens: The maximum number of visual tokens that can be processed by the model.
+        **kwargs: Arbitrary keyword arguments to be passed to the parent `Qwen2VLProcessor` class.
+    """
+
+    query_prefix: ClassVar[str] = "Query: "
+    query_augmentation_token: ClassVar[str] = "<|endoftext|>"
+
+    def __init__(
+        self,
+        *args,
+        **kwargs,
+    ):
+        super().__init__(*args, **kwargs)
+        self.tokenizer.padding_side = "left"
+        self.chat_template = self.tokenizer.chat_template
+
+    @classmethod
+    def from_pretrained(
+        cls,
+        *args,
+        device_map: Optional[str] = None,
+        **kwargs,
+    ):
+        instance = super().from_pretrained(
+            *args,
+            device_map=device_map,
+            **kwargs,
+        )
+
+        # if "max_num_visual_tokens" in kwargs:
+        #     instance.image_processor.max_pixels = kwargs["max_num_visual_tokens"] * 28 * 28
+        #     instance.image_processor.size["longest_edge"] = instance.image_processor.max_pixels
+
+        return instance
+
+    def process_conversations(self, conversations: List[dict]) -> BatchFeature:
+        batch_doc = super().apply_chat_template(
+            conversations,
+            # transformers is bugged and doesn't support standalone audio when this flag is True
+            load_audio_from_video=False,
+            add_generation_prompt=True,
+            tokenize=True,
+            return_dict=True,
+            return_tensors="pt",
+            # video_fps=1,
+            padding=True,
+            use_audio_in_video=False,
+        )
+
+        # if "pixel_values" in batch_doc:
+        #     # # NOTE: The following adjustment ensures correct behavior with DDP on multiple GPUs.
+        #     offsets = batch_doc["image_grid_thw"][:, 1] * batch_doc["image_grid_thw"][:, 2]  # (batch_size,)
+
+        #     # Split the pixel_values tensor into a list of tensors, one per image
+        #     pixel_values = list(
+        #         torch.split(batch_doc["pixel_values"], offsets.tolist())
+        #     )  # [(num_patches_image_0, pixel_values), ..., (num_patches_image_n, pixel_values)]
+
+        #     # Pad the list of pixel_value tensors to the same length along the sequence dimension
+        #     batch_doc["pixel_values"] = torch.nn.utils.rnn.pad_sequence(
+        #         pixel_values, batch_first=True
+        #     )  # (batch_size, max_num_patches, pixel_values)
+        return batch_doc
+
+    def process_images(self, images: List[Image.Image]) -> BatchFeature:
+        """
+        Process images for ColQwen2.5.
+
+        Args:
+            images: List of PIL images or paths/urls to images.
+        """
+
+        conversations = [
+            [
+                {
+                    "role": "system",
+                    "content": [
+                        {
+                            "type": "text",
+                            "text": "You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, capable "
+                            "of perceiving auditory and visual inputs, as well as generating text and speech.",
+                        }
+                    ],
+                },
+                {
+                    "role": "user",
+                    "content": [
+                        {"type": "image", "image": image.convert("RGB")},
+                        {"type": "text", "text": "Describe the content."},
+                    ],
+                },
+            ]
+            for image in images
+        ]
+        batch_doc = self.process_conversations(conversations)
+        return batch_doc
+
+    def process_audios(self, audios) -> BatchFeature:
+        """
+        Process audios for ColQwen2.5.
+
+        Args:
+            audios: List of Numpy array of WAV files (or paths/URLs to WAV).
+        """
+
+        conversations = [
+            [
+                {
+                    "role": "system",
+                    "content": [
+                        {
+                            "type": "text",
+                            "text": "You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, capable "
+                            "of perceiving auditory and visual inputs, as well as generating text and speech.",
+                        }
+                    ],
+                },
+                {
+                    "role": "user",
+                    "content": [{"type": "audio", "path": audio}, {"type": "text", "text": "Describe the content."}],
+                },
+            ]
+            for audio in audios
+        ]
+        batch_doc = self.process_conversations(conversations)
+        return batch_doc
+
+    def process_videos(self, videos) -> BatchFeature:
+        """
+        Process videos for ColQwen2.5.
+
+        Args:
+            videos: List of videos or paths/urls to videos. Each video can be a 4D NumPy array or PyTorch
+            tensor, or a nested list of 3D frames. Both channels-first and channels-last formats are supported.
+        """
+
+        conversations = [
+            [
+                {
+                    "role": "system",
+                    "content": [
+                        {
+                            "type": "text",
+                            "text": "You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, capable "
+                            "of perceiving auditory and visual inputs, as well as generating text and speech.",
+                        }
+                    ],
+                },
+                {
+                    "role": "user",
+                    "content": [{"type": "video", "path": video}, {"type": "text", "text": "Describe the content."}],
+                },
+            ]
+            for video in videos
+        ]
+        batch_doc = self.process_conversations(conversations)
+        return batch_doc
+
+    def process_texts(
+        self,
+        texts: List[str],
+    ) -> BatchFeature:
+        """
+        Process texts for ColQwenOmni.
+        """
+
+        conversations = [
+            [
+                {
+                    "role": "system",
+                    "content": [
+                        {
+                            "type": "text",
+                            "text": "You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, capable "
+                            "of perceiving auditory and visual inputs, as well as generating text and speech.",
+                        }
+                    ],
+                },
+                {
+                    "role": "user",
+                    "content": [{"type": "text", "text": text}],
+                },
+            ]
+            for text in texts
+        ]
+        batch_query = self.process_conversations(conversations)
+        return batch_query
+
+    def score(
+        self,
+        qs: List[torch.Tensor],
+        ps: List[torch.Tensor],
+        device: Optional[Union[str, torch.device]] = None,
+        **kwargs,
+    ) -> torch.Tensor:
+        """
+        Compute the MaxSim score (ColBERT-like) for the given multi-vector query and passage embeddings.
+        """
+        return self.score_multi_vector(qs, ps, device=device, **kwargs)
+
+    def get_n_patches(
+        self,
+        image_size: Tuple[int, int],
+        spatial_merge_size: int,
+    ) -> Tuple[int, int]:
+        """
+        Get the number of patches (n_patches_x, n_patches_y) that will be used to process an image of
+        size (height, width) with the given patch size.
+
+        The `spatial_merge_size` is the number of patches that will be merged spatially. It is stored in
+        as a `Qwen2VLForConditionalGeneration` attribute under `model.spatial_merge_size`.
+        """
+        raise NotImplementedError("ColQwen2.5 Omni does not support the `get_n_patches` method. ")

+ 3 - 0
deconstruct_SQI/colpali/colpali_engine/trainer/__init__.py

@@ -0,0 +1,3 @@
+from .colmodel_torch_training import ColModelTorchTraining
+from .colmodel_training import ColModelTraining, ColModelTrainingConfig
+from .contrastive_trainer import ContrastiveTrainer

+ 261 - 0
deconstruct_SQI/colpali/colpali_engine/trainer/colmodel_torch_training.py

@@ -0,0 +1,261 @@
+import os
+
+import torch
+import torch.distributed as dist
+from torch.distributed._functional_collectives import all_gather_tensor_autograd
+from torch.nn.parallel import DistributedDataParallel
+from torch.utils.data import DataLoader, DistributedSampler
+from tqdm.auto import tqdm
+
+from colpali_engine.collators import VisualRetrieverCollator
+from colpali_engine.trainer.colmodel_training import ColModelTrainingConfig
+from colpali_engine.utils.gpu_stats import print_gpu_utilization
+
+
+class ColModelTorchTraining:
+    """
+    Class that contains the training and evaluation logic for a ColVision model.
+    """
+
+    def __init__(self, config: ColModelTrainingConfig) -> None:
+        self.config = config
+        self.model = self.config.model
+        self.current_git_hash = os.popen("git rev-parse HEAD").read().strip()
+        self.train_dataset = self.config.train_dataset
+        self.eval_dataset = self.config.eval_dataset
+        self.collator = VisualRetrieverCollator(
+            processor=self.config.processor,
+            max_length=self.config.max_length,
+        )
+
+        # Initialize distributed if needed
+        if dist.is_available() and not dist.is_initialized():
+            dist.init_process_group(backend="nccl", init_method="env://")
+            print("Distributed process group initialized.")
+        self.local_rank = int(os.environ.get("LOCAL_RANK", 0))
+        self.world_size = dist.get_world_size() if dist.is_initialized() else 1
+        print(f"Local rank: {self.local_rank}, World size: {self.world_size}")
+
+        device = torch.device(f"cuda:{self.local_rank}" if torch.cuda.is_available() else "cpu")
+        torch.cuda.set_device(device)
+        self.model.to(device)
+
+        # Gradient checkpointing if supported
+        if getattr(self.config.tr_args, "gradient_checkpointing", False):
+            # huggingface models expose this
+            try:
+                self.model.gradient_checkpointing_enable(
+                    gradient_checkpointing_kwargs=self.config.tr_args.gradient_checkpointing_kwargs
+                )
+                if self._is_rank0():
+                    print("Gradient checkpointing enabled.")
+            except Exception as e:
+                if self._is_rank0():
+                    print("Warning: gradient_checkpointing_enable() not supported by model.")
+                    print(e)
+
+        self.model = DistributedDataParallel(self.model, device_ids=[self.local_rank], output_device=self.local_rank)
+
+        self.model = torch.compile(
+            self.model,
+            backend="inductor",
+            dynamic=True,  # or True if you know shapes will vary a lot
+        )
+
+    def _is_rank0(self) -> bool:
+        return not dist.is_initialized() or dist.get_rank() == 0
+
+    def train(self) -> None:
+        # Mixed precision setup
+        use_amp = getattr(self.config, "use_amp", False)
+        scaler = torch.amp.GradScaler("cuda") if use_amp else None
+        max_grad_norm = getattr(self.config.tr_args, "max_grad_norm", None)
+        print(f"Using AMP: {use_amp}, Max grad norm: {max_grad_norm}")
+
+        sampler = DistributedSampler(self.train_dataset) if dist.is_initialized() else None
+        train_loader = DataLoader(
+            self.train_dataset,
+            batch_size=self.config.tr_args.per_device_train_batch_size,
+            sampler=sampler,
+            collate_fn=self.collator,
+            num_workers=self.config.tr_args.dataloader_num_workers,
+            prefetch_factor=2,
+            pin_memory=True,
+            drop_last=True,
+        )
+
+        # Evaluation loader
+        eval_loader = None
+        if self.config.eval_dataset_loader is not None:
+            eval_loader = DataLoader(
+                self.eval_dataset,
+                batch_size=self.config.tr_args.per_device_eval_batch_size,
+                collate_fn=self.collator,
+            )
+        elif self._is_rank0():
+            print("No eval dataset provided. Skipping evaluation.")
+
+        optimizer = torch.optim.AdamW(
+            self.model.parameters(),
+            lr=self.config.tr_args.learning_rate,
+            weight_decay=self.config.tr_args.weight_decay,
+        )
+        num_training_steps = self.config.tr_args.num_train_epochs * len(train_loader)
+        warmup_steps = self.config.tr_args.warmup_steps
+
+        def lr_lambda(current_step):
+            if current_step < warmup_steps:
+                return float(current_step) / float(max(1, warmup_steps))
+            progress = float(current_step - warmup_steps) / float(max(1, num_training_steps - warmup_steps))
+            return max(0.1, 1.0 - (1.0 - 0.1) * progress)
+
+        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
+
+        loss_fn = self.config.loss_func
+
+        def gather_with_grad(x: torch.Tensor) -> torch.Tensor:
+            return all_gather_tensor_autograd(x, gather_dim=0, group=dist.group.WORLD)
+
+        # Training loop
+        # only rank0 should display
+        if self._is_rank0():
+            pbar = tqdm(total=num_training_steps, desc="Training", leave=True)
+        else:
+            pbar = None
+
+        for epoch in range(self.config.tr_args.num_train_epochs):
+            if sampler:
+                sampler.set_epoch(epoch)
+            self.model.train()
+
+            for step, batch in enumerate(train_loader):
+                # Move batch to device
+                batch = {k: v.to(self.model.device, non_blocking=True) for k, v in batch.items()}
+
+                # Forward with optional AMP
+                with torch.amp.autocast("cuda", enabled=use_amp):
+                    q_embed = self.model(
+                        input_ids=batch["query_input_ids"], attention_mask=batch["query_attention_mask"]
+                    )
+                    d_embed = self.model(**{k[4:]: v for k, v in batch.items() if k.startswith("doc_")})
+                    neg_embed = None
+                    if "neg_doc_input_ids" in batch:
+                        neg_embed = self.model(**{k[8:]: v for k, v in batch.items() if k.startswith("neg_doc")})
+
+                    def pad_to_max_len_right(x: torch.Tensor) -> torch.Tensor:
+                        """
+                        Right-pad x along dim=1 so that all ranks share the same length.
+
+                        Args:
+                            x: Tensor of shape [B, L, D] (or [B, L] if 2D)
+                        Returns:
+                            Padded tensor of shape [B, max_L, D], with zeros on the right.
+                        """
+                        # 1) local length
+                        local_len = x.size(1)
+                        # 2) get global max length
+                        len_tensor = torch.tensor(local_len, device=x.device)
+                        dist.all_reduce(len_tensor, op=dist.ReduceOp.MAX)
+                        max_len = len_tensor.item()
+
+                        # 3) if shorter, pad on the right of dim=1
+                        if local_len < max_len:
+                            pad_amount = max_len - local_len
+                            # torch.nn.functional.pad takes (D_left, D_right, L_left, L_right)
+                            x = torch.nn.functional.pad(x, (0, 0, 0, pad_amount), value=0.0)
+                        return x
+
+                    # Usage before gathering:
+                    d_embed = pad_to_max_len_right(d_embed)
+                    if neg_embed is not None:
+                        neg_embed = pad_to_max_len_right(neg_embed)
+
+                    # Now safe to all_gather:
+                    d_global = gather_with_grad(d_embed)
+                    n_global = gather_with_grad(neg_embed) if neg_embed is not None else None
+
+                    # loss = loss_fn(q_global, d_global) if n_global is None else loss_fn(q_global, d_global, n_global)
+                    loss = (
+                        loss_fn(q_embed, d_global, offset=(dist.get_rank() * batch["query_input_ids"].shape[0]))
+                        if n_global is None
+                        else loss_fn(
+                            q_embed, d_global, n_global, offset=(dist.get_rank() * batch["query_input_ids"].shape[0])
+                        )
+                    )
+
+                # Backward
+                if use_amp:
+                    scaler.scale(loss).backward()
+                    if max_grad_norm:
+                        scaler.unscale_(optimizer)
+                        torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_grad_norm)
+                    scaler.step(optimizer)
+                    scaler.update()
+                else:
+                    loss.backward()
+                    if max_grad_norm:
+                        torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_grad_norm)
+                    optimizer.step()
+
+                scheduler.step()
+                optimizer.zero_grad()
+
+                if self._is_rank0() and not isinstance(train_loader, DataLoader):
+                    # advance the global bar
+                    # you can also show epoch/step in the postfix if you like:
+                    pbar.set_postfix(epoch=epoch + 1, step=step + 1, refresh=False)
+                    pbar.update(1)
+
+                if self._is_rank0() and step % 10 == 0:
+                    print(f"Step {step}/{len(train_loader)}")
+                    print(f"Query embedding shape: {q_embed.shape}")
+                    print(f"Document embedding shape: {d_embed.shape}")
+                    if neg_embed is not None:
+                        print(f"Negative document embedding shape: {neg_embed.shape}")
+
+                    # print(f"Gathered query embedding shape: {q_global.shape}")
+                    print(f"Gathered document embedding shape: {d_global.shape}")
+                    if neg_embed is not None:
+                        print(f"Gathered negative document embedding shape: {n_global.shape}")
+
+                    print(f"Batch size: {batch['query_input_ids'].shape[0]}")
+
+                    print_gpu_utilization()
+                    print(f"Local loss: {loss.item()}")
+                    print(f"Learning rate: {optimizer.param_groups[0]['lr']}")
+                    print(f"Epoch: {epoch + 1}/{self.config.tr_args.num_train_epochs}")
+                    print(f"World size: {dist.get_world_size()}")
+                    # with torch.no_grad():
+                    #     avg_loss = loss.detach()
+                    #     dist.all_reduce(avg_loss, op=dist.ReduceOp.SUM)
+                    #     avg_loss /= dist.get_world_size()
+                    #     print(f"Local loss: {avg_loss.item()}")
+
+            # Optional evaluation
+            if eval_loader and self._is_rank0():
+                self.evaluate(eval_loader)
+
+        self.model = self.model.module if hasattr(self.model, "module") else self.model
+        # Final actions
+        if self._is_rank0():
+            pbar.close()
+            print("Training complete. Saving model.")
+            self.save()
+            print("Model saved.")
+
+        if dist.is_initialized():
+            dist.destroy_process_group()
+
+    def eval(self) -> None:
+        raise NotImplementedError("Evaluation is not implemented yet.")
+
+    def save(self):
+        """
+        Save the model with its training config, as well as the tokenizer and processor if provided.
+        """
+        self.model.save_pretrained(self.config.output_dir)
+        self.config.processor.save_pretrained(self.config.output_dir)
+
+        # Save git hash of the commit at beginning of training
+        with open(f"{self.config.output_dir}/git_hash.txt", "w") as f:
+            f.write(self.current_git_hash)

+ 118 - 0
deconstruct_SQI/colpali/colpali_engine/trainer/colmodel_training.py

@@ -0,0 +1,118 @@
+import os
+from dataclasses import dataclass
+from typing import Callable, Dict, List, Optional, Union
+
+from peft import LoraConfig, PeftModel, get_peft_model
+from transformers import (
+    PreTrainedModel,
+    TrainingArguments,
+)
+
+from colpali_engine.collators import VisualRetrieverCollator
+from colpali_engine.data.dataset import ColPaliEngineDataset
+from colpali_engine.loss.late_interaction_losses import (
+    ColbertLoss,
+)
+from colpali_engine.trainer.contrastive_trainer import ContrastiveTrainer
+from colpali_engine.utils.gpu_stats import print_gpu_utilization, print_summary
+from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor
+
+
+@dataclass
+class ColModelTrainingConfig:
+    model: Union[PreTrainedModel, PeftModel]
+    processor: BaseVisualRetrieverProcessor
+    train_dataset: Union[ColPaliEngineDataset, List[ColPaliEngineDataset]]
+    eval_dataset: Optional[Union[ColPaliEngineDataset, Dict[str, ColPaliEngineDataset]]] = None
+    tr_args: Optional[TrainingArguments] = None
+    output_dir: Optional[str] = None
+    max_length: int = 256
+    run_eval: bool = True
+    run_train: bool = True
+    peft_config: Optional[LoraConfig] = None
+    loss_func: Optional[Callable] = ColbertLoss()
+    pretrained_peft_model_name_or_path: Optional[str] = None
+    """
+    Config class used for training a ColVision model.
+    """
+
+    def __post_init__(self):
+        """
+        Initialize the model and tokenizer if not provided
+        """
+        if self.output_dir is None:
+            sanitized_name = str(self.model.name_or_path).replace("/", "_")
+            self.output_dir = f"./models/{sanitized_name}"
+
+        if self.tr_args is None:
+            print("No training arguments provided. Using default.")
+            self.tr_args = TrainingArguments(output_dir=self.output_dir)
+        elif self.tr_args.output_dir is None or self.tr_args.output_dir == "trainer_output":
+            self.tr_args.output_dir = self.output_dir
+
+        if isinstance(self.tr_args.learning_rate, str):
+            print("Casting learning rate to float")
+            self.tr_args.learning_rate = float(self.tr_args.learning_rate)
+
+        self.tr_args.remove_unused_columns = False
+
+        if self.pretrained_peft_model_name_or_path is not None:
+            print("Loading pretrained PEFT model")
+            self.model.load_adapter(self.pretrained_peft_model_name_or_path, is_trainable=True)
+
+        if self.peft_config is not None:
+            print("Configurating PEFT model")
+            if self.pretrained_peft_model_name_or_path is None:
+                self.model = get_peft_model(self.model, self.peft_config)
+                self.model.print_trainable_parameters()
+            else:
+                print(f"Adapter already loaded from {self.pretrained_peft_model_name_or_path}. Not overwriting.")
+
+    print_gpu_utilization()
+
+
+class ColModelTraining:
+    """
+    Class that contains the training and evaluation logic for a ColVision model.
+    """
+
+    def __init__(self, config: ColModelTrainingConfig) -> None:
+        self.config = config
+        self.model = self.config.model
+        self.current_git_hash = os.popen("git rev-parse HEAD").read().strip()
+        self.train_dataset = self.config.train_dataset
+        self.eval_dataset = self.config.eval_dataset
+        self.collator = VisualRetrieverCollator(
+            processor=self.config.processor,
+            max_length=self.config.max_length,
+        )
+
+    def train(self) -> None:
+        trainer = ContrastiveTrainer(
+            model=self.model,
+            train_dataset=self.train_dataset,
+            eval_dataset=self.eval_dataset,
+            args=self.config.tr_args,
+            data_collator=self.collator,
+            loss_func=self.config.loss_func,
+            is_vision_model=self.config.processor is not None,
+        )
+
+        trainer.args.remove_unused_columns = False
+
+        result = trainer.train(resume_from_checkpoint=self.config.tr_args.resume_from_checkpoint)
+        print_summary(result)
+
+    def eval(self) -> None:
+        raise NotImplementedError("Evaluation is not implemented yet.")
+
+    def save(self):
+        """
+        Save the model with its training config, as well as the tokenizer and processor if provided.
+        """
+        self.model.save_pretrained(self.config.output_dir)
+        self.config.processor.save_pretrained(self.config.output_dir)
+
+        # Save git hash of the commit at beginning of training
+        with open(f"{self.config.output_dir}/git_hash.txt", "w") as f:
+            f.write(self.current_git_hash)

+ 225 - 0
deconstruct_SQI/colpali/colpali_engine/trainer/contrastive_trainer.py

@@ -0,0 +1,225 @@
+from functools import partial
+from typing import Optional
+
+import datasets
+import torch
+from torch.distributed.nn.functional import all_gather  # PyTorch ≥ 2.1
+from torch.utils.data import ConcatDataset, DataLoader, Dataset
+from transformers import Trainer, is_datasets_available
+from transformers.trainer_utils import seed_worker
+
+from colpali_engine.data.sampler import SingleDatasetBatchSampler
+
+
+def concat_all_gather(t: torch.Tensor) -> torch.Tensor:
+    if torch.distributed.is_available() and torch.distributed.is_initialized():
+        return torch.cat(all_gather(t), dim=0)  # keeps grad graph
+    return t
+
+
+def concat_datasets(datasets: list[Dataset], batch_size: int) -> Dataset:
+    """
+    Concatenates a list of datasets into a single dataset.
+    This is a utility function to handle the case where multiple datasets are provided.
+    """
+    # round down each dataset if not divible by global batch size
+    for i in range(len(datasets)):
+        if len(datasets[i]) % batch_size != 0:
+            total_samples = (len(datasets[i]) // batch_size) * batch_size
+            datasets[i] = datasets[i].take(total_samples)
+
+    return ConcatDataset(datasets)
+
+
+class ContrastiveTrainer(Trainer):
+    def __init__(self, loss_func, is_vision_model, compute_symetric_loss=False, *args, **kwargs):
+        if isinstance(kwargs["train_dataset"], list):
+            train_dataset_list = kwargs["train_dataset"]
+            kwargs["train_dataset"] = concat_datasets(train_dataset_list, batch_size=kwargs["args"].train_batch_size)
+        else:
+            train_dataset_list = None
+
+        if isinstance(kwargs["eval_dataset"], list):
+            eval_dataset_list = kwargs["eval_dataset"]
+            kwargs["eval_dataset"] = concat_datasets(eval_dataset_list)
+        else:
+            eval_dataset_list = None
+
+        super().__init__(*args, **kwargs)
+        self.loss_func = loss_func
+        self.is_vision_model = is_vision_model  # Unused argument, will be removed in 0.4.0
+        self.args.remove_unused_columns = False  # Safety, don't remove dataset columns from dataloader
+        self.train_dataset_list = train_dataset_list
+        self.eval_dataset_list = eval_dataset_list
+        self.compute_symetric_loss = compute_symetric_loss
+
+    def get_train_dataloader(self) -> DataLoader:
+        """
+        Returns the training [`~torch.utils.data.DataLoader`].
+
+        Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed
+        training if necessary) otherwise.
+
+        Subclass and override this method if you want to inject some custom behavior.
+        """
+        if self.train_dataset is None:
+            raise ValueError("Trainer: training requires a train_dataset.")
+
+        if self.train_dataset_list is None:
+            # If no dataset list, use the default behavior
+            return super().get_train_dataloader()
+
+        dataset = self.train_dataset
+        description = "Training"
+        sampler_fn = self._get_train_sampler
+        is_training = True
+        dataloader_key = None
+
+        data_collator = self.data_collator
+        if is_datasets_available() and isinstance(dataset, datasets.Dataset):
+            dataset = self._remove_unused_columns(dataset, description=description)
+        else:
+            data_collator = self._get_collator_with_removed_columns(self.data_collator, description=description)
+
+        self.query_prefix = data_collator.query_prefix
+        self.pos_prefix = data_collator.pos_doc_prefix
+        self.neg_prefix = data_collator.neg_doc_prefix
+
+        dataloader_params = {
+            ######### don't set batch size, mutually exclusive from batch sampler ######
+            "collate_fn": data_collator,
+            "num_workers": self.args.dataloader_num_workers,
+            "pin_memory": self.args.dataloader_pin_memory,
+            "persistent_workers": self.args.dataloader_persistent_workers,
+        }
+
+        if not isinstance(dataset, torch.utils.data.IterableDataset):
+            if sampler_fn is not None:
+                ###### batch_sampler set instead of sampler in trainer code #######
+                dataloader_params["batch_sampler"] = sampler_fn()
+            dataloader_params["drop_last"] = self.args.dataloader_drop_last
+            dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
+            if is_training:
+                dataloader_params["worker_init_fn"] = partial(
+                    seed_worker, num_workers=self.args.dataloader_num_workers, rank=self.args.process_index
+                )
+
+        dataloader = DataLoader(dataset, **dataloader_params)
+
+        # Accelerator.free_memory() will destroy the references, so
+        # we need to store the non-prepared version for eval dataloaders.
+        if dataloader_key is not None and self.args.dataloader_persistent_workers:
+            if hasattr(self, "_eval_dataloaders"):
+                self._eval_dataloaders[dataloader_key] = dataloader
+            else:
+                self._eval_dataloaders = {dataloader_key: dataloader}
+
+        return self.accelerator.prepare(dataloader)
+
+    def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
+        if self.train_dataset_list is None:
+            return super()._get_train_sampler()
+
+        # Use SingleDatasetBatchSampler to ensure that each dataset in the list is sampled independently
+        # Note: Surely breaks in distributed training
+        # TODO: fix this
+        generator = torch.Generator()
+        generator.manual_seed(self.args.seed)
+        return SingleDatasetBatchSampler(
+            self.train_dataset_list,
+            self.args.train_batch_size,
+            drop_last=self.args.dataloader_drop_last,
+            generator=generator,
+        )
+
+    def _compute_loss_from_outputs(
+        self,
+        query_outputs,
+        pos_target_outputs,
+        neg_target_outputs=None,
+    ):
+        offset = 0
+        batch_size = query_outputs.size(0)
+        if self.accelerator.num_processes > 1 and self.accelerator.sync_gradients:
+            # gather docs across all processes
+            pos_target_outputs = self.accelerator.pad_across_processes(
+                pos_target_outputs, dim=1, pad_index=0, pad_first=True
+            )
+            pos_target_outputs = concat_all_gather(pos_target_outputs)
+            rank = self.accelerator.process_index
+            offset = rank * batch_size
+
+        if neg_target_outputs is not None:
+            loss = self.loss_func(
+                query_embeddings=query_outputs,
+                doc_embeddings=pos_target_outputs,
+                neg_doc_embeddings=neg_target_outputs,
+                offset=offset,
+            )
+        else:
+            loss = self.loss_func(query_embeddings=query_outputs, doc_embeddings=pos_target_outputs, offset=offset)
+
+        return loss
+
+    def _reshape_neg_doc_inputs(self, inputs):
+        """
+        Helper function to reshape negative doc inputs to (batch_size * num_neg_docs, ...)
+        """
+        neg_doc_inputs = {k[len(self.neg_prefix) :]: v for k, v in inputs.items() if k.startswith(self.neg_prefix)}
+
+        for k in neg_doc_inputs:
+            # go from (batch_size, num_neg_docs, ...) to (batch_size * num_neg_docs, ...)
+            neg_doc_inputs[k] = neg_doc_inputs[k].view(-1, *neg_doc_inputs[k].shape[2:])
+
+        return neg_doc_inputs
+
+    def _reshape_neg_doc_outputs(self, neg_doc_outputs, num_neg_docs):
+        """
+        Helper function to reshape negative doc outputs to (batch_size, num_neg_docs, ...)
+        """
+        neg_doc_outputs = neg_doc_outputs.view(-1, num_neg_docs, *neg_doc_outputs.shape[1:])
+
+        return neg_doc_outputs
+
+    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
+        query_inputs = {k[len(self.query_prefix) :]: v for k, v in inputs.items() if k.startswith(self.query_prefix)}
+        query_outputs = model(**query_inputs)
+        # feed only kwargs with 'doc_' prefix
+        doc_inputs = {k[len(self.pos_prefix) :]: v for k, v in inputs.items() if k.startswith(self.pos_prefix)}
+        doc_outputs = model(**doc_inputs)
+        if "neg_doc_input_ids" in inputs:
+            # Negative docs are not gathered across processes, so we can use them without offset
+            num_negs = inputs["neg_doc_input_ids"].size(1)
+            neg_doc_inputs = self._reshape_neg_doc_inputs(inputs)
+            neg_doc_outputs = model(**neg_doc_inputs)
+            neg_doc_outputs = self._reshape_neg_doc_outputs(neg_doc_outputs, num_negs)
+        else:
+            neg_doc_outputs = None
+
+        # query -> doc loss
+        loss = self._compute_loss_from_outputs(query_outputs, doc_outputs, neg_doc_outputs)
+
+        if self.compute_symetric_loss:
+            assert neg_doc_outputs is None, "Symmetric loss is not compatible with negative documents."
+            # doc -> query loss
+            sym_loss = self._compute_loss_from_outputs(doc_outputs, query_outputs)
+            loss = (loss + sym_loss) / 2
+
+        return (loss, (query_outputs, doc_outputs)) if return_outputs else loss
+
+    def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=True):
+        """This function is used to generate predictions and return the loss for the given inputs."""
+        if not prediction_loss_only:
+            raise ValueError("prediction_step is only called with prediction_loss_only=True")
+
+        with torch.no_grad():
+            # feed only kwargs with 'doc_' prefix
+            doc_outputs = model(**{k[4:]: v for k, v in inputs.items() if k.startswith("doc")})
+            query_outputs = model(input_ids=inputs["query_input_ids"], attention_mask=inputs["query_attention_mask"])
+            if "neg_doc_input_ids" in inputs:
+                neg_doc_outputs = model(**{k[8:]: v for k, v in inputs.items() if k.startswith("neg_doc")})
+                loss = self.loss_func(query_outputs, doc_outputs, neg_doc_outputs)
+                return loss, None, None
+
+            loss = self.loss_func(query_outputs, doc_outputs)
+            return loss, None, None

+ 0 - 0
deconstruct_SQI/colpali/colpali_engine/utils/__init__.py


+ 268 - 0
deconstruct_SQI/colpali/colpali_engine/utils/dataset_transformation.py

@@ -0,0 +1,268 @@
+import os
+from typing import List, Tuple, cast
+
+from datasets import Dataset, DatasetDict, concatenate_datasets, load_dataset
+from PIL import Image
+
+from colpali_engine.data.dataset import ColPaliEngineDataset, Corpus
+
+USE_LOCAL_DATASET = os.environ.get("USE_LOCAL_DATASET", "1") == "1"
+
+
+def load_train_set() -> ColPaliEngineDataset:
+    base_path = "./data_dir/" if USE_LOCAL_DATASET else "vidore/"
+    dataset = load_dataset(base_path + "colpali_train_set", split="train")
+
+    train_dataset = ColPaliEngineDataset(dataset, pos_target_column_name="image")
+
+    return train_dataset
+
+
+def load_eval_set(dataset_path) -> ColPaliEngineDataset:
+    dataset = load_dataset(dataset_path, split="test")
+
+    return dataset
+
+
+def load_train_set_ir(num_negs=0) -> ColPaliEngineDataset:
+    """Returns the query dataset, then the anchor dataset with the documents, then the dataset type"""
+    base_path = "./data_dir/" if USE_LOCAL_DATASET else "manu/"
+    corpus_data = load_dataset(base_path + "colpali-corpus", split="train")
+    corpus = Corpus(corpus_data=corpus_data, doc_column_name="image")
+
+    dataset = load_dataset(base_path + "colpali-queries", split="train")
+
+    print("Dataset size:", len(dataset))
+    # filter out queries with "gold_in_top_100" == False
+    dataset = dataset.filter(lambda x: x["gold_in_top_100"], num_proc=16)
+    if num_negs > 0:
+        # keep only top 5 negative passages
+        dataset = dataset.map(lambda x: {"negative_passages": x["negative_passages"][:num_negs]})
+    print("Dataset size after filtering:", len(dataset))
+
+    train_dataset = ColPaliEngineDataset(
+        data=dataset,
+        corpus=corpus,
+        pos_target_column_name="positive_passages",
+        neg_target_column_name="negative_passages" if num_negs else None,
+    )
+
+    return train_dataset
+
+
+def load_train_set_detailed() -> DatasetDict:
+    ds_paths = [
+        "infovqa_train",
+        "docvqa_train",
+        "arxivqa_train",
+        "tatdqa_train",
+        "syntheticDocQA_government_reports_train",
+        "syntheticDocQA_healthcare_industry_train",
+        "syntheticDocQA_artificial_intelligence_train",
+        "syntheticDocQA_energy_train",
+    ]
+    base_path = "./data_dir/" if USE_LOCAL_DATASET else "vidore/"
+    ds_tot = []
+    for path in ds_paths:
+        cpath = base_path + path
+        ds = cast(Dataset, load_dataset(cpath, split="train"))
+        if "arxivqa" in path:
+            # subsample 10k
+            ds = ds.shuffle(42).select(range(10000))
+        ds_tot.append(ds)
+
+    dataset = cast(Dataset, concatenate_datasets(ds_tot))
+    dataset = dataset.shuffle(seed=42)
+    # split into train and test
+    dataset_eval = dataset.select(range(500))
+    dataset = dataset.select(range(500, len(dataset)))
+    ds_dict = DatasetDict({"train": dataset, "test": dataset_eval})
+    return ds_dict
+
+
+def load_train_set_with_tabfquad() -> DatasetDict:
+    ds_paths = [
+        "infovqa_train",
+        "docvqa_train",
+        "arxivqa_train",
+        "tatdqa_train",
+        "tabfquad_train_subsampled",
+        "syntheticDocQA_government_reports_train",
+        "syntheticDocQA_healthcare_industry_train",
+        "syntheticDocQA_artificial_intelligence_train",
+        "syntheticDocQA_energy_train",
+    ]
+    base_path = "./data_dir/" if USE_LOCAL_DATASET else "vidore/"
+    ds_tot = []
+    for path in ds_paths:
+        cpath = base_path + path
+        ds = cast(Dataset, load_dataset(cpath, split="train"))
+        if "arxivqa" in path:
+            # subsample 10k
+            ds = ds.shuffle(42).select(range(10000))
+        ds_tot.append(ds)
+
+    dataset = cast(Dataset, concatenate_datasets(ds_tot))
+    dataset = dataset.shuffle(seed=42)
+    # split into train and test
+    dataset_eval = dataset.select(range(500))
+    dataset = dataset.select(range(500, len(dataset)))
+    ds_dict = DatasetDict({"train": dataset, "test": dataset_eval})
+    return ds_dict
+
+
+def load_docmatix_ir_negs() -> Tuple[DatasetDict, Dataset, str]:
+    """Returns the query dataset, then the anchor dataset with the documents, then the dataset type"""
+    base_path = "./data_dir/" if USE_LOCAL_DATASET else "Tevatron/"
+    dataset = cast(Dataset, load_dataset(base_path + "docmatix-ir", split="train"))
+    # dataset = dataset.select(range(100500))
+
+    dataset_eval = dataset.select(range(500))
+    dataset = dataset.select(range(500, len(dataset)))
+    ds_dict = DatasetDict({"train": dataset, "test": dataset_eval})
+
+    base_path = "./data_dir/" if USE_LOCAL_DATASET else "HuggingFaceM4/"
+    anchor_ds = cast(Dataset, load_dataset(base_path + "Docmatix", "images", split="train"))
+
+    return ds_dict, anchor_ds, "docmatix"
+
+
+def load_wikiss() -> Tuple[DatasetDict, Dataset, str]:
+    """Returns the query dataset, then the anchor dataset with the documents, then the dataset type"""
+    base_path = "./data_dir/" if USE_LOCAL_DATASET else "Tevatron/"
+    dataset = cast(Dataset, load_dataset(base_path + "wiki-ss-nq", data_files="train.jsonl", split="train"))
+    # dataset = dataset.select(range(400500))
+    dataset_eval = dataset.select(range(500))
+    dataset = dataset.select(range(500, len(dataset)))
+    ds_dict = DatasetDict({"train": dataset, "test": dataset_eval})
+
+    base_path = "./data_dir/" if USE_LOCAL_DATASET else "HuggingFaceM4/"
+    anchor_ds = cast(Dataset, load_dataset(base_path + "wiki-ss-corpus", split="train"))
+
+    return ds_dict, anchor_ds, "wikiss"
+
+
+def load_train_set_with_docmatix() -> DatasetDict:
+    ds_paths = [
+        "infovqa_train",
+        "docvqa_train",
+        "arxivqa_train",
+        "tatdqa_train",
+        "tabfquad_train_subsampled",
+        "syntheticDocQA_government_reports_train",
+        "syntheticDocQA_healthcare_industry_train",
+        "syntheticDocQA_artificial_intelligence_train",
+        "syntheticDocQA_energy_train",
+        "Docmatix_filtered_train",
+    ]
+    base_path = "./data_dir/" if USE_LOCAL_DATASET else "vidore/"
+    ds_tot: List[Dataset] = []
+    for path in ds_paths:
+        cpath = base_path + path
+        ds = cast(Dataset, load_dataset(cpath, split="train"))
+        if "arxivqa" in path:
+            # subsample 10k
+            ds = ds.shuffle(42).select(range(10000))
+        ds_tot.append(ds)
+
+    dataset = concatenate_datasets(ds_tot)
+    dataset = dataset.shuffle(seed=42)
+    # split into train and test
+    dataset_eval = dataset.select(range(500))
+    dataset = dataset.select(range(500, len(dataset)))
+    ds_dict = DatasetDict({"train": dataset, "test": dataset_eval})
+    return ds_dict
+
+
+def load_docvqa_dataset() -> DatasetDict:
+    if USE_LOCAL_DATASET:
+        dataset_doc = cast(Dataset, load_dataset("./data_dir/DocVQA", "DocVQA", split="validation"))
+        dataset_doc_eval = cast(Dataset, load_dataset("./data_dir/DocVQA", "DocVQA", split="test"))
+        dataset_info = cast(Dataset, load_dataset("./data_dir/DocVQA", "InfographicVQA", split="validation"))
+        dataset_info_eval = cast(Dataset, load_dataset("./data_dir/DocVQA", "InfographicVQA", split="test"))
+    else:
+        dataset_doc = cast(Dataset, load_dataset("lmms-lab/DocVQA", "DocVQA", split="validation"))
+        dataset_doc_eval = cast(Dataset, load_dataset("lmms-lab/DocVQA", "DocVQA", split="test"))
+        dataset_info = cast(Dataset, load_dataset("lmms-lab/DocVQA", "InfographicVQA", split="validation"))
+        dataset_info_eval = cast(Dataset, load_dataset("lmms-lab/DocVQA", "InfographicVQA", split="test"))
+
+    # concatenate the two datasets
+    dataset = concatenate_datasets([dataset_doc, dataset_info])
+    dataset_eval = concatenate_datasets([dataset_doc_eval, dataset_info_eval])
+    # sample 100 from eval dataset
+    dataset_eval = dataset_eval.shuffle(seed=42).select(range(200))
+
+    # rename question as query
+    dataset = dataset.rename_column("question", "query")
+    dataset_eval = dataset_eval.rename_column("question", "query")
+
+    # create new column image_filename that corresponds to ucsf_document_id if not None, else image_url
+    dataset = dataset.map(
+        lambda x: {"image_filename": x["ucsf_document_id"] if x["ucsf_document_id"] is not None else x["image_url"]}
+    )
+    dataset_eval = dataset_eval.map(
+        lambda x: {"image_filename": x["ucsf_document_id"] if x["ucsf_document_id"] is not None else x["image_url"]}
+    )
+
+    ds_dict = DatasetDict({"train": dataset, "test": dataset_eval})
+
+    return ds_dict
+
+
+def load_dummy_dataset() -> List[DatasetDict]:
+    # create a dataset from the queries and images
+    queries_1 = ["What is the capital of France?", "What is the capital of Germany?"]
+    queries_2 = ["What is the capital of Italy?", "What is the capital of Spain?"]
+
+    images_1 = [Image.new("RGB", (100, 100)) for _ in range(2)]
+    images_2 = [Image.new("RGB", (120, 120)) for _ in range(2)]
+
+    dataset_1 = Dataset.from_list([{"query": q, "image": i} for q, i in zip(queries_1, images_1)])
+    dataset_2 = Dataset.from_list([{"query": q, "image": i} for q, i in zip(queries_2, images_2)])
+
+    return DatasetDict(
+        {
+            "train": DatasetDict({"dataset_1": dataset_1, "dataset_2": dataset_2}),
+            "test": DatasetDict({"dataset_1": dataset_2, "dataset_2": dataset_1}),
+        }
+    )
+
+
+def load_multi_qa_datasets() -> List[DatasetDict]:
+    dataset_args = [
+        ("vidore/colpali_train_set"),
+        ("llamaindex/vdr-multilingual-train", "de"),
+        ("llamaindex/vdr-multilingual-train", "en"),
+        ("llamaindex/vdr-multilingual-train", "es"),
+        ("llamaindex/vdr-multilingual-train", "fr"),
+        ("llamaindex/vdr-multilingual-train", "it"),
+    ]
+
+    train_datasets = {}
+    test_datasets = {}
+    for args in dataset_args:
+        dataset_name = args[0] + "_" + args[1]
+        dataset = load_dataset(*args)
+        if "test" in dataset:
+            train_datasets[dataset_name] = dataset["train"]
+            test_datasets[dataset_name] = dataset["test"]
+        else:
+            train_dataset, test_dataset = dataset.split_by_ratio(test_size=200)
+            train_datasets[dataset_name] = train_dataset
+            test_datasets[dataset_name] = test_dataset
+
+    return DatasetDict({"train": DatasetDict(train_datasets), "test": DatasetDict(test_datasets)})
+
+
+class TestSetFactory:
+    def __init__(self, dataset_path):
+        self.dataset_path = dataset_path
+
+    def __call__(self, *args, **kwargs):
+        dataset = load_dataset(self.dataset_path, split="test")
+        return dataset
+
+
+if __name__ == "__main__":
+    ds = TestSetFactory("vidore/tabfquad_test_subsampled")()
+    print(ds)

+ 24 - 0
deconstruct_SQI/colpali/colpali_engine/utils/gpu_stats.py

@@ -0,0 +1,24 @@
+# cond import
+try:
+    from pynvml import nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo, nvmlInit
+
+    def print_gpu_utilization():
+        nvmlInit()
+        handle = nvmlDeviceGetHandleByIndex(0)
+        info = nvmlDeviceGetMemoryInfo(handle)
+        print(f"GPU memory occupied: {info.used // 1024**2} MB.")
+
+    def print_summary(result):
+        print(f"Time: {result.metrics['train_runtime']:.2f}")
+        print(f"Samples/second: {result.metrics['train_samples_per_second']:.2f}")
+        print_gpu_utilization()
+
+except ImportError:
+    print("pynvml not found. GPU stats will not be printed.")
+
+    def print_summary(result):
+        print(f"Time: {result.metrics['train_runtime']:.2f}")
+        print(f"Samples/second: {result.metrics['train_samples_per_second']:.2f}")
+
+    def print_gpu_utilization():
+        pass

+ 256 - 0
deconstruct_SQI/colpali/colpali_engine/utils/processing_utils.py

@@ -0,0 +1,256 @@
+import importlib
+import logging
+from abc import ABC, abstractmethod
+from typing import ClassVar, List, Optional, Tuple, Union
+
+import torch
+from PIL import Image
+from transformers import BatchEncoding, BatchFeature
+
+try:
+    from fast_plaid import search
+except ImportError:
+    logging.info(
+        "FastPlaid is not installed.If you want to use it:Instal with `pip install --no-deps fast-plaid fastkmeans`"
+    )
+
+from colpali_engine.utils.torch_utils import get_torch_device
+
+
+class BaseVisualRetrieverProcessor(ABC):
+    """
+    Base class for visual retriever processors.
+    """
+
+    query_prefix: ClassVar[str] = ""  # Default prefix for queries. Override in subclasses if needed.
+
+    @abstractmethod
+    def process_images(
+        self,
+        images: List[Image.Image],
+    ) -> Union[BatchFeature, BatchEncoding]:
+        """
+        Process a list of images into a format suitable for the model.
+        Args:
+            images (List[Image.Image]): List of images to process.
+        Returns:
+            Union[BatchFeature, BatchEncoding]: Processed images.
+        """
+        pass
+
+    @abstractmethod
+    def process_texts(self, texts: List[str]) -> Union[BatchFeature, BatchEncoding]:
+        """
+        Process a list of texts into a format suitable for the model.
+
+        Args:
+            texts: List of input texts.
+
+        Returns:
+            Union[BatchFeature, BatchEncoding]: Processed texts.
+        """
+        pass
+
+    def process_queries(
+        self,
+        texts: Optional[List[str]] = None,
+        queries: Optional[List[str]] = None,
+        max_length: int = 50,
+        contexts: Optional[List[str]] = None,
+        suffix: Optional[str] = None,
+    ) -> Union[BatchFeature, BatchEncoding]:
+        """
+        Process a list of queries into a format suitable for the model.
+
+        Args:
+            texts: List of input texts.
+            [DEPRECATED] max_length: Maximum length of the text.
+            suffix: Suffix to append to each text. If None, the default query augmentation token is used.
+
+        Returns:
+            Union[BatchFeature, BatchEncoding]: Processed texts.
+
+        NOTE: This function will be deprecated. Use `process_texts` instead.
+        It is kept to maintain back-compatibility with vidore evaluator.
+        """
+
+        if texts and queries:
+            raise ValueError("Only one of 'texts' or 'queries' should be provided.")
+        if queries is not None:
+            texts = queries
+        elif texts is None:
+            raise ValueError("No texts or queries provided.")
+
+        if suffix is None:
+            suffix = self.query_augmentation_token * 10
+
+        # Add the query prefix and suffix to each text
+        texts = [self.query_prefix + text + suffix for text in texts]
+
+        return self.process_texts(texts=texts)
+
+    @abstractmethod
+    def score(
+        self,
+        qs: Union[torch.Tensor, List[torch.Tensor]],
+        ps: Union[torch.Tensor, List[torch.Tensor]],
+        device: Optional[Union[str, torch.device]] = None,
+        **kwargs,
+    ) -> torch.Tensor:
+        pass
+
+    @staticmethod
+    def score_single_vector(
+        qs: Union[torch.Tensor, List[torch.Tensor]],
+        ps: Union[torch.Tensor, List[torch.Tensor]],
+        device: Optional[Union[str, torch.device]] = None,
+    ) -> torch.Tensor:
+        """
+        Compute the dot product score for the given single-vector query and passage embeddings.
+        """
+        device = device or get_torch_device("auto")
+
+        if isinstance(qs, list) and isinstance(ps, list):
+            if len(qs) == 0:
+                raise ValueError("No queries provided")
+            if len(ps) == 0:
+                raise ValueError("No passages provided")
+
+            qs = torch.stack(qs).to(device)
+            ps = torch.stack(ps).to(device)
+        else:
+            qs = qs.to(device)
+            ps = ps.to(device)
+
+        scores = torch.einsum("bd,cd->bc", qs, ps)
+        assert scores.shape[0] == len(qs), f"Expected {len(qs)} scores, got {scores.shape[0]}"
+
+        scores = scores.to(torch.float32)
+        return scores
+
+    @staticmethod
+    def score_multi_vector(
+        qs: Union[torch.Tensor, List[torch.Tensor]],
+        ps: Union[torch.Tensor, List[torch.Tensor]],
+        batch_size: int = 128,
+        device: Optional[Union[str, torch.device]] = None,
+    ) -> torch.Tensor:
+        """
+        Compute the late-interaction/MaxSim score (ColBERT-like) for the given multi-vector
+        query embeddings (`qs`) and passage embeddings (`ps`). For ColPali, a passage is the
+        image of a document page.
+
+        Because the embedding tensors are multi-vector and can thus have different shapes, they
+        should be fed as:
+        (1) a list of tensors, where the i-th tensor is of shape (sequence_length_i, embedding_dim)
+        (2) a single tensor of shape (n_passages, max_sequence_length, embedding_dim) -> usually
+            obtained by padding the list of tensors.
+
+        Args:
+            qs (`Union[torch.Tensor, List[torch.Tensor]`): Query embeddings.
+            ps (`Union[torch.Tensor, List[torch.Tensor]`): Passage embeddings.
+            batch_size (`int`, *optional*, defaults to 128): Batch size for computing scores.
+            device (`Union[str, torch.device]`, *optional*): Device to use for computation. If not
+                provided, uses `get_torch_device("auto")`.
+
+        Returns:
+            `torch.Tensor`: A tensor of shape `(n_queries, n_passages)` containing the scores. The score
+            tensor is saved on the "cpu" device.
+        """
+        device = device or get_torch_device("auto")
+
+        if len(qs) == 0:
+            raise ValueError("No queries provided")
+        if len(ps) == 0:
+            raise ValueError("No passages provided")
+
+        scores_list: List[torch.Tensor] = []
+
+        for i in range(0, len(qs), batch_size):
+            scores_batch = []
+            qs_batch = torch.nn.utils.rnn.pad_sequence(qs[i : i + batch_size], batch_first=True, padding_value=0).to(
+                device
+            )
+            for j in range(0, len(ps), batch_size):
+                ps_batch = torch.nn.utils.rnn.pad_sequence(
+                    ps[j : j + batch_size], batch_first=True, padding_value=0
+                ).to(device)
+                scores_batch.append(torch.einsum("bnd,csd->bcns", qs_batch, ps_batch).max(dim=3)[0].sum(dim=2))
+            scores_batch = torch.cat(scores_batch, dim=1).cpu()
+            scores_list.append(scores_batch)
+
+        scores = torch.cat(scores_list, dim=0)
+        assert scores.shape[0] == len(qs), f"Expected {len(qs)} scores, got {scores.shape[0]}"
+
+        scores = scores.to(torch.float32)
+        return scores
+
+    @staticmethod
+    def get_topk_plaid(
+        qs: Union[torch.Tensor, List[torch.Tensor]],
+        plaid_index: "search.FastPlaid",
+        k: int = 10,
+        batch_size: int = 128,
+        device: Optional[Union[str, torch.device]] = None,
+    ) -> torch.Tensor:
+        """
+        Experimental: Compute the late-interaction/MaxSim score (ColBERT-like) for the given multi-vector
+        query embeddings (`qs`) and passage embeddings endoded in a plaid index. For ColPali, a passage is the
+        image of a document page.
+        """
+        device = device or get_torch_device("auto")
+
+        if len(qs) == 0:
+            raise ValueError("No queries provided")
+
+        scores_list: List[torch.Tensor] = []
+
+        for i in range(0, len(qs), batch_size):
+            scores_batch = []
+            qs_batch = torch.nn.utils.rnn.pad_sequence(qs[i : i + batch_size], batch_first=True, padding_value=0).to(
+                device
+            )
+            # Use the plaid index to get the top-k scores
+            scores_batch = plaid_index.search(
+                queries_embeddings=qs_batch.to(torch.float32),
+                top_k=k,
+            )
+            scores_list.append(scores_batch)
+
+        return scores_list
+
+    @staticmethod
+    def create_plaid_index(
+        ps: Union[torch.Tensor, List[torch.Tensor]],
+        device: Optional[Union[str, torch.device]] = None,
+    ) -> torch.Tensor:
+        """
+        Experimental: Create a FastPlaid index from the given passage embeddings.
+        Args:
+            ps (`Union[torch.Tensor, List[torch.Tensor]]`): Passage embeddings. Should be a list of tensors,
+                where each tensor is of shape (sequence_length_i, embedding_dim).
+            device (`Optional[Union[str, torch.device]]`, *optional*): Device to use for computation. If not
+                provided, uses `get_torch_device("auto")`.
+        """
+        # assert fast_plaid is installed
+        if not importlib.util.find_spec("fast_plaid"):
+            raise ImportError("FastPlaid is not installed. Please install it with `pip install fast-plaid`.")
+
+        fast_plaid_index = search.FastPlaid(index="index")
+        # torch.nn.utils.rnn.pad_sequence(ds, batch_first=True, padding_value=0).to(device)
+        device = device or get_torch_device("auto")
+        fast_plaid_index.create(documents_embeddings=[d.to(device).to(torch.float32) for d in ps])
+        return fast_plaid_index
+
+    @abstractmethod
+    def get_n_patches(
+        self,
+        image_size: Tuple[int, int],
+        *args,
+        **kwargs,
+    ) -> Tuple[int, int]:
+        """
+        Get the number of patches (n_patches_x, n_patches_y) that will be used to process an
+        image of size (height, width) with the given patch size.
+        """
+        pass

+ 99 - 0
deconstruct_SQI/colpali/colpali_engine/utils/torch_utils.py

@@ -0,0 +1,99 @@
+import gc
+import logging
+from typing import List, TypeVar
+
+import torch
+from torch.utils.data import Dataset
+
+logger = logging.getLogger(__name__)
+T = TypeVar("T")
+
+
+def get_torch_device(device: str = "auto") -> str:
+    """
+    Returns the device (string) to be used by PyTorch.
+
+    `device` arg defaults to "auto" which will use:
+    - "cuda:0" if available
+    - else "mps" if available
+    - else "cpu".
+    """
+
+    if device == "auto":
+        if torch.cuda.is_available():
+            device = "cuda:0"
+        elif torch.backends.mps.is_available():  # for Apple Silicon
+            device = "mps"
+        else:
+            device = "cpu"
+        logger.info(f"Using device: {device}")
+
+    return device
+
+
+def tear_down_torch():
+    """
+    Teardown for PyTorch.
+    Clears GPU cache for both CUDA and MPS.
+    """
+    gc.collect()
+    if torch.cuda.is_available():
+        torch.cuda.empty_cache()
+    if torch.backends.mps.is_available():
+        torch.mps.empty_cache()
+
+
+class ListDataset(Dataset[T]):
+    def __init__(self, elements: List[T]):
+        self.elements = elements
+
+    def __len__(self) -> int:
+        return len(self.elements)
+
+    def __getitem__(self, idx: int) -> T:
+        return self.elements[idx]
+
+
+def unbind_padded_multivector_embeddings(
+    embeddings: torch.Tensor,
+    padding_value: float = 0.0,
+    padding_side: str = "left",
+) -> List[torch.Tensor]:
+    """
+    Removes padding elements from a batch of multivector embeddings.
+
+    Args:
+        embeddings (torch.Tensor): A tensor of shape (batch_size, seq_length, dim) with padding.
+        padding_value (float): The value used for padding. Each padded token is assumed
+            to be a vector where every element equals this value.
+        padding_side (str): Either "left" or "right". This indicates whether the padded
+            elements appear at the beginning (left) or end (right) of the sequence.
+
+    Returns:
+        List[torch.Tensor]: A list of tensors, one per sequence in the batch, where
+            each tensor has shape (new_seq_length, dim) and contains only the non-padding elements.
+    """
+    results: List[torch.Tensor] = []
+
+    for seq in embeddings:
+        is_padding = torch.all(seq.eq(padding_value), dim=-1)
+
+        if padding_side == "left":
+            non_padding_indices = (~is_padding).nonzero(as_tuple=False)
+            if non_padding_indices.numel() == 0:
+                valid_seq = seq[:0]
+            else:
+                first_valid_idx = non_padding_indices[0].item()
+                valid_seq = seq[first_valid_idx:]
+        elif padding_side == "right":
+            non_padding_indices = (~is_padding).nonzero(as_tuple=False)
+            if non_padding_indices.numel() == 0:
+                valid_seq = seq[:0]
+            else:
+                last_valid_idx = non_padding_indices[-1].item()
+                valid_seq = seq[: last_valid_idx + 1]
+        else:
+            raise ValueError("padding_side must be either 'left' or 'right'.")
+        results.append(valid_seq)
+
+    return results

+ 20 - 0
deconstruct_SQI/colpali/colpali_engine/utils/transformers_wrappers.py

@@ -0,0 +1,20 @@
+import importlib
+
+if importlib.util.find_spec("transformers") is not None:
+    from transformers import AutoProcessor, AutoTokenizer
+    from transformers.tokenization_utils import PreTrainedTokenizer
+
+    class AllPurposeWrapper:
+        def __new__(cls, class_to_instanciate, *args, **kwargs):
+            return class_to_instanciate.from_pretrained(*args, **kwargs)
+
+    class AutoProcessorWrapper:
+        def __new__(cls, *args, **kwargs):
+            return AutoProcessor.from_pretrained(*args, **kwargs)
+
+    class AutoTokenizerWrapper(PreTrainedTokenizer):
+        def __new__(cls, *args, **kwargs):
+            return AutoTokenizer.from_pretrained(*args, **kwargs)
+
+else:
+    raise ModuleNotFoundError("Transformers must be loaded")

+ 86 - 0
deconstruct_SQI/colpali/pyproject.toml

@@ -0,0 +1,86 @@
+[build-system]
+requires = ["hatchling", "hatch-vcs"]
+build-backend = "hatchling.build"
+
+[tool.hatch.version]
+source = "vcs"
+
+[tool.hatch.build.targets.wheel]
+include = ["colpali_engine"]
+  
+[project]
+name = "colpali_engine"
+dynamic = ["version"]
+description = "The code used to train and run inference with the ColPali architecture."
+authors = [
+    { name = "Manuel Faysse", email = "manuel.faysse@illuin.tech" },
+    { name = "Hugues Sibille", email = "hugues.sibille@illuin.tech" },
+    { name = "Tony Wu", email = "tony.wu@illuin.tech" },
+]
+maintainers = [
+    { name = "Manuel Faysse", email = "manuel.faysse@illuin.tech" },
+    { name = "Tony Wu", email = "tony.wu@illuin.tech" },
+]
+readme = "README.md"
+requires-python = ">=3.9"
+classifiers = [
+    "Programming Language :: Python :: 3",
+    "License :: OSI Approved :: MIT License",
+    "Intended Audience :: Science/Research",
+    "Intended Audience :: Developers",
+    "Operating System :: OS Independent",
+    "Topic :: Scientific/Engineering :: Artificial Intelligence",
+]
+
+dependencies = [
+    "numpy",
+    "peft>=0.14.0,<0.18.0",
+    "pillow>=10.0.0",
+    "requests",
+    "scipy",
+    "torch>=2.2.0,<2.9.0",
+    "torchvision",
+    "transformers>=4.53.1,<4.58.0",
+]
+
+[project.optional-dependencies]
+train = [
+    "accelerate>=0.34.0,<1.9.0",
+    "bitsandbytes",
+    "configue>=5.0.0",
+    "datasets>=2.19.1",
+    "mteb>=1.16.3,<2",
+    "pillow>=10.0.0,<11.4.0",
+    "typer>=0.15.1",
+]
+
+interpretability = [
+    "einops>=0.8.0,<1.0.0",
+    "matplotlib>=3.9.0,<4.0.0",
+    "seaborn>=0.13.2,<1.0.0",
+]
+
+dev = ["pytest>=8.0.0", "ruff>=0.4.0"]
+
+all = [
+    "colpali-engine[dev]",
+    "colpali-engine[interpretability]",
+    "colpali-engine[train]",
+]
+
+[project.urls]
+homepage = "https://github.com/illuin-tech/colpali"
+
+[tool.pytest.ini_options]
+filterwarnings = ["ignore::Warning"]
+markers = ["slow: marks test as slow"]
+testpaths = ["tests"]
+
+[tool.ruff]
+line-length = 120
+
+[tool.ruff.lint]
+select = ["E", "F", "W", "I", "N"]
+
+[tool.ruff.lint.per-file-ignores]
+"__init__.py" = ["F401"]

+ 109 - 0
deconstruct_SQI/colpali/scripts/api_call.py

@@ -0,0 +1,109 @@
+import asyncio
+import base64
+import os
+from io import BytesIO
+from typing import Any, List
+
+import aiohttp
+import torch
+from PIL import Image
+from tqdm.asyncio import tqdm_asyncio
+
+
+class IlluinAPIModelWrapper:
+    def __init__(
+        self,
+        model_name: str,
+        **kwargs: Any,
+    ):
+        """Wrapper for Illuin API embedding model"""
+        self.model_name = model_name
+        self.url = model_name
+        self.HEADERS = {
+            "Accept": "application/json",
+            "Authorization": f"Bearer {os.getenv('HF_TOKEN')}",
+            "Content-Type": "application/json",
+        }
+
+    @staticmethod
+    def convert_image_to_base64(image: Image.Image) -> str:
+        buffer = BytesIO()
+        image.save(buffer, format="JPEG")
+        return base64.b64encode(buffer.getvalue()).decode("utf-8")
+
+    async def post_images(self, session: aiohttp.ClientSession, encoded_images: List[str]):
+        payload = {"inputs": {"images": encoded_images}}
+        async with session.post(self.url, headers=self.HEADERS, json=payload) as response:
+            return await response.json()
+
+    async def post_queries(self, session: aiohttp.ClientSession, queries: List[str]):
+        payload = {"inputs": {"queries": queries}}
+        async with session.post(self.url, headers=self.HEADERS, json=payload) as response:
+            return await response.json()
+
+    async def call_api_queries(self, queries: List[str]):
+        embeddings = []
+        semaphore = asyncio.Semaphore(16)
+        async with aiohttp.ClientSession() as session:
+
+            async def sem_post(batch):
+                async with semaphore:
+                    return await self.post_queries(session, batch)
+
+            tasks = [asyncio.create_task(sem_post([batch])) for batch in queries]
+
+            # ORDER-PRESERVING
+            results = await tqdm_asyncio.gather(*tasks, desc="Query batches")
+
+            for result in results:
+                embeddings.extend(result.get("embeddings", []))
+
+        return embeddings
+
+    async def call_api_images(self, images_b64: List[str]):
+        embeddings = []
+        semaphore = asyncio.Semaphore(16)
+
+        async with aiohttp.ClientSession() as session:
+
+            async def sem_post(batch):
+                async with semaphore:
+                    return await self.post_images(session, batch)
+
+            tasks = [asyncio.create_task(sem_post([batch])) for batch in images_b64]
+
+            # ORDER-PRESERVING
+            results = await tqdm_asyncio.gather(*tasks, desc="Doc batches")
+
+            for result in results:
+                embeddings.extend(result.get("embeddings", []))
+
+        return embeddings
+
+    def forward_queries(self, queries: List[str]) -> torch.Tensor:
+        response = asyncio.run(self.call_api_queries(queries))
+        return response
+
+    def forward_passages(self, passages: List[Image.Image]) -> torch.Tensor:
+        response = asyncio.run(self.call_api_images([self.convert_image_to_base64(doc) for doc in passages]))
+        return response
+
+
+if __name__ == "__main__":
+    # Example usage
+
+    client = IlluinAPIModelWrapper(
+        model_name="https://sxeg6spz1yy8unh7.us-east-1.aws.endpoints.huggingface.cloud",
+    )
+
+    embed_queries = client.forward_queries(["What is the capital of France?", "Explain quantum computing."])
+
+    images = [
+        Image.new("RGB", (32, 32), color="white"),
+        Image.new("RGB", (128, 128), color="black"),
+    ]
+
+    embed_images = client.forward_passages(images)
+
+    print("Query embeddings shape:", len(embed_queries))
+    print("Image embeddings shape:", len(embed_images))

+ 131 - 0
deconstruct_SQI/colpali/scripts/compute_hardnegs.py

@@ -0,0 +1,131 @@
+from typing import cast
+
+import datasets
+import numpy as np
+import torch
+from torch.utils.data import DataLoader
+from tqdm import tqdm
+
+from colpali_engine.models import BiQwen2, BiQwen2Processor
+from colpali_engine.utils.dataset_transformation import load_train_set
+
+train_set = load_train_set()
+
+
+COMPUTE_EMBEDDINGS = False
+COMPUTE_HARDNEGS = False
+
+if COMPUTE_HARDNEGS or COMPUTE_EMBEDDINGS:
+    print("Loading base model")
+    model = BiQwen2.from_pretrained(
+        "./models/biqwen2-warmup-256-newpad-0e",
+        torch_dtype=torch.bfloat16,
+        device_map="cuda",
+        attn_implementation="flash_attention_2" if torch.cuda.is_available() else None,
+    ).eval()
+
+    print("Loading processor")
+    processor = BiQwen2Processor.from_pretrained("./models/biqwen2-warmup-256-newpad-0e")
+
+if COMPUTE_EMBEDDINGS:
+    print("Loading images")
+    print("Images loaded")
+
+    document_set = train_set["train"]
+    print("Filtering dataset")
+    print(document_set)
+    initial_list = document_set["image_filename"]
+    _, unique_indices = np.unique(initial_list, return_index=True, axis=0)
+    filtered_dataset = document_set.select(unique_indices.tolist())
+    filtered_dataset = filtered_dataset.map(
+        lambda example: {"image": example["image"], "image_filename": example["image_filename"]}, num_proc=16
+    )
+    # keep only column image and image_filename and source if it exists
+    cols_to_remove = [col for col in filtered_dataset.column_names if col not in ["image", "image_filename"]]
+    filtered_dataset = filtered_dataset.remove_columns(cols_to_remove)
+    # save it
+    print("Saving filtered dataset")
+    print(filtered_dataset)
+    filtered_dataset.save_to_disk("data_dir/filtered_dataset", max_shard_size="200MB")
+
+    print("Processing images")
+    # run inference - docs
+    dataloader = DataLoader(
+        filtered_dataset,
+        batch_size=8,
+        shuffle=False,
+        collate_fn=lambda x: processor.process_images([a["image"] for a in x]),
+    )
+    print("Computing embeddings")
+
+    ds = []
+    for batch_doc in tqdm(dataloader):
+        with torch.no_grad():
+            batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
+            embeddings_doc = model(**batch_doc)
+        ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
+
+    ds = torch.stack(ds)
+
+    # save embeddings
+    torch.save(ds, "data_dir/filtered_dataset_embeddings.pt")
+
+if not COMPUTE_EMBEDDINGS:
+    ds = torch.load("data_dir/filtered_dataset_embeddings.pt")
+
+
+if COMPUTE_HARDNEGS:
+    # compute hard negatives
+    ds = cast(torch.Tensor, ds).to("cuda")
+
+    # iterate on the train set
+    mined_hardnegs = []
+
+    for i in tqdm(range(0, len(train_set["train"]), 8)):
+        samples = train_set["train"][i : i + 8]
+        batch_query = processor.process_queries(samples["query"])
+        with torch.no_grad():
+            batch_query = {k: v.to(model.device) for k, v in batch_query.items()}
+            embeddings_query = model(**batch_query)
+
+        # compute scores
+        scores = torch.einsum("bd,cd->bc", embeddings_query, ds)
+        # get top 100 indexes
+        top100 = scores.topk(100, dim=1).indices
+        # indices to list
+        top100 = top100.tolist()
+        # append to mined_hardnegs
+        mined_hardnegs.extend(top100)
+
+    # save mined hardnegs as txt
+    with open("data_dir/mined_hardnegs_filtered.txt", "w") as f:
+        for item in mined_hardnegs:
+            f.write("%s\n" % item)
+
+
+with open("data_dir/mined_hardnegs_filtered.txt") as f:
+    mined_hardnegs = f.readlines()
+
+
+filtered_dataset = datasets.load_from_disk("data_dir/filtered_dataset")
+filenames = list(filtered_dataset["image_filename"])
+
+
+def mapper_fn(example, idx):
+    tmp = {
+        "negative_passages": [int(x) for x in mined_hardnegs[idx][1:-2].strip().split(",")],
+        "query": example["query"],
+        "positive_passages": [filenames.index(example["image_filename"])],
+    }
+
+    tmp["gold_in_top_100"] = tmp["positive_passages"][0] in tmp["negative_passages"]
+    # remove gold index from negs if it is there
+    if tmp["gold_in_top_100"]:
+        tmp["negative_passages"].remove(tmp["positive_passages"][0])
+    return tmp
+
+
+final_dataset = train_set["train"].map(mapper_fn, with_indices=True, num_proc=16)
+# drop image
+final_dataset = final_dataset.remove_columns("image")
+final_dataset.save_to_disk("data_dir/final_dataset")

+ 3 - 0
deconstruct_SQI/colpali/scripts/configs/data/debug_data.yaml

@@ -0,0 +1,3 @@
+syntheticDocQA_energy:
+  (): colpali_engine.utils.dataset_transformation.load_eval_set
+  dataset_path: vidore/syntheticDocQA_energy_test

+ 31 - 0
deconstruct_SQI/colpali/scripts/configs/data/test_data.yaml

@@ -0,0 +1,31 @@
+# eval_dataset_loader:
+syntheticDocQA_energy:
+  (): colpali_engine.utils.dataset_transformation.load_eval_set
+  dataset_path: !path ../../../data_dir/syntheticDocQA_energy_test
+syntheticDocQA_healthcare_industry:
+  (): colpali_engine.utils.dataset_transformation.load_eval_set
+  dataset_path: !path ../../../data_dir/syntheticDocQA_healthcare_industry_test
+syntheticDocQA_artificial_intelligence_test:
+  (): colpali_engine.utils.dataset_transformation.load_eval_set
+  dataset_path: !path ../../../data_dir/syntheticDocQA_artificial_intelligence_test
+syntheticDocQA_government_reports:
+  (): colpali_engine.utils.dataset_transformation.load_eval_set
+  dataset_path: !path ../../../data_dir/syntheticDocQA_government_reports_test
+infovqa_subsampled:
+  (): colpali_engine.utils.dataset_transformation.load_eval_set
+  dataset_path: !path ../../../data_dir/infovqa_test_subsampled
+docvqa_subsampled:
+  (): colpali_engine.utils.dataset_transformation.load_eval_set
+  dataset_path: !path ../../../data_dir/docvqa_test_subsampled
+arxivqa_subsampled:
+  (): colpali_engine.utils.dataset_transformation.load_eval_set
+  dataset_path: !path ../../../data_dir/arxivqa_test_subsampled
+tabfquad_subsampled:
+  (): colpali_engine.utils.dataset_transformation.load_eval_set
+  dataset_path: !path ../../../data_dir/tabfquad_test_subsampled
+tatdqa:
+  (): colpali_engine.utils.dataset_transformation.load_eval_set
+  dataset_path: !path ../../../data_dir/tatdqa_test
+shift_project:
+  (): colpali_engine.utils.dataset_transformation.load_eval_set
+  dataset_path: !path ../../../data_dir/shiftproject_test

+ 72 - 0
deconstruct_SQI/colpali/scripts/configs/idefics/train_colsmolvlm_model.yaml

@@ -0,0 +1,72 @@
+config:
+  (): colpali_engine.trainer.colmodel_training.ColModelTrainingConfig
+  output_dir: !path ../../../models/colsmolvlm
+  processor:
+    (): colpali_engine.utils.transformers_wrappers.AllPurposeWrapper
+    class_to_instanciate: !ext colpali_engine.models.ColIdefics3Processor
+    pretrained_model_name_or_path:  "./models/ColSmolVLM-base"
+    # num_image_tokens: 2048
+    # max_length: 50
+
+  model:
+    (): colpali_engine.utils.transformers_wrappers.AllPurposeWrapper
+    class_to_instanciate: !ext colpali_engine.models.ColIdefics3
+    pretrained_model_name_or_path: "./models/ColSmolVLM-base"
+    torch_dtype:  !ext torch.bfloat16
+    # use_cache: false
+    attn_implementation: "flash_attention_2"
+#    device_map: "auto"
+#    quantization_config:
+#      (): transformers.BitsAndBytesConfig
+#      load_in_4bit: true
+#      bnb_4bit_quant_type: "nf4"
+#      bnb_4bit_compute_dtype:  "bfloat16"
+#      bnb_4bit_use_double_quant: true
+
+  train_dataset:
+    (): colpali_engine.utils.dataset_transformation.load_train_set
+  eval_dataset: !import ../data/test_data.yaml
+
+  # max_length: 50
+  run_eval: true
+  loss_func:
+    (): colpali_engine.loss.late_interaction_losses.ColbertPairwiseCELoss
+  tr_args:
+    (): transformers.training_args.TrainingArguments
+    output_dir: null
+    overwrite_output_dir: true
+    num_train_epochs: 3
+    per_device_train_batch_size: 32
+    gradient_checkpointing: true
+    gradient_checkpointing_kwargs: { "use_reentrant": false }
+    # gradient_checkpointing: true
+    # 6 x 8 gpus = 48 batch size
+    # gradient_accumulation_steps: 4
+    per_device_eval_batch_size: 32
+    eval_strategy: "steps"
+    dataloader_num_workers: 4
+    # bf16: true
+    save_steps: 500
+    logging_steps: 10
+    eval_steps: 100
+    warmup_steps: 100
+    learning_rate: 5e-4
+    save_total_limit: 1
+    # resume_from_checkpoint: true
+    # optim: "paged_adamw_8bit"
+    # wandb logging
+    # wandb_project: "colqwen2"
+    # run_name: "colqwen2-ba32-nolora"
+    report_to: "wandb"
+
+
+  peft_config:
+    (): peft.LoraConfig
+    r: 32
+    lora_alpha: 32
+    lora_dropout: 0.1
+    init_lora_weights: "gaussian"
+    bias: "none"
+    task_type: "FEATURE_EXTRACTION"
+    target_modules: '(.*(model.text_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)'
+    # target_modules: '(.*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)'

+ 41 - 0
deconstruct_SQI/colpali/scripts/configs/pali/train_bipali_all_model.yaml

@@ -0,0 +1,41 @@
+config:
+  (): colpali_engine.trainer.colmodel_training.ColModelTrainingConfig
+  output_dir: !path ../../../models/without_tabfquad_no_pairwise/train_bipali_all_mean-3b-mix-448
+  processor:
+    () : colpali_engine.utils.transformers_wrappers.AutoProcessorWrapper
+    pretrained_model_name_or_path: "./models/paligemma-3b-mix-448"
+    max_length: 50
+  model:
+    (): colpali_engine.utils.transformers_wrappers.AllPurposeWrapper
+    class_to_instanciate: !ext colpali_engine.models.BiPali
+    pretrained_model_name_or_path: "./models/paligemma-3b-mix-448-base"
+    torch_dtype:  !ext torch.bfloat16
+#    device_map: "auto"
+#    quantization_config:
+#      (): transformers.BitsAndBytesConfig
+#      load_in_4bit: true
+#      bnb_4bit_quant_type: "nf4"
+#      bnb_4bit_compute_dtype:  "bfloat16"
+#      bnb_4bit_use_double_quant: true
+
+  train_dataset: 
+    (): colpali_engine.utils.dataset_transformation.load_train_set
+  eval_dataset: !import ../data/test_data.yaml
+
+  max_length: 50
+  run_eval: true
+  
+  loss_func:
+    (): colpali_engine.loss.bi_encoder_losses.BiEncoderLoss
+  tr_args: !import ../tr_args/default_tr_args.yaml
+  peft_config:
+    (): peft.LoraConfig
+    r: 32
+    lora_alpha: 32
+    lora_dropout: 0.1
+    init_lora_weights: "gaussian"
+    bias: "none"
+    task_type: "FEATURE_EXTRACTION"
+    target_modules: '(.*(language_model|vision_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(multi_modal_projector\.linear).*$)'
+    # target_modules: '(.*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$'
+

+ 41 - 0
deconstruct_SQI/colpali/scripts/configs/pali/train_bipali_model.yaml

@@ -0,0 +1,41 @@
+config:
+  (): colpali_engine.trainer.colmodel_training.ColModelTrainingConfig
+  output_dir: !path ../../../models/right_pad/train_bipali_reproduction
+  processor:
+    (): colpali_engine.utils.transformers_wrappers.AllPurposeWrapper
+    class_to_instanciate: !ext colpali_engine.models.BiPaliProcessor
+    pretrained_model_name_or_path: "./models/paligemma-3b-mix-448"
+  model:
+    (): colpali_engine.utils.transformers_wrappers.AllPurposeWrapper
+    class_to_instanciate: !ext colpali_engine.models.BiPali
+    pretrained_model_name_or_path: "./models/paligemma-3b-mix-448"
+    torch_dtype:  !ext torch.bfloat16
+#    device_map: "auto"
+#    quantization_config:
+#      (): transformers.BitsAndBytesConfig
+#      load_in_4bit: true
+#      bnb_4bit_quant_type: "nf4"
+#      bnb_4bit_compute_dtype:  "bfloat16"
+#      bnb_4bit_use_double_quant: true
+
+  train_dataset: 
+    (): colpali_engine.utils.dataset_transformation.load_train_set_detailed
+  eval_dataset: !import ../data/test_data.yaml
+
+  max_length: 50
+  run_eval: true
+  
+  loss_func:
+    (): colpali_engine.loss.bi_encoder_losses.BiEncoderLoss
+  tr_args: !import ../tr_args/default_tr_args.yaml
+  peft_config:
+    (): peft.LoraConfig
+    r: 32
+    lora_alpha: 32
+    lora_dropout: 0.1
+    init_lora_weights: "gaussian"
+    bias: "none"
+    task_type: "FEATURE_EXTRACTION"
+    target_modules: '(.*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$)'
+    # target_modules: '(.*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$'
+

+ 65 - 0
deconstruct_SQI/colpali/scripts/configs/pali/train_bipali_pairwise_256_model.yaml

@@ -0,0 +1,65 @@
+config:
+  (): colpali_engine.trainer.colmodel_training.ColModelTrainingConfig
+  output_dir: !path ../../../models/right_pad/train_bipali_pairwise_256_pt
+  processor:
+    (): colpali_engine.utils.transformers_wrappers.AllPurposeWrapper
+    class_to_instanciate: !ext colpali_engine.models.BiPaliProcessor
+    pretrained_model_name_or_path: "./models/colpaligemma-3b-pt-448-base"
+    max_length: 50
+  model:
+    (): colpali_engine.utils.transformers_wrappers.AllPurposeWrapper
+    class_to_instanciate: !ext colpali_engine.models.BiPali
+    pretrained_model_name_or_path: "./models/colpaligemma-3b-pt-448-base"
+    torch_dtype:  !ext torch.bfloat16
+    attn_implementation: "flash_attention_2"
+    # use_cache: false
+#    device_map: "auto"
+#    quantization_config:
+#      (): transformers.BitsAndBytesConfig
+#      load_in_4bit: true
+#      bnb_4bit_quant_type: "nf4"
+#      bnb_4bit_compute_dtype:  "bfloat16"
+#      bnb_4bit_use_double_quant: true
+
+  train_dataset: 
+    (): colpali_engine.utils.dataset_transformation.load_train_set
+  eval_dataset: !import ../data/test_data.yaml
+
+  max_length: 50
+  run_eval: true
+  
+  loss_func:
+    (): colpali_engine.loss.bi_encoder_losses.BiPairwiseCELoss
+  tr_args:
+    (): transformers.training_args.TrainingArguments
+    output_dir: null
+    overwrite_output_dir: true
+    num_train_epochs: 3
+    per_device_train_batch_size: 64
+    gradient_checkpointing: true
+    gradient_checkpointing_kwargs: { "use_reentrant": false }
+    # 6 x 8 gpus = 48 batch size
+    # gradient_accumulation_steps: 4
+    per_device_eval_batch_size: 64
+    eval_strategy: "steps"
+    dataloader_num_workers: 8
+    # bf16: true
+    save_steps: 500
+    logging_steps: 10
+    eval_steps: 100
+    warmup_steps: 100
+    learning_rate: 5e-4
+    save_total_limit: 1
+    resume_from_checkpoint: false
+    report_to: "wandb"
+  peft_config:
+    (): peft.LoraConfig
+    r: 32
+    lora_alpha: 32
+    lora_dropout: 0.1
+    init_lora_weights: "gaussian"
+    bias: "none"
+    task_type: "FEATURE_EXTRACTION"
+    target_modules: '(.*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$)'
+    # target_modules: '(.*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$'
+

+ 42 - 0
deconstruct_SQI/colpali/scripts/configs/pali/train_bipali_pairwise_hardneg_model.yaml

@@ -0,0 +1,42 @@
+config:
+  (): colpali_engine.trainer.colmodel_training.ColModelTrainingConfig
+  output_dir: !path ../../../models/right_pad/train_bipali_pairwise_hardneg_proj
+  processor:
+    () : colpali_engine.utils.transformers_wrappers.AutoProcessorWrapper
+    pretrained_model_name_or_path: "./models/paligemma-3b-mix-448"
+    max_length: 50
+  model:
+    (): colpali_engine.utils.transformers_wrappers.AllPurposeWrapper
+    class_to_instanciate: !ext colpali_engine.models.BiPaliProj
+    pretrained_model_name_or_path: "./models/paligemma-3b-mix-448"
+    torch_dtype:  !ext torch.bfloat16
+#    device_map: "auto"
+#    quantization_config:
+#      (): transformers.BitsAndBytesConfig
+#      load_in_4bit: true
+#      bnb_4bit_quant_type: "nf4"
+#      bnb_4bit_compute_dtype:  "bfloat16"
+#      bnb_4bit_use_double_quant: true
+
+  train_dataset: 
+    (): colpali_engine.utils.dataset_transformation.load_train_set_ir_negs
+  eval_dataset: !import ../data/test_data.yaml
+
+  max_length: 50
+  run_eval: true
+  
+  loss_func:
+    (): colpali_engine.loss.bi_encoder_losses.BiPairwiseNegativeCELoss
+    in_batch_term_weight: 0.5
+  tr_args: !import ../tr_args/default_neg_tr_args.yaml
+  peft_config:
+    (): peft.LoraConfig
+    r: 32
+    lora_alpha: 32
+    lora_dropout: 0.1
+    init_lora_weights: "gaussian"
+    bias: "none"
+    task_type: "FEATURE_EXTRACTION"
+    target_modules: '(.*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$)'
+    # target_modules: '(.*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$'
+

+ 42 - 0
deconstruct_SQI/colpali/scripts/configs/pali/train_bipali_pairwise_model.yaml

@@ -0,0 +1,42 @@
+config:
+  (): colpali_engine.trainer.colmodel_training.ColModelTrainingConfig
+  output_dir: !path ../../../models/right_pad/train_bipali_pairwise_reproduction
+  processor:
+    (): colpali_engine.utils.transformers_wrappers.AllPurposeWrapper
+    class_to_instanciate: !ext colpali_engine.models.BiPaliProcessor
+    pretrained_model_name_or_path: "./models/paligemma-3b-mix-448"
+    max_length: 50
+  model:
+    (): colpali_engine.utils.transformers_wrappers.AllPurposeWrapper
+    class_to_instanciate: !ext colpali_engine.models.BiPali
+    pretrained_model_name_or_path: "./models/paligemma-3b-mix-448"
+    torch_dtype:  !ext torch.bfloat16
+#    device_map: "auto"
+#    quantization_config:
+#      (): transformers.BitsAndBytesConfig
+#      load_in_4bit: true
+#      bnb_4bit_quant_type: "nf4"
+#      bnb_4bit_compute_dtype:  "bfloat16"
+#      bnb_4bit_use_double_quant: true
+
+  train_dataset: 
+    (): colpali_engine.utils.dataset_transformation.load_train_set_detailed
+  eval_dataset: !import ../data/test_data.yaml
+
+  max_length: 50
+  run_eval: true
+  
+  loss_func:
+    (): colpali_engine.loss.bi_encoder_losses.BiPairwiseCELoss
+  tr_args: !import ../tr_args/default_tr_args.yaml
+  peft_config:
+    (): peft.LoraConfig
+    r: 32
+    lora_alpha: 32
+    lora_dropout: 0.1
+    init_lora_weights: "gaussian"
+    bias: "none"
+    task_type: "FEATURE_EXTRACTION"
+    target_modules: '(.*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$)'
+    # target_modules: '(.*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$'
+

+ 65 - 0
deconstruct_SQI/colpali/scripts/configs/pali/train_colpali2_pt_model.yaml

@@ -0,0 +1,65 @@
+config:
+  (): colpali_engine.trainer.colmodel_training.ColModelTrainingConfig
+  output_dir: !path ../../../models/train_colpali2-3b-pt-448-5e5
+  processor:
+    (): colpali_engine.utils.transformers_wrappers.AllPurposeWrapper
+    class_to_instanciate: !ext colpali_engine.models.ColPaliProcessor
+    pretrained_model_name_or_path:  "./models/colpaligemma2-3b-pt-448-base" # "./models/paligemma-3b-mix-448"
+    max_length: 50
+  model:
+    (): colpali_engine.utils.transformers_wrappers.AllPurposeWrapper
+    class_to_instanciate: !ext colpali_engine.models.ColPali
+    pretrained_model_name_or_path: "./models/colpaligemma2-3b-pt-448-base"
+    torch_dtype:  !ext torch.bfloat16
+    attn_implementation: "flash_attention_2"
+  #    device_map: "auto"
+#    quantization_config:
+#      (): transformers.BitsAndBytesConfig
+#      load_in_4bit: true
+#      bnb_4bit_quant_type: "nf4"
+#      bnb_4bit_compute_dtype:  "bfloat16"
+#      bnb_4bit_use_double_quant: true
+
+  train_dataset: 
+    (): colpali_engine.utils.dataset_transformation.load_train_set
+  eval_dataset: !import ../data/test_data.yaml
+
+  max_length: 50
+  run_eval: true
+  
+  loss_func:
+    (): colpali_engine.loss.late_interaction_losses.ColbertPairwiseCELoss
+  tr_args:
+    (): transformers.training_args.TrainingArguments
+    output_dir: null
+    overwrite_output_dir: true
+    num_train_epochs: 5
+    per_device_train_batch_size: 32
+    gradient_checkpointing: true
+    gradient_checkpointing_kwargs: { "use_reentrant": false }
+    # 6 x 8 gpus = 48 batch size
+    # gradient_accumulation_steps: 4
+    per_device_eval_batch_size: 32
+    eval_strategy: "steps"
+    dataloader_num_workers: 16
+    # bf16: true
+    save_steps: 500
+    logging_steps: 10
+    eval_steps: 100
+    warmup_steps: 100
+    learning_rate: 5e-5
+    save_total_limit: 1
+    resume_from_checkpoint: false
+    report_to: "wandb"
+
+  peft_config:
+    (): peft.LoraConfig
+    r: 32
+    lora_alpha: 32
+    lora_dropout: 0.1
+    init_lora_weights: "gaussian"
+    bias: "none"
+    task_type: "FEATURE_EXTRACTION"
+    target_modules: '(.*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)'
+    # target_modules: '(.*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)'
+

+ 40 - 0
deconstruct_SQI/colpali/scripts/configs/pali/train_colpali_all_model.yaml

@@ -0,0 +1,40 @@
+config:
+  (): colpali_engine.trainer.colmodel_training.ColModelTrainingConfig
+  output_dir: !path ../../../models/without_tabfquad_no_pairwise/train_colpali_all-3b-mix-448
+  processor:
+    () : colpali_engine.utils.transformers_wrappers.AutoProcessorWrapper
+    pretrained_model_name_or_path:  "./models/colpaligemma-3b-mix-448-base" # "./models/paligemma-3b-mix-448"
+    max_length: 50
+  model:
+    (): colpali_engine.utils.transformers_wrappers.AllPurposeWrapper
+    class_to_instanciate: !ext colpali_engine.models.ColPali
+    pretrained_model_name_or_path: "./models/colpaligemma-3b-mix-448-base"
+    torch_dtype:  !ext torch.bfloat16
+#    device_map: "auto"
+#    quantization_config:
+#      (): transformers.BitsAndBytesConfig
+#      load_in_4bit: true
+#      bnb_4bit_quant_type: "nf4"
+#      bnb_4bit_compute_dtype:  "bfloat16"
+#      bnb_4bit_use_double_quant: true
+
+  train_dataset: 
+    (): colpali_engine.utils.dataset_transformation.load_train_set
+  eval_dataset: !import ../data/test_data.yaml
+
+  max_length: 50
+  run_eval: true
+  loss_func:
+    (): colpali_engine.loss.late_interaction_losses.ColbertLoss
+  tr_args: !import ../tr_args/default_tr_args.yaml
+  peft_config:
+    (): peft.LoraConfig
+    r: 32
+    lora_alpha: 32
+    lora_dropout: 0.1
+    init_lora_weights: "gaussian"
+    bias: "none"
+    task_type: "FEATURE_EXTRACTION"
+    target_modules: '(.*(language_model|vision_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(multi_modal_projector\.linear).*$|.*(custom_text_proj).*$)'
+    # target_modules: '(.*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)'
+

+ 42 - 0
deconstruct_SQI/colpali/scripts/configs/pali/train_colpali_docmatix_hardneg_model.yaml

@@ -0,0 +1,42 @@
+config:
+  (): colpali_engine.trainer.colmodel_training.ColModelTrainingConfig
+  output_dir: !path ../../../models/train_colpali_docmatix_hardneg_ib_3b-mix-448
+  processor:
+    () : colpali_engine.utils.transformers_wrappers.AutoProcessorWrapper
+    pretrained_model_name_or_path:  "./models/colpaligemma-3b-mix-448-base" # "./models/paligemma-3b-mix-448"
+    max_length: 50
+  model:
+    (): colpali_engine.utils.transformers_wrappers.AllPurposeWrapper
+    class_to_instanciate: !ext colpali_engine.models.ColPali
+    pretrained_model_name_or_path: "./models/colpaligemma-3b-mix-448-base"
+    torch_dtype:  !ext torch.bfloat16
+#    device_map: "auto"
+#    quantization_config:
+#      (): transformers.BitsAndBytesConfig
+#      load_in_4bit: true
+#      bnb_4bit_quant_type: "nf4"
+#      bnb_4bit_compute_dtype:  "bfloat16"
+#      bnb_4bit_use_double_quant: true
+
+  train_dataset: 
+    (): colpali_engine.utils.dataset_transformation.load_docmatix_ir_negs
+  eval_dataset: !import ../data/test_data.yaml
+
+  max_length: 50
+  run_eval: true
+  
+  loss_func:
+    (): colpali_engine.loss.late_interaction_losses.ColbertPairwiseNegativeCELoss
+    in_batch_term_weight: 0.5
+  tr_args: !import ../tr_args/default_neg_tr_args.yaml
+  peft_config:
+    (): peft.LoraConfig
+    r: 32
+    lora_alpha: 32
+    lora_dropout: 0.1
+    init_lora_weights: "gaussian"
+    bias: "none"
+    task_type: "FEATURE_EXTRACTION"
+    target_modules: '(.*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)'
+    # target_modules: '(.*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)'
+

+ 39 - 0
deconstruct_SQI/colpali/scripts/configs/pali/train_colpali_docmatix_model.yaml

@@ -0,0 +1,39 @@
+config:
+  (): colpali_engine.trainer.colmodel_training.ColModelTrainingConfig
+  output_dir: !path ../../../models/train_colpali-docmatix-3b-mix-448
+  processor:
+    () : colpali_engine.utils.transformers_wrappers.AutoProcessorWrapper
+    pretrained_model_name_or_path:  "./models/colpaligemma-3b-mix-448-base" # "./models/paligemma-3b-mix-448"
+    max_length: 50
+  model:
+    (): colpali_engine.utils.transformers_wrappers.AllPurposeWrapper
+    class_to_instanciate: !ext colpali_engine.models.ColPali
+    pretrained_model_name_or_path: "./models/colpaligemma-3b-mix-448-base"
+    torch_dtype:  !ext torch.bfloat16
+#    quantization_config:
+#      (): transformers.BitsAndBytesConfig
+#      load_in_4bit: true
+#      bnb_4bit_quant_type: "nf4"
+#      bnb_4bit_compute_dtype:  "bfloat16"
+#      bnb_4bit_use_double_quant: true
+
+  train_dataset: 
+    (): colpali_engine.utils.dataset_transformation.load_train_set_with_docmatix
+  eval_dataset: !import ../data/test_data.yaml
+
+  max_length: 50
+  run_eval: true
+  loss_func:
+    (): colpali_engine.loss.late_interaction_losses.ColbertPairwiseCELoss
+  tr_args: !import ../tr_args/default_tr_args.yaml
+  peft_config:
+    (): peft.LoraConfig
+    r: 32
+    lora_alpha: 32
+    lora_dropout: 0.1
+    init_lora_weights: "gaussian"
+    bias: "none"
+    task_type: "FEATURE_EXTRACTION"
+    target_modules: '(.*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)'
+    # target_modules: '(.*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)'
+

+ 42 - 0
deconstruct_SQI/colpali/scripts/configs/pali/train_colpali_hardneg_debug_model.yaml

@@ -0,0 +1,42 @@
+config:
+  (): colpali_engine.trainer.colmodel_training.ColModelTrainingConfig
+  output_dir: !path ../../../models/without_tabfquad/train_colpali-3b-mix-448-debug
+  processor:
+    () : colpali_engine.utils.transformers_wrappers.AutoProcessorWrapper
+    pretrained_model_name_or_path:  "./models/colpaligemma-3b-mix-448-base" # "./models/paligemma-3b-mix-448"
+    max_length: 50
+  model:
+    (): colpali_engine.utils.transformers_wrappers.AllPurposeWrapper
+    class_to_instanciate: !ext colpali_engine.models.ColPali
+    pretrained_model_name_or_path: "./models/colpaligemma-3b-mix-448-base"
+    torch_dtype:  !ext torch.bfloat16
+#    device_map: "auto"
+#    quantization_config:
+#      (): transformers.BitsAndBytesConfig
+#      load_in_4bit: true
+#      bnb_4bit_quant_type: "nf4"
+#      bnb_4bit_compute_dtype:  "bfloat16"
+#      bnb_4bit_use_double_quant: true
+
+  train_dataset: 
+    (): colpali_engine.utils.dataset_transformation.load_train_set_ir_negs
+  eval_dataset: !import ../data/test_data.yaml
+
+  max_length: 50
+  run_eval: true
+  
+  loss_func:
+    (): colpali_engine.loss.late_interaction_losses.ColbertPairwiseNegativeCELoss
+    in_batch_term_weight: 0.5
+  tr_args: !import ../tr_args/resume_neg_tr_args.yaml
+  peft_config:
+    (): peft.LoraConfig
+    r: 32
+    lora_alpha: 32
+    lora_dropout: 0.1
+    init_lora_weights: "gaussian"
+    bias: "none"
+    task_type: "FEATURE_EXTRACTION"
+    target_modules: '(.*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)'
+    # target_modules: '(.*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)'
+

+ 62 - 0
deconstruct_SQI/colpali/scripts/configs/pali/train_colpali_hardneg_model.yaml

@@ -0,0 +1,62 @@
+config:
+  (): colpali_engine.trainer.colmodel_training.ColModelTrainingConfig
+  output_dir: !path ../../../models/right_pad/train_colpali_hardneg_long
+  processor:
+    () : colpali_engine.utils.transformers_wrappers.AutoProcessorWrapper
+    pretrained_model_name_or_path:  "./models/colpaligemma-3b-pt-448-base" # "./models/paligemma-3b-mix-448"
+    max_length: 50
+  model:
+    (): colpali_engine.utils.transformers_wrappers.AllPurposeWrapper
+    class_to_instanciate: !ext colpali_engine.models.ColPali
+    pretrained_model_name_or_path: "./models/colpaligemma-3b-pt-448-base"
+    torch_dtype:  !ext torch.bfloat16
+#    device_map: "auto"
+#    quantization_config:
+#      (): transformers.BitsAndBytesConfig
+#      load_in_4bit: true
+#      bnb_4bit_quant_type: "nf4"
+#      bnb_4bit_compute_dtype:  "bfloat16"
+#      bnb_4bit_use_double_quant: true
+
+  train_dataset: 
+    (): colpali_engine.utils.dataset_transformation.load_train_set_ir_negs
+  eval_dataset: !import ../data/test_data.yaml
+
+  max_length: 50
+  run_eval: true
+  
+  loss_func:
+    (): colpali_engine.loss.late_interaction_losses.ColbertPairwiseNegativeCELoss
+    in_batch_term_weight: 0.5
+  tr_args:
+    (): transformers.training_args.TrainingArguments
+    output_dir: null
+    overwrite_output_dir: true
+    num_train_epochs: 5
+    per_device_train_batch_size: 4
+    # 6 x 8 gpus = 48 batch size
+    # gradient_accumulation_steps: 4
+    per_device_eval_batch_size: 4
+    eval_strategy: "steps"
+    # dataloader_num_workers: 8
+    # bf16: true
+    save_steps: 500
+    logging_steps: 10
+    eval_steps: 50
+    warmup_steps: 1000
+    learning_rate: 5e-5
+    save_total_limit: 1
+    resume_from_checkpoint: true
+    # optim: "paged_adamw_8bit"
+
+  peft_config:
+    (): peft.LoraConfig
+    r: 32
+    lora_alpha: 32
+    lora_dropout: 0.1
+    init_lora_weights: "gaussian"
+    bias: "none"
+    task_type: "FEATURE_EXTRACTION"
+    target_modules: '(.*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)'
+    # target_modules: '(.*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)'
+

+ 41 - 0
deconstruct_SQI/colpali/scripts/configs/pali/train_colpali_model.yaml

@@ -0,0 +1,41 @@
+config:
+  (): colpali_engine.trainer.colmodel_training.ColModelTrainingConfig
+  output_dir: !path ../../../models/right_pad/train_colpali-3b-mix-448
+  processor:
+    (): colpali_engine.utils.transformers_wrappers.AllPurposeWrapper
+    class_to_instanciate: !ext colpali_engine.models.ColPaliProcessor
+    pretrained_model_name_or_path:  "./models/colpaligemma-3b-mix-448-base" # "./models/paligemma-3b-mix-448"
+    max_length: 50
+  model:
+    (): colpali_engine.utils.transformers_wrappers.AllPurposeWrapper
+    class_to_instanciate: !ext colpali_engine.models.ColPali
+    pretrained_model_name_or_path: "./models/colpaligemma-3b-mix-448-base"
+    torch_dtype:  !ext torch.bfloat16
+#    device_map: "auto"
+#    quantization_config:
+#      (): transformers.BitsAndBytesConfig
+#      load_in_4bit: true
+#      bnb_4bit_quant_type: "nf4"
+#      bnb_4bit_compute_dtype:  "bfloat16"
+#      bnb_4bit_use_double_quant: true
+
+  train_dataset: 
+    (): colpali_engine.utils.dataset_transformation.load_train_set
+  eval_dataset: !import ../data/test_data.yaml
+
+  max_length: 50
+  run_eval: true
+  loss_func:
+    (): colpali_engine.loss.late_interaction_losses.ColbertPairwiseCELoss
+  tr_args: !import ../tr_args/default_tr_args.yaml
+  peft_config:
+    (): peft.LoraConfig
+    r: 32
+    lora_alpha: 32
+    lora_dropout: 0.1
+    init_lora_weights: "gaussian"
+    bias: "none"
+    task_type: "FEATURE_EXTRACTION"
+    target_modules: '(.*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)'
+    # target_modules: '(.*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)'
+

+ 41 - 0
deconstruct_SQI/colpali/scripts/configs/pali/train_colpali_pt_model.yaml

@@ -0,0 +1,41 @@
+config:
+  (): colpali_engine.trainer.colmodel_training.ColModelTrainingConfig
+  output_dir: !path ../../../models/right_pad/train_colpali-3b-pt-448
+  processor:
+    () : colpali_engine.utils.transformers_wrappers.AutoProcessorWrapper
+    pretrained_model_name_or_path:  "./models/colpaligemma-3b-pt-448-base" # "./models/paligemma-3b-mix-448"
+    max_length: 50
+  model:
+    (): colpali_engine.utils.transformers_wrappers.AllPurposeWrapper
+    class_to_instanciate: !ext colpali_engine.models.ColPali
+    pretrained_model_name_or_path: "./models/colpaligemma-3b-pt-448-base"
+    torch_dtype:  !ext torch.bfloat16
+#    device_map: "auto"
+#    quantization_config:
+#      (): transformers.BitsAndBytesConfig
+#      load_in_4bit: true
+#      bnb_4bit_quant_type: "nf4"
+#      bnb_4bit_compute_dtype:  "bfloat16"
+#      bnb_4bit_use_double_quant: true
+
+  train_dataset: 
+    (): colpali_engine.utils.dataset_transformation.load_train_set
+  eval_dataset: !import ../data/test_data.yaml
+
+  max_length: 50
+  run_eval: true
+  
+  loss_func:
+    (): colpali_engine.loss.late_interaction_losses.ColbertPairwiseCELoss
+  tr_args: !import ../tr_args/default_tr_args.yaml
+  peft_config:
+    (): peft.LoraConfig
+    r: 32
+    lora_alpha: 32
+    lora_dropout: 0.1
+    init_lora_weights: "gaussian"
+    bias: "none"
+    task_type: "FEATURE_EXTRACTION"
+    target_modules: '(.*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)'
+    # target_modules: '(.*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)'
+

+ 65 - 0
deconstruct_SQI/colpali/scripts/configs/qwen2/deprecated/train_biqwen2_docmatix_model.yaml

@@ -0,0 +1,65 @@
+config:
+  (): colpali_engine.trainer.colmodel_training.ColModelTrainingConfig
+  output_dir: !path ../../../models/biqwen2-docmatix-256
+  processor:
+    (): colpali_engine.utils.transformers_wrappers.AllPurposeWrapper
+    class_to_instanciate: !ext colpali_engine.models.BiQwen2Processor
+    pretrained_model_name_or_path:  "./models/colqwen2_base" # "./models/paligemma-3b-mix-448"
+    # max_length: 50
+
+  model:
+    (): colpali_engine.utils.transformers_wrappers.AllPurposeWrapper
+    class_to_instanciate: !ext colpali_engine.models.BiQwen2
+    pretrained_model_name_or_path: "./models/colqwen2_base"
+    torch_dtype:  !ext torch.bfloat16
+    use_cache: false
+    attn_implementation: "flash_attention_2"
+#    device_map: "auto"
+#    quantization_config:
+#      (): transformers.BitsAndBytesConfig
+#      load_in_4bit: true
+#      bnb_4bit_quant_type: "nf4"
+#      bnb_4bit_compute_dtype:  "bfloat16"
+#      bnb_4bit_use_double_quant: true
+
+  dataset_loading_func: !ext colpali_engine.utils.dataset_transformation.load_docmatix_ir_negs
+  eval_dataset_loader: !import ../data/test_data.yaml
+
+  # max_length: 50
+  run_eval: true
+  
+  loss_func:
+    (): colpali_engine.loss.bi_encoder_losses.BiPairwiseNegativeCELoss
+  tr_args:
+    (): transformers.training_args.TrainingArguments
+    output_dir: null
+    overwrite_output_dir: true
+    num_train_epochs: 1
+    per_device_train_batch_size: 64
+    gradient_checkpointing: true
+    gradient_checkpointing_kwargs: { "use_reentrant": false }
+    # 6 x 8 gpus = 48 batch size
+    # gradient_accumulation_steps: 4
+    per_device_eval_batch_size: 64
+    eval_strategy: "steps"
+    dataloader_num_workers: 8
+    # bf16: true
+    save_steps: 500
+    logging_steps: 10
+    eval_steps: 100
+    warmup_steps: 100
+    max_steps: 2000
+    learning_rate: 5e-5
+    save_total_limit: 1
+    # optim: "paged_adamw_8bit"
+  peft_config:
+    (): peft.LoraConfig
+    r: 32
+    lora_alpha: 32
+    lora_dropout: 0.1
+    init_lora_weights: "gaussian"
+    bias: "none"
+    task_type: "FEATURE_EXTRACTION"
+    target_modules: '(.*(model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$)'
+    # target_modules: '(.*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)'
+

+ 66 - 0
deconstruct_SQI/colpali/scripts/configs/qwen2/deprecated/train_biqwen2_warmup_model.yaml

@@ -0,0 +1,66 @@
+config:
+  (): colpali_engine.trainer.colmodel_training.ColModelTrainingConfig
+  output_dir: !path ../../../models/biqwen2-warmup-256-newpad-0e
+  processor:
+    (): colpali_engine.utils.transformers_wrappers.AllPurposeWrapper
+    class_to_instanciate: !ext colpali_engine.models.BiQwen2Processor
+    pretrained_model_name_or_path:  "./models/colqwen2_base" # "./models/paligemma-3b-mix-448"
+    # max_length: 50
+
+  model:
+    (): colpali_engine.utils.transformers_wrappers.AllPurposeWrapper
+    class_to_instanciate: !ext colpali_engine.models.BiQwen2
+    pretrained_model_name_or_path: "./models/qwen2-warmup"
+    torch_dtype:  !ext torch.bfloat16
+    use_cache: false
+    attn_implementation: "flash_attention_2"
+#    device_map: "auto"
+#    quantization_config:
+#      (): transformers.BitsAndBytesConfig
+#      load_in_4bit: true
+#      bnb_4bit_quant_type: "nf4"
+#      bnb_4bit_compute_dtype:  "bfloat16"
+#      bnb_4bit_use_double_quant: true
+
+  dataset_loading_func: !ext colpali_engine.utils.dataset_transformation.load_train_set
+  eval_dataset_loader: !import ../data/test_data.yaml
+
+  # max_length: 50
+  run_eval: true
+  
+  loss_func:
+    (): colpali_engine.loss.bi_encoder_losses.BiPairwiseCELoss
+  tr_args:
+    (): transformers.training_args.TrainingArguments
+    output_dir: null
+    overwrite_output_dir: true
+    num_train_epochs: 1
+    per_device_train_batch_size: 64
+    gradient_checkpointing: true
+    gradient_checkpointing_kwargs: { "use_reentrant": false }
+    # 6 x 8 gpus = 48 batch size
+    # gradient_accumulation_steps: 4
+    per_device_eval_batch_size: 64
+    eval_strategy: "steps"
+    dataloader_num_workers: 8
+    # bf16: true
+    save_steps: 500
+    logging_steps: 10
+    max_steps: 1
+    eval_steps: 100
+    warmup_steps: 100
+    learning_rate: 5e-5
+    save_total_limit: 1
+    resume_from_checkpoint: false
+    # optim: "paged_adamw_8bit"
+  peft_config:
+    (): peft.LoraConfig
+    r: 32
+    lora_alpha: 32
+    lora_dropout: 0.1
+    init_lora_weights: "gaussian"
+    bias: "none"
+    task_type: "FEATURE_EXTRACTION"
+    target_modules: '(.*(model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$)'
+    # target_modules: '(.*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)'
+

+ 67 - 0
deconstruct_SQI/colpali/scripts/configs/qwen2/deprecated/train_colqwen2_docmatix_model.yaml

@@ -0,0 +1,67 @@
+config:
+  (): colpali_engine.trainer.colmodel_training.ColModelTrainingConfig
+  output_dir: !path ../../../models/colqwen2-docmatix-ba256-ckpt-1000s
+  processor:
+    (): colpali_engine.utils.transformers_wrappers.AllPurposeWrapper
+    class_to_instanciate: !ext colpali_engine.models.ColQwen2Processor
+    pretrained_model_name_or_path:  "./models/colqwen2_base" # "./models/paligemma-3b-mix-448"
+    # num_image_tokens: 2048
+    # max_length: 50
+
+  model:
+    (): colpali_engine.utils.transformers_wrappers.AllPurposeWrapper
+    class_to_instanciate: !ext colpali_engine.models.ColQwen2
+    pretrained_model_name_or_path: "./models/colqwen2_base"
+    torch_dtype:  !ext torch.bfloat16
+    use_cache: false
+    attn_implementation: "flash_attention_2"
+#    device_map: "auto"
+#    quantization_config:
+#      (): transformers.BitsAndBytesConfig
+#      load_in_4bit: true
+#      bnb_4bit_quant_type: "nf4"
+#      bnb_4bit_compute_dtype:  "bfloat16"
+#      bnb_4bit_use_double_quant: true
+
+  dataset_loading_func: !ext colpali_engine.utils.dataset_transformation.load_docmatix_ir_negs
+  eval_dataset_loader: !import ../data/test_data.yaml
+
+  # max_length: 50
+  run_eval: true
+  
+  loss_func:
+    (): colpali_engine.loss.late_interaction_losses.ColbertPairwiseNegativeCELoss
+  tr_args:
+    (): transformers.training_args.TrainingArguments
+    output_dir: null
+    overwrite_output_dir: true
+    num_train_epochs: 1
+    per_device_train_batch_size: 64
+    gradient_checkpointing: true
+    gradient_checkpointing_kwargs: {"use_reentrant": false}
+    # 6 x 8 gpus = 48 batch size
+    # gradient_accumulation_steps: 4
+    per_device_eval_batch_size: 64
+    eval_strategy: "steps"
+    dataloader_num_workers: 8
+    # bf16: true
+    save_steps: 500
+    logging_steps: 10
+    eval_steps: 500
+    warmup_steps: 100
+    max_steps: 1000
+    learning_rate: 5e-5
+    save_total_limit: 1
+    # resume_from_checkpoint: true
+    # optim: "paged_adamw_8bit"
+  peft_config:
+    (): peft.LoraConfig
+    r: 32
+    lora_alpha: 32
+    lora_dropout: 0.1
+    init_lora_weights: "gaussian"
+    bias: "none"
+    task_type: "FEATURE_EXTRACTION"
+    target_modules: '(.*(model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)'
+    # target_modules: '(.*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)'
+

+ 66 - 0
deconstruct_SQI/colpali/scripts/configs/qwen2/deprecated/train_colqwen2_hardneg_model.yaml

@@ -0,0 +1,66 @@
+config:
+  (): colpali_engine.trainer.colmodel_training.ColModelTrainingConfig
+  output_dir: !path ../../../models/colqwen2-hardneg-256-5e
+  processor:
+    (): colpali_engine.utils.transformers_wrappers.AllPurposeWrapper
+    class_to_instanciate: !ext colpali_engine.models.ColQwen2Processor
+    pretrained_model_name_or_path:  "./models/base_models/colqwen2-base"
+    # max_length: 50
+
+  model:
+    (): colpali_engine.utils.transformers_wrappers.AllPurposeWrapper
+    class_to_instanciate: !ext colpali_engine.models.ColQwen2
+    pretrained_model_name_or_path: "./models/base_models/colqwen2-base"
+    torch_dtype:  !ext torch.bfloat16
+    use_cache: false
+    attn_implementation: "flash_attention_2"
+  #    device_map: "auto"
+#    quantization_config:
+#      (): transformers.BitsAndBytesConfig
+#      load_in_4bit: true
+#      bnb_4bit_quant_type: "nf4"
+#      bnb_4bit_compute_dtype:  "bfloat16"
+#      bnb_4bit_use_double_quant: true
+
+  dataset_loading_func: !ext colpali_engine.utils.dataset_transformation.load_train_set_ir_negs
+  eval_dataset_loader: !import ../data/test_data.yaml
+
+  # max_length: 50
+  run_eval: true
+  
+  loss_func:
+    (): colpali_engine.loss.late_interaction_losses.ColbertPairwiseNegativeCELoss
+    in_batch_term_weight: 0.5
+  tr_args:
+    (): transformers.training_args.TrainingArguments
+    output_dir: null
+    overwrite_output_dir: true
+    num_train_epochs: 5
+    per_device_train_batch_size: 64
+    gradient_checkpointing: true
+    gradient_checkpointing_kwargs: {"use_reentrant": false}
+    # 6 x 8 gpus = 48 batch size
+    # gradient_accumulation_steps: 4
+    per_device_eval_batch_size: 64
+    eval_strategy: "steps"
+    dataloader_num_workers: 8
+    # bf16: true
+    save_steps: 500
+    logging_steps: 10
+    eval_steps: 100
+    warmup_steps: 100
+    learning_rate: 2e-4
+    save_total_limit: 1
+    # resume_from_checkpoint: true
+    # optim: "paged_adamw_8bit"  peft_config:
+  peft_config:
+    (): peft.LoraConfig
+    r: 32
+    lora_alpha: 32
+    lora_dropout: 0.1
+    init_lora_weights: "gaussian"
+    bias: "none"
+    task_type: "FEATURE_EXTRACTION"
+    target_modules: '(.*(model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)'
+    # target_modules: '(.*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)'
+

+ 65 - 0
deconstruct_SQI/colpali/scripts/configs/qwen2/deprecated/train_colqwen2_wikiss_model.yaml

@@ -0,0 +1,65 @@
+config:
+  (): colpali_engine.trainer.colmodel_training.ColModelTrainingConfig
+  output_dir: !path ../../../models/colqwen2-wikiss-ba64
+  processor:
+    (): colpali_engine.utils.transformers_wrappers.AllPurposeWrapper
+    class_to_instanciate: !ext colpali_engine.models.ColQwen2Processor
+    pretrained_model_name_or_path:  "./models/colqwen2_base" # "./models/paligemma-3b-mix-448"
+    # num_image_tokens: 2048
+    # max_length: 50
+
+  model:
+    (): colpali_engine.utils.transformers_wrappers.AllPurposeWrapper
+    class_to_instanciate: !ext colpali_engine.models.ColQwen2
+    pretrained_model_name_or_path: "./models/colqwen2_base"
+    torch_dtype:  !ext torch.bfloat16
+    use_cache: false
+    attn_implementation: "flash_attention_2"
+#    device_map: "auto"
+#    quantization_config:
+#      (): transformers.BitsAndBytesConfig
+#      load_in_4bit: true
+#      bnb_4bit_quant_type: "nf4"
+#      bnb_4bit_compute_dtype:  "bfloat16"
+#      bnb_4bit_use_double_quant: true
+
+  dataset_loading_func: !ext colpali_engine.utils.dataset_transformation.load_wikiss
+  eval_dataset_loader: !import ../data/test_data.yaml
+
+  # max_length: 50
+  run_eval: true
+  
+  loss_func:
+    (): colpali_engine.loss.late_interaction_losses.ColbertPairwiseNegativeCELoss
+  tr_args:
+    (): transformers.training_args.TrainingArguments
+    output_dir: null
+    overwrite_output_dir: true
+    num_train_epochs: 1
+    per_device_train_batch_size: 16
+    gradient_checkpointing: true
+    # 6 x 8 gpus = 48 batch size
+    # gradient_accumulation_steps: 4
+    per_device_eval_batch_size: 16
+    eval_strategy: "steps"
+    dataloader_num_workers: 16
+    # bf16: true
+    save_steps: 500
+    logging_steps: 10
+    eval_steps: 500
+    warmup_steps: 1000
+    learning_rate: 5e-4
+    save_total_limit: 1
+    # resume_from_checkpoint: true
+    # optim: "paged_adamw_8bit"
+  peft_config:
+    (): peft.LoraConfig
+    r: 32
+    lora_alpha: 32
+    lora_dropout: 0.1
+    init_lora_weights: "gaussian"
+    bias: "none"
+    task_type: "FEATURE_EXTRACTION"
+    target_modules: '(.*(model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)'
+    # target_modules: '(.*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)'
+

+ 68 - 0
deconstruct_SQI/colpali/scripts/configs/qwen2/train_biqwen2_hardneg_model.py

@@ -0,0 +1,68 @@
+import os
+import shutil
+from pathlib import Path
+
+import torch
+from peft import LoraConfig
+from transformers import TrainingArguments
+
+from colpali_engine.loss.bi_encoder_losses import BiEncoderLoss
+from colpali_engine.models import BiQwen2, BiQwen2Processor
+from colpali_engine.trainer.colmodel_training import ColModelTraining, ColModelTrainingConfig
+from colpali_engine.utils.dataset_transformation import load_train_set
+
+config = ColModelTrainingConfig(
+    output_dir="./models/biqwen2-hardneg-5e-0304",
+    processor=BiQwen2Processor.from_pretrained(
+        pretrained_model_name_or_path="./models/base_models/colqwen2-base",
+    ),
+    model=BiQwen2.from_pretrained(
+        pretrained_model_name_or_path="./models/base_models/colqwen2-base",
+        torch_dtype=torch.bfloat16,
+        use_cache=False,
+        attn_implementation="flash_attention_2",
+    ),
+    dataset_loading_func=load_train_set,  # load_train_set_ir_negs,
+    eval_dataset_loader=None,
+    run_eval=True,
+    loss_func=BiEncoderLoss(),  # BiNegativeCELoss(in_batch_term_weight=0.5),
+    tr_args=TrainingArguments(
+        output_dir=None,
+        overwrite_output_dir=True,
+        num_train_epochs=5,
+        per_device_train_batch_size=64,
+        gradient_checkpointing=True,
+        gradient_checkpointing_kwargs={"use_reentrant": False},
+        per_device_eval_batch_size=16,
+        eval_strategy="steps",
+        dataloader_num_workers=2,
+        save_steps=500,
+        logging_steps=10,
+        eval_steps=100,
+        warmup_steps=100,
+        learning_rate=2e-4,
+        save_total_limit=1,
+    ),
+    peft_config=LoraConfig(
+        r=32,
+        lora_alpha=32,
+        lora_dropout=0.1,
+        init_lora_weights="gaussian",
+        bias="none",
+        task_type="FEATURE_EXTRACTION",
+        target_modules="(.*(model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$)",
+    ),
+)
+
+
+if __name__ == "__main__":
+    # ensure output_dir exists
+    os.makedirs(config.output_dir, exist_ok=True)
+    # version this script by copying it into the output dir
+    current_script = Path(__file__)
+    shutil.copy(current_script, Path(config.output_dir) / current_script.name)
+
+    training_app = ColModelTraining(config)
+
+    training_app.train()
+    training_app.save()

+ 66 - 0
deconstruct_SQI/colpali/scripts/configs/qwen2/train_biqwen2_hardneg_model.yaml

@@ -0,0 +1,66 @@
+config:
+  (): colpali_engine.trainer.colmodel_training.ColModelTrainingConfig
+  output_dir: !path ../../../models/biqwen2-hardneg-5e-0304
+  processor:
+    (): colpali_engine.utils.transformers_wrappers.AllPurposeWrapper
+    class_to_instanciate: !ext colpali_engine.models.BiQwen2Processor
+    pretrained_model_name_or_path:  "./models/base_models/colqwen2-base"
+    # max_length: 50
+
+  model:
+    (): colpali_engine.utils.transformers_wrappers.AllPurposeWrapper
+    class_to_instanciate: !ext colpali_engine.models.BiQwen2
+    pretrained_model_name_or_path: "./models/base_models/colqwen2-base"
+    torch_dtype:  !ext torch.bfloat16
+    use_cache: false
+    attn_implementation: "flash_attention_2"
+#    device_map: "auto"
+#    quantization_config:
+#      (): transformers.BitsAndBytesConfig
+#      load_in_4bit: true
+#      bnb_4bit_quant_type: "nf4"
+#      bnb_4bit_compute_dtype:  "bfloat16"
+#      bnb_4bit_use_double_quant: true
+
+  train_dataset: 
+    (): colpali_engine.utils.dataset_transformation.load_train_set_ir_negs
+  eval_dataset: !import ../data/test_data.yaml
+
+  # max_length: 50
+  run_eval: true
+  
+  loss_func:
+    (): colpali_engine.loss.bi_encoder_losses.BiNegativeCELoss
+    in_batch_term_weight: 0.5
+  tr_args:
+    (): transformers.training_args.TrainingArguments
+    output_dir: null
+    overwrite_output_dir: true
+    num_train_epochs: 5
+    per_device_train_batch_size: 64
+    gradient_checkpointing: true
+    gradient_checkpointing_kwargs: { "use_reentrant": false }
+    # 6 x 8 gpus = 48 batch size
+    # gradient_accumulation_steps: 4
+    per_device_eval_batch_size: 64
+    eval_strategy: "steps"
+    dataloader_num_workers: 8
+    # bf16: true
+    save_steps: 500
+    logging_steps: 10
+    eval_steps: 100
+    warmup_steps: 100
+    learning_rate: 2e-4
+    save_total_limit: 1
+    # optim: "paged_adamw_8bit"
+  peft_config:
+    (): peft.LoraConfig
+    r: 32
+    lora_alpha: 32
+    lora_dropout: 0.1
+    init_lora_weights: "gaussian"
+    bias: "none"
+    task_type: "FEATURE_EXTRACTION"
+    target_modules: '(.*(model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$)'
+    # target_modules: '(.*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)'
+

+ 62 - 0
deconstruct_SQI/colpali/scripts/configs/qwen2/train_biqwen2_model.yaml

@@ -0,0 +1,62 @@
+config:
+  (): colpali_engine.trainer.colmodel_training.ColModelTrainingConfig
+  output_dir: !path ../../../models/biqwen2-ba256-5e-0304
+  processor:
+    (): colpali_engine.utils.transformers_wrappers.AllPurposeWrapper
+    class_to_instanciate: !ext colpali_engine.models.BiQwen2Processor
+    pretrained_model_name_or_path:  "./models/base_models/colqwen2-base"
+    max_num_visual_tokens: 1024
+
+  model:
+    (): colpali_engine.utils.transformers_wrappers.AllPurposeWrapper
+    class_to_instanciate: !ext colpali_engine.models.BiQwen2
+    pretrained_model_name_or_path: "./models/base_models/colqwen2-base"
+    torch_dtype:  !ext torch.bfloat16
+    use_cache: false
+    attn_implementation: "flash_attention_2"
+
+
+  train_dataset: 
+    (): colpali_engine.utils.dataset_transformation.load_train_set
+  eval_dataset: !import ../data/test_data.yaml
+
+  # max_length: 50
+  run_eval: true
+  
+  loss_func:
+    (): colpali_engine.loss.bi_encoder_losses.BiEncoderLoss
+  tr_args:
+    (): transformers.training_args.TrainingArguments
+    output_dir: null
+    overwrite_output_dir: true
+    num_train_epochs: 5
+    per_device_train_batch_size: 64
+    gradient_checkpointing: true
+    gradient_checkpointing_kwargs: { "use_reentrant": false }
+    # 6 x 8 gpus = 48 batch size
+    # gradient_accumulation_steps: 4
+    per_device_eval_batch_size: 8
+    eval_strategy: "steps"
+    dataloader_num_workers: 4
+    # bf16: true
+    save_steps: 500
+    logging_steps: 10
+    eval_steps: 100
+    warmup_steps: 100
+    learning_rate: 2e-4
+    save_total_limit: 1
+    resume_from_checkpoint: false
+    report_to: "wandb"
+
+    # optim: "paged_adamw_8bit"
+  peft_config:
+    (): peft.LoraConfig
+    r: 32
+    lora_alpha: 32
+    lora_dropout: 0.1
+    init_lora_weights: "gaussian"
+    bias: "none"
+    task_type: "FEATURE_EXTRACTION"
+    target_modules: '(.*(model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$)'
+    # target_modules: '(.*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)'
+

+ 100 - 0
deconstruct_SQI/colpali/scripts/configs/qwen2/train_colqwen25_model.py

@@ -0,0 +1,100 @@
+import argparse
+import shutil
+from pathlib import Path
+
+import torch
+from datasets import load_dataset
+from peft import LoraConfig
+from transformers import TrainingArguments
+
+from colpali_engine.data.dataset import ColPaliEngineDataset
+from colpali_engine.loss.late_interaction_losses import ColbertLoss, ColbertPairwiseCELoss
+from colpali_engine.models import ColQwen2_5, ColQwen2_5_Processor
+from colpali_engine.trainer.colmodel_torch_training import ColModelTorchTraining
+from colpali_engine.trainer.colmodel_training import ColModelTraining, ColModelTrainingConfig
+from colpali_engine.utils.dataset_transformation import load_train_set
+
+
+def parse_args():
+    p = argparse.ArgumentParser()
+    p.add_argument("--output-dir", type=str, required=True, help="where to write model + script copy")
+    p.add_argument("--lr", type=float, default=2e-4, help="learning rate")
+    p.add_argument("--tau", type=float, default=0.02, help="temperature for loss function")
+    p.add_argument("--trainer", type=str, default="hf", choices=["torch", "hf"], help="trainer to use")
+    p.add_argument("--loss", type=str, default="ce", choices=["ce", "pairwise"], help="loss function to use")
+    p.add_argument("--peft", action="store_true", help="use PEFT for training")
+    return p.parse_args()
+
+
+if __name__ == "__main__":
+    args = parse_args()
+
+    if args.loss == "ce":
+        loss_func = ColbertLoss(
+            temperature=args.tau,
+            normalize_scores=True,
+            use_smooth_max=False,
+            pos_aware_negative_filtering=False,
+        )
+    elif args.loss == "pairwise":
+        loss_func = ColbertPairwiseCELoss(
+            normalize_scores=False,
+        )
+    else:
+        raise ValueError(f"Unknown loss function: {args.loss}")
+
+    config = ColModelTrainingConfig(
+        output_dir=args.output_dir,
+        processor=ColQwen2_5_Processor.from_pretrained(
+            pretrained_model_name_or_path="./models/base_models/colqwen2.5-base",
+            max_num_visual_tokens=768,
+        ),
+        model=ColQwen2_5.from_pretrained(
+            pretrained_model_name_or_path="./models/base_models/colqwen2.5-base",
+            torch_dtype=torch.bfloat16,
+            use_cache=False,
+            attn_implementation="flash_attention_2",
+        ),
+        train_dataset=load_train_set(),
+        eval_dataset=ColPaliEngineDataset(
+            load_dataset("./data_dir/colpali_train_set", split="test"), pos_target_column_name="image"
+        ),
+        run_eval=True,
+        loss_func=loss_func,
+        tr_args=TrainingArguments(
+            output_dir=None,
+            overwrite_output_dir=True,
+            num_train_epochs=5,
+            per_device_train_batch_size=64,
+            gradient_checkpointing=True,
+            gradient_checkpointing_kwargs={"use_reentrant": False},
+            per_device_eval_batch_size=16,
+            eval_strategy="steps",
+            dataloader_num_workers=8,
+            save_steps=500,
+            logging_steps=10,
+            eval_steps=100,
+            warmup_steps=100,
+            learning_rate=args.lr,
+            save_total_limit=1,
+        ),
+        peft_config=LoraConfig(
+            r=32,
+            lora_alpha=32,
+            lora_dropout=0.1,
+            init_lora_weights="gaussian",
+            bias="none",
+            task_type="FEATURE_EXTRACTION",
+            target_modules="(.*(model)(?!.*visual).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)",
+        )
+        if args.peft
+        else None,
+    )
+
+    # make sure output_dir exists and copy script for provenance
+    Path(config.output_dir).mkdir(parents=True, exist_ok=True)
+    shutil.copy(Path(__file__), Path(config.output_dir) / Path(__file__).name)
+
+    trainer = ColModelTraining(config) if args.trainer == "hf" else ColModelTorchTraining(config)
+    trainer.train()
+    trainer.save()

Některé soubory nejsou zobrazeny, neboť je v těchto rozdílových datech změněno mnoho souborů