Sfoglia il codice sorgente

Fix lora inference

Lengyue 2 anni fa
parent
commit
982f78794e
1 ha cambiato i file con 1 aggiunte e 1 eliminazioni
  1. 1 1
      tools/llama/generate.py

+ 1 - 1
tools/llama/generate.py

@@ -323,7 +323,7 @@ def load_model(config_name, checkpoint_path, device, precision):
         cfg = compose(config_name=config_name)
 
     with torch.device("meta"):
-        model: Transformer = instantiate(cfg.model.model)
+        model: Transformer = instantiate(cfg.model).model
 
     if "int8" in str(checkpoint_path):
         logger.info("Using int8 weight-only quantization!")