briarmbg2.py 98 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506250725082509251025112512251325142515251625172518251925202521252225232524252525262527252825292530253125322533253425352536253725382539254025412542254325442545254625472548254925502551255225532554255525562557255825592560256125622563256425652566256725682569257025712572257325742575257625772578257925802581258225832584258525862587258825892590259125922593259425952596259725982599260026012602260326042605260626072608260926102611261226132614261526162617261826192620262126222623262426252626262726282629263026312632263326342635263626372638263926402641264226432644264526462647264826492650265126522653265426552656265726582659266026612662266326642665266626672668266926702671267226732674267526762677267826792680268126822683268426852686268726882689269026912692269326942695269626972698269927002701270227032704270527062707270827092710271127122713271427152716271727182719272027212722272327242725272627272728272927302731273227332734273527362737273827392740274127422743274427452746274727482749275027512752275327542755275627572758275927602761276227632764276527662767276827692770277127722773277427752776277727782779278027812782278327842785278627872788278927902791279227932794279527962797279827992800280128022803280428052806280728082809281028112812281328142815281628172818281928202821282228232824282528262827282828292830283128322833283428352836283728382839284028412842284328442845284628472848284928502851285228532854285528562857285828592860286128622863286428652866286728682869287028712872287328742875287628772878287928802881288228832884288528862887288828892890289128922893289428952896289728982899290029012902290329042905290629072908290929102911291229132914291529162917291829192920292129222923292429252926292729282929293029312932293329342935293629372938293929402941294229432944294529462947294829492950295129522953295429552956
  1. # copy from https://huggingface.co/briaai/RMBG-2.0/tree/main
  2. import math
  3. import os
  4. from collections import OrderedDict
  5. from functools import partial
  6. import numpy as np
  7. import torch
  8. import torch.nn as nn
  9. import torch.nn.functional as F
  10. import torch.utils.checkpoint as checkpoint
  11. from timm.models.layers import DropPath, to_2tuple, trunc_normal_
  12. from torchvision.models import (
  13. ResNet50_Weights,
  14. VGG16_BN_Weights,
  15. VGG16_Weights,
  16. resnet50,
  17. vgg16,
  18. vgg16_bn,
  19. )
  20. from transformers import PretrainedConfig, PreTrainedModel
  21. class Config:
  22. def __init__(self) -> None:
  23. # PATH settings
  24. self.sys_home_dir = os.path.expanduser(
  25. "~"
  26. ) # Make up your file system as: SYS_HOME_DIR/codes/dis/BiRefNet, SYS_HOME_DIR/datasets/dis/xx, SYS_HOME_DIR/weights/xx
  27. # TASK settings
  28. self.task = ["DIS5K", "COD", "HRSOD", "DIS5K+HRSOD+HRS10K", "P3M-10k"][0]
  29. self.training_set = {
  30. "DIS5K": ["DIS-TR", "DIS-TR+DIS-TE1+DIS-TE2+DIS-TE3+DIS-TE4"][0],
  31. "COD": "TR-COD10K+TR-CAMO",
  32. "HRSOD": [
  33. "TR-DUTS",
  34. "TR-HRSOD",
  35. "TR-UHRSD",
  36. "TR-DUTS+TR-HRSOD",
  37. "TR-DUTS+TR-UHRSD",
  38. "TR-HRSOD+TR-UHRSD",
  39. "TR-DUTS+TR-HRSOD+TR-UHRSD",
  40. ][5],
  41. "DIS5K+HRSOD+HRS10K": "DIS-TE1+DIS-TE2+DIS-TE3+DIS-TE4+DIS-TR+TE-HRS10K+TE-HRSOD+TE-UHRSD+TR-HRS10K+TR-HRSOD+TR-UHRSD", # leave DIS-VD for evaluation.
  42. "P3M-10k": "TR-P3M-10k",
  43. }[self.task]
  44. self.prompt4loc = ["dense", "sparse"][0]
  45. # Faster-Training settings
  46. self.load_all = True
  47. self.compile = True # 1. Trigger CPU memory leak in some extend, which is an inherent problem of PyTorch.
  48. # Machines with > 70GB CPU memory can run the whole training on DIS5K with default setting.
  49. # 2. Higher PyTorch version may fix it: https://github.com/pytorch/pytorch/issues/119607.
  50. # 3. But compile in Pytorch > 2.0.1 seems to bring no acceleration for training.
  51. self.precisionHigh = True
  52. # MODEL settings
  53. self.ms_supervision = True
  54. self.out_ref = self.ms_supervision and True
  55. self.dec_ipt = True
  56. self.dec_ipt_split = True
  57. self.cxt_num = [0, 3][1] # multi-scale skip connections from encoder
  58. self.mul_scl_ipt = ["", "add", "cat"][2]
  59. self.dec_att = ["", "ASPP", "ASPPDeformable"][2]
  60. self.squeeze_block = [
  61. "",
  62. "BasicDecBlk_x1",
  63. "ResBlk_x4",
  64. "ASPP_x3",
  65. "ASPPDeformable_x3",
  66. ][1]
  67. self.dec_blk = ["BasicDecBlk", "ResBlk", "HierarAttDecBlk"][0]
  68. # TRAINING settings
  69. self.batch_size = 4
  70. self.IoU_finetune_last_epochs = [
  71. 0,
  72. {
  73. "DIS5K": -50,
  74. "COD": -20,
  75. "HRSOD": -20,
  76. "DIS5K+HRSOD+HRS10K": -20,
  77. "P3M-10k": -20,
  78. }[self.task],
  79. ][
  80. 1
  81. ] # choose 0 to skip
  82. self.lr = (1e-4 if "DIS5K" in self.task else 1e-5) * math.sqrt(
  83. self.batch_size / 4
  84. ) # DIS needs high lr to converge faster. Adapt the lr linearly
  85. self.size = 1024
  86. self.num_workers = max(
  87. 4, self.batch_size
  88. ) # will be decrease to min(it, batch_size) at the initialization of the data_loader
  89. # Backbone settings
  90. self.bb = [
  91. "vgg16",
  92. "vgg16bn",
  93. "resnet50", # 0, 1, 2
  94. "swin_v1_t",
  95. "swin_v1_s", # 3, 4
  96. "swin_v1_b",
  97. "swin_v1_l", # 5-bs9, 6-bs4
  98. "pvt_v2_b0",
  99. "pvt_v2_b1", # 7, 8
  100. "pvt_v2_b2",
  101. "pvt_v2_b5", # 9-bs10, 10-bs5
  102. ][6]
  103. self.lateral_channels_in_collection = {
  104. "vgg16": [512, 256, 128, 64],
  105. "vgg16bn": [512, 256, 128, 64],
  106. "resnet50": [1024, 512, 256, 64],
  107. "pvt_v2_b2": [512, 320, 128, 64],
  108. "pvt_v2_b5": [512, 320, 128, 64],
  109. "swin_v1_b": [1024, 512, 256, 128],
  110. "swin_v1_l": [1536, 768, 384, 192],
  111. "swin_v1_t": [768, 384, 192, 96],
  112. "swin_v1_s": [768, 384, 192, 96],
  113. "pvt_v2_b0": [256, 160, 64, 32],
  114. "pvt_v2_b1": [512, 320, 128, 64],
  115. }[self.bb]
  116. if self.mul_scl_ipt == "cat":
  117. self.lateral_channels_in_collection = [
  118. channel * 2 for channel in self.lateral_channels_in_collection
  119. ]
  120. self.cxt = (
  121. self.lateral_channels_in_collection[1:][::-1][-self.cxt_num :]
  122. if self.cxt_num
  123. else []
  124. )
  125. # MODEL settings - inactive
  126. self.lat_blk = ["BasicLatBlk"][0]
  127. self.dec_channels_inter = ["fixed", "adap"][0]
  128. self.refine = ["", "itself", "RefUNet", "Refiner", "RefinerPVTInChannels4"][0]
  129. self.progressive_ref = self.refine and True
  130. self.ender = self.progressive_ref and False
  131. self.scale = self.progressive_ref and 2
  132. self.auxiliary_classification = (
  133. False # Only for DIS5K, where class labels are saved in `dataset.py`.
  134. )
  135. self.refine_iteration = 1
  136. self.freeze_bb = False
  137. self.model = [
  138. "BiRefNet",
  139. ][0]
  140. if self.dec_blk == "HierarAttDecBlk":
  141. self.batch_size = 2 ** [0, 1, 2, 3, 4][2]
  142. # TRAINING settings - inactive
  143. self.preproc_methods = ["flip", "enhance", "rotate", "pepper", "crop"][:4]
  144. self.optimizer = ["Adam", "AdamW"][1]
  145. self.lr_decay_epochs = [
  146. 1e5
  147. ] # Set to negative N to decay the lr in the last N-th epoch.
  148. self.lr_decay_rate = 0.5
  149. # Loss
  150. self.lambdas_pix_last = {
  151. # not 0 means opening this loss
  152. # original rate -- 1 : 30 : 1.5 : 0.2, bce x 30
  153. "bce": 30 * 1, # high performance
  154. "iou": 0.5 * 1, # 0 / 255
  155. "iou_patch": 0.5 * 0, # 0 / 255, win_size = (64, 64)
  156. "mse": 150 * 0, # can smooth the saliency map
  157. "triplet": 3 * 0,
  158. "reg": 100 * 0,
  159. "ssim": 10 * 1, # help contours,
  160. "cnt": 5 * 0, # help contours
  161. "structure": 5
  162. * 0, # structure loss from codes of MVANet. A little improvement on DIS-TE[1,2,3], a bit more decrease on DIS-TE4.
  163. }
  164. self.lambdas_cls = {"ce": 5.0}
  165. # Adv
  166. self.lambda_adv_g = 10.0 * 0 # turn to 0 to avoid adv training
  167. self.lambda_adv_d = 3.0 * (self.lambda_adv_g > 0)
  168. # PATH settings - inactive
  169. self.data_root_dir = os.path.join(self.sys_home_dir, "datasets/dis")
  170. self.weights_root_dir = os.path.join(self.sys_home_dir, "weights")
  171. self.weights = {
  172. "pvt_v2_b2": os.path.join(self.weights_root_dir, "pvt_v2_b2.pth"),
  173. "pvt_v2_b5": os.path.join(
  174. self.weights_root_dir, ["pvt_v2_b5.pth", "pvt_v2_b5_22k.pth"][0]
  175. ),
  176. "swin_v1_b": os.path.join(
  177. self.weights_root_dir,
  178. [
  179. "swin_base_patch4_window12_384_22kto1k.pth",
  180. "swin_base_patch4_window12_384_22k.pth",
  181. ][0],
  182. ),
  183. "swin_v1_l": os.path.join(
  184. self.weights_root_dir,
  185. [
  186. "swin_large_patch4_window12_384_22kto1k.pth",
  187. "swin_large_patch4_window12_384_22k.pth",
  188. ][0],
  189. ),
  190. "swin_v1_t": os.path.join(
  191. self.weights_root_dir,
  192. ["swin_tiny_patch4_window7_224_22kto1k_finetune.pth"][0],
  193. ),
  194. "swin_v1_s": os.path.join(
  195. self.weights_root_dir,
  196. ["swin_small_patch4_window7_224_22kto1k_finetune.pth"][0],
  197. ),
  198. "pvt_v2_b0": os.path.join(self.weights_root_dir, ["pvt_v2_b0.pth"][0]),
  199. "pvt_v2_b1": os.path.join(self.weights_root_dir, ["pvt_v2_b1.pth"][0]),
  200. }
  201. # Callbacks - inactive
  202. self.verbose_eval = True
  203. self.only_S_MAE = False
  204. self.use_fp16 = False # Bugs. It may cause nan in training.
  205. self.SDPA_enabled = False # Bugs. Slower and errors occur in multi-GPUs
  206. # others
  207. self.device = [0, "cpu"][0] # .to(0) == .to('cuda:0')
  208. self.batch_size_valid = 1
  209. self.rand_seed = 7
  210. # run_sh_file = [f for f in os.listdir('.') if 'train.sh' == f] + [os.path.join('..', f) for f in os.listdir('..') if 'train.sh' == f]
  211. # with open(run_sh_file[0], 'r') as f:
  212. # lines = f.readlines()
  213. # self.save_last = int([l.strip() for l in lines if '"{}")'.format(self.task) in l and 'val_last=' in l][0].split('val_last=')[-1].split()[0])
  214. # self.save_step = int([l.strip() for l in lines if '"{}")'.format(self.task) in l and 'step=' in l][0].split('step=')[-1].split()[0])
  215. # self.val_step = [0, self.save_step][0]
  216. def print_task(self) -> None:
  217. # Return task for choosing settings in shell scripts.
  218. print(self.task)
  219. class Mlp(nn.Module):
  220. def __init__(
  221. self,
  222. in_features,
  223. hidden_features=None,
  224. out_features=None,
  225. act_layer=nn.GELU,
  226. drop=0.0,
  227. ):
  228. super().__init__()
  229. out_features = out_features or in_features
  230. hidden_features = hidden_features or in_features
  231. self.fc1 = nn.Linear(in_features, hidden_features)
  232. self.dwconv = DWConv(hidden_features)
  233. self.act = act_layer()
  234. self.fc2 = nn.Linear(hidden_features, out_features)
  235. self.drop = nn.Dropout(drop)
  236. self.apply(self._init_weights)
  237. def _init_weights(self, m):
  238. if isinstance(m, nn.Linear):
  239. trunc_normal_(m.weight, std=0.02)
  240. if isinstance(m, nn.Linear) and m.bias is not None:
  241. nn.init.constant_(m.bias, 0)
  242. elif isinstance(m, nn.LayerNorm):
  243. nn.init.constant_(m.bias, 0)
  244. nn.init.constant_(m.weight, 1.0)
  245. elif isinstance(m, nn.Conv2d):
  246. fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
  247. fan_out //= m.groups
  248. m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
  249. if m.bias is not None:
  250. m.bias.data.zero_()
  251. def forward(self, x, H, W):
  252. x = self.fc1(x)
  253. x = self.dwconv(x, H, W)
  254. x = self.act(x)
  255. x = self.drop(x)
  256. x = self.fc2(x)
  257. x = self.drop(x)
  258. return x
  259. class Attention(nn.Module):
  260. def __init__(
  261. self,
  262. dim,
  263. num_heads=8,
  264. qkv_bias=False,
  265. qk_scale=None,
  266. attn_drop=0.0,
  267. proj_drop=0.0,
  268. sr_ratio=1,
  269. ):
  270. super().__init__()
  271. assert (
  272. dim % num_heads == 0
  273. ), f"dim {dim} should be divided by num_heads {num_heads}."
  274. self.dim = dim
  275. self.num_heads = num_heads
  276. head_dim = dim // num_heads
  277. self.scale = qk_scale or head_dim**-0.5
  278. self.q = nn.Linear(dim, dim, bias=qkv_bias)
  279. self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
  280. self.attn_drop_prob = attn_drop
  281. self.attn_drop = nn.Dropout(attn_drop)
  282. self.proj = nn.Linear(dim, dim)
  283. self.proj_drop = nn.Dropout(proj_drop)
  284. self.sr_ratio = sr_ratio
  285. if sr_ratio > 1:
  286. self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
  287. self.norm = nn.LayerNorm(dim)
  288. self.apply(self._init_weights)
  289. def _init_weights(self, m):
  290. if isinstance(m, nn.Linear):
  291. trunc_normal_(m.weight, std=0.02)
  292. if isinstance(m, nn.Linear) and m.bias is not None:
  293. nn.init.constant_(m.bias, 0)
  294. elif isinstance(m, nn.LayerNorm):
  295. nn.init.constant_(m.bias, 0)
  296. nn.init.constant_(m.weight, 1.0)
  297. elif isinstance(m, nn.Conv2d):
  298. fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
  299. fan_out //= m.groups
  300. m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
  301. if m.bias is not None:
  302. m.bias.data.zero_()
  303. def forward(self, x, H, W):
  304. B, N, C = x.shape
  305. q = (
  306. self.q(x)
  307. .reshape(B, N, self.num_heads, C // self.num_heads)
  308. .permute(0, 2, 1, 3)
  309. )
  310. if self.sr_ratio > 1:
  311. x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
  312. x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
  313. x_ = self.norm(x_)
  314. kv = (
  315. self.kv(x_)
  316. .reshape(B, -1, 2, self.num_heads, C // self.num_heads)
  317. .permute(2, 0, 3, 1, 4)
  318. )
  319. else:
  320. kv = (
  321. self.kv(x)
  322. .reshape(B, -1, 2, self.num_heads, C // self.num_heads)
  323. .permute(2, 0, 3, 1, 4)
  324. )
  325. k, v = kv[0], kv[1]
  326. if config.SDPA_enabled:
  327. x = (
  328. torch.nn.functional.scaled_dot_product_attention(
  329. q,
  330. k,
  331. v,
  332. attn_mask=None,
  333. dropout_p=self.attn_drop_prob,
  334. is_causal=False,
  335. )
  336. .transpose(1, 2)
  337. .reshape(B, N, C)
  338. )
  339. else:
  340. attn = (q @ k.transpose(-2, -1)) * self.scale
  341. attn = attn.softmax(dim=-1)
  342. attn = self.attn_drop(attn)
  343. x = (attn @ v).transpose(1, 2).reshape(B, N, C)
  344. x = self.proj(x)
  345. x = self.proj_drop(x)
  346. return x
  347. class Block(nn.Module):
  348. def __init__(
  349. self,
  350. dim,
  351. num_heads,
  352. mlp_ratio=4.0,
  353. qkv_bias=False,
  354. qk_scale=None,
  355. drop=0.0,
  356. attn_drop=0.0,
  357. drop_path=0.0,
  358. act_layer=nn.GELU,
  359. norm_layer=nn.LayerNorm,
  360. sr_ratio=1,
  361. ):
  362. super().__init__()
  363. self.norm1 = norm_layer(dim)
  364. self.attn = Attention(
  365. dim,
  366. num_heads=num_heads,
  367. qkv_bias=qkv_bias,
  368. qk_scale=qk_scale,
  369. attn_drop=attn_drop,
  370. proj_drop=drop,
  371. sr_ratio=sr_ratio,
  372. )
  373. # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
  374. self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
  375. self.norm2 = norm_layer(dim)
  376. mlp_hidden_dim = int(dim * mlp_ratio)
  377. self.mlp = Mlp(
  378. in_features=dim,
  379. hidden_features=mlp_hidden_dim,
  380. act_layer=act_layer,
  381. drop=drop,
  382. )
  383. self.apply(self._init_weights)
  384. def _init_weights(self, m):
  385. if isinstance(m, nn.Linear):
  386. trunc_normal_(m.weight, std=0.02)
  387. if isinstance(m, nn.Linear) and m.bias is not None:
  388. nn.init.constant_(m.bias, 0)
  389. elif isinstance(m, nn.LayerNorm):
  390. nn.init.constant_(m.bias, 0)
  391. nn.init.constant_(m.weight, 1.0)
  392. elif isinstance(m, nn.Conv2d):
  393. fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
  394. fan_out //= m.groups
  395. m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
  396. if m.bias is not None:
  397. m.bias.data.zero_()
  398. def forward(self, x, H, W):
  399. x = x + self.drop_path(self.attn(self.norm1(x), H, W))
  400. x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
  401. return x
  402. class OverlapPatchEmbed(nn.Module):
  403. """Image to Patch Embedding"""
  404. def __init__(
  405. self, img_size=224, patch_size=7, stride=4, in_channels=3, embed_dim=768
  406. ):
  407. super().__init__()
  408. img_size = to_2tuple(img_size)
  409. patch_size = to_2tuple(patch_size)
  410. self.img_size = img_size
  411. self.patch_size = patch_size
  412. self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]
  413. self.num_patches = self.H * self.W
  414. self.proj = nn.Conv2d(
  415. in_channels,
  416. embed_dim,
  417. kernel_size=patch_size,
  418. stride=stride,
  419. padding=(patch_size[0] // 2, patch_size[1] // 2),
  420. )
  421. self.norm = nn.LayerNorm(embed_dim)
  422. self.apply(self._init_weights)
  423. def _init_weights(self, m):
  424. if isinstance(m, nn.Linear):
  425. trunc_normal_(m.weight, std=0.02)
  426. if isinstance(m, nn.Linear) and m.bias is not None:
  427. nn.init.constant_(m.bias, 0)
  428. elif isinstance(m, nn.LayerNorm):
  429. nn.init.constant_(m.bias, 0)
  430. nn.init.constant_(m.weight, 1.0)
  431. elif isinstance(m, nn.Conv2d):
  432. fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
  433. fan_out //= m.groups
  434. m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
  435. if m.bias is not None:
  436. m.bias.data.zero_()
  437. def forward(self, x):
  438. x = self.proj(x)
  439. _, _, H, W = x.shape
  440. x = x.flatten(2).transpose(1, 2)
  441. x = self.norm(x)
  442. return x, H, W
  443. class PyramidVisionTransformerImpr(nn.Module):
  444. def __init__(
  445. self,
  446. img_size=224,
  447. patch_size=16,
  448. in_channels=3,
  449. num_classes=1000,
  450. embed_dims=[64, 128, 256, 512],
  451. num_heads=[1, 2, 4, 8],
  452. mlp_ratios=[4, 4, 4, 4],
  453. qkv_bias=False,
  454. qk_scale=None,
  455. drop_rate=0.0,
  456. attn_drop_rate=0.0,
  457. drop_path_rate=0.0,
  458. norm_layer=nn.LayerNorm,
  459. depths=[3, 4, 6, 3],
  460. sr_ratios=[8, 4, 2, 1],
  461. ):
  462. super().__init__()
  463. self.num_classes = num_classes
  464. self.depths = depths
  465. # patch_embed
  466. self.patch_embed1 = OverlapPatchEmbed(
  467. img_size=img_size,
  468. patch_size=7,
  469. stride=4,
  470. in_channels=in_channels,
  471. embed_dim=embed_dims[0],
  472. )
  473. self.patch_embed2 = OverlapPatchEmbed(
  474. img_size=img_size // 4,
  475. patch_size=3,
  476. stride=2,
  477. in_channels=embed_dims[0],
  478. embed_dim=embed_dims[1],
  479. )
  480. self.patch_embed3 = OverlapPatchEmbed(
  481. img_size=img_size // 8,
  482. patch_size=3,
  483. stride=2,
  484. in_channels=embed_dims[1],
  485. embed_dim=embed_dims[2],
  486. )
  487. self.patch_embed4 = OverlapPatchEmbed(
  488. img_size=img_size // 16,
  489. patch_size=3,
  490. stride=2,
  491. in_channels=embed_dims[2],
  492. embed_dim=embed_dims[3],
  493. )
  494. # transformer encoder
  495. dpr = [
  496. x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
  497. ] # stochastic depth decay rule
  498. cur = 0
  499. self.block1 = nn.ModuleList(
  500. [
  501. Block(
  502. dim=embed_dims[0],
  503. num_heads=num_heads[0],
  504. mlp_ratio=mlp_ratios[0],
  505. qkv_bias=qkv_bias,
  506. qk_scale=qk_scale,
  507. drop=drop_rate,
  508. attn_drop=attn_drop_rate,
  509. drop_path=dpr[cur + i],
  510. norm_layer=norm_layer,
  511. sr_ratio=sr_ratios[0],
  512. )
  513. for i in range(depths[0])
  514. ]
  515. )
  516. self.norm1 = norm_layer(embed_dims[0])
  517. cur += depths[0]
  518. self.block2 = nn.ModuleList(
  519. [
  520. Block(
  521. dim=embed_dims[1],
  522. num_heads=num_heads[1],
  523. mlp_ratio=mlp_ratios[1],
  524. qkv_bias=qkv_bias,
  525. qk_scale=qk_scale,
  526. drop=drop_rate,
  527. attn_drop=attn_drop_rate,
  528. drop_path=dpr[cur + i],
  529. norm_layer=norm_layer,
  530. sr_ratio=sr_ratios[1],
  531. )
  532. for i in range(depths[1])
  533. ]
  534. )
  535. self.norm2 = norm_layer(embed_dims[1])
  536. cur += depths[1]
  537. self.block3 = nn.ModuleList(
  538. [
  539. Block(
  540. dim=embed_dims[2],
  541. num_heads=num_heads[2],
  542. mlp_ratio=mlp_ratios[2],
  543. qkv_bias=qkv_bias,
  544. qk_scale=qk_scale,
  545. drop=drop_rate,
  546. attn_drop=attn_drop_rate,
  547. drop_path=dpr[cur + i],
  548. norm_layer=norm_layer,
  549. sr_ratio=sr_ratios[2],
  550. )
  551. for i in range(depths[2])
  552. ]
  553. )
  554. self.norm3 = norm_layer(embed_dims[2])
  555. cur += depths[2]
  556. self.block4 = nn.ModuleList(
  557. [
  558. Block(
  559. dim=embed_dims[3],
  560. num_heads=num_heads[3],
  561. mlp_ratio=mlp_ratios[3],
  562. qkv_bias=qkv_bias,
  563. qk_scale=qk_scale,
  564. drop=drop_rate,
  565. attn_drop=attn_drop_rate,
  566. drop_path=dpr[cur + i],
  567. norm_layer=norm_layer,
  568. sr_ratio=sr_ratios[3],
  569. )
  570. for i in range(depths[3])
  571. ]
  572. )
  573. self.norm4 = norm_layer(embed_dims[3])
  574. # classification head
  575. # self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity()
  576. self.apply(self._init_weights)
  577. def _init_weights(self, m):
  578. if isinstance(m, nn.Linear):
  579. trunc_normal_(m.weight, std=0.02)
  580. if isinstance(m, nn.Linear) and m.bias is not None:
  581. nn.init.constant_(m.bias, 0)
  582. elif isinstance(m, nn.LayerNorm):
  583. nn.init.constant_(m.bias, 0)
  584. nn.init.constant_(m.weight, 1.0)
  585. elif isinstance(m, nn.Conv2d):
  586. fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
  587. fan_out //= m.groups
  588. m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
  589. if m.bias is not None:
  590. m.bias.data.zero_()
  591. def init_weights(self, pretrained=None):
  592. if isinstance(pretrained, str):
  593. logger = 1
  594. # load_checkpoint(self, pretrained, map_location='cpu', strict=False, logger=logger)
  595. def reset_drop_path(self, drop_path_rate):
  596. dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))]
  597. cur = 0
  598. for i in range(self.depths[0]):
  599. self.block1[i].drop_path.drop_prob = dpr[cur + i]
  600. cur += self.depths[0]
  601. for i in range(self.depths[1]):
  602. self.block2[i].drop_path.drop_prob = dpr[cur + i]
  603. cur += self.depths[1]
  604. for i in range(self.depths[2]):
  605. self.block3[i].drop_path.drop_prob = dpr[cur + i]
  606. cur += self.depths[2]
  607. for i in range(self.depths[3]):
  608. self.block4[i].drop_path.drop_prob = dpr[cur + i]
  609. def freeze_patch_emb(self):
  610. self.patch_embed1.requires_grad = False
  611. @torch.jit.ignore
  612. def no_weight_decay(self):
  613. return {
  614. "pos_embed1",
  615. "pos_embed2",
  616. "pos_embed3",
  617. "pos_embed4",
  618. "cls_token",
  619. } # has pos_embed may be better
  620. def get_classifier(self):
  621. return self.head
  622. def reset_classifier(self, num_classes, global_pool=""):
  623. self.num_classes = num_classes
  624. self.head = (
  625. nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
  626. )
  627. def forward_features(self, x):
  628. B = x.shape[0]
  629. outs = []
  630. # stage 1
  631. x, H, W = self.patch_embed1(x)
  632. for i, blk in enumerate(self.block1):
  633. x = blk(x, H, W)
  634. x = self.norm1(x)
  635. x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
  636. outs.append(x)
  637. # stage 2
  638. x, H, W = self.patch_embed2(x)
  639. for i, blk in enumerate(self.block2):
  640. x = blk(x, H, W)
  641. x = self.norm2(x)
  642. x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
  643. outs.append(x)
  644. # stage 3
  645. x, H, W = self.patch_embed3(x)
  646. for i, blk in enumerate(self.block3):
  647. x = blk(x, H, W)
  648. x = self.norm3(x)
  649. x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
  650. outs.append(x)
  651. # stage 4
  652. x, H, W = self.patch_embed4(x)
  653. for i, blk in enumerate(self.block4):
  654. x = blk(x, H, W)
  655. x = self.norm4(x)
  656. x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
  657. outs.append(x)
  658. return outs
  659. # return x.mean(dim=1)
  660. def forward(self, x):
  661. x = self.forward_features(x)
  662. # x = self.head(x)
  663. return x
  664. class DWConv(nn.Module):
  665. def __init__(self, dim=768):
  666. super(DWConv, self).__init__()
  667. self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
  668. def forward(self, x, H, W):
  669. B, N, C = x.shape
  670. x = x.transpose(1, 2).view(B, C, H, W).contiguous()
  671. x = self.dwconv(x)
  672. x = x.flatten(2).transpose(1, 2)
  673. return x
  674. class pvt_v2_b0(PyramidVisionTransformerImpr):
  675. def __init__(self, **kwargs):
  676. super(pvt_v2_b0, self).__init__(
  677. patch_size=4,
  678. embed_dims=[32, 64, 160, 256],
  679. num_heads=[1, 2, 5, 8],
  680. mlp_ratios=[8, 8, 4, 4],
  681. qkv_bias=True,
  682. norm_layer=partial(nn.LayerNorm, eps=1e-6),
  683. depths=[2, 2, 2, 2],
  684. sr_ratios=[8, 4, 2, 1],
  685. drop_rate=0.0,
  686. drop_path_rate=0.1,
  687. )
  688. class pvt_v2_b1(PyramidVisionTransformerImpr):
  689. def __init__(self, **kwargs):
  690. super(pvt_v2_b1, self).__init__(
  691. patch_size=4,
  692. embed_dims=[64, 128, 320, 512],
  693. num_heads=[1, 2, 5, 8],
  694. mlp_ratios=[8, 8, 4, 4],
  695. qkv_bias=True,
  696. norm_layer=partial(nn.LayerNorm, eps=1e-6),
  697. depths=[2, 2, 2, 2],
  698. sr_ratios=[8, 4, 2, 1],
  699. drop_rate=0.0,
  700. drop_path_rate=0.1,
  701. )
  702. ## @register_model
  703. class pvt_v2_b2(PyramidVisionTransformerImpr):
  704. def __init__(self, in_channels=3, **kwargs):
  705. super(pvt_v2_b2, self).__init__(
  706. patch_size=4,
  707. embed_dims=[64, 128, 320, 512],
  708. num_heads=[1, 2, 5, 8],
  709. mlp_ratios=[8, 8, 4, 4],
  710. qkv_bias=True,
  711. norm_layer=partial(nn.LayerNorm, eps=1e-6),
  712. depths=[3, 4, 6, 3],
  713. sr_ratios=[8, 4, 2, 1],
  714. drop_rate=0.0,
  715. drop_path_rate=0.1,
  716. in_channels=in_channels,
  717. )
  718. ## @register_model
  719. class pvt_v2_b3(PyramidVisionTransformerImpr):
  720. def __init__(self, **kwargs):
  721. super(pvt_v2_b3, self).__init__(
  722. patch_size=4,
  723. embed_dims=[64, 128, 320, 512],
  724. num_heads=[1, 2, 5, 8],
  725. mlp_ratios=[8, 8, 4, 4],
  726. qkv_bias=True,
  727. norm_layer=partial(nn.LayerNorm, eps=1e-6),
  728. depths=[3, 4, 18, 3],
  729. sr_ratios=[8, 4, 2, 1],
  730. drop_rate=0.0,
  731. drop_path_rate=0.1,
  732. )
  733. ## @register_model
  734. class pvt_v2_b4(PyramidVisionTransformerImpr):
  735. def __init__(self, **kwargs):
  736. super(pvt_v2_b4, self).__init__(
  737. patch_size=4,
  738. embed_dims=[64, 128, 320, 512],
  739. num_heads=[1, 2, 5, 8],
  740. mlp_ratios=[8, 8, 4, 4],
  741. qkv_bias=True,
  742. norm_layer=partial(nn.LayerNorm, eps=1e-6),
  743. depths=[3, 8, 27, 3],
  744. sr_ratios=[8, 4, 2, 1],
  745. drop_rate=0.0,
  746. drop_path_rate=0.1,
  747. )
  748. ## @register_model
  749. class pvt_v2_b5(PyramidVisionTransformerImpr):
  750. def __init__(self, **kwargs):
  751. super(pvt_v2_b5, self).__init__(
  752. patch_size=4,
  753. embed_dims=[64, 128, 320, 512],
  754. num_heads=[1, 2, 5, 8],
  755. mlp_ratios=[4, 4, 4, 4],
  756. qkv_bias=True,
  757. norm_layer=partial(nn.LayerNorm, eps=1e-6),
  758. depths=[3, 6, 40, 3],
  759. sr_ratios=[8, 4, 2, 1],
  760. drop_rate=0.0,
  761. drop_path_rate=0.1,
  762. )
  763. ### models/backbones/swin_v1.py
  764. # --------------------------------------------------------
  765. # Swin Transformer
  766. # Copyright (c) 2021 Microsoft
  767. # Licensed under The MIT License [see LICENSE for details]
  768. # Written by Ze Liu, Yutong Lin, Yixuan Wei
  769. # --------------------------------------------------------
  770. class Mlp(nn.Module):
  771. """Multilayer perceptron."""
  772. def __init__(
  773. self,
  774. in_features,
  775. hidden_features=None,
  776. out_features=None,
  777. act_layer=nn.GELU,
  778. drop=0.0,
  779. ):
  780. super().__init__()
  781. out_features = out_features or in_features
  782. hidden_features = hidden_features or in_features
  783. self.fc1 = nn.Linear(in_features, hidden_features)
  784. self.act = act_layer()
  785. self.fc2 = nn.Linear(hidden_features, out_features)
  786. self.drop = nn.Dropout(drop)
  787. def forward(self, x):
  788. x = self.fc1(x)
  789. x = self.act(x)
  790. x = self.drop(x)
  791. x = self.fc2(x)
  792. x = self.drop(x)
  793. return x
  794. def window_partition(x, window_size):
  795. """
  796. Args:
  797. x: (B, H, W, C)
  798. window_size (int): window size
  799. Returns:
  800. windows: (num_windows*B, window_size, window_size, C)
  801. """
  802. B, H, W, C = x.shape
  803. x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
  804. windows = (
  805. x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
  806. )
  807. return windows
  808. def window_reverse(windows, window_size, H, W):
  809. """
  810. Args:
  811. windows: (num_windows*B, window_size, window_size, C)
  812. window_size (int): Window size
  813. H (int): Height of image
  814. W (int): Width of image
  815. Returns:
  816. x: (B, H, W, C)
  817. """
  818. B = int(windows.shape[0] / (H * W / window_size / window_size))
  819. x = windows.view(
  820. B, H // window_size, W // window_size, window_size, window_size, -1
  821. )
  822. x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
  823. return x
  824. class WindowAttention(nn.Module):
  825. """Window based multi-head self attention (W-MSA) module with relative position bias.
  826. It supports both of shifted and non-shifted window.
  827. Args:
  828. dim (int): Number of input channels.
  829. window_size (tuple[int]): The height and width of the window.
  830. num_heads (int): Number of attention heads.
  831. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
  832. qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
  833. attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
  834. proj_drop (float, optional): Dropout ratio of output. Default: 0.0
  835. """
  836. def __init__(
  837. self,
  838. dim,
  839. window_size,
  840. num_heads,
  841. qkv_bias=True,
  842. qk_scale=None,
  843. attn_drop=0.0,
  844. proj_drop=0.0,
  845. ):
  846. super().__init__()
  847. self.dim = dim
  848. self.window_size = window_size # Wh, Ww
  849. self.num_heads = num_heads
  850. head_dim = dim // num_heads
  851. self.scale = qk_scale or head_dim**-0.5
  852. # define a parameter table of relative position bias
  853. self.relative_position_bias_table = nn.Parameter(
  854. torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
  855. ) # 2*Wh-1 * 2*Ww-1, nH
  856. # get pair-wise relative position index for each token inside the window
  857. coords_h = torch.arange(self.window_size[0])
  858. coords_w = torch.arange(self.window_size[1])
  859. coords = torch.stack(
  860. torch.meshgrid([coords_h, coords_w], indexing="ij")
  861. ) # 2, Wh, Ww
  862. coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
  863. relative_coords = (
  864. coords_flatten[:, :, None] - coords_flatten[:, None, :]
  865. ) # 2, Wh*Ww, Wh*Ww
  866. relative_coords = relative_coords.permute(
  867. 1, 2, 0
  868. ).contiguous() # Wh*Ww, Wh*Ww, 2
  869. relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
  870. relative_coords[:, :, 1] += self.window_size[1] - 1
  871. relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
  872. relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
  873. self.register_buffer("relative_position_index", relative_position_index)
  874. self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
  875. self.attn_drop_prob = attn_drop
  876. self.attn_drop = nn.Dropout(attn_drop)
  877. self.proj = nn.Linear(dim, dim)
  878. self.proj_drop = nn.Dropout(proj_drop)
  879. trunc_normal_(self.relative_position_bias_table, std=0.02)
  880. self.softmax = nn.Softmax(dim=-1)
  881. def forward(self, x, mask=None):
  882. """Forward function.
  883. Args:
  884. x: input features with shape of (num_windows*B, N, C)
  885. mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
  886. """
  887. B_, N, C = x.shape
  888. qkv = (
  889. self.qkv(x)
  890. .reshape(B_, N, 3, self.num_heads, C // self.num_heads)
  891. .permute(2, 0, 3, 1, 4)
  892. )
  893. q, k, v = (
  894. qkv[0],
  895. qkv[1],
  896. qkv[2],
  897. ) # make torchscript happy (cannot use tensor as tuple)
  898. q = q * self.scale
  899. if config.SDPA_enabled:
  900. x = (
  901. torch.nn.functional.scaled_dot_product_attention(
  902. q,
  903. k,
  904. v,
  905. attn_mask=None,
  906. dropout_p=self.attn_drop_prob,
  907. is_causal=False,
  908. )
  909. .transpose(1, 2)
  910. .reshape(B_, N, C)
  911. )
  912. else:
  913. attn = q @ k.transpose(-2, -1)
  914. relative_position_bias = self.relative_position_bias_table[
  915. self.relative_position_index.view(-1)
  916. ].view(
  917. self.window_size[0] * self.window_size[1],
  918. self.window_size[0] * self.window_size[1],
  919. -1,
  920. ) # Wh*Ww,Wh*Ww,nH
  921. relative_position_bias = relative_position_bias.permute(
  922. 2, 0, 1
  923. ).contiguous() # nH, Wh*Ww, Wh*Ww
  924. attn = attn + relative_position_bias.unsqueeze(0)
  925. if mask is not None:
  926. nW = mask.shape[0]
  927. attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(
  928. 1
  929. ).unsqueeze(0)
  930. attn = attn.view(-1, self.num_heads, N, N)
  931. attn = self.softmax(attn)
  932. else:
  933. attn = self.softmax(attn)
  934. attn = self.attn_drop(attn)
  935. x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
  936. x = self.proj(x)
  937. x = self.proj_drop(x)
  938. return x
  939. class SwinTransformerBlock(nn.Module):
  940. """Swin Transformer Block.
  941. Args:
  942. dim (int): Number of input channels.
  943. num_heads (int): Number of attention heads.
  944. window_size (int): Window size.
  945. shift_size (int): Shift size for SW-MSA.
  946. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
  947. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
  948. qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
  949. drop (float, optional): Dropout rate. Default: 0.0
  950. attn_drop (float, optional): Attention dropout rate. Default: 0.0
  951. drop_path (float, optional): Stochastic depth rate. Default: 0.0
  952. act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
  953. norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
  954. """
  955. def __init__(
  956. self,
  957. dim,
  958. num_heads,
  959. window_size=7,
  960. shift_size=0,
  961. mlp_ratio=4.0,
  962. qkv_bias=True,
  963. qk_scale=None,
  964. drop=0.0,
  965. attn_drop=0.0,
  966. drop_path=0.0,
  967. act_layer=nn.GELU,
  968. norm_layer=nn.LayerNorm,
  969. ):
  970. super().__init__()
  971. self.dim = dim
  972. self.num_heads = num_heads
  973. self.window_size = window_size
  974. self.shift_size = shift_size
  975. self.mlp_ratio = mlp_ratio
  976. assert (
  977. 0 <= self.shift_size < self.window_size
  978. ), "shift_size must in 0-window_size"
  979. self.norm1 = norm_layer(dim)
  980. self.attn = WindowAttention(
  981. dim,
  982. window_size=to_2tuple(self.window_size),
  983. num_heads=num_heads,
  984. qkv_bias=qkv_bias,
  985. qk_scale=qk_scale,
  986. attn_drop=attn_drop,
  987. proj_drop=drop,
  988. )
  989. self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
  990. self.norm2 = norm_layer(dim)
  991. mlp_hidden_dim = int(dim * mlp_ratio)
  992. self.mlp = Mlp(
  993. in_features=dim,
  994. hidden_features=mlp_hidden_dim,
  995. act_layer=act_layer,
  996. drop=drop,
  997. )
  998. self.H = None
  999. self.W = None
  1000. def forward(self, x, mask_matrix):
  1001. """Forward function.
  1002. Args:
  1003. x: Input feature, tensor size (B, H*W, C).
  1004. H, W: Spatial resolution of the input feature.
  1005. mask_matrix: Attention mask for cyclic shift.
  1006. """
  1007. B, L, C = x.shape
  1008. H, W = self.H, self.W
  1009. assert L == H * W, "input feature has wrong size"
  1010. shortcut = x
  1011. x = self.norm1(x)
  1012. x = x.view(B, H, W, C)
  1013. # pad feature maps to multiples of window size
  1014. pad_l = pad_t = 0
  1015. pad_r = (self.window_size - W % self.window_size) % self.window_size
  1016. pad_b = (self.window_size - H % self.window_size) % self.window_size
  1017. x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
  1018. _, Hp, Wp, _ = x.shape
  1019. # cyclic shift
  1020. if self.shift_size > 0:
  1021. shifted_x = torch.roll(
  1022. x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)
  1023. )
  1024. attn_mask = mask_matrix
  1025. else:
  1026. shifted_x = x
  1027. attn_mask = None
  1028. # partition windows
  1029. x_windows = window_partition(
  1030. shifted_x, self.window_size
  1031. ) # nW*B, window_size, window_size, C
  1032. x_windows = x_windows.view(
  1033. -1, self.window_size * self.window_size, C
  1034. ) # nW*B, window_size*window_size, C
  1035. # W-MSA/SW-MSA
  1036. attn_windows = self.attn(
  1037. x_windows, mask=attn_mask
  1038. ) # nW*B, window_size*window_size, C
  1039. # merge windows
  1040. attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
  1041. shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
  1042. # reverse cyclic shift
  1043. if self.shift_size > 0:
  1044. x = torch.roll(
  1045. shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)
  1046. )
  1047. else:
  1048. x = shifted_x
  1049. if pad_r > 0 or pad_b > 0:
  1050. x = x[:, :H, :W, :].contiguous()
  1051. x = x.view(B, H * W, C)
  1052. # FFN
  1053. x = shortcut + self.drop_path(x)
  1054. x = x + self.drop_path(self.mlp(self.norm2(x)))
  1055. return x
  1056. class PatchMerging(nn.Module):
  1057. """Patch Merging Layer
  1058. Args:
  1059. dim (int): Number of input channels.
  1060. norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
  1061. """
  1062. def __init__(self, dim, norm_layer=nn.LayerNorm):
  1063. super().__init__()
  1064. self.dim = dim
  1065. self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
  1066. self.norm = norm_layer(4 * dim)
  1067. def forward(self, x, H, W):
  1068. """Forward function.
  1069. Args:
  1070. x: Input feature, tensor size (B, H*W, C).
  1071. H, W: Spatial resolution of the input feature.
  1072. """
  1073. B, L, C = x.shape
  1074. assert L == H * W, "input feature has wrong size"
  1075. x = x.view(B, H, W, C)
  1076. # padding
  1077. pad_input = (H % 2 == 1) or (W % 2 == 1)
  1078. if pad_input:
  1079. x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
  1080. x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
  1081. x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
  1082. x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
  1083. x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
  1084. x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
  1085. x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
  1086. x = self.norm(x)
  1087. x = self.reduction(x)
  1088. return x
  1089. class BasicLayer(nn.Module):
  1090. """A basic Swin Transformer layer for one stage.
  1091. Args:
  1092. dim (int): Number of feature channels
  1093. depth (int): Depths of this stage.
  1094. num_heads (int): Number of attention head.
  1095. window_size (int): Local window size. Default: 7.
  1096. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
  1097. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
  1098. qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
  1099. drop (float, optional): Dropout rate. Default: 0.0
  1100. attn_drop (float, optional): Attention dropout rate. Default: 0.0
  1101. drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
  1102. norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
  1103. downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
  1104. use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
  1105. """
  1106. def __init__(
  1107. self,
  1108. dim,
  1109. depth,
  1110. num_heads,
  1111. window_size=7,
  1112. mlp_ratio=4.0,
  1113. qkv_bias=True,
  1114. qk_scale=None,
  1115. drop=0.0,
  1116. attn_drop=0.0,
  1117. drop_path=0.0,
  1118. norm_layer=nn.LayerNorm,
  1119. downsample=None,
  1120. use_checkpoint=False,
  1121. ):
  1122. super().__init__()
  1123. self.window_size = window_size
  1124. self.shift_size = window_size // 2
  1125. self.depth = depth
  1126. self.use_checkpoint = use_checkpoint
  1127. # build blocks
  1128. self.blocks = nn.ModuleList(
  1129. [
  1130. SwinTransformerBlock(
  1131. dim=dim,
  1132. num_heads=num_heads,
  1133. window_size=window_size,
  1134. shift_size=0 if (i % 2 == 0) else window_size // 2,
  1135. mlp_ratio=mlp_ratio,
  1136. qkv_bias=qkv_bias,
  1137. qk_scale=qk_scale,
  1138. drop=drop,
  1139. attn_drop=attn_drop,
  1140. drop_path=drop_path[i]
  1141. if isinstance(drop_path, list)
  1142. else drop_path,
  1143. norm_layer=norm_layer,
  1144. )
  1145. for i in range(depth)
  1146. ]
  1147. )
  1148. # patch merging layer
  1149. if downsample is not None:
  1150. self.downsample = downsample(dim=dim, norm_layer=norm_layer)
  1151. else:
  1152. self.downsample = None
  1153. def forward(self, x, H, W):
  1154. """Forward function.
  1155. Args:
  1156. x: Input feature, tensor size (B, H*W, C).
  1157. H, W: Spatial resolution of the input feature.
  1158. """
  1159. # calculate attention mask for SW-MSA
  1160. Hp = int(np.ceil(H / self.window_size)) * self.window_size
  1161. Wp = int(np.ceil(W / self.window_size)) * self.window_size
  1162. img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
  1163. h_slices = (
  1164. slice(0, -self.window_size),
  1165. slice(-self.window_size, -self.shift_size),
  1166. slice(-self.shift_size, None),
  1167. )
  1168. w_slices = (
  1169. slice(0, -self.window_size),
  1170. slice(-self.window_size, -self.shift_size),
  1171. slice(-self.shift_size, None),
  1172. )
  1173. cnt = 0
  1174. for h in h_slices:
  1175. for w in w_slices:
  1176. img_mask[:, h, w, :] = cnt
  1177. cnt += 1
  1178. mask_windows = window_partition(
  1179. img_mask, self.window_size
  1180. ) # nW, window_size, window_size, 1
  1181. mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
  1182. attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
  1183. attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
  1184. attn_mask == 0, float(0.0)
  1185. )
  1186. for blk in self.blocks:
  1187. blk.H, blk.W = H, W
  1188. if self.use_checkpoint:
  1189. x = checkpoint.checkpoint(blk, x, attn_mask)
  1190. else:
  1191. x = blk(x, attn_mask)
  1192. if self.downsample is not None:
  1193. x_down = self.downsample(x, H, W)
  1194. Wh, Ww = (H + 1) // 2, (W + 1) // 2
  1195. return x, H, W, x_down, Wh, Ww
  1196. else:
  1197. return x, H, W, x, H, W
  1198. class PatchEmbed(nn.Module):
  1199. """Image to Patch Embedding
  1200. Args:
  1201. patch_size (int): Patch token size. Default: 4.
  1202. in_channels (int): Number of input image channels. Default: 3.
  1203. embed_dim (int): Number of linear projection output channels. Default: 96.
  1204. norm_layer (nn.Module, optional): Normalization layer. Default: None
  1205. """
  1206. def __init__(self, patch_size=4, in_channels=3, embed_dim=96, norm_layer=None):
  1207. super().__init__()
  1208. patch_size = to_2tuple(patch_size)
  1209. self.patch_size = patch_size
  1210. self.in_channels = in_channels
  1211. self.embed_dim = embed_dim
  1212. self.proj = nn.Conv2d(
  1213. in_channels, embed_dim, kernel_size=patch_size, stride=patch_size
  1214. )
  1215. if norm_layer is not None:
  1216. self.norm = norm_layer(embed_dim)
  1217. else:
  1218. self.norm = None
  1219. def forward(self, x):
  1220. """Forward function."""
  1221. # padding
  1222. _, _, H, W = x.size()
  1223. if W % self.patch_size[1] != 0:
  1224. x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
  1225. if H % self.patch_size[0] != 0:
  1226. x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
  1227. x = self.proj(x) # B C Wh Ww
  1228. if self.norm is not None:
  1229. Wh, Ww = x.size(2), x.size(3)
  1230. x = x.flatten(2).transpose(1, 2)
  1231. x = self.norm(x)
  1232. x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
  1233. return x
  1234. class SwinTransformer(nn.Module):
  1235. """Swin Transformer backbone.
  1236. A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
  1237. https://arxiv.org/pdf/2103.14030
  1238. Args:
  1239. pretrain_img_size (int): Input image size for training the pretrained model,
  1240. used in absolute postion embedding. Default 224.
  1241. patch_size (int | tuple(int)): Patch size. Default: 4.
  1242. in_channels (int): Number of input image channels. Default: 3.
  1243. embed_dim (int): Number of linear projection output channels. Default: 96.
  1244. depths (tuple[int]): Depths of each Swin Transformer stage.
  1245. num_heads (tuple[int]): Number of attention head of each stage.
  1246. window_size (int): Window size. Default: 7.
  1247. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
  1248. qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
  1249. qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
  1250. drop_rate (float): Dropout rate.
  1251. attn_drop_rate (float): Attention dropout rate. Default: 0.
  1252. drop_path_rate (float): Stochastic depth rate. Default: 0.2.
  1253. norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
  1254. ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
  1255. patch_norm (bool): If True, add normalization after patch embedding. Default: True.
  1256. out_indices (Sequence[int]): Output from which stages.
  1257. frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
  1258. -1 means not freezing any parameters.
  1259. use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
  1260. """
  1261. def __init__(
  1262. self,
  1263. pretrain_img_size=224,
  1264. patch_size=4,
  1265. in_channels=3,
  1266. embed_dim=96,
  1267. depths=[2, 2, 6, 2],
  1268. num_heads=[3, 6, 12, 24],
  1269. window_size=7,
  1270. mlp_ratio=4.0,
  1271. qkv_bias=True,
  1272. qk_scale=None,
  1273. drop_rate=0.0,
  1274. attn_drop_rate=0.0,
  1275. drop_path_rate=0.2,
  1276. norm_layer=nn.LayerNorm,
  1277. ape=False,
  1278. patch_norm=True,
  1279. out_indices=(0, 1, 2, 3),
  1280. frozen_stages=-1,
  1281. use_checkpoint=False,
  1282. ):
  1283. super().__init__()
  1284. self.pretrain_img_size = pretrain_img_size
  1285. self.num_layers = len(depths)
  1286. self.embed_dim = embed_dim
  1287. self.ape = ape
  1288. self.patch_norm = patch_norm
  1289. self.out_indices = out_indices
  1290. self.frozen_stages = frozen_stages
  1291. # split image into non-overlapping patches
  1292. self.patch_embed = PatchEmbed(
  1293. patch_size=patch_size,
  1294. in_channels=in_channels,
  1295. embed_dim=embed_dim,
  1296. norm_layer=norm_layer if self.patch_norm else None,
  1297. )
  1298. # absolute position embedding
  1299. if self.ape:
  1300. pretrain_img_size = to_2tuple(pretrain_img_size)
  1301. patch_size = to_2tuple(patch_size)
  1302. patches_resolution = [
  1303. pretrain_img_size[0] // patch_size[0],
  1304. pretrain_img_size[1] // patch_size[1],
  1305. ]
  1306. self.absolute_pos_embed = nn.Parameter(
  1307. torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1])
  1308. )
  1309. trunc_normal_(self.absolute_pos_embed, std=0.02)
  1310. self.pos_drop = nn.Dropout(p=drop_rate)
  1311. # stochastic depth
  1312. dpr = [
  1313. x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
  1314. ] # stochastic depth decay rule
  1315. # build layers
  1316. self.layers = nn.ModuleList()
  1317. for i_layer in range(self.num_layers):
  1318. layer = BasicLayer(
  1319. dim=int(embed_dim * 2**i_layer),
  1320. depth=depths[i_layer],
  1321. num_heads=num_heads[i_layer],
  1322. window_size=window_size,
  1323. mlp_ratio=mlp_ratio,
  1324. qkv_bias=qkv_bias,
  1325. qk_scale=qk_scale,
  1326. drop=drop_rate,
  1327. attn_drop=attn_drop_rate,
  1328. drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],
  1329. norm_layer=norm_layer,
  1330. downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
  1331. use_checkpoint=use_checkpoint,
  1332. )
  1333. self.layers.append(layer)
  1334. num_features = [int(embed_dim * 2**i) for i in range(self.num_layers)]
  1335. self.num_features = num_features
  1336. # add a norm layer for each output
  1337. for i_layer in out_indices:
  1338. layer = norm_layer(num_features[i_layer])
  1339. layer_name = f"norm{i_layer}"
  1340. self.add_module(layer_name, layer)
  1341. self._freeze_stages()
  1342. def _freeze_stages(self):
  1343. if self.frozen_stages >= 0:
  1344. self.patch_embed.eval()
  1345. for param in self.patch_embed.parameters():
  1346. param.requires_grad = False
  1347. if self.frozen_stages >= 1 and self.ape:
  1348. self.absolute_pos_embed.requires_grad = False
  1349. if self.frozen_stages >= 2:
  1350. self.pos_drop.eval()
  1351. for i in range(0, self.frozen_stages - 1):
  1352. m = self.layers[i]
  1353. m.eval()
  1354. for param in m.parameters():
  1355. param.requires_grad = False
  1356. def forward(self, x):
  1357. """Forward function."""
  1358. x = self.patch_embed(x)
  1359. Wh, Ww = x.size(2), x.size(3)
  1360. if self.ape:
  1361. # interpolate the position embedding to the corresponding size
  1362. absolute_pos_embed = F.interpolate(
  1363. self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic"
  1364. )
  1365. x = x + absolute_pos_embed # B Wh*Ww C
  1366. outs = [] # x.contiguous()]
  1367. x = x.flatten(2).transpose(1, 2)
  1368. x = self.pos_drop(x)
  1369. for i in range(self.num_layers):
  1370. layer = self.layers[i]
  1371. x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
  1372. if i in self.out_indices:
  1373. norm_layer = getattr(self, f"norm{i}")
  1374. x_out = norm_layer(x_out)
  1375. out = (
  1376. x_out.view(-1, H, W, self.num_features[i])
  1377. .permute(0, 3, 1, 2)
  1378. .contiguous()
  1379. )
  1380. outs.append(out)
  1381. return tuple(outs)
  1382. def train(self, mode=True):
  1383. """Convert the model into training mode while keep layers freezed."""
  1384. super(SwinTransformer, self).train(mode)
  1385. self._freeze_stages()
  1386. def swin_v1_t():
  1387. model = SwinTransformer(
  1388. embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], window_size=7
  1389. )
  1390. return model
  1391. def swin_v1_s():
  1392. model = SwinTransformer(
  1393. embed_dim=96, depths=[2, 2, 18, 2], num_heads=[3, 6, 12, 24], window_size=7
  1394. )
  1395. return model
  1396. def swin_v1_b():
  1397. model = SwinTransformer(
  1398. embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32], window_size=12
  1399. )
  1400. return model
  1401. def swin_v1_l():
  1402. model = SwinTransformer(
  1403. embed_dim=192, depths=[2, 2, 18, 2], num_heads=[6, 12, 24, 48], window_size=12
  1404. )
  1405. return model
  1406. ### models/modules/deform_conv.py
  1407. import torch
  1408. import torch.nn as nn
  1409. from torchvision.ops import deform_conv2d
  1410. class DeformableConv2d(nn.Module):
  1411. def __init__(
  1412. self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False
  1413. ):
  1414. super(DeformableConv2d, self).__init__()
  1415. assert type(kernel_size) == tuple or type(kernel_size) == int
  1416. kernel_size = (
  1417. kernel_size if type(kernel_size) == tuple else (kernel_size, kernel_size)
  1418. )
  1419. self.stride = stride if type(stride) == tuple else (stride, stride)
  1420. self.padding = padding
  1421. self.offset_conv = nn.Conv2d(
  1422. in_channels,
  1423. 2 * kernel_size[0] * kernel_size[1],
  1424. kernel_size=kernel_size,
  1425. stride=stride,
  1426. padding=self.padding,
  1427. bias=True,
  1428. )
  1429. nn.init.constant_(self.offset_conv.weight, 0.0)
  1430. nn.init.constant_(self.offset_conv.bias, 0.0)
  1431. self.modulator_conv = nn.Conv2d(
  1432. in_channels,
  1433. 1 * kernel_size[0] * kernel_size[1],
  1434. kernel_size=kernel_size,
  1435. stride=stride,
  1436. padding=self.padding,
  1437. bias=True,
  1438. )
  1439. nn.init.constant_(self.modulator_conv.weight, 0.0)
  1440. nn.init.constant_(self.modulator_conv.bias, 0.0)
  1441. self.regular_conv = nn.Conv2d(
  1442. in_channels,
  1443. out_channels=out_channels,
  1444. kernel_size=kernel_size,
  1445. stride=stride,
  1446. padding=self.padding,
  1447. bias=bias,
  1448. )
  1449. def forward(self, x):
  1450. # h, w = x.shape[2:]
  1451. # max_offset = max(h, w)/4.
  1452. offset = self.offset_conv(x) # .clamp(-max_offset, max_offset)
  1453. modulator = 2.0 * torch.sigmoid(self.modulator_conv(x))
  1454. x = deform_conv2d(
  1455. input=x,
  1456. offset=offset,
  1457. weight=self.regular_conv.weight,
  1458. bias=self.regular_conv.bias,
  1459. padding=self.padding,
  1460. mask=modulator,
  1461. stride=self.stride,
  1462. )
  1463. return x
  1464. ### utils.py
  1465. import torch.nn as nn
  1466. def build_act_layer(act_layer):
  1467. if act_layer == "ReLU":
  1468. return nn.ReLU(inplace=True)
  1469. elif act_layer == "SiLU":
  1470. return nn.SiLU(inplace=True)
  1471. elif act_layer == "GELU":
  1472. return nn.GELU()
  1473. raise NotImplementedError(f"build_act_layer does not support {act_layer}")
  1474. def build_norm_layer(
  1475. dim, norm_layer, in_format="channels_last", out_format="channels_last", eps=1e-6
  1476. ):
  1477. layers = []
  1478. if norm_layer == "BN":
  1479. if in_format == "channels_last":
  1480. layers.append(to_channels_first())
  1481. layers.append(nn.BatchNorm2d(dim))
  1482. if out_format == "channels_last":
  1483. layers.append(to_channels_last())
  1484. elif norm_layer == "LN":
  1485. if in_format == "channels_first":
  1486. layers.append(to_channels_last())
  1487. layers.append(nn.LayerNorm(dim, eps=eps))
  1488. if out_format == "channels_first":
  1489. layers.append(to_channels_first())
  1490. else:
  1491. raise NotImplementedError(f"build_norm_layer does not support {norm_layer}")
  1492. return nn.Sequential(*layers)
  1493. class to_channels_first(nn.Module):
  1494. def __init__(self):
  1495. super().__init__()
  1496. def forward(self, x):
  1497. return x.permute(0, 3, 1, 2)
  1498. class to_channels_last(nn.Module):
  1499. def __init__(self):
  1500. super().__init__()
  1501. def forward(self, x):
  1502. return x.permute(0, 2, 3, 1)
  1503. ### dataset.py
  1504. _class_labels_TR_sorted = (
  1505. "Airplane, Ant, Antenna, Archery, Axe, BabyCarriage, Bag, BalanceBeam, Balcony, Balloon, Basket, BasketballHoop, Beatle, Bed, Bee, Bench, Bicycle, "
  1506. "BicycleFrame, BicycleStand, Boat, Bonsai, BoomLift, Bridge, BunkBed, Butterfly, Button, Cable, CableLift, Cage, Camcorder, Cannon, Canoe, Car, "
  1507. "CarParkDropArm, Carriage, Cart, Caterpillar, CeilingLamp, Centipede, Chair, Clip, Clock, Clothes, CoatHanger, Comb, ConcretePumpTruck, Crack, Crane, "
  1508. "Cup, DentalChair, Desk, DeskChair, Diagram, DishRack, DoorHandle, Dragonfish, Dragonfly, Drum, Earphone, Easel, ElectricIron, Excavator, Eyeglasses, "
  1509. "Fan, Fence, Fencing, FerrisWheel, FireExtinguisher, Fishing, Flag, FloorLamp, Forklift, GasStation, Gate, Gear, Goal, Golf, GymEquipment, Hammock, "
  1510. "Handcart, Handcraft, Handrail, HangGlider, Harp, Harvester, Headset, Helicopter, Helmet, Hook, HorizontalBar, Hydrovalve, IroningTable, Jewelry, Key, "
  1511. "KidsPlayground, Kitchenware, Kite, Knife, Ladder, LaundryRack, Lightning, Lobster, Locust, Machine, MachineGun, MagazineRack, Mantis, Medal, MemorialArchway, "
  1512. "Microphone, Missile, MobileHolder, Monitor, Mosquito, Motorcycle, MovingTrolley, Mower, MusicPlayer, MusicStand, ObservationTower, Octopus, OilWell, "
  1513. "OlympicLogo, OperatingTable, OutdoorFitnessEquipment, Parachute, Pavilion, Piano, Pipe, PlowHarrow, PoleVault, Punchbag, Rack, Racket, Rifle, Ring, Robot, "
  1514. "RockClimbing, Rope, Sailboat, Satellite, Scaffold, Scale, Scissor, Scooter, Sculpture, Seadragon, Seahorse, Seal, SewingMachine, Ship, Shoe, ShoppingCart, "
  1515. "ShoppingTrolley, Shower, Shrimp, Signboard, Skateboarding, Skeleton, Skiing, Spade, SpeedBoat, Spider, Spoon, Stair, Stand, Stationary, SteeringWheel, "
  1516. "Stethoscope, Stool, Stove, StreetLamp, SweetStand, Swing, Sword, TV, Table, TableChair, TableLamp, TableTennis, Tank, Tapeline, Teapot, Telescope, Tent, "
  1517. "TobaccoPipe, Toy, Tractor, TrafficLight, TrafficSign, Trampoline, TransmissionTower, Tree, Tricycle, TrimmerCover, Tripod, Trombone, Truck, Trumpet, Tuba, "
  1518. "UAV, Umbrella, UnevenBars, UtilityPole, VacuumCleaner, Violin, Wakesurfing, Watch, WaterTower, WateringPot, Well, WellLid, Wheel, Wheelchair, WindTurbine, Windmill, WineGlass, WireWhisk, Yacht"
  1519. )
  1520. class_labels_TR_sorted = _class_labels_TR_sorted.split(", ")
  1521. ### models/backbones/build_backbones.py
  1522. config = Config()
  1523. def build_backbone(bb_name, pretrained=True, params_settings=""):
  1524. if bb_name == "vgg16":
  1525. bb_net = list(
  1526. vgg16(pretrained=VGG16_Weights.DEFAULT if pretrained else None).children()
  1527. )[0]
  1528. bb = nn.Sequential(
  1529. OrderedDict(
  1530. {
  1531. "conv1": bb_net[:4],
  1532. "conv2": bb_net[4:9],
  1533. "conv3": bb_net[9:16],
  1534. "conv4": bb_net[16:23],
  1535. }
  1536. )
  1537. )
  1538. elif bb_name == "vgg16bn":
  1539. bb_net = list(
  1540. vgg16_bn(
  1541. pretrained=VGG16_BN_Weights.DEFAULT if pretrained else None
  1542. ).children()
  1543. )[0]
  1544. bb = nn.Sequential(
  1545. OrderedDict(
  1546. {
  1547. "conv1": bb_net[:6],
  1548. "conv2": bb_net[6:13],
  1549. "conv3": bb_net[13:23],
  1550. "conv4": bb_net[23:33],
  1551. }
  1552. )
  1553. )
  1554. elif bb_name == "resnet50":
  1555. bb_net = list(
  1556. resnet50(
  1557. pretrained=ResNet50_Weights.DEFAULT if pretrained else None
  1558. ).children()
  1559. )
  1560. bb = nn.Sequential(
  1561. OrderedDict(
  1562. {
  1563. "conv1": nn.Sequential(*bb_net[0:3]),
  1564. "conv2": bb_net[4],
  1565. "conv3": bb_net[5],
  1566. "conv4": bb_net[6],
  1567. }
  1568. )
  1569. )
  1570. else:
  1571. bb = eval("{}({})".format(bb_name, params_settings))
  1572. if pretrained:
  1573. bb = load_weights(bb, bb_name)
  1574. return bb
  1575. def load_weights(model, model_name):
  1576. save_model = torch.load(config.weights[model_name], map_location="cpu")
  1577. model_dict = model.state_dict()
  1578. state_dict = {
  1579. k: v if v.size() == model_dict[k].size() else model_dict[k]
  1580. for k, v in save_model.items()
  1581. if k in model_dict.keys()
  1582. }
  1583. # to ignore the weights with mismatched size when I modify the backbone itself.
  1584. if not state_dict:
  1585. save_model_keys = list(save_model.keys())
  1586. sub_item = save_model_keys[0] if len(save_model_keys) == 1 else None
  1587. state_dict = {
  1588. k: v if v.size() == model_dict[k].size() else model_dict[k]
  1589. for k, v in save_model[sub_item].items()
  1590. if k in model_dict.keys()
  1591. }
  1592. if not state_dict or not sub_item:
  1593. print(
  1594. "Weights are not successully loaded. Check the state dict of weights file."
  1595. )
  1596. return None
  1597. else:
  1598. print(
  1599. 'Found correct weights in the "{}" item of loaded state_dict.'.format(
  1600. sub_item
  1601. )
  1602. )
  1603. model_dict.update(state_dict)
  1604. model.load_state_dict(model_dict)
  1605. return model
  1606. ### models/modules/decoder_blocks.py
  1607. import torch
  1608. import torch.nn as nn
  1609. # from models.aspp import ASPP, ASPPDeformable
  1610. # from config import Config
  1611. # config = Config()
  1612. class BasicDecBlk(nn.Module):
  1613. def __init__(self, in_channels=64, out_channels=64, inter_channels=64):
  1614. super(BasicDecBlk, self).__init__()
  1615. inter_channels = in_channels // 4 if config.dec_channels_inter == "adap" else 64
  1616. self.conv_in = nn.Conv2d(in_channels, inter_channels, 3, 1, padding=1)
  1617. self.relu_in = nn.ReLU(inplace=True)
  1618. if config.dec_att == "ASPP":
  1619. self.dec_att = ASPP(in_channels=inter_channels)
  1620. elif config.dec_att == "ASPPDeformable":
  1621. self.dec_att = ASPPDeformable(in_channels=inter_channels)
  1622. self.conv_out = nn.Conv2d(inter_channels, out_channels, 3, 1, padding=1)
  1623. self.bn_in = (
  1624. nn.BatchNorm2d(inter_channels) if config.batch_size > 1 else nn.Identity()
  1625. )
  1626. self.bn_out = (
  1627. nn.BatchNorm2d(out_channels) if config.batch_size > 1 else nn.Identity()
  1628. )
  1629. def forward(self, x):
  1630. x = self.conv_in(x)
  1631. x = self.bn_in(x)
  1632. x = self.relu_in(x)
  1633. if hasattr(self, "dec_att"):
  1634. x = self.dec_att(x)
  1635. x = self.conv_out(x)
  1636. x = self.bn_out(x)
  1637. return x
  1638. class ResBlk(nn.Module):
  1639. def __init__(self, in_channels=64, out_channels=None, inter_channels=64):
  1640. super(ResBlk, self).__init__()
  1641. if out_channels is None:
  1642. out_channels = in_channels
  1643. inter_channels = in_channels // 4 if config.dec_channels_inter == "adap" else 64
  1644. self.conv_in = nn.Conv2d(in_channels, inter_channels, 3, 1, padding=1)
  1645. self.bn_in = (
  1646. nn.BatchNorm2d(inter_channels) if config.batch_size > 1 else nn.Identity()
  1647. )
  1648. self.relu_in = nn.ReLU(inplace=True)
  1649. if config.dec_att == "ASPP":
  1650. self.dec_att = ASPP(in_channels=inter_channels)
  1651. elif config.dec_att == "ASPPDeformable":
  1652. self.dec_att = ASPPDeformable(in_channels=inter_channels)
  1653. self.conv_out = nn.Conv2d(inter_channels, out_channels, 3, 1, padding=1)
  1654. self.bn_out = (
  1655. nn.BatchNorm2d(out_channels) if config.batch_size > 1 else nn.Identity()
  1656. )
  1657. self.conv_resi = nn.Conv2d(in_channels, out_channels, 1, 1, 0)
  1658. def forward(self, x):
  1659. _x = self.conv_resi(x)
  1660. x = self.conv_in(x)
  1661. x = self.bn_in(x)
  1662. x = self.relu_in(x)
  1663. if hasattr(self, "dec_att"):
  1664. x = self.dec_att(x)
  1665. x = self.conv_out(x)
  1666. x = self.bn_out(x)
  1667. return x + _x
  1668. ### models/modules/lateral_blocks.py
  1669. from functools import partial
  1670. import numpy as np
  1671. import torch
  1672. import torch.nn as nn
  1673. import torch.nn.functional as F
  1674. # from config import Config
  1675. # config = Config()
  1676. class BasicLatBlk(nn.Module):
  1677. def __init__(self, in_channels=64, out_channels=64, inter_channels=64):
  1678. super(BasicLatBlk, self).__init__()
  1679. inter_channels = in_channels // 4 if config.dec_channels_inter == "adap" else 64
  1680. self.conv = nn.Conv2d(in_channels, out_channels, 1, 1, 0)
  1681. def forward(self, x):
  1682. x = self.conv(x)
  1683. return x
  1684. ### models/modules/aspp.py
  1685. import torch
  1686. import torch.nn as nn
  1687. import torch.nn.functional as F
  1688. # from models.deform_conv import DeformableConv2d
  1689. # from config import Config
  1690. # config = Config()
  1691. class _ASPPModule(nn.Module):
  1692. def __init__(self, in_channels, planes, kernel_size, padding, dilation):
  1693. super(_ASPPModule, self).__init__()
  1694. self.atrous_conv = nn.Conv2d(
  1695. in_channels,
  1696. planes,
  1697. kernel_size=kernel_size,
  1698. stride=1,
  1699. padding=padding,
  1700. dilation=dilation,
  1701. bias=False,
  1702. )
  1703. self.bn = nn.BatchNorm2d(planes) if config.batch_size > 1 else nn.Identity()
  1704. self.relu = nn.ReLU(inplace=True)
  1705. def forward(self, x):
  1706. x = self.atrous_conv(x)
  1707. x = self.bn(x)
  1708. return self.relu(x)
  1709. class ASPP(nn.Module):
  1710. def __init__(self, in_channels=64, out_channels=None, output_stride=16):
  1711. super(ASPP, self).__init__()
  1712. self.down_scale = 1
  1713. if out_channels is None:
  1714. out_channels = in_channels
  1715. self.in_channelster = 256 // self.down_scale
  1716. if output_stride == 16:
  1717. dilations = [1, 6, 12, 18]
  1718. elif output_stride == 8:
  1719. dilations = [1, 12, 24, 36]
  1720. else:
  1721. raise NotImplementedError
  1722. self.aspp1 = _ASPPModule(
  1723. in_channels, self.in_channelster, 1, padding=0, dilation=dilations[0]
  1724. )
  1725. self.aspp2 = _ASPPModule(
  1726. in_channels,
  1727. self.in_channelster,
  1728. 3,
  1729. padding=dilations[1],
  1730. dilation=dilations[1],
  1731. )
  1732. self.aspp3 = _ASPPModule(
  1733. in_channels,
  1734. self.in_channelster,
  1735. 3,
  1736. padding=dilations[2],
  1737. dilation=dilations[2],
  1738. )
  1739. self.aspp4 = _ASPPModule(
  1740. in_channels,
  1741. self.in_channelster,
  1742. 3,
  1743. padding=dilations[3],
  1744. dilation=dilations[3],
  1745. )
  1746. self.global_avg_pool = nn.Sequential(
  1747. nn.AdaptiveAvgPool2d((1, 1)),
  1748. nn.Conv2d(in_channels, self.in_channelster, 1, stride=1, bias=False),
  1749. nn.BatchNorm2d(self.in_channelster)
  1750. if config.batch_size > 1
  1751. else nn.Identity(),
  1752. nn.ReLU(inplace=True),
  1753. )
  1754. self.conv1 = nn.Conv2d(self.in_channelster * 5, out_channels, 1, bias=False)
  1755. self.bn1 = (
  1756. nn.BatchNorm2d(out_channels) if config.batch_size > 1 else nn.Identity()
  1757. )
  1758. self.relu = nn.ReLU(inplace=True)
  1759. self.dropout = nn.Dropout(0.5)
  1760. def forward(self, x):
  1761. x1 = self.aspp1(x)
  1762. x2 = self.aspp2(x)
  1763. x3 = self.aspp3(x)
  1764. x4 = self.aspp4(x)
  1765. x5 = self.global_avg_pool(x)
  1766. x5 = F.interpolate(x5, size=x1.size()[2:], mode="bilinear", align_corners=True)
  1767. x = torch.cat((x1, x2, x3, x4, x5), dim=1)
  1768. x = self.conv1(x)
  1769. x = self.bn1(x)
  1770. x = self.relu(x)
  1771. return self.dropout(x)
  1772. ##################### Deformable
  1773. class _ASPPModuleDeformable(nn.Module):
  1774. def __init__(self, in_channels, planes, kernel_size, padding):
  1775. super(_ASPPModuleDeformable, self).__init__()
  1776. self.atrous_conv = DeformableConv2d(
  1777. in_channels,
  1778. planes,
  1779. kernel_size=kernel_size,
  1780. stride=1,
  1781. padding=padding,
  1782. bias=False,
  1783. )
  1784. self.bn = nn.BatchNorm2d(planes) if config.batch_size > 1 else nn.Identity()
  1785. self.relu = nn.ReLU(inplace=True)
  1786. def forward(self, x):
  1787. x = self.atrous_conv(x)
  1788. x = self.bn(x)
  1789. return self.relu(x)
  1790. class ASPPDeformable(nn.Module):
  1791. def __init__(self, in_channels, out_channels=None, parallel_block_sizes=[1, 3, 7]):
  1792. super(ASPPDeformable, self).__init__()
  1793. self.down_scale = 1
  1794. if out_channels is None:
  1795. out_channels = in_channels
  1796. self.in_channelster = 256 // self.down_scale
  1797. self.aspp1 = _ASPPModuleDeformable(
  1798. in_channels, self.in_channelster, 1, padding=0
  1799. )
  1800. self.aspp_deforms = nn.ModuleList(
  1801. [
  1802. _ASPPModuleDeformable(
  1803. in_channels,
  1804. self.in_channelster,
  1805. conv_size,
  1806. padding=int(conv_size // 2),
  1807. )
  1808. for conv_size in parallel_block_sizes
  1809. ]
  1810. )
  1811. self.global_avg_pool = nn.Sequential(
  1812. nn.AdaptiveAvgPool2d((1, 1)),
  1813. nn.Conv2d(in_channels, self.in_channelster, 1, stride=1, bias=False),
  1814. nn.BatchNorm2d(self.in_channelster)
  1815. if config.batch_size > 1
  1816. else nn.Identity(),
  1817. nn.ReLU(inplace=True),
  1818. )
  1819. self.conv1 = nn.Conv2d(
  1820. self.in_channelster * (2 + len(self.aspp_deforms)),
  1821. out_channels,
  1822. 1,
  1823. bias=False,
  1824. )
  1825. self.bn1 = (
  1826. nn.BatchNorm2d(out_channels) if config.batch_size > 1 else nn.Identity()
  1827. )
  1828. self.relu = nn.ReLU(inplace=True)
  1829. self.dropout = nn.Dropout(0.5)
  1830. def forward(self, x):
  1831. x1 = self.aspp1(x)
  1832. x_aspp_deforms = [aspp_deform(x) for aspp_deform in self.aspp_deforms]
  1833. x5 = self.global_avg_pool(x)
  1834. x5 = F.interpolate(x5, size=x1.size()[2:], mode="bilinear", align_corners=True)
  1835. x = torch.cat((x1, *x_aspp_deforms, x5), dim=1)
  1836. x = self.conv1(x)
  1837. x = self.bn1(x)
  1838. x = self.relu(x)
  1839. return self.dropout(x)
  1840. ### models/refinement/refiner.py
  1841. class RefinerPVTInChannels4(nn.Module):
  1842. def __init__(self, in_channels=3 + 1):
  1843. super(RefinerPVTInChannels4, self).__init__()
  1844. self.config = Config()
  1845. self.epoch = 1
  1846. self.bb = build_backbone(self.config.bb, params_settings="in_channels=4")
  1847. lateral_channels_in_collection = {
  1848. "vgg16": [512, 256, 128, 64],
  1849. "vgg16bn": [512, 256, 128, 64],
  1850. "resnet50": [1024, 512, 256, 64],
  1851. "pvt_v2_b2": [512, 320, 128, 64],
  1852. "pvt_v2_b5": [512, 320, 128, 64],
  1853. "swin_v1_b": [1024, 512, 256, 128],
  1854. "swin_v1_l": [1536, 768, 384, 192],
  1855. }
  1856. channels = lateral_channels_in_collection[self.config.bb]
  1857. self.squeeze_module = BasicDecBlk(channels[0], channels[0])
  1858. self.decoder = Decoder(channels)
  1859. if 0:
  1860. for key, value in self.named_parameters():
  1861. if "bb." in key:
  1862. value.requires_grad = False
  1863. def forward(self, x):
  1864. if isinstance(x, list):
  1865. x = torch.cat(x, dim=1)
  1866. ########## Encoder ##########
  1867. if self.config.bb in ["vgg16", "vgg16bn", "resnet50"]:
  1868. x1 = self.bb.conv1(x)
  1869. x2 = self.bb.conv2(x1)
  1870. x3 = self.bb.conv3(x2)
  1871. x4 = self.bb.conv4(x3)
  1872. else:
  1873. x1, x2, x3, x4 = self.bb(x)
  1874. x4 = self.squeeze_module(x4)
  1875. ########## Decoder ##########
  1876. features = [x, x1, x2, x3, x4]
  1877. scaled_preds = self.decoder(features)
  1878. return scaled_preds
  1879. class Refiner(nn.Module):
  1880. def __init__(self, in_channels=3 + 1):
  1881. super(Refiner, self).__init__()
  1882. self.config = Config()
  1883. self.epoch = 1
  1884. self.stem_layer = StemLayer(
  1885. in_channels=in_channels,
  1886. inter_channels=48,
  1887. out_channels=3,
  1888. norm_layer="BN" if self.config.batch_size > 1 else "LN",
  1889. )
  1890. self.bb = build_backbone(self.config.bb)
  1891. lateral_channels_in_collection = {
  1892. "vgg16": [512, 256, 128, 64],
  1893. "vgg16bn": [512, 256, 128, 64],
  1894. "resnet50": [1024, 512, 256, 64],
  1895. "pvt_v2_b2": [512, 320, 128, 64],
  1896. "pvt_v2_b5": [512, 320, 128, 64],
  1897. "swin_v1_b": [1024, 512, 256, 128],
  1898. "swin_v1_l": [1536, 768, 384, 192],
  1899. }
  1900. channels = lateral_channels_in_collection[self.config.bb]
  1901. self.squeeze_module = BasicDecBlk(channels[0], channels[0])
  1902. self.decoder = Decoder(channels)
  1903. if 0:
  1904. for key, value in self.named_parameters():
  1905. if "bb." in key:
  1906. value.requires_grad = False
  1907. def forward(self, x):
  1908. if isinstance(x, list):
  1909. x = torch.cat(x, dim=1)
  1910. x = self.stem_layer(x)
  1911. ########## Encoder ##########
  1912. if self.config.bb in ["vgg16", "vgg16bn", "resnet50"]:
  1913. x1 = self.bb.conv1(x)
  1914. x2 = self.bb.conv2(x1)
  1915. x3 = self.bb.conv3(x2)
  1916. x4 = self.bb.conv4(x3)
  1917. else:
  1918. x1, x2, x3, x4 = self.bb(x)
  1919. x4 = self.squeeze_module(x4)
  1920. ########## Decoder ##########
  1921. features = [x, x1, x2, x3, x4]
  1922. scaled_preds = self.decoder(features)
  1923. return scaled_preds
  1924. class Decoder(nn.Module):
  1925. def __init__(self, channels):
  1926. super(Decoder, self).__init__()
  1927. self.config = Config()
  1928. DecoderBlock = eval("BasicDecBlk")
  1929. LateralBlock = eval("BasicLatBlk")
  1930. self.decoder_block4 = DecoderBlock(channels[0], channels[1])
  1931. self.decoder_block3 = DecoderBlock(channels[1], channels[2])
  1932. self.decoder_block2 = DecoderBlock(channels[2], channels[3])
  1933. self.decoder_block1 = DecoderBlock(channels[3], channels[3] // 2)
  1934. self.lateral_block4 = LateralBlock(channels[1], channels[1])
  1935. self.lateral_block3 = LateralBlock(channels[2], channels[2])
  1936. self.lateral_block2 = LateralBlock(channels[3], channels[3])
  1937. if self.config.ms_supervision:
  1938. self.conv_ms_spvn_4 = nn.Conv2d(channels[1], 1, 1, 1, 0)
  1939. self.conv_ms_spvn_3 = nn.Conv2d(channels[2], 1, 1, 1, 0)
  1940. self.conv_ms_spvn_2 = nn.Conv2d(channels[3], 1, 1, 1, 0)
  1941. self.conv_out1 = nn.Sequential(nn.Conv2d(channels[3] // 2, 1, 1, 1, 0))
  1942. def forward(self, features):
  1943. x, x1, x2, x3, x4 = features
  1944. outs = []
  1945. p4 = self.decoder_block4(x4)
  1946. _p4 = F.interpolate(p4, size=x3.shape[2:], mode="bilinear", align_corners=True)
  1947. _p3 = _p4 + self.lateral_block4(x3)
  1948. p3 = self.decoder_block3(_p3)
  1949. _p3 = F.interpolate(p3, size=x2.shape[2:], mode="bilinear", align_corners=True)
  1950. _p2 = _p3 + self.lateral_block3(x2)
  1951. p2 = self.decoder_block2(_p2)
  1952. _p2 = F.interpolate(p2, size=x1.shape[2:], mode="bilinear", align_corners=True)
  1953. _p1 = _p2 + self.lateral_block2(x1)
  1954. _p1 = self.decoder_block1(_p1)
  1955. _p1 = F.interpolate(_p1, size=x.shape[2:], mode="bilinear", align_corners=True)
  1956. p1_out = self.conv_out1(_p1)
  1957. if self.config.ms_supervision:
  1958. outs.append(self.conv_ms_spvn_4(p4))
  1959. outs.append(self.conv_ms_spvn_3(p3))
  1960. outs.append(self.conv_ms_spvn_2(p2))
  1961. outs.append(p1_out)
  1962. return outs
  1963. class RefUNet(nn.Module):
  1964. # Refinement
  1965. def __init__(self, in_channels=3 + 1):
  1966. super(RefUNet, self).__init__()
  1967. self.encoder_1 = nn.Sequential(
  1968. nn.Conv2d(in_channels, 64, 3, 1, 1),
  1969. nn.Conv2d(64, 64, 3, 1, 1),
  1970. nn.BatchNorm2d(64),
  1971. nn.ReLU(inplace=True),
  1972. )
  1973. self.encoder_2 = nn.Sequential(
  1974. nn.MaxPool2d(2, 2, ceil_mode=True),
  1975. nn.Conv2d(64, 64, 3, 1, 1),
  1976. nn.BatchNorm2d(64),
  1977. nn.ReLU(inplace=True),
  1978. )
  1979. self.encoder_3 = nn.Sequential(
  1980. nn.MaxPool2d(2, 2, ceil_mode=True),
  1981. nn.Conv2d(64, 64, 3, 1, 1),
  1982. nn.BatchNorm2d(64),
  1983. nn.ReLU(inplace=True),
  1984. )
  1985. self.encoder_4 = nn.Sequential(
  1986. nn.MaxPool2d(2, 2, ceil_mode=True),
  1987. nn.Conv2d(64, 64, 3, 1, 1),
  1988. nn.BatchNorm2d(64),
  1989. nn.ReLU(inplace=True),
  1990. )
  1991. self.pool4 = nn.MaxPool2d(2, 2, ceil_mode=True)
  1992. #####
  1993. self.decoder_5 = nn.Sequential(
  1994. nn.Conv2d(64, 64, 3, 1, 1), nn.BatchNorm2d(64), nn.ReLU(inplace=True)
  1995. )
  1996. #####
  1997. self.decoder_4 = nn.Sequential(
  1998. nn.Conv2d(128, 64, 3, 1, 1), nn.BatchNorm2d(64), nn.ReLU(inplace=True)
  1999. )
  2000. self.decoder_3 = nn.Sequential(
  2001. nn.Conv2d(128, 64, 3, 1, 1), nn.BatchNorm2d(64), nn.ReLU(inplace=True)
  2002. )
  2003. self.decoder_2 = nn.Sequential(
  2004. nn.Conv2d(128, 64, 3, 1, 1), nn.BatchNorm2d(64), nn.ReLU(inplace=True)
  2005. )
  2006. self.decoder_1 = nn.Sequential(
  2007. nn.Conv2d(128, 64, 3, 1, 1), nn.BatchNorm2d(64), nn.ReLU(inplace=True)
  2008. )
  2009. self.conv_d0 = nn.Conv2d(64, 1, 3, 1, 1)
  2010. self.upscore2 = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
  2011. def forward(self, x):
  2012. outs = []
  2013. if isinstance(x, list):
  2014. x = torch.cat(x, dim=1)
  2015. hx = x
  2016. hx1 = self.encoder_1(hx)
  2017. hx2 = self.encoder_2(hx1)
  2018. hx3 = self.encoder_3(hx2)
  2019. hx4 = self.encoder_4(hx3)
  2020. hx = self.decoder_5(self.pool4(hx4))
  2021. hx = torch.cat((self.upscore2(hx), hx4), 1)
  2022. d4 = self.decoder_4(hx)
  2023. hx = torch.cat((self.upscore2(d4), hx3), 1)
  2024. d3 = self.decoder_3(hx)
  2025. hx = torch.cat((self.upscore2(d3), hx2), 1)
  2026. d2 = self.decoder_2(hx)
  2027. hx = torch.cat((self.upscore2(d2), hx1), 1)
  2028. d1 = self.decoder_1(hx)
  2029. x = self.conv_d0(d1)
  2030. outs.append(x)
  2031. return outs
  2032. ### models/stem_layer.py
  2033. class StemLayer(nn.Module):
  2034. r"""Stem layer of InternImage
  2035. Args:
  2036. in_channels (int): number of input channels
  2037. out_channels (int): number of output channels
  2038. act_layer (str): activation layer
  2039. norm_layer (str): normalization layer
  2040. """
  2041. def __init__(
  2042. self,
  2043. in_channels=3 + 1,
  2044. inter_channels=48,
  2045. out_channels=96,
  2046. act_layer="GELU",
  2047. norm_layer="BN",
  2048. ):
  2049. super().__init__()
  2050. self.conv1 = nn.Conv2d(
  2051. in_channels, inter_channels, kernel_size=3, stride=1, padding=1
  2052. )
  2053. self.norm1 = build_norm_layer(
  2054. inter_channels, norm_layer, "channels_first", "channels_first"
  2055. )
  2056. self.act = build_act_layer(act_layer)
  2057. self.conv2 = nn.Conv2d(
  2058. inter_channels, out_channels, kernel_size=3, stride=1, padding=1
  2059. )
  2060. self.norm2 = build_norm_layer(
  2061. out_channels, norm_layer, "channels_first", "channels_first"
  2062. )
  2063. def forward(self, x):
  2064. x = self.conv1(x)
  2065. x = self.norm1(x)
  2066. x = self.act(x)
  2067. x = self.conv2(x)
  2068. x = self.norm2(x)
  2069. return x
  2070. ### models/birefnet.py
  2071. class BiRefNetConfig(PretrainedConfig):
  2072. model_type = "SegformerForSemanticSegmentation"
  2073. def __init__(self, bb_pretrained=False, **kwargs):
  2074. self.bb_pretrained = bb_pretrained
  2075. super().__init__(**kwargs)
  2076. class BiRefNet(PreTrainedModel):
  2077. config_class = BiRefNetConfig
  2078. def __init__(self, bb_pretrained=True, config=BiRefNetConfig()):
  2079. super(BiRefNet, self).__init__(config)
  2080. print(1)
  2081. bb_pretrained = config.bb_pretrained
  2082. self.config = Config()
  2083. self.epoch = 1
  2084. self.bb = build_backbone(self.config.bb, pretrained=bb_pretrained)
  2085. channels = self.config.lateral_channels_in_collection
  2086. if self.config.auxiliary_classification:
  2087. self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
  2088. self.cls_head = nn.Sequential(
  2089. nn.Linear(channels[0], len(class_labels_TR_sorted))
  2090. )
  2091. if self.config.squeeze_block:
  2092. self.squeeze_module = nn.Sequential(
  2093. *[
  2094. eval(self.config.squeeze_block.split("_x")[0])(
  2095. channels[0] + sum(self.config.cxt), channels[0]
  2096. )
  2097. for _ in range(eval(self.config.squeeze_block.split("_x")[1]))
  2098. ]
  2099. )
  2100. self.decoder = Decoder(channels)
  2101. if self.config.ender:
  2102. self.dec_end = nn.Sequential(
  2103. nn.Conv2d(1, 16, 3, 1, 1),
  2104. nn.Conv2d(16, 1, 3, 1, 1),
  2105. nn.ReLU(inplace=True),
  2106. )
  2107. # refine patch-level segmentation
  2108. if self.config.refine:
  2109. if self.config.refine == "itself":
  2110. self.stem_layer = StemLayer(
  2111. in_channels=3 + 1,
  2112. inter_channels=48,
  2113. out_channels=3,
  2114. norm_layer="BN" if self.config.batch_size > 1 else "LN",
  2115. )
  2116. else:
  2117. self.refiner = eval(
  2118. "{}({})".format(self.config.refine, "in_channels=3+1")
  2119. )
  2120. if self.config.freeze_bb:
  2121. # Freeze the backbone...
  2122. print(self.named_parameters())
  2123. for key, value in self.named_parameters():
  2124. if "bb." in key and "refiner." not in key:
  2125. value.requires_grad = False
  2126. def forward_enc(self, x):
  2127. if self.config.bb in ["vgg16", "vgg16bn", "resnet50"]:
  2128. x1 = self.bb.conv1(x)
  2129. x2 = self.bb.conv2(x1)
  2130. x3 = self.bb.conv3(x2)
  2131. x4 = self.bb.conv4(x3)
  2132. else:
  2133. x1, x2, x3, x4 = self.bb(x)
  2134. if self.config.mul_scl_ipt == "cat":
  2135. B, C, H, W = x.shape
  2136. x1_, x2_, x3_, x4_ = self.bb(
  2137. F.interpolate(
  2138. x, size=(H // 2, W // 2), mode="bilinear", align_corners=True
  2139. )
  2140. )
  2141. x1 = torch.cat(
  2142. [
  2143. x1,
  2144. F.interpolate(
  2145. x1_, size=x1.shape[2:], mode="bilinear", align_corners=True
  2146. ),
  2147. ],
  2148. dim=1,
  2149. )
  2150. x2 = torch.cat(
  2151. [
  2152. x2,
  2153. F.interpolate(
  2154. x2_, size=x2.shape[2:], mode="bilinear", align_corners=True
  2155. ),
  2156. ],
  2157. dim=1,
  2158. )
  2159. x3 = torch.cat(
  2160. [
  2161. x3,
  2162. F.interpolate(
  2163. x3_, size=x3.shape[2:], mode="bilinear", align_corners=True
  2164. ),
  2165. ],
  2166. dim=1,
  2167. )
  2168. x4 = torch.cat(
  2169. [
  2170. x4,
  2171. F.interpolate(
  2172. x4_, size=x4.shape[2:], mode="bilinear", align_corners=True
  2173. ),
  2174. ],
  2175. dim=1,
  2176. )
  2177. elif self.config.mul_scl_ipt == "add":
  2178. B, C, H, W = x.shape
  2179. x1_, x2_, x3_, x4_ = self.bb(
  2180. F.interpolate(
  2181. x, size=(H // 2, W // 2), mode="bilinear", align_corners=True
  2182. )
  2183. )
  2184. x1 = x1 + F.interpolate(
  2185. x1_, size=x1.shape[2:], mode="bilinear", align_corners=True
  2186. )
  2187. x2 = x2 + F.interpolate(
  2188. x2_, size=x2.shape[2:], mode="bilinear", align_corners=True
  2189. )
  2190. x3 = x3 + F.interpolate(
  2191. x3_, size=x3.shape[2:], mode="bilinear", align_corners=True
  2192. )
  2193. x4 = x4 + F.interpolate(
  2194. x4_, size=x4.shape[2:], mode="bilinear", align_corners=True
  2195. )
  2196. class_preds = (
  2197. self.cls_head(self.avgpool(x4).view(x4.shape[0], -1))
  2198. if self.training and self.config.auxiliary_classification
  2199. else None
  2200. )
  2201. if self.config.cxt:
  2202. x4 = torch.cat(
  2203. (
  2204. *[
  2205. F.interpolate(
  2206. x1, size=x4.shape[2:], mode="bilinear", align_corners=True
  2207. ),
  2208. F.interpolate(
  2209. x2, size=x4.shape[2:], mode="bilinear", align_corners=True
  2210. ),
  2211. F.interpolate(
  2212. x3, size=x4.shape[2:], mode="bilinear", align_corners=True
  2213. ),
  2214. ][-len(self.config.cxt) :],
  2215. x4,
  2216. ),
  2217. dim=1,
  2218. )
  2219. return (x1, x2, x3, x4), class_preds
  2220. def forward_ori(self, x):
  2221. ########## Encoder ##########
  2222. (x1, x2, x3, x4), class_preds = self.forward_enc(x)
  2223. if self.config.squeeze_block:
  2224. x4 = self.squeeze_module(x4)
  2225. ########## Decoder ##########
  2226. features = [x, x1, x2, x3, x4]
  2227. # if self.training and self.config.out_ref:
  2228. # features.append(laplacian(torch.mean(x, dim=1).unsqueeze(1), kernel_size=5))
  2229. scaled_preds = self.decoder(features)
  2230. return scaled_preds, class_preds
  2231. def forward(self, x):
  2232. scaled_preds, class_preds = self.forward_ori(x)
  2233. class_preds_lst = [class_preds]
  2234. return [scaled_preds, class_preds_lst] if self.training else scaled_preds
  2235. class Decoder(nn.Module):
  2236. def __init__(self, channels):
  2237. super(Decoder, self).__init__()
  2238. self.config = Config()
  2239. DecoderBlock = eval(self.config.dec_blk)
  2240. LateralBlock = eval(self.config.lat_blk)
  2241. if self.config.dec_ipt:
  2242. self.split = self.config.dec_ipt_split
  2243. N_dec_ipt = 64
  2244. DBlock = SimpleConvs
  2245. ic = 64
  2246. ipt_cha_opt = 1
  2247. self.ipt_blk5 = DBlock(
  2248. 2**10 * 3 if self.split else 3,
  2249. [N_dec_ipt, channels[0] // 8][ipt_cha_opt],
  2250. inter_channels=ic,
  2251. )
  2252. self.ipt_blk4 = DBlock(
  2253. 2**8 * 3 if self.split else 3,
  2254. [N_dec_ipt, channels[0] // 8][ipt_cha_opt],
  2255. inter_channels=ic,
  2256. )
  2257. self.ipt_blk3 = DBlock(
  2258. 2**6 * 3 if self.split else 3,
  2259. [N_dec_ipt, channels[1] // 8][ipt_cha_opt],
  2260. inter_channels=ic,
  2261. )
  2262. self.ipt_blk2 = DBlock(
  2263. 2**4 * 3 if self.split else 3,
  2264. [N_dec_ipt, channels[2] // 8][ipt_cha_opt],
  2265. inter_channels=ic,
  2266. )
  2267. self.ipt_blk1 = DBlock(
  2268. 2**0 * 3 if self.split else 3,
  2269. [N_dec_ipt, channels[3] // 8][ipt_cha_opt],
  2270. inter_channels=ic,
  2271. )
  2272. else:
  2273. self.split = None
  2274. self.decoder_block4 = DecoderBlock(
  2275. channels[0]
  2276. + (
  2277. [N_dec_ipt, channels[0] // 8][ipt_cha_opt] if self.config.dec_ipt else 0
  2278. ),
  2279. channels[1],
  2280. )
  2281. self.decoder_block3 = DecoderBlock(
  2282. channels[1]
  2283. + (
  2284. [N_dec_ipt, channels[0] // 8][ipt_cha_opt] if self.config.dec_ipt else 0
  2285. ),
  2286. channels[2],
  2287. )
  2288. self.decoder_block2 = DecoderBlock(
  2289. channels[2]
  2290. + (
  2291. [N_dec_ipt, channels[1] // 8][ipt_cha_opt] if self.config.dec_ipt else 0
  2292. ),
  2293. channels[3],
  2294. )
  2295. self.decoder_block1 = DecoderBlock(
  2296. channels[3]
  2297. + (
  2298. [N_dec_ipt, channels[2] // 8][ipt_cha_opt] if self.config.dec_ipt else 0
  2299. ),
  2300. channels[3] // 2,
  2301. )
  2302. self.conv_out1 = nn.Sequential(
  2303. nn.Conv2d(
  2304. channels[3] // 2
  2305. + (
  2306. [N_dec_ipt, channels[3] // 8][ipt_cha_opt]
  2307. if self.config.dec_ipt
  2308. else 0
  2309. ),
  2310. 1,
  2311. 1,
  2312. 1,
  2313. 0,
  2314. )
  2315. )
  2316. self.lateral_block4 = LateralBlock(channels[1], channels[1])
  2317. self.lateral_block3 = LateralBlock(channels[2], channels[2])
  2318. self.lateral_block2 = LateralBlock(channels[3], channels[3])
  2319. if self.config.ms_supervision:
  2320. self.conv_ms_spvn_4 = nn.Conv2d(channels[1], 1, 1, 1, 0)
  2321. self.conv_ms_spvn_3 = nn.Conv2d(channels[2], 1, 1, 1, 0)
  2322. self.conv_ms_spvn_2 = nn.Conv2d(channels[3], 1, 1, 1, 0)
  2323. if self.config.out_ref:
  2324. _N = 16
  2325. self.gdt_convs_4 = nn.Sequential(
  2326. nn.Conv2d(channels[1], _N, 3, 1, 1),
  2327. nn.BatchNorm2d(_N) if self.config.batch_size > 1 else nn.Identity(),
  2328. nn.ReLU(inplace=True),
  2329. )
  2330. self.gdt_convs_3 = nn.Sequential(
  2331. nn.Conv2d(channels[2], _N, 3, 1, 1),
  2332. nn.BatchNorm2d(_N) if self.config.batch_size > 1 else nn.Identity(),
  2333. nn.ReLU(inplace=True),
  2334. )
  2335. self.gdt_convs_2 = nn.Sequential(
  2336. nn.Conv2d(channels[3], _N, 3, 1, 1),
  2337. nn.BatchNorm2d(_N) if self.config.batch_size > 1 else nn.Identity(),
  2338. nn.ReLU(inplace=True),
  2339. )
  2340. self.gdt_convs_pred_4 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
  2341. self.gdt_convs_pred_3 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
  2342. self.gdt_convs_pred_2 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
  2343. self.gdt_convs_attn_4 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
  2344. self.gdt_convs_attn_3 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
  2345. self.gdt_convs_attn_2 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
  2346. def get_patches_batch(self, x, p):
  2347. _size_h, _size_w = p.shape[2:]
  2348. patches_batch = []
  2349. for idx in range(x.shape[0]):
  2350. columns_x = torch.split(x[idx], split_size_or_sections=_size_w, dim=-1)
  2351. patches_x = []
  2352. for column_x in columns_x:
  2353. patches_x += [
  2354. p.unsqueeze(0)
  2355. for p in torch.split(
  2356. column_x, split_size_or_sections=_size_h, dim=-2
  2357. )
  2358. ]
  2359. patch_sample = torch.cat(patches_x, dim=1)
  2360. patches_batch.append(patch_sample)
  2361. return torch.cat(patches_batch, dim=0)
  2362. def forward(self, features):
  2363. if self.training and self.config.out_ref:
  2364. outs_gdt_pred = []
  2365. outs_gdt_label = []
  2366. x, x1, x2, x3, x4, gdt_gt = features
  2367. else:
  2368. x, x1, x2, x3, x4 = features
  2369. outs = []
  2370. if self.config.dec_ipt:
  2371. patches_batch = self.get_patches_batch(x, x4) if self.split else x
  2372. x4 = torch.cat(
  2373. (
  2374. x4,
  2375. self.ipt_blk5(
  2376. F.interpolate(
  2377. patches_batch,
  2378. size=x4.shape[2:],
  2379. mode="bilinear",
  2380. align_corners=True,
  2381. )
  2382. ),
  2383. ),
  2384. 1,
  2385. )
  2386. p4 = self.decoder_block4(x4)
  2387. m4 = self.conv_ms_spvn_4(p4) if self.config.ms_supervision else None
  2388. if self.config.out_ref:
  2389. p4_gdt = self.gdt_convs_4(p4)
  2390. if self.training:
  2391. # >> GT:
  2392. m4_dia = m4
  2393. gdt_label_main_4 = gdt_gt * F.interpolate(
  2394. m4_dia, size=gdt_gt.shape[2:], mode="bilinear", align_corners=True
  2395. )
  2396. outs_gdt_label.append(gdt_label_main_4)
  2397. # >> Pred:
  2398. gdt_pred_4 = self.gdt_convs_pred_4(p4_gdt)
  2399. outs_gdt_pred.append(gdt_pred_4)
  2400. gdt_attn_4 = self.gdt_convs_attn_4(p4_gdt).sigmoid()
  2401. # >> Finally:
  2402. p4 = p4 * gdt_attn_4
  2403. _p4 = F.interpolate(p4, size=x3.shape[2:], mode="bilinear", align_corners=True)
  2404. _p3 = _p4 + self.lateral_block4(x3)
  2405. if self.config.dec_ipt:
  2406. patches_batch = self.get_patches_batch(x, _p3) if self.split else x
  2407. _p3 = torch.cat(
  2408. (
  2409. _p3,
  2410. self.ipt_blk4(
  2411. F.interpolate(
  2412. patches_batch,
  2413. size=x3.shape[2:],
  2414. mode="bilinear",
  2415. align_corners=True,
  2416. )
  2417. ),
  2418. ),
  2419. 1,
  2420. )
  2421. p3 = self.decoder_block3(_p3)
  2422. m3 = self.conv_ms_spvn_3(p3) if self.config.ms_supervision else None
  2423. if self.config.out_ref:
  2424. p3_gdt = self.gdt_convs_3(p3)
  2425. if self.training:
  2426. # >> GT:
  2427. # m3 --dilation--> m3_dia
  2428. # G_3^gt * m3_dia --> G_3^m, which is the label of gradient
  2429. m3_dia = m3
  2430. gdt_label_main_3 = gdt_gt * F.interpolate(
  2431. m3_dia, size=gdt_gt.shape[2:], mode="bilinear", align_corners=True
  2432. )
  2433. outs_gdt_label.append(gdt_label_main_3)
  2434. # >> Pred:
  2435. # p3 --conv--BN--> F_3^G, where F_3^G predicts the \hat{G_3} with xx
  2436. # F_3^G --sigmoid--> A_3^G
  2437. gdt_pred_3 = self.gdt_convs_pred_3(p3_gdt)
  2438. outs_gdt_pred.append(gdt_pred_3)
  2439. gdt_attn_3 = self.gdt_convs_attn_3(p3_gdt).sigmoid()
  2440. # >> Finally:
  2441. # p3 = p3 * A_3^G
  2442. p3 = p3 * gdt_attn_3
  2443. _p3 = F.interpolate(p3, size=x2.shape[2:], mode="bilinear", align_corners=True)
  2444. _p2 = _p3 + self.lateral_block3(x2)
  2445. if self.config.dec_ipt:
  2446. patches_batch = self.get_patches_batch(x, _p2) if self.split else x
  2447. _p2 = torch.cat(
  2448. (
  2449. _p2,
  2450. self.ipt_blk3(
  2451. F.interpolate(
  2452. patches_batch,
  2453. size=x2.shape[2:],
  2454. mode="bilinear",
  2455. align_corners=True,
  2456. )
  2457. ),
  2458. ),
  2459. 1,
  2460. )
  2461. p2 = self.decoder_block2(_p2)
  2462. m2 = self.conv_ms_spvn_2(p2) if self.config.ms_supervision else None
  2463. if self.config.out_ref:
  2464. p2_gdt = self.gdt_convs_2(p2)
  2465. if self.training:
  2466. # >> GT:
  2467. m2_dia = m2
  2468. gdt_label_main_2 = gdt_gt * F.interpolate(
  2469. m2_dia, size=gdt_gt.shape[2:], mode="bilinear", align_corners=True
  2470. )
  2471. outs_gdt_label.append(gdt_label_main_2)
  2472. # >> Pred:
  2473. gdt_pred_2 = self.gdt_convs_pred_2(p2_gdt)
  2474. outs_gdt_pred.append(gdt_pred_2)
  2475. gdt_attn_2 = self.gdt_convs_attn_2(p2_gdt).sigmoid()
  2476. # >> Finally:
  2477. p2 = p2 * gdt_attn_2
  2478. _p2 = F.interpolate(p2, size=x1.shape[2:], mode="bilinear", align_corners=True)
  2479. _p1 = _p2 + self.lateral_block2(x1)
  2480. if self.config.dec_ipt:
  2481. patches_batch = self.get_patches_batch(x, _p1) if self.split else x
  2482. _p1 = torch.cat(
  2483. (
  2484. _p1,
  2485. self.ipt_blk2(
  2486. F.interpolate(
  2487. patches_batch,
  2488. size=x1.shape[2:],
  2489. mode="bilinear",
  2490. align_corners=True,
  2491. )
  2492. ),
  2493. ),
  2494. 1,
  2495. )
  2496. _p1 = self.decoder_block1(_p1)
  2497. _p1 = F.interpolate(_p1, size=x.shape[2:], mode="bilinear", align_corners=True)
  2498. if self.config.dec_ipt:
  2499. patches_batch = self.get_patches_batch(x, _p1) if self.split else x
  2500. _p1 = torch.cat(
  2501. (
  2502. _p1,
  2503. self.ipt_blk1(
  2504. F.interpolate(
  2505. patches_batch,
  2506. size=x.shape[2:],
  2507. mode="bilinear",
  2508. align_corners=True,
  2509. )
  2510. ),
  2511. ),
  2512. 1,
  2513. )
  2514. p1_out = self.conv_out1(_p1)
  2515. if self.config.ms_supervision:
  2516. outs.append(m4)
  2517. outs.append(m3)
  2518. outs.append(m2)
  2519. outs.append(p1_out)
  2520. return (
  2521. outs
  2522. if not (self.config.out_ref and self.training)
  2523. else ([outs_gdt_pred, outs_gdt_label], outs)
  2524. )
  2525. class SimpleConvs(nn.Module):
  2526. def __init__(self, in_channels: int, out_channels: int, inter_channels=64) -> None:
  2527. super().__init__()
  2528. self.conv1 = nn.Conv2d(in_channels, inter_channels, 3, 1, 1)
  2529. self.conv_out = nn.Conv2d(inter_channels, out_channels, 3, 1, 1)
  2530. def forward(self, x):
  2531. return self.conv_out(self.conv1(x))
  2532. def create_briarmbg2_session():
  2533. birefnet = BiRefNet.from_pretrained("briaai/RMBG-2.0")
  2534. return birefnet
  2535. def briarmbg2_process(device, bgr_np_image, session, only_mask=False):
  2536. from PIL import Image
  2537. from torchvision import transforms
  2538. transform_image = transforms.Compose(
  2539. [
  2540. transforms.Resize((1024, 1024)),
  2541. transforms.ToTensor(),
  2542. transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
  2543. ]
  2544. )
  2545. image = Image.fromarray(bgr_np_image)
  2546. image_size = image.size
  2547. input_images = transform_image(image).unsqueeze(0)
  2548. input_images = input_images.to(device)
  2549. # Prediction
  2550. preds = session(input_images)[-1].sigmoid().cpu()
  2551. pred = preds[0].squeeze()
  2552. pred_pil = transforms.ToPILImage()(pred)
  2553. mask = pred_pil.resize(image_size)
  2554. if only_mask:
  2555. return np.array(mask)
  2556. image.putalpha(mask)
  2557. return np.array(image)