| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980 |
- from typing import List, Type, Union
- import torch
- from torch import nn as nn
- from torch.nn import init as init
- from torch.nn.modules.batchnorm import _BatchNorm
- @torch.no_grad()
- def default_init_weights(
- module_list: Union[List[nn.Module], nn.Module],
- scale: float = 1,
- bias_fill: float = 0,
- **kwargs,
- ) -> None:
- """Initialize network weights.
- Args:
- module_list (list[nn.Module] | nn.Module): Modules to be initialized.
- scale (float): Scale initialized weights, especially for residual
- blocks. Default: 1.
- bias_fill (float): The value to fill bias. Default: 0
- kwargs (dict): Other arguments for initialization function.
- """
- if not isinstance(module_list, list):
- module_list = [module_list]
- for module in module_list:
- for m in module.modules():
- if isinstance(m, nn.Conv2d):
- init.kaiming_normal_(m.weight, **kwargs)
- m.weight.data *= scale
- if m.bias is not None:
- m.bias.data.fill_(bias_fill)
- elif isinstance(m, nn.Linear):
- init.kaiming_normal_(m.weight, **kwargs)
- m.weight.data *= scale
- if m.bias is not None:
- m.bias.data.fill_(bias_fill)
- elif isinstance(m, _BatchNorm):
- init.constant_(m.weight, 1)
- if m.bias is not None:
- m.bias.data.fill_(bias_fill)
- def make_layer(
- basic_block: Type[nn.Module], num_basic_block: int, **kwarg
- ) -> nn.Sequential:
- """Make layers by stacking the same blocks.
- Args:
- basic_block (Type[nn.Module]): nn.Module class for basic block.
- num_basic_block (int): number of blocks.
- Returns:
- nn.Sequential: Stacked blocks in nn.Sequential.
- """
- layers = []
- for _ in range(num_basic_block):
- layers.append(basic_block(**kwarg))
- return nn.Sequential(*layers)
- # TODO: may write a cpp file
- def pixel_unshuffle(x: torch.Tensor, scale: int) -> torch.Tensor:
- """Pixel unshuffle.
- Args:
- x (Tensor): Input feature with shape (b, c, hh, hw).
- scale (int): Downsample ratio.
- Returns:
- Tensor: the pixel unshuffled feature.
- """
- b, c, hh, hw = x.size()
- out_channel = c * (scale**2)
- assert hh % scale == 0 and hw % scale == 0
- h = hh // scale
- w = hw // scale
- x_view = x.view(b, c, h, scale, w, scale)
- return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
|