Przeglądaj źródła

Fix lora inference

Lengyue 2 lat temu
rodzic
commit
982f78794e
1 zmienionych plików z 1 dodań i 1 usunięć
  1. 1 1
      tools/llama/generate.py

+ 1 - 1
tools/llama/generate.py

@@ -323,7 +323,7 @@ def load_model(config_name, checkpoint_path, device, precision):
         cfg = compose(config_name=config_name)
 
     with torch.device("meta"):
-        model: Transformer = instantiate(cfg.model.model)
+        model: Transformer = instantiate(cfg.model).model
 
     if "int8" in str(checkpoint_path):
         logger.info("Using int8 weight-only quantization!")