tp.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. # This source code is licensed under the license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. import os
  6. from typing import List, Optional
  7. import torch
  8. import torch.distributed as dist
  9. from quantize import WeightOnlyInt4Linear
  10. from torch import nn
  11. from torch.distributed import _functional_collectives as funcol
  12. from fish_speech.models.text2semantic.llama import Attention, FeedForward, Transformer
  13. def _get_rank() -> int:
  14. return int(os.environ.get("LOCAL_RANK", "0"))
  15. def is_local():
  16. return _get_rank() == 0
  17. def local_break():
  18. if is_local():
  19. breakpoint()
  20. dist.barrier()
  21. def _get_world_size() -> int:
  22. return int(os.environ.get("LOCAL_WORLD_SIZE", "1"))
  23. def maybe_init_dist() -> Optional[int]:
  24. try:
  25. # provided by torchrun
  26. rank = _get_rank()
  27. world_size = _get_world_size()
  28. if world_size < 2:
  29. # too few gpus to parallelize, tp is no-op
  30. return None
  31. except KeyError:
  32. # not run via torchrun, no-op
  33. return None
  34. dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
  35. return rank
  36. def _apply_tp_linear(
  37. linear: nn.Linear, style: str, weight_splits: List[int] = []
  38. ) -> None:
  39. rank = _get_rank()
  40. world_size = _get_world_size()
  41. # Linear's weight matrix is transposed, and is of shape
  42. # (linear.out_features, linear.in_features)
  43. dim_lookup = {"colwise": (0, "out_features"), "rowwise": (1, "in_features")}
  44. assert style in dim_lookup
  45. shard_dim, size_attr = dim_lookup[style]
  46. # ensure we can shard evenly
  47. assert getattr(linear, size_attr) % world_size == 0
  48. def shard(x, dim):
  49. assert x.size(dim=dim) % world_size == 0
  50. return torch.tensor_split(x, world_size, dim=dim)[rank]
  51. def shard_qkv(qkv, dim, weight_splits):
  52. q, k, v = qkv.split(weight_splits, dim=dim)
  53. q = shard(q, dim)
  54. k = shard(k, dim)
  55. v = shard(v, dim)
  56. return torch.cat((q, k, v), dim=dim)
  57. # shard
  58. if weight_splits:
  59. # attention
  60. assert len(weight_splits) == 3
  61. if isinstance(linear, WeightOnlyInt4Linear):
  62. sharded_weight = shard_qkv(
  63. linear.weight, shard_dim, [i // 8 for i in weight_splits]
  64. )
  65. linear.scales_and_zeros = shard_qkv(
  66. linear.scales_and_zeros, 1 - shard_dim, weight_splits
  67. )
  68. else:
  69. sharded_weight = shard_qkv(linear.weight, shard_dim, weight_splits)
  70. if hasattr(linear, "scales") and style == "colwise":
  71. linear.scales = shard_qkv(linear.scales, 0, weight_splits)
  72. else:
  73. sharded_weight = shard(linear.weight, shard_dim)
  74. if isinstance(linear, WeightOnlyInt4Linear):
  75. linear.scales_and_zeros = shard(linear.scales_and_zeros, 1 - shard_dim)
  76. if style == "rowwise":
  77. assert (
  78. linear.scales_and_zeros.shape[0] * 32
  79. == sharded_weight.shape[1]
  80. * sharded_weight.shape[2]
  81. * sharded_weight.shape[3]
  82. )
  83. assert linear.scales_and_zeros.shape[1] == sharded_weight.shape[0] * 8
  84. if hasattr(linear, "scales") and style == "colwise":
  85. linear.scales = shard(linear.scales, 0)
  86. # local_break()
  87. linear.weight = nn.Parameter(sharded_weight, requires_grad=False)
  88. setattr(linear, size_attr, getattr(linear, size_attr) // world_size)
  89. # shape info should still be synced
  90. # assert linear.weight.shape == (linear.out_features, linear.in_features)
  91. def _apply_tp_ffn(mlp: FeedForward) -> None:
  92. assert hasattr(mlp, "w1")
  93. assert hasattr(mlp, "w3")
  94. assert hasattr(mlp, "w2")
  95. _apply_tp_linear(mlp.w1, "colwise")
  96. _apply_tp_linear(mlp.w3, "colwise")
  97. _apply_tp_linear(mlp.w2, "rowwise")
  98. world_size = _get_world_size()
  99. mlp.register_forward_hook(
  100. lambda _module, _input, output: funcol.all_reduce(
  101. output, "sum", list(range(world_size))
  102. )
  103. )
  104. def _apply_tp_attn(attn: Attention) -> None:
  105. assert hasattr(attn, "wqkv")
  106. assert hasattr(attn, "wo")
  107. kv_size = attn.n_local_heads * attn.head_dim
  108. _apply_tp_linear(attn.wqkv, "colwise", [attn.dim, kv_size, kv_size])
  109. _apply_tp_linear(attn.wo, "rowwise")
  110. # overwrite
  111. world_size = _get_world_size()
  112. attn.n_head = attn.n_head // world_size
  113. attn.dim = attn.dim // world_size
  114. attn.head_dim = attn.dim // attn.n_head
  115. attn.n_local_heads = attn.n_local_heads // world_size
  116. attn.register_forward_hook(
  117. lambda _module, _input, output: funcol.all_reduce(
  118. output[0], "sum", list(range(world_size))
  119. )
  120. )
  121. def _apply_tp_Transformer(Transformer: Transformer) -> None:
  122. # overwrite config before Transformer.setup_cache is called
  123. world_size = _get_world_size()
  124. Transformer.config.n_head = Transformer.config.n_head // world_size
  125. Transformer.config.dim = Transformer.config.dim // world_size
  126. Transformer.config.n_local_heads = Transformer.config.n_local_heads // world_size
  127. def apply_tp(model: Transformer) -> None:
  128. _apply_tp_Transformer(model)
  129. for block in model.layers:
  130. # Apply to MLP
  131. _apply_tp_ffn(block.feed_forward)
  132. _apply_tp_attn(block.attention)