Lengyue пре 2 година
родитељ
комит
982f78794e
1 измењених фајлова са 1 додато и 1 уклоњено
  1. 1 1
      tools/llama/generate.py

+ 1 - 1
tools/llama/generate.py

@@ -323,7 +323,7 @@ def load_model(config_name, checkpoint_path, device, precision):
         cfg = compose(config_name=config_name)
 
     with torch.device("meta"):
-        model: Transformer = instantiate(cfg.model.model)
+        model: Transformer = instantiate(cfg.model).model
 
     if "int8" in str(checkpoint_path):
         logger.info("Using int8 weight-only quantization!")