import copy import os from typing import Any, Dict, Union from transformers import AutoConfig from transformers.configuration_utils import PretrainedConfig from transformers.utils import logging logger = logging.get_logger(__name__) DEFAULT_TEXT_MODEL_NAME = "jhu-clsp/ettin-encoder-150m" DEFAULT_VISION_MODEL_NAME = "google/siglip2-base-patch16-512" def collect_arg_in_candidates(config, candidates, default=None) -> Any: """Gets the first available argument in a config given a list of candidate names.""" for c in candidates: if hasattr(config, c): return getattr(config, c) elif c in config: return config[c] if default is not None: return default raise ValueError(f"No matching arguments found in candidates. Candidates: {candidates}, Config: {config}") class ModernVBertTextConfig(PretrainedConfig): r""" This is the configuration class to store the configuration of a [`ModernBERT`]. It is used to instantiate an ModernBERT model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the [jhu-clsp/ettin-encoder-150m](https://huggingface.co/jhu-clsp/ettin-encoder-150m) architecture. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. """ model_type = "modernvbert_text" def __init__( self, text_model_name=DEFAULT_TEXT_MODEL_NAME, hidden_size=768, num_hidden_layers=22, intermediate_size=1152, mlp_bias=False, vocab_size=50368, **kwargs, ): super().__init__( text_model_name=text_model_name, hidden_size=hidden_size, num_hidden_layers=num_hidden_layers, intermediate_size=intermediate_size, mlp_bias=mlp_bias, vocab_size=vocab_size, **kwargs, ) @classmethod def from_base_model( cls, text_model_name=DEFAULT_TEXT_MODEL_NAME, **kwargs, ): text_config = AutoConfig.from_pretrained(text_model_name, trust_remote_code=True) if hasattr(text_config, "text_config"): text_config = text_config.text_config hidden_size = collect_arg_in_candidates(text_config, ["hidden_size", "embed_dim"]) num_hidden_layers = collect_arg_in_candidates(text_config, ["num_hidden_layers", "num_hidden_blocks"]) intermediate_size = collect_arg_in_candidates(text_config, ["intermediate_size", "mlp_dim"]) mlp_bias = collect_arg_in_candidates(text_config, ["mlp_bias", "mlp_hidden_bias"], default=False) vocab_size = collect_arg_in_candidates(text_config, ["vocab_size"]) return cls( text_model_name=text_model_name, hidden_size=hidden_size, num_hidden_layers=num_hidden_layers, intermediate_size=intermediate_size, mlp_bias=mlp_bias, vocab_size=vocab_size, **kwargs, ) class ModernVBertVisionConfig(PretrainedConfig): r""" This is the configuration class to store the configuration of a [`SigLIP`]. It is used to instantiate the vision encoder part of the ModernVBERT. model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the SigLIP. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. """ model_type = "modernvbert_vision" attribute_map = { "hidden_size": "embed_dim", } def __init__( self, vision_model_name=DEFAULT_VISION_MODEL_NAME, embed_dim=768, image_size=512, patch_size=16, num_hidden_layers=12, intermediate_size=3072, **kwargs, ): super().__init__( vision_model_name=vision_model_name, embed_dim=embed_dim, image_size=image_size, patch_size=patch_size, num_hidden_layers=num_hidden_layers, intermediate_size=intermediate_size, **kwargs, ) @classmethod def from_base_model( cls, vision_model_name=DEFAULT_VISION_MODEL_NAME, **kwargs, ): vision_config = AutoConfig.from_pretrained(vision_model_name, trust_remote_code=True) if hasattr(vision_config, "vision_config"): vision_config = vision_config.vision_config embed_dim = collect_arg_in_candidates(vision_config, ["embed_dim", "hidden_size"]) image_size = collect_arg_in_candidates(vision_config, ["image_size", "img_size"]) patch_size = collect_arg_in_candidates(vision_config, ["patch_size"]) num_hidden_layers = collect_arg_in_candidates(vision_config, ["num_hidden_layers", "num_hidden_blocks"]) intermediate_size = collect_arg_in_candidates(vision_config, ["intermediate_size", "mlp_dim"]) return cls( vision_model_name=vision_model_name, embed_dim=embed_dim, image_size=image_size, patch_size=patch_size, num_hidden_layers=num_hidden_layers, intermediate_size=intermediate_size, **kwargs, ) class ModernVBertConfig(PretrainedConfig): r""" This is the configuration class to store the configuration of a `ModernVBert` model. It is used to instantiate a ModernVBert model according to the specified arguments and defines the model architecture. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. See the documentation for [`PretrainedConfig`] for more details. Args: text_config (`PretrainedConfig` or `dict`, optional): Custom text config or a dict with a `text_model_name` key for the text encoder. If `None`, the default text backbone defined by `DEFAULT_TEXT_MODEL_NAME` is used. vision_config (`PretrainedConfig` or `dict`, optional): Custom vision config or a dict with a `vision_model_name` key for the vision encoder. If `None`, the default vision backbone defined by `DEFAULT_VISION_MODEL_NAME` is used. image_token_id (`int`, optional, defaults to 128257): Token id reserved for image tokens inserted into the text stream. vocab_size (`int`, optional, defaults to 128256): Vocabulary size used by the text embeddings. use_cache (`bool`, optional, defaults to `True`): Whether to cache key/value tensors for attention (relevant for decoder architectures). tie_word_embeddings (`bool`, optional, defaults to `False`): Whether to tie input token embeddings and output token embeddings. pixel_shuffle_factor (`int`, optional, defaults to 4): Scale factor used by any pixel-shuffle / upsampling operations in the vision head. additional_vocab_size (`int`, optional, defaults to 0): Number of extra tokens appended to the base vocabulary (useful for adapters / special tokens). pad_token_id (`int`, optional): Padding token id. initializer_range (`float`, optional, defaults to 0.02): Stddev used for weight initialization. freeze_config (`Any`, optional): Optional config describing which submodules to freeze during training. use_resampler (`bool`, optional, defaults to `False`): Whether to enable an additional resampler on visual features. neftune_noise_alpha (`float`, optional, defaults to 0.0): Alpha parameter for neftune noise injection. Example: ```python >>> from modernvbert import ModernVBertConfig >>> # Initializing configuration >>> configuration = ModernVBertConfig() >>> # Initializing a model from the configuration (model class is implemented in >>> # `modernvbert.modeling_modernvbert`) >>> # from modernvbert import ModernVBertModel >>> # model = ModernVBertModel(configuration) >>> # Accessing the model configuration >>> # cfg = model.config ```""" model_type = "modernvbert" is_composition = True def __init__( self, text_config: Union[PretrainedConfig, Dict[str, Any]] = None, vision_config: Union[PretrainedConfig, Dict[str, Any]] = None, image_token_id: int = 50407, vocab_size=50368, use_cache=True, tie_word_embeddings=False, freeze_config=None, pad_token_id=None, initializer_range=0.02, pixel_shuffle_factor=4, use_resampler=False, additional_vocab_size=0, neftune_noise_alpha=0.0, **kwargs, ): self.image_token_id = image_token_id self.use_cache = use_cache self.tie_word_embeddings = tie_word_embeddings self.scale_factor = pixel_shuffle_factor self.additional_vocab_size = additional_vocab_size if text_config is None: base_text_config = AutoConfig.from_pretrained(DEFAULT_TEXT_MODEL_NAME, trust_remote_code=True) text_config = ModernVBertTextConfig(base_text_config) elif isinstance(text_config, dict): text_config = ModernVBertTextConfig.from_dict(text_config) self.text_config = text_config if vision_config is None: base_vision_config = AutoConfig.from_pretrained(DEFAULT_VISION_MODEL_NAME, trust_remote_code=True) vision_config = ModernVBertVisionConfig(base_vision_config) elif isinstance(vision_config, dict): vision_config = ModernVBertVisionConfig.from_dict(vision_config) self.vision_config = vision_config self.freeze_config = freeze_config self.pixel_shuffle_factor = pixel_shuffle_factor self.use_resampler = use_resampler self.neftune_noise_alpha = neftune_noise_alpha self.initializer_range = initializer_range hidden_size = kwargs.pop("hidden_size", self.text_config.hidden_size) super().__init__( **kwargs, pad_token_id=pad_token_id, tie_word_embeddings=tie_word_embeddings, vocab_size=vocab_size, hidden_size=hidden_size, ) def to_dict(self): output = copy.deepcopy(self.__dict__) output["model_type"] = self.__class__.model_type output["vision_config"] = self.vision_config.to_dict() output["text_config"] = self.text_config.to_dict() return output @classmethod def from_pretrained_models( cls, text_model_name: Union[str, os.PathLike], vision_model_name: Union[str, os.PathLike], **kwargs, ) -> "PretrainedConfig": text_model_config = ModernVBertTextConfig.from_base_model(text_model_name) vision_model_config = ModernVBertVisionConfig.from_base_model(vision_model_name) return cls( text_config=text_model_config, vision_config=vision_model_config, **kwargs, )