Przeglądaj źródła

Fix convnext init

Lengyue 2 lat temu
rodzic
commit
03a4247d68
1 zmienionych plików z 9 dodań i 4 usunięć
  1. 9 4
      fish_speech/models/vq_diffusion/convnext_1d.py

+ 9 - 4
fish_speech/models/vq_diffusion/convnext_1d.py

@@ -163,10 +163,6 @@ class ConvNext1DModel(ModelMixin, ConfigMixin):
         self.in_proj = nn.Conv1d(in_channels, intermediate_dim, 1)
         self.out_proj = nn.Conv1d(intermediate_dim, out_channels, 1)
 
-        # Initialize weights
-        nn.init.normal_(self.out_proj.weight, mean=0, std=0.01)
-        nn.init.zeros_(self.out_proj.bias)
-
         # Blocks
         self.blocks = nn.ModuleList(
             [
@@ -179,11 +175,20 @@ class ConvNext1DModel(ModelMixin, ConfigMixin):
             ]
         )
 
+        # Initialize weights
+        self.apply(self._init_weights)
+
         self.gradient_checkpointing = False
 
     def _set_gradient_checkpointing(self, module, value: bool = False):
         self.gradient_checkpointing = value
 
+    def _init_weights(self, m):
+        if isinstance(m, (nn.Conv2d, nn.Linear, nn.Conv1d)):
+            nn.init.trunc_normal_(m.weight, mean=0, std=0.02)
+            if m.bias is not None:
+                nn.init.zeros_(m.bias)
+
     def forward(
         self,
         sample: torch.FloatTensor,