pipeline_brushnet.py 66 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455
  1. # https://github.com/TencentARC/BrushNet
  2. import inspect
  3. from typing import Any, Callable, Dict, List, Optional, Union
  4. import numpy as np
  5. import PIL.Image
  6. import torch
  7. import torch.nn.functional as F
  8. from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
  9. from diffusers.loaders import (
  10. FromSingleFileMixin,
  11. IPAdapterMixin,
  12. LoraLoaderMixin,
  13. TextualInversionLoaderMixin,
  14. )
  15. from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel
  16. from diffusers.models.lora import adjust_lora_scale_text_encoder
  17. from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
  18. from diffusers.pipelines.stable_diffusion.pipeline_output import (
  19. StableDiffusionPipelineOutput,
  20. )
  21. from diffusers.pipelines.stable_diffusion.safety_checker import (
  22. StableDiffusionSafetyChecker,
  23. )
  24. from diffusers.schedulers import KarrasDiffusionSchedulers
  25. from diffusers.utils import (
  26. USE_PEFT_BACKEND,
  27. deprecate,
  28. logging,
  29. replace_example_docstring,
  30. scale_lora_layers,
  31. unscale_lora_layers,
  32. )
  33. from diffusers.utils.torch_utils import (
  34. is_compiled_module,
  35. is_torch_version,
  36. randn_tensor,
  37. )
  38. from transformers import (
  39. CLIPImageProcessor,
  40. CLIPTextModel,
  41. CLIPTokenizer,
  42. CLIPVisionModelWithProjection,
  43. )
  44. from .brushnet import BrushNetModel
  45. logger = logging.get_logger(__name__) # pylint: disable=invalid-name
  46. EXAMPLE_DOC_STRING = """
  47. Examples:
  48. ```py
  49. from diffusers import StableDiffusionBrushNetPipeline, BrushNetModel, UniPCMultistepScheduler
  50. from diffusers.utils import load_image
  51. import torch
  52. import cv2
  53. import numpy as np
  54. from PIL import Image
  55. base_model_path = "runwayml/stable-diffusion-v1-5"
  56. brushnet_path = "ckpt_path"
  57. brushnet = BrushNetModel.from_pretrained(brushnet_path, torch_dtype=torch.float16)
  58. pipe = StableDiffusionBrushNetPipeline.from_pretrained(
  59. base_model_path, brushnet=brushnet, torch_dtype=torch.float16, low_cpu_mem_usage=False
  60. )
  61. # speed up diffusion process with faster scheduler and memory optimization
  62. pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
  63. # remove following line if xformers is not installed or when using Torch 2.0.
  64. # pipe.enable_xformers_memory_efficient_attention()
  65. # memory optimization.
  66. pipe.enable_model_cpu_offload()
  67. image_path="examples/brushnet/src/test_image.jpg"
  68. mask_path="examples/brushnet/src/test_mask.jpg"
  69. caption="A cake on the table."
  70. init_image = cv2.imread(image_path)
  71. mask_image = 1.*(cv2.imread(mask_path).sum(-1)>255)[:,:,np.newaxis]
  72. init_image = init_image * (1-mask_image)
  73. init_image = Image.fromarray(init_image.astype(np.uint8)).convert("RGB")
  74. mask_image = Image.fromarray(mask_image.astype(np.uint8).repeat(3,-1)*255).convert("RGB")
  75. generator = torch.Generator("cuda").manual_seed(1234)
  76. image = pipe(
  77. caption,
  78. init_image,
  79. mask_image,
  80. num_inference_steps=50,
  81. generator=generator,
  82. paintingnet_conditioning_scale=1.0
  83. ).images[0]
  84. image.save("output.png")
  85. ```
  86. """
  87. # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
  88. def retrieve_timesteps(
  89. scheduler,
  90. num_inference_steps: Optional[int] = None,
  91. device: Optional[Union[str, torch.device]] = None,
  92. timesteps: Optional[List[int]] = None,
  93. **kwargs,
  94. ):
  95. """
  96. Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
  97. custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
  98. Args:
  99. scheduler (`SchedulerMixin`):
  100. The scheduler to get timesteps from.
  101. num_inference_steps (`int`):
  102. The number of diffusion steps used when generating samples with a pre-trained model. If used,
  103. `timesteps` must be `None`.
  104. device (`str` or `torch.device`, *optional*):
  105. The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
  106. timesteps (`List[int]`, *optional*):
  107. Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
  108. timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
  109. must be `None`.
  110. Returns:
  111. `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
  112. second element is the number of inference steps.
  113. """
  114. if timesteps is not None:
  115. accepts_timesteps = "timesteps" in set(
  116. inspect.signature(scheduler.set_timesteps).parameters.keys()
  117. )
  118. if not accepts_timesteps:
  119. raise ValueError(
  120. f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
  121. f" timestep schedules. Please check whether you are using the correct scheduler."
  122. )
  123. scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
  124. timesteps = scheduler.timesteps
  125. num_inference_steps = len(timesteps)
  126. else:
  127. scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
  128. timesteps = scheduler.timesteps
  129. return timesteps, num_inference_steps
  130. class StableDiffusionBrushNetPipeline(
  131. DiffusionPipeline,
  132. StableDiffusionMixin,
  133. TextualInversionLoaderMixin,
  134. LoraLoaderMixin,
  135. IPAdapterMixin,
  136. FromSingleFileMixin,
  137. ):
  138. r"""
  139. Pipeline for text-to-image generation using Stable Diffusion with BrushNet guidance.
  140. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
  141. implemented for all pipelines (downloading, saving, running on a particular device, etc.).
  142. The pipeline also inherits the following loading methods:
  143. - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
  144. - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
  145. - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
  146. - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
  147. - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
  148. Args:
  149. vae ([`AutoencoderKL`]):
  150. Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
  151. text_encoder ([`~transformers.CLIPTextModel`]):
  152. Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
  153. tokenizer ([`~transformers.CLIPTokenizer`]):
  154. A `CLIPTokenizer` to tokenize text.
  155. unet ([`UNet2DConditionModel`]):
  156. A `UNet2DConditionModel` to denoise the encoded image latents.
  157. brushnet ([`BrushNetModel`]`):
  158. Provides additional conditioning to the `unet` during the denoising process.
  159. scheduler ([`SchedulerMixin`]):
  160. A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
  161. [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
  162. safety_checker ([`StableDiffusionSafetyChecker`]):
  163. Classification module that estimates whether generated images could be considered offensive or harmful.
  164. Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
  165. about a model's potential harms.
  166. feature_extractor ([`~transformers.CLIPImageProcessor`]):
  167. A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
  168. """
  169. model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
  170. _optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
  171. _exclude_from_cpu_offload = ["safety_checker"]
  172. _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
  173. def __init__(
  174. self,
  175. vae: AutoencoderKL,
  176. text_encoder: CLIPTextModel,
  177. tokenizer: CLIPTokenizer,
  178. unet: UNet2DConditionModel,
  179. brushnet: BrushNetModel,
  180. scheduler: KarrasDiffusionSchedulers,
  181. safety_checker: StableDiffusionSafetyChecker,
  182. feature_extractor: CLIPImageProcessor,
  183. image_encoder: CLIPVisionModelWithProjection = None,
  184. requires_safety_checker: bool = True,
  185. ):
  186. super().__init__()
  187. if safety_checker is None and requires_safety_checker:
  188. logger.warning(
  189. f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
  190. " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
  191. " results in services or applications open to the public. Both the diffusers team and Hugging Face"
  192. " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
  193. " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
  194. " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
  195. )
  196. if safety_checker is not None and feature_extractor is None:
  197. raise ValueError(
  198. f"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
  199. " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
  200. )
  201. self.register_modules(
  202. vae=vae,
  203. text_encoder=text_encoder,
  204. tokenizer=tokenizer,
  205. unet=unet,
  206. brushnet=brushnet,
  207. scheduler=scheduler,
  208. safety_checker=safety_checker,
  209. feature_extractor=feature_extractor,
  210. image_encoder=image_encoder,
  211. )
  212. self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
  213. self.image_processor = VaeImageProcessor(
  214. vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True
  215. )
  216. self.register_to_config(requires_safety_checker=requires_safety_checker)
  217. # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
  218. def _encode_prompt(
  219. self,
  220. prompt,
  221. device,
  222. num_images_per_prompt,
  223. do_classifier_free_guidance,
  224. negative_prompt=None,
  225. prompt_embeds: Optional[torch.FloatTensor] = None,
  226. negative_prompt_embeds: Optional[torch.FloatTensor] = None,
  227. lora_scale: Optional[float] = None,
  228. **kwargs,
  229. ):
  230. deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
  231. deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
  232. prompt_embeds_tuple = self.encode_prompt(
  233. prompt=prompt,
  234. device=device,
  235. num_images_per_prompt=num_images_per_prompt,
  236. do_classifier_free_guidance=do_classifier_free_guidance,
  237. negative_prompt=negative_prompt,
  238. prompt_embeds=prompt_embeds,
  239. negative_prompt_embeds=negative_prompt_embeds,
  240. lora_scale=lora_scale,
  241. **kwargs,
  242. )
  243. # concatenate for backwards comp
  244. prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
  245. return prompt_embeds
  246. # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
  247. def encode_prompt(
  248. self,
  249. prompt,
  250. device,
  251. num_images_per_prompt,
  252. do_classifier_free_guidance,
  253. negative_prompt=None,
  254. prompt_embeds: Optional[torch.FloatTensor] = None,
  255. negative_prompt_embeds: Optional[torch.FloatTensor] = None,
  256. lora_scale: Optional[float] = None,
  257. clip_skip: Optional[int] = None,
  258. ):
  259. r"""
  260. Encodes the prompt into text encoder hidden states.
  261. Args:
  262. prompt (`str` or `List[str]`, *optional*):
  263. prompt to be encoded
  264. device: (`torch.device`):
  265. torch device
  266. num_images_per_prompt (`int`):
  267. number of images that should be generated per prompt
  268. do_classifier_free_guidance (`bool`):
  269. whether to use classifier free guidance or not
  270. negative_prompt (`str` or `List[str]`, *optional*):
  271. The prompt or prompts not to guide the image generation. If not defined, one has to pass
  272. `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
  273. less than `1`).
  274. prompt_embeds (`torch.FloatTensor`, *optional*):
  275. Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
  276. provided, text embeddings will be generated from `prompt` input argument.
  277. negative_prompt_embeds (`torch.FloatTensor`, *optional*):
  278. Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
  279. weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
  280. argument.
  281. lora_scale (`float`, *optional*):
  282. A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
  283. clip_skip (`int`, *optional*):
  284. Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
  285. the output of the pre-final layer will be used for computing the prompt embeddings.
  286. """
  287. # set lora scale so that monkey patched LoRA
  288. # function of text encoder can correctly access it
  289. if lora_scale is not None and isinstance(self, LoraLoaderMixin):
  290. self._lora_scale = lora_scale
  291. # dynamically adjust the LoRA scale
  292. if not USE_PEFT_BACKEND:
  293. adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
  294. else:
  295. scale_lora_layers(self.text_encoder, lora_scale)
  296. if prompt is not None and isinstance(prompt, str):
  297. batch_size = 1
  298. elif prompt is not None and isinstance(prompt, list):
  299. batch_size = len(prompt)
  300. else:
  301. batch_size = prompt_embeds.shape[0]
  302. if prompt_embeds is None:
  303. # textual inversion: process multi-vector tokens if necessary
  304. if isinstance(self, TextualInversionLoaderMixin):
  305. prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
  306. text_inputs = self.tokenizer(
  307. prompt,
  308. padding="max_length",
  309. max_length=self.tokenizer.model_max_length,
  310. truncation=True,
  311. return_tensors="pt",
  312. )
  313. text_input_ids = text_inputs.input_ids
  314. untruncated_ids = self.tokenizer(
  315. prompt, padding="longest", return_tensors="pt"
  316. ).input_ids
  317. if untruncated_ids.shape[-1] >= text_input_ids.shape[
  318. -1
  319. ] and not torch.equal(text_input_ids, untruncated_ids):
  320. removed_text = self.tokenizer.batch_decode(
  321. untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
  322. )
  323. logger.warning(
  324. "The following part of your input was truncated because CLIP can only handle sequences up to"
  325. f" {self.tokenizer.model_max_length} tokens: {removed_text}"
  326. )
  327. if (
  328. hasattr(self.text_encoder.config, "use_attention_mask")
  329. and self.text_encoder.config.use_attention_mask
  330. ):
  331. attention_mask = text_inputs.attention_mask.to(device)
  332. else:
  333. attention_mask = None
  334. if clip_skip is None:
  335. prompt_embeds = self.text_encoder(
  336. text_input_ids.to(device), attention_mask=attention_mask
  337. )
  338. prompt_embeds = prompt_embeds[0]
  339. else:
  340. prompt_embeds = self.text_encoder(
  341. text_input_ids.to(device),
  342. attention_mask=attention_mask,
  343. output_hidden_states=True,
  344. )
  345. # Access the `hidden_states` first, that contains a tuple of
  346. # all the hidden states from the encoder layers. Then index into
  347. # the tuple to access the hidden states from the desired layer.
  348. prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
  349. # We also need to apply the final LayerNorm here to not mess with the
  350. # representations. The `last_hidden_states` that we typically use for
  351. # obtaining the final prompt representations passes through the LayerNorm
  352. # layer.
  353. prompt_embeds = self.text_encoder.text_model.final_layer_norm(
  354. prompt_embeds
  355. )
  356. if self.text_encoder is not None:
  357. prompt_embeds_dtype = self.text_encoder.dtype
  358. elif self.unet is not None:
  359. prompt_embeds_dtype = self.unet.dtype
  360. else:
  361. prompt_embeds_dtype = prompt_embeds.dtype
  362. prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
  363. bs_embed, seq_len, _ = prompt_embeds.shape
  364. # duplicate text embeddings for each generation per prompt, using mps friendly method
  365. prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
  366. prompt_embeds = prompt_embeds.view(
  367. bs_embed * num_images_per_prompt, seq_len, -1
  368. )
  369. # get unconditional embeddings for classifier free guidance
  370. if do_classifier_free_guidance and negative_prompt_embeds is None:
  371. uncond_tokens: List[str]
  372. if negative_prompt is None:
  373. uncond_tokens = [""] * batch_size
  374. elif prompt is not None and type(prompt) is not type(negative_prompt):
  375. raise TypeError(
  376. f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
  377. f" {type(prompt)}."
  378. )
  379. elif isinstance(negative_prompt, str):
  380. uncond_tokens = [negative_prompt]
  381. elif batch_size != len(negative_prompt):
  382. raise ValueError(
  383. f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
  384. f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
  385. " the batch size of `prompt`."
  386. )
  387. else:
  388. uncond_tokens = negative_prompt
  389. # textual inversion: process multi-vector tokens if necessary
  390. if isinstance(self, TextualInversionLoaderMixin):
  391. uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
  392. max_length = prompt_embeds.shape[1]
  393. uncond_input = self.tokenizer(
  394. uncond_tokens,
  395. padding="max_length",
  396. max_length=max_length,
  397. truncation=True,
  398. return_tensors="pt",
  399. )
  400. if (
  401. hasattr(self.text_encoder.config, "use_attention_mask")
  402. and self.text_encoder.config.use_attention_mask
  403. ):
  404. attention_mask = uncond_input.attention_mask.to(device)
  405. else:
  406. attention_mask = None
  407. negative_prompt_embeds = self.text_encoder(
  408. uncond_input.input_ids.to(device),
  409. attention_mask=attention_mask,
  410. )
  411. negative_prompt_embeds = negative_prompt_embeds[0]
  412. if do_classifier_free_guidance:
  413. # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
  414. seq_len = negative_prompt_embeds.shape[1]
  415. negative_prompt_embeds = negative_prompt_embeds.to(
  416. dtype=prompt_embeds_dtype, device=device
  417. )
  418. negative_prompt_embeds = negative_prompt_embeds.repeat(
  419. 1, num_images_per_prompt, 1
  420. )
  421. negative_prompt_embeds = negative_prompt_embeds.view(
  422. batch_size * num_images_per_prompt, seq_len, -1
  423. )
  424. if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
  425. # Retrieve the original scale by scaling back the LoRA layers
  426. unscale_lora_layers(self.text_encoder, lora_scale)
  427. return prompt_embeds, negative_prompt_embeds
  428. # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
  429. def encode_image(
  430. self, image, device, num_images_per_prompt, output_hidden_states=None
  431. ):
  432. dtype = next(self.image_encoder.parameters()).dtype
  433. if not isinstance(image, torch.Tensor):
  434. image = self.feature_extractor(image, return_tensors="pt").pixel_values
  435. image = image.to(device=device, dtype=dtype)
  436. if output_hidden_states:
  437. image_enc_hidden_states = self.image_encoder(
  438. image, output_hidden_states=True
  439. ).hidden_states[-2]
  440. image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(
  441. num_images_per_prompt, dim=0
  442. )
  443. uncond_image_enc_hidden_states = self.image_encoder(
  444. torch.zeros_like(image), output_hidden_states=True
  445. ).hidden_states[-2]
  446. uncond_image_enc_hidden_states = (
  447. uncond_image_enc_hidden_states.repeat_interleave(
  448. num_images_per_prompt, dim=0
  449. )
  450. )
  451. return image_enc_hidden_states, uncond_image_enc_hidden_states
  452. else:
  453. image_embeds = self.image_encoder(image).image_embeds
  454. image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
  455. uncond_image_embeds = torch.zeros_like(image_embeds)
  456. return image_embeds, uncond_image_embeds
  457. # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
  458. def prepare_ip_adapter_image_embeds(
  459. self,
  460. ip_adapter_image,
  461. ip_adapter_image_embeds,
  462. device,
  463. num_images_per_prompt,
  464. do_classifier_free_guidance,
  465. ):
  466. if ip_adapter_image_embeds is None:
  467. if not isinstance(ip_adapter_image, list):
  468. ip_adapter_image = [ip_adapter_image]
  469. if len(ip_adapter_image) != len(
  470. self.unet.encoder_hid_proj.image_projection_layers
  471. ):
  472. raise ValueError(
  473. f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
  474. )
  475. image_embeds = []
  476. for single_ip_adapter_image, image_proj_layer in zip(
  477. ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
  478. ):
  479. output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
  480. single_image_embeds, single_negative_image_embeds = self.encode_image(
  481. single_ip_adapter_image, device, 1, output_hidden_state
  482. )
  483. single_image_embeds = torch.stack(
  484. [single_image_embeds] * num_images_per_prompt, dim=0
  485. )
  486. single_negative_image_embeds = torch.stack(
  487. [single_negative_image_embeds] * num_images_per_prompt, dim=0
  488. )
  489. if do_classifier_free_guidance:
  490. single_image_embeds = torch.cat(
  491. [single_negative_image_embeds, single_image_embeds]
  492. )
  493. single_image_embeds = single_image_embeds.to(device)
  494. image_embeds.append(single_image_embeds)
  495. else:
  496. repeat_dims = [1]
  497. image_embeds = []
  498. for single_image_embeds in ip_adapter_image_embeds:
  499. if do_classifier_free_guidance:
  500. (
  501. single_negative_image_embeds,
  502. single_image_embeds,
  503. ) = single_image_embeds.chunk(2)
  504. single_image_embeds = single_image_embeds.repeat(
  505. num_images_per_prompt,
  506. *(repeat_dims * len(single_image_embeds.shape[1:])),
  507. )
  508. single_negative_image_embeds = single_negative_image_embeds.repeat(
  509. num_images_per_prompt,
  510. *(repeat_dims * len(single_negative_image_embeds.shape[1:])),
  511. )
  512. single_image_embeds = torch.cat(
  513. [single_negative_image_embeds, single_image_embeds]
  514. )
  515. else:
  516. single_image_embeds = single_image_embeds.repeat(
  517. num_images_per_prompt,
  518. *(repeat_dims * len(single_image_embeds.shape[1:])),
  519. )
  520. image_embeds.append(single_image_embeds)
  521. return image_embeds
  522. # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
  523. def run_safety_checker(self, image, device, dtype):
  524. if self.safety_checker is None:
  525. has_nsfw_concept = None
  526. else:
  527. if torch.is_tensor(image):
  528. feature_extractor_input = self.image_processor.postprocess(
  529. image, output_type="pil"
  530. )
  531. else:
  532. feature_extractor_input = self.image_processor.numpy_to_pil(image)
  533. safety_checker_input = self.feature_extractor(
  534. feature_extractor_input, return_tensors="pt"
  535. ).to(device)
  536. image, has_nsfw_concept = self.safety_checker(
  537. images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
  538. )
  539. return image, has_nsfw_concept
  540. # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
  541. def decode_latents(self, latents):
  542. deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
  543. deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
  544. latents = 1 / self.vae.config.scaling_factor * latents
  545. image = self.vae.decode(latents, return_dict=False)[0]
  546. image = (image / 2 + 0.5).clamp(0, 1)
  547. # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
  548. image = image.cpu().permute(0, 2, 3, 1).float().numpy()
  549. return image
  550. # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
  551. def prepare_extra_step_kwargs(self, generator, eta):
  552. # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
  553. # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
  554. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
  555. # and should be between [0, 1]
  556. accepts_eta = "eta" in set(
  557. inspect.signature(self.scheduler.step).parameters.keys()
  558. )
  559. extra_step_kwargs = {}
  560. if accepts_eta:
  561. extra_step_kwargs["eta"] = eta
  562. # check if the scheduler accepts generator
  563. accepts_generator = "generator" in set(
  564. inspect.signature(self.scheduler.step).parameters.keys()
  565. )
  566. if accepts_generator:
  567. extra_step_kwargs["generator"] = generator
  568. return extra_step_kwargs
  569. def check_inputs(
  570. self,
  571. prompt,
  572. image,
  573. mask,
  574. callback_steps,
  575. negative_prompt=None,
  576. prompt_embeds=None,
  577. negative_prompt_embeds=None,
  578. ip_adapter_image=None,
  579. ip_adapter_image_embeds=None,
  580. brushnet_conditioning_scale=1.0,
  581. control_guidance_start=0.0,
  582. control_guidance_end=1.0,
  583. callback_on_step_end_tensor_inputs=None,
  584. ):
  585. if callback_steps is not None and (
  586. not isinstance(callback_steps, int) or callback_steps <= 0
  587. ):
  588. raise ValueError(
  589. f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
  590. f" {type(callback_steps)}."
  591. )
  592. if callback_on_step_end_tensor_inputs is not None and not all(
  593. k in self._callback_tensor_inputs
  594. for k in callback_on_step_end_tensor_inputs
  595. ):
  596. raise ValueError(
  597. f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
  598. )
  599. if prompt is not None and prompt_embeds is not None:
  600. raise ValueError(
  601. f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
  602. " only forward one of the two."
  603. )
  604. elif prompt is None and prompt_embeds is None:
  605. raise ValueError(
  606. "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
  607. )
  608. elif prompt is not None and (
  609. not isinstance(prompt, str) and not isinstance(prompt, list)
  610. ):
  611. raise ValueError(
  612. f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
  613. )
  614. if negative_prompt is not None and negative_prompt_embeds is not None:
  615. raise ValueError(
  616. f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
  617. f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
  618. )
  619. if prompt_embeds is not None and negative_prompt_embeds is not None:
  620. if prompt_embeds.shape != negative_prompt_embeds.shape:
  621. raise ValueError(
  622. "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
  623. f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
  624. f" {negative_prompt_embeds.shape}."
  625. )
  626. # Check `image`
  627. is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
  628. self.brushnet, torch._dynamo.eval_frame.OptimizedModule
  629. )
  630. if (
  631. isinstance(self.brushnet, BrushNetModel)
  632. or is_compiled
  633. and isinstance(self.brushnet._orig_mod, BrushNetModel)
  634. ):
  635. self.check_image(image, mask, prompt, prompt_embeds)
  636. else:
  637. assert False
  638. # Check `brushnet_conditioning_scale`
  639. if (
  640. isinstance(self.brushnet, BrushNetModel)
  641. or is_compiled
  642. and isinstance(self.brushnet._orig_mod, BrushNetModel)
  643. ):
  644. if not isinstance(brushnet_conditioning_scale, float):
  645. raise TypeError(
  646. "For single brushnet: `brushnet_conditioning_scale` must be type `float`."
  647. )
  648. else:
  649. assert False
  650. if not isinstance(control_guidance_start, (tuple, list)):
  651. control_guidance_start = [control_guidance_start]
  652. if not isinstance(control_guidance_end, (tuple, list)):
  653. control_guidance_end = [control_guidance_end]
  654. if len(control_guidance_start) != len(control_guidance_end):
  655. raise ValueError(
  656. f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list."
  657. )
  658. for start, end in zip(control_guidance_start, control_guidance_end):
  659. if start >= end:
  660. raise ValueError(
  661. f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}."
  662. )
  663. if start < 0.0:
  664. raise ValueError(
  665. f"control guidance start: {start} can't be smaller than 0."
  666. )
  667. if end > 1.0:
  668. raise ValueError(
  669. f"control guidance end: {end} can't be larger than 1.0."
  670. )
  671. if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
  672. raise ValueError(
  673. "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
  674. )
  675. if ip_adapter_image_embeds is not None:
  676. if not isinstance(ip_adapter_image_embeds, list):
  677. raise ValueError(
  678. f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
  679. )
  680. elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
  681. raise ValueError(
  682. f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
  683. )
  684. def check_image(self, image, mask, prompt, prompt_embeds):
  685. image_is_pil = isinstance(image, PIL.Image.Image)
  686. image_is_tensor = isinstance(image, torch.Tensor)
  687. image_is_np = isinstance(image, np.ndarray)
  688. image_is_pil_list = isinstance(image, list) and isinstance(
  689. image[0], PIL.Image.Image
  690. )
  691. image_is_tensor_list = isinstance(image, list) and isinstance(
  692. image[0], torch.Tensor
  693. )
  694. image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)
  695. if (
  696. not image_is_pil
  697. and not image_is_tensor
  698. and not image_is_np
  699. and not image_is_pil_list
  700. and not image_is_tensor_list
  701. and not image_is_np_list
  702. ):
  703. raise TypeError(
  704. f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}"
  705. )
  706. mask_is_pil = isinstance(mask, PIL.Image.Image)
  707. mask_is_tensor = isinstance(mask, torch.Tensor)
  708. mask_is_np = isinstance(mask, np.ndarray)
  709. mask_is_pil_list = isinstance(mask, list) and isinstance(
  710. mask[0], PIL.Image.Image
  711. )
  712. mask_is_tensor_list = isinstance(mask, list) and isinstance(
  713. mask[0], torch.Tensor
  714. )
  715. mask_is_np_list = isinstance(mask, list) and isinstance(mask[0], np.ndarray)
  716. if (
  717. not mask_is_pil
  718. and not mask_is_tensor
  719. and not mask_is_np
  720. and not mask_is_pil_list
  721. and not mask_is_tensor_list
  722. and not mask_is_np_list
  723. ):
  724. raise TypeError(
  725. f"mask must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(mask)}"
  726. )
  727. if image_is_pil:
  728. image_batch_size = 1
  729. else:
  730. image_batch_size = len(image)
  731. if prompt is not None and isinstance(prompt, str):
  732. prompt_batch_size = 1
  733. elif prompt is not None and isinstance(prompt, list):
  734. prompt_batch_size = len(prompt)
  735. elif prompt_embeds is not None:
  736. prompt_batch_size = prompt_embeds.shape[0]
  737. if image_batch_size != 1 and image_batch_size != prompt_batch_size:
  738. raise ValueError(
  739. f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
  740. )
  741. def prepare_image(
  742. self,
  743. image,
  744. width,
  745. height,
  746. batch_size,
  747. num_images_per_prompt,
  748. device,
  749. dtype,
  750. do_classifier_free_guidance=False,
  751. guess_mode=False,
  752. ):
  753. image = self.image_processor.preprocess(image, height=height, width=width).to(
  754. dtype=torch.float32
  755. )
  756. image_batch_size = image.shape[0]
  757. if image_batch_size == 1:
  758. repeat_by = batch_size
  759. else:
  760. # image batch size is the same as prompt batch size
  761. repeat_by = num_images_per_prompt
  762. image = image.repeat_interleave(repeat_by, dim=0)
  763. image = image.to(device=device, dtype=dtype)
  764. if do_classifier_free_guidance and not guess_mode:
  765. image = torch.cat([image] * 2)
  766. return image
  767. # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
  768. def prepare_latents(
  769. self,
  770. batch_size,
  771. num_channels_latents,
  772. height,
  773. width,
  774. dtype,
  775. device,
  776. generator,
  777. latents=None,
  778. ):
  779. shape = (
  780. batch_size,
  781. num_channels_latents,
  782. height // self.vae_scale_factor,
  783. width // self.vae_scale_factor,
  784. )
  785. if isinstance(generator, list) and len(generator) != batch_size:
  786. raise ValueError(
  787. f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
  788. f" size of {batch_size}. Make sure the batch size matches the length of the generators."
  789. )
  790. if latents is None:
  791. noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
  792. else:
  793. noise = latents.to(device)
  794. # scale the initial noise by the standard deviation required by the scheduler
  795. latents = noise * self.scheduler.init_noise_sigma
  796. return latents, noise
  797. # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
  798. def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
  799. """
  800. See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
  801. Args:
  802. timesteps (`torch.Tensor`):
  803. generate embedding vectors at these timesteps
  804. embedding_dim (`int`, *optional*, defaults to 512):
  805. dimension of the embeddings to generate
  806. dtype:
  807. data type of the generated embeddings
  808. Returns:
  809. `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
  810. """
  811. assert len(w.shape) == 1
  812. w = w * 1000.0
  813. half_dim = embedding_dim // 2
  814. emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
  815. emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
  816. emb = w.to(dtype)[:, None] * emb[None, :]
  817. emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
  818. if embedding_dim % 2 == 1: # zero pad
  819. emb = torch.nn.functional.pad(emb, (0, 1))
  820. assert emb.shape == (w.shape[0], embedding_dim)
  821. return emb
  822. @property
  823. def guidance_scale(self):
  824. return self._guidance_scale
  825. @property
  826. def clip_skip(self):
  827. return self._clip_skip
  828. # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
  829. # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
  830. # corresponds to doing no classifier free guidance.
  831. @property
  832. def do_classifier_free_guidance(self):
  833. return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
  834. @property
  835. def cross_attention_kwargs(self):
  836. return self._cross_attention_kwargs
  837. @property
  838. def num_timesteps(self):
  839. return self._num_timesteps
  840. @torch.no_grad()
  841. @replace_example_docstring(EXAMPLE_DOC_STRING)
  842. def __call__(
  843. self,
  844. prompt: Union[str, List[str]] = None,
  845. image: PipelineImageInput = None,
  846. mask: PipelineImageInput = None,
  847. height: Optional[int] = None,
  848. width: Optional[int] = None,
  849. num_inference_steps: int = 50,
  850. timesteps: List[int] = None,
  851. guidance_scale: float = 7.5,
  852. negative_prompt: Optional[Union[str, List[str]]] = None,
  853. num_images_per_prompt: Optional[int] = 1,
  854. eta: float = 0.0,
  855. generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
  856. latents: Optional[torch.FloatTensor] = None,
  857. prompt_embeds: Optional[torch.FloatTensor] = None,
  858. negative_prompt_embeds: Optional[torch.FloatTensor] = None,
  859. ip_adapter_image: Optional[PipelineImageInput] = None,
  860. ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None,
  861. output_type: Optional[str] = "pil",
  862. return_dict: bool = True,
  863. cross_attention_kwargs: Optional[Dict[str, Any]] = None,
  864. brushnet_conditioning_scale: Union[float, List[float]] = 1.0,
  865. guess_mode: bool = False,
  866. control_guidance_start: Union[float, List[float]] = 0.0,
  867. control_guidance_end: Union[float, List[float]] = 1.0,
  868. clip_skip: Optional[int] = None,
  869. callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
  870. callback_on_step_end_tensor_inputs: List[str] = ["latents"],
  871. **kwargs,
  872. ):
  873. r"""
  874. The call function to the pipeline for generation.
  875. Args:
  876. prompt (`str` or `List[str]`, *optional*):
  877. The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
  878. image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
  879. `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
  880. The BrushNet input condition to provide guidance to the `unet` for generation. If the type is
  881. specified as `torch.FloatTensor`, it is passed to BrushNet as is. `PIL.Image.Image` can also be
  882. accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height
  883. and/or width are passed, `image` is resized accordingly. If multiple BrushNets are specified in
  884. `init`, images must be passed as a list such that each element of the list can be correctly batched for
  885. input to a single BrushNet. When `prompt` is a list, and if a list of images is passed for a single BrushNet,
  886. each will be paired with each prompt in the `prompt` list. This also applies to multiple BrushNets,
  887. where a list of image lists can be passed to batch for each prompt and each BrushNet.
  888. mask (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
  889. `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
  890. The BrushNet input condition to provide guidance to the `unet` for generation. If the type is
  891. specified as `torch.FloatTensor`, it is passed to BrushNet as is. `PIL.Image.Image` can also be
  892. accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height
  893. and/or width are passed, `image` is resized accordingly. If multiple BrushNets are specified in
  894. `init`, images must be passed as a list such that each element of the list can be correctly batched for
  895. input to a single BrushNet. When `prompt` is a list, and if a list of images is passed for a single BrushNet,
  896. each will be paired with each prompt in the `prompt` list. This also applies to multiple BrushNets,
  897. where a list of image lists can be passed to batch for each prompt and each BrushNet.
  898. height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
  899. The height in pixels of the generated image.
  900. width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
  901. The width in pixels of the generated image.
  902. num_inference_steps (`int`, *optional*, defaults to 50):
  903. The number of denoising steps. More denoising steps usually lead to a higher quality image at the
  904. expense of slower inference.
  905. timesteps (`List[int]`, *optional*):
  906. Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
  907. in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
  908. passed will be used. Must be in descending order.
  909. guidance_scale (`float`, *optional*, defaults to 7.5):
  910. A higher guidance scale value encourages the model to generate images closely linked to the text
  911. `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
  912. negative_prompt (`str` or `List[str]`, *optional*):
  913. The prompt or prompts to guide what to not include in image generation. If not defined, you need to
  914. pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
  915. num_images_per_prompt (`int`, *optional*, defaults to 1):
  916. The number of images to generate per prompt.
  917. eta (`float`, *optional*, defaults to 0.0):
  918. Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
  919. to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
  920. generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
  921. A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
  922. generation deterministic.
  923. latents (`torch.FloatTensor`, *optional*):
  924. Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
  925. generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
  926. tensor is generated by sampling using the supplied random `generator`.
  927. prompt_embeds (`torch.FloatTensor`, *optional*):
  928. Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
  929. provided, text embeddings are generated from the `prompt` input argument.
  930. negative_prompt_embeds (`torch.FloatTensor`, *optional*):
  931. Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
  932. not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
  933. ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
  934. ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
  935. Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters.
  936. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding
  937. if `do_classifier_free_guidance` is set to `True`.
  938. If not provided, embeddings are computed from the `ip_adapter_image` input argument.
  939. output_type (`str`, *optional*, defaults to `"pil"`):
  940. The output format of the generated image. Choose between `PIL.Image` or `np.array`.
  941. return_dict (`bool`, *optional*, defaults to `True`):
  942. Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
  943. plain tuple.
  944. callback (`Callable`, *optional*):
  945. A function that calls every `callback_steps` steps during inference. The function is called with the
  946. following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
  947. callback_steps (`int`, *optional*, defaults to 1):
  948. The frequency at which the `callback` function is called. If not specified, the callback is called at
  949. every step.
  950. cross_attention_kwargs (`dict`, *optional*):
  951. A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
  952. [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
  953. brushnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
  954. The outputs of the BrushNet are multiplied by `brushnet_conditioning_scale` before they are added
  955. to the residual in the original `unet`. If multiple BrushNets are specified in `init`, you can set
  956. the corresponding scale as a list.
  957. guess_mode (`bool`, *optional*, defaults to `False`):
  958. The BrushNet encoder tries to recognize the content of the input image even if you remove all
  959. prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended.
  960. control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
  961. The percentage of total steps at which the BrushNet starts applying.
  962. control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
  963. The percentage of total steps at which the BrushNet stops applying.
  964. clip_skip (`int`, *optional*):
  965. Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
  966. the output of the pre-final layer will be used for computing the prompt embeddings.
  967. callback_on_step_end (`Callable`, *optional*):
  968. A function that calls at the end of each denoising steps during the inference. The function is called
  969. with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
  970. callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
  971. `callback_on_step_end_tensor_inputs`.
  972. callback_on_step_end_tensor_inputs (`List`, *optional*):
  973. The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
  974. will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
  975. `._callback_tensor_inputs` attribute of your pipeine class.
  976. Examples:
  977. Returns:
  978. [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
  979. If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
  980. otherwise a `tuple` is returned where the first element is a list with the generated images and the
  981. second element is a list of `bool`s indicating whether the corresponding generated image contains
  982. "not-safe-for-work" (nsfw) content.
  983. """
  984. callback = kwargs.pop("callback", None)
  985. callback_steps = kwargs.pop("callback_steps", None)
  986. if callback is not None:
  987. deprecate(
  988. "callback",
  989. "1.0.0",
  990. "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
  991. )
  992. if callback_steps is not None:
  993. deprecate(
  994. "callback_steps",
  995. "1.0.0",
  996. "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
  997. )
  998. brushnet = (
  999. self.brushnet._orig_mod
  1000. if is_compiled_module(self.brushnet)
  1001. else self.brushnet
  1002. )
  1003. # align format for control guidance
  1004. if not isinstance(control_guidance_start, list) and isinstance(
  1005. control_guidance_end, list
  1006. ):
  1007. control_guidance_start = len(control_guidance_end) * [
  1008. control_guidance_start
  1009. ]
  1010. elif not isinstance(control_guidance_end, list) and isinstance(
  1011. control_guidance_start, list
  1012. ):
  1013. control_guidance_end = len(control_guidance_start) * [control_guidance_end]
  1014. elif not isinstance(control_guidance_start, list) and not isinstance(
  1015. control_guidance_end, list
  1016. ):
  1017. control_guidance_start, control_guidance_end = (
  1018. [control_guidance_start],
  1019. [control_guidance_end],
  1020. )
  1021. # 1. Check inputs. Raise error if not correct
  1022. self.check_inputs(
  1023. prompt,
  1024. image,
  1025. mask,
  1026. callback_steps,
  1027. negative_prompt,
  1028. prompt_embeds,
  1029. negative_prompt_embeds,
  1030. ip_adapter_image,
  1031. ip_adapter_image_embeds,
  1032. brushnet_conditioning_scale,
  1033. control_guidance_start,
  1034. control_guidance_end,
  1035. callback_on_step_end_tensor_inputs,
  1036. )
  1037. self._guidance_scale = guidance_scale
  1038. self._clip_skip = clip_skip
  1039. self._cross_attention_kwargs = cross_attention_kwargs
  1040. # 2. Define call parameters
  1041. if prompt is not None and isinstance(prompt, str):
  1042. batch_size = 1
  1043. elif prompt is not None and isinstance(prompt, list):
  1044. batch_size = len(prompt)
  1045. else:
  1046. batch_size = prompt_embeds.shape[0]
  1047. device = self._execution_device
  1048. global_pool_conditions = (
  1049. brushnet.config.global_pool_conditions
  1050. if isinstance(brushnet, BrushNetModel)
  1051. else brushnet.nets[0].config.global_pool_conditions
  1052. )
  1053. guess_mode = guess_mode or global_pool_conditions
  1054. # 3. Encode input prompt
  1055. text_encoder_lora_scale = (
  1056. self.cross_attention_kwargs.get("scale", None)
  1057. if self.cross_attention_kwargs is not None
  1058. else None
  1059. )
  1060. prompt_embeds, negative_prompt_embeds = self.encode_prompt(
  1061. prompt,
  1062. device,
  1063. num_images_per_prompt,
  1064. self.do_classifier_free_guidance,
  1065. negative_prompt,
  1066. prompt_embeds=prompt_embeds,
  1067. negative_prompt_embeds=negative_prompt_embeds,
  1068. lora_scale=text_encoder_lora_scale,
  1069. clip_skip=self.clip_skip,
  1070. )
  1071. # For classifier free guidance, we need to do two forward passes.
  1072. # Here we concatenate the unconditional and text embeddings into a single batch
  1073. # to avoid doing two forward passes
  1074. if self.do_classifier_free_guidance:
  1075. prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
  1076. if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
  1077. image_embeds = self.prepare_ip_adapter_image_embeds(
  1078. ip_adapter_image,
  1079. ip_adapter_image_embeds,
  1080. device,
  1081. batch_size * num_images_per_prompt,
  1082. self.do_classifier_free_guidance,
  1083. )
  1084. # 4. Prepare image
  1085. if isinstance(brushnet, BrushNetModel):
  1086. image = self.prepare_image(
  1087. image=image,
  1088. width=width,
  1089. height=height,
  1090. batch_size=batch_size * num_images_per_prompt,
  1091. num_images_per_prompt=num_images_per_prompt,
  1092. device=device,
  1093. dtype=brushnet.dtype,
  1094. do_classifier_free_guidance=self.do_classifier_free_guidance,
  1095. guess_mode=guess_mode,
  1096. )
  1097. original_mask = self.prepare_image(
  1098. image=mask,
  1099. width=width,
  1100. height=height,
  1101. batch_size=batch_size * num_images_per_prompt,
  1102. num_images_per_prompt=num_images_per_prompt,
  1103. device=device,
  1104. dtype=brushnet.dtype,
  1105. do_classifier_free_guidance=self.do_classifier_free_guidance,
  1106. guess_mode=guess_mode,
  1107. )
  1108. original_mask = (original_mask.sum(1)[:, None, :, :] < 0).to(image.dtype)
  1109. height, width = image.shape[-2:]
  1110. else:
  1111. assert False
  1112. # 5. Prepare timesteps
  1113. timesteps, num_inference_steps = retrieve_timesteps(
  1114. self.scheduler, num_inference_steps, device, timesteps
  1115. )
  1116. self._num_timesteps = len(timesteps)
  1117. # 6. Prepare latent variables
  1118. num_channels_latents = self.unet.config.in_channels
  1119. latents, noise = self.prepare_latents(
  1120. batch_size * num_images_per_prompt,
  1121. num_channels_latents,
  1122. height,
  1123. width,
  1124. prompt_embeds.dtype,
  1125. device,
  1126. generator,
  1127. latents,
  1128. )
  1129. # 6.1 prepare condition latents
  1130. conditioning_latents = (
  1131. self.vae.encode(image).latent_dist.sample() * self.vae.config.scaling_factor
  1132. )
  1133. mask = torch.nn.functional.interpolate(
  1134. original_mask,
  1135. size=(conditioning_latents.shape[-2], conditioning_latents.shape[-1]),
  1136. )
  1137. conditioning_latents = torch.concat([conditioning_latents, mask], 1)
  1138. # 6.5 Optionally get Guidance Scale Embedding
  1139. timestep_cond = None
  1140. if self.unet.config.time_cond_proj_dim is not None:
  1141. guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(
  1142. batch_size * num_images_per_prompt
  1143. )
  1144. timestep_cond = self.get_guidance_scale_embedding(
  1145. guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
  1146. ).to(device=device, dtype=latents.dtype)
  1147. # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
  1148. extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
  1149. # 7.1 Add image embeds for IP-Adapter
  1150. added_cond_kwargs = (
  1151. {"image_embeds": image_embeds}
  1152. if ip_adapter_image is not None or ip_adapter_image_embeds is not None
  1153. else None
  1154. )
  1155. # 7.2 Create tensor stating which brushnets to keep
  1156. brushnet_keep = []
  1157. for i in range(len(timesteps)):
  1158. keeps = [
  1159. 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
  1160. for s, e in zip(control_guidance_start, control_guidance_end)
  1161. ]
  1162. brushnet_keep.append(
  1163. keeps[0] if isinstance(brushnet, BrushNetModel) else keeps
  1164. )
  1165. # 8. Denoising loop
  1166. num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
  1167. is_unet_compiled = is_compiled_module(self.unet)
  1168. is_brushnet_compiled = is_compiled_module(self.brushnet)
  1169. is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
  1170. with self.progress_bar(total=num_inference_steps) as progress_bar:
  1171. for i, t in enumerate(timesteps):
  1172. # Relevant thread:
  1173. # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
  1174. if (
  1175. is_unet_compiled and is_brushnet_compiled
  1176. ) and is_torch_higher_equal_2_1:
  1177. torch._inductor.cudagraph_mark_step_begin()
  1178. # expand the latents if we are doing classifier free guidance
  1179. latent_model_input = (
  1180. torch.cat([latents] * 2)
  1181. if self.do_classifier_free_guidance
  1182. else latents
  1183. )
  1184. latent_model_input = self.scheduler.scale_model_input(
  1185. latent_model_input, t
  1186. )
  1187. # brushnet(s) inference
  1188. if guess_mode and self.do_classifier_free_guidance:
  1189. # Infer BrushNet only for the conditional batch.
  1190. control_model_input = latents
  1191. control_model_input = self.scheduler.scale_model_input(
  1192. control_model_input, t
  1193. )
  1194. brushnet_prompt_embeds = prompt_embeds.chunk(2)[1]
  1195. else:
  1196. control_model_input = latent_model_input
  1197. brushnet_prompt_embeds = prompt_embeds
  1198. if isinstance(brushnet_keep[i], list):
  1199. cond_scale = [
  1200. c * s
  1201. for c, s in zip(brushnet_conditioning_scale, brushnet_keep[i])
  1202. ]
  1203. else:
  1204. brushnet_cond_scale = brushnet_conditioning_scale
  1205. if isinstance(brushnet_cond_scale, list):
  1206. brushnet_cond_scale = brushnet_cond_scale[0]
  1207. cond_scale = brushnet_cond_scale * brushnet_keep[i]
  1208. (
  1209. down_block_res_samples,
  1210. mid_block_res_sample,
  1211. up_block_res_samples,
  1212. ) = self.brushnet(
  1213. control_model_input,
  1214. t,
  1215. encoder_hidden_states=brushnet_prompt_embeds,
  1216. brushnet_cond=conditioning_latents,
  1217. conditioning_scale=cond_scale,
  1218. guess_mode=guess_mode,
  1219. return_dict=False,
  1220. )
  1221. if guess_mode and self.do_classifier_free_guidance:
  1222. # Infered BrushNet only for the conditional batch.
  1223. # To apply the output of BrushNet to both the unconditional and conditional batches,
  1224. # add 0 to the unconditional batch to keep it unchanged.
  1225. down_block_res_samples = [
  1226. torch.cat([torch.zeros_like(d), d])
  1227. for d in down_block_res_samples
  1228. ]
  1229. mid_block_res_sample = torch.cat(
  1230. [torch.zeros_like(mid_block_res_sample), mid_block_res_sample]
  1231. )
  1232. up_block_res_samples = [
  1233. torch.cat([torch.zeros_like(d), d])
  1234. for d in up_block_res_samples
  1235. ]
  1236. # predict the noise residual
  1237. noise_pred = self.unet(
  1238. latent_model_input,
  1239. t,
  1240. encoder_hidden_states=prompt_embeds,
  1241. timestep_cond=timestep_cond,
  1242. cross_attention_kwargs=self.cross_attention_kwargs,
  1243. down_block_add_samples=down_block_res_samples,
  1244. mid_block_add_sample=mid_block_res_sample,
  1245. up_block_add_samples=up_block_res_samples,
  1246. added_cond_kwargs=added_cond_kwargs,
  1247. return_dict=False,
  1248. )[0]
  1249. # perform guidance
  1250. if self.do_classifier_free_guidance:
  1251. noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
  1252. noise_pred = noise_pred_uncond + self.guidance_scale * (
  1253. noise_pred_text - noise_pred_uncond
  1254. )
  1255. # compute the previous noisy sample x_t -> x_t-1
  1256. latents = self.scheduler.step(
  1257. noise_pred, t, latents, **extra_step_kwargs, return_dict=False
  1258. )[0]
  1259. if callback_on_step_end is not None:
  1260. callback_kwargs = {}
  1261. for k in callback_on_step_end_tensor_inputs:
  1262. callback_kwargs[k] = locals()[k]
  1263. callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
  1264. latents = callback_outputs.pop("latents", latents)
  1265. prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
  1266. negative_prompt_embeds = callback_outputs.pop(
  1267. "negative_prompt_embeds", negative_prompt_embeds
  1268. )
  1269. # call the callback, if provided
  1270. if i == len(timesteps) - 1 or (
  1271. (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
  1272. ):
  1273. progress_bar.update()
  1274. if callback is not None and i % callback_steps == 0:
  1275. step_idx = i // getattr(self.scheduler, "order", 1)
  1276. callback(step_idx, t, latents)
  1277. # If we do sequential model offloading, let's offload unet and brushnet
  1278. # manually for max memory savings
  1279. if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
  1280. self.unet.to("cpu")
  1281. self.brushnet.to("cpu")
  1282. torch.cuda.empty_cache()
  1283. if not output_type == "latent":
  1284. image = self.vae.decode(
  1285. latents / self.vae.config.scaling_factor,
  1286. return_dict=False,
  1287. generator=generator,
  1288. )[0]
  1289. image, has_nsfw_concept = self.run_safety_checker(
  1290. image, device, prompt_embeds.dtype
  1291. )
  1292. else:
  1293. image = latents
  1294. has_nsfw_concept = None
  1295. if has_nsfw_concept is None:
  1296. do_denormalize = [True] * image.shape[0]
  1297. else:
  1298. do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
  1299. image = self.image_processor.postprocess(
  1300. image, output_type=output_type, do_denormalize=do_denormalize
  1301. )
  1302. # Offload all models
  1303. self.maybe_free_model_hooks()
  1304. if not return_dict:
  1305. return (image, has_nsfw_concept)
  1306. return StableDiffusionPipelineOutput(
  1307. images=image, nsfw_content_detected=has_nsfw_concept
  1308. )