Explorar el Código

Fix model training mode

Lengyue hace 2 años
padre
commit
0dae48d2dc
Se han modificado 1 ficheros con 2 adiciones y 0 borrados
  1. 2 0
      speech_lm/train.py

+ 2 - 0
speech_lm/train.py

@@ -32,6 +32,8 @@ def train(
     fabric: Fabric,
     fabric: Fabric,
     cfg: DictConfig,
     cfg: DictConfig,
 ):
 ):
+    model.train()
+
     bar = tqdm(total=cfg.schedule.max_steps, desc="Training")
     bar = tqdm(total=cfg.schedule.max_steps, desc="Training")
     bar.update(global_step)
     bar.update(global_step)
     accumulate_steps = 0
     accumulate_steps = 0