|
|
@@ -37,7 +37,7 @@ def valid(
|
|
|
|
|
|
accumulate_infos = None
|
|
|
|
|
|
- for idx, batch in tqdm(enumerate(valid_dataloader), desc="Evaluating"):
|
|
|
+ for idx, batch in enumerate(tqdm(valid_dataloader, desc="Evaluating")):
|
|
|
outputs = model(**batch)
|
|
|
loss = outputs.loss
|
|
|
metrics = getattr(outputs, "metrics", {})
|
|
|
@@ -81,8 +81,6 @@ def train(
|
|
|
fabric: Fabric,
|
|
|
cfg: DictConfig,
|
|
|
):
|
|
|
- bar = tqdm(total=cfg.schedule.max_steps, desc="Training")
|
|
|
- bar.update(global_step)
|
|
|
accumulate_steps = 0
|
|
|
optimizer.zero_grad()
|
|
|
|
|
|
@@ -163,7 +161,6 @@ def train(
|
|
|
)
|
|
|
|
|
|
global_step += 1
|
|
|
- bar.update(1)
|
|
|
|
|
|
if global_step % cfg.schedule.log_interval == 0:
|
|
|
step_time = (time.time() - start_time) / cfg.schedule.log_interval
|