|
|
@@ -1,5 +1,6 @@
|
|
|
import json
|
|
|
import math
|
|
|
+from collections import OrderedDict
|
|
|
from dataclasses import dataclass
|
|
|
from pathlib import Path
|
|
|
from typing import Optional
|
|
|
@@ -370,6 +371,32 @@ class BaseTransformer(nn.Module):
|
|
|
weights = torch.load(
|
|
|
Path(path) / "model.pth", map_location="cpu", mmap=True
|
|
|
)
|
|
|
+
|
|
|
+ if "state_dict" in weights:
|
|
|
+ logger.warning(
|
|
|
+ "Using a TextToSemantic LightningModule checkpoint, "
|
|
|
+ "please make sure it is a full model, not a LoRA model."
|
|
|
+ )
|
|
|
+ weights = weights["state_dict"]
|
|
|
+
|
|
|
+ if next(iter(weights.keys())).startswith("model."):
|
|
|
+ logger.info(
|
|
|
+ f"Remove prefix 'model.' created by TextToSemantic LightningModule from keys"
|
|
|
+ )
|
|
|
+ new_weights = OrderedDict()
|
|
|
+ for k, v in weights.items():
|
|
|
+ new_weights[k.replace("model.", "")] = v
|
|
|
+ weights = new_weights
|
|
|
+
|
|
|
+ # Verify the name and shape of parameters since strict=False in load_state_dict.
|
|
|
+ for k, v in model.named_parameters():
|
|
|
+ if k not in weights:
|
|
|
+ logger.warning(f"No weight for {k}")
|
|
|
+ elif v.shape != weights[k].shape:
|
|
|
+ logger.warning(
|
|
|
+ f"Shape mismatch for {k}: {v.shape} vs {weights[k].shape}"
|
|
|
+ )
|
|
|
+
|
|
|
err = model.load_state_dict(weights, strict=False, assign=True)
|
|
|
log.info(f"Loaded weights with error: {err}")
|
|
|
|