flash_whisper.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313
  1. # A whisper that supports flash-attention and dynamic input length.
  2. from typing import Optional, Tuple, Union
  3. import numpy as np
  4. import torch
  5. import torch.nn.functional as F
  6. from torch import nn
  7. from transformers.modeling_outputs import BaseModelOutput
  8. from transformers.models.whisper.modeling_whisper import (
  9. WhisperAttention,
  10. WhisperConfig,
  11. WhisperDecoder,
  12. WhisperDecoderLayer,
  13. WhisperEncoder,
  14. WhisperEncoderLayer,
  15. WhisperForConditionalGeneration,
  16. WhisperModel,
  17. )
  18. from transformers.utils import logging
  19. logger = logging.get_logger(__name__)
  20. class FlashWhisperAttention(WhisperAttention):
  21. """Multi-headed attention from 'Attention Is All You Need' paper"""
  22. # Copied from transformers.models.bart.modeling_bart.BartAttention.forward with BART->whisper
  23. def forward(
  24. self,
  25. hidden_states: torch.Tensor,
  26. key_value_states: Optional[torch.Tensor] = None,
  27. past_key_value: Optional[Tuple[torch.Tensor]] = None,
  28. attention_mask: Optional[torch.Tensor] = None,
  29. layer_head_mask: Optional[torch.Tensor] = None,
  30. output_attentions: bool = False,
  31. ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
  32. """Input shape: Batch x Time x Channel"""
  33. # if key_value_states are provided this layer is used as a cross-attention layer
  34. # for the decoder
  35. is_cross_attention = key_value_states is not None
  36. bsz, tgt_len, _ = hidden_states.size()
  37. # get query proj - don't scale here since Flash Attention performs this under the hood
  38. query_states = self._shape(self.q_proj(hidden_states), -1, bsz)
  39. # get key, value proj
  40. # `past_key_value[0].shape[2] == key_value_states.shape[1]`
  41. # is checking that the `sequence_length` of the `past_key_value` is the same as
  42. # the provided `key_value_states` to support prefix tuning
  43. if (
  44. is_cross_attention
  45. and past_key_value is not None
  46. and past_key_value[0].shape[2] == key_value_states.shape[1]
  47. ):
  48. # reuse k,v, cross_attentions
  49. key_states = past_key_value[0]
  50. value_states = past_key_value[1]
  51. elif is_cross_attention:
  52. # cross_attentions
  53. key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
  54. value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
  55. elif past_key_value is not None:
  56. # reuse k, v, self_attention
  57. key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
  58. value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
  59. key_states = torch.cat([past_key_value[0], key_states], dim=2)
  60. value_states = torch.cat([past_key_value[1], value_states], dim=2)
  61. else:
  62. # self_attention
  63. key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
  64. value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
  65. if self.is_decoder:
  66. # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
  67. # Further calls to cross_attention layer can then reuse all cross-attention
  68. # key/value_states (first "if" case)
  69. # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
  70. # all previous decoder key/value_states. Further calls to uni-directional self-attention
  71. # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
  72. # if encoder bi-directional self-attention `past_key_value` is always `None`
  73. past_key_value = (key_states, value_states)
  74. attn_output = F.scaled_dot_product_attention(
  75. query=query_states,
  76. key=key_states,
  77. value=value_states,
  78. attn_mask=attention_mask,
  79. scale=self.scaling,
  80. )
  81. attn_output = attn_output.transpose(1, 2)
  82. # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
  83. # partitioned across GPUs when using tensor-parallelism.
  84. attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
  85. attn_output = self.out_proj(attn_output)
  86. return attn_output, None, past_key_value
  87. # Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Whisper
  88. class FlashWhisperEncoderLayer(WhisperEncoderLayer):
  89. def __init__(self, config: WhisperConfig):
  90. super().__init__(config)
  91. self.self_attn = FlashWhisperAttention(
  92. embed_dim=self.embed_dim,
  93. num_heads=config.encoder_attention_heads,
  94. dropout=config.attention_dropout,
  95. )
  96. class FlashWhisperDecoderLayer(WhisperDecoderLayer):
  97. def __init__(self, config: WhisperConfig):
  98. super().__init__(config)
  99. self.self_attn = FlashWhisperAttention(
  100. embed_dim=self.embed_dim,
  101. num_heads=config.decoder_attention_heads,
  102. dropout=config.attention_dropout,
  103. is_decoder=True,
  104. )
  105. class FlashWhisperEncoder(WhisperEncoder):
  106. """
  107. Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
  108. [`WhisperEncoderLayer`].
  109. Args:
  110. config: WhisperConfig
  111. """
  112. def __init__(self, config: WhisperConfig):
  113. super().__init__(config)
  114. self.layers = nn.ModuleList(
  115. [FlashWhisperEncoderLayer(config) for _ in range(config.encoder_layers)]
  116. )
  117. def forward(
  118. self,
  119. input_features,
  120. attention_mask=None,
  121. head_mask=None,
  122. output_attentions=None,
  123. output_hidden_states=None,
  124. return_dict=None,
  125. ):
  126. r"""
  127. Args:
  128. input_features (`torch.LongTensor` of shape `(batch_size, feature_size, sequence_length)`):
  129. Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be
  130. obtained by loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a
  131. `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into
  132. `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding
  133. and conversion into a tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`]
  134. attention_mask (`torch.Tensor`)`, *optional*):
  135. Whisper does not support masking of the `input_features`, this argument is preserved for compatibility,
  136. but it is not used. By default the silence in the input log mel spectrogram are ignored.
  137. head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
  138. Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
  139. - 1 indicates the head is **not masked**,
  140. - 0 indicates the head is **masked**.
  141. output_attentions (`bool`, *optional*):
  142. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  143. returned tensors for more detail.
  144. output_hidden_states (`bool`, *optional*):
  145. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
  146. for more detail.
  147. return_dict (`bool`, *optional*):
  148. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  149. """
  150. # If we receive the output of input feature directly, just return it
  151. if input_features.shape[-2:] == (1500, 1024):
  152. if not return_dict:
  153. return (input_features,)
  154. return BaseModelOutput(last_hidden_state=input_features)
  155. output_attentions = (
  156. output_attentions
  157. if output_attentions is not None
  158. else self.config.output_attentions
  159. )
  160. output_hidden_states = (
  161. output_hidden_states
  162. if output_hidden_states is not None
  163. else self.config.output_hidden_states
  164. )
  165. return_dict = (
  166. return_dict if return_dict is not None else self.config.use_return_dict
  167. )
  168. inputs_embeds = nn.functional.gelu(self.conv1(input_features))
  169. inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))
  170. inputs_embeds = inputs_embeds.permute(0, 2, 1)
  171. embed_pos = self.embed_positions.weight
  172. hidden_states = inputs_embeds + embed_pos[None, : inputs_embeds.size(1), :]
  173. hidden_states = nn.functional.dropout(
  174. hidden_states, p=self.dropout, training=self.training
  175. )
  176. encoder_states = () if output_hidden_states else None
  177. all_attentions = () if output_attentions else None
  178. # check if head_mask has a correct number of layers specified if desired
  179. if head_mask is not None:
  180. assert head_mask.size()[0] == (
  181. len(self.layers)
  182. ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
  183. for idx, encoder_layer in enumerate(self.layers):
  184. if output_hidden_states:
  185. encoder_states = encoder_states + (hidden_states,)
  186. # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
  187. to_drop = False
  188. if self.training:
  189. dropout_probability = torch.rand([])
  190. if dropout_probability < self.layerdrop: # skip the layer
  191. to_drop = True
  192. if to_drop:
  193. layer_outputs = (None, None)
  194. else:
  195. if self.gradient_checkpointing and self.training:
  196. def create_custom_forward(module):
  197. def custom_forward(*inputs):
  198. return module(*inputs, output_attentions)
  199. return custom_forward
  200. layer_outputs = torch.utils.checkpoint.checkpoint(
  201. create_custom_forward(encoder_layer),
  202. hidden_states,
  203. None,
  204. (head_mask[idx] if head_mask is not None else None),
  205. )
  206. else:
  207. layer_outputs = encoder_layer(
  208. hidden_states,
  209. None,
  210. layer_head_mask=(
  211. head_mask[idx] if head_mask is not None else None
  212. ),
  213. output_attentions=output_attentions,
  214. )
  215. hidden_states = layer_outputs[0]
  216. if output_attentions:
  217. all_attentions = all_attentions + (layer_outputs[1],)
  218. hidden_states = self.layer_norm(hidden_states)
  219. # Simply set states to zero for attention_mask
  220. # hidden_states[:, 40:, :] = 0
  221. if output_hidden_states:
  222. encoder_states = encoder_states + (hidden_states,)
  223. if not return_dict:
  224. return tuple(
  225. v
  226. for v in [hidden_states, encoder_states, all_attentions]
  227. if v is not None
  228. )
  229. return BaseModelOutput(
  230. last_hidden_state=hidden_states,
  231. hidden_states=encoder_states,
  232. attentions=all_attentions,
  233. )
  234. class FlashWhisperDecoder(WhisperDecoder):
  235. """
  236. Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a
  237. [`WhisperDecoderLayer`]
  238. Args:
  239. config: WhisperConfig
  240. """
  241. def __init__(self, config: WhisperConfig):
  242. super().__init__(config)
  243. self.layers = nn.ModuleList(
  244. [FlashWhisperDecoderLayer(config) for _ in range(config.decoder_layers)]
  245. )
  246. class FlashWhisperModel(WhisperModel):
  247. def __init__(self, config: WhisperConfig):
  248. super().__init__(config)
  249. self.encoder = FlashWhisperEncoder(config)
  250. self.decoder = FlashWhisperDecoder(config)
  251. self.post_init()
  252. class FlashWhisperForConditionalGeneration(WhisperForConditionalGeneration):
  253. def __init__(self, config: WhisperConfig):
  254. super().__init__(config)
  255. self.model = FlashWhisperModel(config)
  256. self.post_init()