|
|
@@ -52,16 +52,18 @@ def merge(llama_config, lora_config, llama_weight, lora_weight, output):
|
|
|
lora_state_dict = lora_state_dict["state_dict"]
|
|
|
|
|
|
# remove prefix model.
|
|
|
- llama_state_dict = {
|
|
|
- k.replace("model.", ""): v
|
|
|
- for k, v in llama_state_dict.items()
|
|
|
- if k.startswith("model.")
|
|
|
- }
|
|
|
- lora_state_dict = {
|
|
|
- k.replace("model.", ""): v
|
|
|
- for k, v in lora_state_dict.items()
|
|
|
- if k.startswith("model.")
|
|
|
- }
|
|
|
+ if any(k.startswith("model.") for k in llama_state_dict.keys()):
|
|
|
+ llama_state_dict = {
|
|
|
+ k.replace("model.", ""): v
|
|
|
+ for k, v in llama_state_dict.items()
|
|
|
+ if k.startswith("model.")
|
|
|
+ }
|
|
|
+ if any(k.startswith("model.") for k in lora_state_dict.keys()):
|
|
|
+ lora_state_dict = {
|
|
|
+ k.replace("model.", ""): v
|
|
|
+ for k, v in lora_state_dict.items()
|
|
|
+ if k.startswith("model.")
|
|
|
+ }
|
|
|
|
|
|
logger.info(f"Found {len(llama_state_dict)} keys in llama model")
|
|
|
logger.info(f"Found {len(lora_state_dict)} keys in lora model")
|