|
|
@@ -0,0 +1,292 @@
|
|
|
+import time
|
|
|
+from collections import defaultdict
|
|
|
+from datetime import timedelta
|
|
|
+from pathlib import Path
|
|
|
+from typing import Optional
|
|
|
+
|
|
|
+import hydra
|
|
|
+import torch
|
|
|
+from lightning.fabric import Fabric
|
|
|
+from natsort import natsorted
|
|
|
+from omegaconf import DictConfig, OmegaConf
|
|
|
+from tqdm import tqdm
|
|
|
+from transformers import LlamaForCausalLM
|
|
|
+from transformers.utils import is_flash_attn_available
|
|
|
+
|
|
|
+from fish_speech.logger import RankedLogger
|
|
|
+
|
|
|
+# Allow TF32 on Ampere GPUs
|
|
|
+torch.set_float32_matmul_precision("high")
|
|
|
+torch.backends.cudnn.allow_tf32 = True
|
|
|
+
|
|
|
+# register eval resolver
|
|
|
+OmegaConf.register_new_resolver("eval", eval)
|
|
|
+
|
|
|
+log = RankedLogger(__name__, rank_zero_only=True)
|
|
|
+
|
|
|
+
|
|
|
+def valid(
|
|
|
+ model: LlamaForCausalLM,
|
|
|
+ valid_dataloader: Optional[torch.utils.data.DataLoader],
|
|
|
+ global_step: int,
|
|
|
+ fabric: Fabric,
|
|
|
+ cfg: DictConfig,
|
|
|
+):
|
|
|
+ model.eval()
|
|
|
+ log.info(f"Evaluating at step {global_step}")
|
|
|
+
|
|
|
+ accumulate_infos = None
|
|
|
+
|
|
|
+ for idx, batch in enumerate(tqdm(valid_dataloader, desc="Evaluating")):
|
|
|
+ outputs = model(**batch)
|
|
|
+ loss = outputs.loss
|
|
|
+ metrics = getattr(outputs, "metrics", {})
|
|
|
+ log_info = {
|
|
|
+ "valid/loss": float(loss),
|
|
|
+ **{f"valid/{k}": float(v) for k, v in metrics.items()},
|
|
|
+ }
|
|
|
+
|
|
|
+ fabric.log_dict(
|
|
|
+ log_info,
|
|
|
+ step=global_step + idx,
|
|
|
+ )
|
|
|
+
|
|
|
+ # Update log info
|
|
|
+ if accumulate_infos is None:
|
|
|
+ accumulate_infos = log_info
|
|
|
+ else:
|
|
|
+ assert set(accumulate_infos.keys()) == set(
|
|
|
+ log_info.keys()
|
|
|
+ ), "Log keys changed during evaluation"
|
|
|
+ for k in accumulate_infos.keys():
|
|
|
+ accumulate_infos[k] += log_info[k]
|
|
|
+
|
|
|
+ if idx == getattr(cfg.schedule, "eval_max_batches", None):
|
|
|
+ break
|
|
|
+
|
|
|
+ # Log average
|
|
|
+ items = []
|
|
|
+ for k in accumulate_infos.keys():
|
|
|
+ items.append(f"{k}: {accumulate_infos[k] / (idx + 1):.4f}")
|
|
|
+ log.info(f"Average: {' | '.join(items)}")
|
|
|
+
|
|
|
+
|
|
|
+def train(
|
|
|
+ model: LlamaForCausalLM,
|
|
|
+ optimizer: torch.optim.Optimizer,
|
|
|
+ scheduler: torch.optim.lr_scheduler._LRScheduler,
|
|
|
+ train_dataloader: torch.utils.data.DataLoader,
|
|
|
+ valid_dataloader: Optional[torch.utils.data.DataLoader],
|
|
|
+ global_step: int,
|
|
|
+ fabric: Fabric,
|
|
|
+ cfg: DictConfig,
|
|
|
+):
|
|
|
+ accumulate_steps = 0
|
|
|
+ optimizer.zero_grad()
|
|
|
+
|
|
|
+ # Start time is ~model forward time + data loading time
|
|
|
+ start_time = time.time()
|
|
|
+ trackers = defaultdict(list)
|
|
|
+
|
|
|
+ while global_step < cfg.schedule.max_steps:
|
|
|
+ last_batch_time = time.time()
|
|
|
+ for batch in train_dataloader:
|
|
|
+ # Measure time used by data loading
|
|
|
+ trackers["data_time"].append(time.time() - last_batch_time)
|
|
|
+
|
|
|
+ # Measure time used by model forward
|
|
|
+ model_begin_time = time.time()
|
|
|
+ model.train()
|
|
|
+
|
|
|
+ # Accumulate gradients
|
|
|
+ gradient_accumulation_steps = cfg.schedule.gradient_accumulation_steps
|
|
|
+ is_accumulating = accumulate_steps % gradient_accumulation_steps != 0
|
|
|
+ accumulate_steps += 1
|
|
|
+
|
|
|
+ # Train one step
|
|
|
+ with fabric.no_backward_sync(model, enabled=is_accumulating):
|
|
|
+ outputs = model(**batch)
|
|
|
+ loss = outputs.loss
|
|
|
+ metrics = getattr(outputs, "metrics", {})
|
|
|
+
|
|
|
+ # Need to divide loss by accumulation steps
|
|
|
+ fabric.backward(loss / gradient_accumulation_steps)
|
|
|
+
|
|
|
+ # Update trackers
|
|
|
+ trackers["loss"].append(float(loss))
|
|
|
+ trackers["lr"].append(float(optimizer.param_groups[0]["lr"]))
|
|
|
+ for k, v in metrics.items():
|
|
|
+ trackers[f"metrics/{k}"].append(float(v))
|
|
|
+
|
|
|
+ trackers["model_time"].append(time.time() - model_begin_time)
|
|
|
+
|
|
|
+ if is_accumulating:
|
|
|
+ last_batch_time = time.time()
|
|
|
+ continue
|
|
|
+
|
|
|
+ # Check all trackers has the same length
|
|
|
+ assert (
|
|
|
+ len(set(len(v) for k, v in trackers.items() if k != "grad_norm")) == 1
|
|
|
+ ), "Trackers has ambiguous length"
|
|
|
+
|
|
|
+ # Perform gradient clipping
|
|
|
+ grad_norm = fabric.clip_gradients(
|
|
|
+ model,
|
|
|
+ optimizer,
|
|
|
+ max_norm=cfg.schedule.clip_grad_norm,
|
|
|
+ norm_type=2.0,
|
|
|
+ error_if_nonfinite=True,
|
|
|
+ )
|
|
|
+
|
|
|
+ if torch.isnan(grad_norm) or torch.isinf(grad_norm):
|
|
|
+ log.warning(f"Gradient norm is {grad_norm}, skipping update")
|
|
|
+ optimizer.zero_grad()
|
|
|
+
|
|
|
+ # We can't average gradients across multiple steps
|
|
|
+ trackers["grad_norm"].append(float(grad_norm))
|
|
|
+
|
|
|
+ # Update
|
|
|
+ optimizer.step()
|
|
|
+ optimizer.zero_grad()
|
|
|
+ scheduler.step()
|
|
|
+
|
|
|
+ fabric.log_dict(
|
|
|
+ {
|
|
|
+ f"train/{k}": sum(v[-gradient_accumulation_steps:])
|
|
|
+ / len(v[-gradient_accumulation_steps:])
|
|
|
+ for k, v in trackers.items()
|
|
|
+ },
|
|
|
+ step=global_step,
|
|
|
+ )
|
|
|
+
|
|
|
+ # accumulate_steps = 0
|
|
|
+ global_step += 1
|
|
|
+
|
|
|
+ if global_step % cfg.schedule.log_interval == 0:
|
|
|
+ step_time = (time.time() - start_time) / cfg.schedule.log_interval
|
|
|
+ eta = step_time * (cfg.schedule.max_steps - global_step)
|
|
|
+ additional_info = [
|
|
|
+ f"{k}: {sum(v[-cfg.schedule.log_interval:]) / len(v[-cfg.schedule.log_interval:]):.4f}"
|
|
|
+ for k, v in trackers.items()
|
|
|
+ if k != "lr" # lr use .2e format
|
|
|
+ ]
|
|
|
+
|
|
|
+ log.info(
|
|
|
+ f"[{global_step}/{cfg.schedule.max_steps}] "
|
|
|
+ + f"step_time: {step_time:.2f}s "
|
|
|
+ + f"ETA: {timedelta(seconds=round(eta))}s "
|
|
|
+ f"lr: {optimizer.param_groups[0]['lr']:.2e} "
|
|
|
+ + " ".join(additional_info)
|
|
|
+ )
|
|
|
+
|
|
|
+ # Reset trackers
|
|
|
+ trackers = defaultdict(list)
|
|
|
+
|
|
|
+ start_time = time.time()
|
|
|
+
|
|
|
+ if global_step % cfg.schedule.save_interval == 0:
|
|
|
+ fabric.save(
|
|
|
+ Path(cfg.paths.checkpoint_dir) / f"step_{global_step}.ckpt",
|
|
|
+ {
|
|
|
+ "model": model,
|
|
|
+ "optimizer": optimizer,
|
|
|
+ "scheduler": scheduler.state_dict(),
|
|
|
+ "global_step": global_step,
|
|
|
+ },
|
|
|
+ )
|
|
|
+
|
|
|
+ if (
|
|
|
+ getattr(cfg.schedule, "eval_interval", None) is not None
|
|
|
+ and global_step % cfg.schedule.eval_interval == 0
|
|
|
+ and valid_dataloader is not None
|
|
|
+ ):
|
|
|
+ valid(model, valid_dataloader, global_step, fabric, cfg)
|
|
|
+
|
|
|
+ if global_step >= cfg.schedule.max_steps:
|
|
|
+ break
|
|
|
+
|
|
|
+ last_batch_time = time.time()
|
|
|
+
|
|
|
+
|
|
|
+@hydra.main(
|
|
|
+ version_base="1.3", config_path="./configs", config_name="llama_pretrain.yaml"
|
|
|
+)
|
|
|
+def main(cfg: DictConfig):
|
|
|
+ log.info(f"Config: \n{OmegaConf.to_yaml(cfg)}")
|
|
|
+
|
|
|
+ if is_flash_attn_available() is False:
|
|
|
+ log.warning("Flash attention is not available, using default attention")
|
|
|
+
|
|
|
+ fabric: Fabric = hydra.utils.instantiate(cfg.trainer)
|
|
|
+ fabric.launch()
|
|
|
+ log.info(f"Fabric: {fabric}")
|
|
|
+
|
|
|
+ model = hydra.utils.instantiate(cfg.model)
|
|
|
+ log.info(f"Model: {repr(model)}")
|
|
|
+
|
|
|
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
|
+ freeze_params = sum(p.numel() for p in model.parameters() if not p.requires_grad)
|
|
|
+ log.info(f"Trainable parameters: {trainable_params/1e6:.2f}M")
|
|
|
+ log.info(f"Freeze parameters: {freeze_params/1e6:.2f}M")
|
|
|
+
|
|
|
+ optimizer = hydra.utils.instantiate(cfg.optimizer, params=model.parameters())
|
|
|
+ scheduler = hydra.utils.instantiate(cfg.scheduler, optimizer=optimizer)
|
|
|
+ log.info(f"Optimizer: {optimizer}")
|
|
|
+ log.info(f"Scheduler: {scheduler}")
|
|
|
+
|
|
|
+ log.info(f"Setup fabric model & dataset")
|
|
|
+ model = fabric.setup_module(model)
|
|
|
+ optimizer = fabric.setup_optimizers(optimizer)
|
|
|
+
|
|
|
+ # Build state
|
|
|
+ global_step = 0
|
|
|
+
|
|
|
+ # Restore training from checkpoint
|
|
|
+ checkpoint_dir = Path(cfg.paths.checkpoint_dir)
|
|
|
+ checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
|
|
+
|
|
|
+ # Alphabetically sort checkpoints
|
|
|
+ checkpoints = natsorted(checkpoint_dir.glob("*.ckpt"))
|
|
|
+ if len(checkpoints) > 0:
|
|
|
+ checkpoint_path = checkpoints[-1]
|
|
|
+
|
|
|
+ log.info(f"Restoring checkpoint from {checkpoint_path}")
|
|
|
+ remainder = fabric.load(
|
|
|
+ checkpoint_path,
|
|
|
+ {
|
|
|
+ "model": model,
|
|
|
+ "optimizer": optimizer,
|
|
|
+ "scheduler": scheduler,
|
|
|
+ },
|
|
|
+ )
|
|
|
+ global_step = remainder["global_step"]
|
|
|
+ log.info(f"Restored global step: {global_step}")
|
|
|
+
|
|
|
+ train_dataloader = hydra.utils.instantiate(cfg.train_dataloader)
|
|
|
+ log.info(f"Train Dataloader: {train_dataloader}")
|
|
|
+
|
|
|
+ valid_dataloader = None
|
|
|
+ if getattr(cfg, "valid_dataloader", None) is not None:
|
|
|
+ valid_dataloader = hydra.utils.instantiate(cfg.valid_dataloader)
|
|
|
+ log.info(f"Valid Dataloader: {valid_dataloader}")
|
|
|
+
|
|
|
+ train_dataloader = fabric.setup_dataloaders(train_dataloader)
|
|
|
+ if valid_dataloader is not None:
|
|
|
+ valid_dataloader = fabric.setup_dataloaders(valid_dataloader)
|
|
|
+
|
|
|
+ log.info(f"Begin training")
|
|
|
+
|
|
|
+ train(
|
|
|
+ model=model,
|
|
|
+ optimizer=optimizer,
|
|
|
+ scheduler=scheduler,
|
|
|
+ train_dataloader=train_dataloader,
|
|
|
+ valid_dataloader=valid_dataloader,
|
|
|
+ global_step=global_step,
|
|
|
+ fabric=fabric,
|
|
|
+ cfg=cfg,
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ main()
|