| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255 |
- import copy
- import random
- from typing import Any, List, Union
- from transformers import CLIPTokenizer
- from sorawm.iopaint.schema import PowerPaintTask
- def add_task_to_prompt(prompt, negative_prompt, task: PowerPaintTask):
- if task == PowerPaintTask.object_remove:
- promptA = prompt + " P_ctxt"
- promptB = prompt + " P_ctxt"
- negative_promptA = negative_prompt + " P_obj"
- negative_promptB = negative_prompt + " P_obj"
- elif task == PowerPaintTask.context_aware:
- promptA = prompt + " P_ctxt"
- promptB = prompt + " P_ctxt"
- negative_promptA = negative_prompt
- negative_promptB = negative_prompt
- elif task == PowerPaintTask.shape_guided:
- promptA = prompt + " P_shape"
- promptB = prompt + " P_ctxt"
- negative_promptA = negative_prompt
- negative_promptB = negative_prompt
- elif task == PowerPaintTask.outpainting:
- promptA = prompt + " P_ctxt"
- promptB = prompt + " P_ctxt"
- negative_promptA = negative_prompt + " P_obj"
- negative_promptB = negative_prompt + " P_obj"
- else:
- promptA = prompt + " P_obj"
- promptB = prompt + " P_obj"
- negative_promptA = negative_prompt
- negative_promptB = negative_prompt
- return promptA, promptB, negative_promptA, negative_promptB
- def task_to_prompt(task: PowerPaintTask):
- promptA, promptB, negative_promptA, negative_promptB = add_task_to_prompt(
- "", "", task
- )
- return (
- promptA.strip(),
- promptB.strip(),
- negative_promptA.strip(),
- negative_promptB.strip(),
- )
- class PowerPaintTokenizer:
- def __init__(self, tokenizer: CLIPTokenizer):
- self.wrapped = tokenizer
- self.token_map = {}
- placeholder_tokens = ["P_ctxt", "P_shape", "P_obj"]
- num_vec_per_token = 10
- for placeholder_token in placeholder_tokens:
- output = []
- for i in range(num_vec_per_token):
- ith_token = placeholder_token + f"_{i}"
- output.append(ith_token)
- self.token_map[placeholder_token] = output
- def __getattr__(self, name: str) -> Any:
- if name == "wrapped":
- return super().__getattr__("wrapped")
- try:
- return getattr(self.wrapped, name)
- except AttributeError:
- try:
- return super().__getattr__(name)
- except AttributeError:
- raise AttributeError(
- "'name' cannot be found in both "
- f"'{self.__class__.__name__}' and "
- f"'{self.__class__.__name__}.tokenizer'."
- )
- def try_adding_tokens(self, tokens: Union[str, List[str]], *args, **kwargs):
- """Attempt to add tokens to the tokenizer.
- Args:
- tokens (Union[str, List[str]]): The tokens to be added.
- """
- num_added_tokens = self.wrapped.add_tokens(tokens, *args, **kwargs)
- assert num_added_tokens != 0, (
- f"The tokenizer already contains the token {tokens}. Please pass "
- "a different `placeholder_token` that is not already in the "
- "tokenizer."
- )
- def get_token_info(self, token: str) -> dict:
- """Get the information of a token, including its start and end index in
- the current tokenizer.
- Args:
- token (str): The token to be queried.
- Returns:
- dict: The information of the token, including its start and end
- index in current tokenizer.
- """
- token_ids = self.__call__(token).input_ids
- start, end = token_ids[1], token_ids[-2] + 1
- return {"name": token, "start": start, "end": end}
- def add_placeholder_token(
- self, placeholder_token: str, *args, num_vec_per_token: int = 1, **kwargs
- ):
- """Add placeholder tokens to the tokenizer.
- Args:
- placeholder_token (str): The placeholder token to be added.
- num_vec_per_token (int, optional): The number of vectors of
- the added placeholder token.
- *args, **kwargs: The arguments for `self.wrapped.add_tokens`.
- """
- output = []
- if num_vec_per_token == 1:
- self.try_adding_tokens(placeholder_token, *args, **kwargs)
- output.append(placeholder_token)
- else:
- output = []
- for i in range(num_vec_per_token):
- ith_token = placeholder_token + f"_{i}"
- self.try_adding_tokens(ith_token, *args, **kwargs)
- output.append(ith_token)
- for token in self.token_map:
- if token in placeholder_token:
- raise ValueError(
- f"The tokenizer already has placeholder token {token} "
- f"that can get confused with {placeholder_token} "
- "keep placeholder tokens independent"
- )
- self.token_map[placeholder_token] = output
- def replace_placeholder_tokens_in_text(
- self,
- text: Union[str, List[str]],
- vector_shuffle: bool = False,
- prop_tokens_to_load: float = 1.0,
- ) -> Union[str, List[str]]:
- """Replace the keywords in text with placeholder tokens. This function
- will be called in `self.__call__` and `self.encode`.
- Args:
- text (Union[str, List[str]]): The text to be processed.
- vector_shuffle (bool, optional): Whether to shuffle the vectors.
- Defaults to False.
- prop_tokens_to_load (float, optional): The proportion of tokens to
- be loaded. If 1.0, all tokens will be loaded. Defaults to 1.0.
- Returns:
- Union[str, List[str]]: The processed text.
- """
- if isinstance(text, list):
- output = []
- for i in range(len(text)):
- output.append(
- self.replace_placeholder_tokens_in_text(
- text[i], vector_shuffle=vector_shuffle
- )
- )
- return output
- for placeholder_token in self.token_map:
- if placeholder_token in text:
- tokens = self.token_map[placeholder_token]
- tokens = tokens[: 1 + int(len(tokens) * prop_tokens_to_load)]
- if vector_shuffle:
- tokens = copy.copy(tokens)
- random.shuffle(tokens)
- text = text.replace(placeholder_token, " ".join(tokens))
- return text
- def replace_text_with_placeholder_tokens(
- self, text: Union[str, List[str]]
- ) -> Union[str, List[str]]:
- """Replace the placeholder tokens in text with the original keywords.
- This function will be called in `self.decode`.
- Args:
- text (Union[str, List[str]]): The text to be processed.
- Returns:
- Union[str, List[str]]: The processed text.
- """
- if isinstance(text, list):
- output = []
- for i in range(len(text)):
- output.append(self.replace_text_with_placeholder_tokens(text[i]))
- return output
- for placeholder_token, tokens in self.token_map.items():
- merged_tokens = " ".join(tokens)
- if merged_tokens in text:
- text = text.replace(merged_tokens, placeholder_token)
- return text
- def __call__(
- self,
- text: Union[str, List[str]],
- *args,
- vector_shuffle: bool = False,
- prop_tokens_to_load: float = 1.0,
- **kwargs,
- ):
- """The call function of the wrapper.
- Args:
- text (Union[str, List[str]]): The text to be tokenized.
- vector_shuffle (bool, optional): Whether to shuffle the vectors.
- Defaults to False.
- prop_tokens_to_load (float, optional): The proportion of tokens to
- be loaded. If 1.0, all tokens will be loaded. Defaults to 1.0
- *args, **kwargs: The arguments for `self.wrapped.__call__`.
- """
- replaced_text = self.replace_placeholder_tokens_in_text(
- text, vector_shuffle=vector_shuffle, prop_tokens_to_load=prop_tokens_to_load
- )
- return self.wrapped.__call__(replaced_text, *args, **kwargs)
- def encode(self, text: Union[str, List[str]], *args, **kwargs):
- """Encode the passed text to token index.
- Args:
- text (Union[str, List[str]]): The text to be encode.
- *args, **kwargs: The arguments for `self.wrapped.__call__`.
- """
- replaced_text = self.replace_placeholder_tokens_in_text(text)
- return self.wrapped(replaced_text, *args, **kwargs)
- def decode(
- self, token_ids, return_raw: bool = False, *args, **kwargs
- ) -> Union[str, List[str]]:
- """Decode the token index to text.
- Args:
- token_ids: The token index to be decoded.
- return_raw: Whether keep the placeholder token in the text.
- Defaults to False.
- *args, **kwargs: The arguments for `self.wrapped.decode`.
- Returns:
- Union[str, List[str]]: The decoded text.
- """
- text = self.wrapped.decode(token_ids, *args, **kwargs)
- if return_raw:
- return text
- replaced_text = self.replace_text_with_placeholder_tokens(text)
- return replaced_text
|