|
|
@@ -10,7 +10,7 @@ import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.nn.functional as F
|
|
|
|
|
|
-from fish_speech.models.text2semantic.llama import Transformer, find_multiple
|
|
|
+from fish_speech.models.text2semantic.llama import ModelArgs, Transformer, find_multiple
|
|
|
|
|
|
##### Quantization Primitives ######
|
|
|
|
|
|
@@ -419,7 +419,6 @@ def quantize(
|
|
|
mode: str = "int8",
|
|
|
# following arguments only available when setting int4 quantization.
|
|
|
groupsize: int = 128,
|
|
|
- label: str = "",
|
|
|
) -> None:
|
|
|
assert checkpoint_path.is_file(), checkpoint_path
|
|
|
|
|
|
@@ -430,9 +429,28 @@ def quantize(
|
|
|
t0 = time.time()
|
|
|
|
|
|
with torch.device("meta"):
|
|
|
- model = Transformer.from_name(checkpoint_path.parent.name)
|
|
|
+ model = Transformer(
|
|
|
+ ModelArgs(
|
|
|
+ max_seq_len=4096,
|
|
|
+ vocab_size=36408,
|
|
|
+ n_layer=24,
|
|
|
+ n_head=16,
|
|
|
+ dim=1024,
|
|
|
+ rope_base=10000,
|
|
|
+ norm_eps=1e-5,
|
|
|
+ num_codebooks=4, # single codebook
|
|
|
+ codebook_size=168, # codebook size 160 + 2 special tokens
|
|
|
+ )
|
|
|
+ )
|
|
|
|
|
|
checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
|
|
|
+ if "state_dict" in checkpoint:
|
|
|
+ checkpoint = checkpoint["state_dict"]
|
|
|
+ checkpoint = {
|
|
|
+ k.replace("model.", ""): v
|
|
|
+ for k, v in checkpoint.items()
|
|
|
+ if k.startswith("model.")
|
|
|
+ }
|
|
|
model.load_state_dict(checkpoint, assign=True)
|
|
|
model = model.to(dtype=precision, device=device)
|
|
|
|
|
|
@@ -444,8 +462,9 @@ def quantize(
|
|
|
quantized_state_dict = quant_handler.create_quantized_state_dict()
|
|
|
|
|
|
dir_name = checkpoint_path.parent
|
|
|
- base_name = checkpoint_path.name
|
|
|
- new_base_name = base_name.replace(".pth", f"{label}int8.pth")
|
|
|
+ base_name = checkpoint_path.stem
|
|
|
+ suffix = checkpoint_path.suffix
|
|
|
+ quantize_path = dir_name / f"{base_name}.int8{suffix}"
|
|
|
|
|
|
elif mode == "int4":
|
|
|
print(
|
|
|
@@ -456,19 +475,18 @@ def quantize(
|
|
|
|
|
|
dir_name = checkpoint_path.parent
|
|
|
base_name = checkpoint_path.name
|
|
|
- new_base_name = base_name.replace(".pth", f"{label}int4.g{groupsize}.pth")
|
|
|
+ suffix = checkpoint_path.suffix
|
|
|
+ quantize_path = dir_name / f"{base_name}.int4.g{groupsize}{suffix}"
|
|
|
|
|
|
else:
|
|
|
raise ValueError(
|
|
|
f"Invalid quantization mode {mode} needs to be one of [int8, int4, int4-gpptq]"
|
|
|
)
|
|
|
|
|
|
- quantize_path = dir_name / new_base_name
|
|
|
print(f"Writing quantized weights to {quantize_path}")
|
|
|
quantize_path.unlink(missing_ok=True) # remove existing file if one already there
|
|
|
torch.save(quantized_state_dict, quantize_path)
|
|
|
print(f"Quantization complete took {time.time() - t0:.02f} seconds")
|
|
|
- return
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
@@ -492,9 +510,6 @@ if __name__ == "__main__":
|
|
|
parser.add_argument(
|
|
|
"--groupsize", type=int, default=32, help="Group size for int4 quantization."
|
|
|
)
|
|
|
- parser.add_argument(
|
|
|
- "--label", type=str, default="_", help="label to add to output filename"
|
|
|
- )
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
- quantize(args.checkpoint_path, args.mode, args.groupsize, args.label)
|
|
|
+ quantize(args.checkpoint_path, args.mode, args.groupsize)
|