anytext_model.py 2.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. import torch
  2. from huggingface_hub import hf_hub_download
  3. from sorawm.iopaint.const import ANYTEXT_NAME
  4. from sorawm.iopaint.model.anytext.anytext_pipeline import AnyTextPipeline
  5. from sorawm.iopaint.model.base import DiffusionInpaintModel
  6. from sorawm.iopaint.model.utils import get_torch_dtype, is_local_files_only
  7. from sorawm.iopaint.schema import InpaintRequest
  8. class AnyText(DiffusionInpaintModel):
  9. name = ANYTEXT_NAME
  10. pad_mod = 64
  11. is_erase_model = False
  12. @staticmethod
  13. def download(local_files_only=False):
  14. hf_hub_download(
  15. repo_id=ANYTEXT_NAME,
  16. filename="model_index.json",
  17. local_files_only=local_files_only,
  18. )
  19. ckpt_path = hf_hub_download(
  20. repo_id=ANYTEXT_NAME,
  21. filename="pytorch_model.fp16.safetensors",
  22. local_files_only=local_files_only,
  23. )
  24. font_path = hf_hub_download(
  25. repo_id=ANYTEXT_NAME,
  26. filename="SourceHanSansSC-Medium.otf",
  27. local_files_only=local_files_only,
  28. )
  29. return ckpt_path, font_path
  30. def init_model(self, device, **kwargs):
  31. local_files_only = is_local_files_only(**kwargs)
  32. ckpt_path, font_path = self.download(local_files_only)
  33. use_gpu, torch_dtype = get_torch_dtype(device, kwargs.get("no_half", False))
  34. self.model = AnyTextPipeline(
  35. ckpt_path=ckpt_path,
  36. font_path=font_path,
  37. device=device,
  38. use_fp16=torch_dtype == torch.float16,
  39. )
  40. self.callback = kwargs.pop("callback", None)
  41. def forward(self, image, mask, config: InpaintRequest):
  42. """Input image and output image have same size
  43. image: [H, W, C] RGB
  44. mask: [H, W, 1] 255 means area to inpainting
  45. return: BGR IMAGE
  46. """
  47. height, width = image.shape[:2]
  48. mask = mask.astype("float32") / 255.0
  49. masked_image = image * (1 - mask)
  50. # list of rgb ndarray
  51. results, rtn_code, rtn_warning = self.model(
  52. image=image,
  53. masked_image=masked_image,
  54. prompt=config.prompt,
  55. negative_prompt=config.negative_prompt,
  56. num_inference_steps=config.sd_steps,
  57. strength=config.sd_strength,
  58. guidance_scale=config.sd_guidance_scale,
  59. height=height,
  60. width=width,
  61. seed=config.sd_seed,
  62. sort_priority="y",
  63. callback=self.callback,
  64. )
  65. inpainted_rgb_image = results[0][..., ::-1]
  66. return inpainted_rgb_image