configuration_modernvbert.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279
  1. import copy
  2. import os
  3. from typing import Any, Dict, Union
  4. from transformers import AutoConfig
  5. from transformers.configuration_utils import PretrainedConfig
  6. from transformers.utils import logging
  7. logger = logging.get_logger(__name__)
  8. DEFAULT_TEXT_MODEL_NAME = "jhu-clsp/ettin-encoder-150m"
  9. DEFAULT_VISION_MODEL_NAME = "google/siglip2-base-patch16-512"
  10. def collect_arg_in_candidates(config, candidates, default=None) -> Any:
  11. """Gets the first available argument in a config given a list of candidate names."""
  12. for c in candidates:
  13. if hasattr(config, c):
  14. return getattr(config, c)
  15. elif c in config:
  16. return config[c]
  17. if default is not None:
  18. return default
  19. raise ValueError(f"No matching arguments found in candidates. Candidates: {candidates}, Config: {config}")
  20. class ModernVBertTextConfig(PretrainedConfig):
  21. r"""
  22. This is the configuration class to store the configuration of a [`ModernBERT`].
  23. It is used to instantiate an ModernBERT
  24. model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
  25. defaults will yield a similar configuration to that of the
  26. [jhu-clsp/ettin-encoder-150m](https://huggingface.co/jhu-clsp/ettin-encoder-150m) architecture.
  27. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
  28. documentation from [`PretrainedConfig`] for more information.
  29. """
  30. model_type = "modernvbert_text"
  31. def __init__(
  32. self,
  33. text_model_name=DEFAULT_TEXT_MODEL_NAME,
  34. hidden_size=768,
  35. num_hidden_layers=22,
  36. intermediate_size=1152,
  37. mlp_bias=False,
  38. vocab_size=50368,
  39. **kwargs,
  40. ):
  41. super().__init__(
  42. text_model_name=text_model_name,
  43. hidden_size=hidden_size,
  44. num_hidden_layers=num_hidden_layers,
  45. intermediate_size=intermediate_size,
  46. mlp_bias=mlp_bias,
  47. vocab_size=vocab_size,
  48. **kwargs,
  49. )
  50. @classmethod
  51. def from_base_model(
  52. cls,
  53. text_model_name=DEFAULT_TEXT_MODEL_NAME,
  54. **kwargs,
  55. ):
  56. text_config = AutoConfig.from_pretrained(text_model_name, trust_remote_code=True)
  57. if hasattr(text_config, "text_config"):
  58. text_config = text_config.text_config
  59. hidden_size = collect_arg_in_candidates(text_config, ["hidden_size", "embed_dim"])
  60. num_hidden_layers = collect_arg_in_candidates(text_config, ["num_hidden_layers", "num_hidden_blocks"])
  61. intermediate_size = collect_arg_in_candidates(text_config, ["intermediate_size", "mlp_dim"])
  62. mlp_bias = collect_arg_in_candidates(text_config, ["mlp_bias", "mlp_hidden_bias"], default=False)
  63. vocab_size = collect_arg_in_candidates(text_config, ["vocab_size"])
  64. return cls(
  65. text_model_name=text_model_name,
  66. hidden_size=hidden_size,
  67. num_hidden_layers=num_hidden_layers,
  68. intermediate_size=intermediate_size,
  69. mlp_bias=mlp_bias,
  70. vocab_size=vocab_size,
  71. **kwargs,
  72. )
  73. class ModernVBertVisionConfig(PretrainedConfig):
  74. r"""
  75. This is the configuration class to store the configuration of a [`SigLIP`]. It is used to instantiate
  76. the vision encoder part of the ModernVBERT.
  77. model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
  78. defaults will yield a similar configuration to that of the SigLIP.
  79. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
  80. documentation from [`PretrainedConfig`] for more information.
  81. """
  82. model_type = "modernvbert_vision"
  83. attribute_map = {
  84. "hidden_size": "embed_dim",
  85. }
  86. def __init__(
  87. self,
  88. vision_model_name=DEFAULT_VISION_MODEL_NAME,
  89. embed_dim=768,
  90. image_size=512,
  91. patch_size=16,
  92. num_hidden_layers=12,
  93. intermediate_size=3072,
  94. **kwargs,
  95. ):
  96. super().__init__(
  97. vision_model_name=vision_model_name,
  98. embed_dim=embed_dim,
  99. image_size=image_size,
  100. patch_size=patch_size,
  101. num_hidden_layers=num_hidden_layers,
  102. intermediate_size=intermediate_size,
  103. **kwargs,
  104. )
  105. @classmethod
  106. def from_base_model(
  107. cls,
  108. vision_model_name=DEFAULT_VISION_MODEL_NAME,
  109. **kwargs,
  110. ):
  111. vision_config = AutoConfig.from_pretrained(vision_model_name, trust_remote_code=True)
  112. if hasattr(vision_config, "vision_config"):
  113. vision_config = vision_config.vision_config
  114. embed_dim = collect_arg_in_candidates(vision_config, ["embed_dim", "hidden_size"])
  115. image_size = collect_arg_in_candidates(vision_config, ["image_size", "img_size"])
  116. patch_size = collect_arg_in_candidates(vision_config, ["patch_size"])
  117. num_hidden_layers = collect_arg_in_candidates(vision_config, ["num_hidden_layers", "num_hidden_blocks"])
  118. intermediate_size = collect_arg_in_candidates(vision_config, ["intermediate_size", "mlp_dim"])
  119. return cls(
  120. vision_model_name=vision_model_name,
  121. embed_dim=embed_dim,
  122. image_size=image_size,
  123. patch_size=patch_size,
  124. num_hidden_layers=num_hidden_layers,
  125. intermediate_size=intermediate_size,
  126. **kwargs,
  127. )
  128. class ModernVBertConfig(PretrainedConfig):
  129. r"""
  130. This is the configuration class to store the configuration of a `ModernVBert` model. It is used to
  131. instantiate a ModernVBert model according to the specified arguments and defines the model architecture.
  132. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs.
  133. See the documentation for [`PretrainedConfig`] for more details.
  134. Args:
  135. text_config (`PretrainedConfig` or `dict`, optional):
  136. Custom text config or a dict with a `text_model_name` key for the text encoder. If `None`, the
  137. default text backbone defined by `DEFAULT_TEXT_MODEL_NAME` is used.
  138. vision_config (`PretrainedConfig` or `dict`, optional):
  139. Custom vision config or a dict with a `vision_model_name` key for the vision encoder. If `None`, the
  140. default vision backbone defined by `DEFAULT_VISION_MODEL_NAME` is used.
  141. image_token_id (`int`, optional, defaults to 128257):
  142. Token id reserved for image tokens inserted into the text stream.
  143. vocab_size (`int`, optional, defaults to 128256):
  144. Vocabulary size used by the text embeddings.
  145. use_cache (`bool`, optional, defaults to `True`):
  146. Whether to cache key/value tensors for attention (relevant for decoder architectures).
  147. tie_word_embeddings (`bool`, optional, defaults to `False`):
  148. Whether to tie input token embeddings and output token embeddings.
  149. pixel_shuffle_factor (`int`, optional, defaults to 4):
  150. Scale factor used by any pixel-shuffle / upsampling operations in the vision head.
  151. additional_vocab_size (`int`, optional, defaults to 0):
  152. Number of extra tokens appended to the base vocabulary (useful for adapters / special tokens).
  153. pad_token_id (`int`, optional):
  154. Padding token id.
  155. initializer_range (`float`, optional, defaults to 0.02):
  156. Stddev used for weight initialization.
  157. freeze_config (`Any`, optional):
  158. Optional config describing which submodules to freeze during training.
  159. use_resampler (`bool`, optional, defaults to `False`):
  160. Whether to enable an additional resampler on visual features.
  161. neftune_noise_alpha (`float`, optional, defaults to 0.0):
  162. Alpha parameter for neftune noise injection.
  163. Example:
  164. ```python
  165. >>> from modernvbert import ModernVBertConfig
  166. >>> # Initializing configuration
  167. >>> configuration = ModernVBertConfig()
  168. >>> # Initializing a model from the configuration (model class is implemented in
  169. >>> # `modernvbert.modeling_modernvbert`)
  170. >>> # from modernvbert import ModernVBertModel
  171. >>> # model = ModernVBertModel(configuration)
  172. >>> # Accessing the model configuration
  173. >>> # cfg = model.config
  174. ```"""
  175. model_type = "modernvbert"
  176. is_composition = True
  177. def __init__(
  178. self,
  179. text_config: Union[PretrainedConfig, Dict[str, Any]] = None,
  180. vision_config: Union[PretrainedConfig, Dict[str, Any]] = None,
  181. image_token_id: int = 50407,
  182. vocab_size=50368,
  183. use_cache=True,
  184. tie_word_embeddings=False,
  185. freeze_config=None,
  186. pad_token_id=None,
  187. initializer_range=0.02,
  188. pixel_shuffle_factor=4,
  189. use_resampler=False,
  190. additional_vocab_size=0,
  191. neftune_noise_alpha=0.0,
  192. **kwargs,
  193. ):
  194. self.image_token_id = image_token_id
  195. self.use_cache = use_cache
  196. self.tie_word_embeddings = tie_word_embeddings
  197. self.scale_factor = pixel_shuffle_factor
  198. self.additional_vocab_size = additional_vocab_size
  199. if text_config is None:
  200. base_text_config = AutoConfig.from_pretrained(DEFAULT_TEXT_MODEL_NAME, trust_remote_code=True)
  201. text_config = ModernVBertTextConfig(base_text_config)
  202. elif isinstance(text_config, dict):
  203. text_config = ModernVBertTextConfig.from_dict(text_config)
  204. self.text_config = text_config
  205. if vision_config is None:
  206. base_vision_config = AutoConfig.from_pretrained(DEFAULT_VISION_MODEL_NAME, trust_remote_code=True)
  207. vision_config = ModernVBertVisionConfig(base_vision_config)
  208. elif isinstance(vision_config, dict):
  209. vision_config = ModernVBertVisionConfig.from_dict(vision_config)
  210. self.vision_config = vision_config
  211. self.freeze_config = freeze_config
  212. self.pixel_shuffle_factor = pixel_shuffle_factor
  213. self.use_resampler = use_resampler
  214. self.neftune_noise_alpha = neftune_noise_alpha
  215. self.initializer_range = initializer_range
  216. hidden_size = kwargs.pop("hidden_size", self.text_config.hidden_size)
  217. super().__init__(
  218. **kwargs,
  219. pad_token_id=pad_token_id,
  220. tie_word_embeddings=tie_word_embeddings,
  221. vocab_size=vocab_size,
  222. hidden_size=hidden_size,
  223. )
  224. def to_dict(self):
  225. output = copy.deepcopy(self.__dict__)
  226. output["model_type"] = self.__class__.model_type
  227. output["vision_config"] = self.vision_config.to_dict()
  228. output["text_config"] = self.text_config.to_dict()
  229. return output
  230. @classmethod
  231. def from_pretrained_models(
  232. cls,
  233. text_model_name: Union[str, os.PathLike],
  234. vision_model_name: Union[str, os.PathLike],
  235. **kwargs,
  236. ) -> "PretrainedConfig":
  237. text_model_config = ModernVBertTextConfig.from_base_model(text_model_name)
  238. vision_model_config = ModernVBertVisionConfig.from_base_model(vision_model_name)
  239. return cls(
  240. text_config=text_model_config,
  241. vision_config=vision_model_config,
  242. **kwargs,
  243. )