|
|
@@ -8,6 +8,7 @@ from omegaconf import DictConfig, OmegaConf
|
|
|
from tqdm import tqdm
|
|
|
from transformers import LlamaForCausalLM
|
|
|
from transformers.utils import is_flash_attn_available
|
|
|
+from natsort import natsorted
|
|
|
|
|
|
# Allow TF32 on Ampere GPUs
|
|
|
torch.set_float32_matmul_precision("high")
|
|
|
@@ -125,8 +126,12 @@ def main(cfg: DictConfig):
|
|
|
# Restore training from checkpoint
|
|
|
checkpoint_dir = Path(cfg.paths.checkpoint_dir)
|
|
|
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
|
|
- checkpoint_path = checkpoint_dir / "last.ckpt"
|
|
|
- if checkpoint_path.exists():
|
|
|
+
|
|
|
+ # Alphabetically sort checkpoints
|
|
|
+ checkpoints = natsorted(checkpoint_dir.glob("*.ckpt"))
|
|
|
+ if len(checkpoints) > 0:
|
|
|
+ checkpoint_path = checkpoints[-1]
|
|
|
+
|
|
|
log.info(f"Restoring checkpoint from {checkpoint_path}")
|
|
|
remainder = fabric.load(
|
|
|
checkpoint_path,
|