|
|
@@ -6,11 +6,12 @@
|
|
|
import time
|
|
|
from pathlib import Path
|
|
|
|
|
|
+import click
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.nn.functional as F
|
|
|
|
|
|
-from fish_speech.models.text2semantic.llama import ModelArgs, Transformer, find_multiple
|
|
|
+from .generate import load_model
|
|
|
|
|
|
##### Quantization Primitives ######
|
|
|
|
|
|
@@ -414,11 +415,21 @@ class WeightOnlyInt4Linear(torch.nn.Module):
|
|
|
)
|
|
|
|
|
|
|
|
|
+@click.command()
|
|
|
+@click.option(
|
|
|
+ "--checkpoint-path",
|
|
|
+ type=click.Path(path_type=Path, exists=True),
|
|
|
+ default="checkpoints/text2semantic-sft-medium-v1-4k.pth",
|
|
|
+)
|
|
|
+@click.option("--config-name", type=str, default="dual_ar_2_codebook_medium")
|
|
|
+@click.option(
|
|
|
+ "--mode", type=str, default="int8", help="type of quantization to perform"
|
|
|
+)
|
|
|
+@click.option(
|
|
|
+ "--groupsize", type=int, default=128, help="Group size for int4 quantization."
|
|
|
+)
|
|
|
def quantize(
|
|
|
- checkpoint_path: Path = Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"),
|
|
|
- mode: str = "int8",
|
|
|
- # following arguments only available when setting int4 quantization.
|
|
|
- groupsize: int = 128,
|
|
|
+ checkpoint_path: Path, config_name: str, mode: str, groupsize: int
|
|
|
) -> None:
|
|
|
assert checkpoint_path.is_file(), checkpoint_path
|
|
|
|
|
|
@@ -428,31 +439,14 @@ def quantize(
|
|
|
print("Loading model ...")
|
|
|
t0 = time.time()
|
|
|
|
|
|
- with torch.device("meta"):
|
|
|
- model = Transformer(
|
|
|
- ModelArgs(
|
|
|
- max_seq_len=4096,
|
|
|
- vocab_size=36408,
|
|
|
- n_layer=24,
|
|
|
- n_head=16,
|
|
|
- dim=1024,
|
|
|
- rope_base=10000,
|
|
|
- norm_eps=1e-5,
|
|
|
- num_codebooks=4, # single codebook
|
|
|
- codebook_size=168, # codebook size 160 + 2 special tokens
|
|
|
- )
|
|
|
- )
|
|
|
-
|
|
|
- checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
|
|
|
- if "state_dict" in checkpoint:
|
|
|
- checkpoint = checkpoint["state_dict"]
|
|
|
- checkpoint = {
|
|
|
- k.replace("model.", ""): v
|
|
|
- for k, v in checkpoint.items()
|
|
|
- if k.startswith("model.")
|
|
|
- }
|
|
|
- model.load_state_dict(checkpoint, assign=True)
|
|
|
- model = model.to(dtype=precision, device=device)
|
|
|
+ model, _ = load_model(
|
|
|
+ config_name,
|
|
|
+ checkpoint_path=checkpoint_path,
|
|
|
+ device=device,
|
|
|
+ precision=precision,
|
|
|
+ compile=False,
|
|
|
+ max_length=2048,
|
|
|
+ )
|
|
|
|
|
|
if mode == "int8":
|
|
|
print(
|
|
|
@@ -490,26 +484,4 @@ def quantize(
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
- import argparse
|
|
|
-
|
|
|
- parser = argparse.ArgumentParser(description="Quantize a model.")
|
|
|
- parser.add_argument(
|
|
|
- "--checkpoint_path",
|
|
|
- type=Path,
|
|
|
- default=Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"),
|
|
|
- help="Path to the model checkpoint to be quantized.",
|
|
|
- )
|
|
|
- parser.add_argument(
|
|
|
- "--mode",
|
|
|
- "-q",
|
|
|
- type=str,
|
|
|
- default="int8",
|
|
|
- choices=["int8", "int4"],
|
|
|
- help="type of quantization to perform",
|
|
|
- )
|
|
|
- parser.add_argument(
|
|
|
- "--groupsize", type=int, default=32, help="Group size for int4 quantization."
|
|
|
- )
|
|
|
-
|
|
|
- args = parser.parse_args()
|
|
|
- quantize(args.checkpoint_path, args.mode, args.groupsize)
|
|
|
+ quantize()
|