realesrgan.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468
  1. import math
  2. import cv2
  3. import numpy as np
  4. import torch
  5. import torch.nn.functional as F
  6. from loguru import logger
  7. from torch import nn
  8. from sorawm.iopaint.helper import download_model
  9. from sorawm.iopaint.plugins.base_plugin import BasePlugin
  10. from sorawm.iopaint.schema import RealESRGANModel, RunPluginRequest
  11. class RealESRGANer:
  12. """A helper class for upsampling images with RealESRGAN.
  13. Args:
  14. scale (int): Upsampling scale factor used in the networks. It is usually 2 or 4.
  15. model_path (str): The path to the pretrained model. It can be urls (will first download it automatically).
  16. model (nn.Module): The defined network. Default: None.
  17. tile (int): As too large images result in the out of GPU memory issue, so this tile option will first crop
  18. input images into tiles, and then process each of them. Finally, they will be merged into one image.
  19. 0 denotes for do not use tile. Default: 0.
  20. tile_pad (int): The pad size for each tile, to remove border artifacts. Default: 10.
  21. pre_pad (int): Pad the input images to avoid border artifacts. Default: 10.
  22. half (float): Whether to use half precision during inference. Default: False.
  23. """
  24. def __init__(
  25. self,
  26. scale,
  27. model_path,
  28. dni_weight=None,
  29. model=None,
  30. tile=0,
  31. tile_pad=10,
  32. pre_pad=10,
  33. half=False,
  34. device=None,
  35. gpu_id=None,
  36. ):
  37. self.scale = scale
  38. self.tile_size = tile
  39. self.tile_pad = tile_pad
  40. self.pre_pad = pre_pad
  41. self.mod_scale = None
  42. self.half = half
  43. # initialize model
  44. if gpu_id:
  45. self.device = (
  46. torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu")
  47. if device is None
  48. else device
  49. )
  50. else:
  51. self.device = (
  52. torch.device("cuda" if torch.cuda.is_available() else "cpu")
  53. if device is None
  54. else device
  55. )
  56. if isinstance(model_path, list):
  57. # dni
  58. assert len(model_path) == len(
  59. dni_weight
  60. ), "model_path and dni_weight should have the save length."
  61. loadnet = self.dni(model_path[0], model_path[1], dni_weight)
  62. else:
  63. # if the model_path starts with https, it will first download models to the folder: weights
  64. loadnet = torch.load(model_path, map_location=torch.device("cpu"))
  65. # prefer to use params_ema
  66. if "params_ema" in loadnet:
  67. keyname = "params_ema"
  68. else:
  69. keyname = "params"
  70. model.load_state_dict(loadnet[keyname], strict=True)
  71. model.eval()
  72. self.model = model.to(self.device)
  73. if self.half:
  74. self.model = self.model.half()
  75. def dni(self, net_a, net_b, dni_weight, key="params", loc="cpu"):
  76. """Deep network interpolation.
  77. ``Paper: Deep Network Interpolation for Continuous Imagery Effect Transition``
  78. """
  79. net_a = torch.load(net_a, map_location=torch.device(loc))
  80. net_b = torch.load(net_b, map_location=torch.device(loc))
  81. for k, v_a in net_a[key].items():
  82. net_a[key][k] = dni_weight[0] * v_a + dni_weight[1] * net_b[key][k]
  83. return net_a
  84. def pre_process(self, img):
  85. """Pre-process, such as pre-pad and mod pad, so that the images can be divisible"""
  86. img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float()
  87. self.img = img.unsqueeze(0).to(self.device)
  88. if self.half:
  89. self.img = self.img.half()
  90. # pre_pad
  91. if self.pre_pad != 0:
  92. self.img = F.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), "reflect")
  93. # mod pad for divisible borders
  94. if self.scale == 2:
  95. self.mod_scale = 2
  96. elif self.scale == 1:
  97. self.mod_scale = 4
  98. if self.mod_scale is not None:
  99. self.mod_pad_h, self.mod_pad_w = 0, 0
  100. _, _, h, w = self.img.size()
  101. if h % self.mod_scale != 0:
  102. self.mod_pad_h = self.mod_scale - h % self.mod_scale
  103. if w % self.mod_scale != 0:
  104. self.mod_pad_w = self.mod_scale - w % self.mod_scale
  105. self.img = F.pad(
  106. self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), "reflect"
  107. )
  108. def process(self):
  109. # model inference
  110. self.output = self.model(self.img)
  111. def tile_process(self):
  112. """It will first crop input images to tiles, and then process each tile.
  113. Finally, all the processed tiles are merged into one images.
  114. Modified from: https://github.com/ata4/esrgan-launcher
  115. """
  116. batch, channel, height, width = self.img.shape
  117. output_height = height * self.scale
  118. output_width = width * self.scale
  119. output_shape = (batch, channel, output_height, output_width)
  120. # start with black image
  121. self.output = self.img.new_zeros(output_shape)
  122. tiles_x = math.ceil(width / self.tile_size)
  123. tiles_y = math.ceil(height / self.tile_size)
  124. # loop over all tiles
  125. for y in range(tiles_y):
  126. for x in range(tiles_x):
  127. # extract tile from input image
  128. ofs_x = x * self.tile_size
  129. ofs_y = y * self.tile_size
  130. # input tile area on total image
  131. input_start_x = ofs_x
  132. input_end_x = min(ofs_x + self.tile_size, width)
  133. input_start_y = ofs_y
  134. input_end_y = min(ofs_y + self.tile_size, height)
  135. # input tile area on total image with padding
  136. input_start_x_pad = max(input_start_x - self.tile_pad, 0)
  137. input_end_x_pad = min(input_end_x + self.tile_pad, width)
  138. input_start_y_pad = max(input_start_y - self.tile_pad, 0)
  139. input_end_y_pad = min(input_end_y + self.tile_pad, height)
  140. # input tile dimensions
  141. input_tile_width = input_end_x - input_start_x
  142. input_tile_height = input_end_y - input_start_y
  143. tile_idx = y * tiles_x + x + 1
  144. input_tile = self.img[
  145. :,
  146. :,
  147. input_start_y_pad:input_end_y_pad,
  148. input_start_x_pad:input_end_x_pad,
  149. ]
  150. # upscale tile
  151. try:
  152. with torch.no_grad():
  153. output_tile = self.model(input_tile)
  154. except RuntimeError as error:
  155. print("Error", error)
  156. print(f"\tTile {tile_idx}/{tiles_x * tiles_y}")
  157. # output tile area on total image
  158. output_start_x = input_start_x * self.scale
  159. output_end_x = input_end_x * self.scale
  160. output_start_y = input_start_y * self.scale
  161. output_end_y = input_end_y * self.scale
  162. # output tile area without padding
  163. output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale
  164. output_end_x_tile = output_start_x_tile + input_tile_width * self.scale
  165. output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale
  166. output_end_y_tile = output_start_y_tile + input_tile_height * self.scale
  167. # put tile into output image
  168. self.output[
  169. :, :, output_start_y:output_end_y, output_start_x:output_end_x
  170. ] = output_tile[
  171. :,
  172. :,
  173. output_start_y_tile:output_end_y_tile,
  174. output_start_x_tile:output_end_x_tile,
  175. ]
  176. def post_process(self):
  177. # remove extra pad
  178. if self.mod_scale is not None:
  179. _, _, h, w = self.output.size()
  180. self.output = self.output[
  181. :,
  182. :,
  183. 0 : h - self.mod_pad_h * self.scale,
  184. 0 : w - self.mod_pad_w * self.scale,
  185. ]
  186. # remove prepad
  187. if self.pre_pad != 0:
  188. _, _, h, w = self.output.size()
  189. self.output = self.output[
  190. :,
  191. :,
  192. 0 : h - self.pre_pad * self.scale,
  193. 0 : w - self.pre_pad * self.scale,
  194. ]
  195. return self.output
  196. @torch.no_grad()
  197. def enhance(self, img, outscale=None, alpha_upsampler="realesrgan"):
  198. h_input, w_input = img.shape[0:2]
  199. # img: numpy
  200. img = img.astype(np.float32)
  201. if np.max(img) > 256: # 16-bit image
  202. max_range = 65535
  203. print("\tInput is a 16-bit image")
  204. else:
  205. max_range = 255
  206. img = img / max_range
  207. if len(img.shape) == 2: # gray image
  208. img_mode = "L"
  209. img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
  210. elif img.shape[2] == 4: # RGBA image with alpha channel
  211. img_mode = "RGBA"
  212. alpha = img[:, :, 3]
  213. img = img[:, :, 0:3]
  214. img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  215. if alpha_upsampler == "realesrgan":
  216. alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB)
  217. else:
  218. img_mode = "RGB"
  219. img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  220. # ------------------- process image (without the alpha channel) ------------------- #
  221. self.pre_process(img)
  222. if self.tile_size > 0:
  223. self.tile_process()
  224. else:
  225. self.process()
  226. output_img = self.post_process()
  227. output_img = output_img.data.squeeze().float().cpu().clamp_(0, 1).numpy()
  228. output_img = np.transpose(output_img[[2, 1, 0], :, :], (1, 2, 0))
  229. if img_mode == "L":
  230. output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY)
  231. # ------------------- process the alpha channel if necessary ------------------- #
  232. if img_mode == "RGBA":
  233. if alpha_upsampler == "realesrgan":
  234. self.pre_process(alpha)
  235. if self.tile_size > 0:
  236. self.tile_process()
  237. else:
  238. self.process()
  239. output_alpha = self.post_process()
  240. output_alpha = (
  241. output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy()
  242. )
  243. output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0))
  244. output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY)
  245. else: # use the cv2 resize for alpha channel
  246. h, w = alpha.shape[0:2]
  247. output_alpha = cv2.resize(
  248. alpha,
  249. (w * self.scale, h * self.scale),
  250. interpolation=cv2.INTER_LINEAR,
  251. )
  252. # merge the alpha channel
  253. output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA)
  254. output_img[:, :, 3] = output_alpha
  255. # ------------------------------ return ------------------------------ #
  256. if max_range == 65535: # 16-bit image
  257. output = (output_img * 65535.0).round().astype(np.uint16)
  258. else:
  259. output = (output_img * 255.0).round().astype(np.uint8)
  260. if outscale is not None and outscale != float(self.scale):
  261. output = cv2.resize(
  262. output,
  263. (
  264. int(w_input * outscale),
  265. int(h_input * outscale),
  266. ),
  267. interpolation=cv2.INTER_LANCZOS4,
  268. )
  269. return output, img_mode
  270. class SRVGGNetCompact(nn.Module):
  271. """A compact VGG-style network structure for super-resolution.
  272. It is a compact network structure, which performs upsampling in the last layer and no convolution is
  273. conducted on the HR feature space.
  274. Args:
  275. num_in_ch (int): Channel number of inputs. Default: 3.
  276. num_out_ch (int): Channel number of outputs. Default: 3.
  277. num_feat (int): Channel number of intermediate features. Default: 64.
  278. num_conv (int): Number of convolution layers in the body network. Default: 16.
  279. upscale (int): Upsampling factor. Default: 4.
  280. act_type (str): Activation type, options: 'relu', 'prelu', 'leakyrelu'. Default: prelu.
  281. """
  282. def __init__(
  283. self,
  284. num_in_ch=3,
  285. num_out_ch=3,
  286. num_feat=64,
  287. num_conv=16,
  288. upscale=4,
  289. act_type="prelu",
  290. ):
  291. super(SRVGGNetCompact, self).__init__()
  292. self.num_in_ch = num_in_ch
  293. self.num_out_ch = num_out_ch
  294. self.num_feat = num_feat
  295. self.num_conv = num_conv
  296. self.upscale = upscale
  297. self.act_type = act_type
  298. self.body = nn.ModuleList()
  299. # the first conv
  300. self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1))
  301. # the first activation
  302. if act_type == "relu":
  303. activation = nn.ReLU(inplace=True)
  304. elif act_type == "prelu":
  305. activation = nn.PReLU(num_parameters=num_feat)
  306. elif act_type == "leakyrelu":
  307. activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
  308. self.body.append(activation)
  309. # the body structure
  310. for _ in range(num_conv):
  311. self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1))
  312. # activation
  313. if act_type == "relu":
  314. activation = nn.ReLU(inplace=True)
  315. elif act_type == "prelu":
  316. activation = nn.PReLU(num_parameters=num_feat)
  317. elif act_type == "leakyrelu":
  318. activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
  319. self.body.append(activation)
  320. # the last conv
  321. self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1))
  322. # upsample
  323. self.upsampler = nn.PixelShuffle(upscale)
  324. def forward(self, x):
  325. out = x
  326. for i in range(0, len(self.body)):
  327. out = self.body[i](out)
  328. out = self.upsampler(out)
  329. # add the nearest upsampled image, so that the network learns the residual
  330. base = F.interpolate(x, scale_factor=self.upscale, mode="nearest")
  331. out += base
  332. return out
  333. class RealESRGANUpscaler(BasePlugin):
  334. name = "RealESRGAN"
  335. support_gen_image = True
  336. def __init__(self, name, device, no_half=False):
  337. super().__init__()
  338. self.model_name = name
  339. self.device = device
  340. self.no_half = no_half
  341. self._init_model(name)
  342. def _init_model(self, name):
  343. from .basicsr import RRDBNet
  344. REAL_ESRGAN_MODELS = {
  345. RealESRGANModel.realesr_general_x4v3: {
  346. "url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
  347. "scale": 4,
  348. "model": lambda: SRVGGNetCompact(
  349. num_in_ch=3,
  350. num_out_ch=3,
  351. num_feat=64,
  352. num_conv=32,
  353. upscale=4,
  354. act_type="prelu",
  355. ),
  356. "model_md5": "91a7644643c884ee00737db24e478156",
  357. },
  358. RealESRGANModel.RealESRGAN_x4plus: {
  359. "url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
  360. "scale": 4,
  361. "model": lambda: RRDBNet(
  362. num_in_ch=3,
  363. num_out_ch=3,
  364. num_feat=64,
  365. num_block=23,
  366. num_grow_ch=32,
  367. scale=4,
  368. ),
  369. "model_md5": "99ec365d4afad750833258a1a24f44ca",
  370. },
  371. RealESRGANModel.RealESRGAN_x4plus_anime_6B: {
  372. "url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
  373. "scale": 4,
  374. "model": lambda: RRDBNet(
  375. num_in_ch=3,
  376. num_out_ch=3,
  377. num_feat=64,
  378. num_block=6,
  379. num_grow_ch=32,
  380. scale=4,
  381. ),
  382. "model_md5": "d58ce384064ec1591c2ea7b79dbf47ba",
  383. },
  384. }
  385. if name not in REAL_ESRGAN_MODELS:
  386. raise ValueError(f"Unknown RealESRGAN model name: {name}")
  387. model_info = REAL_ESRGAN_MODELS[name]
  388. model_path = download_model(model_info["url"], model_info["model_md5"])
  389. logger.info(f"RealESRGAN model path: {model_path}")
  390. self.model = RealESRGANer(
  391. scale=model_info["scale"],
  392. model_path=model_path,
  393. model=model_info["model"](),
  394. half=True if "cuda" in str(self.device) and not self.no_half else False,
  395. tile=512,
  396. tile_pad=10,
  397. pre_pad=10,
  398. device=self.device,
  399. )
  400. def switch_model(self, new_model_name: str):
  401. if self.model_name == new_model_name:
  402. return
  403. self._init_model(new_model_name)
  404. self.model_name = new_model_name
  405. def gen_image(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
  406. bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
  407. logger.info(f"RealESRGAN input shape: {bgr_np_img.shape}, scale: {req.scale}")
  408. result = self.forward(bgr_np_img, req.scale)
  409. logger.info(f"RealESRGAN output shape: {result.shape}")
  410. return result
  411. @torch.inference_mode()
  412. def forward(self, bgr_np_img, scale: float):
  413. # 输出是 BGR
  414. upsampled = self.model.enhance(bgr_np_img, outscale=scale)[0]
  415. return upsampled