|
@@ -76,19 +76,23 @@ def train(cfg: DictConfig) -> tuple[dict, dict]:
|
|
|
log.info("Starting training!")
|
|
log.info("Starting training!")
|
|
|
|
|
|
|
|
ckpt_path = cfg.get("ckpt_path")
|
|
ckpt_path = cfg.get("ckpt_path")
|
|
|
|
|
+ auto_resume = False
|
|
|
|
|
|
|
|
if ckpt_path is None:
|
|
if ckpt_path is None:
|
|
|
ckpt_path = utils.get_latest_checkpoint(cfg.paths.ckpt_dir)
|
|
ckpt_path = utils.get_latest_checkpoint(cfg.paths.ckpt_dir)
|
|
|
|
|
+ auto_resume = True
|
|
|
|
|
|
|
|
if ckpt_path is not None:
|
|
if ckpt_path is not None:
|
|
|
log.info(f"Resuming from checkpoint: {ckpt_path}")
|
|
log.info(f"Resuming from checkpoint: {ckpt_path}")
|
|
|
|
|
|
|
|
- if cfg.get("resume_weights_only"):
|
|
|
|
|
|
|
+ # resume weights only is disabled for auto-resume
|
|
|
|
|
+ if cfg.get("resume_weights_only") and auto_resume is False:
|
|
|
log.info("Resuming weights only!")
|
|
log.info("Resuming weights only!")
|
|
|
ckpt = torch.load(ckpt_path, map_location=model.device)
|
|
ckpt = torch.load(ckpt_path, map_location=model.device)
|
|
|
- model.load_state_dict(
|
|
|
|
|
- ckpt["state_dict"] if "state_dict" in ckpt else ckpt, strict=False
|
|
|
|
|
- )
|
|
|
|
|
|
|
+ if "state_dict" in ckpt:
|
|
|
|
|
+ ckpt = ckpt["state_dict"]
|
|
|
|
|
+ err = model.load_state_dict(ckpt, strict=False)
|
|
|
|
|
+ log.info(f"Error loading state dict: {err}")
|
|
|
ckpt_path = None
|
|
ckpt_path = None
|
|
|
|
|
|
|
|
trainer.fit(model=model, datamodule=datamodule, ckpt_path=ckpt_path)
|
|
trainer.fit(model=model, datamodule=datamodule, ckpt_path=ckpt_path)
|