فهرست منبع

Update finetune toolchain

Lengyue 2 سال پیش
والد
کامیت
f5919b7e71

+ 2 - 0
.gitignore

@@ -11,3 +11,5 @@ filelists
 /checkpoints
 /.vscode
 /data_server/target
+/*.npy
+/*.wav

+ 3 - 1
fish_speech/configs/text2semantic_finetune.yaml

@@ -2,8 +2,10 @@ defaults:
   - base
   - _self_
 
-project: text2semantic_400m_multi
+project: text2semantic_400m_finetune
 max_length: 4096
+ckpt_path: results/text2semantic_400m_pretrain/checkpoints/step_000065000.ckpt
+resume_weights_only: true
 
 # Lightning Trainer
 trainer:

+ 1 - 1
fish_speech/configs/text2semantic_pretrain.yaml

@@ -2,7 +2,7 @@ defaults:
   - base
   - _self_
 
-project: text2semantic_400m_multi
+project: text2semantic_400m_pretrain
 max_length: 1024
 
 # Lightning Trainer

+ 0 - 0
fish_speech/configs/vqgan.yaml → fish_speech/configs/vqgan_pretrain.yaml


+ 8 - 4
fish_speech/train.py

@@ -76,19 +76,23 @@ def train(cfg: DictConfig) -> tuple[dict, dict]:
         log.info("Starting training!")
 
         ckpt_path = cfg.get("ckpt_path")
+        auto_resume = False
 
         if ckpt_path is None:
             ckpt_path = utils.get_latest_checkpoint(cfg.paths.ckpt_dir)
+            auto_resume = True
 
         if ckpt_path is not None:
             log.info(f"Resuming from checkpoint: {ckpt_path}")
 
-        if cfg.get("resume_weights_only"):
+        # resume weights only is disabled for auto-resume
+        if cfg.get("resume_weights_only") and auto_resume is False:
             log.info("Resuming weights only!")
             ckpt = torch.load(ckpt_path, map_location=model.device)
-            model.load_state_dict(
-                ckpt["state_dict"] if "state_dict" in ckpt else ckpt, strict=False
-            )
+            if "state_dict" in ckpt:
+                ckpt = ckpt["state_dict"]
+            err = model.load_state_dict(ckpt, strict=False)
+            log.info(f"Error loading state dict: {err}")
             ckpt_path = None
 
         trainer.fit(model=model, datamodule=datamodule, ckpt_path=ckpt_path)

+ 29 - 0
tools/extract_model.py

@@ -0,0 +1,29 @@
+import click
+import torch
+from loguru import logger
+
+
+@click.command()
+@click.argument("model_path")
+@click.argument("output_path")
+def main(model_path, output_path):
+    if model_path == output_path:
+        logger.error("Model path and output path are the same")
+        click.Abort()
+
+    logger.info(f"Loading model from {model_path}")
+    state_dict = torch.load(model_path, map_location="cpu")["state_dict"]
+    logger.info("Extracting model")
+
+    state_dict = {
+        state_dict: value
+        for state_dict, value in state_dict.items()
+        if state_dict.startswith("model.")
+    }
+
+    torch.save(state_dict, output_path)
+    logger.info(f"Model saved to {output_path}")
+
+
+if __name__ == "__main__":
+    main()

+ 0 - 12
tools/llama/extract_model.py

@@ -1,12 +0,0 @@
-import torch
-
-state_dict = torch.load(
-    "results/text2semantic_400m/checkpoints/step_000095000.ckpt", map_location="cpu"
-)["state_dict"]
-state_dict = {
-    state_dict.replace("model.", ""): value
-    for state_dict, value in state_dict.items()
-    if state_dict.startswith("model.")
-}
-
-torch.save(state_dict, "results/text2semantic_400m/step_000095000_weights.ckpt")