| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313 |
- # A whisper that supports flash-attention and dynamic input length.
- from typing import Optional, Tuple, Union
- import numpy as np
- import torch
- import torch.nn.functional as F
- from torch import nn
- from transformers.modeling_outputs import BaseModelOutput
- from transformers.models.whisper.modeling_whisper import (
- WhisperAttention,
- WhisperConfig,
- WhisperDecoder,
- WhisperDecoderLayer,
- WhisperEncoder,
- WhisperEncoderLayer,
- WhisperForConditionalGeneration,
- WhisperModel,
- )
- from transformers.utils import logging
- logger = logging.get_logger(__name__)
- class FlashWhisperAttention(WhisperAttention):
- """Multi-headed attention from 'Attention Is All You Need' paper"""
- # Copied from transformers.models.bart.modeling_bart.BartAttention.forward with BART->whisper
- def forward(
- self,
- hidden_states: torch.Tensor,
- key_value_states: Optional[torch.Tensor] = None,
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
- attention_mask: Optional[torch.Tensor] = None,
- layer_head_mask: Optional[torch.Tensor] = None,
- output_attentions: bool = False,
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
- """Input shape: Batch x Time x Channel"""
- # if key_value_states are provided this layer is used as a cross-attention layer
- # for the decoder
- is_cross_attention = key_value_states is not None
- bsz, tgt_len, _ = hidden_states.size()
- # get query proj - don't scale here since Flash Attention performs this under the hood
- query_states = self._shape(self.q_proj(hidden_states), -1, bsz)
- # get key, value proj
- # `past_key_value[0].shape[2] == key_value_states.shape[1]`
- # is checking that the `sequence_length` of the `past_key_value` is the same as
- # the provided `key_value_states` to support prefix tuning
- if (
- is_cross_attention
- and past_key_value is not None
- and past_key_value[0].shape[2] == key_value_states.shape[1]
- ):
- # reuse k,v, cross_attentions
- key_states = past_key_value[0]
- value_states = past_key_value[1]
- elif is_cross_attention:
- # cross_attentions
- key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
- value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
- elif past_key_value is not None:
- # reuse k, v, self_attention
- key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
- value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
- key_states = torch.cat([past_key_value[0], key_states], dim=2)
- value_states = torch.cat([past_key_value[1], value_states], dim=2)
- else:
- # self_attention
- key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
- value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
- if self.is_decoder:
- # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
- # Further calls to cross_attention layer can then reuse all cross-attention
- # key/value_states (first "if" case)
- # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
- # all previous decoder key/value_states. Further calls to uni-directional self-attention
- # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
- # if encoder bi-directional self-attention `past_key_value` is always `None`
- past_key_value = (key_states, value_states)
- attn_output = F.scaled_dot_product_attention(
- query=query_states,
- key=key_states,
- value=value_states,
- attn_mask=attention_mask,
- scale=self.scaling,
- )
- attn_output = attn_output.transpose(1, 2)
- # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
- # partitioned across GPUs when using tensor-parallelism.
- attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
- attn_output = self.out_proj(attn_output)
- return attn_output, None, past_key_value
- # Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Whisper
- class FlashWhisperEncoderLayer(WhisperEncoderLayer):
- def __init__(self, config: WhisperConfig):
- super().__init__(config)
- self.self_attn = FlashWhisperAttention(
- embed_dim=self.embed_dim,
- num_heads=config.encoder_attention_heads,
- dropout=config.attention_dropout,
- )
- class FlashWhisperDecoderLayer(WhisperDecoderLayer):
- def __init__(self, config: WhisperConfig):
- super().__init__(config)
- self.self_attn = FlashWhisperAttention(
- embed_dim=self.embed_dim,
- num_heads=config.decoder_attention_heads,
- dropout=config.attention_dropout,
- is_decoder=True,
- )
- class FlashWhisperEncoder(WhisperEncoder):
- """
- Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
- [`WhisperEncoderLayer`].
- Args:
- config: WhisperConfig
- """
- def __init__(self, config: WhisperConfig):
- super().__init__(config)
- self.layers = nn.ModuleList(
- [FlashWhisperEncoderLayer(config) for _ in range(config.encoder_layers)]
- )
- def forward(
- self,
- input_features,
- attention_mask=None,
- head_mask=None,
- output_attentions=None,
- output_hidden_states=None,
- return_dict=None,
- ):
- r"""
- Args:
- input_features (`torch.LongTensor` of shape `(batch_size, feature_size, sequence_length)`):
- Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be
- obtained by loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a
- `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into
- `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding
- and conversion into a tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`]
- attention_mask (`torch.Tensor`)`, *optional*):
- Whisper does not support masking of the `input_features`, this argument is preserved for compatibility,
- but it is not used. By default the silence in the input log mel spectrogram are ignored.
- head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
- Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
- - 1 indicates the head is **not masked**,
- - 0 indicates the head is **masked**.
- output_attentions (`bool`, *optional*):
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
- returned tensors for more detail.
- output_hidden_states (`bool`, *optional*):
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
- for more detail.
- return_dict (`bool`, *optional*):
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
- """
- # If we receive the output of input feature directly, just return it
- if input_features.shape[-2:] == (1500, 1024):
- if not return_dict:
- return (input_features,)
- return BaseModelOutput(last_hidden_state=input_features)
- output_attentions = (
- output_attentions
- if output_attentions is not None
- else self.config.output_attentions
- )
- output_hidden_states = (
- output_hidden_states
- if output_hidden_states is not None
- else self.config.output_hidden_states
- )
- return_dict = (
- return_dict if return_dict is not None else self.config.use_return_dict
- )
- inputs_embeds = nn.functional.gelu(self.conv1(input_features))
- inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))
- inputs_embeds = inputs_embeds.permute(0, 2, 1)
- embed_pos = self.embed_positions.weight
- hidden_states = inputs_embeds + embed_pos[None, : inputs_embeds.size(1), :]
- hidden_states = nn.functional.dropout(
- hidden_states, p=self.dropout, training=self.training
- )
- encoder_states = () if output_hidden_states else None
- all_attentions = () if output_attentions else None
- # check if head_mask has a correct number of layers specified if desired
- if head_mask is not None:
- assert head_mask.size()[0] == (
- len(self.layers)
- ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
- for idx, encoder_layer in enumerate(self.layers):
- if output_hidden_states:
- encoder_states = encoder_states + (hidden_states,)
- # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
- to_drop = False
- if self.training:
- dropout_probability = torch.rand([])
- if dropout_probability < self.layerdrop: # skip the layer
- to_drop = True
- if to_drop:
- layer_outputs = (None, None)
- else:
- if self.gradient_checkpointing and self.training:
- def create_custom_forward(module):
- def custom_forward(*inputs):
- return module(*inputs, output_attentions)
- return custom_forward
- layer_outputs = torch.utils.checkpoint.checkpoint(
- create_custom_forward(encoder_layer),
- hidden_states,
- None,
- (head_mask[idx] if head_mask is not None else None),
- )
- else:
- layer_outputs = encoder_layer(
- hidden_states,
- None,
- layer_head_mask=(
- head_mask[idx] if head_mask is not None else None
- ),
- output_attentions=output_attentions,
- )
- hidden_states = layer_outputs[0]
- if output_attentions:
- all_attentions = all_attentions + (layer_outputs[1],)
- hidden_states = self.layer_norm(hidden_states)
- # Simply set states to zero for attention_mask
- # hidden_states[:, 40:, :] = 0
- if output_hidden_states:
- encoder_states = encoder_states + (hidden_states,)
- if not return_dict:
- return tuple(
- v
- for v in [hidden_states, encoder_states, all_attentions]
- if v is not None
- )
- return BaseModelOutput(
- last_hidden_state=hidden_states,
- hidden_states=encoder_states,
- attentions=all_attentions,
- )
- class FlashWhisperDecoder(WhisperDecoder):
- """
- Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a
- [`WhisperDecoderLayer`]
- Args:
- config: WhisperConfig
- """
- def __init__(self, config: WhisperConfig):
- super().__init__(config)
- self.layers = nn.ModuleList(
- [FlashWhisperDecoderLayer(config) for _ in range(config.decoder_layers)]
- )
- class FlashWhisperModel(WhisperModel):
- def __init__(self, config: WhisperConfig):
- super().__init__(config)
- self.encoder = FlashWhisperEncoder(config)
- self.decoder = FlashWhisperDecoder(config)
- self.post_init()
- class FlashWhisperForConditionalGeneration(WhisperForConditionalGeneration):
- def __init__(self, config: WhisperConfig):
- super().__init__(config)
- self.model = FlashWhisperModel(config)
- self.post_init()
|