|
@@ -116,7 +116,6 @@ def train(
|
|
|
# Update trackers
|
|
# Update trackers
|
|
|
trackers["loss"].append(float(loss))
|
|
trackers["loss"].append(float(loss))
|
|
|
trackers["lr"].append(float(optimizer.param_groups[0]["lr"]))
|
|
trackers["lr"].append(float(optimizer.param_groups[0]["lr"]))
|
|
|
- trackers["grad_norm"].append(float(grad_norm))
|
|
|
|
|
for k, v in metrics.items():
|
|
for k, v in metrics.items():
|
|
|
trackers[f"metrics/{k}"].append(float(v))
|
|
trackers[f"metrics/{k}"].append(float(v))
|
|
|
|
|
|
|
@@ -127,7 +126,7 @@ def train(
|
|
|
|
|
|
|
|
# Check all trackers has the same length
|
|
# Check all trackers has the same length
|
|
|
assert (
|
|
assert (
|
|
|
- len(set(len(v) for v in trackers.values())) == 1
|
|
|
|
|
|
|
+ len(set(len(v) for k, v in trackers.items() if k != "grad_norm")) == 1
|
|
|
), "Trackers has ambiguous length"
|
|
), "Trackers has ambiguous length"
|
|
|
|
|
|
|
|
# Perform gradient clipping
|
|
# Perform gradient clipping
|
|
@@ -135,6 +134,9 @@ def train(
|
|
|
model, optimizer, max_norm=cfg.schedule.clip_grad_norm, norm_type=2.0
|
|
model, optimizer, max_norm=cfg.schedule.clip_grad_norm, norm_type=2.0
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
|
|
+ # We can't average gradients across multiple steps
|
|
|
|
|
+ trackers["grad_norm"].append(float(grad_norm))
|
|
|
|
|
+
|
|
|
# Update
|
|
# Update
|
|
|
optimizer.step()
|
|
optimizer.step()
|
|
|
optimizer.zero_grad()
|
|
optimizer.zero_grad()
|
|
@@ -163,12 +165,15 @@ def train(
|
|
|
|
|
|
|
|
log.info(
|
|
log.info(
|
|
|
f"[{global_step}/{cfg.schedule.max_steps}] "
|
|
f"[{global_step}/{cfg.schedule.max_steps}] "
|
|
|
- + f"step time: {step_time:.2f}s "
|
|
|
|
|
- + f"ETA: {timedelta(round(eta))}s "
|
|
|
|
|
|
|
+ + f"step_time: {step_time:.2f}s "
|
|
|
|
|
+ + f"ETA: {timedelta(seconds=round(eta))}s "
|
|
|
f"lr: {optimizer.param_groups[0]['lr']:.2e} "
|
|
f"lr: {optimizer.param_groups[0]['lr']:.2e} "
|
|
|
+ " ".join(additional_info)
|
|
+ " ".join(additional_info)
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
|
|
+ # Reset trackers
|
|
|
|
|
+ trackers = defaultdict(list)
|
|
|
|
|
+
|
|
|
start_time = time.time()
|
|
start_time = time.time()
|
|
|
|
|
|
|
|
if global_step % cfg.schedule.save_interval == 0:
|
|
if global_step % cfg.schedule.save_interval == 0:
|
|
@@ -183,14 +188,17 @@ def train(
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
if (
|
|
if (
|
|
|
- global_step % cfg.schedule.eval_interval == 0
|
|
|
|
|
|
|
+ getattr(cfg.schedule, "eval_interval", None) is not None
|
|
|
|
|
+ and global_step % cfg.schedule.eval_interval == 0
|
|
|
and valid_dataloader is not None
|
|
and valid_dataloader is not None
|
|
|
):
|
|
):
|
|
|
- valid(model, valid_dataloader, fabric, global_step, cfg)
|
|
|
|
|
|
|
+ valid(model, valid_dataloader, global_step, fabric, cfg)
|
|
|
|
|
|
|
|
if global_step >= cfg.schedule.max_steps:
|
|
if global_step >= cfg.schedule.max_steps:
|
|
|
break
|
|
break
|
|
|
|
|
|
|
|
|
|
+ last_batch_time = time.time()
|
|
|
|
|
+
|
|
|
|
|
|
|
|
@hydra.main(version_base="1.3", config_path="./configs", config_name="pretrain.yaml")
|
|
@hydra.main(version_base="1.3", config_path="./configs", config_name="pretrain.yaml")
|
|
|
def main(cfg: DictConfig):
|
|
def main(cfg: DictConfig):
|
|
@@ -248,8 +256,8 @@ def main(cfg: DictConfig):
|
|
|
log.info(f"Train Dataloader: {train_dataloader}")
|
|
log.info(f"Train Dataloader: {train_dataloader}")
|
|
|
|
|
|
|
|
valid_dataloader = None
|
|
valid_dataloader = None
|
|
|
- if getattr(train_dataloader, "valid_dataloader", None) is not None:
|
|
|
|
|
- valid_dataloader = hydra.utils.instantiate(train_dataloader.valid_dataloader)
|
|
|
|
|
|
|
+ if getattr(cfg, "valid_dataloader", None) is not None:
|
|
|
|
|
+ valid_dataloader = hydra.utils.instantiate(cfg.valid_dataloader)
|
|
|
log.info(f"Valid Dataloader: {valid_dataloader}")
|
|
log.info(f"Valid Dataloader: {valid_dataloader}")
|
|
|
|
|
|
|
|
train_dataloader = fabric.setup_dataloaders(train_dataloader)
|
|
train_dataloader = fabric.setup_dataloaders(train_dataloader)
|