Procházet zdrojové kódy

Optimize training loop & upgrade transformers

Lengyue před 2 roky
rodič
revize
23109ecb94
3 změnil soubory, kde provedl 12 přidání a 8 odebrání
  1. 1 1
      requirements.txt
  2. 10 3
      speech_lm/configs/whisper_vq.yaml
  3. 1 4
      speech_lm/train.py

+ 1 - 1
requirements.txt

@@ -1,4 +1,4 @@
-transformers>=4.34.0
+transformers>=4.34.1
 datasets>=2.14.5
 bitsandbytes>=0.41.1
 peft>=0.5.0

+ 10 - 3
speech_lm/configs/whisper_vq.yaml

@@ -43,7 +43,7 @@ schedule:
   save_interval: 2000
   gradient_accumulation_steps: "${eval: ${schedule.batch_size} // ${schedule.micro_batch_size}}"
   clip_grad_norm: 2.0
-  log_interval: 10
+  log_interval: 50
   eval_interval: 2000
 
 train_dataloader:
@@ -52,7 +52,11 @@ train_dataloader:
     _target_: speech_lm.datasets.whisper_vq.WhisperVQDataset
     filelist: filelists/whisper-vq.train.filelist
   batch_size: ${schedule.micro_batch_size}
-  num_workers: 8
+  num_workers: 16
+  prefetch_factor: 4
+  pin_memory: true
+  persistent_workers: true
+  shuffle: true
   collate_fn:
     _target_: speech_lm.datasets.whisper_vq.WhisperVQCollator
 
@@ -62,7 +66,10 @@ valid_dataloader:
     _target_: speech_lm.datasets.whisper_vq.WhisperVQDataset
     filelist: filelists/whisper-vq.test.filelist
   batch_size: 32
-  num_workers: 4
+  num_workers: 8
+  prefetch_factor: 4
+  pin_memory: true
+  shuffle: false
   collate_fn:
     _target_: speech_lm.datasets.whisper_vq.WhisperVQCollator
 

+ 1 - 4
speech_lm/train.py

@@ -37,7 +37,7 @@ def valid(
 
     accumulate_infos = None
 
-    for idx, batch in tqdm(enumerate(valid_dataloader), desc="Evaluating"):
+    for idx, batch in enumerate(tqdm(valid_dataloader, desc="Evaluating")):
         outputs = model(**batch)
         loss = outputs.loss
         metrics = getattr(outputs, "metrics", {})
@@ -81,8 +81,6 @@ def train(
     fabric: Fabric,
     cfg: DictConfig,
 ):
-    bar = tqdm(total=cfg.schedule.max_steps, desc="Training")
-    bar.update(global_step)
     accumulate_steps = 0
     optimizer.zero_grad()
 
@@ -163,7 +161,6 @@ def train(
             )
 
             global_step += 1
-            bar.update(1)
 
             if global_step % cfg.schedule.log_interval == 0:
                 step_time = (time.time() - start_time) / cfg.schedule.log_interval