anime_seg.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462
  1. import cv2
  2. import numpy as np
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. from PIL import Image
  7. from sorawm.iopaint.helper import load_model
  8. from sorawm.iopaint.plugins.base_plugin import BasePlugin
  9. from sorawm.iopaint.schema import RunPluginRequest
  10. class REBNCONV(nn.Module):
  11. def __init__(self, in_ch=3, out_ch=3, dirate=1, stride=1):
  12. super(REBNCONV, self).__init__()
  13. self.conv_s1 = nn.Conv2d(
  14. in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate, stride=stride
  15. )
  16. self.bn_s1 = nn.BatchNorm2d(out_ch)
  17. self.relu_s1 = nn.ReLU(inplace=True)
  18. def forward(self, x):
  19. hx = x
  20. xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
  21. return xout
  22. ## upsample tensor 'src' to have the same spatial size with tensor 'tar'
  23. def _upsample_like(src, tar):
  24. src = F.interpolate(src, size=tar.shape[2:], mode="bilinear", align_corners=False)
  25. return src
  26. ### RSU-7 ###
  27. class RSU7(nn.Module):
  28. def __init__(self, in_ch=3, mid_ch=12, out_ch=3, img_size=512):
  29. super(RSU7, self).__init__()
  30. self.in_ch = in_ch
  31. self.mid_ch = mid_ch
  32. self.out_ch = out_ch
  33. self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) ## 1 -> 1/2
  34. self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
  35. self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
  36. self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
  37. self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
  38. self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
  39. self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
  40. self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
  41. self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
  42. self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
  43. self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
  44. self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1)
  45. self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2)
  46. self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
  47. self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
  48. self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
  49. self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
  50. self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
  51. self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
  52. def forward(self, x):
  53. b, c, h, w = x.shape
  54. hx = x
  55. hxin = self.rebnconvin(hx)
  56. hx1 = self.rebnconv1(hxin)
  57. hx = self.pool1(hx1)
  58. hx2 = self.rebnconv2(hx)
  59. hx = self.pool2(hx2)
  60. hx3 = self.rebnconv3(hx)
  61. hx = self.pool3(hx3)
  62. hx4 = self.rebnconv4(hx)
  63. hx = self.pool4(hx4)
  64. hx5 = self.rebnconv5(hx)
  65. hx = self.pool5(hx5)
  66. hx6 = self.rebnconv6(hx)
  67. hx7 = self.rebnconv7(hx6)
  68. hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1))
  69. hx6dup = _upsample_like(hx6d, hx5)
  70. hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1))
  71. hx5dup = _upsample_like(hx5d, hx4)
  72. hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
  73. hx4dup = _upsample_like(hx4d, hx3)
  74. hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
  75. hx3dup = _upsample_like(hx3d, hx2)
  76. hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
  77. hx2dup = _upsample_like(hx2d, hx1)
  78. hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
  79. return hx1d + hxin
  80. ### RSU-6 ###
  81. class RSU6(nn.Module):
  82. def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
  83. super(RSU6, self).__init__()
  84. self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
  85. self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
  86. self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
  87. self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
  88. self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
  89. self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
  90. self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
  91. self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
  92. self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
  93. self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
  94. self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2)
  95. self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
  96. self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
  97. self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
  98. self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
  99. self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
  100. def forward(self, x):
  101. hx = x
  102. hxin = self.rebnconvin(hx)
  103. hx1 = self.rebnconv1(hxin)
  104. hx = self.pool1(hx1)
  105. hx2 = self.rebnconv2(hx)
  106. hx = self.pool2(hx2)
  107. hx3 = self.rebnconv3(hx)
  108. hx = self.pool3(hx3)
  109. hx4 = self.rebnconv4(hx)
  110. hx = self.pool4(hx4)
  111. hx5 = self.rebnconv5(hx)
  112. hx6 = self.rebnconv6(hx5)
  113. hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1))
  114. hx5dup = _upsample_like(hx5d, hx4)
  115. hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
  116. hx4dup = _upsample_like(hx4d, hx3)
  117. hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
  118. hx3dup = _upsample_like(hx3d, hx2)
  119. hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
  120. hx2dup = _upsample_like(hx2d, hx1)
  121. hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
  122. return hx1d + hxin
  123. ### RSU-5 ###
  124. class RSU5(nn.Module):
  125. def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
  126. super(RSU5, self).__init__()
  127. self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
  128. self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
  129. self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
  130. self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
  131. self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
  132. self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
  133. self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
  134. self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
  135. self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2)
  136. self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
  137. self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
  138. self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
  139. self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
  140. def forward(self, x):
  141. hx = x
  142. hxin = self.rebnconvin(hx)
  143. hx1 = self.rebnconv1(hxin)
  144. hx = self.pool1(hx1)
  145. hx2 = self.rebnconv2(hx)
  146. hx = self.pool2(hx2)
  147. hx3 = self.rebnconv3(hx)
  148. hx = self.pool3(hx3)
  149. hx4 = self.rebnconv4(hx)
  150. hx5 = self.rebnconv5(hx4)
  151. hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1))
  152. hx4dup = _upsample_like(hx4d, hx3)
  153. hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
  154. hx3dup = _upsample_like(hx3d, hx2)
  155. hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
  156. hx2dup = _upsample_like(hx2d, hx1)
  157. hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
  158. return hx1d + hxin
  159. ### RSU-4 ###
  160. class RSU4(nn.Module):
  161. def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
  162. super(RSU4, self).__init__()
  163. self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
  164. self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
  165. self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
  166. self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
  167. self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
  168. self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
  169. self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2)
  170. self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
  171. self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
  172. self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
  173. def forward(self, x):
  174. hx = x
  175. hxin = self.rebnconvin(hx)
  176. hx1 = self.rebnconv1(hxin)
  177. hx = self.pool1(hx1)
  178. hx2 = self.rebnconv2(hx)
  179. hx = self.pool2(hx2)
  180. hx3 = self.rebnconv3(hx)
  181. hx4 = self.rebnconv4(hx3)
  182. hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
  183. hx3dup = _upsample_like(hx3d, hx2)
  184. hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
  185. hx2dup = _upsample_like(hx2d, hx1)
  186. hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
  187. return hx1d + hxin
  188. ### RSU-4F ###
  189. class RSU4F(nn.Module):
  190. def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
  191. super(RSU4F, self).__init__()
  192. self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
  193. self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
  194. self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2)
  195. self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4)
  196. self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8)
  197. self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4)
  198. self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2)
  199. self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
  200. def forward(self, x):
  201. hx = x
  202. hxin = self.rebnconvin(hx)
  203. hx1 = self.rebnconv1(hxin)
  204. hx2 = self.rebnconv2(hx1)
  205. hx3 = self.rebnconv3(hx2)
  206. hx4 = self.rebnconv4(hx3)
  207. hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
  208. hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1))
  209. hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1))
  210. return hx1d + hxin
  211. class ISNetDIS(nn.Module):
  212. def __init__(self, in_ch=3, out_ch=1):
  213. super(ISNetDIS, self).__init__()
  214. self.conv_in = nn.Conv2d(in_ch, 64, 3, stride=2, padding=1)
  215. self.pool_in = nn.MaxPool2d(2, stride=2, ceil_mode=True)
  216. self.stage1 = RSU7(64, 32, 64)
  217. self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
  218. self.stage2 = RSU6(64, 32, 128)
  219. self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
  220. self.stage3 = RSU5(128, 64, 256)
  221. self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
  222. self.stage4 = RSU4(256, 128, 512)
  223. self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
  224. self.stage5 = RSU4F(512, 256, 512)
  225. self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
  226. self.stage6 = RSU4F(512, 256, 512)
  227. # decoder
  228. self.stage5d = RSU4F(1024, 256, 512)
  229. self.stage4d = RSU4(1024, 128, 256)
  230. self.stage3d = RSU5(512, 64, 128)
  231. self.stage2d = RSU6(256, 32, 64)
  232. self.stage1d = RSU7(128, 16, 64)
  233. self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
  234. def forward(self, x):
  235. hx = x
  236. hxin = self.conv_in(hx)
  237. hx = self.pool_in(hxin)
  238. # stage 1
  239. hx1 = self.stage1(hxin)
  240. hx = self.pool12(hx1)
  241. # stage 2
  242. hx2 = self.stage2(hx)
  243. hx = self.pool23(hx2)
  244. # stage 3
  245. hx3 = self.stage3(hx)
  246. hx = self.pool34(hx3)
  247. # stage 4
  248. hx4 = self.stage4(hx)
  249. hx = self.pool45(hx4)
  250. # stage 5
  251. hx5 = self.stage5(hx)
  252. hx = self.pool56(hx5)
  253. # stage 6
  254. hx6 = self.stage6(hx)
  255. hx6up = _upsample_like(hx6, hx5)
  256. # -------------------- decoder --------------------
  257. hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
  258. hx5dup = _upsample_like(hx5d, hx4)
  259. hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
  260. hx4dup = _upsample_like(hx4d, hx3)
  261. hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
  262. hx3dup = _upsample_like(hx3d, hx2)
  263. hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
  264. hx2dup = _upsample_like(hx2d, hx1)
  265. hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
  266. # side output
  267. d1 = self.side1(hx1d)
  268. d1 = _upsample_like(d1, x)
  269. return d1.sigmoid()
  270. # 从小到大
  271. ANIME_SEG_MODELS = {
  272. "url": "https://github.com/Sanster/models/releases/download/isnetis/isnetis.pth",
  273. "md5": "5f25479076b73074730ab8de9e8f2051",
  274. }
  275. class AnimeSeg(BasePlugin):
  276. # Model from: https://github.com/SkyTNT/anime-segmentation
  277. name = "AnimeSeg"
  278. support_gen_image = True
  279. support_gen_mask = True
  280. def __init__(self):
  281. super().__init__()
  282. self.model = load_model(
  283. ISNetDIS(),
  284. ANIME_SEG_MODELS["url"],
  285. "cpu",
  286. ANIME_SEG_MODELS["md5"],
  287. )
  288. def gen_image(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
  289. mask = self.forward(rgb_np_img)
  290. mask = Image.fromarray(mask, mode="L")
  291. h0, w0 = rgb_np_img.shape[0], rgb_np_img.shape[1]
  292. empty = Image.new("RGBA", (w0, h0), 0)
  293. img = Image.fromarray(rgb_np_img)
  294. cutout = Image.composite(img, empty, mask)
  295. return np.asarray(cutout)
  296. def gen_mask(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
  297. return self.forward(rgb_np_img)
  298. @torch.inference_mode()
  299. def forward(self, rgb_np_img):
  300. s = 1024
  301. h0, w0 = h, w = rgb_np_img.shape[0], rgb_np_img.shape[1]
  302. if h > w:
  303. h, w = s, int(s * w / h)
  304. else:
  305. h, w = int(s * h / w), s
  306. ph, pw = s - h, s - w
  307. tmpImg = np.zeros([s, s, 3], dtype=np.float32)
  308. tmpImg[ph // 2 : ph // 2 + h, pw // 2 : pw // 2 + w] = (
  309. cv2.resize(rgb_np_img, (w, h)) / 255
  310. )
  311. tmpImg = tmpImg.transpose((2, 0, 1))
  312. tmpImg = torch.from_numpy(tmpImg).unsqueeze(0).type(torch.FloatTensor)
  313. mask = self.model(tmpImg)
  314. mask = mask[0, :, ph // 2 : ph // 2 + h, pw // 2 : pw // 2 + w]
  315. mask = cv2.resize(mask.cpu().numpy().transpose((1, 2, 0)), (w0, h0))
  316. return (mask * 255).astype("uint8")