Quellcode durchsuchen

Add preprocessing and validation steps when loading the checkpoint. (#433)

* Add preprocessing and validation steps when loading the checkpoint.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: liupeng <laupeng1989@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
bfs18 vor 1 Jahr
Ursprung
Commit
dfc95d0f5c
1 geänderte Dateien mit 27 neuen und 0 gelöschten Zeilen
  1. 27 0
      fish_speech/models/text2semantic/llama.py

+ 27 - 0
fish_speech/models/text2semantic/llama.py

@@ -1,5 +1,6 @@
 import json
 import math
+from collections import OrderedDict
 from dataclasses import dataclass
 from pathlib import Path
 from typing import Optional
@@ -370,6 +371,32 @@ class BaseTransformer(nn.Module):
             weights = torch.load(
                 Path(path) / "model.pth", map_location="cpu", mmap=True
             )
+
+            if "state_dict" in weights:
+                logger.warning(
+                    "Using a TextToSemantic LightningModule checkpoint, "
+                    "please make sure it is a full model, not a LoRA model."
+                )
+                weights = weights["state_dict"]
+
+            if next(iter(weights.keys())).startswith("model."):
+                logger.info(
+                    f"Remove prefix 'model.' created by TextToSemantic LightningModule from keys"
+                )
+                new_weights = OrderedDict()
+                for k, v in weights.items():
+                    new_weights[k.replace("model.", "")] = v
+                weights = new_weights
+
+            # Verify the name and shape of parameters since strict=False in load_state_dict.
+            for k, v in model.named_parameters():
+                if k not in weights:
+                    logger.warning(f"No weight for {k}")
+                elif v.shape != weights[k].shape:
+                    logger.warning(
+                        f"Shape mismatch for {k}: {v.shape} vs {weights[k].shape}"
+                    )
+
             err = model.load_state_dict(weights, strict=False, assign=True)
             log.info(f"Loaded weights with error: {err}")