quantize.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. # This source code is licensed under the license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. import time
  6. from pathlib import Path
  7. import torch
  8. import torch.nn as nn
  9. import torch.nn.functional as F
  10. from fish_speech.models.text2semantic.llama import ModelArgs, Transformer, find_multiple
  11. ##### Quantization Primitives ######
  12. def dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype):
  13. # assumes symmetric quantization
  14. # assumes axis == 0
  15. # assumes dense memory format
  16. # TODO(future): relax ^ as needed
  17. # default setup for affine quantization of activations
  18. eps = torch.finfo(torch.float32).eps
  19. # get min and max
  20. min_val, max_val = torch.aminmax(x, dim=1)
  21. # calculate scales and zero_points based on min and max
  22. # reference: https://fburl.com/code/srbiybme
  23. min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
  24. max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
  25. device = min_val_neg.device
  26. # reference: https://fburl.com/code/4wll53rk
  27. max_val_pos = torch.max(-min_val_neg, max_val_pos)
  28. scales = max_val_pos / (float(quant_max - quant_min) / 2)
  29. # ensure scales is the same dtype as the original tensor
  30. scales = torch.clamp(scales, min=eps).to(x.dtype)
  31. zero_points = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device)
  32. # quantize based on qmin/qmax/scales/zp
  33. # reference: https://www.internalfb.com/code/fbsource/[8edc275012b1]/fbcode/caffe2/torch/ao/quantization/fx/_decomposed.py?lines=63
  34. x_div = x / scales.unsqueeze(-1)
  35. x_round = torch.round(x_div)
  36. x_zp = x_round + zero_points.unsqueeze(-1)
  37. quant = torch.clamp(x_zp, quant_min, quant_max).to(target_dtype)
  38. return quant, scales, zero_points
  39. def get_group_qparams(w, n_bit=4, groupsize=128):
  40. # needed for GPTQ with padding
  41. if groupsize > w.shape[-1]:
  42. groupsize = w.shape[-1]
  43. assert groupsize > 1
  44. assert w.shape[-1] % groupsize == 0
  45. assert w.dim() == 2
  46. to_quant = w.reshape(-1, groupsize)
  47. assert torch.isnan(to_quant).sum() == 0
  48. max_val = to_quant.amax(dim=1, keepdim=True)
  49. min_val = to_quant.amin(dim=1, keepdim=True)
  50. max_int = 2**n_bit - 1
  51. scales = (max_val - min_val).clamp(min=1e-6) / max_int
  52. zeros = min_val + scales * (2 ** (n_bit - 1))
  53. return scales.to(torch.bfloat16).reshape(w.shape[0], -1), zeros.to(
  54. torch.bfloat16
  55. ).reshape(w.shape[0], -1)
  56. def pack_scales_and_zeros(scales, zeros):
  57. assert scales.shape == zeros.shape
  58. assert scales.dtype == torch.bfloat16
  59. assert zeros.dtype == torch.bfloat16
  60. return (
  61. torch.cat(
  62. [
  63. scales.reshape(scales.size(0), scales.size(1), 1),
  64. zeros.reshape(zeros.size(0), zeros.size(1), 1),
  65. ],
  66. 2,
  67. )
  68. .transpose(0, 1)
  69. .contiguous()
  70. )
  71. def unpack_scales_and_zeros(scales_and_zeros):
  72. assert len(scales_and_zeros.shape) == 3 and scales_and_zeros.shape[2] == 2
  73. assert scales_and_zeros.dtype == torch.float
  74. return torch.split(scales_and_zeros.transpose(0, 1), 1, 2)
  75. def group_quantize_tensor_from_qparams(w, scales, zeros, n_bit=4, groupsize=128):
  76. assert groupsize > 1
  77. # needed for GPTQ single column quantize
  78. if groupsize > w.shape[-1] and scales.shape[-1] == 1:
  79. groupsize = w.shape[-1]
  80. assert w.shape[-1] % groupsize == 0
  81. assert w.dim() == 2
  82. to_quant = w.reshape(-1, groupsize)
  83. assert torch.isnan(to_quant).sum() == 0
  84. scales = scales.reshape(-1, 1)
  85. zeros = zeros.reshape(-1, 1)
  86. min_val = zeros - scales * (2 ** (n_bit - 1))
  87. max_int = 2**n_bit - 1
  88. min_int = 0
  89. w_int32 = (
  90. to_quant.sub(min_val)
  91. .div(scales)
  92. .round()
  93. .clamp_(min_int, max_int)
  94. .to(torch.int32)
  95. .reshape_as(w)
  96. )
  97. return w_int32
  98. def group_quantize_tensor(w, n_bit=4, groupsize=128):
  99. scales, zeros = get_group_qparams(w, n_bit, groupsize)
  100. w_int32 = group_quantize_tensor_from_qparams(w, scales, zeros, n_bit, groupsize)
  101. scales_and_zeros = pack_scales_and_zeros(scales, zeros)
  102. return w_int32, scales_and_zeros
  103. def group_dequantize_tensor_from_qparams(
  104. w_int32, scales, zeros, n_bit=4, groupsize=128
  105. ):
  106. assert groupsize > 1
  107. # needed for GPTQ single column dequantize
  108. if groupsize > w_int32.shape[-1] and scales.shape[-1] == 1:
  109. groupsize = w_int32.shape[-1]
  110. assert w_int32.shape[-1] % groupsize == 0
  111. assert w_int32.dim() == 2
  112. w_int32_grouped = w_int32.reshape(-1, groupsize)
  113. scales = scales.reshape(-1, 1)
  114. zeros = zeros.reshape(-1, 1)
  115. w_dq = (
  116. w_int32_grouped.sub(2 ** (n_bit - 1)).mul(scales).add(zeros).reshape_as(w_int32)
  117. )
  118. return w_dq
  119. def group_dequantize_tensor(w_int32, scales_and_zeros, n_bit=4, groupsize=128):
  120. scales, zeros = unpack_scales_and_zeros(scales_and_zeros)
  121. return group_dequantize_tensor_from_qparams(
  122. w_int32, scales, zeros, n_bit, groupsize
  123. )
  124. class QuantHandler:
  125. def __init__(self, mod):
  126. self.mod = mod
  127. def create_quantized_state_dict(self) -> "StateDict":
  128. pass
  129. def convert_for_runtime(self) -> "nn.Module":
  130. pass
  131. ##### Weight-only int8 per-channel quantized code ######
  132. def replace_linear_weight_only_int8_per_channel(module):
  133. for name, child in module.named_children():
  134. if isinstance(child, nn.Linear):
  135. setattr(
  136. module,
  137. name,
  138. WeightOnlyInt8Linear(child.in_features, child.out_features),
  139. )
  140. else:
  141. replace_linear_weight_only_int8_per_channel(child)
  142. class WeightOnlyInt8QuantHandler:
  143. def __init__(self, mod):
  144. self.mod = mod
  145. @torch.no_grad()
  146. def create_quantized_state_dict(self):
  147. cur_state_dict = self.mod.state_dict()
  148. for fqn, mod in self.mod.named_modules():
  149. if isinstance(mod, torch.nn.Linear):
  150. int8_weight, scales, _ = dynamically_quantize_per_channel(
  151. mod.weight.float(), -128, 127, torch.int8
  152. )
  153. cur_state_dict[f"{fqn}.weight"] = int8_weight
  154. cur_state_dict[f"{fqn}.scales"] = scales.to(mod.weight.dtype)
  155. return cur_state_dict
  156. def convert_for_runtime(self):
  157. replace_linear_weight_only_int8_per_channel(self.mod)
  158. return self.mod
  159. class WeightOnlyInt8Linear(torch.nn.Module):
  160. __constants__ = ["in_features", "out_features"]
  161. in_features: int
  162. out_features: int
  163. weight: torch.Tensor
  164. def __init__(
  165. self,
  166. in_features: int,
  167. out_features: int,
  168. bias: bool = True,
  169. device=None,
  170. dtype=None,
  171. ) -> None:
  172. factory_kwargs = {"device": device, "dtype": dtype}
  173. super().__init__()
  174. self.in_features = in_features
  175. self.out_features = out_features
  176. self.register_buffer(
  177. "weight", torch.empty((out_features, in_features), dtype=torch.int8)
  178. )
  179. self.register_buffer("scales", torch.ones(out_features, dtype=torch.bfloat16))
  180. def forward(self, input: torch.Tensor) -> torch.Tensor:
  181. return F.linear(input, self.weight.to(dtype=input.dtype)) * self.scales
  182. ##### weight only int4 per channel groupwise quantized code ######
  183. def prepare_int4_weight_and_scales_and_zeros(weight_bf16, groupsize, inner_k_tiles):
  184. weight_int32, scales_and_zeros = group_quantize_tensor(
  185. weight_bf16, n_bit=4, groupsize=groupsize
  186. )
  187. weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(
  188. weight_int32, inner_k_tiles
  189. )
  190. return weight_int4pack, scales_and_zeros
  191. def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, groupsize):
  192. origin_x_size = x.size()
  193. x = x.reshape(-1, origin_x_size[-1])
  194. c = torch.ops.aten._weight_int4pack_mm(
  195. x, weight_int4pack, groupsize, scales_and_zeros
  196. )
  197. new_shape = origin_x_size[:-1] + (out_features,)
  198. c = c.reshape(new_shape)
  199. return c
  200. def _check_linear_int4_k(k, groupsize=1, inner_k_tiles=1):
  201. return k % groupsize == 0 and k % (inner_k_tiles * 16) == 0
  202. def replace_linear_int4(module, groupsize, inner_k_tiles, padding):
  203. for name, child in module.named_children():
  204. if isinstance(child, nn.Linear):
  205. if _check_linear_int4_k(child.in_features, groupsize, inner_k_tiles):
  206. setattr(
  207. module,
  208. name,
  209. WeightOnlyInt4Linear(
  210. child.in_features,
  211. child.out_features,
  212. bias=False,
  213. groupsize=groupsize,
  214. inner_k_tiles=inner_k_tiles,
  215. padding=False,
  216. ),
  217. )
  218. elif padding:
  219. setattr(
  220. module,
  221. name,
  222. WeightOnlyInt4Linear(
  223. child.in_features,
  224. child.out_features,
  225. bias=False,
  226. groupsize=groupsize,
  227. inner_k_tiles=inner_k_tiles,
  228. padding=True,
  229. ),
  230. )
  231. else:
  232. replace_linear_int4(child, groupsize, inner_k_tiles, padding)
  233. class WeightOnlyInt4QuantHandler:
  234. def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True):
  235. self.mod = mod
  236. self.groupsize = groupsize
  237. self.inner_k_tiles = inner_k_tiles
  238. self.padding = padding
  239. assert groupsize in [32, 64, 128, 256]
  240. assert inner_k_tiles in [2, 4, 8]
  241. @torch.no_grad()
  242. def create_quantized_state_dict(self):
  243. cur_state_dict = self.mod.state_dict()
  244. for fqn, mod in self.mod.named_modules():
  245. if isinstance(mod, torch.nn.Linear):
  246. assert not mod.bias
  247. out_features = mod.out_features
  248. in_features = mod.in_features
  249. assert out_features % 8 == 0, "require out_features % 8 == 0"
  250. print(f"linear: {fqn}, in={in_features}, out={out_features}")
  251. weight = mod.weight.data
  252. if not _check_linear_int4_k(
  253. in_features, self.groupsize, self.inner_k_tiles
  254. ):
  255. if self.padding:
  256. import torch.nn.functional as F
  257. print(
  258. f"warning: {fqn} is padded to satisfy in_features % 1024 == 0"
  259. )
  260. padded_in_features = find_multiple(in_features, 1024)
  261. weight = F.pad(
  262. weight, pad=(0, padded_in_features - in_features)
  263. )
  264. else:
  265. print(
  266. f"warning: {fqn} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, "
  267. + "and that groupsize and inner_k_tiles*16 evenly divide into it"
  268. )
  269. continue
  270. (
  271. weight_int4pack,
  272. scales_and_zeros,
  273. ) = prepare_int4_weight_and_scales_and_zeros(
  274. weight.to(torch.bfloat16).to("cuda"),
  275. self.groupsize,
  276. self.inner_k_tiles,
  277. )
  278. cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to("cpu")
  279. cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to("cpu")
  280. return cur_state_dict
  281. def convert_for_runtime(self):
  282. replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding)
  283. return self.mod
  284. class WeightOnlyInt4Linear(torch.nn.Module):
  285. __constants__ = ["in_features", "out_features"]
  286. in_features: int
  287. out_features: int
  288. weight: torch.Tensor
  289. def __init__(
  290. self,
  291. in_features: int,
  292. out_features: int,
  293. bias=True,
  294. device=None,
  295. dtype=None,
  296. groupsize: int = 128,
  297. inner_k_tiles: int = 8,
  298. padding: bool = True,
  299. ) -> None:
  300. super().__init__()
  301. self.padding = padding
  302. if padding:
  303. self.origin_in_features = in_features
  304. in_features = find_multiple(in_features, 1024)
  305. self.in_features = in_features
  306. self.out_features = out_features
  307. assert not bias, "require bias=False"
  308. self.groupsize = groupsize
  309. self.inner_k_tiles = inner_k_tiles
  310. assert out_features % 8 == 0, "require out_features % 8 == 0"
  311. assert (
  312. in_features % (inner_k_tiles * 16) == 0
  313. ), "require in_features % (innerKTiles * 16) == 0"
  314. self.register_buffer(
  315. "weight",
  316. torch.empty(
  317. (
  318. out_features // 8,
  319. in_features // (inner_k_tiles * 16),
  320. 32,
  321. inner_k_tiles // 2,
  322. ),
  323. dtype=torch.int32,
  324. ),
  325. )
  326. self.register_buffer(
  327. "scales_and_zeros",
  328. torch.empty(
  329. (in_features // groupsize, out_features, 2), dtype=torch.bfloat16
  330. ),
  331. )
  332. def forward(self, input: torch.Tensor) -> torch.Tensor:
  333. input = input.to(torch.bfloat16)
  334. if self.padding:
  335. import torch.nn.functional as F
  336. input = F.pad(input, pad=(0, self.in_features - self.origin_in_features))
  337. return linear_forward_int4(
  338. input, self.weight, self.scales_and_zeros, self.out_features, self.groupsize
  339. )
  340. def quantize(
  341. checkpoint_path: Path = Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"),
  342. mode: str = "int8",
  343. # following arguments only available when setting int4 quantization.
  344. groupsize: int = 128,
  345. ) -> None:
  346. assert checkpoint_path.is_file(), checkpoint_path
  347. device = "cpu"
  348. precision = torch.bfloat16
  349. print("Loading model ...")
  350. t0 = time.time()
  351. with torch.device("meta"):
  352. model = Transformer(
  353. ModelArgs(
  354. max_seq_len=4096,
  355. vocab_size=36408,
  356. n_layer=24,
  357. n_head=16,
  358. dim=1024,
  359. rope_base=10000,
  360. norm_eps=1e-5,
  361. num_codebooks=4, # single codebook
  362. codebook_size=168, # codebook size 160 + 2 special tokens
  363. )
  364. )
  365. checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
  366. if "state_dict" in checkpoint:
  367. checkpoint = checkpoint["state_dict"]
  368. checkpoint = {
  369. k.replace("model.", ""): v
  370. for k, v in checkpoint.items()
  371. if k.startswith("model.")
  372. }
  373. model.load_state_dict(checkpoint, assign=True)
  374. model = model.to(dtype=precision, device=device)
  375. if mode == "int8":
  376. print(
  377. "Quantizing model weights for int8 weight-only symmetric per-channel quantization"
  378. )
  379. quant_handler = WeightOnlyInt8QuantHandler(model)
  380. quantized_state_dict = quant_handler.create_quantized_state_dict()
  381. dir_name = checkpoint_path.parent
  382. base_name = checkpoint_path.stem
  383. suffix = checkpoint_path.suffix
  384. quantize_path = dir_name / f"{base_name}.int8{suffix}"
  385. elif mode == "int4":
  386. print(
  387. "Quantizing model weights for int4 weight-only affine per-channel groupwise quantization"
  388. )
  389. quant_handler = WeightOnlyInt4QuantHandler(model, groupsize)
  390. quantized_state_dict = quant_handler.create_quantized_state_dict()
  391. dir_name = checkpoint_path.parent
  392. base_name = checkpoint_path.name
  393. suffix = checkpoint_path.suffix
  394. quantize_path = dir_name / f"{base_name}.int4.g{groupsize}{suffix}"
  395. else:
  396. raise ValueError(
  397. f"Invalid quantization mode {mode} needs to be one of [int8, int4, int4-gpptq]"
  398. )
  399. print(f"Writing quantized weights to {quantize_path}")
  400. quantize_path.unlink(missing_ok=True) # remove existing file if one already there
  401. torch.save(quantized_state_dict, quantize_path)
  402. print(f"Quantization complete took {time.time() - t0:.02f} seconds")
  403. if __name__ == "__main__":
  404. import argparse
  405. parser = argparse.ArgumentParser(description="Quantize a model.")
  406. parser.add_argument(
  407. "--checkpoint_path",
  408. type=Path,
  409. default=Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"),
  410. help="Path to the model checkpoint to be quantized.",
  411. )
  412. parser.add_argument(
  413. "--mode",
  414. "-q",
  415. type=str,
  416. default="int8",
  417. choices=["int8", "int4"],
  418. help="type of quantization to perform",
  419. )
  420. parser.add_argument(
  421. "--groupsize", type=int, default=32, help="Group size for int4 quantization."
  422. )
  423. args = parser.parse_args()
  424. quantize(args.checkpoint_path, args.mode, args.groupsize)