|
@@ -31,6 +31,7 @@ class ModelArgs:
|
|
|
rope_base: float = 10000
|
|
rope_base: float = 10000
|
|
|
norm_eps: float = 1e-5
|
|
norm_eps: float = 1e-5
|
|
|
max_seq_len: int = 2048
|
|
max_seq_len: int = 2048
|
|
|
|
|
+ dropout: float = 0.0
|
|
|
|
|
|
|
|
# Additional decoding heads
|
|
# Additional decoding heads
|
|
|
codebook_size: int = 160
|
|
codebook_size: int = 160
|
|
@@ -260,6 +261,7 @@ class Attention(nn.Module):
|
|
|
self.wo = nn.Linear(config.dim, config.dim, bias=False)
|
|
self.wo = nn.Linear(config.dim, config.dim, bias=False)
|
|
|
self.kv_cache = None
|
|
self.kv_cache = None
|
|
|
|
|
|
|
|
|
|
+ self.dropout = config.dropout
|
|
|
self.n_head = config.n_head
|
|
self.n_head = config.n_head
|
|
|
self.head_dim = config.head_dim
|
|
self.head_dim = config.head_dim
|
|
|
self.n_local_heads = config.n_local_heads
|
|
self.n_local_heads = config.n_local_heads
|
|
@@ -301,7 +303,13 @@ class Attention(nn.Module):
|
|
|
|
|
|
|
|
k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
|
|
k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
|
|
|
v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
|
|
v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
|
|
|
- y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
|
|
|
|
|
|
|
+ y = F.scaled_dot_product_attention(
|
|
|
|
|
+ q,
|
|
|
|
|
+ k,
|
|
|
|
|
+ v,
|
|
|
|
|
+ attn_mask=mask,
|
|
|
|
|
+ dropout_p=self.dropout if self.training else 0.0,
|
|
|
|
|
+ )
|
|
|
|
|
|
|
|
y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
|
|
y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
|
|
|
else:
|
|
else:
|
|
@@ -311,7 +319,7 @@ class Attention(nn.Module):
|
|
|
|
|
|
|
|
# We don't need to transpose q, k, v here because flash_attn_varlen_func
|
|
# We don't need to transpose q, k, v here because flash_attn_varlen_func
|
|
|
attn_output = self._flash_attention_forward(
|
|
attn_output = self._flash_attention_forward(
|
|
|
- q, k, v, mask, seqlen, dropout=0.0
|
|
|
|
|
|
|
+ q, k, v, mask, seqlen, dropout=self.dropout if self.training else 0.0
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
y = attn_output.reshape(bsz, seqlen, self.dim).contiguous()
|
|
y = attn_output.reshape(bsz, seqlen, self.dim).contiguous()
|