|
|
@@ -96,7 +96,27 @@ class TextToSemantic(L.LightningModule):
|
|
|
state_dict.pop(name)
|
|
|
|
|
|
def configure_optimizers(self) -> OptimizerLRScheduler:
|
|
|
- optimizer = self.optimizer_builder(self.parameters())
|
|
|
+ # Get weight decay parameters
|
|
|
+ weight_decay_parameters, other_parameters = [], []
|
|
|
+ for name, param in self.named_parameters():
|
|
|
+ if ".bias" in name or "norm.weight" in name or ".embeddings." in name:
|
|
|
+ other_parameters.append(param)
|
|
|
+ else:
|
|
|
+ weight_decay_parameters.append(param)
|
|
|
+
|
|
|
+ optimizer = self.optimizer_builder(
|
|
|
+ [
|
|
|
+ {"params": weight_decay_parameters},
|
|
|
+ {"params": other_parameters, "weight_decay": 0.0},
|
|
|
+ ]
|
|
|
+ )
|
|
|
+
|
|
|
+ # Print the parameters and their weight decay
|
|
|
+ for i in optimizer.param_groups:
|
|
|
+ log.info(
|
|
|
+ f"Set weight decay: {i['weight_decay']} for {len(i['params'])} parameters"
|
|
|
+ )
|
|
|
+
|
|
|
lr_scheduler = self.lr_scheduler_builder(optimizer)
|
|
|
|
|
|
return {
|