quantize.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. import datetime
  4. import shutil
  5. # This source code is licensed under the license found in the
  6. # LICENSE file in the root directory of this source tree.
  7. import time
  8. from pathlib import Path
  9. import click
  10. import torch
  11. import torch.nn as nn
  12. import torch.nn.functional as F
  13. from fish_speech.models.text2semantic.inference import load_model
  14. from fish_speech.models.text2semantic.llama import find_multiple
  15. ##### Quantization Primitives ######
  16. def dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype):
  17. # assumes symmetric quantization
  18. # assumes axis == 0
  19. # assumes dense memory format
  20. # TODO(future): relax ^ as needed
  21. # default setup for affine quantization of activations
  22. eps = torch.finfo(torch.float32).eps
  23. # get min and max
  24. min_val, max_val = torch.aminmax(x, dim=1)
  25. # calculate scales and zero_points based on min and max
  26. # reference: https://fburl.com/code/srbiybme
  27. min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
  28. max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
  29. device = min_val_neg.device
  30. # reference: https://fburl.com/code/4wll53rk
  31. max_val_pos = torch.max(-min_val_neg, max_val_pos)
  32. scales = max_val_pos / (float(quant_max - quant_min) / 2)
  33. # ensure scales is the same dtype as the original tensor
  34. scales = torch.clamp(scales, min=eps).to(x.dtype)
  35. zero_points = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device)
  36. # quantize based on qmin/qmax/scales/zp
  37. # reference: https://www.internalfb.com/code/fbsource/[8edc275012b1]/fbcode/caffe2/torch/ao/quantization/fx/_decomposed.py?lines=63
  38. x_div = x / scales.unsqueeze(-1)
  39. x_round = torch.round(x_div)
  40. x_zp = x_round + zero_points.unsqueeze(-1)
  41. quant = torch.clamp(x_zp, quant_min, quant_max).to(target_dtype)
  42. return quant, scales, zero_points
  43. def get_group_qparams(w, n_bit=4, groupsize=128):
  44. # needed for GPTQ with padding
  45. if groupsize > w.shape[-1]:
  46. groupsize = w.shape[-1]
  47. assert groupsize > 1
  48. assert w.shape[-1] % groupsize == 0
  49. assert w.dim() == 2
  50. to_quant = w.reshape(-1, groupsize)
  51. assert torch.isnan(to_quant).sum() == 0
  52. max_val = to_quant.amax(dim=1, keepdim=True)
  53. min_val = to_quant.amin(dim=1, keepdim=True)
  54. max_int = 2**n_bit - 1
  55. scales = (max_val - min_val).clamp(min=1e-6) / max_int
  56. zeros = min_val + scales * (2 ** (n_bit - 1))
  57. return scales.to(torch.bfloat16).reshape(w.shape[0], -1), zeros.to(
  58. torch.bfloat16
  59. ).reshape(w.shape[0], -1)
  60. def pack_scales_and_zeros(scales, zeros):
  61. assert scales.shape == zeros.shape
  62. assert scales.dtype == torch.bfloat16
  63. assert zeros.dtype == torch.bfloat16
  64. return (
  65. torch.cat(
  66. [
  67. scales.reshape(scales.size(0), scales.size(1), 1),
  68. zeros.reshape(zeros.size(0), zeros.size(1), 1),
  69. ],
  70. 2,
  71. )
  72. .transpose(0, 1)
  73. .contiguous()
  74. )
  75. def unpack_scales_and_zeros(scales_and_zeros):
  76. assert len(scales_and_zeros.shape) == 3 and scales_and_zeros.shape[2] == 2
  77. assert scales_and_zeros.dtype == torch.float
  78. return torch.split(scales_and_zeros.transpose(0, 1), 1, 2)
  79. def group_quantize_tensor_from_qparams(w, scales, zeros, n_bit=4, groupsize=128):
  80. assert groupsize > 1
  81. # needed for GPTQ single column quantize
  82. if groupsize > w.shape[-1] and scales.shape[-1] == 1:
  83. groupsize = w.shape[-1]
  84. assert w.shape[-1] % groupsize == 0
  85. assert w.dim() == 2
  86. to_quant = w.reshape(-1, groupsize)
  87. assert torch.isnan(to_quant).sum() == 0
  88. scales = scales.reshape(-1, 1)
  89. zeros = zeros.reshape(-1, 1)
  90. min_val = zeros - scales * (2 ** (n_bit - 1))
  91. max_int = 2**n_bit - 1
  92. min_int = 0
  93. w_int32 = (
  94. to_quant.sub(min_val)
  95. .div(scales)
  96. .round()
  97. .clamp_(min_int, max_int)
  98. .to(torch.int32)
  99. .reshape_as(w)
  100. )
  101. return w_int32
  102. def group_quantize_tensor(w, n_bit=4, groupsize=128):
  103. scales, zeros = get_group_qparams(w, n_bit, groupsize)
  104. w_int32 = group_quantize_tensor_from_qparams(w, scales, zeros, n_bit, groupsize)
  105. scales_and_zeros = pack_scales_and_zeros(scales, zeros)
  106. return w_int32, scales_and_zeros
  107. def group_dequantize_tensor_from_qparams(
  108. w_int32, scales, zeros, n_bit=4, groupsize=128
  109. ):
  110. assert groupsize > 1
  111. # needed for GPTQ single column dequantize
  112. if groupsize > w_int32.shape[-1] and scales.shape[-1] == 1:
  113. groupsize = w_int32.shape[-1]
  114. assert w_int32.shape[-1] % groupsize == 0
  115. assert w_int32.dim() == 2
  116. w_int32_grouped = w_int32.reshape(-1, groupsize)
  117. scales = scales.reshape(-1, 1)
  118. zeros = zeros.reshape(-1, 1)
  119. w_dq = (
  120. w_int32_grouped.sub(2 ** (n_bit - 1)).mul(scales).add(zeros).reshape_as(w_int32)
  121. )
  122. return w_dq
  123. def group_dequantize_tensor(w_int32, scales_and_zeros, n_bit=4, groupsize=128):
  124. scales, zeros = unpack_scales_and_zeros(scales_and_zeros)
  125. return group_dequantize_tensor_from_qparams(
  126. w_int32, scales, zeros, n_bit, groupsize
  127. )
  128. class QuantHandler:
  129. def __init__(self, mod):
  130. self.mod = mod
  131. def create_quantized_state_dict(self) -> "StateDict":
  132. pass
  133. def convert_for_runtime(self) -> "nn.Module":
  134. pass
  135. ##### Weight-only int8 per-channel quantized code ######
  136. def replace_linear_weight_only_int8_per_channel(module):
  137. for name, child in module.named_children():
  138. if isinstance(child, nn.Linear):
  139. setattr(
  140. module,
  141. name,
  142. WeightOnlyInt8Linear(child.in_features, child.out_features),
  143. )
  144. else:
  145. replace_linear_weight_only_int8_per_channel(child)
  146. class WeightOnlyInt8QuantHandler:
  147. def __init__(self, mod):
  148. self.mod = mod
  149. @torch.no_grad()
  150. def create_quantized_state_dict(self):
  151. cur_state_dict = self.mod.state_dict()
  152. for fqn, mod in self.mod.named_modules():
  153. if isinstance(mod, torch.nn.Linear):
  154. int8_weight, scales, _ = dynamically_quantize_per_channel(
  155. mod.weight.float(), -128, 127, torch.int8
  156. )
  157. cur_state_dict[f"{fqn}.weight"] = int8_weight
  158. cur_state_dict[f"{fqn}.scales"] = scales.to(mod.weight.dtype)
  159. return cur_state_dict
  160. def convert_for_runtime(self):
  161. replace_linear_weight_only_int8_per_channel(self.mod)
  162. return self.mod
  163. class WeightOnlyInt8Linear(torch.nn.Module):
  164. __constants__ = ["in_features", "out_features"]
  165. in_features: int
  166. out_features: int
  167. weight: torch.Tensor
  168. def __init__(
  169. self,
  170. in_features: int,
  171. out_features: int,
  172. bias: bool = True,
  173. device=None,
  174. dtype=None,
  175. ) -> None:
  176. factory_kwargs = {"device": device, "dtype": dtype}
  177. super().__init__()
  178. self.in_features = in_features
  179. self.out_features = out_features
  180. self.register_buffer(
  181. "weight", torch.empty((out_features, in_features), dtype=torch.int8)
  182. )
  183. self.register_buffer("scales", torch.ones(out_features, dtype=torch.bfloat16))
  184. def forward(self, input: torch.Tensor) -> torch.Tensor:
  185. return F.linear(input, self.weight.to(dtype=input.dtype)) * self.scales
  186. ##### weight only int4 per channel groupwise quantized code ######
  187. def prepare_int4_weight_and_scales_and_zeros(weight_bf16, groupsize, inner_k_tiles):
  188. weight_int32, scales_and_zeros = group_quantize_tensor(
  189. weight_bf16, n_bit=4, groupsize=groupsize
  190. )
  191. weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(
  192. weight_int32, inner_k_tiles
  193. )
  194. return weight_int4pack, scales_and_zeros
  195. def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, groupsize):
  196. origin_x_size = x.size()
  197. x = x.reshape(-1, origin_x_size[-1])
  198. c = torch.ops.aten._weight_int4pack_mm(
  199. x, weight_int4pack, groupsize, scales_and_zeros
  200. )
  201. new_shape = origin_x_size[:-1] + (out_features,)
  202. c = c.reshape(new_shape)
  203. return c
  204. def _check_linear_int4_k(k, groupsize=1, inner_k_tiles=1):
  205. return k % groupsize == 0 and k % (inner_k_tiles * 16) == 0
  206. def replace_linear_int4(module, groupsize, inner_k_tiles, padding):
  207. for name, child in module.named_children():
  208. if isinstance(child, nn.Linear):
  209. if _check_linear_int4_k(child.in_features, groupsize, inner_k_tiles):
  210. setattr(
  211. module,
  212. name,
  213. WeightOnlyInt4Linear(
  214. child.in_features,
  215. child.out_features,
  216. bias=False,
  217. groupsize=groupsize,
  218. inner_k_tiles=inner_k_tiles,
  219. padding=False,
  220. ),
  221. )
  222. elif padding:
  223. setattr(
  224. module,
  225. name,
  226. WeightOnlyInt4Linear(
  227. child.in_features,
  228. child.out_features,
  229. bias=False,
  230. groupsize=groupsize,
  231. inner_k_tiles=inner_k_tiles,
  232. padding=True,
  233. ),
  234. )
  235. else:
  236. replace_linear_int4(child, groupsize, inner_k_tiles, padding)
  237. class WeightOnlyInt4QuantHandler:
  238. def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True):
  239. self.mod = mod
  240. self.groupsize = groupsize
  241. self.inner_k_tiles = inner_k_tiles
  242. self.padding = padding
  243. assert groupsize in [32, 64, 128, 256]
  244. assert inner_k_tiles in [2, 4, 8]
  245. @torch.no_grad()
  246. def create_quantized_state_dict(self):
  247. cur_state_dict = self.mod.state_dict()
  248. for fqn, mod in self.mod.named_modules():
  249. if isinstance(mod, torch.nn.Linear):
  250. assert not mod.bias
  251. out_features = mod.out_features
  252. in_features = mod.in_features
  253. assert out_features % 8 == 0, "require out_features % 8 == 0"
  254. print(f"linear: {fqn}, in={in_features}, out={out_features}")
  255. weight = mod.weight.data
  256. if not _check_linear_int4_k(
  257. in_features, self.groupsize, self.inner_k_tiles
  258. ):
  259. if self.padding:
  260. import torch.nn.functional as F
  261. print(
  262. f"warning: {fqn} is padded to satisfy in_features % 1024 == 0"
  263. )
  264. padded_in_features = find_multiple(in_features, 1024)
  265. weight = F.pad(
  266. weight, pad=(0, padded_in_features - in_features)
  267. )
  268. else:
  269. print(
  270. f"warning: {fqn} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, "
  271. + "and that groupsize and inner_k_tiles*16 evenly divide into it"
  272. )
  273. continue
  274. (
  275. weight_int4pack,
  276. scales_and_zeros,
  277. ) = prepare_int4_weight_and_scales_and_zeros(
  278. weight.to(torch.bfloat16).to("cuda"),
  279. self.groupsize,
  280. self.inner_k_tiles,
  281. )
  282. cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to("cpu")
  283. cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to("cpu")
  284. return cur_state_dict
  285. def convert_for_runtime(self):
  286. replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding)
  287. return self.mod
  288. class WeightOnlyInt4Linear(torch.nn.Module):
  289. __constants__ = ["in_features", "out_features"]
  290. in_features: int
  291. out_features: int
  292. weight: torch.Tensor
  293. def __init__(
  294. self,
  295. in_features: int,
  296. out_features: int,
  297. bias=True,
  298. device=None,
  299. dtype=None,
  300. groupsize: int = 128,
  301. inner_k_tiles: int = 8,
  302. padding: bool = True,
  303. ) -> None:
  304. super().__init__()
  305. self.padding = padding
  306. if padding:
  307. self.origin_in_features = in_features
  308. in_features = find_multiple(in_features, 1024)
  309. self.in_features = in_features
  310. self.out_features = out_features
  311. assert not bias, "require bias=False"
  312. self.groupsize = groupsize
  313. self.inner_k_tiles = inner_k_tiles
  314. assert out_features % 8 == 0, "require out_features % 8 == 0"
  315. assert (
  316. in_features % (inner_k_tiles * 16) == 0
  317. ), "require in_features % (innerKTiles * 16) == 0"
  318. self.register_buffer(
  319. "weight",
  320. torch.empty(
  321. (
  322. out_features // 8,
  323. in_features // (inner_k_tiles * 16),
  324. 32,
  325. inner_k_tiles // 2,
  326. ),
  327. dtype=torch.int32,
  328. ),
  329. )
  330. self.register_buffer(
  331. "scales_and_zeros",
  332. torch.empty(
  333. (in_features // groupsize, out_features, 2), dtype=torch.bfloat16
  334. ),
  335. )
  336. def forward(self, input: torch.Tensor) -> torch.Tensor:
  337. input = input.to(torch.bfloat16)
  338. if self.padding:
  339. import torch.nn.functional as F
  340. input = F.pad(input, pad=(0, self.in_features - self.origin_in_features))
  341. return linear_forward_int4(
  342. input, self.weight, self.scales_and_zeros, self.out_features, self.groupsize
  343. )
  344. def generate_folder_name():
  345. now = datetime.datetime.now()
  346. folder_name = now.strftime("%Y%m%d_%H%M%S")
  347. return folder_name
  348. @click.command()
  349. @click.option(
  350. "--checkpoint-path",
  351. type=click.Path(path_type=Path, exists=True),
  352. default="checkpoints/fish-speech-1.4",
  353. )
  354. @click.option(
  355. "--mode", type=str, default="int8", help="type of quantization to perform"
  356. )
  357. @click.option(
  358. "--groupsize", type=int, default=128, help="Group size for int4 quantization."
  359. )
  360. @click.option("--timestamp", type=str, default="None", help="When to do quantization")
  361. def quantize(checkpoint_path: Path, mode: str, groupsize: int, timestamp: str) -> None:
  362. device = "cpu"
  363. precision = torch.bfloat16
  364. print("Loading model ...")
  365. t0 = time.time()
  366. model, _ = load_model(
  367. checkpoint_path=checkpoint_path,
  368. device=device,
  369. precision=precision,
  370. compile=False,
  371. )
  372. vq_model = "codec.pth"
  373. now = timestamp if timestamp != "None" else generate_folder_name()
  374. if mode == "int8":
  375. print(
  376. "Quantizing model weights for int8 weight-only symmetric per-channel quantization"
  377. )
  378. quant_handler = WeightOnlyInt8QuantHandler(model)
  379. quantized_state_dict = quant_handler.create_quantized_state_dict()
  380. dir_name = checkpoint_path
  381. dst_name = Path(f"checkpoints/fs-1.2-int8-{now}")
  382. shutil.copytree(str(dir_name.resolve()), str(dst_name.resolve()))
  383. if (dst_name / vq_model).exists():
  384. (dst_name / vq_model).unlink()
  385. quantize_path = dst_name / "model.pth"
  386. elif mode == "int4":
  387. print(
  388. "Quantizing model weights for int4 weight-only affine per-channel groupwise quantization"
  389. )
  390. quant_handler = WeightOnlyInt4QuantHandler(model, groupsize)
  391. quantized_state_dict = quant_handler.create_quantized_state_dict()
  392. dir_name = checkpoint_path
  393. dst_name = Path(f"checkpoints/fs-1.2-int4-g{groupsize}-{now}")
  394. shutil.copytree(str(dir_name.resolve()), str(dst_name.resolve()))
  395. if (dst_name / vq_model).exists():
  396. (dst_name / vq_model).unlink()
  397. quantize_path = dst_name / "model.pth"
  398. else:
  399. raise ValueError(
  400. f"Invalid quantization mode {mode} needs to be one of [int8, int4, int4-gpptq]"
  401. )
  402. print(f"Writing quantized weights to {quantize_path}")
  403. quantize_path.unlink(missing_ok=True) # remove existing file if one already there
  404. torch.save(quantized_state_dict, quantize_path)
  405. print(f"Quantization complete took {time.time() - t0:.02f} seconds")
  406. if __name__ == "__main__":
  407. quantize()