Lengyue 2 лет назад
Родитель
Сommit
8f9299673d
1 измененных файлов с 21 добавлено и 1 удалено
  1. 21 1
      fish_speech/models/text2semantic/lit_module.py

+ 21 - 1
fish_speech/models/text2semantic/lit_module.py

@@ -96,7 +96,27 @@ class TextToSemantic(L.LightningModule):
                 state_dict.pop(name)
 
     def configure_optimizers(self) -> OptimizerLRScheduler:
-        optimizer = self.optimizer_builder(self.parameters())
+        # Get weight decay parameters
+        weight_decay_parameters, other_parameters = [], []
+        for name, param in self.named_parameters():
+            if ".bias" in name or "norm.weight" in name or ".embeddings." in name:
+                other_parameters.append(param)
+            else:
+                weight_decay_parameters.append(param)
+
+        optimizer = self.optimizer_builder(
+            [
+                {"params": weight_decay_parameters},
+                {"params": other_parameters, "weight_decay": 0.0},
+            ]
+        )
+
+        # Print the parameters and their weight decay
+        for i in optimizer.param_groups:
+            log.info(
+                f"Set weight decay: {i['weight_decay']} for {len(i['params'])} parameters"
+            )
+
         lr_scheduler = self.lr_scheduler_builder(optimizer)
 
         return {