powerpaint_tokenizer.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255
  1. import copy
  2. import random
  3. from typing import Any, List, Union
  4. from transformers import CLIPTokenizer
  5. from sorawm.iopaint.schema import PowerPaintTask
  6. def add_task_to_prompt(prompt, negative_prompt, task: PowerPaintTask):
  7. if task == PowerPaintTask.object_remove:
  8. promptA = prompt + " P_ctxt"
  9. promptB = prompt + " P_ctxt"
  10. negative_promptA = negative_prompt + " P_obj"
  11. negative_promptB = negative_prompt + " P_obj"
  12. elif task == PowerPaintTask.context_aware:
  13. promptA = prompt + " P_ctxt"
  14. promptB = prompt + " P_ctxt"
  15. negative_promptA = negative_prompt
  16. negative_promptB = negative_prompt
  17. elif task == PowerPaintTask.shape_guided:
  18. promptA = prompt + " P_shape"
  19. promptB = prompt + " P_ctxt"
  20. negative_promptA = negative_prompt
  21. negative_promptB = negative_prompt
  22. elif task == PowerPaintTask.outpainting:
  23. promptA = prompt + " P_ctxt"
  24. promptB = prompt + " P_ctxt"
  25. negative_promptA = negative_prompt + " P_obj"
  26. negative_promptB = negative_prompt + " P_obj"
  27. else:
  28. promptA = prompt + " P_obj"
  29. promptB = prompt + " P_obj"
  30. negative_promptA = negative_prompt
  31. negative_promptB = negative_prompt
  32. return promptA, promptB, negative_promptA, negative_promptB
  33. def task_to_prompt(task: PowerPaintTask):
  34. promptA, promptB, negative_promptA, negative_promptB = add_task_to_prompt(
  35. "", "", task
  36. )
  37. return (
  38. promptA.strip(),
  39. promptB.strip(),
  40. negative_promptA.strip(),
  41. negative_promptB.strip(),
  42. )
  43. class PowerPaintTokenizer:
  44. def __init__(self, tokenizer: CLIPTokenizer):
  45. self.wrapped = tokenizer
  46. self.token_map = {}
  47. placeholder_tokens = ["P_ctxt", "P_shape", "P_obj"]
  48. num_vec_per_token = 10
  49. for placeholder_token in placeholder_tokens:
  50. output = []
  51. for i in range(num_vec_per_token):
  52. ith_token = placeholder_token + f"_{i}"
  53. output.append(ith_token)
  54. self.token_map[placeholder_token] = output
  55. def __getattr__(self, name: str) -> Any:
  56. if name == "wrapped":
  57. return super().__getattr__("wrapped")
  58. try:
  59. return getattr(self.wrapped, name)
  60. except AttributeError:
  61. try:
  62. return super().__getattr__(name)
  63. except AttributeError:
  64. raise AttributeError(
  65. "'name' cannot be found in both "
  66. f"'{self.__class__.__name__}' and "
  67. f"'{self.__class__.__name__}.tokenizer'."
  68. )
  69. def try_adding_tokens(self, tokens: Union[str, List[str]], *args, **kwargs):
  70. """Attempt to add tokens to the tokenizer.
  71. Args:
  72. tokens (Union[str, List[str]]): The tokens to be added.
  73. """
  74. num_added_tokens = self.wrapped.add_tokens(tokens, *args, **kwargs)
  75. assert num_added_tokens != 0, (
  76. f"The tokenizer already contains the token {tokens}. Please pass "
  77. "a different `placeholder_token` that is not already in the "
  78. "tokenizer."
  79. )
  80. def get_token_info(self, token: str) -> dict:
  81. """Get the information of a token, including its start and end index in
  82. the current tokenizer.
  83. Args:
  84. token (str): The token to be queried.
  85. Returns:
  86. dict: The information of the token, including its start and end
  87. index in current tokenizer.
  88. """
  89. token_ids = self.__call__(token).input_ids
  90. start, end = token_ids[1], token_ids[-2] + 1
  91. return {"name": token, "start": start, "end": end}
  92. def add_placeholder_token(
  93. self, placeholder_token: str, *args, num_vec_per_token: int = 1, **kwargs
  94. ):
  95. """Add placeholder tokens to the tokenizer.
  96. Args:
  97. placeholder_token (str): The placeholder token to be added.
  98. num_vec_per_token (int, optional): The number of vectors of
  99. the added placeholder token.
  100. *args, **kwargs: The arguments for `self.wrapped.add_tokens`.
  101. """
  102. output = []
  103. if num_vec_per_token == 1:
  104. self.try_adding_tokens(placeholder_token, *args, **kwargs)
  105. output.append(placeholder_token)
  106. else:
  107. output = []
  108. for i in range(num_vec_per_token):
  109. ith_token = placeholder_token + f"_{i}"
  110. self.try_adding_tokens(ith_token, *args, **kwargs)
  111. output.append(ith_token)
  112. for token in self.token_map:
  113. if token in placeholder_token:
  114. raise ValueError(
  115. f"The tokenizer already has placeholder token {token} "
  116. f"that can get confused with {placeholder_token} "
  117. "keep placeholder tokens independent"
  118. )
  119. self.token_map[placeholder_token] = output
  120. def replace_placeholder_tokens_in_text(
  121. self,
  122. text: Union[str, List[str]],
  123. vector_shuffle: bool = False,
  124. prop_tokens_to_load: float = 1.0,
  125. ) -> Union[str, List[str]]:
  126. """Replace the keywords in text with placeholder tokens. This function
  127. will be called in `self.__call__` and `self.encode`.
  128. Args:
  129. text (Union[str, List[str]]): The text to be processed.
  130. vector_shuffle (bool, optional): Whether to shuffle the vectors.
  131. Defaults to False.
  132. prop_tokens_to_load (float, optional): The proportion of tokens to
  133. be loaded. If 1.0, all tokens will be loaded. Defaults to 1.0.
  134. Returns:
  135. Union[str, List[str]]: The processed text.
  136. """
  137. if isinstance(text, list):
  138. output = []
  139. for i in range(len(text)):
  140. output.append(
  141. self.replace_placeholder_tokens_in_text(
  142. text[i], vector_shuffle=vector_shuffle
  143. )
  144. )
  145. return output
  146. for placeholder_token in self.token_map:
  147. if placeholder_token in text:
  148. tokens = self.token_map[placeholder_token]
  149. tokens = tokens[: 1 + int(len(tokens) * prop_tokens_to_load)]
  150. if vector_shuffle:
  151. tokens = copy.copy(tokens)
  152. random.shuffle(tokens)
  153. text = text.replace(placeholder_token, " ".join(tokens))
  154. return text
  155. def replace_text_with_placeholder_tokens(
  156. self, text: Union[str, List[str]]
  157. ) -> Union[str, List[str]]:
  158. """Replace the placeholder tokens in text with the original keywords.
  159. This function will be called in `self.decode`.
  160. Args:
  161. text (Union[str, List[str]]): The text to be processed.
  162. Returns:
  163. Union[str, List[str]]: The processed text.
  164. """
  165. if isinstance(text, list):
  166. output = []
  167. for i in range(len(text)):
  168. output.append(self.replace_text_with_placeholder_tokens(text[i]))
  169. return output
  170. for placeholder_token, tokens in self.token_map.items():
  171. merged_tokens = " ".join(tokens)
  172. if merged_tokens in text:
  173. text = text.replace(merged_tokens, placeholder_token)
  174. return text
  175. def __call__(
  176. self,
  177. text: Union[str, List[str]],
  178. *args,
  179. vector_shuffle: bool = False,
  180. prop_tokens_to_load: float = 1.0,
  181. **kwargs,
  182. ):
  183. """The call function of the wrapper.
  184. Args:
  185. text (Union[str, List[str]]): The text to be tokenized.
  186. vector_shuffle (bool, optional): Whether to shuffle the vectors.
  187. Defaults to False.
  188. prop_tokens_to_load (float, optional): The proportion of tokens to
  189. be loaded. If 1.0, all tokens will be loaded. Defaults to 1.0
  190. *args, **kwargs: The arguments for `self.wrapped.__call__`.
  191. """
  192. replaced_text = self.replace_placeholder_tokens_in_text(
  193. text, vector_shuffle=vector_shuffle, prop_tokens_to_load=prop_tokens_to_load
  194. )
  195. return self.wrapped.__call__(replaced_text, *args, **kwargs)
  196. def encode(self, text: Union[str, List[str]], *args, **kwargs):
  197. """Encode the passed text to token index.
  198. Args:
  199. text (Union[str, List[str]]): The text to be encode.
  200. *args, **kwargs: The arguments for `self.wrapped.__call__`.
  201. """
  202. replaced_text = self.replace_placeholder_tokens_in_text(text)
  203. return self.wrapped(replaced_text, *args, **kwargs)
  204. def decode(
  205. self, token_ids, return_raw: bool = False, *args, **kwargs
  206. ) -> Union[str, List[str]]:
  207. """Decode the token index to text.
  208. Args:
  209. token_ids: The token index to be decoded.
  210. return_raw: Whether keep the placeholder token in the text.
  211. Defaults to False.
  212. *args, **kwargs: The arguments for `self.wrapped.decode`.
  213. Returns:
  214. Union[str, List[str]]: The decoded text.
  215. """
  216. text = self.wrapped.decode(token_ids, *args, **kwargs)
  217. if return_raw:
  218. return text
  219. replaced_text = self.replace_text_with_placeholder_tokens(text)
  220. return replaced_text