anytext_pipeline.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401
  1. """
  2. AnyText: Multilingual Visual Text Generation And Editing
  3. Paper: https://arxiv.org/abs/2311.03054
  4. Code: https://github.com/tyxsspa/AnyText
  5. Copyright (c) Alibaba, Inc. and its affiliates.
  6. """
  7. import os
  8. from pathlib import Path
  9. from safetensors.torch import load_file
  10. from sorawm.iopaint.model.utils import set_seed
  11. os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
  12. import re
  13. import cv2
  14. import einops
  15. import numpy as np
  16. import torch
  17. from PIL import ImageFont
  18. from sorawm.iopaint.model.anytext.cldm.ddim_hacked import DDIMSampler
  19. from sorawm.iopaint.model.anytext.cldm.model import create_model, load_state_dict
  20. from sorawm.iopaint.model.anytext.utils import check_channels, draw_glyph, draw_glyph2
  21. BBOX_MAX_NUM = 8
  22. PLACE_HOLDER = "*"
  23. max_chars = 20
  24. ANYTEXT_CFG = os.path.join(
  25. os.path.dirname(os.path.abspath(__file__)), "anytext_sd15.yaml"
  26. )
  27. def check_limits(tensor):
  28. float16_min = torch.finfo(torch.float16).min
  29. float16_max = torch.finfo(torch.float16).max
  30. # 检查张量中是否有值小于float16的最小值或大于float16的最大值
  31. is_below_min = (tensor < float16_min).any()
  32. is_above_max = (tensor > float16_max).any()
  33. return is_below_min or is_above_max
  34. class AnyTextPipeline:
  35. def __init__(self, ckpt_path, font_path, device, use_fp16=True):
  36. self.cfg_path = ANYTEXT_CFG
  37. self.font_path = font_path
  38. self.use_fp16 = use_fp16
  39. self.device = device
  40. self.font = ImageFont.truetype(font_path, size=60)
  41. self.model = create_model(
  42. self.cfg_path,
  43. device=self.device,
  44. use_fp16=self.use_fp16,
  45. )
  46. if self.use_fp16:
  47. self.model = self.model.half()
  48. if Path(ckpt_path).suffix == ".safetensors":
  49. state_dict = load_file(ckpt_path, device="cpu")
  50. else:
  51. state_dict = load_state_dict(ckpt_path, location="cpu")
  52. self.model.load_state_dict(state_dict, strict=False)
  53. self.model = self.model.eval().to(self.device)
  54. self.ddim_sampler = DDIMSampler(self.model, device=self.device)
  55. def __call__(
  56. self,
  57. prompt: str,
  58. negative_prompt: str,
  59. image: np.ndarray,
  60. masked_image: np.ndarray,
  61. num_inference_steps: int,
  62. strength: float,
  63. guidance_scale: float,
  64. height: int,
  65. width: int,
  66. seed: int,
  67. sort_priority: str = "y",
  68. callback=None,
  69. ):
  70. """
  71. Args:
  72. prompt:
  73. negative_prompt:
  74. image:
  75. masked_image:
  76. num_inference_steps:
  77. strength:
  78. guidance_scale:
  79. height:
  80. width:
  81. seed:
  82. sort_priority: x: left-right, y: top-down
  83. Returns:
  84. result: list of images in numpy.ndarray format
  85. rst_code: 0: normal -1: error 1:warning
  86. rst_info: string of error or warning
  87. """
  88. set_seed(seed)
  89. str_warning = ""
  90. mode = "text-editing"
  91. revise_pos = False
  92. img_count = 1
  93. ddim_steps = num_inference_steps
  94. w = width
  95. h = height
  96. strength = strength
  97. cfg_scale = guidance_scale
  98. eta = 0.0
  99. prompt, texts = self.modify_prompt(prompt)
  100. if prompt is None and texts is None:
  101. return (
  102. None,
  103. -1,
  104. "You have input Chinese prompt but the translator is not loaded!",
  105. "",
  106. )
  107. n_lines = len(texts)
  108. if mode in ["text-generation", "gen"]:
  109. edit_image = np.ones((h, w, 3)) * 127.5 # empty mask image
  110. elif mode in ["text-editing", "edit"]:
  111. if masked_image is None or image is None:
  112. return (
  113. None,
  114. -1,
  115. "Reference image and position image are needed for text editing!",
  116. "",
  117. )
  118. if isinstance(image, str):
  119. image = cv2.imread(image)[..., ::-1]
  120. assert image is not None, f"Can't read ori_image image from{image}!"
  121. elif isinstance(image, torch.Tensor):
  122. image = image.cpu().numpy()
  123. else:
  124. assert isinstance(
  125. image, np.ndarray
  126. ), f"Unknown format of ori_image: {type(image)}"
  127. edit_image = image.clip(1, 255) # for mask reason
  128. edit_image = check_channels(edit_image)
  129. # edit_image = resize_image(
  130. # edit_image, max_length=768
  131. # ) # make w h multiple of 64, resize if w or h > max_length
  132. h, w = edit_image.shape[:2] # change h, w by input ref_img
  133. # preprocess pos_imgs(if numpy, make sure it's white pos in black bg)
  134. if masked_image is None:
  135. pos_imgs = np.zeros((w, h, 1))
  136. if isinstance(masked_image, str):
  137. masked_image = cv2.imread(masked_image)[..., ::-1]
  138. assert (
  139. masked_image is not None
  140. ), f"Can't read draw_pos image from{masked_image}!"
  141. pos_imgs = 255 - masked_image
  142. elif isinstance(masked_image, torch.Tensor):
  143. pos_imgs = masked_image.cpu().numpy()
  144. else:
  145. assert isinstance(
  146. masked_image, np.ndarray
  147. ), f"Unknown format of draw_pos: {type(masked_image)}"
  148. pos_imgs = 255 - masked_image
  149. pos_imgs = pos_imgs[..., 0:1]
  150. pos_imgs = cv2.convertScaleAbs(pos_imgs)
  151. _, pos_imgs = cv2.threshold(pos_imgs, 254, 255, cv2.THRESH_BINARY)
  152. # seprate pos_imgs
  153. pos_imgs = self.separate_pos_imgs(pos_imgs, sort_priority)
  154. if len(pos_imgs) == 0:
  155. pos_imgs = [np.zeros((h, w, 1))]
  156. if len(pos_imgs) < n_lines:
  157. if n_lines == 1 and texts[0] == " ":
  158. pass # text-to-image without text
  159. else:
  160. raise RuntimeError(
  161. f"{n_lines} text line to draw from prompt, not enough mask area({len(pos_imgs)}) on images"
  162. )
  163. elif len(pos_imgs) > n_lines:
  164. str_warning = f"Warning: found {len(pos_imgs)} positions that > needed {n_lines} from prompt."
  165. # get pre_pos, poly_list, hint that needed for anytext
  166. pre_pos = []
  167. poly_list = []
  168. for input_pos in pos_imgs:
  169. if input_pos.mean() != 0:
  170. input_pos = (
  171. input_pos[..., np.newaxis]
  172. if len(input_pos.shape) == 2
  173. else input_pos
  174. )
  175. poly, pos_img = self.find_polygon(input_pos)
  176. pre_pos += [pos_img / 255.0]
  177. poly_list += [poly]
  178. else:
  179. pre_pos += [np.zeros((h, w, 1))]
  180. poly_list += [None]
  181. np_hint = np.sum(pre_pos, axis=0).clip(0, 1)
  182. # prepare info dict
  183. info = {}
  184. info["glyphs"] = []
  185. info["gly_line"] = []
  186. info["positions"] = []
  187. info["n_lines"] = [len(texts)] * img_count
  188. gly_pos_imgs = []
  189. for i in range(len(texts)):
  190. text = texts[i]
  191. if len(text) > max_chars:
  192. str_warning = (
  193. f'"{text}" length > max_chars: {max_chars}, will be cut off...'
  194. )
  195. text = text[:max_chars]
  196. gly_scale = 2
  197. if pre_pos[i].mean() != 0:
  198. gly_line = draw_glyph(self.font, text)
  199. glyphs = draw_glyph2(
  200. self.font,
  201. text,
  202. poly_list[i],
  203. scale=gly_scale,
  204. width=w,
  205. height=h,
  206. add_space=False,
  207. )
  208. gly_pos_img = cv2.drawContours(
  209. glyphs * 255, [poly_list[i] * gly_scale], 0, (255, 255, 255), 1
  210. )
  211. if revise_pos:
  212. resize_gly = cv2.resize(
  213. glyphs, (pre_pos[i].shape[1], pre_pos[i].shape[0])
  214. )
  215. new_pos = cv2.morphologyEx(
  216. (resize_gly * 255).astype(np.uint8),
  217. cv2.MORPH_CLOSE,
  218. kernel=np.ones(
  219. (resize_gly.shape[0] // 10, resize_gly.shape[1] // 10),
  220. dtype=np.uint8,
  221. ),
  222. iterations=1,
  223. )
  224. new_pos = (
  225. new_pos[..., np.newaxis] if len(new_pos.shape) == 2 else new_pos
  226. )
  227. contours, _ = cv2.findContours(
  228. new_pos, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE
  229. )
  230. if len(contours) != 1:
  231. str_warning = f"Fail to revise position {i} to bounding rect, remain position unchanged..."
  232. else:
  233. rect = cv2.minAreaRect(contours[0])
  234. poly = np.int0(cv2.boxPoints(rect))
  235. pre_pos[i] = (
  236. cv2.drawContours(new_pos, [poly], -1, 255, -1) / 255.0
  237. )
  238. gly_pos_img = cv2.drawContours(
  239. glyphs * 255, [poly * gly_scale], 0, (255, 255, 255), 1
  240. )
  241. gly_pos_imgs += [gly_pos_img] # for show
  242. else:
  243. glyphs = np.zeros((h * gly_scale, w * gly_scale, 1))
  244. gly_line = np.zeros((80, 512, 1))
  245. gly_pos_imgs += [
  246. np.zeros((h * gly_scale, w * gly_scale, 1))
  247. ] # for show
  248. pos = pre_pos[i]
  249. info["glyphs"] += [self.arr2tensor(glyphs, img_count)]
  250. info["gly_line"] += [self.arr2tensor(gly_line, img_count)]
  251. info["positions"] += [self.arr2tensor(pos, img_count)]
  252. # get masked_x
  253. masked_img = ((edit_image.astype(np.float32) / 127.5) - 1.0) * (1 - np_hint)
  254. masked_img = np.transpose(masked_img, (2, 0, 1))
  255. masked_img = torch.from_numpy(masked_img.copy()).float().to(self.device)
  256. if self.use_fp16:
  257. masked_img = masked_img.half()
  258. encoder_posterior = self.model.encode_first_stage(masked_img[None, ...])
  259. masked_x = self.model.get_first_stage_encoding(encoder_posterior).detach()
  260. if self.use_fp16:
  261. masked_x = masked_x.half()
  262. info["masked_x"] = torch.cat([masked_x for _ in range(img_count)], dim=0)
  263. hint = self.arr2tensor(np_hint, img_count)
  264. cond = self.model.get_learned_conditioning(
  265. dict(
  266. c_concat=[hint],
  267. c_crossattn=[[prompt] * img_count],
  268. text_info=info,
  269. )
  270. )
  271. un_cond = self.model.get_learned_conditioning(
  272. dict(
  273. c_concat=[hint],
  274. c_crossattn=[[negative_prompt] * img_count],
  275. text_info=info,
  276. )
  277. )
  278. shape = (4, h // 8, w // 8)
  279. self.model.control_scales = [strength] * 13
  280. samples, intermediates = self.ddim_sampler.sample(
  281. ddim_steps,
  282. img_count,
  283. shape,
  284. cond,
  285. verbose=False,
  286. eta=eta,
  287. unconditional_guidance_scale=cfg_scale,
  288. unconditional_conditioning=un_cond,
  289. callback=callback,
  290. )
  291. if self.use_fp16:
  292. samples = samples.half()
  293. x_samples = self.model.decode_first_stage(samples)
  294. x_samples = (
  295. (einops.rearrange(x_samples, "b c h w -> b h w c") * 127.5 + 127.5)
  296. .cpu()
  297. .numpy()
  298. .clip(0, 255)
  299. .astype(np.uint8)
  300. )
  301. results = [x_samples[i] for i in range(img_count)]
  302. # if (
  303. # mode == "edit" and False
  304. # ): # replace backgound in text editing but not ideal yet
  305. # results = [r * np_hint + edit_image * (1 - np_hint) for r in results]
  306. # results = [r.clip(0, 255).astype(np.uint8) for r in results]
  307. # if len(gly_pos_imgs) > 0 and show_debug:
  308. # glyph_bs = np.stack(gly_pos_imgs, axis=2)
  309. # glyph_img = np.sum(glyph_bs, axis=2) * 255
  310. # glyph_img = glyph_img.clip(0, 255).astype(np.uint8)
  311. # results += [np.repeat(glyph_img, 3, axis=2)]
  312. rst_code = 1 if str_warning else 0
  313. return results, rst_code, str_warning
  314. def modify_prompt(self, prompt):
  315. prompt = prompt.replace("“", '"')
  316. prompt = prompt.replace("”", '"')
  317. p = '"(.*?)"'
  318. strs = re.findall(p, prompt)
  319. if len(strs) == 0:
  320. strs = [" "]
  321. else:
  322. for s in strs:
  323. prompt = prompt.replace(f'"{s}"', f" {PLACE_HOLDER} ", 1)
  324. # if self.is_chinese(prompt):
  325. # if self.trans_pipe is None:
  326. # return None, None
  327. # old_prompt = prompt
  328. # prompt = self.trans_pipe(input=prompt + " .")["translation"][:-1]
  329. # print(f"Translate: {old_prompt} --> {prompt}")
  330. return prompt, strs
  331. # def is_chinese(self, text):
  332. # text = checker._clean_text(text)
  333. # for char in text:
  334. # cp = ord(char)
  335. # if checker._is_chinese_char(cp):
  336. # return True
  337. # return False
  338. def separate_pos_imgs(self, img, sort_priority, gap=102):
  339. num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(img)
  340. components = []
  341. for label in range(1, num_labels):
  342. component = np.zeros_like(img)
  343. component[labels == label] = 255
  344. components.append((component, centroids[label]))
  345. if sort_priority == "y":
  346. fir, sec = 1, 0 # top-down first
  347. elif sort_priority == "x":
  348. fir, sec = 0, 1 # left-right first
  349. components.sort(key=lambda c: (c[1][fir] // gap, c[1][sec] // gap))
  350. sorted_components = [c[0] for c in components]
  351. return sorted_components
  352. def find_polygon(self, image, min_rect=False):
  353. contours, hierarchy = cv2.findContours(
  354. image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE
  355. )
  356. max_contour = max(contours, key=cv2.contourArea) # get contour with max area
  357. if min_rect:
  358. # get minimum enclosing rectangle
  359. rect = cv2.minAreaRect(max_contour)
  360. poly = np.int0(cv2.boxPoints(rect))
  361. else:
  362. # get approximate polygon
  363. epsilon = 0.01 * cv2.arcLength(max_contour, True)
  364. poly = cv2.approxPolyDP(max_contour, epsilon, True)
  365. n, _, xy = poly.shape
  366. poly = poly.reshape(n, xy)
  367. cv2.drawContours(image, [poly], -1, 255, -1)
  368. return poly, image
  369. def arr2tensor(self, arr, bs):
  370. arr = np.transpose(arr, (2, 0, 1))
  371. _arr = torch.from_numpy(arr.copy()).float().to(self.device)
  372. if self.use_fp16:
  373. _arr = _arr.half()
  374. _arr = torch.stack([_arr for _ in range(bs)], dim=0)
  375. return _arr