@@ -84,7 +84,7 @@ def train(
{
"model": model,
"optimizer": optimizer,
- "scheduler": scheduler,
+ "scheduler": scheduler.state_dict(),
"global_step": global_step,
},
)