|
@@ -323,7 +323,7 @@ def load_model(config_name, checkpoint_path, device, precision):
|
|
|
cfg = compose(config_name=config_name)
|
|
cfg = compose(config_name=config_name)
|
|
|
|
|
|
|
|
with torch.device("meta"):
|
|
with torch.device("meta"):
|
|
|
- model: Transformer = instantiate(cfg.model.model)
|
|
|
|
|
|
|
+ model: Transformer = instantiate(cfg.model).model
|
|
|
|
|
|
|
|
if "int8" in str(checkpoint_path):
|
|
if "int8" in str(checkpoint_path):
|
|
|
logger.info("Using int8 weight-only quantization!")
|
|
logger.info("Using int8 weight-only quantization!")
|