# A whisper that supports flash-attention and dynamic input length. from typing import Optional, Tuple, Union import numpy as np import torch import torch.nn.functional as F from torch import nn from transformers.modeling_outputs import BaseModelOutput from transformers.models.whisper.modeling_whisper import ( WhisperAttention, WhisperConfig, WhisperDecoder, WhisperDecoderLayer, WhisperEncoder, WhisperEncoderLayer, WhisperForConditionalGeneration, WhisperModel, ) from transformers.utils import logging logger = logging.get_logger(__name__) class FlashWhisperAttention(WhisperAttention): """Multi-headed attention from 'Attention Is All You Need' paper""" # Copied from transformers.models.bart.modeling_bart.BartAttention.forward with BART->whisper def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder is_cross_attention = key_value_states is not None bsz, tgt_len, _ = hidden_states.size() # get query proj - don't scale here since Flash Attention performs this under the hood query_states = self._shape(self.q_proj(hidden_states), -1, bsz) # get key, value proj # `past_key_value[0].shape[2] == key_value_states.shape[1]` # is checking that the `sequence_length` of the `past_key_value` is the same as # the provided `key_value_states` to support prefix tuning if ( is_cross_attention and past_key_value is not None and past_key_value[0].shape[2] == key_value_states.shape[1] ): # reuse k,v, cross_attentions key_states = past_key_value[0] value_states = past_key_value[1] elif is_cross_attention: # cross_attentions key_states = self._shape(self.k_proj(key_value_states), -1, bsz) value_states = self._shape(self.v_proj(key_value_states), -1, bsz) elif past_key_value is not None: # reuse k, v, self_attention key_states = self._shape(self.k_proj(hidden_states), -1, bsz) value_states = self._shape(self.v_proj(hidden_states), -1, bsz) key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) else: # self_attention key_states = self._shape(self.k_proj(hidden_states), -1, bsz) value_states = self._shape(self.v_proj(hidden_states), -1, bsz) if self.is_decoder: # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. # Further calls to cross_attention layer can then reuse all cross-attention # key/value_states (first "if" case) # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of # all previous decoder key/value_states. Further calls to uni-directional self-attention # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) attn_output = F.scaled_dot_product_attention( query=query_states, key=key_states, value=value_states, attn_mask=attention_mask, scale=self.scaling, ) attn_output = attn_output.transpose(1, 2) # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be # partitioned across GPUs when using tensor-parallelism. attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) attn_output = self.out_proj(attn_output) return attn_output, None, past_key_value # Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Whisper class FlashWhisperEncoderLayer(WhisperEncoderLayer): def __init__(self, config: WhisperConfig): super().__init__(config) self.self_attn = FlashWhisperAttention( embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, ) class FlashWhisperDecoderLayer(WhisperDecoderLayer): def __init__(self, config: WhisperConfig): super().__init__(config) self.self_attn = FlashWhisperAttention( embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, is_decoder=True, ) class FlashWhisperEncoder(WhisperEncoder): """ Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a [`WhisperEncoderLayer`]. Args: config: WhisperConfig """ def __init__(self, config: WhisperConfig): super().__init__(config) self.layers = nn.ModuleList( [FlashWhisperEncoderLayer(config) for _ in range(config.encoder_layers)] ) def forward( self, input_features, attention_mask=None, head_mask=None, output_attentions=None, output_hidden_states=None, return_dict=None, ): r""" Args: input_features (`torch.LongTensor` of shape `(batch_size, feature_size, sequence_length)`): Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be obtained by loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding and conversion into a tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`] attention_mask (`torch.Tensor`)`, *optional*): Whisper does not support masking of the `input_features`, this argument is preserved for compatibility, but it is not used. By default the silence in the input log mel spectrogram are ignored. head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ # If we receive the output of input feature directly, just return it if input_features.shape[-2:] == (1500, 1024): if not return_dict: return (input_features,) return BaseModelOutput(last_hidden_state=input_features) output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) inputs_embeds = nn.functional.gelu(self.conv1(input_features)) inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds)) inputs_embeds = inputs_embeds.permute(0, 2, 1) embed_pos = self.embed_positions.weight hidden_states = inputs_embeds + embed_pos[None, : inputs_embeds.size(1), :] hidden_states = nn.functional.dropout( hidden_states, p=self.dropout, training=self.training ) encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None # check if head_mask has a correct number of layers specified if desired if head_mask is not None: assert head_mask.size()[0] == ( len(self.layers) ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." for idx, encoder_layer in enumerate(self.layers): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) to_drop = False if self.training: dropout_probability = torch.rand([]) if dropout_probability < self.layerdrop: # skip the layer to_drop = True if to_drop: layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs, output_attentions) return custom_forward layer_outputs = torch.utils.checkpoint.checkpoint( create_custom_forward(encoder_layer), hidden_states, None, (head_mask[idx] if head_mask is not None else None), ) else: layer_outputs = encoder_layer( hidden_states, None, layer_head_mask=( head_mask[idx] if head_mask is not None else None ), output_attentions=output_attentions, ) hidden_states = layer_outputs[0] if output_attentions: all_attentions = all_attentions + (layer_outputs[1],) hidden_states = self.layer_norm(hidden_states) # Simply set states to zero for attention_mask # hidden_states[:, 40:, :] = 0 if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if not return_dict: return tuple( v for v in [hidden_states, encoder_states, all_attentions] if v is not None ) return BaseModelOutput( last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions, ) class FlashWhisperDecoder(WhisperDecoder): """ Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`WhisperDecoderLayer`] Args: config: WhisperConfig """ def __init__(self, config: WhisperConfig): super().__init__(config) self.layers = nn.ModuleList( [FlashWhisperDecoderLayer(config) for _ in range(config.decoder_layers)] ) class FlashWhisperModel(WhisperModel): def __init__(self, config: WhisperConfig): super().__init__(config) self.encoder = FlashWhisperEncoder(config) self.decoder = FlashWhisperDecoder(config) self.post_init() class FlashWhisperForConditionalGeneration(WhisperForConditionalGeneration): def __init__(self, config: WhisperConfig): super().__init__(config) self.model = FlashWhisperModel(config) self.post_init()