소스 검색

Update merge_lora.py (#284)

Naozumi 1 년 전
부모
커밋
fbe2e3f030
1개의 변경된 파일12개의 추가작업 그리고 10개의 파일을 삭제
  1. 12 10
      tools/llama/merge_lora.py

+ 12 - 10
tools/llama/merge_lora.py

@@ -52,16 +52,18 @@ def merge(llama_config, lora_config, llama_weight, lora_weight, output):
         lora_state_dict = lora_state_dict["state_dict"]
 
     # remove prefix model.
-    llama_state_dict = {
-        k.replace("model.", ""): v
-        for k, v in llama_state_dict.items()
-        if k.startswith("model.")
-    }
-    lora_state_dict = {
-        k.replace("model.", ""): v
-        for k, v in lora_state_dict.items()
-        if k.startswith("model.")
-    }
+    if any(k.startswith("model.") for k in llama_state_dict.keys()):
+        llama_state_dict = {
+            k.replace("model.", ""): v
+            for k, v in llama_state_dict.items()
+            if k.startswith("model.")
+        }
+    if any(k.startswith("model.") for k in lora_state_dict.keys()):
+        lora_state_dict = {
+            k.replace("model.", ""): v
+            for k, v in lora_state_dict.items()
+            if k.startswith("model.")
+        }
 
     logger.info(f"Found {len(llama_state_dict)} keys in llama model")
     logger.info(f"Found {len(lora_state_dict)} keys in lora model")