|
|
@@ -6,6 +6,11 @@ import torch.nn as nn
|
|
|
from einops import rearrange
|
|
|
from torch import Tensor
|
|
|
from torch.nn import functional as F
|
|
|
+from transformers.utils import is_flash_attn_2_available
|
|
|
+
|
|
|
+if is_flash_attn_2_available():
|
|
|
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
|
|
|
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
|
|
|
|
|
|
|
|
def find_multiple(n: int, k: int) -> int:
|
|
|
@@ -32,6 +37,12 @@ class ModelArgs:
|
|
|
num_codebooks: int = 4
|
|
|
codebook_padding_idx: int = 0
|
|
|
|
|
|
+ # Use flash attention
|
|
|
+ use_flash_attention: bool = is_flash_attn_2_available()
|
|
|
+
|
|
|
+ # Gradient checkpointing
|
|
|
+ use_gradient_checkpointing: bool = True
|
|
|
+
|
|
|
def __post_init__(self):
|
|
|
if self.n_local_heads == -1:
|
|
|
self.n_local_heads = self.n_head
|
|
|
@@ -154,7 +165,12 @@ class Transformer(nn.Module):
|
|
|
input_pos: Optional[Tensor] = None,
|
|
|
) -> TransformerForwardResult:
|
|
|
for layer in self.layers:
|
|
|
- x = layer(x, freqs_cis, mask, input_pos=input_pos)
|
|
|
+ if self.config.use_gradient_checkpointing and self.training:
|
|
|
+ x = torch.utils.checkpoint.checkpoint(
|
|
|
+ layer, x, freqs_cis, mask, input_pos, use_reentrant=True
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ x = layer(x, freqs_cis, mask, input_pos=input_pos)
|
|
|
|
|
|
x = self.norm(x)
|
|
|
logits = self.output(x)
|
|
|
@@ -191,8 +207,10 @@ class Transformer(nn.Module):
|
|
|
# Not that the causal mask here follows the definition of scaled_dot_product_attention
|
|
|
# That is, FALSE means masked out
|
|
|
# To maintain consistency, key_padding_mask use TRUE to mask out
|
|
|
- if key_padding_mask is not None:
|
|
|
+ if self.config.use_flash_attention is False and key_padding_mask is not None:
|
|
|
mask = mask & key_padding_mask[:, None, None, :].logical_not()
|
|
|
+ elif self.config.use_flash_attention is True and key_padding_mask is not None:
|
|
|
+ mask = key_padding_mask.logical_not()
|
|
|
|
|
|
return self.compute(x, freqs_cis, mask)
|
|
|
|
|
|
@@ -246,6 +264,7 @@ class Attention(nn.Module):
|
|
|
self.head_dim = config.head_dim
|
|
|
self.n_local_heads = config.n_local_heads
|
|
|
self.dim = config.dim
|
|
|
+ self.use_flash_attention = config.use_flash_attention
|
|
|
self._register_load_state_dict_pre_hook(self.load_hook)
|
|
|
|
|
|
def load_hook(self, state_dict, prefix, *args):
|
|
|
@@ -274,19 +293,165 @@ class Attention(nn.Module):
|
|
|
q = apply_rotary_emb(q, freqs_cis)
|
|
|
k = apply_rotary_emb(k, freqs_cis)
|
|
|
|
|
|
- q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
|
|
|
+ if self.use_flash_attention is False:
|
|
|
+ q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
|
|
|
+
|
|
|
+ if self.kv_cache is not None:
|
|
|
+ k, v = self.kv_cache.update(input_pos, k, v)
|
|
|
+
|
|
|
+ k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
|
|
|
+ v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
|
|
|
+ y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
|
|
|
+
|
|
|
+ y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
|
|
|
+ else:
|
|
|
+ assert (
|
|
|
+ self.kv_cache is None
|
|
|
+ ), "kv_cache is not supported for flash attention for now"
|
|
|
+
|
|
|
+ # We don't need to transpose q, k, v here because flash_attn_varlen_func
|
|
|
+ attn_output = self._flash_attention_forward(
|
|
|
+ q, k, v, mask, seqlen, dropout=0.0
|
|
|
+ )
|
|
|
+
|
|
|
+ y = attn_output.reshape(bsz, seqlen, self.dim).contiguous()
|
|
|
+
|
|
|
+ return self.wo(y)
|
|
|
+
|
|
|
+ def _flash_attention_forward(
|
|
|
+ self,
|
|
|
+ query_states,
|
|
|
+ key_states,
|
|
|
+ value_states,
|
|
|
+ attention_mask,
|
|
|
+ query_length,
|
|
|
+ dropout=0.0,
|
|
|
+ softmax_scale=None,
|
|
|
+ ):
|
|
|
+ """
|
|
|
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
|
|
|
+ first unpad the input, then computes the attention scores and pad the final attention scores.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ query_states (`torch.Tensor`):
|
|
|
+ Input query states to be passed to Flash Attention API
|
|
|
+ key_states (`torch.Tensor`):
|
|
|
+ Input key states to be passed to Flash Attention API
|
|
|
+ value_states (`torch.Tensor`):
|
|
|
+ Input value states to be passed to Flash Attention API
|
|
|
+ attention_mask (`torch.Tensor`):
|
|
|
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
|
|
|
+ position of padding tokens and 1 for the position of non-padding tokens.
|
|
|
+ dropout (`int`, *optional*):
|
|
|
+ Attention dropout
|
|
|
+ softmax_scale (`float`, *optional*):
|
|
|
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
|
|
|
+ """
|
|
|
+
|
|
|
+ # Contains at least one padding token in the sequence
|
|
|
+ if attention_mask is not None:
|
|
|
+ batch_size = query_states.shape[0]
|
|
|
+ (
|
|
|
+ query_states,
|
|
|
+ key_states,
|
|
|
+ value_states,
|
|
|
+ indices_q,
|
|
|
+ cu_seq_lens,
|
|
|
+ max_seq_lens,
|
|
|
+ ) = self._upad_input(
|
|
|
+ query_states, key_states, value_states, attention_mask, query_length
|
|
|
+ )
|
|
|
+
|
|
|
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
|
|
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
|
|
+
|
|
|
+ attn_output_unpad = flash_attn_varlen_func(
|
|
|
+ query_states,
|
|
|
+ key_states,
|
|
|
+ value_states,
|
|
|
+ cu_seqlens_q=cu_seqlens_q,
|
|
|
+ cu_seqlens_k=cu_seqlens_k,
|
|
|
+ max_seqlen_q=max_seqlen_in_batch_q,
|
|
|
+ max_seqlen_k=max_seqlen_in_batch_k,
|
|
|
+ dropout_p=dropout,
|
|
|
+ softmax_scale=softmax_scale,
|
|
|
+ causal=True,
|
|
|
+ )
|
|
|
|
|
|
- if self.kv_cache is not None:
|
|
|
- k, v = self.kv_cache.update(input_pos, k, v)
|
|
|
+ attn_output = pad_input(
|
|
|
+ attn_output_unpad, indices_q, batch_size, query_length
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ attn_output = flash_attn_func(
|
|
|
+ query_states,
|
|
|
+ key_states,
|
|
|
+ value_states,
|
|
|
+ dropout,
|
|
|
+ softmax_scale=softmax_scale,
|
|
|
+ causal=True,
|
|
|
+ )
|
|
|
|
|
|
- k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
|
|
|
- v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
|
|
|
- y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
|
|
|
+ return attn_output
|
|
|
|
|
|
- y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
|
|
|
+ def _get_unpad_data(self, attention_mask):
|
|
|
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
|
|
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
|
|
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
|
|
|
+ cu_seqlens = F.pad(
|
|
|
+ torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)
|
|
|
+ )
|
|
|
+ return (
|
|
|
+ indices,
|
|
|
+ cu_seqlens,
|
|
|
+ max_seqlen_in_batch,
|
|
|
+ )
|
|
|
|
|
|
- y = self.wo(y)
|
|
|
- return y
|
|
|
+ def _upad_input(
|
|
|
+ self, query_layer, key_layer, value_layer, attention_mask, query_length
|
|
|
+ ):
|
|
|
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = self._get_unpad_data(
|
|
|
+ attention_mask
|
|
|
+ )
|
|
|
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
|
|
|
+
|
|
|
+ key_layer = index_first_axis(
|
|
|
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
|
|
|
+ indices_k,
|
|
|
+ )
|
|
|
+ value_layer = index_first_axis(
|
|
|
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
|
|
|
+ indices_k,
|
|
|
+ )
|
|
|
+ if query_length == kv_seq_len:
|
|
|
+ query_layer = index_first_axis(
|
|
|
+ query_layer.reshape(batch_size * kv_seq_len, self.n_head, head_dim),
|
|
|
+ indices_k,
|
|
|
+ )
|
|
|
+ cu_seqlens_q = cu_seqlens_k
|
|
|
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
|
|
|
+ indices_q = indices_k
|
|
|
+ elif query_length == 1:
|
|
|
+ max_seqlen_in_batch_q = 1
|
|
|
+ cu_seqlens_q = torch.arange(
|
|
|
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
|
|
|
+ ) # There is a memcpy here, that is very bad.
|
|
|
+ indices_q = cu_seqlens_q[:-1]
|
|
|
+ query_layer = query_layer.squeeze(1)
|
|
|
+ else:
|
|
|
+ # The -q_len: slice assumes left padding.
|
|
|
+ attention_mask = attention_mask[:, -query_length:]
|
|
|
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(
|
|
|
+ query_layer, attention_mask
|
|
|
+ )
|
|
|
+
|
|
|
+ return (
|
|
|
+ query_layer,
|
|
|
+ key_layer,
|
|
|
+ value_layer,
|
|
|
+ indices_q,
|
|
|
+ (cu_seqlens_q, cu_seqlens_k),
|
|
|
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
|
|
|
+ )
|
|
|
|
|
|
|
|
|
class FeedForward(nn.Module):
|