quantize.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. # This source code is licensed under the license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. import time
  6. from pathlib import Path
  7. import click
  8. import torch
  9. import torch.nn as nn
  10. import torch.nn.functional as F
  11. from .generate import load_model
  12. ##### Quantization Primitives ######
  13. def dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype):
  14. # assumes symmetric quantization
  15. # assumes axis == 0
  16. # assumes dense memory format
  17. # TODO(future): relax ^ as needed
  18. # default setup for affine quantization of activations
  19. eps = torch.finfo(torch.float32).eps
  20. # get min and max
  21. min_val, max_val = torch.aminmax(x, dim=1)
  22. # calculate scales and zero_points based on min and max
  23. # reference: https://fburl.com/code/srbiybme
  24. min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
  25. max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
  26. device = min_val_neg.device
  27. # reference: https://fburl.com/code/4wll53rk
  28. max_val_pos = torch.max(-min_val_neg, max_val_pos)
  29. scales = max_val_pos / (float(quant_max - quant_min) / 2)
  30. # ensure scales is the same dtype as the original tensor
  31. scales = torch.clamp(scales, min=eps).to(x.dtype)
  32. zero_points = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device)
  33. # quantize based on qmin/qmax/scales/zp
  34. # reference: https://www.internalfb.com/code/fbsource/[8edc275012b1]/fbcode/caffe2/torch/ao/quantization/fx/_decomposed.py?lines=63
  35. x_div = x / scales.unsqueeze(-1)
  36. x_round = torch.round(x_div)
  37. x_zp = x_round + zero_points.unsqueeze(-1)
  38. quant = torch.clamp(x_zp, quant_min, quant_max).to(target_dtype)
  39. return quant, scales, zero_points
  40. def get_group_qparams(w, n_bit=4, groupsize=128):
  41. # needed for GPTQ with padding
  42. if groupsize > w.shape[-1]:
  43. groupsize = w.shape[-1]
  44. assert groupsize > 1
  45. assert w.shape[-1] % groupsize == 0
  46. assert w.dim() == 2
  47. to_quant = w.reshape(-1, groupsize)
  48. assert torch.isnan(to_quant).sum() == 0
  49. max_val = to_quant.amax(dim=1, keepdim=True)
  50. min_val = to_quant.amin(dim=1, keepdim=True)
  51. max_int = 2**n_bit - 1
  52. scales = (max_val - min_val).clamp(min=1e-6) / max_int
  53. zeros = min_val + scales * (2 ** (n_bit - 1))
  54. return scales.to(torch.bfloat16).reshape(w.shape[0], -1), zeros.to(
  55. torch.bfloat16
  56. ).reshape(w.shape[0], -1)
  57. def pack_scales_and_zeros(scales, zeros):
  58. assert scales.shape == zeros.shape
  59. assert scales.dtype == torch.bfloat16
  60. assert zeros.dtype == torch.bfloat16
  61. return (
  62. torch.cat(
  63. [
  64. scales.reshape(scales.size(0), scales.size(1), 1),
  65. zeros.reshape(zeros.size(0), zeros.size(1), 1),
  66. ],
  67. 2,
  68. )
  69. .transpose(0, 1)
  70. .contiguous()
  71. )
  72. def unpack_scales_and_zeros(scales_and_zeros):
  73. assert len(scales_and_zeros.shape) == 3 and scales_and_zeros.shape[2] == 2
  74. assert scales_and_zeros.dtype == torch.float
  75. return torch.split(scales_and_zeros.transpose(0, 1), 1, 2)
  76. def group_quantize_tensor_from_qparams(w, scales, zeros, n_bit=4, groupsize=128):
  77. assert groupsize > 1
  78. # needed for GPTQ single column quantize
  79. if groupsize > w.shape[-1] and scales.shape[-1] == 1:
  80. groupsize = w.shape[-1]
  81. assert w.shape[-1] % groupsize == 0
  82. assert w.dim() == 2
  83. to_quant = w.reshape(-1, groupsize)
  84. assert torch.isnan(to_quant).sum() == 0
  85. scales = scales.reshape(-1, 1)
  86. zeros = zeros.reshape(-1, 1)
  87. min_val = zeros - scales * (2 ** (n_bit - 1))
  88. max_int = 2**n_bit - 1
  89. min_int = 0
  90. w_int32 = (
  91. to_quant.sub(min_val)
  92. .div(scales)
  93. .round()
  94. .clamp_(min_int, max_int)
  95. .to(torch.int32)
  96. .reshape_as(w)
  97. )
  98. return w_int32
  99. def group_quantize_tensor(w, n_bit=4, groupsize=128):
  100. scales, zeros = get_group_qparams(w, n_bit, groupsize)
  101. w_int32 = group_quantize_tensor_from_qparams(w, scales, zeros, n_bit, groupsize)
  102. scales_and_zeros = pack_scales_and_zeros(scales, zeros)
  103. return w_int32, scales_and_zeros
  104. def group_dequantize_tensor_from_qparams(
  105. w_int32, scales, zeros, n_bit=4, groupsize=128
  106. ):
  107. assert groupsize > 1
  108. # needed for GPTQ single column dequantize
  109. if groupsize > w_int32.shape[-1] and scales.shape[-1] == 1:
  110. groupsize = w_int32.shape[-1]
  111. assert w_int32.shape[-1] % groupsize == 0
  112. assert w_int32.dim() == 2
  113. w_int32_grouped = w_int32.reshape(-1, groupsize)
  114. scales = scales.reshape(-1, 1)
  115. zeros = zeros.reshape(-1, 1)
  116. w_dq = (
  117. w_int32_grouped.sub(2 ** (n_bit - 1)).mul(scales).add(zeros).reshape_as(w_int32)
  118. )
  119. return w_dq
  120. def group_dequantize_tensor(w_int32, scales_and_zeros, n_bit=4, groupsize=128):
  121. scales, zeros = unpack_scales_and_zeros(scales_and_zeros)
  122. return group_dequantize_tensor_from_qparams(
  123. w_int32, scales, zeros, n_bit, groupsize
  124. )
  125. class QuantHandler:
  126. def __init__(self, mod):
  127. self.mod = mod
  128. def create_quantized_state_dict(self) -> "StateDict":
  129. pass
  130. def convert_for_runtime(self) -> "nn.Module":
  131. pass
  132. ##### Weight-only int8 per-channel quantized code ######
  133. def replace_linear_weight_only_int8_per_channel(module):
  134. for name, child in module.named_children():
  135. if isinstance(child, nn.Linear):
  136. setattr(
  137. module,
  138. name,
  139. WeightOnlyInt8Linear(child.in_features, child.out_features),
  140. )
  141. else:
  142. replace_linear_weight_only_int8_per_channel(child)
  143. class WeightOnlyInt8QuantHandler:
  144. def __init__(self, mod):
  145. self.mod = mod
  146. @torch.no_grad()
  147. def create_quantized_state_dict(self):
  148. cur_state_dict = self.mod.state_dict()
  149. for fqn, mod in self.mod.named_modules():
  150. if isinstance(mod, torch.nn.Linear):
  151. int8_weight, scales, _ = dynamically_quantize_per_channel(
  152. mod.weight.float(), -128, 127, torch.int8
  153. )
  154. cur_state_dict[f"{fqn}.weight"] = int8_weight
  155. cur_state_dict[f"{fqn}.scales"] = scales.to(mod.weight.dtype)
  156. return cur_state_dict
  157. def convert_for_runtime(self):
  158. replace_linear_weight_only_int8_per_channel(self.mod)
  159. return self.mod
  160. class WeightOnlyInt8Linear(torch.nn.Module):
  161. __constants__ = ["in_features", "out_features"]
  162. in_features: int
  163. out_features: int
  164. weight: torch.Tensor
  165. def __init__(
  166. self,
  167. in_features: int,
  168. out_features: int,
  169. bias: bool = True,
  170. device=None,
  171. dtype=None,
  172. ) -> None:
  173. factory_kwargs = {"device": device, "dtype": dtype}
  174. super().__init__()
  175. self.in_features = in_features
  176. self.out_features = out_features
  177. self.register_buffer(
  178. "weight", torch.empty((out_features, in_features), dtype=torch.int8)
  179. )
  180. self.register_buffer("scales", torch.ones(out_features, dtype=torch.bfloat16))
  181. def forward(self, input: torch.Tensor) -> torch.Tensor:
  182. return F.linear(input, self.weight.to(dtype=input.dtype)) * self.scales
  183. ##### weight only int4 per channel groupwise quantized code ######
  184. def prepare_int4_weight_and_scales_and_zeros(weight_bf16, groupsize, inner_k_tiles):
  185. weight_int32, scales_and_zeros = group_quantize_tensor(
  186. weight_bf16, n_bit=4, groupsize=groupsize
  187. )
  188. weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(
  189. weight_int32, inner_k_tiles
  190. )
  191. return weight_int4pack, scales_and_zeros
  192. def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, groupsize):
  193. origin_x_size = x.size()
  194. x = x.reshape(-1, origin_x_size[-1])
  195. c = torch.ops.aten._weight_int4pack_mm(
  196. x, weight_int4pack, groupsize, scales_and_zeros
  197. )
  198. new_shape = origin_x_size[:-1] + (out_features,)
  199. c = c.reshape(new_shape)
  200. return c
  201. def _check_linear_int4_k(k, groupsize=1, inner_k_tiles=1):
  202. return k % groupsize == 0 and k % (inner_k_tiles * 16) == 0
  203. def replace_linear_int4(module, groupsize, inner_k_tiles, padding):
  204. for name, child in module.named_children():
  205. if isinstance(child, nn.Linear):
  206. if _check_linear_int4_k(child.in_features, groupsize, inner_k_tiles):
  207. setattr(
  208. module,
  209. name,
  210. WeightOnlyInt4Linear(
  211. child.in_features,
  212. child.out_features,
  213. bias=False,
  214. groupsize=groupsize,
  215. inner_k_tiles=inner_k_tiles,
  216. padding=False,
  217. ),
  218. )
  219. elif padding:
  220. setattr(
  221. module,
  222. name,
  223. WeightOnlyInt4Linear(
  224. child.in_features,
  225. child.out_features,
  226. bias=False,
  227. groupsize=groupsize,
  228. inner_k_tiles=inner_k_tiles,
  229. padding=True,
  230. ),
  231. )
  232. else:
  233. replace_linear_int4(child, groupsize, inner_k_tiles, padding)
  234. class WeightOnlyInt4QuantHandler:
  235. def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True):
  236. self.mod = mod
  237. self.groupsize = groupsize
  238. self.inner_k_tiles = inner_k_tiles
  239. self.padding = padding
  240. assert groupsize in [32, 64, 128, 256]
  241. assert inner_k_tiles in [2, 4, 8]
  242. @torch.no_grad()
  243. def create_quantized_state_dict(self):
  244. cur_state_dict = self.mod.state_dict()
  245. for fqn, mod in self.mod.named_modules():
  246. if isinstance(mod, torch.nn.Linear):
  247. assert not mod.bias
  248. out_features = mod.out_features
  249. in_features = mod.in_features
  250. assert out_features % 8 == 0, "require out_features % 8 == 0"
  251. print(f"linear: {fqn}, in={in_features}, out={out_features}")
  252. weight = mod.weight.data
  253. if not _check_linear_int4_k(
  254. in_features, self.groupsize, self.inner_k_tiles
  255. ):
  256. if self.padding:
  257. import torch.nn.functional as F
  258. print(
  259. f"warning: {fqn} is padded to satisfy in_features % 1024 == 0"
  260. )
  261. padded_in_features = find_multiple(in_features, 1024)
  262. weight = F.pad(
  263. weight, pad=(0, padded_in_features - in_features)
  264. )
  265. else:
  266. print(
  267. f"warning: {fqn} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, "
  268. + "and that groupsize and inner_k_tiles*16 evenly divide into it"
  269. )
  270. continue
  271. (
  272. weight_int4pack,
  273. scales_and_zeros,
  274. ) = prepare_int4_weight_and_scales_and_zeros(
  275. weight.to(torch.bfloat16).to("cuda"),
  276. self.groupsize,
  277. self.inner_k_tiles,
  278. )
  279. cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to("cpu")
  280. cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to("cpu")
  281. return cur_state_dict
  282. def convert_for_runtime(self):
  283. replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding)
  284. return self.mod
  285. class WeightOnlyInt4Linear(torch.nn.Module):
  286. __constants__ = ["in_features", "out_features"]
  287. in_features: int
  288. out_features: int
  289. weight: torch.Tensor
  290. def __init__(
  291. self,
  292. in_features: int,
  293. out_features: int,
  294. bias=True,
  295. device=None,
  296. dtype=None,
  297. groupsize: int = 128,
  298. inner_k_tiles: int = 8,
  299. padding: bool = True,
  300. ) -> None:
  301. super().__init__()
  302. self.padding = padding
  303. if padding:
  304. self.origin_in_features = in_features
  305. in_features = find_multiple(in_features, 1024)
  306. self.in_features = in_features
  307. self.out_features = out_features
  308. assert not bias, "require bias=False"
  309. self.groupsize = groupsize
  310. self.inner_k_tiles = inner_k_tiles
  311. assert out_features % 8 == 0, "require out_features % 8 == 0"
  312. assert (
  313. in_features % (inner_k_tiles * 16) == 0
  314. ), "require in_features % (innerKTiles * 16) == 0"
  315. self.register_buffer(
  316. "weight",
  317. torch.empty(
  318. (
  319. out_features // 8,
  320. in_features // (inner_k_tiles * 16),
  321. 32,
  322. inner_k_tiles // 2,
  323. ),
  324. dtype=torch.int32,
  325. ),
  326. )
  327. self.register_buffer(
  328. "scales_and_zeros",
  329. torch.empty(
  330. (in_features // groupsize, out_features, 2), dtype=torch.bfloat16
  331. ),
  332. )
  333. def forward(self, input: torch.Tensor) -> torch.Tensor:
  334. input = input.to(torch.bfloat16)
  335. if self.padding:
  336. import torch.nn.functional as F
  337. input = F.pad(input, pad=(0, self.in_features - self.origin_in_features))
  338. return linear_forward_int4(
  339. input, self.weight, self.scales_and_zeros, self.out_features, self.groupsize
  340. )
  341. @click.command()
  342. @click.option(
  343. "--checkpoint-path",
  344. type=click.Path(path_type=Path, exists=True),
  345. default="checkpoints/text2semantic-sft-medium-v1-4k.pth",
  346. )
  347. @click.option("--config-name", type=str, default="dual_ar_2_codebook_medium")
  348. @click.option(
  349. "--mode", type=str, default="int8", help="type of quantization to perform"
  350. )
  351. @click.option(
  352. "--groupsize", type=int, default=128, help="Group size for int4 quantization."
  353. )
  354. def quantize(
  355. checkpoint_path: Path, config_name: str, mode: str, groupsize: int
  356. ) -> None:
  357. assert checkpoint_path.is_file(), checkpoint_path
  358. device = "cpu"
  359. precision = torch.bfloat16
  360. print("Loading model ...")
  361. t0 = time.time()
  362. model, _ = load_model(
  363. config_name,
  364. checkpoint_path=checkpoint_path,
  365. device=device,
  366. precision=precision,
  367. compile=False,
  368. max_length=2048,
  369. )
  370. if mode == "int8":
  371. print(
  372. "Quantizing model weights for int8 weight-only symmetric per-channel quantization"
  373. )
  374. quant_handler = WeightOnlyInt8QuantHandler(model)
  375. quantized_state_dict = quant_handler.create_quantized_state_dict()
  376. dir_name = checkpoint_path.parent
  377. base_name = checkpoint_path.stem
  378. suffix = checkpoint_path.suffix
  379. quantize_path = dir_name / f"{base_name}.int8{suffix}"
  380. elif mode == "int4":
  381. print(
  382. "Quantizing model weights for int4 weight-only affine per-channel groupwise quantization"
  383. )
  384. quant_handler = WeightOnlyInt4QuantHandler(model, groupsize)
  385. quantized_state_dict = quant_handler.create_quantized_state_dict()
  386. dir_name = checkpoint_path.parent
  387. base_name = checkpoint_path.name
  388. suffix = checkpoint_path.suffix
  389. quantize_path = dir_name / f"{base_name}.int4.g{groupsize}{suffix}"
  390. else:
  391. raise ValueError(
  392. f"Invalid quantization mode {mode} needs to be one of [int8, int4, int4-gpptq]"
  393. )
  394. print(f"Writing quantized weights to {quantize_path}")
  395. quantize_path.unlink(missing_ok=True) # remove existing file if one already there
  396. torch.save(quantized_state_dict, quantize_path)
  397. print(f"Quantization complete took {time.time() - t0:.02f} seconds")
  398. if __name__ == "__main__":
  399. quantize()