|
|
@@ -1,3 +1,4 @@
|
|
|
+import math
|
|
|
from dataclasses import dataclass
|
|
|
from typing import Optional
|
|
|
|
|
|
@@ -44,6 +45,9 @@ class ModelArgs:
|
|
|
# Gradient checkpointing
|
|
|
use_gradient_checkpointing: bool = True
|
|
|
|
|
|
+ # NEFT
|
|
|
+ neft_alpha: float = 0
|
|
|
+
|
|
|
def __post_init__(self):
|
|
|
if self.n_local_heads == -1:
|
|
|
self.n_local_heads = self.n_head
|
|
|
@@ -156,7 +160,14 @@ class Transformer(nn.Module):
|
|
|
vocab_embeds.append(emb)
|
|
|
|
|
|
x = torch.stack(vocab_embeds, dim=3)
|
|
|
- return x.sum(dim=3)
|
|
|
+ x = x.sum(dim=3)
|
|
|
+
|
|
|
+ if self.config.neft_alpha > 0:
|
|
|
+ # alpha / sqrt(L * D)
|
|
|
+ scaled_alpha = self.config.neft_alpha / math.sqrt(
|
|
|
+ self.config.dim * x.shape[2]
|
|
|
+ )
|
|
|
+ x += torch.rand_like(x) * scaled_alpha
|
|
|
|
|
|
def compute(
|
|
|
self,
|