|
|
@@ -0,0 +1,299 @@
|
|
|
+# A whisper that supports flash-attention and dynamic input length.
|
|
|
+import torch
|
|
|
+from transformers.models.whisper.modeling_whisper import (
|
|
|
+ WhisperPreTrainedModel,
|
|
|
+ WhisperConfig,
|
|
|
+ WHISPER_START_DOCSTRING,
|
|
|
+ WHISPER_INPUTS_DOCSTRING,
|
|
|
+ WhisperModel,
|
|
|
+ shift_tokens_right,
|
|
|
+ _dynamic_time_warping,
|
|
|
+ _median_filter,
|
|
|
+ WhisperAttention,
|
|
|
+ WhisperEncoder,
|
|
|
+ WhisperModel,
|
|
|
+ WhisperPreTrainedModel,
|
|
|
+ WhisperEncoderLayer,
|
|
|
+ WhisperForConditionalGeneration,
|
|
|
+ WhisperDecoder,
|
|
|
+ WhisperDecoderLayer
|
|
|
+)
|
|
|
+from transformers.models.whisper.tokenization_whisper import TO_LANGUAGE_CODE, TASK_IDS
|
|
|
+from torch.nn import CrossEntropyLoss
|
|
|
+from torch import nn
|
|
|
+from typing import Optional, Tuple, Union
|
|
|
+from transformers.utils import (
|
|
|
+ add_start_docstrings,
|
|
|
+ add_start_docstrings_to_model_forward,
|
|
|
+ logging,
|
|
|
+ replace_return_docstrings,
|
|
|
+)
|
|
|
+from transformers.modeling_outputs import (
|
|
|
+ BaseModelOutput,
|
|
|
+)
|
|
|
+from transformers.generation.logits_process import WhisperTimeStampLogitsProcessor
|
|
|
+import numpy as np
|
|
|
+import torch.nn.functional as F
|
|
|
+
|
|
|
+logger = logging.get_logger(__name__)
|
|
|
+
|
|
|
+_CONFIG_FOR_DOC = "WhisperConfig"
|
|
|
+
|
|
|
+class FlashWhisperAttention(WhisperAttention):
|
|
|
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
|
+
|
|
|
+ # Copied from transformers.models.bart.modeling_bart.BartAttention.forward with BART->whisper
|
|
|
+ def forward(
|
|
|
+ self,
|
|
|
+ hidden_states: torch.Tensor,
|
|
|
+ key_value_states: Optional[torch.Tensor] = None,
|
|
|
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
|
|
+ attention_mask: Optional[torch.Tensor] = None,
|
|
|
+ layer_head_mask: Optional[torch.Tensor] = None,
|
|
|
+ output_attentions: bool = False,
|
|
|
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
|
+ """Input shape: Batch x Time x Channel"""
|
|
|
+
|
|
|
+ # if key_value_states are provided this layer is used as a cross-attention layer
|
|
|
+ # for the decoder
|
|
|
+ is_cross_attention = key_value_states is not None
|
|
|
+
|
|
|
+ bsz, tgt_len, _ = hidden_states.size()
|
|
|
+
|
|
|
+ # get query proj - don't scale here since Flash Attention performs this under the hood
|
|
|
+ query_states = self._shape(self.q_proj(hidden_states), -1, bsz)
|
|
|
+
|
|
|
+ # get key, value proj
|
|
|
+ # `past_key_value[0].shape[2] == key_value_states.shape[1]`
|
|
|
+ # is checking that the `sequence_length` of the `past_key_value` is the same as
|
|
|
+ # the provided `key_value_states` to support prefix tuning
|
|
|
+ if (
|
|
|
+ is_cross_attention
|
|
|
+ and past_key_value is not None
|
|
|
+ and past_key_value[0].shape[2] == key_value_states.shape[1]
|
|
|
+ ):
|
|
|
+ # reuse k,v, cross_attentions
|
|
|
+ key_states = past_key_value[0]
|
|
|
+ value_states = past_key_value[1]
|
|
|
+ elif is_cross_attention:
|
|
|
+ # cross_attentions
|
|
|
+ key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
|
|
|
+ value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
|
|
|
+ elif past_key_value is not None:
|
|
|
+ # reuse k, v, self_attention
|
|
|
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
|
|
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
|
|
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
|
|
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
|
|
+ else:
|
|
|
+ # self_attention
|
|
|
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
|
|
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
|
|
+
|
|
|
+ if self.is_decoder:
|
|
|
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
|
|
+ # Further calls to cross_attention layer can then reuse all cross-attention
|
|
|
+ # key/value_states (first "if" case)
|
|
|
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
|
|
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
|
|
|
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
|
|
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
|
|
|
+ past_key_value = (key_states, value_states)
|
|
|
+
|
|
|
+ attn_output = F.scaled_dot_product_attention(
|
|
|
+ query=query_states,
|
|
|
+ key=key_states,
|
|
|
+ value=value_states,
|
|
|
+ attn_mask=attention_mask,
|
|
|
+ scale=self.scaling
|
|
|
+ )
|
|
|
+
|
|
|
+ attn_output = attn_output.transpose(1, 2)
|
|
|
+
|
|
|
+ # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
|
|
|
+ # partitioned across GPUs when using tensor-parallelism.
|
|
|
+ attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
|
|
|
+
|
|
|
+ attn_output = self.out_proj(attn_output)
|
|
|
+
|
|
|
+ return attn_output, None, past_key_value
|
|
|
+
|
|
|
+
|
|
|
+# Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Whisper
|
|
|
+class FlashWhisperEncoderLayer(WhisperEncoderLayer):
|
|
|
+ def __init__(self, config: WhisperConfig):
|
|
|
+ super().__init__(config)
|
|
|
+
|
|
|
+ self.self_attn = FlashWhisperAttention(
|
|
|
+ embed_dim=self.embed_dim,
|
|
|
+ num_heads=config.encoder_attention_heads,
|
|
|
+ dropout=config.attention_dropout,
|
|
|
+ )
|
|
|
+
|
|
|
+class FlashWhisperDecoderLayer(WhisperDecoderLayer):
|
|
|
+ def __init__(self, config: WhisperConfig):
|
|
|
+ super().__init__(config)
|
|
|
+
|
|
|
+ self.self_attn = FlashWhisperAttention(
|
|
|
+ embed_dim=self.embed_dim,
|
|
|
+ num_heads=config.decoder_attention_heads,
|
|
|
+ dropout=config.attention_dropout,
|
|
|
+ is_decoder=True,
|
|
|
+ )
|
|
|
+
|
|
|
+class FlashWhisperEncoder(WhisperEncoder):
|
|
|
+ """
|
|
|
+ Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
|
|
|
+ [`WhisperEncoderLayer`].
|
|
|
+
|
|
|
+ Args:
|
|
|
+ config: WhisperConfig
|
|
|
+ """
|
|
|
+
|
|
|
+ def __init__(self, config: WhisperConfig):
|
|
|
+ super().__init__(config)
|
|
|
+
|
|
|
+ self.layers = nn.ModuleList([FlashWhisperEncoderLayer(config) for _ in range(config.encoder_layers)])
|
|
|
+
|
|
|
+ def forward(
|
|
|
+ self,
|
|
|
+ input_features,
|
|
|
+ attention_mask=None,
|
|
|
+ head_mask=None,
|
|
|
+ output_attentions=None,
|
|
|
+ output_hidden_states=None,
|
|
|
+ return_dict=None,
|
|
|
+ ):
|
|
|
+ r"""
|
|
|
+ Args:
|
|
|
+ input_features (`torch.LongTensor` of shape `(batch_size, feature_size, sequence_length)`):
|
|
|
+ Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be
|
|
|
+ obtained by loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a
|
|
|
+ `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into
|
|
|
+ `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding
|
|
|
+ and conversion into a tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`]
|
|
|
+ attention_mask (`torch.Tensor`)`, *optional*):
|
|
|
+ Whisper does not support masking of the `input_features`, this argument is preserved for compatibility,
|
|
|
+ but it is not used. By default the silence in the input log mel spectrogram are ignored.
|
|
|
+ head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
|
|
|
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
|
|
|
+
|
|
|
+ - 1 indicates the head is **not masked**,
|
|
|
+ - 0 indicates the head is **masked**.
|
|
|
+ output_attentions (`bool`, *optional*):
|
|
|
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
|
|
+ returned tensors for more detail.
|
|
|
+ output_hidden_states (`bool`, *optional*):
|
|
|
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
|
|
+ for more detail.
|
|
|
+ return_dict (`bool`, *optional*):
|
|
|
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
|
|
+ """
|
|
|
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
|
+ output_hidden_states = (
|
|
|
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
|
+ )
|
|
|
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
+ inputs_embeds = nn.functional.gelu(self.conv1(input_features))
|
|
|
+ inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))
|
|
|
+
|
|
|
+ inputs_embeds = inputs_embeds.permute(0, 2, 1)
|
|
|
+ embed_pos = self.embed_positions.weight
|
|
|
+
|
|
|
+ hidden_states = inputs_embeds + embed_pos[None, :inputs_embeds.size(1), :]
|
|
|
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
|
|
+
|
|
|
+ encoder_states = () if output_hidden_states else None
|
|
|
+ all_attentions = () if output_attentions else None
|
|
|
+
|
|
|
+ # check if head_mask has a correct number of layers specified if desired
|
|
|
+ if head_mask is not None:
|
|
|
+ assert head_mask.size()[0] == (
|
|
|
+ len(self.layers)
|
|
|
+ ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
|
|
|
+
|
|
|
+ for idx, encoder_layer in enumerate(self.layers):
|
|
|
+ if output_hidden_states:
|
|
|
+ encoder_states = encoder_states + (hidden_states,)
|
|
|
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
|
|
+ to_drop = False
|
|
|
+ if self.training:
|
|
|
+ dropout_probability = torch.rand([])
|
|
|
+ if dropout_probability < self.layerdrop: # skip the layer
|
|
|
+ to_drop = True
|
|
|
+
|
|
|
+ if to_drop:
|
|
|
+ layer_outputs = (None, None)
|
|
|
+ else:
|
|
|
+ if self.gradient_checkpointing and self.training:
|
|
|
+
|
|
|
+ def create_custom_forward(module):
|
|
|
+ def custom_forward(*inputs):
|
|
|
+ return module(*inputs, output_attentions)
|
|
|
+
|
|
|
+ return custom_forward
|
|
|
+
|
|
|
+ layer_outputs = torch.utils.checkpoint.checkpoint(
|
|
|
+ create_custom_forward(encoder_layer),
|
|
|
+ hidden_states,
|
|
|
+ None,
|
|
|
+ (head_mask[idx] if head_mask is not None else None),
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ layer_outputs = encoder_layer(
|
|
|
+ hidden_states,
|
|
|
+ None,
|
|
|
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
|
|
+ output_attentions=output_attentions,
|
|
|
+ )
|
|
|
+
|
|
|
+ hidden_states = layer_outputs[0]
|
|
|
+
|
|
|
+ if output_attentions:
|
|
|
+ all_attentions = all_attentions + (layer_outputs[1],)
|
|
|
+
|
|
|
+ hidden_states = self.layer_norm(hidden_states)
|
|
|
+
|
|
|
+ # Simply set states to zero for attention_mask
|
|
|
+ # hidden_states[:, 40:, :] = 0
|
|
|
+
|
|
|
+ if output_hidden_states:
|
|
|
+ encoder_states = encoder_states + (hidden_states,)
|
|
|
+
|
|
|
+ if not return_dict:
|
|
|
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
|
|
|
+ return BaseModelOutput(
|
|
|
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+class FlashWhisperDecoder(WhisperDecoder):
|
|
|
+ """
|
|
|
+ Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a
|
|
|
+ [`WhisperDecoderLayer`]
|
|
|
+
|
|
|
+ Args:
|
|
|
+ config: WhisperConfig
|
|
|
+ """
|
|
|
+
|
|
|
+ def __init__(self, config: WhisperConfig):
|
|
|
+ super().__init__(config)
|
|
|
+
|
|
|
+ self.layers = nn.ModuleList([FlashWhisperDecoderLayer(config) for _ in range(config.decoder_layers)])
|
|
|
+
|
|
|
+
|
|
|
+class FlashWhisperModel(WhisperModel):
|
|
|
+ def __init__(self, config: WhisperConfig):
|
|
|
+ super().__init__(config)
|
|
|
+
|
|
|
+ self.encoder = FlashWhisperEncoder(config)
|
|
|
+ self.decoder = FlashWhisperDecoder(config)
|
|
|
+ self.post_init()
|
|
|
+
|
|
|
+
|
|
|
+class FlashWhisperForConditionalGeneration(WhisperForConditionalGeneration):
|
|
|
+ def __init__(self, config: WhisperConfig):
|
|
|
+ super().__init__(config)
|
|
|
+
|
|
|
+ self.model = FlashWhisperModel(config)
|
|
|
+ self.post_init()
|