|
|
@@ -1,6 +1,8 @@
|
|
|
+from collections import defaultdict
|
|
|
import time
|
|
|
from datetime import timedelta
|
|
|
from pathlib import Path
|
|
|
+from typing import Optional
|
|
|
|
|
|
import hydra
|
|
|
import torch
|
|
|
@@ -23,25 +25,81 @@ OmegaConf.register_new_resolver("eval", eval)
|
|
|
log = RankedLogger(__name__, rank_zero_only=True)
|
|
|
|
|
|
|
|
|
+def valid(
|
|
|
+ model: LlamaForCausalLM,
|
|
|
+ valid_dataloader: Optional[torch.utils.data.DataLoader],
|
|
|
+ global_step: int,
|
|
|
+ fabric: Fabric,
|
|
|
+ cfg: DictConfig,
|
|
|
+):
|
|
|
+ model.eval()
|
|
|
+ log.info(f"Evaluating at step {global_step}")
|
|
|
+
|
|
|
+ accumulate_infos = None
|
|
|
+
|
|
|
+ for idx, batch in tqdm(enumerate(valid_dataloader), desc="Evaluating"):
|
|
|
+ outputs = model(**batch)
|
|
|
+ loss = outputs.loss
|
|
|
+ metrics = getattr(outputs, "metrics", {})
|
|
|
+ log_info = {
|
|
|
+ "valid/loss": float(loss),
|
|
|
+ **{f"valid/{k}": float(v) for k, v in metrics.items()},
|
|
|
+ }
|
|
|
+
|
|
|
+ fabric.log_dict(
|
|
|
+ log_info,
|
|
|
+ step=global_step + idx,
|
|
|
+ )
|
|
|
+
|
|
|
+ # Update log info
|
|
|
+ if accumulate_infos is None:
|
|
|
+ accumulate_infos = log_info
|
|
|
+ else:
|
|
|
+ assert set(accumulate_infos.keys()) == set(
|
|
|
+ log_info.keys()
|
|
|
+ ), "Log keys changed during evaluation"
|
|
|
+ for k in accumulate_infos.keys():
|
|
|
+ accumulate_infos[k] += log_info[k]
|
|
|
+
|
|
|
+ if idx == getattr(cfg.schedule, "eval_max_batches", None):
|
|
|
+ break
|
|
|
+
|
|
|
+ # Log average
|
|
|
+ items = []
|
|
|
+ for k in accumulate_infos.keys():
|
|
|
+ items.append(f"{k}: {accumulate_infos[k] / (idx + 1):.4f}")
|
|
|
+ log.info(f"Average: {' | '.join(items)}")
|
|
|
+
|
|
|
+
|
|
|
def train(
|
|
|
model: LlamaForCausalLM,
|
|
|
optimizer: torch.optim.Optimizer,
|
|
|
scheduler: torch.optim.lr_scheduler._LRScheduler,
|
|
|
- dataloader: torch.utils.data.DataLoader,
|
|
|
+ train_dataloader: torch.utils.data.DataLoader,
|
|
|
+ valid_dataloader: Optional[torch.utils.data.DataLoader],
|
|
|
global_step: int,
|
|
|
fabric: Fabric,
|
|
|
cfg: DictConfig,
|
|
|
):
|
|
|
- model.train()
|
|
|
-
|
|
|
bar = tqdm(total=cfg.schedule.max_steps, desc="Training")
|
|
|
bar.update(global_step)
|
|
|
accumulate_steps = 0
|
|
|
optimizer.zero_grad()
|
|
|
+
|
|
|
+ # Start time is ~model forward time + data loading time
|
|
|
start_time = time.time()
|
|
|
+ trackers = defaultdict(list)
|
|
|
|
|
|
while global_step < cfg.schedule.max_steps:
|
|
|
- for batch in dataloader:
|
|
|
+ last_batch_time = time.time()
|
|
|
+ for batch in train_dataloader:
|
|
|
+ # Measure time used by data loading
|
|
|
+ trackers["data_time"].append(time.time() - last_batch_time)
|
|
|
+
|
|
|
+ # Measure time used by model forward
|
|
|
+ model_begin_time = time.time()
|
|
|
+ model.train()
|
|
|
+
|
|
|
# Accumulate gradients
|
|
|
is_accumulating = (
|
|
|
accumulate_steps % cfg.schedule.gradient_accumulation_steps != 0
|
|
|
@@ -55,9 +113,25 @@ def train(
|
|
|
metrics = getattr(outputs, "metrics", {})
|
|
|
fabric.backward(loss)
|
|
|
|
|
|
+ # Update trackers
|
|
|
+ trackers["loss"].append(float(loss))
|
|
|
+ trackers["lr"].append(float(optimizer.param_groups[0]["lr"]))
|
|
|
+ trackers["grad_norm"].append(
|
|
|
+ trackers.get("grad_norm", 0) + float(grad_norm)
|
|
|
+ )
|
|
|
+ for k, v in metrics.items():
|
|
|
+ trackers[f"metrics/{k}"].append(float(v))
|
|
|
+
|
|
|
+ trackers["model_time"].append(time.time() - model_begin_time)
|
|
|
+
|
|
|
if is_accumulating:
|
|
|
continue
|
|
|
|
|
|
+ # Check all trackers has the same length
|
|
|
+ assert (
|
|
|
+ len(set(len(v) for v in trackers.values())) == 1
|
|
|
+ ), "Trackers has ambiguous length"
|
|
|
+
|
|
|
# Perform gradient clipping
|
|
|
grad_norm = fabric.clip_gradients(
|
|
|
model, optimizer, max_norm=cfg.schedule.clip_grad_norm, norm_type=2.0
|
|
|
@@ -70,10 +144,9 @@ def train(
|
|
|
|
|
|
fabric.log_dict(
|
|
|
{
|
|
|
- "train/loss": loss,
|
|
|
- "train/lr": optimizer.param_groups[0]["lr"],
|
|
|
- "train/grad_norm": grad_norm,
|
|
|
- **{f"train/{k}": v for k, v in metrics.items()},
|
|
|
+ f"train/{k}": sum(v[-accumulate_steps:])
|
|
|
+ / len(v[-accumulate_steps:])
|
|
|
+ for k, v in trackers.items()
|
|
|
},
|
|
|
step=global_step,
|
|
|
)
|
|
|
@@ -84,12 +157,18 @@ def train(
|
|
|
if global_step % cfg.schedule.log_interval == 0:
|
|
|
step_time = (time.time() - start_time) / cfg.schedule.log_interval
|
|
|
eta = step_time * (cfg.schedule.max_steps - global_step)
|
|
|
+ additional_info = [
|
|
|
+ f"{k}: {sum(v[-cfg.schedule.log_interval:]) / len(v[-cfg.schedule.log_interval:]):.4f}"
|
|
|
+ for k, v in trackers.items()
|
|
|
+ if k != "lr" # lr use .2e format
|
|
|
+ ]
|
|
|
+
|
|
|
log.info(
|
|
|
- f"[{global_step}/{cfg.schedule.max_steps}] loss: {loss:.4f} "
|
|
|
+ f"[{global_step}/{cfg.schedule.max_steps}] "
|
|
|
+ f"step time: {step_time:.2f}s "
|
|
|
+ + f"ETA: {timedelta(round(eta))}s "
|
|
|
f"lr: {optimizer.param_groups[0]['lr']:.2e} "
|
|
|
- + f"grad_norm: {grad_norm:.2f} "
|
|
|
- + f"ETA: {timedelta(round(eta))}s"
|
|
|
+ + " ".join(additional_info)
|
|
|
)
|
|
|
|
|
|
start_time = time.time()
|
|
|
@@ -105,6 +184,12 @@ def train(
|
|
|
},
|
|
|
)
|
|
|
|
|
|
+ if (
|
|
|
+ global_step % cfg.schedule.eval_interval == 0
|
|
|
+ and valid_dataloader is not None
|
|
|
+ ):
|
|
|
+ valid(model, valid_dataloader, fabric, global_step, cfg)
|
|
|
+
|
|
|
if global_step >= cfg.schedule.max_steps:
|
|
|
break
|
|
|
|
|
|
@@ -161,17 +246,26 @@ def main(cfg: DictConfig):
|
|
|
global_step = remainder["global_step"]
|
|
|
log.info(f"Restored global step: {global_step}")
|
|
|
|
|
|
- train_dataloader = hydra.utils.instantiate(cfg.dataloader)
|
|
|
- log.info(f"Dataloader: {train_dataloader}")
|
|
|
+ train_dataloader = hydra.utils.instantiate(cfg.train_dataloader)
|
|
|
+ log.info(f"Train Dataloader: {train_dataloader}")
|
|
|
+
|
|
|
+ valid_dataloader = None
|
|
|
+ if getattr(train_dataloader, "valid_dataloader", None) is not None:
|
|
|
+ valid_dataloader = hydra.utils.instantiate(train_dataloader.valid_dataloader)
|
|
|
+ log.info(f"Valid Dataloader: {valid_dataloader}")
|
|
|
|
|
|
train_dataloader = fabric.setup_dataloaders(train_dataloader)
|
|
|
+ if valid_dataloader is not None:
|
|
|
+ valid_dataloader = fabric.setup_dataloaders(valid_dataloader)
|
|
|
+
|
|
|
log.info(f"Begin training")
|
|
|
|
|
|
train(
|
|
|
model=model,
|
|
|
optimizer=optimizer,
|
|
|
scheduler=scheduler,
|
|
|
- dataloader=train_dataloader,
|
|
|
+ train_dataloader=train_dataloader,
|
|
|
+ valid_dataloader=valid_dataloader,
|
|
|
global_step=global_step,
|
|
|
fabric=fabric,
|
|
|
cfg=cfg,
|