@@ -323,7 +323,7 @@ def load_model(config_name, checkpoint_path, device, precision):
cfg = compose(config_name=config_name)
with torch.device("meta"):
- model: Transformer = instantiate(cfg.model.model)
+ model: Transformer = instantiate(cfg.model).model
if "int8" in str(checkpoint_path):
logger.info("Using int8 weight-only quantization!")