bisenet.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from .resnet import ResNet18
  5. class ConvBNReLU(nn.Module):
  6. def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1):
  7. super(ConvBNReLU, self).__init__()
  8. self.conv = nn.Conv2d(
  9. in_chan,
  10. out_chan,
  11. kernel_size=ks,
  12. stride=stride,
  13. padding=padding,
  14. bias=False,
  15. )
  16. self.bn = nn.BatchNorm2d(out_chan)
  17. def forward(self, x):
  18. x = self.conv(x)
  19. x = F.relu(self.bn(x))
  20. return x
  21. class BiSeNetOutput(nn.Module):
  22. def __init__(self, in_chan, mid_chan, num_class):
  23. super(BiSeNetOutput, self).__init__()
  24. self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1)
  25. self.conv_out = nn.Conv2d(mid_chan, num_class, kernel_size=1, bias=False)
  26. def forward(self, x):
  27. feat = self.conv(x)
  28. out = self.conv_out(feat)
  29. return out, feat
  30. class AttentionRefinementModule(nn.Module):
  31. def __init__(self, in_chan, out_chan):
  32. super(AttentionRefinementModule, self).__init__()
  33. self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1)
  34. self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size=1, bias=False)
  35. self.bn_atten = nn.BatchNorm2d(out_chan)
  36. self.sigmoid_atten = nn.Sigmoid()
  37. def forward(self, x):
  38. feat = self.conv(x)
  39. atten = F.avg_pool2d(feat, feat.size()[2:])
  40. atten = self.conv_atten(atten)
  41. atten = self.bn_atten(atten)
  42. atten = self.sigmoid_atten(atten)
  43. out = torch.mul(feat, atten)
  44. return out
  45. class ContextPath(nn.Module):
  46. def __init__(self):
  47. super(ContextPath, self).__init__()
  48. self.resnet = ResNet18()
  49. self.arm16 = AttentionRefinementModule(256, 128)
  50. self.arm32 = AttentionRefinementModule(512, 128)
  51. self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
  52. self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
  53. self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0)
  54. def forward(self, x):
  55. feat8, feat16, feat32 = self.resnet(x)
  56. h8, w8 = feat8.size()[2:]
  57. h16, w16 = feat16.size()[2:]
  58. h32, w32 = feat32.size()[2:]
  59. avg = F.avg_pool2d(feat32, feat32.size()[2:])
  60. avg = self.conv_avg(avg)
  61. avg_up = F.interpolate(avg, (h32, w32), mode="nearest")
  62. feat32_arm = self.arm32(feat32)
  63. feat32_sum = feat32_arm + avg_up
  64. feat32_up = F.interpolate(feat32_sum, (h16, w16), mode="nearest")
  65. feat32_up = self.conv_head32(feat32_up)
  66. feat16_arm = self.arm16(feat16)
  67. feat16_sum = feat16_arm + feat32_up
  68. feat16_up = F.interpolate(feat16_sum, (h8, w8), mode="nearest")
  69. feat16_up = self.conv_head16(feat16_up)
  70. return feat8, feat16_up, feat32_up # x8, x8, x16
  71. class FeatureFusionModule(nn.Module):
  72. def __init__(self, in_chan, out_chan):
  73. super(FeatureFusionModule, self).__init__()
  74. self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0)
  75. self.conv1 = nn.Conv2d(
  76. out_chan, out_chan // 4, kernel_size=1, stride=1, padding=0, bias=False
  77. )
  78. self.conv2 = nn.Conv2d(
  79. out_chan // 4, out_chan, kernel_size=1, stride=1, padding=0, bias=False
  80. )
  81. self.relu = nn.ReLU(inplace=True)
  82. self.sigmoid = nn.Sigmoid()
  83. def forward(self, fsp, fcp):
  84. fcat = torch.cat([fsp, fcp], dim=1)
  85. feat = self.convblk(fcat)
  86. atten = F.avg_pool2d(feat, feat.size()[2:])
  87. atten = self.conv1(atten)
  88. atten = self.relu(atten)
  89. atten = self.conv2(atten)
  90. atten = self.sigmoid(atten)
  91. feat_atten = torch.mul(feat, atten)
  92. feat_out = feat_atten + feat
  93. return feat_out
  94. class BiSeNet(nn.Module):
  95. def __init__(self, num_class):
  96. super(BiSeNet, self).__init__()
  97. self.cp = ContextPath()
  98. self.ffm = FeatureFusionModule(256, 256)
  99. self.conv_out = BiSeNetOutput(256, 256, num_class)
  100. self.conv_out16 = BiSeNetOutput(128, 64, num_class)
  101. self.conv_out32 = BiSeNetOutput(128, 64, num_class)
  102. def forward(self, x, return_feat=False):
  103. h, w = x.size()[2:]
  104. feat_res8, feat_cp8, feat_cp16 = self.cp(x) # return res3b1 feature
  105. feat_sp = feat_res8 # replace spatial path feature with res3b1 feature
  106. feat_fuse = self.ffm(feat_sp, feat_cp8)
  107. out, feat = self.conv_out(feat_fuse)
  108. out16, feat16 = self.conv_out16(feat_cp8)
  109. out32, feat32 = self.conv_out32(feat_cp16)
  110. out = F.interpolate(out, (h, w), mode="bilinear", align_corners=True)
  111. out16 = F.interpolate(out16, (h, w), mode="bilinear", align_corners=True)
  112. out32 = F.interpolate(out32, (h, w), mode="bilinear", align_corners=True)
  113. if return_feat:
  114. feat = F.interpolate(feat, (h, w), mode="bilinear", align_corners=True)
  115. feat16 = F.interpolate(feat16, (h, w), mode="bilinear", align_corners=True)
  116. feat32 = F.interpolate(feat32, (h, w), mode="bilinear", align_corners=True)
  117. return out, out16, out32, feat, feat16, feat32
  118. else:
  119. return out, out16, out32