Lengyue 2 лет назад
Родитель
Сommit
60ac11d7fe
3 измененных файлов с 11 добавлено и 5 удалено
  1. 1 1
      dockerfile
  2. 3 2
      requirements.txt
  3. 7 2
      speech_lm/train.py

+ 1 - 1
dockerfile

@@ -29,6 +29,6 @@ RUN pip3 install --upgrade pip && \
 # Project Env
 WORKDIR /exp
 COPY requirements.txt .
-RUN pip3 install -r requirements.txt
+RUN pip3 install -r requirements.txt && pip3 install encodec --no-deps
 
 CMD /bin/zsh

+ 3 - 2
requirements.txt

@@ -6,5 +6,6 @@ lightning>=2.0.9.post0
 hydra-core>=1.3.2
 pyrootutils>=1.0.4
 tensorboard>=2.14.1
-librosa
-encodec
+natsort>=8.4.0
+einops>=0.7.0
+librosa>=0.10.1

+ 7 - 2
speech_lm/train.py

@@ -8,6 +8,7 @@ from omegaconf import DictConfig, OmegaConf
 from tqdm import tqdm
 from transformers import LlamaForCausalLM
 from transformers.utils import is_flash_attn_available
+from natsort import natsorted
 
 # Allow TF32 on Ampere GPUs
 torch.set_float32_matmul_precision("high")
@@ -125,8 +126,12 @@ def main(cfg: DictConfig):
     # Restore training from checkpoint
     checkpoint_dir = Path(cfg.paths.checkpoint_dir)
     checkpoint_dir.mkdir(parents=True, exist_ok=True)
-    checkpoint_path = checkpoint_dir / "last.ckpt"
-    if checkpoint_path.exists():
+
+    # Alphabetically sort checkpoints
+    checkpoints = natsorted(checkpoint_dir.glob("*.ckpt"))
+    if len(checkpoints) > 0:
+        checkpoint_path = checkpoints[-1]
+
         log.info(f"Restoring checkpoint from {checkpoint_path}")
         remainder = fabric.load(
             checkpoint_path,