|
|
@@ -163,10 +163,6 @@ class ConvNext1DModel(ModelMixin, ConfigMixin):
|
|
|
self.in_proj = nn.Conv1d(in_channels, intermediate_dim, 1)
|
|
|
self.out_proj = nn.Conv1d(intermediate_dim, out_channels, 1)
|
|
|
|
|
|
- # Initialize weights
|
|
|
- nn.init.normal_(self.out_proj.weight, mean=0, std=0.01)
|
|
|
- nn.init.zeros_(self.out_proj.bias)
|
|
|
-
|
|
|
# Blocks
|
|
|
self.blocks = nn.ModuleList(
|
|
|
[
|
|
|
@@ -179,11 +175,20 @@ class ConvNext1DModel(ModelMixin, ConfigMixin):
|
|
|
]
|
|
|
)
|
|
|
|
|
|
+ # Initialize weights
|
|
|
+ self.apply(self._init_weights)
|
|
|
+
|
|
|
self.gradient_checkpointing = False
|
|
|
|
|
|
def _set_gradient_checkpointing(self, module, value: bool = False):
|
|
|
self.gradient_checkpointing = value
|
|
|
|
|
|
+ def _init_weights(self, m):
|
|
|
+ if isinstance(m, (nn.Conv2d, nn.Linear, nn.Conv1d)):
|
|
|
+ nn.init.trunc_normal_(m.weight, mean=0, std=0.02)
|
|
|
+ if m.bias is not None:
|
|
|
+ nn.init.zeros_(m.bias)
|
|
|
+
|
|
|
def forward(
|
|
|
self,
|
|
|
sample: torch.FloatTensor,
|