balancer.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. import typing as tp
  2. from collections import defaultdict
  3. import torch
  4. from torch import autograd
  5. def rank():
  6. if torch.distributed.is_initialized():
  7. return torch.distributed.get_rank()
  8. else:
  9. return 0
  10. def world_size():
  11. if torch.distributed.is_initialized():
  12. return torch.distributed.get_world_size()
  13. else:
  14. return 1
  15. def is_distributed():
  16. return world_size() > 1
  17. def average_metrics(metrics: tp.Dict[str, float], count=1.0):
  18. """Average a dictionary of metrics across all workers, using the optional
  19. `count` as unnormalized weight.
  20. """
  21. if not is_distributed():
  22. return metrics
  23. keys, values = zip(*metrics.items())
  24. device = "cuda" if torch.cuda.is_available() else "cpu"
  25. tensor = torch.tensor(list(values) + [1], device=device, dtype=torch.float32)
  26. tensor *= count
  27. all_reduce(tensor)
  28. averaged = (tensor[:-1] / tensor[-1]).cpu().tolist()
  29. return dict(zip(keys, averaged))
  30. def all_reduce(tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM):
  31. if is_distributed():
  32. return torch.distributed.all_reduce(tensor, op)
  33. def averager(beta: float = 1):
  34. """
  35. Exponential Moving Average callback.
  36. Returns a single function that can be called to repeatidly update the EMA
  37. with a dict of metrics. The callback will return
  38. the new averaged dict of metrics.
  39. Note that for `beta=1`, this is just plain averaging.
  40. """
  41. fix: tp.Dict[str, float] = defaultdict(float)
  42. total: tp.Dict[str, float] = defaultdict(float)
  43. def _update(
  44. metrics: tp.Dict[str, tp.Any], weight: float = 1
  45. ) -> tp.Dict[str, float]:
  46. nonlocal total, fix
  47. for key, value in metrics.items():
  48. total[key] = total[key] * beta + weight * float(value)
  49. fix[key] = fix[key] * beta + weight
  50. return {key: tot / fix[key] for key, tot in total.items()}
  51. return _update
  52. class Balancer:
  53. """Loss balancer.
  54. The loss balancer combines losses together to compute gradients for the backward.
  55. A call to the balancer will weight the losses according the specified weight coefficients.
  56. A call to the backward method of the balancer will compute the gradients, combining all the losses and
  57. potentially rescaling the gradients, which can help stabilize the training and reasonate
  58. about multiple losses with varying scales.
  59. Expected usage:
  60. weights = {'loss_a': 1, 'loss_b': 4}
  61. balancer = Balancer(weights, ...)
  62. losses: dict = {}
  63. losses['loss_a'] = compute_loss_a(x, y)
  64. losses['loss_b'] = compute_loss_b(x, y)
  65. if model.training():
  66. balancer.backward(losses, x)
  67. ..Warning:: It is unclear how this will interact with DistributedDataParallel,
  68. in particular if you have some losses not handled by the balancer. In that case
  69. you can use `encodec.distrib.sync_grad(model.parameters())` and
  70. `encodec.distrib.sync_buffwers(model.buffers())` as a safe alternative.
  71. Args:
  72. weights (Dict[str, float]): Weight coefficient for each loss. The balancer expect the losses keys
  73. from the backward method to match the weights keys to assign weight to each of the provided loss.
  74. rescale_grads (bool): Whether to rescale gradients or not, without. If False, this is just
  75. a regular weighted sum of losses.
  76. total_norm (float): Reference norm when rescaling gradients, ignored otherwise.
  77. emay_decay (float): EMA decay for averaging the norms when `rescale_grads` is True.
  78. per_batch_item (bool): Whether to compute the averaged norm per batch item or not. This only holds
  79. when rescaling the gradients.
  80. epsilon (float): Epsilon value for numerical stability.
  81. monitor (bool): Whether to store additional ratio for each loss key in metrics.
  82. """ # noqa: E501
  83. def __init__(
  84. self,
  85. weights: tp.Dict[str, float],
  86. rescale_grads: bool = True,
  87. total_norm: float = 1.0,
  88. ema_decay: float = 0.999,
  89. per_batch_item: bool = True,
  90. epsilon: float = 1e-12,
  91. monitor: bool = False,
  92. ):
  93. self.weights = weights
  94. self.per_batch_item = per_batch_item
  95. self.total_norm = total_norm
  96. self.averager = averager(ema_decay)
  97. self.epsilon = epsilon
  98. self.monitor = monitor
  99. self.rescale_grads = rescale_grads
  100. self._metrics: tp.Dict[str, tp.Any] = {}
  101. @property
  102. def metrics(self):
  103. return self._metrics
  104. def compute(self, losses: tp.Dict[str, torch.Tensor], input: torch.Tensor):
  105. norms = {}
  106. grads = {}
  107. for name, loss in losses.items():
  108. (grad,) = autograd.grad(loss, [input], retain_graph=True)
  109. if self.per_batch_item:
  110. dims = tuple(range(1, grad.dim()))
  111. norm = grad.norm(dim=dims).mean()
  112. else:
  113. norm = grad.norm()
  114. norms[name] = norm
  115. grads[name] = grad
  116. count = 1
  117. if self.per_batch_item:
  118. count = len(grad)
  119. avg_norms = average_metrics(self.averager(norms), count)
  120. total = sum(avg_norms.values())
  121. self._metrics = {}
  122. if self.monitor:
  123. for k, v in avg_norms.items():
  124. self._metrics[f"ratio_{k}"] = v / total
  125. total_weights = sum([self.weights[k] for k in avg_norms])
  126. ratios = {k: w / total_weights for k, w in self.weights.items()}
  127. out_grad: tp.Any = 0
  128. for name, avg_norm in avg_norms.items():
  129. if self.rescale_grads:
  130. scale = ratios[name] * self.total_norm / (self.epsilon + avg_norm)
  131. grad = grads[name] * scale
  132. else:
  133. grad = self.weights[name] * grads[name]
  134. out_grad += grad
  135. return out_grad
  136. def test():
  137. from torch.nn import functional as F
  138. x = torch.zeros(1, requires_grad=True)
  139. one = torch.ones_like(x)
  140. loss_1 = F.l1_loss(x, one)
  141. loss_2 = 100 * F.l1_loss(x, -one)
  142. losses = {"1": loss_1, "2": loss_2}
  143. balancer = Balancer(weights={"1": 1, "2": 1}, rescale_grads=False)
  144. out_grad = balancer.compute(losses, x)
  145. x.backward(out_grad)
  146. assert torch.allclose(x.grad, torch.tensor(99.0)), x.grad
  147. loss_1 = F.l1_loss(x, one)
  148. loss_2 = 100 * F.l1_loss(x, -one)
  149. losses = {"1": loss_1, "2": loss_2}
  150. x.grad = None
  151. balancer = Balancer(weights={"1": 1, "2": 1}, rescale_grads=True)
  152. out_grad = balancer.compute({"1": loss_1, "2": loss_2}, x)
  153. x.backward(out_grad)
  154. assert torch.allclose(x.grad, torch.tensor(0.0)), x.grad
  155. if __name__ == "__main__":
  156. test()