ソースを参照

Better quantizer

Lengyue 2 年 前
コミット
ee7a1beb41
1 ファイル変更27 行追加12 行削除
  1. 27 12
      tools/llama/quantize.py

+ 27 - 12
tools/llama/quantize.py

@@ -10,7 +10,7 @@ import torch
 import torch.nn as nn
 import torch.nn.functional as F
 
-from fish_speech.models.text2semantic.llama import Transformer, find_multiple
+from fish_speech.models.text2semantic.llama import ModelArgs, Transformer, find_multiple
 
 ##### Quantization Primitives ######
 
@@ -419,7 +419,6 @@ def quantize(
     mode: str = "int8",
     # following arguments only available when setting int4 quantization.
     groupsize: int = 128,
-    label: str = "",
 ) -> None:
     assert checkpoint_path.is_file(), checkpoint_path
 
@@ -430,9 +429,28 @@ def quantize(
     t0 = time.time()
 
     with torch.device("meta"):
-        model = Transformer.from_name(checkpoint_path.parent.name)
+        model = Transformer(
+            ModelArgs(
+                max_seq_len=4096,
+                vocab_size=36408,
+                n_layer=24,
+                n_head=16,
+                dim=1024,
+                rope_base=10000,
+                norm_eps=1e-5,
+                num_codebooks=4,  # single codebook
+                codebook_size=168,  # codebook size 160 + 2 special tokens
+            )
+        )
 
     checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
+    if "state_dict" in checkpoint:
+        checkpoint = checkpoint["state_dict"]
+    checkpoint = {
+        k.replace("model.", ""): v
+        for k, v in checkpoint.items()
+        if k.startswith("model.")
+    }
     model.load_state_dict(checkpoint, assign=True)
     model = model.to(dtype=precision, device=device)
 
@@ -444,8 +462,9 @@ def quantize(
         quantized_state_dict = quant_handler.create_quantized_state_dict()
 
         dir_name = checkpoint_path.parent
-        base_name = checkpoint_path.name
-        new_base_name = base_name.replace(".pth", f"{label}int8.pth")
+        base_name = checkpoint_path.stem
+        suffix = checkpoint_path.suffix
+        quantize_path = dir_name / f"{base_name}.int8{suffix}"
 
     elif mode == "int4":
         print(
@@ -456,19 +475,18 @@ def quantize(
 
         dir_name = checkpoint_path.parent
         base_name = checkpoint_path.name
-        new_base_name = base_name.replace(".pth", f"{label}int4.g{groupsize}.pth")
+        suffix = checkpoint_path.suffix
+        quantize_path = dir_name / f"{base_name}.int4.g{groupsize}{suffix}"
 
     else:
         raise ValueError(
             f"Invalid quantization mode {mode} needs to be one of [int8, int4, int4-gpptq]"
         )
 
-    quantize_path = dir_name / new_base_name
     print(f"Writing quantized weights to {quantize_path}")
     quantize_path.unlink(missing_ok=True)  # remove existing file if one already there
     torch.save(quantized_state_dict, quantize_path)
     print(f"Quantization complete took {time.time() - t0:.02f} seconds")
-    return
 
 
 if __name__ == "__main__":
@@ -492,9 +510,6 @@ if __name__ == "__main__":
     parser.add_argument(
         "--groupsize", type=int, default=32, help="Group size for int4 quantization."
     )
-    parser.add_argument(
-        "--label", type=str, default="_", help="label to add to output filename"
-    )
 
     args = parser.parse_args()
-    quantize(args.checkpoint_path, args.mode, args.groupsize, args.label)
+    quantize(args.checkpoint_path, args.mode, args.groupsize)