瀏覽代碼

Fix dataset and training loop

Lengyue 2 年之前
父節點
當前提交
01cf9bbafd
共有 2 個文件被更改,包括 5 次插入5 次删除
  1. 2 2
      speech_lm/datasets/cultura_x.py
  2. 3 3
      speech_lm/train.py

+ 2 - 2
speech_lm/datasets/cultura_x.py

@@ -67,8 +67,8 @@ class CulturaXDataset(IterableDataset):
         for filename in files:
             try:
                 yield from self.parse_data(filename)
-            except:
-                log.exception(f"Failed to parse {filename}")
+            except Exception as e:
+                log.exception(f"Failed to parse {filename}: {e}")
 
     def parse_data(self, filename: str):
         url = f"https://huggingface.co/datasets/uonlp/CulturaX/resolve/main/{filename}"

+ 3 - 3
speech_lm/train.py

@@ -117,6 +117,9 @@ def main(cfg: DictConfig):
     log.info(f"Optimizer: {optimizer}")
     log.info(f"Scheduler: {scheduler}")
 
+    log.info(f"Setup fabric model & dataset")
+    model, optimizer, scheduler = fabric.setup(model, optimizer, scheduler)
+
     # Build state
     global_step = 0
 
@@ -141,9 +144,6 @@ def main(cfg: DictConfig):
         global_step = remainder["global_step"]
         log.info(f"Restored global step: {global_step}")
 
-    log.info(f"Setup fabric model & dataset")
-    model, optimizer, scheduler = fabric.setup(model, optimizer, scheduler)
-
     train_dataloader = hydra.utils.instantiate(cfg.dataloader)
     log.info(f"Dataloader: {train_dataloader}")