modeling_modernvbert.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476
  1. from dataclasses import dataclass
  2. from typing import Optional, Tuple, Union
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F # noqa: N812
  6. from torch.nn import CrossEntropyLoss
  7. from transformers import AutoConfig, AutoModel, AutoModelForMaskedLM, PreTrainedModel, logging
  8. from transformers.modeling_outputs import BaseModelOutput
  9. from transformers.models.bert.modeling_bert import BaseModelOutputWithPoolingAndCrossAttentions, MaskedLMOutput
  10. from .configuration_modernvbert import ModernVBertConfig
  11. logger = logging.get_logger(__name__)
  12. class DecoupledEmbedding(nn.Embedding):
  13. # Derived from https://pytorch.org/docs/stable/_modules/torch/nn/modules/sparse.html#Embedding
  14. """
  15. Implements a decoupling of parameters to allow freezing (or not) a subset of the embeddings.
  16. In practise, the regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and
  17. if `num_additional_embeddings` > 0, then it will create `num_additional_embeddings`
  18. additional parameters that are always trained.
  19. If `num_additional_embeddings=0`, then the module defaults back to the regular behavior of `nn.Embedding`.
  20. """
  21. def __init__(
  22. self,
  23. num_embeddings,
  24. num_additional_embeddings,
  25. embedding_dim,
  26. partially_freeze=False,
  27. device=None,
  28. dtype=None,
  29. padding_idx=None,
  30. **kwargs,
  31. ) -> None:
  32. """
  33. num_additional_embeddings: int. Number of additional embeddings. Only useful when you `partially_freeze=True`.
  34. partially_freeze: bool. If True, the regular `weight` will be frozen. `additional_weight` is never frozen.
  35. Note: there are a lot of other parameters to initialize a standard `nn.Embedding` such as `padding_idx`,
  36. `max_norm` or `norm_type`. We are not supporting these.
  37. """
  38. if padding_idx is not None and padding_idx > num_embeddings:
  39. raise ValueError(f"padding_idx must be within num_embeddings. Got {padding_idx} and {num_embeddings}")
  40. super().__init__(
  41. num_embeddings=num_embeddings,
  42. embedding_dim=embedding_dim,
  43. device=device,
  44. dtype=dtype,
  45. padding_idx=padding_idx,
  46. **kwargs,
  47. )
  48. self.num_embeddings = num_embeddings
  49. self.num_additional_embeddings = num_additional_embeddings
  50. self.partially_freeze = partially_freeze
  51. if partially_freeze:
  52. self.weight.requires_grad_(False)
  53. if self.num_additional_embeddings > 0:
  54. self.additional_embedding = nn.Embedding(
  55. num_embeddings=num_additional_embeddings,
  56. embedding_dim=embedding_dim,
  57. device=device,
  58. dtype=dtype,
  59. )
  60. def forward(self, input_ids):
  61. """
  62. we have 2 embeddings, with different indices - one pretrained self.weight and another
  63. self.additional_embedding.weight that is being trained.
  64. in order to make a lookup of the input ids, we:
  65. 1. find out the indices of the entries belonging to the 2nd embedding
  66. 2. extract those values while subtracting the size of the first embedding (num_embeddings),
  67. since the 2nd embedding starts from 0 and not num_embeddings
  68. 3. perform the 2nd embedding lookup
  69. 4. now we handle the 1st embedding, we overwrite indices belonging to the 2nd embedding with a padding index
  70. 5. perform the 1st embedding lookup
  71. 6. now we overwrite the values in the 1st embedding lookup with the values of the 2nd embedding lookup
  72. note: for the 1st embedding lookup we could have looked up only the low indices and not do
  73. the padding, but then we have to create a new tensor and populate it with 2 tensors that are
  74. spread out across various indices - i.e. not a simple concat - I haven't benchmarked the
  75. complex case if it's any faster, given that seqlens are usually relatively short it's
  76. probably not faster or if faster not by much - but might be a good idea to measure.
  77. """
  78. if self.num_additional_embeddings == 0:
  79. return super().forward(input_ids)
  80. input_ids = input_ids.clone()
  81. additional_vocab_indices = torch.where(input_ids >= self.num_embeddings)
  82. input_ids_additional_vocab = input_ids[additional_vocab_indices]
  83. additional_embeddings = self.additional_embedding(input_ids_additional_vocab - self.num_embeddings)
  84. # for successful lookup replace input_ids with 0, the results of these will be discarded anyway
  85. input_ids[additional_vocab_indices] = 0
  86. full_vector = F.embedding(input_ids, self.weight)
  87. full_vector[additional_vocab_indices] = additional_embeddings # overwrite the records with high indices
  88. return full_vector
  89. @dataclass
  90. class ModernVBertBaseModelOutput(BaseModelOutput):
  91. """
  92. Base class for ModernVBERT model's outputs that may also contain a past key/values (to speed up sequential decoding)
  93. Args:
  94. last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
  95. Sequence of hidden-states at the output of the last layer of the model.
  96. If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
  97. hidden_size)` is output.
  98. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed
  99. or when `config.output_hidden_states=True`):
  100. Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
  101. one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
  102. Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
  103. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed
  104. or when `config.output_attentions=True`):
  105. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  106. sequence_length)`.
  107. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
  108. heads.
  109. image_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
  110. Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images,
  111. sequence_length, hidden_size)`.
  112. image_hidden_states of the model produced by the vision encoder
  113. """
  114. last_hidden_state: torch.FloatTensor = None
  115. hidden_states: Optional[Tuple[torch.FloatTensor]] = None
  116. attentions: Optional[Tuple[torch.FloatTensor]] = None
  117. image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
  118. @dataclass
  119. class ModernVBertMaskedLMOutput(MaskedLMOutput):
  120. """
  121. Base class for ModernVBERT model's outputs that may also contain a past key/values (to speed up sequential decoding)
  122. Args:
  123. loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided):
  124. Masked language modeling (MLM) loss.
  125. logits (`torch.FloatTensor`):
  126. Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
  127. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed
  128. or when `config.output_hidden_states=True`):
  129. Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
  130. one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
  131. Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
  132. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed
  133. or when `config.output_attentions=True`):
  134. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  135. sequence_length)`.
  136. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
  137. heads.
  138. image_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
  139. Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images,
  140. sequence_length, hidden_size)`.
  141. image_hidden_states of the model produced by the vision encoder
  142. """
  143. loss: Optional[torch.FloatTensor] = None
  144. logits: torch.FloatTensor = None
  145. hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
  146. attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
  147. image_hidden_states: Optional[torch.FloatTensor] = None
  148. class ModernVBertSimpleMLP(nn.Module):
  149. """A simple linear projection layer to project the vision hidden states to the text hidden states."""
  150. def __init__(self, input_size, output_size):
  151. super().__init__()
  152. self.proj = nn.Linear(input_size, output_size, bias=False)
  153. def forward(self, x):
  154. return self.proj(x)
  155. class ModernVBertConnector(nn.Module):
  156. """
  157. Connector module for ModernVBERT. It performs a pixel shuffle operation
  158. followed by a linear projection to match the text model's hidden size.
  159. Based on https://pytorch.org/docs/stable/generated/torch.nn.PixelShuffle.html
  160. """
  161. def __init__(self, config):
  162. super().__init__()
  163. self.scale_factor = config.pixel_shuffle_factor
  164. self.modality_projection = ModernVBertSimpleMLP(
  165. input_size=config.vision_config.hidden_size * (config.scale_factor**2),
  166. output_size=config.text_config.hidden_size,
  167. )
  168. def pixel_shuffle(self, x, scale_factor):
  169. bsz, seq, embed_dim = x.size()
  170. height = width = int(seq**0.5)
  171. x = x.view(bsz, height, width, embed_dim)
  172. x = x.view(bsz, height, int(width / scale_factor), embed_dim * scale_factor)
  173. x = x.permute(0, 2, 1, 3)
  174. x = x.reshape(bsz, int(width / scale_factor), int(height / scale_factor), embed_dim * (scale_factor**2))
  175. x = x.permute(0, 2, 1, 3)
  176. return x.reshape(bsz, int(seq / (scale_factor**2)), embed_dim * (scale_factor**2))
  177. def forward(self, image_hidden_states):
  178. image_hidden_states = self.pixel_shuffle(image_hidden_states, self.scale_factor)
  179. return self.modality_projection(image_hidden_states)
  180. class ModernVBertPreTrainedModel(PreTrainedModel):
  181. config_class = ModernVBertConfig
  182. base_model_prefix = "model"
  183. supports_gradient_checkpointing = True
  184. _supports_flash_attn_2 = True
  185. _supports_sdpa = True
  186. def _init_weights(self, module):
  187. std = getattr(self.config, "initializer_range", 0.02)
  188. if isinstance(module, (nn.Linear, nn.Conv2d)):
  189. module.weight.data.normal_(mean=0.0, std=std)
  190. if module.bias is not None:
  191. module.bias.data.zero_()
  192. elif isinstance(module, nn.Embedding):
  193. module.weight.data.normal_(mean=0.0, std=std)
  194. if module.padding_idx is not None:
  195. module.weight.data[module.padding_idx].zero_()
  196. class ModernVBertModel(ModernVBertPreTrainedModel):
  197. def __init__(self, config: ModernVBertConfig):
  198. super().__init__(config)
  199. self.vision_model = ModernVBertModel.init_vision_model(config)
  200. self.connector = ModernVBertConnector(config)
  201. self.text_model = ModernVBertModel.init_language_model(config)
  202. self.image_seq_len = int(
  203. ((config.vision_config.image_size // config.vision_config.patch_size) ** 2) / (config.scale_factor**2)
  204. )
  205. self.image_token_id = config.image_token_id
  206. self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
  207. # set the correct dtype for vision and text models
  208. self.vision_model.to(self.dtype)
  209. self.text_model.to(self.dtype)
  210. self.post_init()
  211. @staticmethod
  212. def init_vision_model(config: ModernVBertConfig):
  213. vision_model_config = AutoConfig.from_pretrained(
  214. config.vision_config.vision_model_name,
  215. _attn_implementation=config._attn_implementation,
  216. )
  217. vision_model = AutoModel.from_config(
  218. vision_model_config,
  219. trust_remote_code=True,
  220. )
  221. return getattr(vision_model, "vision_model", vision_model)
  222. @staticmethod
  223. def init_language_model(config: ModernVBertConfig):
  224. text_model_config = AutoConfig.from_pretrained(
  225. config.text_config.text_model_name,
  226. _attn_implementation=config._attn_implementation,
  227. trust_remote_code=True,
  228. )
  229. text_model = AutoModel.from_config(text_model_config, trust_remote_code=True)
  230. embed_layer = DecoupledEmbedding(
  231. num_embeddings=text_model_config.vocab_size,
  232. num_additional_embeddings=config.additional_vocab_size,
  233. embedding_dim=config.hidden_size,
  234. partially_freeze=config.freeze_config["freeze_text_layers"],
  235. padding_idx=config.pad_token_id,
  236. )
  237. text_model.set_input_embeddings(embed_layer)
  238. return text_model
  239. def enable_input_require_grads(self):
  240. """
  241. Enables the gradients for the input embeddings.
  242. This is useful for lora when using gradient checkpointing.
  243. c.f. https://github.com/huggingface/peft/issues/1402#issuecomment-1913675032
  244. Override to set output.requires_grad = True for both the decoder's and vision model's embeddings.
  245. """
  246. def get_lowest_module(module):
  247. if len(list(module.children())) == 0:
  248. # If the module has no children, it is a leaf module (e.g., Linear, Conv2d, etc.)
  249. return module
  250. else:
  251. # Recursively call the function on each child module
  252. return get_lowest_module(list(module.children())[0])
  253. def make_inputs_require_grads(module, input, output):
  254. output.requires_grad_(True)
  255. self._text_require_grads_hook = self.get_input_embeddings().register_forward_hook(make_inputs_require_grads)
  256. self._vision_require_grads_hook = get_lowest_module(self.vision_model).register_forward_hook(
  257. make_inputs_require_grads
  258. )
  259. def get_input_embeddings(self):
  260. return self.text_model.get_input_embeddings()
  261. def set_input_embeddings(self, value):
  262. self.text_model.set_input_embeddings(value)
  263. def inputs_merger(self, input_ids, inputs_embeds, image_hidden_states):
  264. """Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/smolvlm/modeling_smolvlm.py
  265. This method aims at merging the token embeddings with the image hidden states into one single
  266. sequence of vectors that are fed to the transformer LM.
  267. The merging happens as follows:
  268. - The text token sequence is:
  269. `tok_1 tok_2 tok_3 <fake_token_around_image> <image> <image> ... <image> <fake_token_around_image> tok_4`.
  270. - We get the image hidden states for the image through the vision encoder and that hidden state,
  271. after a pixel shuffle operation, is then projected into the text embedding space.
  272. We thus have a sequence of image hidden states of size (1, image_seq_len, hidden_dim),
  273. where 1 is for batch_size of 1 image and hidden_dim is the hidden_dim of the LM transformer.
  274. - The merging happens so that we obtain the following sequence:
  275. `vector_tok_1 vector_tok_2 vector_tok_3 vector_fake_tok_around_image
  276. {sequence of image_seq_len image hidden states}
  277. vector_fake_tok_around_image vector_tok_4`.
  278. That sequence is fed to the LM.
  279. - To fit the format of that sequence, `input_ids`, `input_embeds`, `attention_mask` are all 3 adapted to insert
  280. the image hidden states.
  281. """
  282. _, patch_size, _ = image_hidden_states.shape
  283. image_mask = input_ids == self.image_token_id
  284. num_image_tokens = image_mask.sum(dim=1)
  285. if not torch.all(num_image_tokens % patch_size == 0):
  286. raise ValueError("Number of <image> tokens not divisible by patch_size.")
  287. blocks_per_sample = num_image_tokens // patch_size
  288. offsets = torch.nn.functional.pad(blocks_per_sample.cumsum(dim=0), (1, 0), value=0)
  289. block_offset = offsets[:-1]
  290. row_cum = image_mask.cumsum(dim=-1)
  291. chunk_idx = (row_cum - 1) // patch_size
  292. local_idx = (row_cum - 1) % patch_size
  293. block_idx = block_offset.unsqueeze(1) + chunk_idx
  294. image_embeds = torch.zeros_like(inputs_embeds)
  295. image_embeds[image_mask] = image_hidden_states[block_idx[image_mask], local_idx[image_mask], :]
  296. return torch.where(image_mask.unsqueeze(-1), image_embeds, inputs_embeds)
  297. def forward(
  298. self,
  299. input_ids: torch.LongTensor = None,
  300. attention_mask: Optional[torch.Tensor] = None,
  301. position_ids: Optional[torch.LongTensor] = None,
  302. inputs_embeds: Optional[torch.FloatTensor] = None,
  303. pixel_values: Optional[torch.FloatTensor] = None,
  304. pixel_attention_mask: Optional[torch.BoolTensor] = None,
  305. image_hidden_states: Optional[torch.FloatTensor] = None,
  306. output_attentions: Optional[bool] = None,
  307. output_hidden_states: Optional[bool] = None,
  308. return_dict: Optional[bool] = None,
  309. ) -> Union[Tuple, BaseModelOutputWithPoolingAndCrossAttentions]:
  310. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  311. output_hidden_states = (
  312. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  313. )
  314. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  315. if inputs_embeds is None:
  316. inputs_embeds = self.text_model.get_input_embeddings()(input_ids).to(input_ids.device)
  317. if pixel_values is not None:
  318. batch_size, num_images, _, _, _ = pixel_values.shape
  319. pixel_values = pixel_values.view(batch_size * num_images, *pixel_values.shape[2:])
  320. nb_values_per_image = pixel_values.shape[1:].numel()
  321. real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image
  322. if not any(real_images_inds):
  323. real_images_inds[0] = True
  324. pixel_values = pixel_values[real_images_inds].contiguous()
  325. image_hidden_states = self.vision_model(pixel_values=pixel_values).last_hidden_state
  326. image_hidden_states = self.connector(image_hidden_states)
  327. elif image_hidden_states is not None:
  328. image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device)
  329. if inputs_embeds is not None and image_hidden_states is not None:
  330. inputs_embeds = self.inputs_merger(input_ids, inputs_embeds, image_hidden_states)
  331. outputs = self.text_model(
  332. inputs_embeds=inputs_embeds,
  333. attention_mask=attention_mask,
  334. position_ids=position_ids,
  335. output_attentions=output_attentions,
  336. output_hidden_states=output_hidden_states,
  337. return_dict=return_dict,
  338. )
  339. if not return_dict:
  340. return tuple(v for v in [*outputs, image_hidden_states] if v is not None)
  341. return ModernVBertBaseModelOutput(
  342. last_hidden_state=outputs.last_hidden_state,
  343. hidden_states=outputs.hidden_states,
  344. attentions=outputs.attentions,
  345. image_hidden_states=image_hidden_states,
  346. )
  347. class ModernVBertLMHead(nn.Module):
  348. def __init__(self, config):
  349. super().__init__()
  350. pretrained_config = AutoConfig.from_pretrained(config.text_config.text_model_name, trust_remote_code=True)
  351. pretrained_model = AutoModelForMaskedLM.from_config(pretrained_config, trust_remote_code=True)
  352. self.head = pretrained_model.head
  353. self.decoder = pretrained_model.decoder
  354. def forward(self, hidden_states):
  355. return self.decoder(self.head(hidden_states))
  356. class ModernVBertForMaskedLM(ModernVBertPreTrainedModel):
  357. def __init__(self, config):
  358. super().__init__(config)
  359. self.image_token_id = config.image_token_id
  360. self.in_features = config.hidden_size
  361. self.out_additional_features = config.additional_vocab_size
  362. self.vocab_size = config.vocab_size
  363. self.model = ModernVBertModel(config)
  364. self.lm_head = ModernVBertLMHead(config)
  365. if self.out_additional_features > 0:
  366. self.additional_fc = nn.Linear(self.in_features, self.out_additional_features, bias=False)
  367. self.lm_head.to(self.dtype)
  368. self.loss_function = CrossEntropyLoss()
  369. self.post_init()
  370. def forward(
  371. self,
  372. input_ids: torch.LongTensor = None,
  373. attention_mask: Optional[torch.Tensor] = None,
  374. position_ids: Optional[torch.LongTensor] = None,
  375. inputs_embeds: Optional[torch.FloatTensor] = None,
  376. pixel_values: Optional[torch.FloatTensor] = None,
  377. pixel_attention_mask: Optional[torch.BoolTensor] = None,
  378. image_hidden_states: Optional[torch.FloatTensor] = None,
  379. output_attentions: Optional[bool] = None,
  380. output_hidden_states: Optional[bool] = None,
  381. return_dict: Optional[bool] = None,
  382. labels: Optional[torch.LongTensor] = None,
  383. ) -> Union[Tuple, ModernVBertMaskedLMOutput]:
  384. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  385. output_hidden_states = (
  386. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  387. )
  388. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  389. outputs = self.model(
  390. input_ids=input_ids,
  391. attention_mask=attention_mask,
  392. position_ids=position_ids,
  393. inputs_embeds=inputs_embeds,
  394. pixel_values=pixel_values,
  395. pixel_attention_mask=pixel_attention_mask,
  396. image_hidden_states=image_hidden_states,
  397. output_attentions=output_attentions,
  398. output_hidden_states=output_hidden_states,
  399. return_dict=return_dict,
  400. )
  401. hidden_states = outputs[0]
  402. logits = self.lm_head(hidden_states)
  403. if self.out_additional_features > 0:
  404. proj_states = self.lm_head.head(hidden_states)
  405. additional_features = self.additional_fc(proj_states)
  406. logits = torch.cat((logits, additional_features), -1)
  407. loss = None
  408. if labels is not None:
  409. loss = self.loss_function(logits.view(-1, self.vocab_size + self.out_additional_features), labels.view(-1))
  410. if not return_dict:
  411. output = (logits,) + outputs[2:]
  412. return ((loss,) + output) if loss is not None else output
  413. return ModernVBertMaskedLMOutput(
  414. loss=loss,
  415. logits=logits.float(),
  416. hidden_states=outputs.hidden_states,
  417. attentions=outputs.attentions,
  418. image_hidden_states=outputs.image_hidden_states,
  419. )