|
|
@@ -1,44 +1,42 @@
|
|
|
# A whisper that supports flash-attention and dynamic input length.
|
|
|
+from typing import Optional, Tuple, Union
|
|
|
+
|
|
|
+import numpy as np
|
|
|
import torch
|
|
|
+import torch.nn.functional as F
|
|
|
+from torch import nn
|
|
|
+from torch.nn import CrossEntropyLoss
|
|
|
+from transformers.generation.logits_process import WhisperTimeStampLogitsProcessor
|
|
|
+from transformers.modeling_outputs import BaseModelOutput
|
|
|
from transformers.models.whisper.modeling_whisper import (
|
|
|
- WhisperPreTrainedModel,
|
|
|
- WhisperConfig,
|
|
|
- WHISPER_START_DOCSTRING,
|
|
|
WHISPER_INPUTS_DOCSTRING,
|
|
|
- WhisperModel,
|
|
|
- shift_tokens_right,
|
|
|
- _dynamic_time_warping,
|
|
|
- _median_filter,
|
|
|
+ WHISPER_START_DOCSTRING,
|
|
|
WhisperAttention,
|
|
|
+ WhisperConfig,
|
|
|
+ WhisperDecoder,
|
|
|
+ WhisperDecoderLayer,
|
|
|
WhisperEncoder,
|
|
|
- WhisperModel,
|
|
|
- WhisperPreTrainedModel,
|
|
|
WhisperEncoderLayer,
|
|
|
WhisperForConditionalGeneration,
|
|
|
- WhisperDecoder,
|
|
|
- WhisperDecoderLayer
|
|
|
+ WhisperModel,
|
|
|
+ WhisperPreTrainedModel,
|
|
|
+ _dynamic_time_warping,
|
|
|
+ _median_filter,
|
|
|
+ shift_tokens_right,
|
|
|
)
|
|
|
-from transformers.models.whisper.tokenization_whisper import TO_LANGUAGE_CODE, TASK_IDS
|
|
|
-from torch.nn import CrossEntropyLoss
|
|
|
-from torch import nn
|
|
|
-from typing import Optional, Tuple, Union
|
|
|
+from transformers.models.whisper.tokenization_whisper import TASK_IDS, TO_LANGUAGE_CODE
|
|
|
from transformers.utils import (
|
|
|
add_start_docstrings,
|
|
|
add_start_docstrings_to_model_forward,
|
|
|
logging,
|
|
|
replace_return_docstrings,
|
|
|
)
|
|
|
-from transformers.modeling_outputs import (
|
|
|
- BaseModelOutput,
|
|
|
-)
|
|
|
-from transformers.generation.logits_process import WhisperTimeStampLogitsProcessor
|
|
|
-import numpy as np
|
|
|
-import torch.nn.functional as F
|
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
_CONFIG_FOR_DOC = "WhisperConfig"
|
|
|
|
|
|
+
|
|
|
class FlashWhisperAttention(WhisperAttention):
|
|
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
|
|
|
|
@@ -105,7 +103,7 @@ class FlashWhisperAttention(WhisperAttention):
|
|
|
key=key_states,
|
|
|
value=value_states,
|
|
|
attn_mask=attention_mask,
|
|
|
- scale=self.scaling
|
|
|
+ scale=self.scaling,
|
|
|
)
|
|
|
|
|
|
attn_output = attn_output.transpose(1, 2)
|
|
|
@@ -130,6 +128,7 @@ class FlashWhisperEncoderLayer(WhisperEncoderLayer):
|
|
|
dropout=config.attention_dropout,
|
|
|
)
|
|
|
|
|
|
+
|
|
|
class FlashWhisperDecoderLayer(WhisperDecoderLayer):
|
|
|
def __init__(self, config: WhisperConfig):
|
|
|
super().__init__(config)
|
|
|
@@ -141,6 +140,7 @@ class FlashWhisperDecoderLayer(WhisperDecoderLayer):
|
|
|
is_decoder=True,
|
|
|
)
|
|
|
|
|
|
+
|
|
|
class FlashWhisperEncoder(WhisperEncoder):
|
|
|
"""
|
|
|
Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
|
|
|
@@ -152,8 +152,10 @@ class FlashWhisperEncoder(WhisperEncoder):
|
|
|
|
|
|
def __init__(self, config: WhisperConfig):
|
|
|
super().__init__(config)
|
|
|
-
|
|
|
- self.layers = nn.ModuleList([FlashWhisperEncoderLayer(config) for _ in range(config.encoder_layers)])
|
|
|
+
|
|
|
+ self.layers = nn.ModuleList(
|
|
|
+ [FlashWhisperEncoderLayer(config) for _ in range(config.encoder_layers)]
|
|
|
+ )
|
|
|
|
|
|
def forward(
|
|
|
self,
|
|
|
@@ -189,19 +191,29 @@ class FlashWhisperEncoder(WhisperEncoder):
|
|
|
return_dict (`bool`, *optional*):
|
|
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
|
|
"""
|
|
|
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
|
+ output_attentions = (
|
|
|
+ output_attentions
|
|
|
+ if output_attentions is not None
|
|
|
+ else self.config.output_attentions
|
|
|
+ )
|
|
|
output_hidden_states = (
|
|
|
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
|
+ output_hidden_states
|
|
|
+ if output_hidden_states is not None
|
|
|
+ else self.config.output_hidden_states
|
|
|
+ )
|
|
|
+ return_dict = (
|
|
|
+ return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
)
|
|
|
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
inputs_embeds = nn.functional.gelu(self.conv1(input_features))
|
|
|
inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))
|
|
|
|
|
|
inputs_embeds = inputs_embeds.permute(0, 2, 1)
|
|
|
embed_pos = self.embed_positions.weight
|
|
|
|
|
|
- hidden_states = inputs_embeds + embed_pos[None, :inputs_embeds.size(1), :]
|
|
|
- hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
|
|
+ hidden_states = inputs_embeds + embed_pos[None, : inputs_embeds.size(1), :]
|
|
|
+ hidden_states = nn.functional.dropout(
|
|
|
+ hidden_states, p=self.dropout, training=self.training
|
|
|
+ )
|
|
|
|
|
|
encoder_states = () if output_hidden_states else None
|
|
|
all_attentions = () if output_attentions else None
|
|
|
@@ -243,7 +255,9 @@ class FlashWhisperEncoder(WhisperEncoder):
|
|
|
layer_outputs = encoder_layer(
|
|
|
hidden_states,
|
|
|
None,
|
|
|
- layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
|
|
+ layer_head_mask=(
|
|
|
+ head_mask[idx] if head_mask is not None else None
|
|
|
+ ),
|
|
|
output_attentions=output_attentions,
|
|
|
)
|
|
|
|
|
|
@@ -261,9 +275,15 @@ class FlashWhisperEncoder(WhisperEncoder):
|
|
|
encoder_states = encoder_states + (hidden_states,)
|
|
|
|
|
|
if not return_dict:
|
|
|
- return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
|
|
|
+ return tuple(
|
|
|
+ v
|
|
|
+ for v in [hidden_states, encoder_states, all_attentions]
|
|
|
+ if v is not None
|
|
|
+ )
|
|
|
return BaseModelOutput(
|
|
|
- last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
|
|
|
+ last_hidden_state=hidden_states,
|
|
|
+ hidden_states=encoder_states,
|
|
|
+ attentions=all_attentions,
|
|
|
)
|
|
|
|
|
|
|
|
|
@@ -279,7 +299,9 @@ class FlashWhisperDecoder(WhisperDecoder):
|
|
|
def __init__(self, config: WhisperConfig):
|
|
|
super().__init__(config)
|
|
|
|
|
|
- self.layers = nn.ModuleList([FlashWhisperDecoderLayer(config) for _ in range(config.decoder_layers)])
|
|
|
+ self.layers = nn.ModuleList(
|
|
|
+ [FlashWhisperDecoderLayer(config) for _ in range(config.decoder_layers)]
|
|
|
+ )
|
|
|
|
|
|
|
|
|
class FlashWhisperModel(WhisperModel):
|