arch_util.py 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. from typing import List, Type, Union
  2. import torch
  3. from torch import nn as nn
  4. from torch.nn import init as init
  5. from torch.nn.modules.batchnorm import _BatchNorm
  6. @torch.no_grad()
  7. def default_init_weights(
  8. module_list: Union[List[nn.Module], nn.Module],
  9. scale: float = 1,
  10. bias_fill: float = 0,
  11. **kwargs,
  12. ) -> None:
  13. """Initialize network weights.
  14. Args:
  15. module_list (list[nn.Module] | nn.Module): Modules to be initialized.
  16. scale (float): Scale initialized weights, especially for residual
  17. blocks. Default: 1.
  18. bias_fill (float): The value to fill bias. Default: 0
  19. kwargs (dict): Other arguments for initialization function.
  20. """
  21. if not isinstance(module_list, list):
  22. module_list = [module_list]
  23. for module in module_list:
  24. for m in module.modules():
  25. if isinstance(m, nn.Conv2d):
  26. init.kaiming_normal_(m.weight, **kwargs)
  27. m.weight.data *= scale
  28. if m.bias is not None:
  29. m.bias.data.fill_(bias_fill)
  30. elif isinstance(m, nn.Linear):
  31. init.kaiming_normal_(m.weight, **kwargs)
  32. m.weight.data *= scale
  33. if m.bias is not None:
  34. m.bias.data.fill_(bias_fill)
  35. elif isinstance(m, _BatchNorm):
  36. init.constant_(m.weight, 1)
  37. if m.bias is not None:
  38. m.bias.data.fill_(bias_fill)
  39. def make_layer(
  40. basic_block: Type[nn.Module], num_basic_block: int, **kwarg
  41. ) -> nn.Sequential:
  42. """Make layers by stacking the same blocks.
  43. Args:
  44. basic_block (Type[nn.Module]): nn.Module class for basic block.
  45. num_basic_block (int): number of blocks.
  46. Returns:
  47. nn.Sequential: Stacked blocks in nn.Sequential.
  48. """
  49. layers = []
  50. for _ in range(num_basic_block):
  51. layers.append(basic_block(**kwarg))
  52. return nn.Sequential(*layers)
  53. # TODO: may write a cpp file
  54. def pixel_unshuffle(x: torch.Tensor, scale: int) -> torch.Tensor:
  55. """Pixel unshuffle.
  56. Args:
  57. x (Tensor): Input feature with shape (b, c, hh, hw).
  58. scale (int): Downsample ratio.
  59. Returns:
  60. Tensor: the pixel unshuffled feature.
  61. """
  62. b, c, hh, hw = x.size()
  63. out_channel = c * (scale**2)
  64. assert hh % scale == 0 and hw % scale == 0
  65. h = hh // scale
  66. w = hw // scale
  67. x_view = x.view(b, c, h, scale, w, scale)
  68. return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)