hack.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. import einops
  2. import torch
  3. from transformers import logging
  4. import sorawm.iopaint.model.anytext.ldm.modules.attention
  5. import sorawm.iopaint.model.anytext.ldm.modules.encoders.modules
  6. from sorawm.iopaint.model.anytext.ldm.modules.attention import default
  7. def disable_verbosity():
  8. logging.set_verbosity_error()
  9. print("logging improved.")
  10. return
  11. def enable_sliced_attention():
  12. sorawm.iopaint.model.anytext.ldm.modules.attention.CrossAttention.forward = (
  13. _hacked_sliced_attentin_forward
  14. )
  15. print("Enabled sliced_attention.")
  16. return
  17. def hack_everything(clip_skip=0):
  18. disable_verbosity()
  19. sorawm.iopaint.model.anytext.ldm.modules.encoders.modules.FrozenCLIPEmbedder.forward = (
  20. _hacked_clip_forward
  21. )
  22. sorawm.iopaint.model.anytext.ldm.modules.encoders.modules.FrozenCLIPEmbedder.clip_skip = (
  23. clip_skip
  24. )
  25. print("Enabled clip hacks.")
  26. return
  27. # Written by Lvmin
  28. def _hacked_clip_forward(self, text):
  29. PAD = self.tokenizer.pad_token_id
  30. EOS = self.tokenizer.eos_token_id
  31. BOS = self.tokenizer.bos_token_id
  32. def tokenize(t):
  33. return self.tokenizer(t, truncation=False, add_special_tokens=False)[
  34. "input_ids"
  35. ]
  36. def transformer_encode(t):
  37. if self.clip_skip > 1:
  38. rt = self.transformer(input_ids=t, output_hidden_states=True)
  39. return self.transformer.text_model.final_layer_norm(
  40. rt.hidden_states[-self.clip_skip]
  41. )
  42. else:
  43. return self.transformer(
  44. input_ids=t, output_hidden_states=False
  45. ).last_hidden_state
  46. def split(x):
  47. return x[75 * 0 : 75 * 1], x[75 * 1 : 75 * 2], x[75 * 2 : 75 * 3]
  48. def pad(x, p, i):
  49. return x[:i] if len(x) >= i else x + [p] * (i - len(x))
  50. raw_tokens_list = tokenize(text)
  51. tokens_list = []
  52. for raw_tokens in raw_tokens_list:
  53. raw_tokens_123 = split(raw_tokens)
  54. raw_tokens_123 = [
  55. [BOS] + raw_tokens_i + [EOS] for raw_tokens_i in raw_tokens_123
  56. ]
  57. raw_tokens_123 = [pad(raw_tokens_i, PAD, 77) for raw_tokens_i in raw_tokens_123]
  58. tokens_list.append(raw_tokens_123)
  59. tokens_list = torch.IntTensor(tokens_list).to(self.device)
  60. feed = einops.rearrange(tokens_list, "b f i -> (b f) i")
  61. y = transformer_encode(feed)
  62. z = einops.rearrange(y, "(b f) i c -> b (f i) c", f=3)
  63. return z
  64. # Stolen from https://github.com/basujindal/stable-diffusion/blob/main/optimizedSD/splitAttention.py
  65. def _hacked_sliced_attentin_forward(self, x, context=None, mask=None):
  66. h = self.heads
  67. q = self.to_q(x)
  68. context = default(context, x)
  69. k = self.to_k(context)
  70. v = self.to_v(context)
  71. del context, x
  72. q, k, v = map(
  73. lambda t: einops.rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)
  74. )
  75. limit = k.shape[0]
  76. att_step = 1
  77. q_chunks = list(torch.tensor_split(q, limit // att_step, dim=0))
  78. k_chunks = list(torch.tensor_split(k, limit // att_step, dim=0))
  79. v_chunks = list(torch.tensor_split(v, limit // att_step, dim=0))
  80. q_chunks.reverse()
  81. k_chunks.reverse()
  82. v_chunks.reverse()
  83. sim = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
  84. del k, q, v
  85. for i in range(0, limit, att_step):
  86. q_buffer = q_chunks.pop()
  87. k_buffer = k_chunks.pop()
  88. v_buffer = v_chunks.pop()
  89. sim_buffer = (
  90. torch.einsum("b i d, b j d -> b i j", q_buffer, k_buffer) * self.scale
  91. )
  92. del k_buffer, q_buffer
  93. # attention, what we cannot get enough of, by chunks
  94. sim_buffer = sim_buffer.softmax(dim=-1)
  95. sim_buffer = torch.einsum("b i j, b j d -> b i d", sim_buffer, v_buffer)
  96. del v_buffer
  97. sim[i : i + att_step, :, :] = sim_buffer
  98. del sim_buffer
  99. sim = einops.rearrange(sim, "(b h) n d -> b n (h d)", h=h)
  100. return self.to_out(sim)