|
|
@@ -166,7 +166,7 @@ class Transformer(nn.Module):
|
|
|
x = torch.stack(vocab_embeds, dim=3)
|
|
|
x = x.sum(dim=3)
|
|
|
|
|
|
- if self.config.neft_alpha > 0:
|
|
|
+ if self.config.neft_alpha > 0 and self.training:
|
|
|
# alpha / sqrt(L * D)
|
|
|
scaled_alpha = self.config.neft_alpha / math.sqrt(
|
|
|
self.config.dim * x.shape[2]
|