Przeglądaj źródła

Update merge_lora.py (#284)

Naozumi 1 rok temu
rodzic
commit
fbe2e3f030
1 zmienionych plików z 12 dodań i 10 usunięć
  1. 12 10
      tools/llama/merge_lora.py

+ 12 - 10
tools/llama/merge_lora.py

@@ -52,16 +52,18 @@ def merge(llama_config, lora_config, llama_weight, lora_weight, output):
         lora_state_dict = lora_state_dict["state_dict"]
 
     # remove prefix model.
-    llama_state_dict = {
-        k.replace("model.", ""): v
-        for k, v in llama_state_dict.items()
-        if k.startswith("model.")
-    }
-    lora_state_dict = {
-        k.replace("model.", ""): v
-        for k, v in lora_state_dict.items()
-        if k.startswith("model.")
-    }
+    if any(k.startswith("model.") for k in llama_state_dict.keys()):
+        llama_state_dict = {
+            k.replace("model.", ""): v
+            for k, v in llama_state_dict.items()
+            if k.startswith("model.")
+        }
+    if any(k.startswith("model.") for k in lora_state_dict.keys()):
+        lora_state_dict = {
+            k.replace("model.", ""): v
+            for k, v in lora_state_dict.items()
+            if k.startswith("model.")
+        }
 
     logger.info(f"Found {len(llama_state_dict)} keys in llama model")
     logger.info(f"Found {len(lora_state_dict)} keys in lora model")