flash_whisper.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321
  1. # A whisper that supports flash-attention and dynamic input length.
  2. from typing import Optional, Tuple, Union
  3. import numpy as np
  4. import torch
  5. import torch.nn.functional as F
  6. from torch import nn
  7. from torch.nn import CrossEntropyLoss
  8. from transformers.generation.logits_process import WhisperTimeStampLogitsProcessor
  9. from transformers.modeling_outputs import BaseModelOutput
  10. from transformers.models.whisper.modeling_whisper import (
  11. WHISPER_INPUTS_DOCSTRING,
  12. WHISPER_START_DOCSTRING,
  13. WhisperAttention,
  14. WhisperConfig,
  15. WhisperDecoder,
  16. WhisperDecoderLayer,
  17. WhisperEncoder,
  18. WhisperEncoderLayer,
  19. WhisperForConditionalGeneration,
  20. WhisperModel,
  21. WhisperPreTrainedModel,
  22. _dynamic_time_warping,
  23. _median_filter,
  24. shift_tokens_right,
  25. )
  26. from transformers.models.whisper.tokenization_whisper import TASK_IDS, TO_LANGUAGE_CODE
  27. from transformers.utils import (
  28. add_start_docstrings,
  29. add_start_docstrings_to_model_forward,
  30. logging,
  31. replace_return_docstrings,
  32. )
  33. logger = logging.get_logger(__name__)
  34. _CONFIG_FOR_DOC = "WhisperConfig"
  35. class FlashWhisperAttention(WhisperAttention):
  36. """Multi-headed attention from 'Attention Is All You Need' paper"""
  37. # Copied from transformers.models.bart.modeling_bart.BartAttention.forward with BART->whisper
  38. def forward(
  39. self,
  40. hidden_states: torch.Tensor,
  41. key_value_states: Optional[torch.Tensor] = None,
  42. past_key_value: Optional[Tuple[torch.Tensor]] = None,
  43. attention_mask: Optional[torch.Tensor] = None,
  44. layer_head_mask: Optional[torch.Tensor] = None,
  45. output_attentions: bool = False,
  46. ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
  47. """Input shape: Batch x Time x Channel"""
  48. # if key_value_states are provided this layer is used as a cross-attention layer
  49. # for the decoder
  50. is_cross_attention = key_value_states is not None
  51. bsz, tgt_len, _ = hidden_states.size()
  52. # get query proj - don't scale here since Flash Attention performs this under the hood
  53. query_states = self._shape(self.q_proj(hidden_states), -1, bsz)
  54. # get key, value proj
  55. # `past_key_value[0].shape[2] == key_value_states.shape[1]`
  56. # is checking that the `sequence_length` of the `past_key_value` is the same as
  57. # the provided `key_value_states` to support prefix tuning
  58. if (
  59. is_cross_attention
  60. and past_key_value is not None
  61. and past_key_value[0].shape[2] == key_value_states.shape[1]
  62. ):
  63. # reuse k,v, cross_attentions
  64. key_states = past_key_value[0]
  65. value_states = past_key_value[1]
  66. elif is_cross_attention:
  67. # cross_attentions
  68. key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
  69. value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
  70. elif past_key_value is not None:
  71. # reuse k, v, self_attention
  72. key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
  73. value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
  74. key_states = torch.cat([past_key_value[0], key_states], dim=2)
  75. value_states = torch.cat([past_key_value[1], value_states], dim=2)
  76. else:
  77. # self_attention
  78. key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
  79. value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
  80. if self.is_decoder:
  81. # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
  82. # Further calls to cross_attention layer can then reuse all cross-attention
  83. # key/value_states (first "if" case)
  84. # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
  85. # all previous decoder key/value_states. Further calls to uni-directional self-attention
  86. # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
  87. # if encoder bi-directional self-attention `past_key_value` is always `None`
  88. past_key_value = (key_states, value_states)
  89. attn_output = F.scaled_dot_product_attention(
  90. query=query_states,
  91. key=key_states,
  92. value=value_states,
  93. attn_mask=attention_mask,
  94. scale=self.scaling,
  95. )
  96. attn_output = attn_output.transpose(1, 2)
  97. # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
  98. # partitioned across GPUs when using tensor-parallelism.
  99. attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
  100. attn_output = self.out_proj(attn_output)
  101. return attn_output, None, past_key_value
  102. # Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Whisper
  103. class FlashWhisperEncoderLayer(WhisperEncoderLayer):
  104. def __init__(self, config: WhisperConfig):
  105. super().__init__(config)
  106. self.self_attn = FlashWhisperAttention(
  107. embed_dim=self.embed_dim,
  108. num_heads=config.encoder_attention_heads,
  109. dropout=config.attention_dropout,
  110. )
  111. class FlashWhisperDecoderLayer(WhisperDecoderLayer):
  112. def __init__(self, config: WhisperConfig):
  113. super().__init__(config)
  114. self.self_attn = FlashWhisperAttention(
  115. embed_dim=self.embed_dim,
  116. num_heads=config.decoder_attention_heads,
  117. dropout=config.attention_dropout,
  118. is_decoder=True,
  119. )
  120. class FlashWhisperEncoder(WhisperEncoder):
  121. """
  122. Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
  123. [`WhisperEncoderLayer`].
  124. Args:
  125. config: WhisperConfig
  126. """
  127. def __init__(self, config: WhisperConfig):
  128. super().__init__(config)
  129. self.layers = nn.ModuleList(
  130. [FlashWhisperEncoderLayer(config) for _ in range(config.encoder_layers)]
  131. )
  132. def forward(
  133. self,
  134. input_features,
  135. attention_mask=None,
  136. head_mask=None,
  137. output_attentions=None,
  138. output_hidden_states=None,
  139. return_dict=None,
  140. ):
  141. r"""
  142. Args:
  143. input_features (`torch.LongTensor` of shape `(batch_size, feature_size, sequence_length)`):
  144. Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be
  145. obtained by loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a
  146. `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into
  147. `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding
  148. and conversion into a tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`]
  149. attention_mask (`torch.Tensor`)`, *optional*):
  150. Whisper does not support masking of the `input_features`, this argument is preserved for compatibility,
  151. but it is not used. By default the silence in the input log mel spectrogram are ignored.
  152. head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
  153. Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
  154. - 1 indicates the head is **not masked**,
  155. - 0 indicates the head is **masked**.
  156. output_attentions (`bool`, *optional*):
  157. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  158. returned tensors for more detail.
  159. output_hidden_states (`bool`, *optional*):
  160. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
  161. for more detail.
  162. return_dict (`bool`, *optional*):
  163. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  164. """
  165. output_attentions = (
  166. output_attentions
  167. if output_attentions is not None
  168. else self.config.output_attentions
  169. )
  170. output_hidden_states = (
  171. output_hidden_states
  172. if output_hidden_states is not None
  173. else self.config.output_hidden_states
  174. )
  175. return_dict = (
  176. return_dict if return_dict is not None else self.config.use_return_dict
  177. )
  178. inputs_embeds = nn.functional.gelu(self.conv1(input_features))
  179. inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))
  180. inputs_embeds = inputs_embeds.permute(0, 2, 1)
  181. embed_pos = self.embed_positions.weight
  182. hidden_states = inputs_embeds + embed_pos[None, : inputs_embeds.size(1), :]
  183. hidden_states = nn.functional.dropout(
  184. hidden_states, p=self.dropout, training=self.training
  185. )
  186. encoder_states = () if output_hidden_states else None
  187. all_attentions = () if output_attentions else None
  188. # check if head_mask has a correct number of layers specified if desired
  189. if head_mask is not None:
  190. assert head_mask.size()[0] == (
  191. len(self.layers)
  192. ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
  193. for idx, encoder_layer in enumerate(self.layers):
  194. if output_hidden_states:
  195. encoder_states = encoder_states + (hidden_states,)
  196. # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
  197. to_drop = False
  198. if self.training:
  199. dropout_probability = torch.rand([])
  200. if dropout_probability < self.layerdrop: # skip the layer
  201. to_drop = True
  202. if to_drop:
  203. layer_outputs = (None, None)
  204. else:
  205. if self.gradient_checkpointing and self.training:
  206. def create_custom_forward(module):
  207. def custom_forward(*inputs):
  208. return module(*inputs, output_attentions)
  209. return custom_forward
  210. layer_outputs = torch.utils.checkpoint.checkpoint(
  211. create_custom_forward(encoder_layer),
  212. hidden_states,
  213. None,
  214. (head_mask[idx] if head_mask is not None else None),
  215. )
  216. else:
  217. layer_outputs = encoder_layer(
  218. hidden_states,
  219. None,
  220. layer_head_mask=(
  221. head_mask[idx] if head_mask is not None else None
  222. ),
  223. output_attentions=output_attentions,
  224. )
  225. hidden_states = layer_outputs[0]
  226. if output_attentions:
  227. all_attentions = all_attentions + (layer_outputs[1],)
  228. hidden_states = self.layer_norm(hidden_states)
  229. # Simply set states to zero for attention_mask
  230. # hidden_states[:, 40:, :] = 0
  231. if output_hidden_states:
  232. encoder_states = encoder_states + (hidden_states,)
  233. if not return_dict:
  234. return tuple(
  235. v
  236. for v in [hidden_states, encoder_states, all_attentions]
  237. if v is not None
  238. )
  239. return BaseModelOutput(
  240. last_hidden_state=hidden_states,
  241. hidden_states=encoder_states,
  242. attentions=all_attentions,
  243. )
  244. class FlashWhisperDecoder(WhisperDecoder):
  245. """
  246. Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a
  247. [`WhisperDecoderLayer`]
  248. Args:
  249. config: WhisperConfig
  250. """
  251. def __init__(self, config: WhisperConfig):
  252. super().__init__(config)
  253. self.layers = nn.ModuleList(
  254. [FlashWhisperDecoderLayer(config) for _ in range(config.decoder_layers)]
  255. )
  256. class FlashWhisperModel(WhisperModel):
  257. def __init__(self, config: WhisperConfig):
  258. super().__init__(config)
  259. self.encoder = FlashWhisperEncoder(config)
  260. self.decoder = FlashWhisperDecoder(config)
  261. self.post_init()
  262. class FlashWhisperForConditionalGeneration(WhisperForConditionalGeneration):
  263. def __init__(self, config: WhisperConfig):
  264. super().__init__(config)
  265. self.model = FlashWhisperModel(config)
  266. self.post_init()