Parcourir la source

Fix lora inference

Lengyue il y a 2 ans
Parent
commit
982f78794e
1 fichiers modifiés avec 1 ajouts et 1 suppressions
  1. 1 1
      tools/llama/generate.py

+ 1 - 1
tools/llama/generate.py

@@ -323,7 +323,7 @@ def load_model(config_name, checkpoint_path, device, precision):
         cfg = compose(config_name=config_name)
 
     with torch.device("meta"):
-        model: Transformer = instantiate(cfg.model.model)
+        model: Transformer = instantiate(cfg.model).model
 
     if "int8" in str(checkpoint_path):
         logger.info("Using int8 weight-only quantization!")