briarmbg.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516
  1. # copy from: https://huggingface.co/spaces/briaai/BRIA-RMBG-1.4/blob/main/briarmbg.py
  2. import cv2
  3. import numpy as np
  4. import torch
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. from PIL import Image
  8. from torchvision.transforms.functional import normalize
  9. class REBNCONV(nn.Module):
  10. def __init__(self, in_ch=3, out_ch=3, dirate=1, stride=1):
  11. super(REBNCONV, self).__init__()
  12. self.conv_s1 = nn.Conv2d(
  13. in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate, stride=stride
  14. )
  15. self.bn_s1 = nn.BatchNorm2d(out_ch)
  16. self.relu_s1 = nn.ReLU(inplace=True)
  17. def forward(self, x):
  18. hx = x
  19. xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
  20. return xout
  21. ## upsample tensor 'src' to have the same spatial size with tensor 'tar'
  22. def _upsample_like(src, tar):
  23. src = F.interpolate(src, size=tar.shape[2:], mode="bilinear")
  24. return src
  25. ### RSU-7 ###
  26. class RSU7(nn.Module):
  27. def __init__(self, in_ch=3, mid_ch=12, out_ch=3, img_size=512):
  28. super(RSU7, self).__init__()
  29. self.in_ch = in_ch
  30. self.mid_ch = mid_ch
  31. self.out_ch = out_ch
  32. self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) ## 1 -> 1/2
  33. self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
  34. self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
  35. self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
  36. self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
  37. self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
  38. self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
  39. self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
  40. self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
  41. self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
  42. self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
  43. self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1)
  44. self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2)
  45. self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
  46. self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
  47. self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
  48. self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
  49. self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
  50. self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
  51. def forward(self, x):
  52. b, c, h, w = x.shape
  53. hx = x
  54. hxin = self.rebnconvin(hx)
  55. hx1 = self.rebnconv1(hxin)
  56. hx = self.pool1(hx1)
  57. hx2 = self.rebnconv2(hx)
  58. hx = self.pool2(hx2)
  59. hx3 = self.rebnconv3(hx)
  60. hx = self.pool3(hx3)
  61. hx4 = self.rebnconv4(hx)
  62. hx = self.pool4(hx4)
  63. hx5 = self.rebnconv5(hx)
  64. hx = self.pool5(hx5)
  65. hx6 = self.rebnconv6(hx)
  66. hx7 = self.rebnconv7(hx6)
  67. hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1))
  68. hx6dup = _upsample_like(hx6d, hx5)
  69. hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1))
  70. hx5dup = _upsample_like(hx5d, hx4)
  71. hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
  72. hx4dup = _upsample_like(hx4d, hx3)
  73. hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
  74. hx3dup = _upsample_like(hx3d, hx2)
  75. hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
  76. hx2dup = _upsample_like(hx2d, hx1)
  77. hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
  78. return hx1d + hxin
  79. ### RSU-6 ###
  80. class RSU6(nn.Module):
  81. def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
  82. super(RSU6, self).__init__()
  83. self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
  84. self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
  85. self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
  86. self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
  87. self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
  88. self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
  89. self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
  90. self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
  91. self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
  92. self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
  93. self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2)
  94. self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
  95. self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
  96. self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
  97. self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
  98. self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
  99. def forward(self, x):
  100. hx = x
  101. hxin = self.rebnconvin(hx)
  102. hx1 = self.rebnconv1(hxin)
  103. hx = self.pool1(hx1)
  104. hx2 = self.rebnconv2(hx)
  105. hx = self.pool2(hx2)
  106. hx3 = self.rebnconv3(hx)
  107. hx = self.pool3(hx3)
  108. hx4 = self.rebnconv4(hx)
  109. hx = self.pool4(hx4)
  110. hx5 = self.rebnconv5(hx)
  111. hx6 = self.rebnconv6(hx5)
  112. hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1))
  113. hx5dup = _upsample_like(hx5d, hx4)
  114. hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
  115. hx4dup = _upsample_like(hx4d, hx3)
  116. hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
  117. hx3dup = _upsample_like(hx3d, hx2)
  118. hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
  119. hx2dup = _upsample_like(hx2d, hx1)
  120. hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
  121. return hx1d + hxin
  122. ### RSU-5 ###
  123. class RSU5(nn.Module):
  124. def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
  125. super(RSU5, self).__init__()
  126. self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
  127. self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
  128. self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
  129. self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
  130. self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
  131. self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
  132. self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
  133. self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
  134. self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2)
  135. self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
  136. self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
  137. self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
  138. self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
  139. def forward(self, x):
  140. hx = x
  141. hxin = self.rebnconvin(hx)
  142. hx1 = self.rebnconv1(hxin)
  143. hx = self.pool1(hx1)
  144. hx2 = self.rebnconv2(hx)
  145. hx = self.pool2(hx2)
  146. hx3 = self.rebnconv3(hx)
  147. hx = self.pool3(hx3)
  148. hx4 = self.rebnconv4(hx)
  149. hx5 = self.rebnconv5(hx4)
  150. hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1))
  151. hx4dup = _upsample_like(hx4d, hx3)
  152. hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
  153. hx3dup = _upsample_like(hx3d, hx2)
  154. hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
  155. hx2dup = _upsample_like(hx2d, hx1)
  156. hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
  157. return hx1d + hxin
  158. ### RSU-4 ###
  159. class RSU4(nn.Module):
  160. def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
  161. super(RSU4, self).__init__()
  162. self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
  163. self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
  164. self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
  165. self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
  166. self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
  167. self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
  168. self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2)
  169. self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
  170. self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
  171. self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
  172. def forward(self, x):
  173. hx = x
  174. hxin = self.rebnconvin(hx)
  175. hx1 = self.rebnconv1(hxin)
  176. hx = self.pool1(hx1)
  177. hx2 = self.rebnconv2(hx)
  178. hx = self.pool2(hx2)
  179. hx3 = self.rebnconv3(hx)
  180. hx4 = self.rebnconv4(hx3)
  181. hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
  182. hx3dup = _upsample_like(hx3d, hx2)
  183. hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
  184. hx2dup = _upsample_like(hx2d, hx1)
  185. hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
  186. return hx1d + hxin
  187. ### RSU-4F ###
  188. class RSU4F(nn.Module):
  189. def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
  190. super(RSU4F, self).__init__()
  191. self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
  192. self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
  193. self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2)
  194. self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4)
  195. self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8)
  196. self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4)
  197. self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2)
  198. self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
  199. def forward(self, x):
  200. hx = x
  201. hxin = self.rebnconvin(hx)
  202. hx1 = self.rebnconv1(hxin)
  203. hx2 = self.rebnconv2(hx1)
  204. hx3 = self.rebnconv3(hx2)
  205. hx4 = self.rebnconv4(hx3)
  206. hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
  207. hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1))
  208. hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1))
  209. return hx1d + hxin
  210. class myrebnconv(nn.Module):
  211. def __init__(
  212. self,
  213. in_ch=3,
  214. out_ch=1,
  215. kernel_size=3,
  216. stride=1,
  217. padding=1,
  218. dilation=1,
  219. groups=1,
  220. ):
  221. super(myrebnconv, self).__init__()
  222. self.conv = nn.Conv2d(
  223. in_ch,
  224. out_ch,
  225. kernel_size=kernel_size,
  226. stride=stride,
  227. padding=padding,
  228. dilation=dilation,
  229. groups=groups,
  230. )
  231. self.bn = nn.BatchNorm2d(out_ch)
  232. self.rl = nn.ReLU(inplace=True)
  233. def forward(self, x):
  234. return self.rl(self.bn(self.conv(x)))
  235. class BriaRMBG(nn.Module):
  236. def __init__(self, in_ch=3, out_ch=1):
  237. super(BriaRMBG, self).__init__()
  238. self.conv_in = nn.Conv2d(in_ch, 64, 3, stride=2, padding=1)
  239. self.pool_in = nn.MaxPool2d(2, stride=2, ceil_mode=True)
  240. self.stage1 = RSU7(64, 32, 64)
  241. self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
  242. self.stage2 = RSU6(64, 32, 128)
  243. self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
  244. self.stage3 = RSU5(128, 64, 256)
  245. self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
  246. self.stage4 = RSU4(256, 128, 512)
  247. self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
  248. self.stage5 = RSU4F(512, 256, 512)
  249. self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
  250. self.stage6 = RSU4F(512, 256, 512)
  251. # decoder
  252. self.stage5d = RSU4F(1024, 256, 512)
  253. self.stage4d = RSU4(1024, 128, 256)
  254. self.stage3d = RSU5(512, 64, 128)
  255. self.stage2d = RSU6(256, 32, 64)
  256. self.stage1d = RSU7(128, 16, 64)
  257. self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
  258. self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
  259. self.side3 = nn.Conv2d(128, out_ch, 3, padding=1)
  260. self.side4 = nn.Conv2d(256, out_ch, 3, padding=1)
  261. self.side5 = nn.Conv2d(512, out_ch, 3, padding=1)
  262. self.side6 = nn.Conv2d(512, out_ch, 3, padding=1)
  263. # self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
  264. def forward(self, x):
  265. hx = x
  266. hxin = self.conv_in(hx)
  267. # hx = self.pool_in(hxin)
  268. # stage 1
  269. hx1 = self.stage1(hxin)
  270. hx = self.pool12(hx1)
  271. # stage 2
  272. hx2 = self.stage2(hx)
  273. hx = self.pool23(hx2)
  274. # stage 3
  275. hx3 = self.stage3(hx)
  276. hx = self.pool34(hx3)
  277. # stage 4
  278. hx4 = self.stage4(hx)
  279. hx = self.pool45(hx4)
  280. # stage 5
  281. hx5 = self.stage5(hx)
  282. hx = self.pool56(hx5)
  283. # stage 6
  284. hx6 = self.stage6(hx)
  285. hx6up = _upsample_like(hx6, hx5)
  286. # -------------------- decoder --------------------
  287. hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
  288. hx5dup = _upsample_like(hx5d, hx4)
  289. hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
  290. hx4dup = _upsample_like(hx4d, hx3)
  291. hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
  292. hx3dup = _upsample_like(hx3d, hx2)
  293. hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
  294. hx2dup = _upsample_like(hx2d, hx1)
  295. hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
  296. # side output
  297. d1 = self.side1(hx1d)
  298. d1 = _upsample_like(d1, x)
  299. d2 = self.side2(hx2d)
  300. d2 = _upsample_like(d2, x)
  301. d3 = self.side3(hx3d)
  302. d3 = _upsample_like(d3, x)
  303. d4 = self.side4(hx4d)
  304. d4 = _upsample_like(d4, x)
  305. d5 = self.side5(hx5d)
  306. d5 = _upsample_like(d5, x)
  307. d6 = self.side6(hx6)
  308. d6 = _upsample_like(d6, x)
  309. return (
  310. [
  311. F.sigmoid(d1),
  312. F.sigmoid(d2),
  313. F.sigmoid(d3),
  314. F.sigmoid(d4),
  315. F.sigmoid(d5),
  316. F.sigmoid(d6),
  317. ],
  318. [hx1d, hx2d, hx3d, hx4d, hx5d, hx6],
  319. )
  320. def resize_image(image):
  321. image = image.convert("RGB")
  322. model_input_size = (1024, 1024)
  323. image = image.resize(model_input_size, Image.BILINEAR)
  324. return image
  325. def create_briarmbg_session():
  326. from huggingface_hub import hf_hub_download
  327. net = BriaRMBG()
  328. model_path = hf_hub_download("briaai/RMBG-1.4", "model.pth")
  329. net.load_state_dict(torch.load(model_path, map_location="cpu"))
  330. net.eval()
  331. return net
  332. def briarmbg_process(device, bgr_np_image, session, only_mask=False):
  333. # prepare input
  334. orig_bgr_image = Image.fromarray(bgr_np_image)
  335. w, h = orig_im_size = orig_bgr_image.size
  336. image = resize_image(orig_bgr_image)
  337. im_np = np.array(image)
  338. im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1)
  339. im_tensor = torch.unsqueeze(im_tensor, 0)
  340. im_tensor = torch.divide(im_tensor, 255.0)
  341. im_tensor = normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
  342. im_tensor = im_tensor.to(device)
  343. # inference
  344. result = session(im_tensor)
  345. # post process
  346. result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode="bilinear"), 0)
  347. ma = torch.max(result)
  348. mi = torch.min(result)
  349. result = (result - mi) / (ma - mi)
  350. # image to pil
  351. im_array = (result * 255).cpu().data.numpy().astype(np.uint8)
  352. mask = np.squeeze(im_array)
  353. if only_mask:
  354. return mask
  355. pil_im = Image.fromarray(mask)
  356. # paste the mask on the original image
  357. new_im = Image.new("RGBA", pil_im.size, (0, 0, 0, 0))
  358. new_im.paste(orig_bgr_image, mask=pil_im)
  359. rgba_np_img = np.asarray(new_im)
  360. return rgba_np_img