|
|
@@ -7,11 +7,6 @@ import torch.nn as nn
|
|
|
from einops import rearrange
|
|
|
from torch import Tensor
|
|
|
from torch.nn import functional as F
|
|
|
-from transformers.utils import is_flash_attn_2_available
|
|
|
-
|
|
|
-if is_flash_attn_2_available():
|
|
|
- from flash_attn import flash_attn_func, flash_attn_varlen_func
|
|
|
- from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
|
|
|
|
|
|
|
|
def find_multiple(n: int, k: int) -> int:
|
|
|
@@ -40,9 +35,6 @@ class ModelArgs:
|
|
|
num_in_codebooks: Optional[int] = None
|
|
|
codebook_padding_idx: int = 0
|
|
|
|
|
|
- # Use flash attention
|
|
|
- use_flash_attention: bool = False
|
|
|
-
|
|
|
# Gradient checkpointing
|
|
|
use_gradient_checkpointing: bool = True
|
|
|
|
|
|
@@ -225,10 +217,8 @@ class Transformer(nn.Module):
|
|
|
# Not that the causal mask here follows the definition of scaled_dot_product_attention
|
|
|
# That is, FALSE means masked out
|
|
|
# To maintain consistency, key_padding_mask use TRUE to mask out
|
|
|
- if self.config.use_flash_attention is False and key_padding_mask is not None:
|
|
|
+ if key_padding_mask is not None:
|
|
|
mask = mask & key_padding_mask[:, None, None, :].logical_not()
|
|
|
- elif self.config.use_flash_attention is True and key_padding_mask is not None:
|
|
|
- mask = key_padding_mask.logical_not()
|
|
|
|
|
|
return self.compute(x, freqs_cis, mask)
|
|
|
|
|
|
@@ -283,7 +273,6 @@ class Attention(nn.Module):
|
|
|
self.head_dim = config.head_dim
|
|
|
self.n_local_heads = config.n_local_heads
|
|
|
self.dim = config.dim
|
|
|
- self.use_flash_attention = config.use_flash_attention
|
|
|
self._register_load_state_dict_pre_hook(self.load_hook)
|
|
|
|
|
|
def load_hook(self, state_dict, prefix, *args):
|
|
|
@@ -312,171 +301,24 @@ class Attention(nn.Module):
|
|
|
q = apply_rotary_emb(q, freqs_cis)
|
|
|
k = apply_rotary_emb(k, freqs_cis)
|
|
|
|
|
|
- if self.use_flash_attention is False:
|
|
|
- q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
|
|
|
-
|
|
|
- if self.kv_cache is not None:
|
|
|
- k, v = self.kv_cache.update(input_pos, k, v)
|
|
|
-
|
|
|
- k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
|
|
|
- v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
|
|
|
- y = F.scaled_dot_product_attention(
|
|
|
- q,
|
|
|
- k,
|
|
|
- v,
|
|
|
- attn_mask=mask,
|
|
|
- dropout_p=self.dropout if self.training else 0.0,
|
|
|
- )
|
|
|
-
|
|
|
- y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
|
|
|
- else:
|
|
|
- assert (
|
|
|
- self.kv_cache is None
|
|
|
- ), "kv_cache is not supported for flash attention for now"
|
|
|
-
|
|
|
- # We don't need to transpose q, k, v here because flash_attn_varlen_func
|
|
|
- attn_output = self._flash_attention_forward(
|
|
|
- q, k, v, mask, seqlen, dropout=self.dropout if self.training else 0.0
|
|
|
- )
|
|
|
-
|
|
|
- y = attn_output.reshape(bsz, seqlen, self.dim).contiguous()
|
|
|
-
|
|
|
- return self.wo(y)
|
|
|
-
|
|
|
- def _flash_attention_forward(
|
|
|
- self,
|
|
|
- query_states,
|
|
|
- key_states,
|
|
|
- value_states,
|
|
|
- attention_mask,
|
|
|
- query_length,
|
|
|
- dropout=0.0,
|
|
|
- softmax_scale=None,
|
|
|
- ):
|
|
|
- """
|
|
|
- Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
|
|
|
- first unpad the input, then computes the attention scores and pad the final attention scores.
|
|
|
-
|
|
|
- Args:
|
|
|
- query_states (`torch.Tensor`):
|
|
|
- Input query states to be passed to Flash Attention API
|
|
|
- key_states (`torch.Tensor`):
|
|
|
- Input key states to be passed to Flash Attention API
|
|
|
- value_states (`torch.Tensor`):
|
|
|
- Input value states to be passed to Flash Attention API
|
|
|
- attention_mask (`torch.Tensor`):
|
|
|
- The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
|
|
|
- position of padding tokens and 1 for the position of non-padding tokens.
|
|
|
- dropout (`int`, *optional*):
|
|
|
- Attention dropout
|
|
|
- softmax_scale (`float`, *optional*):
|
|
|
- The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
|
|
|
- """
|
|
|
-
|
|
|
- # Contains at least one padding token in the sequence
|
|
|
- if attention_mask is not None:
|
|
|
- batch_size = query_states.shape[0]
|
|
|
- (
|
|
|
- query_states,
|
|
|
- key_states,
|
|
|
- value_states,
|
|
|
- indices_q,
|
|
|
- cu_seq_lens,
|
|
|
- max_seq_lens,
|
|
|
- ) = self._upad_input(
|
|
|
- query_states, key_states, value_states, attention_mask, query_length
|
|
|
- )
|
|
|
-
|
|
|
- cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
|
|
- max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
|
|
-
|
|
|
- attn_output_unpad = flash_attn_varlen_func(
|
|
|
- query_states,
|
|
|
- key_states,
|
|
|
- value_states,
|
|
|
- cu_seqlens_q=cu_seqlens_q,
|
|
|
- cu_seqlens_k=cu_seqlens_k,
|
|
|
- max_seqlen_q=max_seqlen_in_batch_q,
|
|
|
- max_seqlen_k=max_seqlen_in_batch_k,
|
|
|
- dropout_p=dropout,
|
|
|
- softmax_scale=softmax_scale,
|
|
|
- causal=True,
|
|
|
- )
|
|
|
-
|
|
|
- attn_output = pad_input(
|
|
|
- attn_output_unpad, indices_q, batch_size, query_length
|
|
|
- )
|
|
|
- else:
|
|
|
- attn_output = flash_attn_func(
|
|
|
- query_states,
|
|
|
- key_states,
|
|
|
- value_states,
|
|
|
- dropout,
|
|
|
- softmax_scale=softmax_scale,
|
|
|
- causal=True,
|
|
|
- )
|
|
|
+ q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
|
|
|
|
|
|
- return attn_output
|
|
|
+ if self.kv_cache is not None:
|
|
|
+ k, v = self.kv_cache.update(input_pos, k, v)
|
|
|
|
|
|
- def _get_unpad_data(self, attention_mask):
|
|
|
- seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
|
|
- indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
|
|
- max_seqlen_in_batch = seqlens_in_batch.max().item()
|
|
|
- cu_seqlens = F.pad(
|
|
|
- torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)
|
|
|
- )
|
|
|
- return (
|
|
|
- indices,
|
|
|
- cu_seqlens,
|
|
|
- max_seqlen_in_batch,
|
|
|
+ k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
|
|
|
+ v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
|
|
|
+ y = F.scaled_dot_product_attention(
|
|
|
+ q,
|
|
|
+ k,
|
|
|
+ v,
|
|
|
+ attn_mask=mask,
|
|
|
+ dropout_p=self.dropout if self.training else 0.0,
|
|
|
)
|
|
|
|
|
|
- def _upad_input(
|
|
|
- self, query_layer, key_layer, value_layer, attention_mask, query_length
|
|
|
- ):
|
|
|
- indices_k, cu_seqlens_k, max_seqlen_in_batch_k = self._get_unpad_data(
|
|
|
- attention_mask
|
|
|
- )
|
|
|
- batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
|
|
|
+ y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
|
|
|
|
|
|
- key_layer = index_first_axis(
|
|
|
- key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
|
|
|
- indices_k,
|
|
|
- )
|
|
|
- value_layer = index_first_axis(
|
|
|
- value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
|
|
|
- indices_k,
|
|
|
- )
|
|
|
- if query_length == kv_seq_len:
|
|
|
- query_layer = index_first_axis(
|
|
|
- query_layer.reshape(batch_size * kv_seq_len, self.n_head, head_dim),
|
|
|
- indices_k,
|
|
|
- )
|
|
|
- cu_seqlens_q = cu_seqlens_k
|
|
|
- max_seqlen_in_batch_q = max_seqlen_in_batch_k
|
|
|
- indices_q = indices_k
|
|
|
- elif query_length == 1:
|
|
|
- max_seqlen_in_batch_q = 1
|
|
|
- cu_seqlens_q = torch.arange(
|
|
|
- batch_size + 1, dtype=torch.int32, device=query_layer.device
|
|
|
- ) # There is a memcpy here, that is very bad.
|
|
|
- indices_q = cu_seqlens_q[:-1]
|
|
|
- query_layer = query_layer.squeeze(1)
|
|
|
- else:
|
|
|
- # The -q_len: slice assumes left padding.
|
|
|
- attention_mask = attention_mask[:, -query_length:]
|
|
|
- query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(
|
|
|
- query_layer, attention_mask
|
|
|
- )
|
|
|
-
|
|
|
- return (
|
|
|
- query_layer,
|
|
|
- key_layer,
|
|
|
- value_layer,
|
|
|
- indices_q,
|
|
|
- (cu_seqlens_q, cu_seqlens_k),
|
|
|
- (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
|
|
|
- )
|
|
|
+ return self.wo(y)
|
|
|
|
|
|
|
|
|
class FeedForward(nn.Module):
|