Ver Fonte

feat: 添加Sora水印去除工具的核心功能与资源文件

新增了Sora水印去除工具的核心功能模块,包括:
- 添加了多种图像处理模型(LaMa、LDM、ZITS等)
- 实现了视频水印检测与去除功能
- 添加了测试用图片和视频资源
- 完善了项目配置和文档说明
- 增加了YOLO水印检测模型支持
- 实现了FastAPI后端服务
- 添加了便携版一键启动脚本

新增资源文件包括测试图片、视频模板和模型配置文件等
max_liu há 2 semanas atrás
commit
f8b09a41da
100 ficheiros alterados com 10693 adições e 0 exclusões
  1. 201 0
      LICENSE
  2. 4 0
      README-run.md
  3. 164 0
      README-zh.md
  4. 183 0
      README.md
  5. 175 0
      app.py
  6. BIN
      data/db.sqlite3
  7. 49 0
      datasets/make_yolo_images.py
  8. 9 0
      example.py
  9. 69 0
      ffmpeg/README.md
  10. 43 0
      logs/log_file.log
  11. 48 0
      notebooks/imputation.ipynb
  12. 26 0
      one-click-portable.md
  13. 40 0
      pyproject.toml
  14. BIN
      resources/19700121_1645_68e0a027836c8191a50bea3717ea7485.mp4
  15. BIN
      resources/53abf3fd-11a9-4dd7-a348-34920775f8ad.png
  16. BIN
      resources/app.png
  17. BIN
      resources/best.pt
  18. BIN
      resources/dog_vs_sam.mp4
  19. 23 0
      resources/first_frame.json
  20. BIN
      resources/first_frame.png
  21. BIN
      resources/puppies.mp4
  22. BIN
      resources/sora_watermark_removed.mp4
  23. BIN
      resources/watermark_template.png
  24. 0 0
      sorawm/__init__.py
  25. BIN
      sorawm/__pycache__/__init__.cpython-312.pyc
  26. BIN
      sorawm/__pycache__/configs.cpython-312.pyc
  27. BIN
      sorawm/__pycache__/core.cpython-312.pyc
  28. BIN
      sorawm/__pycache__/watermark_cleaner.cpython-312.pyc
  29. BIN
      sorawm/__pycache__/watermark_detector.cpython-312.pyc
  30. 27 0
      sorawm/configs.py
  31. 199 0
      sorawm/core.py
  32. 56 0
      sorawm/iopaint/__init__.py
  33. 4 0
      sorawm/iopaint/__main__.py
  34. BIN
      sorawm/iopaint/__pycache__/__init__.cpython-312.pyc
  35. BIN
      sorawm/iopaint/__pycache__/const.cpython-312.pyc
  36. BIN
      sorawm/iopaint/__pycache__/download.cpython-312.pyc
  37. BIN
      sorawm/iopaint/__pycache__/helper.cpython-312.pyc
  38. BIN
      sorawm/iopaint/__pycache__/model_manager.cpython-312.pyc
  39. BIN
      sorawm/iopaint/__pycache__/schema.cpython-312.pyc
  40. 411 0
      sorawm/iopaint/api.py
  41. 128 0
      sorawm/iopaint/batch_processing.py
  42. 109 0
      sorawm/iopaint/benchmark.py
  43. 245 0
      sorawm/iopaint/cli.py
  44. 134 0
      sorawm/iopaint/const.py
  45. 314 0
      sorawm/iopaint/download.py
  46. 1 0
      sorawm/iopaint/file_manager/__init__.py
  47. 220 0
      sorawm/iopaint/file_manager/file_manager.py
  48. 46 0
      sorawm/iopaint/file_manager/storage_backends.py
  49. 64 0
      sorawm/iopaint/file_manager/utils.py
  50. 411 0
      sorawm/iopaint/helper.py
  51. 11 0
      sorawm/iopaint/installer.py
  52. 38 0
      sorawm/iopaint/model/__init__.py
  53. BIN
      sorawm/iopaint/model/__pycache__/__init__.cpython-312.pyc
  54. BIN
      sorawm/iopaint/model/__pycache__/base.cpython-312.pyc
  55. BIN
      sorawm/iopaint/model/__pycache__/controlnet.cpython-312.pyc
  56. BIN
      sorawm/iopaint/model/__pycache__/ddim_sampler.cpython-312.pyc
  57. BIN
      sorawm/iopaint/model/__pycache__/fcf.cpython-312.pyc
  58. BIN
      sorawm/iopaint/model/__pycache__/instruct_pix2pix.cpython-312.pyc
  59. BIN
      sorawm/iopaint/model/__pycache__/kandinsky.cpython-312.pyc
  60. BIN
      sorawm/iopaint/model/__pycache__/lama.cpython-312.pyc
  61. BIN
      sorawm/iopaint/model/__pycache__/ldm.cpython-312.pyc
  62. BIN
      sorawm/iopaint/model/__pycache__/manga.cpython-312.pyc
  63. BIN
      sorawm/iopaint/model/__pycache__/mat.cpython-312.pyc
  64. BIN
      sorawm/iopaint/model/__pycache__/mi_gan.cpython-312.pyc
  65. BIN
      sorawm/iopaint/model/__pycache__/opencv2.cpython-312.pyc
  66. BIN
      sorawm/iopaint/model/__pycache__/paint_by_example.cpython-312.pyc
  67. BIN
      sorawm/iopaint/model/__pycache__/plms_sampler.cpython-312.pyc
  68. BIN
      sorawm/iopaint/model/__pycache__/sd.cpython-312.pyc
  69. BIN
      sorawm/iopaint/model/__pycache__/sdxl.cpython-312.pyc
  70. BIN
      sorawm/iopaint/model/__pycache__/utils.cpython-312.pyc
  71. BIN
      sorawm/iopaint/model/__pycache__/zits.cpython-312.pyc
  72. 0 0
      sorawm/iopaint/model/anytext/__init__.py
  73. BIN
      sorawm/iopaint/model/anytext/__pycache__/__init__.cpython-312.pyc
  74. BIN
      sorawm/iopaint/model/anytext/__pycache__/anytext_model.cpython-312.pyc
  75. BIN
      sorawm/iopaint/model/anytext/__pycache__/anytext_pipeline.cpython-312.pyc
  76. BIN
      sorawm/iopaint/model/anytext/__pycache__/utils.cpython-312.pyc
  77. 73 0
      sorawm/iopaint/model/anytext/anytext_model.py
  78. 401 0
      sorawm/iopaint/model/anytext/anytext_pipeline.py
  79. 99 0
      sorawm/iopaint/model/anytext/anytext_sd15.yaml
  80. 0 0
      sorawm/iopaint/model/anytext/cldm/__init__.py
  81. BIN
      sorawm/iopaint/model/anytext/cldm/__pycache__/__init__.cpython-312.pyc
  82. BIN
      sorawm/iopaint/model/anytext/cldm/__pycache__/ddim_hacked.cpython-312.pyc
  83. BIN
      sorawm/iopaint/model/anytext/cldm/__pycache__/model.cpython-312.pyc
  84. 780 0
      sorawm/iopaint/model/anytext/cldm/cldm.py
  85. 486 0
      sorawm/iopaint/model/anytext/cldm/ddim_hacked.py
  86. 185 0
      sorawm/iopaint/model/anytext/cldm/embedding_manager.py
  87. 128 0
      sorawm/iopaint/model/anytext/cldm/hack.py
  88. 41 0
      sorawm/iopaint/model/anytext/cldm/model.py
  89. 302 0
      sorawm/iopaint/model/anytext/cldm/recognizer.py
  90. 0 0
      sorawm/iopaint/model/anytext/ldm/__init__.py
  91. BIN
      sorawm/iopaint/model/anytext/ldm/__pycache__/__init__.cpython-312.pyc
  92. BIN
      sorawm/iopaint/model/anytext/ldm/__pycache__/util.cpython-312.pyc
  93. 0 0
      sorawm/iopaint/model/anytext/ldm/models/__init__.py
  94. 275 0
      sorawm/iopaint/model/anytext/ldm/models/autoencoder.py
  95. 0 0
      sorawm/iopaint/model/anytext/ldm/models/diffusion/__init__.py
  96. 525 0
      sorawm/iopaint/model/anytext/ldm/models/diffusion/ddim.py
  97. 2386 0
      sorawm/iopaint/model/anytext/ldm/models/diffusion/ddpm.py
  98. 1 0
      sorawm/iopaint/model/anytext/ldm/models/diffusion/dpm_solver/__init__.py
  99. 1464 0
      sorawm/iopaint/model/anytext/ldm/models/diffusion/dpm_solver/dpm_solver.py
  100. 95 0
      sorawm/iopaint/model/anytext/ldm/models/diffusion/dpm_solver/sampler.py

+ 201 - 0
LICENSE

@@ -0,0 +1,201 @@
+                                 Apache License
+                           Version 2.0, January 2004
+                        http://www.apache.org/licenses/
+
+   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+   1. Definitions.
+
+      "License" shall mean the terms and conditions for use, reproduction,
+      and distribution as defined by Sections 1 through 9 of this document.
+
+      "Licensor" shall mean the copyright owner or entity authorized by
+      the copyright owner that is granting the License.
+
+      "Legal Entity" shall mean the union of the acting entity and all
+      other entities that control, are controlled by, or are under common
+      control with that entity. For the purposes of this definition,
+      "control" means (i) the power, direct or indirect, to cause the
+      direction or management of such entity, whether by contract or
+      otherwise, or (ii) ownership of fifty percent (50%) or more of the
+      outstanding shares, or (iii) beneficial ownership of such entity.
+
+      "You" (or "Your") shall mean an individual or Legal Entity
+      exercising permissions granted by this License.
+
+      "Source" form shall mean the preferred form for making modifications,
+      including but not limited to software source code, documentation
+      source, and configuration files.
+
+      "Object" form shall mean any form resulting from mechanical
+      transformation or translation of a Source form, including but
+      not limited to compiled object code, generated documentation,
+      and conversions to other media types.
+
+      "Work" shall mean the work of authorship, whether in Source or
+      Object form, made available under the License, as indicated by a
+      copyright notice that is included in or attached to the work
+      (an example is provided in the Appendix below).
+
+      "Derivative Works" shall mean any work, whether in Source or Object
+      form, that is based on (or derived from) the Work and for which the
+      editorial revisions, annotations, elaborations, or other modifications
+      represent, as a whole, an original work of authorship. For the purposes
+      of this License, Derivative Works shall not include works that remain
+      separable from, or merely link (or bind by name) to the interfaces of,
+      the Work and Derivative Works thereof.
+
+      "Contribution" shall mean any work of authorship, including
+      the original version of the Work and any modifications or additions
+      to that Work or Derivative Works thereof, that is intentionally
+      submitted to Licensor for inclusion in the Work by the copyright owner
+      or by an individual or Legal Entity authorized to submit on behalf of
+      the copyright owner. For the purposes of this definition, "submitted"
+      means any form of electronic, verbal, or written communication sent
+      to the Licensor or its representatives, including but not limited to
+      communication on electronic mailing lists, source code control systems,
+      and issue tracking systems that are managed by, or on behalf of, the
+      Licensor for the purpose of discussing and improving the Work, but
+      excluding communication that is conspicuously marked or otherwise
+      designated in writing by the copyright owner as "Not a Contribution."
+
+      "Contributor" shall mean Licensor and any individual or Legal Entity
+      on behalf of whom a Contribution has been received by Licensor and
+      subsequently incorporated within the Work.
+
+   2. Grant of Copyright License. Subject to the terms and conditions of
+      this License, each Contributor hereby grants to You a perpetual,
+      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+      copyright license to reproduce, prepare Derivative Works of,
+      publicly display, publicly perform, sublicense, and distribute the
+      Work and such Derivative Works in Source or Object form.
+
+   3. Grant of Patent License. Subject to the terms and conditions of
+      this License, each Contributor hereby grants to You a perpetual,
+      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+      (except as stated in this section) patent license to make, have made,
+      use, offer to sell, sell, import, and otherwise transfer the Work,
+      where such license applies only to those patent claims licensable
+      by such Contributor that are necessarily infringed by their
+      Contribution(s) alone or by combination of their Contribution(s)
+      with the Work to which such Contribution(s) was submitted. If You
+      institute patent litigation against any entity (including a
+      cross-claim or counterclaim in a lawsuit) alleging that the Work
+      or a Contribution incorporated within the Work constitutes direct
+      or contributory patent infringement, then any patent licenses
+      granted to You under this License for that Work shall terminate
+      as of the date such litigation is filed.
+
+   4. Redistribution. You may reproduce and distribute copies of the
+      Work or Derivative Works thereof in any medium, with or without
+      modifications, and in Source or Object form, provided that You
+      meet the following conditions:
+
+      (a) You must give any other recipients of the Work or
+          Derivative Works a copy of this License; and
+
+      (b) You must cause any modified files to carry prominent notices
+          stating that You changed the files; and
+
+      (c) You must retain, in the Source form of any Derivative Works
+          that You distribute, all copyright, patent, trademark, and
+          attribution notices from the Source form of the Work,
+          excluding those notices that do not pertain to any part of
+          the Derivative Works; and
+
+      (d) If the Work includes a "NOTICE" text file as part of its
+          distribution, then any Derivative Works that You distribute must
+          include a readable copy of the attribution notices contained
+          within such NOTICE file, excluding those notices that do not
+          pertain to any part of the Derivative Works, in at least one
+          of the following places: within a NOTICE text file distributed
+          as part of the Derivative Works; within the Source form or
+          documentation, if provided along with the Derivative Works; or,
+          within a display generated by the Derivative Works, if and
+          wherever such third-party notices normally appear. The contents
+          of the NOTICE file are for informational purposes only and
+          do not modify the License. You may add Your own attribution
+          notices within Derivative Works that You distribute, alongside
+          or as an addendum to the NOTICE text from the Work, provided
+          that such additional attribution notices cannot be construed
+          as modifying the License.
+
+      You may add Your own copyright statement to Your modifications and
+      may provide additional or different license terms and conditions
+      for use, reproduction, or distribution of Your modifications, or
+      for any such Derivative Works as a whole, provided Your use,
+      reproduction, and distribution of the Work otherwise complies with
+      the conditions stated in this License.
+
+   5. Submission of Contributions. Unless You explicitly state otherwise,
+      any Contribution intentionally submitted for inclusion in the Work
+      by You to the Licensor shall be under the terms and conditions of
+      this License, without any additional terms or conditions.
+      Notwithstanding the above, nothing herein shall supersede or modify
+      the terms of any separate license agreement you may have executed
+      with Licensor regarding such Contributions.
+
+   6. Trademarks. This License does not grant permission to use the trade
+      names, trademarks, service marks, or product names of the Licensor,
+      except as required for reasonable and customary use in describing the
+      origin of the Work and reproducing the content of the NOTICE file.
+
+   7. Disclaimer of Warranty. Unless required by applicable law or
+      agreed to in writing, Licensor provides the Work (and each
+      Contributor provides its Contributions) on an "AS IS" BASIS,
+      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+      implied, including, without limitation, any warranties or conditions
+      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+      PARTICULAR PURPOSE. You are solely responsible for determining the
+      appropriateness of using or redistributing the Work and assume any
+      risks associated with Your exercise of permissions under this License.
+
+   8. Limitation of Liability. In no event and under no legal theory,
+      whether in tort (including negligence), contract, or otherwise,
+      unless required by applicable law (such as deliberate and grossly
+      negligent acts) or agreed to in writing, shall any Contributor be
+      liable to You for damages, including any direct, indirect, special,
+      incidental, or consequential damages of any character arising as a
+      result of this License or out of the use or inability to use the
+      Work (including but not limited to damages for loss of goodwill,
+      work stoppage, computer failure or malfunction, or any and all
+      other commercial damages or losses), even if such Contributor
+      has been advised of the possibility of such damages.
+
+   9. Accepting Warranty or Additional Liability. While redistributing
+      the Work or Derivative Works thereof, You may choose to offer,
+      and charge a fee for, acceptance of support, warranty, indemnity,
+      or other liability obligations and/or rights consistent with this
+      License. However, in accepting such obligations, You may act only
+      on Your own behalf and on Your sole responsibility, not on behalf
+      of any other Contributor, and only if You agree to indemnify,
+      defend, and hold each Contributor harmless for any liability
+      incurred by, or claims asserted against, such Contributor by reason
+      of your accepting any such warranty or additional liability.
+
+   END OF TERMS AND CONDITIONS
+
+   APPENDIX: How to apply the Apache License to your work.
+
+      To apply the Apache License to your work, attach the following
+      boilerplate notice, with the fields enclosed by brackets "[]"
+      replaced with your own identifying information. (Don't include
+      the brackets!)  The text should be enclosed in the appropriate
+      comment syntax for the file format. We also recommend that a
+      file or class name and description of purpose be included on the
+      same "printed page" as the copyright notice for easier
+      identification within third-party archives.
+
+   Copyright [yyyy] [name of copyright owner]
+
+   Licensed under the Apache License, Version 2.0 (the "License");
+   you may not use this file except in compliance with the License.
+   You may obtain a copy of the License at
+
+       http://www.apache.org/licenses/LICENSE-2.0
+
+   Unless required by applicable law or agreed to in writing, software
+   distributed under the License is distributed on an "AS IS" BASIS,
+   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+   See the License for the specific language governing permissions and
+   limitations under the License.

+ 4 - 0
README-run.md

@@ -0,0 +1,4 @@
+前端正确启动方式(关闭交互提示,便于自动化预览):
+uv run streamlit run app.py --server.port 8503 --server.headless true --browser.gatherUsageStats false
+后端正确启动方式(关闭交互提示,便于自动化预览):
+uv run python start_server.py --host 0.0.0.0 --port 5344

+ 164 - 0
README-zh.md

@@ -0,0 +1,164 @@
+# SoraWatermarkCleaner
+
+[English](README.md) | 中文
+
+这个项目提供了一种优雅的方式来移除 Sora2 生成视频中的 Sora 水印。
+
+
+- 移除水印后
+
+https://github.com/user-attachments/assets/8cdc075e-7d15-4d04-8fa2-53dd287e5f4c
+
+- 原始视频
+
+https://github.com/user-attachments/assets/3c850ff1-b8e3-41af-a46f-2c734406e77d
+
+⭐️: 
+
+1. **YOLO 权重已更新** — 请尝试新版本的水印检测模型,效果会更好!
+
+2. **数据集已开源** — 我们已经将标注好的数据集上传到了 Hugging Face,查看[此数据集](https://huggingface.co/datasets/LLinked/sora-watermark-dataset)。欢迎训练你自己的检测模型或改进我们的模型!
+
+3. **一键便携版已发布** — [点击这里下载](#3-一键便携版),Windows 用户无需安装即可使用!
+
+
+## 1. 方法
+
+SoraWatermarkCleaner(后面我们简称为 `SoraWm`)由两部分组成:
+
+- SoraWaterMarkDetector:我们训练了一个 yolov11s 版本来检测 Sora 水印。(感谢 YOLO!)
+
+- WaterMarkCleaner:我们参考了 IOPaint 的实现,使用 LAMA 模型进行水印移除。
+
+  (此代码库来自 https://github.com/Sanster/IOPaint#,感谢他们的出色工作!)
+
+我们的 SoraWm 完全由深度学习驱动,在许多生成的视频中都能产生良好的效果。
+
+
+
+## 2. 安装
+视频处理需要 [FFmpeg](https://ffmpeg.org/),请先安装它。我们强烈推荐使用 `uv` 来安装环境:
+
+1. 安装:
+
+```bash
+uv sync
+```
+
+> 现在环境将被安装在 `.venv` 目录下,你可以使用以下命令激活环境:
+>
+> ```bash
+> source .venv/bin/activate
+> ```
+
+2. 下载预训练模型:
+
+训练好的 YOLO 权重将存储在 `resources` 目录中,文件名为 `best.pt`。它将从 https://github.com/linkedlist771/SoraWatermarkCleaner/releases/download/V0.0.1/best.pt 自动下载。`Lama` 模型从 https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt 下载,并将存储在 torch 缓存目录中。两者都是自动下载的,如果失败,请检查你的网络状态。
+
+## 3. 一键便携版
+
+对于不想手动安装的用户,我们提供了**一键便携版本**,包含所有预配置的依赖项,开箱即用。
+
+### 下载链接
+
+**Google Drive(谷歌云盘):**
+- [从 Google Drive 下载](https://drive.google.com/file/d/1ujH28aHaCXGgB146g6kyfz3Qxd-wHR1c/view?usp=share_link)
+
+**百度网盘(推荐国内用户使用):**
+- 链接:https://pan.baidu.com/s/1i4exYsPvXv0evnGs5MWcYA?pwd=3jr6
+- 提取码:`3jr6`
+
+### 特点
+- ✅ 无需安装
+- ✅ 包含所有依赖
+- ✅ 预配置环境
+- ✅ 开箱即用
+
+只需下载、解压并运行!
+
+## 4. 演示
+
+基本用法,只需尝试 `example.py`:
+
+```python
+
+from pathlib import Path
+from sorawm.core import SoraWM
+
+
+if __name__ == "__main__":
+    input_video_path = Path(
+        "resources/dog_vs_sam.mp4"
+    )
+    output_video_path = Path("outputs/sora_watermark_removed.mp4")
+    sora_wm = SoraWM()
+    sora_wm.run(input_video_path, output_video_path)
+
+```
+
+我们还提供了基于 `streamlit` 的交互式网页界面,使用以下命令尝试:
+
+```bash
+streamlit run app.py
+```
+
+<img src="resources/app.png" style="zoom: 25%;" />
+
+## 5. WebServer
+
+在这里,我们提供了一个基于 FastAPI 的 Web 服务器,可以快速将这个水印清除器转换为服务。
+
+只需运行:
+
+```python
+python start_server.py
+```
+
+Web 服务器将在端口 `5344` 启动,你可以查看 FastAPI [文档](http://localhost:5344/docs) 了解详情,有三个路由:
+
+1. submit_remove_task:
+
+   > 上传视频后,会返回一个任务 ID,该视频将立即被处理。
+
+   <img src="resources/53abf3fd-11a9-4dd7-a348-34920775f8ad.png" alt="image" style="zoom: 25%;" />
+
+2. get_results:
+
+你可以使用上面的任务 ID 检索任务状态,它会显示视频处理的百分比。一旦完成,返回的数据中会有下载 URL。
+
+3. downlaod:
+
+你可以使用第2步中的下载 URL 来获取清理后的视频。
+
+## 6. 数据集
+
+我们已经将标注好的数据集上传到了 Hugging Face,请查看 https://huggingface.co/datasets/LLinked/sora-watermark-dataset。欢迎训练你自己的检测模型或改进我们的模型!
+
+
+
+## 7. API
+
+打包为 Cog 并[发布到 Replicate](https://replicate.com/uglyrobot/sora2-watermark-remover),便于基于 API 的简单使用。
+
+## 8. 许可证
+
+Apache License
+
+
+## 9. 引用
+
+如果你使用了这个项目,请引用:
+
+```bibtex
+@misc{sorawatermarkcleaner2025,
+  author = {linkedlist771},
+  title = {SoraWatermarkCleaner},
+  year = {2025},
+  url = {https://github.com/linkedlist771/SoraWatermarkCleaner}
+}
+```
+
+## 10. 致谢
+
+- [IOPaint](https://github.com/Sanster/IOPaint) 提供的 LAMA 实现
+- [Ultralytics YOLO](https://github.com/ultralytics/ultralytics) 提供的目标检测

+ 183 - 0
README.md

@@ -0,0 +1,183 @@
+# SoraWatermarkCleaner
+
+English | [中文](README-zh.md)
+
+This project provides an elegant way to remove the sora watermark in the sora2 generated videos.
+
+
+
+<table>
+  <tr>
+    <td width="50%">
+      <h3 align="center">Watermark removed</h3>
+      <video src="https://github.com/user-attachments/assets/8cdc075e-7d15-4d04-8fa2-53dd287e5f4c" width="100%"></video>
+    </td>
+    <td width="50%">
+      <h3 align="center">Original</h3>
+      <video src="https://github.com/user-attachments/assets/4f032fc7-97da-471b-9a54-9de2a434fa57" width="100%"></video>
+    </td>
+  </tr>
+</table>
+
+
+
+
+
+
+
+
+⭐️: 
+
+1. **Yolo weights has been updated, try the new version watermark detect model, it should work better. **
+
+2. **We have uploaded the labelled datasets into huggingface, check this [dataset](https://huggingface.co/datasets/LLinked/sora-watermark-dataset) out. Free free to train your custom detector model or improve our model!**
+3. **One-click portable build is available** — [Download here](#3-one-click-portable-version) for Windows users! No installation required.
+
+
+## 1. Method
+
+The SoraWatermarkCleaner(we call it `SoraWm` later) is composed of two parsts:
+
+- SoraWaterMarkDetector: We trained a yolov11s version to detect the sora watermark. (Thank you yolo!)
+
+- WaterMarkCleaner: We refer iopaint's implementation for watermark removal using the lama model.
+
+  (This codebase is from https://github.com/Sanster/IOPaint#, thanks for their amazing work!)
+
+Our SoraWm is purely deeplearning driven and yields good results in many generated videos.
+
+
+
+## 2. Installation
+
+[FFmpeg](https://ffmpeg.org/) is needed for video processing, please install it first.  We highly recommend using the `uv` to install the environments:
+
+1. installation:
+
+```bash
+uv sync
+```
+
+> now the envs will be installed at the `.venv`, you can activate the env using:
+>
+> ```bash
+> source .venv/bin/activate
+> ```
+
+2. Downloaded the pretrained models:
+
+The trained yolo weights will be stored in the `resources` dir as the `best.pt`.  And it will be automatically download from https://github.com/linkedlist771/SoraWatermarkCleaner/releases/download/V0.0.1/best.pt . The `Lama` model is downloaded from https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt, and will be stored in the torch cache dir. Both downloads are automatic, if you fail, please check your internet status.
+
+## 3. One-Click Portable Version
+
+For users who prefer a ready-to-use solution without manual installation, we provide a **one-click portable distribution** that includes all dependencies pre-configured.
+
+### Download Links
+
+**Google Drive:**
+- [Download from Google Drive](https://drive.google.com/file/d/1ujH28aHaCXGgB146g6kyfz3Qxd-wHR1c/view?usp=share_link)
+
+**Baidu Pan (百度网盘) - For users in China:**
+- Link: https://pan.baidu.com/s/1i4exYsPvXv0evnGs5MWcYA?pwd=3jr6
+- Extract Code (提取码): `3jr6`
+
+### Features
+- ✅ No installation required
+- ✅ All dependencies included
+- ✅ Pre-configured environment
+- ✅ Ready to use out of the box
+
+Simply download, extract, and run!
+
+## 4.  Demo
+
+To have a basic usage, just try the `example.py`:
+
+```python
+
+from pathlib import Path
+from sorawm.core import SoraWM
+
+
+if __name__ == "__main__":
+    input_video_path = Path(
+        "resources/dog_vs_sam.mp4"
+    )
+    output_video_path = Path("outputs/sora_watermark_removed.mp4")
+    sora_wm = SoraWM()
+    sora_wm.run(input_video_path, output_video_path)
+
+```
+
+We also provide you with a `streamlit` based interactive web page, try it with:
+
+```bash
+streamlit run app.py
+```
+
+<img src="resources/app.png" style="zoom: 25%;" />
+
+## 5. WebServer
+
+Here, we provide a **FastAPI-based web server** that can quickly turn this watermark remover into a service.
+
+Simply run:
+
+```
+python start_server.py
+```
+
+The web server will start on port **5344**.
+
+You can view the FastAPI [documentation](http://localhost:5344/docs) for more details.
+
+There are three routes available:
+
+1. **submit_remove_task**
+
+   > After uploading a video, a task ID will be returned, and the video will begin processing immediately.
+
+<img src="resources/53abf3fd-11a9-4dd7-a348-34920775f8ad.png" alt="image" style="zoom: 25%;" />
+
+2. **get_results**
+
+You can use the task ID obtained above to check the task status.
+
+It will display the percentage of video processing completed.
+
+Once finished, the returned data will include a **download URL**.
+
+3. **download**
+
+You can use the **download URL** from step 2 to retrieve the cleaned video.
+
+## 6. Datasets
+
+We have uploaded the labelled datasets into huggingface, check this out https://huggingface.co/datasets/LLinked/sora-watermark-dataset. Free free to train your custom detector model or improve our model!
+
+## 7. API
+
+Packaged as a Cog and [published to Replicate](https://replicate.com/uglyrobot/sora2-watermark-remover) for simple API based usage.
+
+## 8. License
+
+ Apache License
+
+
+## 9. Citation
+
+If you use this project, please cite:
+
+```bibtex
+@misc{sorawatermarkcleaner2025,
+  author = {linkedlist771},
+  title = {SoraWatermarkCleaner},
+  year = {2025},
+  url = {https://github.com/linkedlist771/SoraWatermarkCleaner}
+}
+```
+
+## 10. Acknowledgments
+
+- [IOPaint](https://github.com/Sanster/IOPaint) for the LAMA implementation
+- [Ultralytics YOLO](https://github.com/ultralytics/ultralytics) for object detection

+ 175 - 0
app.py

@@ -0,0 +1,175 @@
+import shutil
+import tempfile
+from pathlib import Path
+
+import requests
+import time
+import streamlit as st
+
+from sorawm.core import SoraWM
+
+
+def main():
+    st.set_page_config(
+        page_title="Sora Watermark Cleaner", page_icon="🎬", layout="centered"
+    )
+
+    st.title("🎬 Sora Watermark Cleaner")
+    st.markdown("Remove watermarks from Sora-generated videos with ease")
+
+    # 后端模式切换
+    use_backend = st.sidebar.checkbox("使用后端服务 (FastAPI)", value=True)
+    backend_url = st.sidebar.text_input(
+        "后端地址",
+        value="http://localhost:5344",
+        help="FastAPI 服务的基础地址",
+    )
+
+    # Initialize SoraWM 仅用于本地模式
+    if not use_backend:
+        if "sora_wm" not in st.session_state:
+            with st.spinner("Loading AI models..."):
+                st.session_state.sora_wm = SoraWM()
+
+    st.markdown("---")
+
+    # File uploader
+    uploaded_file = st.file_uploader(
+        "Upload your video",
+        type=["mp4", "avi", "mov", "mkv"],
+        help="Select a video file to remove watermarks",
+    )
+
+    if uploaded_file is not None:
+        # Display video info
+        st.success(f"✅ Uploaded: {uploaded_file.name}")
+        st.video(uploaded_file)
+
+        # Process button
+        if st.button("🚀 Remove Watermark", type="primary", use_container_width=True):
+            try:
+                # Create progress bar and status text
+                progress_bar = st.progress(0)
+                status_text = st.empty()
+
+                if use_backend:
+                    # 后端模式:提交任务
+                    submit_url = f"{backend_url}/submit_remove_task"
+                    video_bytes = uploaded_file.getvalue()
+                    files = {"video": (uploaded_file.name, video_bytes, "video/mp4")}
+                    status_text.text("📤 正在提交到后端...")
+                    resp = requests.post(submit_url, files=files, timeout=60)
+                    resp.raise_for_status()
+                    data = resp.json()
+                    task_id = data.get("task_id")
+                    if not task_id:
+                        raise RuntimeError("后端未返回 task_id")
+
+                    # 轮询任务状态
+                    status_url = f"{backend_url}/get_results"
+                    download_url = None
+                    while True:
+                        r = requests.get(status_url, params={"remove_task_id": task_id}, timeout=30)
+                        r.raise_for_status()
+                        info = r.json()
+                        percentage = int(info.get("percentage", 0))
+                        status = info.get("status", "UNKNOWN")
+                        progress_bar.progress(percentage / 100)
+                        if percentage < 50:
+                            status_text.text(f"🔍 后端检测水印... {percentage}% ({status})")
+                        elif percentage < 95:
+                            status_text.text(f"🧹 后端移除水印... {percentage}% ({status})")
+                        else:
+                            status_text.text(f"🎵 后端合并音频... {percentage}% ({status})")
+                        if status == "FINISHED":
+                            download_url = info.get("download_url")
+                            break
+                        elif status == "ERROR":
+                            raise RuntimeError("后端处理失败")
+                        time.sleep(1)
+
+                    if not download_url:
+                        raise RuntimeError("后端未提供下载链接")
+
+                    final_url = f"{backend_url}{download_url}"
+                    progress_bar.progress(1.0)
+                    status_text.text("✅ 后端处理完成!")
+                    st.success("✅ 水印移除成功!")
+                    st.markdown("### 处理结果")
+                    st.video(final_url)
+
+                    # 下载按钮(拉取到本地提供下载)
+                    status_text.text("⬇️ 正在拉取处理后的视频用于下载...")
+                    dl_resp = requests.get(final_url, timeout=300)
+                    dl_resp.raise_for_status()
+                    st.download_button(
+                        label="⬇️ 下载处理后视频",
+                        data=dl_resp.content,
+                        file_name=f"cleaned_{uploaded_file.name}",
+                        mime="video/mp4",
+                        use_container_width=True,
+                    )
+                else:
+                    # 本地模式处理
+                    with tempfile.TemporaryDirectory() as tmp_dir:
+                        tmp_path = Path(tmp_dir)
+
+                        # Save uploaded file
+                        input_path = tmp_path / uploaded_file.name
+                        with open(input_path, "wb") as f:
+                            f.write(uploaded_file.read())
+
+                        # Process video
+                        output_path = tmp_path / f"cleaned_{uploaded_file.name}"
+
+                        def update_progress(progress: int):
+                            progress_bar.progress(progress / 100)
+                            if progress < 50:
+                                status_text.text(f"🔍 Detecting watermarks... {progress}%")
+                            elif progress < 95:
+                                status_text.text(f"🧹 Removing watermarks... {progress}%")
+                            else:
+                                status_text.text(f"🎵 Merging audio... {progress}%")
+
+                        # Run the watermark removal with progress callback
+                        st.session_state.sora_wm.run(
+                            input_path, output_path, progress_callback=update_progress
+                        )
+
+                        # Complete the progress bar
+                        progress_bar.progress(1.0)
+                        status_text.text("✅ Processing complete!")
+
+                        st.success("✅ Watermark removed successfully!")
+
+                        # Display result
+                        st.markdown("### Result")
+                        st.video(str(output_path))
+
+                        # Download button
+                        with open(output_path, "rb") as f:
+                            st.download_button(
+                                label="⬇️ Download Cleaned Video",
+                                data=f,
+                                file_name=f"cleaned_{uploaded_file.name}",
+                                mime="video/mp4",
+                                use_container_width=True,
+                            )
+            except Exception as e:
+                st.error(f"❌ 处理失败: {str(e)}")
+
+    # Footer
+    st.markdown("---")
+    st.markdown(
+        """
+        <div style='text-align: center'>
+            <p>Built with ❤️ using Streamlit and AI</p>
+            <p><a href='https://github.com/linkedlist771/SoraWatermarkCleaner'>GitHub Repository</a></p>
+        </div>
+        """,
+        unsafe_allow_html=True,
+    )
+
+
+if __name__ == "__main__":
+    main()

BIN
data/db.sqlite3


+ 49 - 0
datasets/make_yolo_images.py

@@ -0,0 +1,49 @@
+from pathlib import Path
+
+import cv2
+from tqdm import tqdm
+
+from sorawm.configs import ROOT
+
+videos_dir = ROOT / "videos"
+datasets_dir = ROOT / "datasets"
+images_dir = datasets_dir / "images"
+images_dir.mkdir(exist_ok=True, parents=True)
+
+if __name__ == "__main__":
+    fps_save_interval = 1  # Save every 1th frame
+
+    idx = 0
+    for video_path in tqdm(list(videos_dir.rglob("*.mp4"))):
+        # Open the video file
+        cap = cv2.VideoCapture(str(video_path))
+
+        if not cap.isOpened():
+            print(f"Error opening video: {video_path}")
+            continue
+
+        frame_count = 0
+
+        while True:
+            ret, frame = cap.read()
+
+            # Break if no more frames
+            if not ret:
+                break
+
+            # Save frame at the specified interval
+            if frame_count % fps_save_interval == 0:
+                # Create filename: image_idx_framecount.jpg
+                image_filename = f"image_{idx:06d}_frame_{frame_count:06d}.jpg"
+                image_path = images_dir / image_filename
+
+                # Save the frame
+                cv2.imwrite(str(image_path), frame)
+
+            frame_count += 1
+
+        # Release the video capture object
+        cap.release()
+        idx += 1
+
+    print(f"Processed {idx} videos, extracted frames saved to {images_dir}")

+ 9 - 0
example.py

@@ -0,0 +1,9 @@
+from pathlib import Path
+
+from sorawm.core import SoraWM
+
+if __name__ == "__main__":
+    input_video_path = Path("resources/dog_vs_sam.mp4")
+    output_video_path = Path("outputs/sora_watermark_removed.mp4")
+    sora_wm = SoraWM()
+    sora_wm.run(input_video_path, output_video_path)

+ 69 - 0
ffmpeg/README.md

@@ -0,0 +1,69 @@
+# FFmpeg 可执行文件目录
+
+## 用途
+
+这个目录用于存放 FFmpeg 可执行文件,使项目成为真正的便携版(无需系统安装 FFmpeg)。
+
+## Windows 用户配置步骤
+
+### 1. 下载 FFmpeg
+
+访问 [FFmpeg-Builds Release](https://github.com/BtbN/FFmpeg-Builds/releases) 页面:
+
+- 下载最新的 `ffmpeg-master-latest-win64-gpl.zip`(约 120MB)
+- 或者下载特定版本,如 `ffmpeg-n6.1-latest-win64-gpl-6.1.zip`
+
+### 2. 解压并复制文件
+
+1. 解压下载的 zip 文件
+2. 在解压后的文件夹中找到 `bin` 目录
+3. 将以下两个文件复制到**当前目录**(`ffmpeg/`):
+   - `ffmpeg.exe` - FFmpeg 主程序
+   - `ffprobe.exe` - FFmpeg 媒体信息探测工具
+
+### 3. 验证配置
+
+完成后,此目录应包含:
+
+```
+ffmpeg/
+├── .gitkeep
+├── README.md
+├── ffmpeg.exe    ← 你复制的文件
+└── ffprobe.exe   ← 你复制的文件
+```
+
+### 4. 测试
+
+运行项目中的测试脚本验证配置:
+
+```bash
+python test_ffmpeg_setup.py
+```
+
+如果配置正确,你将看到:`✓ 测试通过!FFmpeg已正确配置并可以使用`
+
+## macOS/Linux 用户
+
+如果需要便携版,请:
+
+1. 下载对应平台的 FFmpeg 二进制文件
+2. 将 `ffmpeg` 和 `ffprobe` 可执行文件放到此目录
+3. 确保文件有执行权限:`chmod +x ffmpeg ffprobe`
+
+## 注意事项
+
+- 这些可执行文件不会被 git 提交(已在 `.gitignore` 中配置)
+- 程序会自动检测并使用此目录下的 FFmpeg
+- 如果此目录没有 FFmpeg,程序会尝试使用系统安装的版本
+
+## 下载链接汇总
+
+- **Windows**: https://github.com/BtbN/FFmpeg-Builds/releases
+- **官方网站**: https://ffmpeg.org/download.html
+- **镜像站点**: https://www.gyan.dev/ffmpeg/builds/ (Windows)
+
+## 许可证
+
+FFmpeg 使用 GPL 许可证,请遵守相关条款。
+

+ 43 - 0
logs/log_file.log

@@ -0,0 +1,43 @@
+2025-10-29 15:15:58.739 | INFO     | __main__:start_server:19 - Starting server at 0.0.0.0:5344
+2025-10-29 15:15:58.763 | INFO     | sorawm.server.lifespan:lifespan:13 - Starting up...
+2025-10-29 15:15:58.765 | INFO     | __main__:start_server:26 - Server shutdown.
+2025-10-29 15:16:46.273 | INFO     | __main__:start_server:19 - Starting server at 0.0.0.0:5344
+2025-10-29 15:16:46.283 | INFO     | sorawm.server.lifespan:lifespan:13 - Starting up...
+2025-10-29 15:16:46.290 | INFO     | sorawm.server.lifespan:lifespan:16 - Database initialized
+2025-10-29 15:16:46.290 | INFO     | sorawm.server.worker:initialize:26 - Initializing SoraWM models...
+2025-10-29 15:16:46.290 | DEBUG    | sorawm.watermark_detector:__init__:18 - Begin to load yolo water mark detet model.
+2025-10-29 15:16:46.335 | DEBUG    | sorawm.utils.devices_utils:get_device:13 - Using device: mps
+2025-10-29 15:16:46.604 | DEBUG    | sorawm.watermark_detector:__init__:21 - Yolo water mark detet model loaded.
+2025-10-29 15:16:46.607 | INFO     | sorawm.iopaint.model_manager:init_model:47 - Loading model: cv2
+2025-10-29 15:16:46.607 | INFO     | sorawm.iopaint.helper:switch_mps_device:30 - cv2 not support mps, switch to cpu
+2025-10-29 15:16:46.607 | INFO     | sorawm.server.worker:initialize:28 - SoraWM models initialized
+2025-10-29 15:16:46.608 | INFO     | sorawm.server.lifespan:lifespan:22 - Application started successfully
+2025-10-29 15:16:46.608 | INFO     | sorawm.server.worker:run:64 - Worker started, waiting for tasks...
+2025-10-29 16:05:14.030 | INFO     | __main__:start_server:19 - Starting server at 0.0.0.0:5344
+2025-10-29 16:05:14.043 | INFO     | sorawm.server.lifespan:lifespan:13 - Starting up...
+2025-10-29 16:05:14.047 | INFO     | sorawm.server.lifespan:lifespan:16 - Database initialized
+2025-10-29 16:05:14.048 | INFO     | sorawm.server.worker:initialize:26 - Initializing SoraWM models...
+2025-10-29 16:05:14.048 | DEBUG    | sorawm.watermark_detector:__init__:18 - Begin to load yolo water mark detet model.
+2025-10-29 16:05:14.132 | DEBUG    | sorawm.utils.devices_utils:get_device:13 - Using device: mps
+2025-10-29 16:05:14.227 | DEBUG    | sorawm.watermark_detector:__init__:21 - Yolo water mark detet model loaded.
+2025-10-29 16:05:14.230 | INFO     | sorawm.iopaint.model_manager:init_model:47 - Loading model: cv2
+2025-10-29 16:05:14.230 | INFO     | sorawm.iopaint.helper:switch_mps_device:30 - cv2 not support mps, switch to cpu
+2025-10-29 16:05:14.230 | INFO     | sorawm.server.worker:initialize:28 - SoraWM models initialized
+2025-10-29 16:05:14.231 | INFO     | sorawm.server.lifespan:lifespan:22 - Application started successfully
+2025-10-29 16:05:14.231 | INFO     | sorawm.server.worker:run:64 - Worker started, waiting for tasks...
+2025-10-29 16:06:43.138 | INFO     | __main__:start_server:19 - Starting server at 0.0.0.0:5344
+2025-10-29 16:06:43.148 | INFO     | sorawm.server.lifespan:lifespan:13 - Starting up...
+2025-10-29 16:06:43.150 | INFO     | sorawm.server.lifespan:lifespan:16 - Database initialized
+2025-10-29 16:06:43.150 | INFO     | sorawm.server.worker:initialize:26 - Initializing SoraWM models...
+2025-10-29 16:06:43.150 | DEBUG    | sorawm.watermark_detector:__init__:18 - Begin to load yolo water mark detet model.
+2025-10-29 16:06:43.190 | DEBUG    | sorawm.utils.devices_utils:get_device:13 - Using device: mps
+2025-10-29 16:06:43.279 | DEBUG    | sorawm.watermark_detector:__init__:21 - Yolo water mark detet model loaded.
+2025-10-29 16:06:43.280 | INFO     | sorawm.iopaint.model_manager:init_model:47 - Loading model: cv2
+2025-10-29 16:06:43.280 | INFO     | sorawm.iopaint.helper:switch_mps_device:30 - cv2 not support mps, switch to cpu
+2025-10-29 16:06:43.280 | INFO     | sorawm.server.worker:initialize:28 - SoraWM models initialized
+2025-10-29 16:06:43.280 | INFO     | sorawm.server.lifespan:lifespan:22 - Application started successfully
+2025-10-29 16:06:43.280 | INFO     | sorawm.server.worker:run:64 - Worker started, waiting for tasks...
+2025-10-29 21:53:46.469 | INFO     | sorawm.server.worker:create_task:40 - Task d75ba3bc-f8ea-4c54-a183-35394d3609dc created with UPLOADING status
+2025-10-29 21:53:46.516 | INFO     | sorawm.server.worker:queue_task:52 - Task d75ba3bc-f8ea-4c54-a183-35394d3609dc queued for processing: /Users/max_liu/Downloads/SoraWatermarkCleaner-main/working_dir/uploads/b061e57b-3f18-4d2a-96c4-d42e4b3d76e5_4_1709482478.mp4
+2025-10-29 21:53:46.517 | INFO     | sorawm.server.worker:run:67 - Processing task d75ba3bc-f8ea-4c54-a183-35394d3609dc: /Users/max_liu/Downloads/SoraWatermarkCleaner-main/working_dir/uploads/b061e57b-3f18-4d2a-96c4-d42e4b3d76e5_4_1709482478.mp4
+2025-10-29 21:53:46.728 | ERROR    | sorawm.server.worker:run:109 - Error processing task d75ba3bc-f8ea-4c54-a183-35394d3609dc: [Errno 2] No such file or directory: 'ffprobe'

Diff do ficheiro suprimidas por serem muito extensas
+ 48 - 0
notebooks/imputation.ipynb


+ 26 - 0
one-click-portable.md

@@ -0,0 +1,26 @@
+# One-Click Portable Version | 一键便携版
+
+For **Windows** users - No installation required! 
+
+适用于 **Windows** 用户 - 无需安装!
+
+## Download | 下载
+
+**Google Drive:**
+- https://drive.google.com/file/d/1ujH28aHaCXGgB146g6kyfz3Qxd-wHR1c/view?usp=share_link
+
+**Baidu Pan | 百度网盘:**
+- Link | 链接: https://pan.baidu.com/s/1_tdgs-3-dLNn0IbufIM75g?pwd=fiju
+- Extract Code | 提取码: `fiju`
+
+## Usage | 使用方法
+
+1. Download and extract the zip file | 下载并解压 zip 文件
+2. Double-click `run.bat` | 双击 `run.bat` 文件
+3. The web service will start automatically! | 网页服务将自动启动!
+
+## Features | 特点
+
+- ✅ Zero installation | 无需安装
+- ✅ All dependencies included | 包含所有依赖
+- ✅ Ready to use | 开箱即用

+ 40 - 0
pyproject.toml

@@ -0,0 +1,40 @@
+[project]
+name = "sorawatermarkcleaner"
+version = "0.1.0"
+description = "Add your description here"
+readme = "README.md"
+requires-python = ">=3.12"
+dependencies = [
+    "aiofiles>=24.1.0",
+    "aiosqlite>=0.21.0",
+    "diffusers>=0.35.1",
+    "einops>=0.8.1",
+    "fastapi==0.108.0",
+    "ffmpeg-python>=0.2.0",
+    "fire>=0.7.1",
+    "httpx>=0.28.1",
+    "huggingface-hub>=0.35.3",
+    "jupyter>=1.1.1",
+    "loguru>=0.7.3",
+    "matplotlib>=3.10.6",
+    "notebook>=7.4.7",
+    "omegaconf>=2.3.0",
+    "opencv-python>=4.12.0.88",
+    "pandas>=2.3.3",
+    "pydantic>=2.11.10",
+    "python-multipart>=0.0.20",
+    "requests>=2.32.5",
+    "ruptures>=1.1.10",
+    "scikit-learn>=1.7.2",
+    "sqlalchemy>=2.0.43",
+    "greenlet>=3.0.3",
+    "streamlit>=1.50.0",
+    "torch>=2.5.0",
+    "torchvision>=0.20.0",
+    "tqdm>=4.67.1",
+    "transformers>=4.57.0",
+    "ultralytics>=8.3.204",
+    "uuid>=1.30",
+    "uvicorn>=0.35.0",
+]
+

BIN
resources/19700121_1645_68e0a027836c8191a50bea3717ea7485.mp4


BIN
resources/53abf3fd-11a9-4dd7-a348-34920775f8ad.png


BIN
resources/app.png


BIN
resources/best.pt


BIN
resources/dog_vs_sam.mp4


Diff do ficheiro suprimidas por serem muito extensas
+ 23 - 0
resources/first_frame.json


BIN
resources/first_frame.png


BIN
resources/puppies.mp4


BIN
resources/sora_watermark_removed.mp4


BIN
resources/watermark_template.png


+ 0 - 0
sorawm/__init__.py


BIN
sorawm/__pycache__/__init__.cpython-312.pyc


BIN
sorawm/__pycache__/configs.cpython-312.pyc


BIN
sorawm/__pycache__/core.cpython-312.pyc


BIN
sorawm/__pycache__/watermark_cleaner.cpython-312.pyc


BIN
sorawm/__pycache__/watermark_detector.cpython-312.pyc


+ 27 - 0
sorawm/configs.py

@@ -0,0 +1,27 @@
+from pathlib import Path
+
+ROOT = Path(__file__).parent.parent
+
+
+RESOURCES_DIR = ROOT / "resources"
+WATER_MARK_TEMPLATE_IMAGE_PATH = RESOURCES_DIR / "watermark_template.png"
+
+WATER_MARK_DETECT_YOLO_WEIGHTS = RESOURCES_DIR / "best.pt"
+
+OUTPUT_DIR = ROOT / "output"
+
+OUTPUT_DIR.mkdir(exist_ok=True, parents=True)
+
+
+DEFAULT_WATERMARK_REMOVE_MODEL = "cv2"
+
+WORKING_DIR = ROOT / "working_dir"
+WORKING_DIR.mkdir(exist_ok=True, parents=True)
+
+LOGS_PATH = ROOT / "logs"
+LOGS_PATH.mkdir(exist_ok=True, parents=True)
+
+DATA_PATH = ROOT / "data"
+DATA_PATH.mkdir(exist_ok=True, parents=True)
+
+SQLITE_PATH = DATA_PATH / "db.sqlite3"

+ 199 - 0
sorawm/core.py

@@ -0,0 +1,199 @@
+from pathlib import Path
+from typing import Callable
+
+import ffmpeg
+import numpy as np
+from loguru import logger
+from tqdm import tqdm
+
+from sorawm.utils.video_utils import VideoLoader
+from sorawm.watermark_cleaner import WaterMarkCleaner
+from sorawm.watermark_detector import SoraWaterMarkDetector
+from sorawm.utils.imputation_utils import (
+    find_2d_data_bkps,
+    get_interval_average_bbox,
+    find_idxs_interval,
+)
+
+
+class SoraWM:
+    def __init__(self):
+        self.detector = SoraWaterMarkDetector()
+        self.cleaner = WaterMarkCleaner()
+
+    def run(
+        self,
+        input_video_path: Path,
+        output_video_path: Path,
+        progress_callback: Callable[[int], None] | None = None,
+    ):
+        input_video_loader = VideoLoader(input_video_path)
+        output_video_path.parent.mkdir(parents=True, exist_ok=True)
+        width = input_video_loader.width
+        height = input_video_loader.height
+        fps = input_video_loader.fps
+        total_frames = input_video_loader.total_frames
+
+        temp_output_path = output_video_path.parent / f"temp_{output_video_path.name}"
+        output_options = {
+            "pix_fmt": "yuv420p",
+            "vcodec": "libx264",
+            "preset": "slow",
+        }
+
+        if input_video_loader.original_bitrate:
+            output_options["video_bitrate"] = str(
+                int(int(input_video_loader.original_bitrate) * 1.2)
+            )
+        else:
+            output_options["crf"] = "18"
+
+        process_out = (
+            ffmpeg.input(
+                "pipe:",
+                format="rawvideo",
+                pix_fmt="bgr24",
+                s=f"{width}x{height}",
+                r=fps,
+            )
+            .output(str(temp_output_path), **output_options)
+            .overwrite_output()
+            .global_args("-loglevel", "error")
+            .run_async(pipe_stdin=True)
+        )
+
+        frame_bboxes = {}
+        detect_missed = []
+        bbox_centers = []
+        bboxes = []
+
+        logger.debug(
+            f"total frames: {total_frames}, fps: {fps}, width: {width}, height: {height}"
+        )
+        for idx, frame in enumerate(
+            tqdm(input_video_loader, total=total_frames, desc="Detect watermarks")
+        ):
+            detection_result = self.detector.detect(frame)
+            if detection_result["detected"]:
+                frame_bboxes[idx] = { "bbox": detection_result["bbox"]}
+                x1, y1, x2, y2 = detection_result["bbox"]
+                bbox_centers.append((int((x1 + x2) / 2), int((y1 + y2) / 2)))
+                bboxes.append((x1, y1, x2, y2))
+
+            else:
+                frame_bboxes[idx] = {"bbox": None}
+                detect_missed.append(idx)
+                bbox_centers.append(None)
+                bboxes.append(None)
+            # 10% - 50%
+            if progress_callback and idx % 10 == 0:
+                progress = 10 + int((idx / total_frames) * 40)
+                progress_callback(progress)
+
+        logger.debug(f"detect missed frames: {detect_missed}")
+        # logger.debug(f"bbox centers: \n{bbox_centers}")
+        if detect_missed:
+            # 1. find the bkps of the bbox centers
+            bkps = find_2d_data_bkps(bbox_centers)
+            # add the start and end position, to form the complete interval boundaries
+            bkps_full = [0] + bkps + [total_frames]
+            # logger.debug(f"bkps intervals: {bkps_full}")
+
+            # 2. calculate the average bbox of each interval
+            interval_bboxes = get_interval_average_bbox(bboxes, bkps_full)
+            # logger.debug(f"interval average bboxes: {interval_bboxes}")
+
+            # 3. find the interval index of each missed frame
+            missed_intervals = find_idxs_interval(detect_missed, bkps_full)
+            # logger.debug(
+            #     f"missed frame intervals: {list(zip(detect_missed, missed_intervals))}"
+            # )
+
+            # 4. fill the missed frames with the average bbox of the corresponding interval
+            for missed_idx, interval_idx in zip(detect_missed, missed_intervals):
+                if (
+                    interval_idx < len(interval_bboxes)
+                    and interval_bboxes[interval_idx] is not None
+                ):
+                    frame_bboxes[missed_idx]["bbox"] = interval_bboxes[interval_idx]
+                    logger.debug(f"Filled missed frame {missed_idx} with bbox:\n"
+                    f" {interval_bboxes[interval_idx]}")
+                else:
+                    # if the interval has no valid bbox, use the previous and next frame to complete (fallback strategy)
+                    before = max(missed_idx - 1, 0)
+                    after = min(missed_idx + 1, total_frames - 1)
+                    before_box = frame_bboxes[before]["bbox"]
+                    after_box = frame_bboxes[after]["bbox"]
+                    if before_box:
+                        frame_bboxes[missed_idx]["bbox"] = before_box
+                    elif after_box:
+                        frame_bboxes[missed_idx]["bbox"] = after_box
+        else:
+            del bboxes
+            del bbox_centers
+            del detect_missed
+        
+        input_video_loader = VideoLoader(input_video_path)
+
+        for idx, frame in enumerate(tqdm(input_video_loader, total=total_frames, desc="Remove watermarks")):
+        # for idx in tqdm(range(total_frames), desc="Remove watermarks"):
+            # frame_info = 
+            bbox = frame_bboxes[idx]["bbox"]
+            if bbox is not None:
+                x1, y1, x2, y2 = bbox
+                mask = np.zeros((height, width), dtype=np.uint8)
+                mask[y1:y2, x1:x2] = 255
+                cleaned_frame = self.cleaner.clean(frame, mask)
+            else:
+                cleaned_frame = frame
+            process_out.stdin.write(cleaned_frame.tobytes())
+
+            # 50% - 95%
+            if progress_callback and idx % 10 == 0:
+                progress = 50 + int((idx / total_frames) * 45)
+                progress_callback(progress)
+
+        process_out.stdin.close()
+        process_out.wait()
+
+        # 95% - 99%
+        if progress_callback:
+            progress_callback(95)
+
+        self.merge_audio_track(input_video_path, temp_output_path, output_video_path)
+
+        if progress_callback:
+            progress_callback(99)
+
+    def merge_audio_track(
+        self, input_video_path: Path, temp_output_path: Path, output_video_path: Path
+    ):
+        logger.info("Merging audio track...")
+        video_stream = ffmpeg.input(str(temp_output_path))
+        audio_stream = ffmpeg.input(str(input_video_path)).audio
+
+        (
+            ffmpeg.output(
+                video_stream,
+                audio_stream,
+                str(output_video_path),
+                vcodec="copy",
+                acodec="aac",
+            )
+            .overwrite_output()
+            .run(quiet=True)
+        )
+        # Clean up temporary file
+        temp_output_path.unlink()
+        logger.info(f"Saved no watermark video with audio at: {output_video_path}")
+
+
+if __name__ == "__main__":
+    from pathlib import Path
+
+    input_video_path = Path(
+        "resources/19700121_1645_68e0a027836c8191a50bea3717ea7485.mp4"
+    )
+    output_video_path = Path("outputs/sora_watermark_removed.mp4")
+    sora_wm = SoraWM()
+    sora_wm.run(input_video_path, output_video_path)

+ 56 - 0
sorawm/iopaint/__init__.py

@@ -0,0 +1,56 @@
+import ctypes
+import importlib.util
+import logging
+import os
+import shutil
+
+os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
+# https://github.com/pytorch/pytorch/issues/27971#issuecomment-1768868068
+os.environ["ONEDNN_PRIMITIVE_CACHE_CAPACITY"] = "1"
+os.environ["LRU_CACHE_CAPACITY"] = "1"
+# prevent CPU memory leak when run model on GPU
+# https://github.com/pytorch/pytorch/issues/98688#issuecomment-1869288431
+# https://github.com/pytorch/pytorch/issues/108334#issuecomment-1752763633
+os.environ["TORCH_CUDNN_V8_API_LRU_CACHE_LIMIT"] = "1"
+
+import warnings
+
+warnings.simplefilter("ignore", UserWarning)
+
+
+def fix_window_pytorch():
+    # copy from: https://github.com/comfyanonymous/ComfyUI/blob/5cbaa9e07c97296b536f240688f5a19300ecf30d/fix_torch.py#L4
+    import platform
+
+    try:
+        if platform.system() != "Windows":
+            return
+        torch_spec = importlib.util.find_spec("torch")
+        for folder in torch_spec.submodule_search_locations:
+            lib_folder = os.path.join(folder, "lib")
+            test_file = os.path.join(lib_folder, "fbgemm.dll")
+            dest = os.path.join(lib_folder, "libomp140.x86_64.dll")
+            if os.path.exists(dest):
+                break
+
+            with open(test_file, "rb") as f:
+                contents = f.read()
+                if b"libomp140.x86_64.dll" not in contents:
+                    break
+            try:
+                mydll = ctypes.cdll.LoadLibrary(test_file)
+            except FileNotFoundError:
+                logging.warning("Detected pytorch version with libomp issue, patching.")
+                shutil.copyfile(os.path.join(lib_folder, "libiomp5md.dll"), dest)
+    except:
+        pass
+
+
+def entry_point():
+    # To make os.environ["XDG_CACHE_HOME"] = args.model_cache_dir works for diffusers
+    # https://github.com/huggingface/diffusers/blob/be99201a567c1ccd841dc16fb24e88f7f239c187/src/diffusers/utils/constants.py#L18
+    from sorawm.iopaint.cli import typer_app
+
+    fix_window_pytorch()
+
+    typer_app()

+ 4 - 0
sorawm/iopaint/__main__.py

@@ -0,0 +1,4 @@
+from iopaint import entry_point
+
+if __name__ == "__main__":
+    entry_point()

BIN
sorawm/iopaint/__pycache__/__init__.cpython-312.pyc


BIN
sorawm/iopaint/__pycache__/const.cpython-312.pyc


BIN
sorawm/iopaint/__pycache__/download.cpython-312.pyc


BIN
sorawm/iopaint/__pycache__/helper.cpython-312.pyc


BIN
sorawm/iopaint/__pycache__/model_manager.cpython-312.pyc


BIN
sorawm/iopaint/__pycache__/schema.cpython-312.pyc


+ 411 - 0
sorawm/iopaint/api.py

@@ -0,0 +1,411 @@
+import asyncio
+import os
+import threading
+import time
+import traceback
+from pathlib import Path
+from typing import Dict, List, Optional
+
+import cv2
+import numpy as np
+import socketio
+import torch
+
+try:
+    torch._C._jit_override_can_fuse_on_cpu(False)
+    torch._C._jit_override_can_fuse_on_gpu(False)
+    torch._C._jit_set_texpr_fuser_enabled(False)
+    torch._C._jit_set_nvfuser_enabled(False)
+    torch._C._jit_set_profiling_mode(False)
+except:
+    pass
+
+import uvicorn
+from fastapi import APIRouter, FastAPI, Request, UploadFile
+from fastapi.encoders import jsonable_encoder
+from fastapi.exceptions import HTTPException
+from fastapi.middleware.cors import CORSMiddleware
+from fastapi.responses import FileResponse, JSONResponse, Response
+from fastapi.staticfiles import StaticFiles
+from loguru import logger
+from PIL import Image
+from socketio import AsyncServer
+
+from sorawm.iopaint.file_manager import FileManager
+from sorawm.iopaint.helper import (
+    adjust_mask,
+    concat_alpha_channel,
+    decode_base64_to_image,
+    gen_frontend_mask,
+    load_img,
+    numpy_to_bytes,
+    pil_to_bytes,
+)
+from sorawm.iopaint.model.utils import torch_gc
+from sorawm.iopaint.model_manager import ModelManager
+from sorawm.iopaint.plugins import InteractiveSeg, RealESRGANUpscaler, build_plugins
+from sorawm.iopaint.plugins.base_plugin import BasePlugin
+from sorawm.iopaint.plugins.remove_bg import RemoveBG
+from sorawm.iopaint.schema import (
+    AdjustMaskRequest,
+    ApiConfig,
+    GenInfoResponse,
+    InpaintRequest,
+    InteractiveSegModel,
+    ModelInfo,
+    PluginInfo,
+    RealESRGANModel,
+    RemoveBGModel,
+    RunPluginRequest,
+    SDSampler,
+    ServerConfigResponse,
+    SwitchModelRequest,
+    SwitchPluginModelRequest,
+)
+
+CURRENT_DIR = Path(__file__).parent.absolute().resolve()
+WEB_APP_DIR = CURRENT_DIR / "web_app"
+
+
+def api_middleware(app: FastAPI):
+    rich_available = False
+    try:
+        if os.environ.get("WEBUI_RICH_EXCEPTIONS", None) is not None:
+            import anyio  # importing just so it can be placed on silent list
+            import starlette  # importing just so it can be placed on silent list
+            from rich.console import Console
+
+            console = Console()
+            rich_available = True
+    except Exception:
+        pass
+
+    def handle_exception(request: Request, e: Exception):
+        err = {
+            "error": type(e).__name__,
+            "detail": vars(e).get("detail", ""),
+            "body": vars(e).get("body", ""),
+            "errors": str(e),
+        }
+        if not isinstance(
+            e, HTTPException
+        ):  # do not print backtrace on known httpexceptions
+            message = f"API error: {request.method}: {request.url} {err}"
+            if rich_available:
+                print(message)
+                console.print_exception(
+                    show_locals=True,
+                    max_frames=2,
+                    extra_lines=1,
+                    suppress=[anyio, starlette],
+                    word_wrap=False,
+                    width=min([console.width, 200]),
+                )
+            else:
+                traceback.print_exc()
+        return JSONResponse(
+            status_code=vars(e).get("status_code", 500), content=jsonable_encoder(err)
+        )
+
+    @app.middleware("http")
+    async def exception_handling(request: Request, call_next):
+        try:
+            return await call_next(request)
+        except Exception as e:
+            return handle_exception(request, e)
+
+    @app.exception_handler(Exception)
+    async def fastapi_exception_handler(request: Request, e: Exception):
+        return handle_exception(request, e)
+
+    @app.exception_handler(HTTPException)
+    async def http_exception_handler(request: Request, e: HTTPException):
+        return handle_exception(request, e)
+
+    cors_options = {
+        "allow_methods": ["*"],
+        "allow_headers": ["*"],
+        "allow_origins": ["*"],
+        "allow_credentials": True,
+        "expose_headers": ["X-Seed"],
+    }
+    app.add_middleware(CORSMiddleware, **cors_options)
+
+
+global_sio: AsyncServer = None
+
+
+def diffuser_callback(pipe, step: int, timestep: int, callback_kwargs: Dict = {}):
+    # self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict
+    # logger.info(f"diffusion callback: step={step}, timestep={timestep}")
+
+    # We use asyncio loos for task processing. Perhaps in the future, we can add a processing queue similar to InvokeAI,
+    # but for now let's just start a separate event loop. It shouldn't make a difference for single person use
+    asyncio.run(global_sio.emit("diffusion_progress", {"step": step}))
+    return {}
+
+
+class Api:
+    def __init__(self, app: FastAPI, config: ApiConfig):
+        self.app = app
+        self.config = config
+        self.router = APIRouter()
+        self.queue_lock = threading.Lock()
+        api_middleware(self.app)
+
+        self.file_manager = self._build_file_manager()
+        self.plugins = self._build_plugins()
+        self.model_manager = self._build_model_manager()
+
+        # fmt: off
+        self.add_api_route("/api/v1/gen-info", self.api_geninfo, methods=["POST"], response_model=GenInfoResponse)
+        self.add_api_route("/api/v1/server-config", self.api_server_config, methods=["GET"],
+                           response_model=ServerConfigResponse)
+        self.add_api_route("/api/v1/model", self.api_current_model, methods=["GET"], response_model=ModelInfo)
+        self.add_api_route("/api/v1/model", self.api_switch_model, methods=["POST"], response_model=ModelInfo)
+        self.add_api_route("/api/v1/inputimage", self.api_input_image, methods=["GET"])
+        self.add_api_route("/api/v1/inpaint", self.api_inpaint, methods=["POST"])
+        self.add_api_route("/api/v1/switch_plugin_model", self.api_switch_plugin_model, methods=["POST"])
+        self.add_api_route("/api/v1/run_plugin_gen_mask", self.api_run_plugin_gen_mask, methods=["POST"])
+        self.add_api_route("/api/v1/run_plugin_gen_image", self.api_run_plugin_gen_image, methods=["POST"])
+        self.add_api_route("/api/v1/samplers", self.api_samplers, methods=["GET"])
+        self.add_api_route("/api/v1/adjust_mask", self.api_adjust_mask, methods=["POST"])
+        self.add_api_route("/api/v1/save_image", self.api_save_image, methods=["POST"])
+        self.app.mount("/", StaticFiles(directory=WEB_APP_DIR, html=True), name="assets")
+        # fmt: on
+
+        global global_sio
+        self.sio = socketio.AsyncServer(async_mode="asgi", cors_allowed_origins="*")
+        self.combined_asgi_app = socketio.ASGIApp(self.sio, self.app)
+        self.app.mount("/ws", self.combined_asgi_app)
+        global_sio = self.sio
+
+    def add_api_route(self, path: str, endpoint, **kwargs):
+        return self.app.add_api_route(path, endpoint, **kwargs)
+
+    def api_save_image(self, file: UploadFile):
+        # Sanitize filename to prevent path traversal
+        safe_filename = Path(file.filename).name  # Get just the filename component
+
+        # Construct the full path within output_dir
+        output_path = self.config.output_dir / safe_filename
+
+        # Ensure output directory exists
+        if not self.config.output_dir or not self.config.output_dir.exists():
+            raise HTTPException(
+                status_code=400,
+                detail="Output directory not configured or doesn't exist",
+            )
+
+        # Read and write the file
+        origin_image_bytes = file.file.read()
+        with open(output_path, "wb") as fw:
+            fw.write(origin_image_bytes)
+
+    def api_current_model(self) -> ModelInfo:
+        return self.model_manager.current_model
+
+    def api_switch_model(self, req: SwitchModelRequest) -> ModelInfo:
+        if req.name == self.model_manager.name:
+            return self.model_manager.current_model
+        self.model_manager.switch(req.name)
+        return self.model_manager.current_model
+
+    def api_switch_plugin_model(self, req: SwitchPluginModelRequest):
+        if req.plugin_name in self.plugins:
+            self.plugins[req.plugin_name].switch_model(req.model_name)
+            if req.plugin_name == RemoveBG.name:
+                self.config.remove_bg_model = req.model_name
+            if req.plugin_name == RealESRGANUpscaler.name:
+                self.config.realesrgan_model = req.model_name
+            if req.plugin_name == InteractiveSeg.name:
+                self.config.interactive_seg_model = req.model_name
+            torch_gc()
+
+    def api_server_config(self) -> ServerConfigResponse:
+        plugins = []
+        for it in self.plugins.values():
+            plugins.append(
+                PluginInfo(
+                    name=it.name,
+                    support_gen_image=it.support_gen_image,
+                    support_gen_mask=it.support_gen_mask,
+                )
+            )
+
+        return ServerConfigResponse(
+            plugins=plugins,
+            modelInfos=self.model_manager.scan_models(),
+            removeBGModel=self.config.remove_bg_model,
+            removeBGModels=RemoveBGModel.values(),
+            realesrganModel=self.config.realesrgan_model,
+            realesrganModels=RealESRGANModel.values(),
+            interactiveSegModel=self.config.interactive_seg_model,
+            interactiveSegModels=InteractiveSegModel.values(),
+            enableFileManager=self.file_manager is not None,
+            enableAutoSaving=self.config.output_dir is not None,
+            enableControlnet=self.model_manager.enable_controlnet,
+            controlnetMethod=self.model_manager.controlnet_method,
+            disableModelSwitch=False,
+            isDesktop=False,
+            samplers=self.api_samplers(),
+        )
+
+    def api_input_image(self) -> FileResponse:
+        if self.config.input is None:
+            raise HTTPException(status_code=200, detail="No input image configured")
+
+        if self.config.input.is_file():
+            return FileResponse(self.config.input)
+        raise HTTPException(status_code=404, detail="Input image not found")
+
+    def api_geninfo(self, file: UploadFile) -> GenInfoResponse:
+        _, _, info = load_img(file.file.read(), return_info=True)
+        parts = info.get("parameters", "").split("Negative prompt: ")
+        prompt = parts[0].strip()
+        negative_prompt = ""
+        if len(parts) > 1:
+            negative_prompt = parts[1].split("\n")[0].strip()
+        return GenInfoResponse(prompt=prompt, negative_prompt=negative_prompt)
+
+    def api_inpaint(self, req: InpaintRequest):
+        image, alpha_channel, infos, ext = decode_base64_to_image(req.image)
+        mask, _, _, _ = decode_base64_to_image(req.mask, gray=True)
+        logger.info(f"image ext: {ext}")
+
+        mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)[1]
+        if image.shape[:2] != mask.shape[:2]:
+            raise HTTPException(
+                400,
+                detail=f"Image size({image.shape[:2]}) and mask size({mask.shape[:2]}) not match.",
+            )
+
+        start = time.time()
+        rgb_np_img = self.model_manager(image, mask, req)
+        logger.info(f"process time: {(time.time() - start) * 1000:.2f}ms")
+        torch_gc()
+
+        rgb_np_img = cv2.cvtColor(rgb_np_img.astype(np.uint8), cv2.COLOR_BGR2RGB)
+        rgb_res = concat_alpha_channel(rgb_np_img, alpha_channel)
+
+        res_img_bytes = pil_to_bytes(
+            Image.fromarray(rgb_res),
+            ext=ext,
+            quality=self.config.quality,
+            infos=infos,
+        )
+
+        asyncio.run(self.sio.emit("diffusion_finish"))
+
+        return Response(
+            content=res_img_bytes,
+            media_type=f"image/{ext}",
+            headers={"X-Seed": str(req.sd_seed)},
+        )
+
+    def api_run_plugin_gen_image(self, req: RunPluginRequest):
+        ext = "png"
+        if req.name not in self.plugins:
+            raise HTTPException(status_code=422, detail="Plugin not found")
+        if not self.plugins[req.name].support_gen_image:
+            raise HTTPException(
+                status_code=422, detail="Plugin does not support output image"
+            )
+        rgb_np_img, alpha_channel, infos, _ = decode_base64_to_image(req.image)
+        bgr_or_rgba_np_img = self.plugins[req.name].gen_image(rgb_np_img, req)
+        torch_gc()
+
+        if bgr_or_rgba_np_img.shape[2] == 4:
+            rgba_np_img = bgr_or_rgba_np_img
+        else:
+            rgba_np_img = cv2.cvtColor(bgr_or_rgba_np_img, cv2.COLOR_BGR2RGB)
+            rgba_np_img = concat_alpha_channel(rgba_np_img, alpha_channel)
+
+        return Response(
+            content=pil_to_bytes(
+                Image.fromarray(rgba_np_img),
+                ext=ext,
+                quality=self.config.quality,
+                infos=infos,
+            ),
+            media_type=f"image/{ext}",
+        )
+
+    def api_run_plugin_gen_mask(self, req: RunPluginRequest):
+        if req.name not in self.plugins:
+            raise HTTPException(status_code=422, detail="Plugin not found")
+        if not self.plugins[req.name].support_gen_mask:
+            raise HTTPException(
+                status_code=422, detail="Plugin does not support output image"
+            )
+        rgb_np_img, _, _, _ = decode_base64_to_image(req.image)
+        bgr_or_gray_mask = self.plugins[req.name].gen_mask(rgb_np_img, req)
+        torch_gc()
+        res_mask = gen_frontend_mask(bgr_or_gray_mask)
+        return Response(
+            content=numpy_to_bytes(res_mask, "png"),
+            media_type="image/png",
+        )
+
+    def api_samplers(self) -> List[str]:
+        return [member.value for member in SDSampler.__members__.values()]
+
+    def api_adjust_mask(self, req: AdjustMaskRequest):
+        mask, _, _, _ = decode_base64_to_image(req.mask, gray=True)
+        mask = adjust_mask(mask, req.kernel_size, req.operate)
+        return Response(content=numpy_to_bytes(mask, "png"), media_type="image/png")
+
+    def launch(self):
+        self.app.include_router(self.router)
+        uvicorn.run(
+            self.combined_asgi_app,
+            host=self.config.host,
+            port=self.config.port,
+            timeout_keep_alive=999999999,
+        )
+
+    def _build_file_manager(self) -> Optional[FileManager]:
+        if self.config.input and self.config.input.is_dir():
+            logger.info(
+                f"Input is directory, initialize file manager {self.config.input}"
+            )
+
+            return FileManager(
+                app=self.app,
+                input_dir=self.config.input,
+                mask_dir=self.config.mask_dir,
+                output_dir=self.config.output_dir,
+            )
+        return None
+
+    def _build_plugins(self) -> Dict[str, BasePlugin]:
+        return build_plugins(
+            self.config.enable_interactive_seg,
+            self.config.interactive_seg_model,
+            self.config.interactive_seg_device,
+            self.config.enable_remove_bg,
+            self.config.remove_bg_device,
+            self.config.remove_bg_model,
+            self.config.enable_anime_seg,
+            self.config.enable_realesrgan,
+            self.config.realesrgan_device,
+            self.config.realesrgan_model,
+            self.config.enable_gfpgan,
+            self.config.gfpgan_device,
+            self.config.enable_restoreformer,
+            self.config.restoreformer_device,
+            self.config.no_half,
+        )
+
+    def _build_model_manager(self):
+        return ModelManager(
+            name=self.config.model,
+            device=torch.device(self.config.device),
+            no_half=self.config.no_half,
+            low_mem=self.config.low_mem,
+            disable_nsfw=self.config.disable_nsfw_checker,
+            sd_cpu_textencoder=self.config.cpu_textencoder,
+            local_files_only=self.config.local_files_only,
+            cpu_offload=self.config.cpu_offload,
+            callback=diffuser_callback,
+        )

+ 128 - 0
sorawm/iopaint/batch_processing.py

@@ -0,0 +1,128 @@
+import json
+from pathlib import Path
+from typing import Dict, Optional
+
+import cv2
+import numpy as np
+from loguru import logger
+from PIL import Image
+from rich.console import Console
+from rich.progress import (
+    BarColumn,
+    MofNCompleteColumn,
+    Progress,
+    SpinnerColumn,
+    TaskProgressColumn,
+    TextColumn,
+    TimeElapsedColumn,
+)
+
+from sorawm.iopaint.helper import pil_to_bytes
+from sorawm.iopaint.model.utils import torch_gc
+from sorawm.iopaint.model_manager import ModelManager
+from sorawm.iopaint.schema import InpaintRequest
+
+
+def glob_images(path: Path) -> Dict[str, Path]:
+    # png/jpg/jpeg
+    if path.is_file():
+        return {path.stem: path}
+    elif path.is_dir():
+        res = {}
+        for it in path.glob("*.*"):
+            if it.suffix.lower() in [".png", ".jpg", ".jpeg"]:
+                res[it.stem] = it
+        return res
+
+
+def batch_inpaint(
+    model: str,
+    device,
+    image: Path,
+    mask: Path,
+    output: Path,
+    config: Optional[Path] = None,
+    concat: bool = False,
+):
+    if image.is_dir() and output.is_file():
+        logger.error(
+            "invalid --output: when image is a directory, output should be a directory"
+        )
+        exit(-1)
+    output.mkdir(parents=True, exist_ok=True)
+
+    image_paths = glob_images(image)
+    mask_paths = glob_images(mask)
+    if len(image_paths) == 0:
+        logger.error("invalid --image: empty image folder")
+        exit(-1)
+    if len(mask_paths) == 0:
+        logger.error("invalid --mask: empty mask folder")
+        exit(-1)
+
+    if config is None:
+        inpaint_request = InpaintRequest()
+        logger.info(f"Using default config: {inpaint_request}")
+    else:
+        with open(config, "r", encoding="utf-8") as f:
+            inpaint_request = InpaintRequest(**json.load(f))
+        logger.info(f"Using config: {inpaint_request}")
+
+    model_manager = ModelManager(name=model, device=device)
+    first_mask = list(mask_paths.values())[0]
+
+    console = Console()
+
+    with Progress(
+        SpinnerColumn(),
+        TextColumn("[progress.description]{task.description}"),
+        BarColumn(),
+        TaskProgressColumn(),
+        MofNCompleteColumn(),
+        TimeElapsedColumn(),
+        console=console,
+        transient=False,
+    ) as progress:
+        task = progress.add_task("Batch processing...", total=len(image_paths))
+        for stem, image_p in image_paths.items():
+            if stem not in mask_paths and mask.is_dir():
+                progress.log(f"mask for {image_p} not found")
+                progress.update(task, advance=1)
+                continue
+            mask_p = mask_paths.get(stem, first_mask)
+
+            infos = Image.open(image_p).info
+
+            img = np.array(Image.open(image_p).convert("RGB"))
+            mask_img = np.array(Image.open(mask_p).convert("L"))
+
+            if mask_img.shape[:2] != img.shape[:2]:
+                progress.log(
+                    f"resize mask {mask_p.name} to image {image_p.name} size: {img.shape[:2]}"
+                )
+                mask_img = cv2.resize(
+                    mask_img,
+                    (img.shape[1], img.shape[0]),
+                    interpolation=cv2.INTER_NEAREST,
+                )
+            mask_img[mask_img >= 127] = 255
+            mask_img[mask_img < 127] = 0
+
+            # bgr
+            inpaint_result = model_manager(img, mask_img, inpaint_request)
+            inpaint_result = cv2.cvtColor(inpaint_result, cv2.COLOR_BGR2RGB)
+            if concat:
+                mask_img = cv2.cvtColor(mask_img, cv2.COLOR_GRAY2RGB)
+                inpaint_result = cv2.hconcat([img, mask_img, inpaint_result])
+
+            img_bytes = pil_to_bytes(Image.fromarray(inpaint_result), "png", 100, infos)
+            save_p = output / f"{stem}.png"
+            with open(save_p, "wb") as fw:
+                fw.write(img_bytes)
+
+            progress.update(task, advance=1)
+            torch_gc()
+            # pid = psutil.Process().pid
+            # memory_info = psutil.Process(pid).memory_info()
+            # memory_in_mb = memory_info.rss / (1024 * 1024)
+            # print(f"原图大小:{img.shape},当前进程的内存占用:{memory_in_mb}MB")

+ 109 - 0
sorawm/iopaint/benchmark.py

@@ -0,0 +1,109 @@
+#!/usr/bin/env python3
+
+import argparse
+import os
+import time
+
+import numpy as np
+import nvidia_smi
+import psutil
+import torch
+
+from sorawm.iopaint.model_manager import ModelManager
+from sorawm.iopaint.schema import HDStrategy, InpaintRequest, SDSampler
+
+try:
+    torch._C._jit_override_can_fuse_on_cpu(False)
+    torch._C._jit_override_can_fuse_on_gpu(False)
+    torch._C._jit_set_texpr_fuser_enabled(False)
+    torch._C._jit_set_nvfuser_enabled(False)
+except:
+    pass
+
+NUM_THREADS = str(4)
+
+os.environ["OMP_NUM_THREADS"] = NUM_THREADS
+os.environ["OPENBLAS_NUM_THREADS"] = NUM_THREADS
+os.environ["MKL_NUM_THREADS"] = NUM_THREADS
+os.environ["VECLIB_MAXIMUM_THREADS"] = NUM_THREADS
+os.environ["NUMEXPR_NUM_THREADS"] = NUM_THREADS
+if os.environ.get("CACHE_DIR"):
+    os.environ["TORCH_HOME"] = os.environ["CACHE_DIR"]
+
+
+def run_model(model, size):
+    # RGB
+    image = np.random.randint(0, 256, (size[0], size[1], 3)).astype(np.uint8)
+    mask = np.random.randint(0, 255, size).astype(np.uint8)
+
+    config = InpaintRequest(
+        ldm_steps=2,
+        hd_strategy=HDStrategy.ORIGINAL,
+        hd_strategy_crop_margin=128,
+        hd_strategy_crop_trigger_size=128,
+        hd_strategy_resize_limit=128,
+        prompt="a fox is sitting on a bench",
+        sd_steps=5,
+        sd_sampler=SDSampler.ddim,
+    )
+    model(image, mask, config)
+
+
+def benchmark(model, times: int, empty_cache: bool):
+    sizes = [(512, 512)]
+
+    nvidia_smi.nvmlInit()
+    device_id = 0
+    handle = nvidia_smi.nvmlDeviceGetHandleByIndex(device_id)
+
+    def format(metrics):
+        return f"{np.mean(metrics):.2f} ± {np.std(metrics):.2f}"
+
+    process = psutil.Process(os.getpid())
+    # 每个 size 给出显存和内存占用的指标
+    for size in sizes:
+        torch.cuda.empty_cache()
+        time_metrics = []
+        cpu_metrics = []
+        memory_metrics = []
+        gpu_memory_metrics = []
+        for _ in range(times):
+            start = time.time()
+            run_model(model, size)
+            torch.cuda.synchronize()
+
+            # cpu_metrics.append(process.cpu_percent())
+            time_metrics.append((time.time() - start) * 1000)
+            memory_metrics.append(process.memory_info().rss / 1024 / 1024)
+            gpu_memory_metrics.append(
+                nvidia_smi.nvmlDeviceGetMemoryInfo(handle).used / 1024 / 1024
+            )
+
+        print(f"size: {size}".center(80, "-"))
+        # print(f"cpu: {format(cpu_metrics)}")
+        print(f"latency: {format(time_metrics)}ms")
+        print(f"memory: {format(memory_metrics)} MB")
+        print(f"gpu memory: {format(gpu_memory_metrics)} MB")
+
+    nvidia_smi.nvmlShutdown()
+
+
+def get_args_parser():
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--name")
+    parser.add_argument("--device", default="cuda", type=str)
+    parser.add_argument("--times", default=10, type=int)
+    parser.add_argument("--empty-cache", action="store_true")
+    return parser.parse_args()
+
+
+if __name__ == "__main__":
+    args = get_args_parser()
+    device = torch.device(args.device)
+    model = ModelManager(
+        name=args.name,
+        device=device,
+        disable_nsfw=True,
+        sd_cpu_textencoder=True,
+    )
+    benchmark(model, args.times, args.empty_cache)

+ 245 - 0
sorawm/iopaint/cli.py

@@ -0,0 +1,245 @@
+import webbrowser
+from contextlib import asynccontextmanager
+from pathlib import Path
+from typing import Optional
+
+import typer
+from fastapi import FastAPI
+from loguru import logger
+from typer import Option
+from typer_config import use_json_config
+
+from sorawm.iopaint.const import *
+from sorawm.iopaint.runtime import check_device, dump_environment_info, setup_model_dir
+from sorawm.iopaint.schema import (
+    Device,
+    InteractiveSegModel,
+    RealESRGANModel,
+    RemoveBGModel,
+)
+
+typer_app = typer.Typer(pretty_exceptions_show_locals=False, add_completion=False)
+
+
+@typer_app.command(help="Install all plugins dependencies")
+def install_plugins_packages():
+    from sorawm.iopaint.installer import install_plugins_package
+
+    install_plugins_package()
+
+
+@typer_app.command(help="Download SD/SDXL normal/inpainting model from HuggingFace")
+def download(
+    model: str = Option(
+        ..., help="Model id on HuggingFace e.g: runwayml/stable-diffusion-inpainting"
+    ),
+    model_dir: Path = Option(
+        DEFAULT_MODEL_DIR,
+        help=MODEL_DIR_HELP,
+        file_okay=False,
+        callback=setup_model_dir,
+    ),
+):
+    from sorawm.iopaint.download import cli_download_model
+
+    cli_download_model(model)
+
+
+@typer_app.command(name="list", help="List downloaded models")
+def list_model(
+    model_dir: Path = Option(
+        DEFAULT_MODEL_DIR,
+        help=MODEL_DIR_HELP,
+        file_okay=False,
+        callback=setup_model_dir,
+    ),
+):
+    from sorawm.iopaint.download import scan_models
+
+    scanned_models = scan_models()
+    for it in scanned_models:
+        print(it.name)
+
+
+@typer_app.command(help="Batch processing images")
+def run(
+    model: str = Option("lama"),
+    device: Device = Option(Device.cpu),
+    image: Path = Option(..., help="Image folders or file path"),
+    mask: Path = Option(
+        ...,
+        help="Mask folders or file path. "
+        "If it is a directory, the mask images in the directory should have the same name as the original image."
+        "If it is a file, all images will use this mask."
+        "Mask will automatically resize to the same size as the original image.",
+    ),
+    output: Path = Option(..., help="Output directory or file path"),
+    config: Path = Option(
+        None, help="Config file path. You can use dump command to create a base config."
+    ),
+    concat: bool = Option(
+        False, help="Concat original image, mask and output images into one image"
+    ),
+    model_dir: Path = Option(
+        DEFAULT_MODEL_DIR,
+        help=MODEL_DIR_HELP,
+        file_okay=False,
+        callback=setup_model_dir,
+    ),
+):
+    from sorawm.iopaint.download import cli_download_model, scan_models
+
+    scanned_models = scan_models()
+    if model not in [it.name for it in scanned_models]:
+        logger.info(f"{model} not found in {model_dir}, try to downloading")
+        cli_download_model(model)
+
+    from sorawm.iopaint.batch_processing import batch_inpaint
+
+    batch_inpaint(model, device, image, mask, output, config, concat)
+
+
+@typer_app.command(help="Start IOPaint server")
+@use_json_config()
+def start(
+    host: str = Option("127.0.0.1"),
+    port: int = Option(8080),
+    inbrowser: bool = Option(False, help=INBROWSER_HELP),
+    model: str = Option(
+        DEFAULT_MODEL,
+        help=f"Erase models: [{', '.join(AVAILABLE_MODELS)}].\n"
+        f"Diffusion models: [{', '.join(DIFFUSION_MODELS)}] or any SD/SDXL normal/inpainting models on HuggingFace.",
+    ),
+    model_dir: Path = Option(
+        DEFAULT_MODEL_DIR,
+        help=MODEL_DIR_HELP,
+        dir_okay=True,
+        file_okay=False,
+        callback=setup_model_dir,
+    ),
+    low_mem: bool = Option(False, help=LOW_MEM_HELP),
+    no_half: bool = Option(False, help=NO_HALF_HELP),
+    cpu_offload: bool = Option(False, help=CPU_OFFLOAD_HELP),
+    disable_nsfw_checker: bool = Option(False, help=DISABLE_NSFW_HELP),
+    cpu_textencoder: bool = Option(False, help=CPU_TEXTENCODER_HELP),
+    local_files_only: bool = Option(False, help=LOCAL_FILES_ONLY_HELP),
+    device: Device = Option(Device.cpu),
+    input: Optional[Path] = Option(None, help=INPUT_HELP),
+    mask_dir: Optional[Path] = Option(
+        None, help=MODEL_DIR_HELP, dir_okay=True, file_okay=False
+    ),
+    output_dir: Optional[Path] = Option(
+        None, help=OUTPUT_DIR_HELP, dir_okay=True, file_okay=False
+    ),
+    quality: int = Option(100, help=QUALITY_HELP),
+    enable_interactive_seg: bool = Option(False, help=INTERACTIVE_SEG_HELP),
+    interactive_seg_model: InteractiveSegModel = Option(
+        InteractiveSegModel.sam2_1_tiny, help=INTERACTIVE_SEG_MODEL_HELP
+    ),
+    interactive_seg_device: Device = Option(Device.cpu),
+    enable_remove_bg: bool = Option(False, help=REMOVE_BG_HELP),
+    remove_bg_device: Device = Option(Device.cpu, help=REMOVE_BG_DEVICE_HELP),
+    remove_bg_model: RemoveBGModel = Option(RemoveBGModel.briaai_rmbg_1_4),
+    enable_anime_seg: bool = Option(False, help=ANIMESEG_HELP),
+    enable_realesrgan: bool = Option(False),
+    realesrgan_device: Device = Option(Device.cpu),
+    realesrgan_model: RealESRGANModel = Option(RealESRGANModel.realesr_general_x4v3),
+    enable_gfpgan: bool = Option(False),
+    gfpgan_device: Device = Option(Device.cpu),
+    enable_restoreformer: bool = Option(False),
+    restoreformer_device: Device = Option(Device.cpu),
+):
+    dump_environment_info()
+    device = check_device(device)
+    remove_bg_device = check_device(remove_bg_device)
+    realesrgan_device = check_device(realesrgan_device)
+    gfpgan_device = check_device(gfpgan_device)
+
+    if input and not input.exists():
+        logger.error(f"invalid --input: {input} not exists")
+        exit(-1)
+    if mask_dir and not mask_dir.exists():
+        logger.error(f"invalid --mask-dir: {mask_dir} not exists")
+        exit(-1)
+    if input and input.is_dir() and not output_dir:
+        logger.error(
+            "invalid --output-dir: --output-dir must be set when --input is a directory"
+        )
+        exit(-1)
+    if output_dir:
+        output_dir = output_dir.expanduser().absolute()
+        logger.info(f"Image will be saved to {output_dir}")
+        if not output_dir.exists():
+            logger.info(f"Create output directory {output_dir}")
+            output_dir.mkdir(parents=True)
+    if mask_dir:
+        mask_dir = mask_dir.expanduser().absolute()
+
+    model_dir = model_dir.expanduser().absolute()
+
+    if local_files_only:
+        os.environ["TRANSFORMERS_OFFLINE"] = "1"
+        os.environ["HF_HUB_OFFLINE"] = "1"
+
+    from sorawm.iopaint.download import cli_download_model, scan_models
+
+    scanned_models = scan_models()
+    if model not in [it.name for it in scanned_models]:
+        logger.info(f"{model} not found in {model_dir}, try to downloading")
+        cli_download_model(model)
+
+    from sorawm.iopaint.api import Api
+    from sorawm.iopaint.schema import ApiConfig
+
+    @asynccontextmanager
+    async def lifespan(app: FastAPI):
+        if inbrowser:
+            webbrowser.open(f"http://localhost:{port}", new=0, autoraise=True)
+        yield
+
+    app = FastAPI(lifespan=lifespan)
+
+    api_config = ApiConfig(
+        host=host,
+        port=port,
+        inbrowser=inbrowser,
+        model=model,
+        no_half=no_half,
+        low_mem=low_mem,
+        cpu_offload=cpu_offload,
+        disable_nsfw_checker=disable_nsfw_checker,
+        local_files_only=local_files_only,
+        cpu_textencoder=cpu_textencoder if device == Device.cuda else False,
+        device=device,
+        input=input,
+        mask_dir=mask_dir,
+        output_dir=output_dir,
+        quality=quality,
+        enable_interactive_seg=enable_interactive_seg,
+        interactive_seg_model=interactive_seg_model,
+        interactive_seg_device=interactive_seg_device,
+        enable_remove_bg=enable_remove_bg,
+        remove_bg_device=remove_bg_device,
+        remove_bg_model=remove_bg_model,
+        enable_anime_seg=enable_anime_seg,
+        enable_realesrgan=enable_realesrgan,
+        realesrgan_device=realesrgan_device,
+        realesrgan_model=realesrgan_model,
+        enable_gfpgan=enable_gfpgan,
+        gfpgan_device=gfpgan_device,
+        enable_restoreformer=enable_restoreformer,
+        restoreformer_device=restoreformer_device,
+    )
+    print(api_config.model_dump_json(indent=4))
+    api = Api(app, api_config)
+    api.launch()
+
+
+@typer_app.command(help="Start IOPaint web config page")
+def start_web_config(
+    config_file: Path = Option("config.json"),
+):
+    dump_environment_info()
+    from sorawm.iopaint.web_config import main
+
+    main(config_file)

+ 134 - 0
sorawm/iopaint/const.py

@@ -0,0 +1,134 @@
+import os
+from typing import List
+
+INSTRUCT_PIX2PIX_NAME = "timbrooks/instruct-pix2pix"
+KANDINSKY22_NAME = "kandinsky-community/kandinsky-2-2-decoder-inpaint"
+POWERPAINT_NAME = "Sanster/PowerPaint-V1-stable-diffusion-inpainting"
+ANYTEXT_NAME = "Sanster/AnyText"
+
+DIFFUSERS_SD_CLASS_NAME = "StableDiffusionPipeline"
+DIFFUSERS_SD_INPAINT_CLASS_NAME = "StableDiffusionInpaintPipeline"
+DIFFUSERS_SDXL_CLASS_NAME = "StableDiffusionXLPipeline"
+DIFFUSERS_SDXL_INPAINT_CLASS_NAME = "StableDiffusionXLInpaintPipeline"
+
+MPS_UNSUPPORT_MODELS = [
+    # "lama",
+    "ldm",
+    "zits",
+    "mat",
+    "fcf",
+    "cv2",
+    "manga",
+]
+
+DEFAULT_MODEL = "lama"
+AVAILABLE_MODELS = ["lama", "ldm", "zits", "mat", "fcf", "manga", "cv2", "migan"]
+DIFFUSION_MODELS = [
+    "runwayml/stable-diffusion-inpainting",
+    "Uminosachi/realisticVisionV51_v51VAE-inpainting",
+    "redstonehero/dreamshaper-inpainting",
+    "Sanster/anything-4.0-inpainting",
+    "diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
+    "Fantasy-Studio/Paint-by-Example",
+    "RunDiffusion/Juggernaut-XI-v11",
+    "SG161222/RealVisXL_V5.0",
+    "eienmojiki/Anything-XL",
+    POWERPAINT_NAME,
+    ANYTEXT_NAME,
+]
+
+NO_HALF_HELP = """
+Using full precision(fp32) model.
+If your diffusion model generate result is always black or green, use this argument.
+"""
+
+CPU_OFFLOAD_HELP = """
+Offloads diffusion model's weight to CPU RAM, significantly reducing vRAM usage.
+"""
+
+LOW_MEM_HELP = "Enable attention slicing and vae tiling to save memory."
+
+DISABLE_NSFW_HELP = """
+Disable NSFW checker for diffusion model.
+"""
+
+CPU_TEXTENCODER_HELP = """
+Run diffusion models text encoder on CPU to reduce vRAM usage.
+"""
+
+SD_CONTROLNET_CHOICES: List[str] = [
+    "lllyasviel/control_v11p_sd15_canny",
+    # "lllyasviel/control_v11p_sd15_seg",
+    "lllyasviel/control_v11p_sd15_openpose",
+    "lllyasviel/control_v11p_sd15_inpaint",
+    "lllyasviel/control_v11f1p_sd15_depth",
+]
+
+SD_BRUSHNET_CHOICES: List[str] = [
+    "Sanster/brushnet_random_mask",
+    "Sanster/brushnet_segmentation_mask",
+]
+
+SD2_CONTROLNET_CHOICES = [
+    "thibaud/controlnet-sd21-canny-diffusers",
+    "thibaud/controlnet-sd21-depth-diffusers",
+    "thibaud/controlnet-sd21-openpose-diffusers",
+]
+
+SDXL_CONTROLNET_CHOICES = [
+    "thibaud/controlnet-openpose-sdxl-1.0",
+    "destitech/controlnet-inpaint-dreamer-sdxl",
+    "diffusers/controlnet-canny-sdxl-1.0",
+    "diffusers/controlnet-canny-sdxl-1.0-mid",
+    "diffusers/controlnet-canny-sdxl-1.0-small",
+    "diffusers/controlnet-depth-sdxl-1.0",
+    "diffusers/controlnet-depth-sdxl-1.0-mid",
+    "diffusers/controlnet-depth-sdxl-1.0-small",
+]
+
+SDXL_BRUSHNET_CHOICES = ["Regulus0725/random_mask_brushnet_ckpt_sdxl_regulus_v1"]
+
+LOCAL_FILES_ONLY_HELP = """
+When loading diffusion models, using local files only, not connect to HuggingFace server.
+"""
+
+DEFAULT_MODEL_DIR = os.path.abspath(
+    os.getenv("XDG_CACHE_HOME", os.path.join(os.path.expanduser("~"), ".cache"))
+)
+
+MODEL_DIR_HELP = f"""
+Model download directory (by setting XDG_CACHE_HOME environment variable), by default model download to {DEFAULT_MODEL_DIR}
+"""
+
+OUTPUT_DIR_HELP = """
+Result images will be saved to output directory automatically.
+"""
+
+MASK_DIR_HELP = """
+You can view masks in FileManager
+"""
+
+INPUT_HELP = """
+If input is image, it will be loaded by default.
+If input is directory, you can browse and select image in file manager.
+"""
+
+GUI_HELP = """
+Launch Lama Cleaner as desktop app
+"""
+
+QUALITY_HELP = """
+Quality of image encoding, 0-100. Default is 95, higher quality will generate larger file size.
+"""
+
+INTERACTIVE_SEG_HELP = "Enable interactive segmentation using Segment Anything."
+INTERACTIVE_SEG_MODEL_HELP = "Model size: mobile_sam < vit_b < vit_l < vit_h. Bigger model size means better segmentation but slower speed."
+REMOVE_BG_HELP = "Enable remove background plugin."
+REMOVE_BG_DEVICE_HELP = "Device for remove background plugin. 'cuda' only supports briaai models(briaai/RMBG-1.4 and briaai/RMBG-2.0)"
+ANIMESEG_HELP = "Enable anime segmentation plugin. Always run on CPU"
+REALESRGAN_HELP = "Enable realesrgan super resolution"
+GFPGAN_HELP = "Enable GFPGAN face restore. To also enhance background, use with --enable-realesrgan"
+RESTOREFORMER_HELP = "Enable RestoreFormer face restore. To also enhance background, use with --enable-realesrgan"
+GIF_HELP = "Enable GIF plugin. Make GIF to compare original and cleaned image"
+
+INBROWSER_HELP = "Automatically launch IOPaint in a new tab on the default browser"

+ 314 - 0
sorawm/iopaint/download.py

@@ -0,0 +1,314 @@
+import glob
+import json
+import os
+from functools import lru_cache
+from pathlib import Path
+from typing import List, Optional
+
+from loguru import logger
+
+from sorawm.iopaint.const import (
+    ANYTEXT_NAME,
+    DEFAULT_MODEL_DIR,
+    DIFFUSERS_SD_CLASS_NAME,
+    DIFFUSERS_SD_INPAINT_CLASS_NAME,
+    DIFFUSERS_SDXL_CLASS_NAME,
+    DIFFUSERS_SDXL_INPAINT_CLASS_NAME,
+)
+from sorawm.iopaint.model.original_sd_configs import get_config_files
+from sorawm.iopaint.schema import ModelInfo, ModelType
+
+
+def cli_download_model(model: str):
+    from sorawm.iopaint.model import models
+    from sorawm.iopaint.model.utils import handle_from_pretrained_exceptions
+
+    if model in models and models[model].is_erase_model:
+        logger.info(f"Downloading {model}...")
+        models[model].download()
+        logger.info("Done.")
+    elif model == ANYTEXT_NAME:
+        logger.info(f"Downloading {model}...")
+        models[model].download()
+        logger.info("Done.")
+    else:
+        logger.info(f"Downloading model from Huggingface: {model}")
+        from diffusers import DiffusionPipeline
+
+        downloaded_path = handle_from_pretrained_exceptions(
+            DiffusionPipeline.download, pretrained_model_name=model, variant="fp16"
+        )
+        logger.info(f"Done. Downloaded to {downloaded_path}")
+
+
+def folder_name_to_show_name(name: str) -> str:
+    return name.replace("models--", "").replace("--", "/")
+
+
+@lru_cache(maxsize=512)
+def get_sd_model_type(model_abs_path: str) -> Optional[ModelType]:
+    if "inpaint" in Path(model_abs_path).name.lower():
+        model_type = ModelType.DIFFUSERS_SD_INPAINT
+    else:
+        # load once to check num_in_channels
+        from diffusers import StableDiffusionInpaintPipeline
+
+        try:
+            StableDiffusionInpaintPipeline.from_single_file(
+                model_abs_path,
+                load_safety_checker=False,
+                num_in_channels=9,
+                original_config_file=get_config_files()["v1"],
+            )
+            model_type = ModelType.DIFFUSERS_SD_INPAINT
+        except ValueError as e:
+            if "[320, 4, 3, 3]" in str(e):
+                model_type = ModelType.DIFFUSERS_SD
+            else:
+                logger.info(f"Ignore non sdxl file: {model_abs_path}")
+                return
+        except Exception as e:
+            logger.error(f"Failed to load {model_abs_path}: {e}")
+            return
+    return model_type
+
+
+@lru_cache()
+def get_sdxl_model_type(model_abs_path: str) -> Optional[ModelType]:
+    if "inpaint" in model_abs_path:
+        model_type = ModelType.DIFFUSERS_SDXL_INPAINT
+    else:
+        # load once to check num_in_channels
+        from diffusers import StableDiffusionXLInpaintPipeline
+
+        try:
+            model = StableDiffusionXLInpaintPipeline.from_single_file(
+                model_abs_path,
+                load_safety_checker=False,
+                num_in_channels=9,
+                original_config_file=get_config_files()["xl"],
+            )
+            if model.unet.config.in_channels == 9:
+                # https://github.com/huggingface/diffusers/issues/6610
+                model_type = ModelType.DIFFUSERS_SDXL_INPAINT
+            else:
+                model_type = ModelType.DIFFUSERS_SDXL
+        except ValueError as e:
+            if "[320, 4, 3, 3]" in str(e):
+                model_type = ModelType.DIFFUSERS_SDXL
+            else:
+                logger.info(f"Ignore non sdxl file: {model_abs_path}")
+                return
+        except Exception as e:
+            logger.error(f"Failed to load {model_abs_path}: {e}")
+            return
+    return model_type
+
+
+def scan_single_file_diffusion_models(cache_dir) -> List[ModelInfo]:
+    cache_dir = Path(cache_dir)
+    stable_diffusion_dir = cache_dir / "stable_diffusion"
+    cache_file = stable_diffusion_dir / "iopaint_cache.json"
+    model_type_cache = {}
+    if cache_file.exists():
+        try:
+            with open(cache_file, "r", encoding="utf-8") as f:
+                model_type_cache = json.load(f)
+                assert isinstance(model_type_cache, dict)
+        except:
+            pass
+
+    res = []
+    for it in stable_diffusion_dir.glob("*.*"):
+        if it.suffix not in [".safetensors", ".ckpt"]:
+            continue
+        model_abs_path = str(it.absolute())
+        model_type = model_type_cache.get(it.name)
+        if model_type is None:
+            model_type = get_sd_model_type(model_abs_path)
+        if model_type is None:
+            continue
+
+        model_type_cache[it.name] = model_type
+        res.append(
+            ModelInfo(
+                name=it.name,
+                path=model_abs_path,
+                model_type=model_type,
+                is_single_file_diffusers=True,
+            )
+        )
+    if stable_diffusion_dir.exists():
+        with open(cache_file, "w", encoding="utf-8") as fw:
+            json.dump(model_type_cache, fw, indent=2, ensure_ascii=False)
+
+    stable_diffusion_xl_dir = cache_dir / "stable_diffusion_xl"
+    sdxl_cache_file = stable_diffusion_xl_dir / "iopaint_cache.json"
+    sdxl_model_type_cache = {}
+    if sdxl_cache_file.exists():
+        try:
+            with open(sdxl_cache_file, "r", encoding="utf-8") as f:
+                sdxl_model_type_cache = json.load(f)
+                assert isinstance(sdxl_model_type_cache, dict)
+        except:
+            pass
+
+    for it in stable_diffusion_xl_dir.glob("*.*"):
+        if it.suffix not in [".safetensors", ".ckpt"]:
+            continue
+        model_abs_path = str(it.absolute())
+        model_type = sdxl_model_type_cache.get(it.name)
+        if model_type is None:
+            model_type = get_sdxl_model_type(model_abs_path)
+        if model_type is None:
+            continue
+
+        sdxl_model_type_cache[it.name] = model_type
+        if stable_diffusion_xl_dir.exists():
+            with open(sdxl_cache_file, "w", encoding="utf-8") as fw:
+                json.dump(sdxl_model_type_cache, fw, indent=2, ensure_ascii=False)
+
+        res.append(
+            ModelInfo(
+                name=it.name,
+                path=model_abs_path,
+                model_type=model_type,
+                is_single_file_diffusers=True,
+            )
+        )
+    return res
+
+
+def scan_inpaint_models(model_dir: Path) -> List[ModelInfo]:
+    res = []
+    from sorawm.iopaint.model import models
+
+    # logger.info(f"Scanning inpaint models in {model_dir}")
+
+    for name, m in models.items():
+        if m.is_erase_model and m.is_downloaded():
+            res.append(
+                ModelInfo(
+                    name=name,
+                    path=name,
+                    model_type=ModelType.INPAINT,
+                )
+            )
+    return res
+
+
+def scan_diffusers_models() -> List[ModelInfo]:
+    from huggingface_hub.constants import HF_HUB_CACHE
+
+    available_models = []
+    cache_dir = Path(HF_HUB_CACHE)
+    # logger.info(f"Scanning diffusers models in {cache_dir}")
+    diffusers_model_names = []
+    model_index_files = glob.glob(
+        os.path.join(cache_dir, "**/*", "model_index.json"), recursive=True
+    )
+    for it in model_index_files:
+        it = Path(it)
+        try:
+            with open(it, "r", encoding="utf-8") as f:
+                data = json.load(f)
+        except:
+            continue
+
+        _class_name = data["_class_name"]
+        name = folder_name_to_show_name(it.parent.parent.parent.name)
+        if name in diffusers_model_names:
+            continue
+        if "PowerPaint" in name:
+            model_type = ModelType.DIFFUSERS_OTHER
+        elif _class_name == DIFFUSERS_SD_CLASS_NAME:
+            model_type = ModelType.DIFFUSERS_SD
+        elif _class_name == DIFFUSERS_SD_INPAINT_CLASS_NAME:
+            model_type = ModelType.DIFFUSERS_SD_INPAINT
+        elif _class_name == DIFFUSERS_SDXL_CLASS_NAME:
+            model_type = ModelType.DIFFUSERS_SDXL
+        elif _class_name == DIFFUSERS_SDXL_INPAINT_CLASS_NAME:
+            model_type = ModelType.DIFFUSERS_SDXL_INPAINT
+        elif _class_name in [
+            "StableDiffusionInstructPix2PixPipeline",
+            "PaintByExamplePipeline",
+            "KandinskyV22InpaintPipeline",
+            "AnyText",
+        ]:
+            model_type = ModelType.DIFFUSERS_OTHER
+        else:
+            continue
+
+        diffusers_model_names.append(name)
+        available_models.append(
+            ModelInfo(
+                name=name,
+                path=name,
+                model_type=model_type,
+            )
+        )
+    return available_models
+
+
+def _scan_converted_diffusers_models(cache_dir) -> List[ModelInfo]:
+    cache_dir = Path(cache_dir)
+    available_models = []
+    diffusers_model_names = []
+    model_index_files = glob.glob(
+        os.path.join(cache_dir, "**/*", "model_index.json"), recursive=True
+    )
+    for it in model_index_files:
+        it = Path(it)
+        with open(it, "r", encoding="utf-8") as f:
+            try:
+                data = json.load(f)
+            except:
+                logger.error(
+                    f"Failed to load {it}, please try revert from original model or fix model_index.json by hand."
+                )
+                continue
+
+            _class_name = data["_class_name"]
+            name = folder_name_to_show_name(it.parent.name)
+            if name in diffusers_model_names:
+                continue
+            elif _class_name == DIFFUSERS_SD_CLASS_NAME:
+                model_type = ModelType.DIFFUSERS_SD
+            elif _class_name == DIFFUSERS_SD_INPAINT_CLASS_NAME:
+                model_type = ModelType.DIFFUSERS_SD_INPAINT
+            elif _class_name == DIFFUSERS_SDXL_CLASS_NAME:
+                model_type = ModelType.DIFFUSERS_SDXL
+            elif _class_name == DIFFUSERS_SDXL_INPAINT_CLASS_NAME:
+                model_type = ModelType.DIFFUSERS_SDXL_INPAINT
+            else:
+                continue
+
+            diffusers_model_names.append(name)
+            available_models.append(
+                ModelInfo(
+                    name=name,
+                    path=str(it.parent.absolute()),
+                    model_type=model_type,
+                )
+            )
+    return available_models
+
+
+def scan_converted_diffusers_models(cache_dir) -> List[ModelInfo]:
+    cache_dir = Path(cache_dir)
+    available_models = []
+    stable_diffusion_dir = cache_dir / "stable_diffusion"
+    stable_diffusion_xl_dir = cache_dir / "stable_diffusion_xl"
+    available_models.extend(_scan_converted_diffusers_models(stable_diffusion_dir))
+    available_models.extend(_scan_converted_diffusers_models(stable_diffusion_xl_dir))
+    return available_models
+
+
+def scan_models() -> List[ModelInfo]:
+    model_dir = os.getenv("XDG_CACHE_HOME", DEFAULT_MODEL_DIR)
+    available_models = []
+    available_models.extend(scan_inpaint_models(model_dir))
+    available_models.extend(scan_single_file_diffusion_models(model_dir))
+    available_models.extend(scan_diffusers_models())
+    available_models.extend(scan_converted_diffusers_models(model_dir))
+    return available_models

+ 1 - 0
sorawm/iopaint/file_manager/__init__.py

@@ -0,0 +1 @@
+from .file_manager import FileManager

+ 220 - 0
sorawm/iopaint/file_manager/file_manager.py

@@ -0,0 +1,220 @@
+import os
+from io import BytesIO
+from pathlib import Path
+from typing import List
+
+from fastapi import FastAPI, HTTPException
+from PIL import Image, ImageOps, PngImagePlugin
+from starlette.responses import FileResponse
+
+from ..schema import MediasResponse, MediaTab
+
+LARGE_ENOUGH_NUMBER = 100
+PngImagePlugin.MAX_TEXT_CHUNK = LARGE_ENOUGH_NUMBER * (1024**2)
+from .storage_backends import FilesystemStorageBackend
+from .utils import aspect_to_string, generate_filename, glob_img
+
+
+class FileManager:
+    def __init__(self, app: FastAPI, input_dir: Path, mask_dir: Path, output_dir: Path):
+        self.app = app
+        self.input_dir: Path = input_dir
+        self.mask_dir: Path = mask_dir
+        self.output_dir: Path = output_dir
+
+        self.image_dir_filenames = []
+        self.output_dir_filenames = []
+        if not self.thumbnail_directory.exists():
+            self.thumbnail_directory.mkdir(parents=True)
+
+        # fmt: off
+        self.app.add_api_route("/api/v1/medias", self.api_medias, methods=["GET"], response_model=List[MediasResponse])
+        self.app.add_api_route("/api/v1/media_file", self.api_media_file, methods=["GET"])
+        self.app.add_api_route("/api/v1/media_thumbnail_file", self.api_media_thumbnail_file, methods=["GET"])
+        # fmt: on
+
+    def api_medias(self, tab: MediaTab) -> List[MediasResponse]:
+        img_dir = self._get_dir(tab)
+        return self._media_names(img_dir)
+
+    def api_media_file(self, tab: MediaTab, filename: str) -> FileResponse:
+        file_path = self._get_file(tab, filename)
+        return FileResponse(file_path, media_type="image/png")
+
+    # tab=${tab}?filename=${filename.name}?width=${width}&height=${height}
+    def api_media_thumbnail_file(
+        self, tab: MediaTab, filename: str, width: int, height: int
+    ) -> FileResponse:
+        img_dir = self._get_dir(tab)
+        thumb_filename, (width, height) = self.get_thumbnail(
+            img_dir, filename, width=width, height=height
+        )
+        thumbnail_filepath = self.thumbnail_directory / thumb_filename
+        return FileResponse(
+            thumbnail_filepath,
+            headers={
+                "X-Width": str(width),
+                "X-Height": str(height),
+            },
+            media_type="image/jpeg",
+        )
+
+    def _get_dir(self, tab: MediaTab) -> Path:
+        if tab == "input":
+            return self.input_dir
+        elif tab == "output":
+            return self.output_dir
+        elif tab == "mask":
+            return self.mask_dir
+        else:
+            raise HTTPException(status_code=422, detail=f"tab not found: {tab}")
+
+    def _get_file(self, tab: MediaTab, filename: str) -> Path:
+        file_path = self._get_dir(tab) / filename
+        if not file_path.exists():
+            raise HTTPException(status_code=422, detail=f"file not found: {file_path}")
+        return file_path
+
+    @property
+    def thumbnail_directory(self) -> Path:
+        return self.output_dir / "thumbnails"
+
+    @staticmethod
+    def _media_names(directory: Path) -> List[MediasResponse]:
+        if directory is None:
+            return []
+        names = sorted([it.name for it in glob_img(directory)])
+        res = []
+        for name in names:
+            path = os.path.join(directory, name)
+            img = Image.open(path)
+            res.append(
+                MediasResponse(
+                    name=name,
+                    height=img.height,
+                    width=img.width,
+                    ctime=os.path.getctime(path),
+                    mtime=os.path.getmtime(path),
+                )
+            )
+        return res
+
+    def get_thumbnail(
+        self, directory: Path, original_filename: str, width, height, **options
+    ):
+        directory = Path(directory)
+        storage = FilesystemStorageBackend(self.app)
+        crop = options.get("crop", "fit")
+        background = options.get("background")
+        quality = options.get("quality", 90)
+
+        original_path, original_filename = os.path.split(original_filename)
+        original_filepath = os.path.join(directory, original_path, original_filename)
+        image = Image.open(BytesIO(storage.read(original_filepath)))
+
+        # keep ratio resize
+        if not width and not height:
+            width = 256
+
+        if width != 0:
+            height = int(image.height * width / image.width)
+        else:
+            width = int(image.width * height / image.height)
+
+        thumbnail_size = (width, height)
+
+        thumbnail_filename = generate_filename(
+            directory,
+            original_filename,
+            aspect_to_string(thumbnail_size),
+            crop,
+            background,
+            quality,
+        )
+
+        thumbnail_filepath = os.path.join(
+            self.thumbnail_directory, original_path, thumbnail_filename
+        )
+
+        if storage.exists(thumbnail_filepath):
+            return thumbnail_filepath, (width, height)
+
+        try:
+            image.load()
+        except (IOError, OSError):
+            self.app.logger.warning("Thumbnail not load image: %s", original_filepath)
+            return thumbnail_filepath, (width, height)
+
+        # get original image format
+        options["format"] = options.get("format", image.format)
+
+        image = self._create_thumbnail(
+            image, thumbnail_size, crop, background=background
+        )
+
+        raw_data = self.get_raw_data(image, **options)
+        storage.save(thumbnail_filepath, raw_data)
+
+        return thumbnail_filepath, (width, height)
+
+    def get_raw_data(self, image, **options):
+        data = {
+            "format": self._get_format(image, **options),
+            "quality": options.get("quality", 90),
+        }
+
+        _file = BytesIO()
+        image.save(_file, **data)
+        return _file.getvalue()
+
+    @staticmethod
+    def colormode(image, colormode="RGB"):
+        if colormode == "RGB" or colormode == "RGBA":
+            if image.mode == "RGBA":
+                return image
+            if image.mode == "LA":
+                return image.convert("RGBA")
+            return image.convert(colormode)
+
+        if colormode == "GRAY":
+            return image.convert("L")
+
+        return image.convert(colormode)
+
+    @staticmethod
+    def background(original_image, color=0xFF):
+        size = (max(original_image.size),) * 2
+        image = Image.new("L", size, color)
+        image.paste(
+            original_image,
+            tuple(map(lambda x: (x[0] - x[1]) / 2, zip(size, original_image.size))),
+        )
+
+        return image
+
+    def _get_format(self, image, **options):
+        if options.get("format"):
+            return options.get("format")
+        if image.format:
+            return image.format
+
+        return "JPEG"
+
+    def _create_thumbnail(self, image, size, crop="fit", background=None):
+        try:
+            resample = Image.Resampling.LANCZOS
+        except AttributeError:  # pylint: disable=raise-missing-from
+            resample = Image.ANTIALIAS
+
+        if crop == "fit":
+            image = ImageOps.fit(image, size, resample)
+        else:
+            image = image.copy()
+            image.thumbnail(size, resample=resample)
+
+        if background is not None:
+            image = self.background(image)
+
+        image = self.colormode(image)
+
+        return image

+ 46 - 0
sorawm/iopaint/file_manager/storage_backends.py

@@ -0,0 +1,46 @@
+# Copy from https://github.com/silentsokolov/flask-thumbnails/blob/master/flask_thumbnails/storage_backends.py
+import errno
+import os
+from abc import ABC, abstractmethod
+
+
+class BaseStorageBackend(ABC):
+    def __init__(self, app=None):
+        self.app = app
+
+    @abstractmethod
+    def read(self, filepath, mode="rb", **kwargs):
+        raise NotImplementedError
+
+    @abstractmethod
+    def exists(self, filepath):
+        raise NotImplementedError
+
+    @abstractmethod
+    def save(self, filepath, data):
+        raise NotImplementedError
+
+
+class FilesystemStorageBackend(BaseStorageBackend):
+    def read(self, filepath, mode="rb", **kwargs):
+        with open(filepath, mode) as f:  # pylint: disable=unspecified-encoding
+            return f.read()
+
+    def exists(self, filepath):
+        return os.path.exists(filepath)
+
+    def save(self, filepath, data):
+        directory = os.path.dirname(filepath)
+
+        if not os.path.exists(directory):
+            try:
+                os.makedirs(directory)
+            except OSError as e:
+                if e.errno != errno.EEXIST:
+                    raise
+
+        if not os.path.isdir(directory):
+            raise IOError("{} is not a directory".format(directory))
+
+        with open(filepath, "wb") as f:
+            f.write(data)

+ 64 - 0
sorawm/iopaint/file_manager/utils.py

@@ -0,0 +1,64 @@
+# Copy from: https://github.com/silentsokolov/flask-thumbnails/blob/master/flask_thumbnails/utils.py
+import hashlib
+from pathlib import Path
+from typing import Union
+
+
+def generate_filename(directory: Path, original_filename, *options) -> str:
+    text = str(directory.absolute()) + original_filename
+    for v in options:
+        text += "%s" % v
+    md5_hash = hashlib.md5()
+    md5_hash.update(text.encode("utf-8"))
+    return md5_hash.hexdigest() + ".jpg"
+
+
+def parse_size(size):
+    if isinstance(size, int):
+        # If the size parameter is a single number, assume square aspect.
+        return [size, size]
+
+    if isinstance(size, (tuple, list)):
+        if len(size) == 1:
+            # If single value tuple/list is provided, exand it to two elements
+            return size + type(size)(size)
+        return size
+
+    try:
+        thumbnail_size = [int(x) for x in size.lower().split("x", 1)]
+    except ValueError:
+        raise ValueError(  # pylint: disable=raise-missing-from
+            "Bad thumbnail size format. Valid format is INTxINT."
+        )
+
+    if len(thumbnail_size) == 1:
+        # If the size parameter only contains a single integer, assume square aspect.
+        thumbnail_size.append(thumbnail_size[0])
+
+    return thumbnail_size
+
+
+def aspect_to_string(size):
+    if isinstance(size, str):
+        return size
+
+    return "x".join(map(str, size))
+
+
+IMG_SUFFIX = {".jpg", ".jpeg", ".png", ".JPG", ".JPEG", ".PNG"}
+
+
+def glob_img(p: Union[Path, str], recursive: bool = False):
+    p = Path(p)
+    if p.is_file() and p.suffix in IMG_SUFFIX:
+        yield p
+    else:
+        if recursive:
+            files = Path(p).glob("**/*.*")
+        else:
+            files = Path(p).glob("*.*")
+
+        for it in files:
+            if it.suffix not in IMG_SUFFIX:
+                continue
+            yield it

+ 411 - 0
sorawm/iopaint/helper.py

@@ -0,0 +1,411 @@
+import base64
+import hashlib
+import imghdr
+import io
+import os
+import sys
+from typing import Dict, List, Optional, Tuple
+from urllib.parse import urlparse
+
+import cv2
+import numpy as np
+import torch
+from loguru import logger
+from PIL import Image, ImageOps, PngImagePlugin
+from torch.hub import download_url_to_file, get_dir
+
+from sorawm.iopaint.const import MPS_UNSUPPORT_MODELS
+
+
+def md5sum(filename):
+    md5 = hashlib.md5()
+    with open(filename, "rb") as f:
+        for chunk in iter(lambda: f.read(128 * md5.block_size), b""):
+            md5.update(chunk)
+    return md5.hexdigest()
+
+
+def switch_mps_device(model_name, device):
+    if model_name in MPS_UNSUPPORT_MODELS and str(device) == "mps":
+        logger.info(f"{model_name} not support mps, switch to cpu")
+        return torch.device("cpu")
+    return device
+
+
+def get_cache_path_by_url(url):
+    parts = urlparse(url)
+    hub_dir = get_dir()
+    model_dir = os.path.join(hub_dir, "checkpoints")
+    if not os.path.isdir(model_dir):
+        os.makedirs(model_dir)
+    filename = os.path.basename(parts.path)
+    cached_file = os.path.join(model_dir, filename)
+    return cached_file
+
+
+def download_model(url, model_md5: str = None):
+    if os.path.exists(url):
+        cached_file = url
+    else:
+        cached_file = get_cache_path_by_url(url)
+    if not os.path.exists(cached_file):
+        sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
+        hash_prefix = None
+        download_url_to_file(url, cached_file, hash_prefix, progress=True)
+        if model_md5:
+            _md5 = md5sum(cached_file)
+            if model_md5 == _md5:
+                logger.info(f"Download model success, md5: {_md5}")
+            else:
+                try:
+                    os.remove(cached_file)
+                    logger.error(
+                        f"Model md5: {_md5}, expected md5: {model_md5}, wrong model deleted. Please restart sorawm.iopaint."
+                        f"If you still have errors, please try download model manually first https://lama-cleaner-docs.vercel.app/install/download_model_manually.\n"
+                    )
+                except:
+                    logger.error(
+                        f"Model md5: {_md5}, expected md5: {model_md5}, please delete {cached_file} and restart sorawm.iopaint."
+                    )
+                exit(-1)
+
+    return cached_file
+
+
+def ceil_modulo(x, mod):
+    if x % mod == 0:
+        return x
+    return (x // mod + 1) * mod
+
+
+def handle_error(model_path, model_md5, e):
+    _md5 = md5sum(model_path)
+    if _md5 != model_md5:
+        try:
+            os.remove(model_path)
+            logger.error(
+                f"Model md5: {_md5}, expected md5: {model_md5}, wrong model deleted. Please restart sorawm.iopaint."
+                f"If you still have errors, please try download model manually first https://lama-cleaner-docs.vercel.app/install/download_model_manually.\n"
+            )
+        except:
+            logger.error(
+                f"Model md5: {_md5}, expected md5: {model_md5}, please delete {model_path} and restart sorawm.iopaint."
+            )
+    else:
+        logger.error(
+            f"Failed to load model {model_path},"
+            f"please submit an issue at https://github.com/Sanster/lama-cleaner/issues and include a screenshot of the error:\n{e}"
+        )
+    exit(-1)
+
+
+def load_jit_model(url_or_path, device, model_md5: str):
+    if os.path.exists(url_or_path):
+        model_path = url_or_path
+    else:
+        model_path = download_model(url_or_path, model_md5)
+
+    logger.info(f"Loading model from: {model_path}")
+    try:
+        model = torch.jit.load(model_path, map_location="cpu").to(device)
+    except Exception as e:
+        handle_error(model_path, model_md5, e)
+    model.eval()
+    return model
+
+
+def load_model(model: torch.nn.Module, url_or_path, device, model_md5):
+    if os.path.exists(url_or_path):
+        model_path = url_or_path
+    else:
+        model_path = download_model(url_or_path, model_md5)
+
+    try:
+        logger.info(f"Loading model from: {model_path}")
+        state_dict = torch.load(model_path, map_location="cpu")
+        model.load_state_dict(state_dict, strict=True)
+        model.to(device)
+    except Exception as e:
+        handle_error(model_path, model_md5, e)
+    model.eval()
+    return model
+
+
+def numpy_to_bytes(image_numpy: np.ndarray, ext: str) -> bytes:
+    data = cv2.imencode(
+        f".{ext}",
+        image_numpy,
+        [int(cv2.IMWRITE_JPEG_QUALITY), 100, int(cv2.IMWRITE_PNG_COMPRESSION), 0],
+    )[1]
+    image_bytes = data.tobytes()
+    return image_bytes
+
+
+def pil_to_bytes(pil_img, ext: str, quality: int = 95, infos={}) -> bytes:
+    with io.BytesIO() as output:
+        kwargs = {k: v for k, v in infos.items() if v is not None}
+        if ext == "jpg":
+            ext = "jpeg"
+        if "png" == ext.lower() and "parameters" in kwargs:
+            pnginfo_data = PngImagePlugin.PngInfo()
+            pnginfo_data.add_text("parameters", kwargs["parameters"])
+            kwargs["pnginfo"] = pnginfo_data
+
+        pil_img.save(output, format=ext, quality=quality, **kwargs)
+        image_bytes = output.getvalue()
+    return image_bytes
+
+
+def load_img(img_bytes, gray: bool = False, return_info: bool = False):
+    alpha_channel = None
+    image = Image.open(io.BytesIO(img_bytes))
+
+    if return_info:
+        infos = image.info
+
+    try:
+        image = ImageOps.exif_transpose(image)
+    except:
+        pass
+
+    if gray:
+        image = image.convert("L")
+        np_img = np.array(image)
+    else:
+        if image.mode == "RGBA":
+            np_img = np.array(image)
+            alpha_channel = np_img[:, :, -1]
+            np_img = cv2.cvtColor(np_img, cv2.COLOR_RGBA2RGB)
+        else:
+            image = image.convert("RGB")
+            np_img = np.array(image)
+
+    if return_info:
+        return np_img, alpha_channel, infos
+    return np_img, alpha_channel
+
+
+def norm_img(np_img):
+    if len(np_img.shape) == 2:
+        np_img = np_img[:, :, np.newaxis]
+    np_img = np.transpose(np_img, (2, 0, 1))
+    np_img = np_img.astype("float32") / 255
+    return np_img
+
+
+def resize_max_size(
+    np_img, size_limit: int, interpolation=cv2.INTER_CUBIC
+) -> np.ndarray:
+    # Resize image's longer size to size_limit if longer size larger than size_limit
+    h, w = np_img.shape[:2]
+    if max(h, w) > size_limit:
+        ratio = size_limit / max(h, w)
+        new_w = int(w * ratio + 0.5)
+        new_h = int(h * ratio + 0.5)
+        return cv2.resize(np_img, dsize=(new_w, new_h), interpolation=interpolation)
+    else:
+        return np_img
+
+
+def pad_img_to_modulo(
+    img: np.ndarray, mod: int, square: bool = False, min_size: Optional[int] = None
+):
+    """
+
+    Args:
+        img: [H, W, C]
+        mod:
+        square: 是否为正方形
+        min_size:
+
+    Returns:
+
+    """
+    if len(img.shape) == 2:
+        img = img[:, :, np.newaxis]
+    height, width = img.shape[:2]
+    out_height = ceil_modulo(height, mod)
+    out_width = ceil_modulo(width, mod)
+
+    if min_size is not None:
+        assert min_size % mod == 0
+        out_width = max(min_size, out_width)
+        out_height = max(min_size, out_height)
+
+    if square:
+        max_size = max(out_height, out_width)
+        out_height = max_size
+        out_width = max_size
+
+    return np.pad(
+        img,
+        ((0, out_height - height), (0, out_width - width), (0, 0)),
+        mode="symmetric",
+    )
+
+
+def boxes_from_mask(mask: np.ndarray) -> List[np.ndarray]:
+    """
+    Args:
+        mask: (h, w, 1)  0~255
+
+    Returns:
+
+    """
+    height, width = mask.shape[:2]
+    _, thresh = cv2.threshold(mask, 127, 255, 0)
+    contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
+
+    boxes = []
+    for cnt in contours:
+        x, y, w, h = cv2.boundingRect(cnt)
+        box = np.array([x, y, x + w, y + h]).astype(int)
+
+        box[::2] = np.clip(box[::2], 0, width)
+        box[1::2] = np.clip(box[1::2], 0, height)
+        boxes.append(box)
+
+    return boxes
+
+
+def only_keep_largest_contour(mask: np.ndarray) -> List[np.ndarray]:
+    """
+    Args:
+        mask: (h, w)  0~255
+
+    Returns:
+
+    """
+    _, thresh = cv2.threshold(mask, 127, 255, 0)
+    contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
+
+    max_area = 0
+    max_index = -1
+    for i, cnt in enumerate(contours):
+        area = cv2.contourArea(cnt)
+        if area > max_area:
+            max_area = area
+            max_index = i
+
+    if max_index != -1:
+        new_mask = np.zeros_like(mask)
+        return cv2.drawContours(new_mask, contours, max_index, 255, -1)
+    else:
+        return mask
+
+
+def is_mac():
+    return sys.platform == "darwin"
+
+
+def get_image_ext(img_bytes):
+    w = imghdr.what("", img_bytes)
+    if w is None:
+        w = "jpeg"
+    return w
+
+
+def decode_base64_to_image(
+    encoding: str, gray=False
+) -> Tuple[np.array, Optional[np.array], Dict, str]:
+    if encoding.startswith("data:image/") or encoding.startswith(
+        "data:application/octet-stream;base64,"
+    ):
+        encoding = encoding.split(";")[1].split(",")[1]
+    image_bytes = base64.b64decode(encoding)
+    ext = get_image_ext(image_bytes)
+    image = Image.open(io.BytesIO(image_bytes))
+
+    alpha_channel = None
+    try:
+        image = ImageOps.exif_transpose(image)
+    except:
+        pass
+    # exif_transpose will remove exif rotate info,we must call image.info after exif_transpose
+    infos = image.info
+
+    if gray:
+        image = image.convert("L")
+        np_img = np.array(image)
+    else:
+        if image.mode == "RGBA":
+            np_img = np.array(image)
+            alpha_channel = np_img[:, :, -1]
+            np_img = cv2.cvtColor(np_img, cv2.COLOR_RGBA2RGB)
+        else:
+            image = image.convert("RGB")
+            np_img = np.array(image)
+
+    return np_img, alpha_channel, infos, ext
+
+
+def encode_pil_to_base64(image: Image, quality: int, infos: Dict) -> bytes:
+    img_bytes = pil_to_bytes(
+        image,
+        "png",
+        quality=quality,
+        infos=infos,
+    )
+    return base64.b64encode(img_bytes)
+
+
+def concat_alpha_channel(rgb_np_img, alpha_channel) -> np.ndarray:
+    if alpha_channel is not None:
+        if alpha_channel.shape[:2] != rgb_np_img.shape[:2]:
+            alpha_channel = cv2.resize(
+                alpha_channel, dsize=(rgb_np_img.shape[1], rgb_np_img.shape[0])
+            )
+        rgb_np_img = np.concatenate(
+            (rgb_np_img, alpha_channel[:, :, np.newaxis]), axis=-1
+        )
+    return rgb_np_img
+
+
+def adjust_mask(mask: np.ndarray, kernel_size: int, operate):
+    # fronted brush color "ffcc00bb"
+    # kernel_size = kernel_size*2+1
+    mask[mask >= 127] = 255
+    mask[mask < 127] = 0
+
+    if operate == "reverse":
+        mask = 255 - mask
+    else:
+        kernel = cv2.getStructuringElement(
+            cv2.MORPH_ELLIPSE, (2 * kernel_size + 1, 2 * kernel_size + 1)
+        )
+        if operate == "expand":
+            mask = cv2.dilate(
+                mask,
+                kernel,
+                iterations=1,
+            )
+        else:
+            mask = cv2.erode(
+                mask,
+                kernel,
+                iterations=1,
+            )
+    res_mask = np.zeros((mask.shape[0], mask.shape[1], 4), dtype=np.uint8)
+    res_mask[mask > 128] = [255, 203, 0, int(255 * 0.73)]
+    res_mask = cv2.cvtColor(res_mask, cv2.COLOR_BGRA2RGBA)
+    return res_mask
+
+
+def gen_frontend_mask(bgr_or_gray_mask):
+    if len(bgr_or_gray_mask.shape) == 3 and bgr_or_gray_mask.shape[2] != 1:
+        bgr_or_gray_mask = cv2.cvtColor(bgr_or_gray_mask, cv2.COLOR_BGR2GRAY)
+
+    # fronted brush color "ffcc00bb"
+    # TODO: how to set kernel size?
+    kernel_size = 9
+    bgr_or_gray_mask = cv2.dilate(
+        bgr_or_gray_mask,
+        np.ones((kernel_size, kernel_size), np.uint8),
+        iterations=1,
+    )
+    res_mask = np.zeros(
+        (bgr_or_gray_mask.shape[0], bgr_or_gray_mask.shape[1], 4), dtype=np.uint8
+    )
+    res_mask[bgr_or_gray_mask > 128] = [255, 203, 0, int(255 * 0.73)]
+    res_mask = cv2.cvtColor(res_mask, cv2.COLOR_BGRA2RGBA)
+    return res_mask

+ 11 - 0
sorawm/iopaint/installer.py

@@ -0,0 +1,11 @@
+import subprocess
+import sys
+
+
+def install(package):
+    subprocess.check_call([sys.executable, "-m", "pip", "install", package])
+
+
+def install_plugins_package():
+    install("onnxruntime<=1.19.2")
+    install("rembg[cpu]")

+ 38 - 0
sorawm/iopaint/model/__init__.py

@@ -0,0 +1,38 @@
+from .anytext.anytext_model import AnyText
+from .controlnet import ControlNet
+from .fcf import FcF
+from .instruct_pix2pix import InstructPix2Pix
+from .kandinsky import Kandinsky22
+from .lama import AnimeLaMa, LaMa
+from .ldm import LDM
+from .manga import Manga
+from .mat import MAT
+from .mi_gan import MIGAN
+from .opencv2 import OpenCV2
+from .paint_by_example import PaintByExample
+from .power_paint.power_paint import PowerPaint
+from .sd import SD, SD2, SD15, Anything4, RealisticVision14
+from .sdxl import SDXL
+from .zits import ZITS
+
+models = {
+    LaMa.name: LaMa,
+    AnimeLaMa.name: AnimeLaMa,
+    LDM.name: LDM,
+    ZITS.name: ZITS,
+    MAT.name: MAT,
+    FcF.name: FcF,
+    OpenCV2.name: OpenCV2,
+    Manga.name: Manga,
+    MIGAN.name: MIGAN,
+    SD15.name: SD15,
+    Anything4.name: Anything4,
+    RealisticVision14.name: RealisticVision14,
+    SD2.name: SD2,
+    PaintByExample.name: PaintByExample,
+    InstructPix2Pix.name: InstructPix2Pix,
+    Kandinsky22.name: Kandinsky22,
+    SDXL.name: SDXL,
+    PowerPaint.name: PowerPaint,
+    AnyText.name: AnyText,
+}

BIN
sorawm/iopaint/model/__pycache__/__init__.cpython-312.pyc


BIN
sorawm/iopaint/model/__pycache__/base.cpython-312.pyc


BIN
sorawm/iopaint/model/__pycache__/controlnet.cpython-312.pyc


BIN
sorawm/iopaint/model/__pycache__/ddim_sampler.cpython-312.pyc


BIN
sorawm/iopaint/model/__pycache__/fcf.cpython-312.pyc


BIN
sorawm/iopaint/model/__pycache__/instruct_pix2pix.cpython-312.pyc


BIN
sorawm/iopaint/model/__pycache__/kandinsky.cpython-312.pyc


BIN
sorawm/iopaint/model/__pycache__/lama.cpython-312.pyc


BIN
sorawm/iopaint/model/__pycache__/ldm.cpython-312.pyc


BIN
sorawm/iopaint/model/__pycache__/manga.cpython-312.pyc


BIN
sorawm/iopaint/model/__pycache__/mat.cpython-312.pyc


BIN
sorawm/iopaint/model/__pycache__/mi_gan.cpython-312.pyc


BIN
sorawm/iopaint/model/__pycache__/opencv2.cpython-312.pyc


BIN
sorawm/iopaint/model/__pycache__/paint_by_example.cpython-312.pyc


BIN
sorawm/iopaint/model/__pycache__/plms_sampler.cpython-312.pyc


BIN
sorawm/iopaint/model/__pycache__/sd.cpython-312.pyc


BIN
sorawm/iopaint/model/__pycache__/sdxl.cpython-312.pyc


BIN
sorawm/iopaint/model/__pycache__/utils.cpython-312.pyc


BIN
sorawm/iopaint/model/__pycache__/zits.cpython-312.pyc


+ 0 - 0
sorawm/iopaint/model/anytext/__init__.py


BIN
sorawm/iopaint/model/anytext/__pycache__/__init__.cpython-312.pyc


BIN
sorawm/iopaint/model/anytext/__pycache__/anytext_model.cpython-312.pyc


BIN
sorawm/iopaint/model/anytext/__pycache__/anytext_pipeline.cpython-312.pyc


BIN
sorawm/iopaint/model/anytext/__pycache__/utils.cpython-312.pyc


+ 73 - 0
sorawm/iopaint/model/anytext/anytext_model.py

@@ -0,0 +1,73 @@
+import torch
+from huggingface_hub import hf_hub_download
+
+from sorawm.iopaint.const import ANYTEXT_NAME
+from sorawm.iopaint.model.anytext.anytext_pipeline import AnyTextPipeline
+from sorawm.iopaint.model.base import DiffusionInpaintModel
+from sorawm.iopaint.model.utils import get_torch_dtype, is_local_files_only
+from sorawm.iopaint.schema import InpaintRequest
+
+
+class AnyText(DiffusionInpaintModel):
+    name = ANYTEXT_NAME
+    pad_mod = 64
+    is_erase_model = False
+
+    @staticmethod
+    def download(local_files_only=False):
+        hf_hub_download(
+            repo_id=ANYTEXT_NAME,
+            filename="model_index.json",
+            local_files_only=local_files_only,
+        )
+        ckpt_path = hf_hub_download(
+            repo_id=ANYTEXT_NAME,
+            filename="pytorch_model.fp16.safetensors",
+            local_files_only=local_files_only,
+        )
+        font_path = hf_hub_download(
+            repo_id=ANYTEXT_NAME,
+            filename="SourceHanSansSC-Medium.otf",
+            local_files_only=local_files_only,
+        )
+        return ckpt_path, font_path
+
+    def init_model(self, device, **kwargs):
+        local_files_only = is_local_files_only(**kwargs)
+        ckpt_path, font_path = self.download(local_files_only)
+        use_gpu, torch_dtype = get_torch_dtype(device, kwargs.get("no_half", False))
+        self.model = AnyTextPipeline(
+            ckpt_path=ckpt_path,
+            font_path=font_path,
+            device=device,
+            use_fp16=torch_dtype == torch.float16,
+        )
+        self.callback = kwargs.pop("callback", None)
+
+    def forward(self, image, mask, config: InpaintRequest):
+        """Input image and output image have same size
+        image: [H, W, C] RGB
+        mask: [H, W, 1] 255 means area to inpainting
+        return: BGR IMAGE
+        """
+        height, width = image.shape[:2]
+        mask = mask.astype("float32") / 255.0
+        masked_image = image * (1 - mask)
+
+        # list of rgb ndarray
+        results, rtn_code, rtn_warning = self.model(
+            image=image,
+            masked_image=masked_image,
+            prompt=config.prompt,
+            negative_prompt=config.negative_prompt,
+            num_inference_steps=config.sd_steps,
+            strength=config.sd_strength,
+            guidance_scale=config.sd_guidance_scale,
+            height=height,
+            width=width,
+            seed=config.sd_seed,
+            sort_priority="y",
+            callback=self.callback,
+        )
+        inpainted_rgb_image = results[0][..., ::-1]
+        return inpainted_rgb_image

+ 401 - 0
sorawm/iopaint/model/anytext/anytext_pipeline.py

@@ -0,0 +1,401 @@
+"""
+AnyText: Multilingual Visual Text Generation And Editing
+Paper: https://arxiv.org/abs/2311.03054
+Code: https://github.com/tyxsspa/AnyText
+Copyright (c) Alibaba, Inc. and its affiliates.
+"""
+import os
+from pathlib import Path
+
+from safetensors.torch import load_file
+
+from sorawm.iopaint.model.utils import set_seed
+
+os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
+import re
+
+import cv2
+import einops
+import numpy as np
+import torch
+from PIL import ImageFont
+
+from sorawm.iopaint.model.anytext.cldm.ddim_hacked import DDIMSampler
+from sorawm.iopaint.model.anytext.cldm.model import create_model, load_state_dict
+from sorawm.iopaint.model.anytext.utils import check_channels, draw_glyph, draw_glyph2
+
+BBOX_MAX_NUM = 8
+PLACE_HOLDER = "*"
+max_chars = 20
+
+ANYTEXT_CFG = os.path.join(
+    os.path.dirname(os.path.abspath(__file__)), "anytext_sd15.yaml"
+)
+
+
+def check_limits(tensor):
+    float16_min = torch.finfo(torch.float16).min
+    float16_max = torch.finfo(torch.float16).max
+
+    # 检查张量中是否有值小于float16的最小值或大于float16的最大值
+    is_below_min = (tensor < float16_min).any()
+    is_above_max = (tensor > float16_max).any()
+
+    return is_below_min or is_above_max
+
+
+class AnyTextPipeline:
+    def __init__(self, ckpt_path, font_path, device, use_fp16=True):
+        self.cfg_path = ANYTEXT_CFG
+        self.font_path = font_path
+        self.use_fp16 = use_fp16
+        self.device = device
+
+        self.font = ImageFont.truetype(font_path, size=60)
+        self.model = create_model(
+            self.cfg_path,
+            device=self.device,
+            use_fp16=self.use_fp16,
+        )
+        if self.use_fp16:
+            self.model = self.model.half()
+        if Path(ckpt_path).suffix == ".safetensors":
+            state_dict = load_file(ckpt_path, device="cpu")
+        else:
+            state_dict = load_state_dict(ckpt_path, location="cpu")
+        self.model.load_state_dict(state_dict, strict=False)
+        self.model = self.model.eval().to(self.device)
+        self.ddim_sampler = DDIMSampler(self.model, device=self.device)
+
+    def __call__(
+        self,
+        prompt: str,
+        negative_prompt: str,
+        image: np.ndarray,
+        masked_image: np.ndarray,
+        num_inference_steps: int,
+        strength: float,
+        guidance_scale: float,
+        height: int,
+        width: int,
+        seed: int,
+        sort_priority: str = "y",
+        callback=None,
+    ):
+        """
+
+        Args:
+            prompt:
+            negative_prompt:
+            image:
+            masked_image:
+            num_inference_steps:
+            strength:
+            guidance_scale:
+            height:
+            width:
+            seed:
+            sort_priority: x: left-right, y: top-down
+
+        Returns:
+            result: list of images in numpy.ndarray format
+            rst_code: 0: normal -1: error 1:warning
+            rst_info: string of error or warning
+
+        """
+        set_seed(seed)
+        str_warning = ""
+
+        mode = "text-editing"
+        revise_pos = False
+        img_count = 1
+        ddim_steps = num_inference_steps
+        w = width
+        h = height
+        strength = strength
+        cfg_scale = guidance_scale
+        eta = 0.0
+
+        prompt, texts = self.modify_prompt(prompt)
+        if prompt is None and texts is None:
+            return (
+                None,
+                -1,
+                "You have input Chinese prompt but the translator is not loaded!",
+                "",
+            )
+        n_lines = len(texts)
+        if mode in ["text-generation", "gen"]:
+            edit_image = np.ones((h, w, 3)) * 127.5  # empty mask image
+        elif mode in ["text-editing", "edit"]:
+            if masked_image is None or image is None:
+                return (
+                    None,
+                    -1,
+                    "Reference image and position image are needed for text editing!",
+                    "",
+                )
+            if isinstance(image, str):
+                image = cv2.imread(image)[..., ::-1]
+                assert image is not None, f"Can't read ori_image image from{image}!"
+            elif isinstance(image, torch.Tensor):
+                image = image.cpu().numpy()
+            else:
+                assert isinstance(
+                    image, np.ndarray
+                ), f"Unknown format of ori_image: {type(image)}"
+            edit_image = image.clip(1, 255)  # for mask reason
+            edit_image = check_channels(edit_image)
+            # edit_image = resize_image(
+            #     edit_image, max_length=768
+            # )  # make w h multiple of 64, resize if w or h > max_length
+            h, w = edit_image.shape[:2]  # change h, w by input ref_img
+        # preprocess pos_imgs(if numpy, make sure it's white pos in black bg)
+        if masked_image is None:
+            pos_imgs = np.zeros((w, h, 1))
+        if isinstance(masked_image, str):
+            masked_image = cv2.imread(masked_image)[..., ::-1]
+            assert (
+                masked_image is not None
+            ), f"Can't read draw_pos image from{masked_image}!"
+            pos_imgs = 255 - masked_image
+        elif isinstance(masked_image, torch.Tensor):
+            pos_imgs = masked_image.cpu().numpy()
+        else:
+            assert isinstance(
+                masked_image, np.ndarray
+            ), f"Unknown format of draw_pos: {type(masked_image)}"
+            pos_imgs = 255 - masked_image
+        pos_imgs = pos_imgs[..., 0:1]
+        pos_imgs = cv2.convertScaleAbs(pos_imgs)
+        _, pos_imgs = cv2.threshold(pos_imgs, 254, 255, cv2.THRESH_BINARY)
+        # seprate pos_imgs
+        pos_imgs = self.separate_pos_imgs(pos_imgs, sort_priority)
+        if len(pos_imgs) == 0:
+            pos_imgs = [np.zeros((h, w, 1))]
+        if len(pos_imgs) < n_lines:
+            if n_lines == 1 and texts[0] == " ":
+                pass  # text-to-image without text
+            else:
+                raise RuntimeError(
+                    f"{n_lines} text line to draw from prompt, not enough mask area({len(pos_imgs)}) on images"
+                )
+        elif len(pos_imgs) > n_lines:
+            str_warning = f"Warning: found {len(pos_imgs)} positions that > needed {n_lines} from prompt."
+        # get pre_pos, poly_list, hint that needed for anytext
+        pre_pos = []
+        poly_list = []
+        for input_pos in pos_imgs:
+            if input_pos.mean() != 0:
+                input_pos = (
+                    input_pos[..., np.newaxis]
+                    if len(input_pos.shape) == 2
+                    else input_pos
+                )
+                poly, pos_img = self.find_polygon(input_pos)
+                pre_pos += [pos_img / 255.0]
+                poly_list += [poly]
+            else:
+                pre_pos += [np.zeros((h, w, 1))]
+                poly_list += [None]
+        np_hint = np.sum(pre_pos, axis=0).clip(0, 1)
+        # prepare info dict
+        info = {}
+        info["glyphs"] = []
+        info["gly_line"] = []
+        info["positions"] = []
+        info["n_lines"] = [len(texts)] * img_count
+        gly_pos_imgs = []
+        for i in range(len(texts)):
+            text = texts[i]
+            if len(text) > max_chars:
+                str_warning = (
+                    f'"{text}" length > max_chars: {max_chars}, will be cut off...'
+                )
+                text = text[:max_chars]
+            gly_scale = 2
+            if pre_pos[i].mean() != 0:
+                gly_line = draw_glyph(self.font, text)
+                glyphs = draw_glyph2(
+                    self.font,
+                    text,
+                    poly_list[i],
+                    scale=gly_scale,
+                    width=w,
+                    height=h,
+                    add_space=False,
+                )
+                gly_pos_img = cv2.drawContours(
+                    glyphs * 255, [poly_list[i] * gly_scale], 0, (255, 255, 255), 1
+                )
+                if revise_pos:
+                    resize_gly = cv2.resize(
+                        glyphs, (pre_pos[i].shape[1], pre_pos[i].shape[0])
+                    )
+                    new_pos = cv2.morphologyEx(
+                        (resize_gly * 255).astype(np.uint8),
+                        cv2.MORPH_CLOSE,
+                        kernel=np.ones(
+                            (resize_gly.shape[0] // 10, resize_gly.shape[1] // 10),
+                            dtype=np.uint8,
+                        ),
+                        iterations=1,
+                    )
+                    new_pos = (
+                        new_pos[..., np.newaxis] if len(new_pos.shape) == 2 else new_pos
+                    )
+                    contours, _ = cv2.findContours(
+                        new_pos, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE
+                    )
+                    if len(contours) != 1:
+                        str_warning = f"Fail to revise position {i} to bounding rect, remain position unchanged..."
+                    else:
+                        rect = cv2.minAreaRect(contours[0])
+                        poly = np.int0(cv2.boxPoints(rect))
+                        pre_pos[i] = (
+                            cv2.drawContours(new_pos, [poly], -1, 255, -1) / 255.0
+                        )
+                        gly_pos_img = cv2.drawContours(
+                            glyphs * 255, [poly * gly_scale], 0, (255, 255, 255), 1
+                        )
+                gly_pos_imgs += [gly_pos_img]  # for show
+            else:
+                glyphs = np.zeros((h * gly_scale, w * gly_scale, 1))
+                gly_line = np.zeros((80, 512, 1))
+                gly_pos_imgs += [
+                    np.zeros((h * gly_scale, w * gly_scale, 1))
+                ]  # for show
+            pos = pre_pos[i]
+            info["glyphs"] += [self.arr2tensor(glyphs, img_count)]
+            info["gly_line"] += [self.arr2tensor(gly_line, img_count)]
+            info["positions"] += [self.arr2tensor(pos, img_count)]
+        # get masked_x
+        masked_img = ((edit_image.astype(np.float32) / 127.5) - 1.0) * (1 - np_hint)
+        masked_img = np.transpose(masked_img, (2, 0, 1))
+        masked_img = torch.from_numpy(masked_img.copy()).float().to(self.device)
+        if self.use_fp16:
+            masked_img = masked_img.half()
+        encoder_posterior = self.model.encode_first_stage(masked_img[None, ...])
+        masked_x = self.model.get_first_stage_encoding(encoder_posterior).detach()
+        if self.use_fp16:
+            masked_x = masked_x.half()
+        info["masked_x"] = torch.cat([masked_x for _ in range(img_count)], dim=0)
+
+        hint = self.arr2tensor(np_hint, img_count)
+        cond = self.model.get_learned_conditioning(
+            dict(
+                c_concat=[hint],
+                c_crossattn=[[prompt] * img_count],
+                text_info=info,
+            )
+        )
+        un_cond = self.model.get_learned_conditioning(
+            dict(
+                c_concat=[hint],
+                c_crossattn=[[negative_prompt] * img_count],
+                text_info=info,
+            )
+        )
+        shape = (4, h // 8, w // 8)
+        self.model.control_scales = [strength] * 13
+        samples, intermediates = self.ddim_sampler.sample(
+            ddim_steps,
+            img_count,
+            shape,
+            cond,
+            verbose=False,
+            eta=eta,
+            unconditional_guidance_scale=cfg_scale,
+            unconditional_conditioning=un_cond,
+            callback=callback,
+        )
+        if self.use_fp16:
+            samples = samples.half()
+        x_samples = self.model.decode_first_stage(samples)
+        x_samples = (
+            (einops.rearrange(x_samples, "b c h w -> b h w c") * 127.5 + 127.5)
+            .cpu()
+            .numpy()
+            .clip(0, 255)
+            .astype(np.uint8)
+        )
+        results = [x_samples[i] for i in range(img_count)]
+        # if (
+        #     mode == "edit" and False
+        # ):  # replace backgound in text editing but not ideal yet
+        #     results = [r * np_hint + edit_image * (1 - np_hint) for r in results]
+        #     results = [r.clip(0, 255).astype(np.uint8) for r in results]
+        # if len(gly_pos_imgs) > 0 and show_debug:
+        #     glyph_bs = np.stack(gly_pos_imgs, axis=2)
+        #     glyph_img = np.sum(glyph_bs, axis=2) * 255
+        #     glyph_img = glyph_img.clip(0, 255).astype(np.uint8)
+        #     results += [np.repeat(glyph_img, 3, axis=2)]
+        rst_code = 1 if str_warning else 0
+        return results, rst_code, str_warning
+
+    def modify_prompt(self, prompt):
+        prompt = prompt.replace("“", '"')
+        prompt = prompt.replace("”", '"')
+        p = '"(.*?)"'
+        strs = re.findall(p, prompt)
+        if len(strs) == 0:
+            strs = [" "]
+        else:
+            for s in strs:
+                prompt = prompt.replace(f'"{s}"', f" {PLACE_HOLDER} ", 1)
+        # if self.is_chinese(prompt):
+        #     if self.trans_pipe is None:
+        #         return None, None
+        #     old_prompt = prompt
+        #     prompt = self.trans_pipe(input=prompt + " .")["translation"][:-1]
+        #     print(f"Translate: {old_prompt} --> {prompt}")
+        return prompt, strs
+
+    # def is_chinese(self, text):
+    #     text = checker._clean_text(text)
+    #     for char in text:
+    #         cp = ord(char)
+    #         if checker._is_chinese_char(cp):
+    #             return True
+    #     return False
+
+    def separate_pos_imgs(self, img, sort_priority, gap=102):
+        num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(img)
+        components = []
+        for label in range(1, num_labels):
+            component = np.zeros_like(img)
+            component[labels == label] = 255
+            components.append((component, centroids[label]))
+        if sort_priority == "y":
+            fir, sec = 1, 0  # top-down first
+        elif sort_priority == "x":
+            fir, sec = 0, 1  # left-right first
+        components.sort(key=lambda c: (c[1][fir] // gap, c[1][sec] // gap))
+        sorted_components = [c[0] for c in components]
+        return sorted_components
+
+    def find_polygon(self, image, min_rect=False):
+        contours, hierarchy = cv2.findContours(
+            image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE
+        )
+        max_contour = max(contours, key=cv2.contourArea)  # get contour with max area
+        if min_rect:
+            # get minimum enclosing rectangle
+            rect = cv2.minAreaRect(max_contour)
+            poly = np.int0(cv2.boxPoints(rect))
+        else:
+            # get approximate polygon
+            epsilon = 0.01 * cv2.arcLength(max_contour, True)
+            poly = cv2.approxPolyDP(max_contour, epsilon, True)
+            n, _, xy = poly.shape
+            poly = poly.reshape(n, xy)
+        cv2.drawContours(image, [poly], -1, 255, -1)
+        return poly, image
+
+    def arr2tensor(self, arr, bs):
+        arr = np.transpose(arr, (2, 0, 1))
+        _arr = torch.from_numpy(arr.copy()).float().to(self.device)
+        if self.use_fp16:
+            _arr = _arr.half()
+        _arr = torch.stack([_arr for _ in range(bs)], dim=0)
+        return _arr

+ 99 - 0
sorawm/iopaint/model/anytext/anytext_sd15.yaml

@@ -0,0 +1,99 @@
+model:
+  target: sorawm.iopaint.model.anytext.cldm.cldm.ControlLDM
+  params:
+    linear_start: 0.00085
+    linear_end: 0.0120
+    num_timesteps_cond: 1
+    log_every_t: 200
+    timesteps: 1000
+    first_stage_key: "img"
+    cond_stage_key: "caption"
+    control_key: "hint"
+    glyph_key: "glyphs"
+    position_key: "positions"
+    image_size: 64
+    channels: 4
+    cond_stage_trainable: true  # need be true when embedding_manager is valid
+    conditioning_key: crossattn
+    monitor: val/loss_simple_ema
+    scale_factor: 0.18215
+    use_ema: False
+    only_mid_control: False
+    loss_alpha: 0  # perceptual loss, 0.003
+    loss_beta: 0  # ctc loss
+    latin_weight: 1.0  # latin text line may need smaller weigth
+    with_step_weight: true
+    use_vae_upsample: true
+    embedding_manager_config:
+      target: sorawm.iopaint.model.anytext.cldm.embedding_manager.EmbeddingManager
+      params:
+        valid: true  # v6
+        emb_type: ocr  # ocr, vit, conv
+        glyph_channels: 1
+        position_channels: 1
+        add_pos: false
+        placeholder_string: '*'
+
+    control_stage_config:
+      target: sorawm.iopaint.model.anytext.cldm.cldm.ControlNet
+      params:
+        image_size: 32 # unused
+        in_channels: 4
+        model_channels: 320
+        glyph_channels: 1
+        position_channels: 1
+        attention_resolutions: [ 4, 2, 1 ]
+        num_res_blocks: 2
+        channel_mult: [ 1, 2, 4, 4 ]
+        num_heads: 8
+        use_spatial_transformer: True
+        transformer_depth: 1
+        context_dim: 768
+        use_checkpoint: True
+        legacy: False
+
+    unet_config:
+      target: sorawm.iopaint.model.anytext.cldm.cldm.ControlledUnetModel
+      params:
+        image_size: 32 # unused
+        in_channels: 4
+        out_channels: 4
+        model_channels: 320
+        attention_resolutions: [ 4, 2, 1 ]
+        num_res_blocks: 2
+        channel_mult: [ 1, 2, 4, 4 ]
+        num_heads: 8
+        use_spatial_transformer: True
+        transformer_depth: 1
+        context_dim: 768
+        use_checkpoint: True
+        legacy: False
+
+    first_stage_config:
+      target: sorawm.iopaint.model.anytext.ldm.models.autoencoder.AutoencoderKL
+      params:
+        embed_dim: 4
+        monitor: val/rec_loss
+        ddconfig:
+          double_z: true
+          z_channels: 4
+          resolution: 256
+          in_channels: 3
+          out_ch: 3
+          ch: 128
+          ch_mult:
+          - 1
+          - 2
+          - 4
+          - 4
+          num_res_blocks: 2
+          attn_resolutions: []
+          dropout: 0.0
+        lossconfig:
+          target: torch.nn.Identity
+
+    cond_stage_config:
+      target: sorawm.iopaint.model.anytext.ldm.modules.encoders.modules.FrozenCLIPEmbedderT3
+      params:
+        version: openai/clip-vit-large-patch14
+        use_vision: false  # v6

+ 0 - 0
sorawm/iopaint/model/anytext/cldm/__init__.py


BIN
sorawm/iopaint/model/anytext/cldm/__pycache__/__init__.cpython-312.pyc


BIN
sorawm/iopaint/model/anytext/cldm/__pycache__/ddim_hacked.cpython-312.pyc


BIN
sorawm/iopaint/model/anytext/cldm/__pycache__/model.cpython-312.pyc


+ 780 - 0
sorawm/iopaint/model/anytext/cldm/cldm.py

@@ -0,0 +1,780 @@
+import copy
+import os
+from pathlib import Path
+
+import einops
+import torch
+import torch as th
+import torch.nn as nn
+from easydict import EasyDict as edict
+from einops import rearrange, repeat
+
+from sorawm.iopaint.model.anytext.ldm.models.diffusion.ddim import DDIMSampler
+from sorawm.iopaint.model.anytext.ldm.models.diffusion.ddpm import LatentDiffusion
+from sorawm.iopaint.model.anytext.ldm.modules.attention import SpatialTransformer
+from sorawm.iopaint.model.anytext.ldm.modules.diffusionmodules.openaimodel import (
+    AttentionBlock,
+    Downsample,
+    ResBlock,
+    TimestepEmbedSequential,
+    UNetModel,
+)
+from sorawm.iopaint.model.anytext.ldm.modules.diffusionmodules.util import (
+    conv_nd,
+    linear,
+    timestep_embedding,
+    zero_module,
+)
+from sorawm.iopaint.model.anytext.ldm.modules.distributions.distributions import (
+    DiagonalGaussianDistribution,
+)
+from sorawm.iopaint.model.anytext.ldm.util import (
+    exists,
+    instantiate_from_config,
+    log_txt_as_img,
+)
+
+from .recognizer import TextRecognizer, create_predictor
+
+CURRENT_DIR = Path(os.path.dirname(os.path.abspath(__file__)))
+
+
+def count_parameters(model):
+    return sum(p.numel() for p in model.parameters() if p.requires_grad)
+
+
+class ControlledUnetModel(UNetModel):
+    def forward(
+        self,
+        x,
+        timesteps=None,
+        context=None,
+        control=None,
+        only_mid_control=False,
+        **kwargs,
+    ):
+        hs = []
+        with torch.no_grad():
+            t_emb = timestep_embedding(
+                timesteps, self.model_channels, repeat_only=False
+            )
+            if self.use_fp16:
+                t_emb = t_emb.half()
+            emb = self.time_embed(t_emb)
+            h = x.type(self.dtype)
+            for module in self.input_blocks:
+                h = module(h, emb, context)
+                hs.append(h)
+            h = self.middle_block(h, emb, context)
+
+        if control is not None:
+            h += control.pop()
+
+        for i, module in enumerate(self.output_blocks):
+            if only_mid_control or control is None:
+                h = torch.cat([h, hs.pop()], dim=1)
+            else:
+                h = torch.cat([h, hs.pop() + control.pop()], dim=1)
+            h = module(h, emb, context)
+
+        h = h.type(x.dtype)
+        return self.out(h)
+
+
+class ControlNet(nn.Module):
+    def __init__(
+        self,
+        image_size,
+        in_channels,
+        model_channels,
+        glyph_channels,
+        position_channels,
+        num_res_blocks,
+        attention_resolutions,
+        dropout=0,
+        channel_mult=(1, 2, 4, 8),
+        conv_resample=True,
+        dims=2,
+        use_checkpoint=False,
+        use_fp16=False,
+        num_heads=-1,
+        num_head_channels=-1,
+        num_heads_upsample=-1,
+        use_scale_shift_norm=False,
+        resblock_updown=False,
+        use_new_attention_order=False,
+        use_spatial_transformer=False,  # custom transformer support
+        transformer_depth=1,  # custom transformer support
+        context_dim=None,  # custom transformer support
+        n_embed=None,  # custom support for prediction of discrete ids into codebook of first stage vq model
+        legacy=True,
+        disable_self_attentions=None,
+        num_attention_blocks=None,
+        disable_middle_self_attn=False,
+        use_linear_in_transformer=False,
+    ):
+        super().__init__()
+        if use_spatial_transformer:
+            assert (
+                context_dim is not None
+            ), "Fool!! You forgot to include the dimension of your cross-attention conditioning..."
+
+        if context_dim is not None:
+            assert (
+                use_spatial_transformer
+            ), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..."
+            from omegaconf.listconfig import ListConfig
+
+            if type(context_dim) == ListConfig:
+                context_dim = list(context_dim)
+
+        if num_heads_upsample == -1:
+            num_heads_upsample = num_heads
+
+        if num_heads == -1:
+            assert (
+                num_head_channels != -1
+            ), "Either num_heads or num_head_channels has to be set"
+
+        if num_head_channels == -1:
+            assert (
+                num_heads != -1
+            ), "Either num_heads or num_head_channels has to be set"
+        self.dims = dims
+        self.image_size = image_size
+        self.in_channels = in_channels
+        self.model_channels = model_channels
+        if isinstance(num_res_blocks, int):
+            self.num_res_blocks = len(channel_mult) * [num_res_blocks]
+        else:
+            if len(num_res_blocks) != len(channel_mult):
+                raise ValueError(
+                    "provide num_res_blocks either as an int (globally constant) or "
+                    "as a list/tuple (per-level) with the same length as channel_mult"
+                )
+            self.num_res_blocks = num_res_blocks
+        if disable_self_attentions is not None:
+            # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
+            assert len(disable_self_attentions) == len(channel_mult)
+        if num_attention_blocks is not None:
+            assert len(num_attention_blocks) == len(self.num_res_blocks)
+            assert all(
+                map(
+                    lambda i: self.num_res_blocks[i] >= num_attention_blocks[i],
+                    range(len(num_attention_blocks)),
+                )
+            )
+            print(
+                f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
+                f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
+                f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
+                f"attention will still not be set."
+            )
+        self.attention_resolutions = attention_resolutions
+        self.dropout = dropout
+        self.channel_mult = channel_mult
+        self.conv_resample = conv_resample
+        self.use_checkpoint = use_checkpoint
+        self.use_fp16 = use_fp16
+        self.dtype = th.float16 if use_fp16 else th.float32
+        self.num_heads = num_heads
+        self.num_head_channels = num_head_channels
+        self.num_heads_upsample = num_heads_upsample
+        self.predict_codebook_ids = n_embed is not None
+
+        time_embed_dim = model_channels * 4
+        self.time_embed = nn.Sequential(
+            linear(model_channels, time_embed_dim),
+            nn.SiLU(),
+            linear(time_embed_dim, time_embed_dim),
+        )
+
+        self.input_blocks = nn.ModuleList(
+            [
+                TimestepEmbedSequential(
+                    conv_nd(dims, in_channels, model_channels, 3, padding=1)
+                )
+            ]
+        )
+        self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels)])
+
+        self.glyph_block = TimestepEmbedSequential(
+            conv_nd(dims, glyph_channels, 8, 3, padding=1),
+            nn.SiLU(),
+            conv_nd(dims, 8, 8, 3, padding=1),
+            nn.SiLU(),
+            conv_nd(dims, 8, 16, 3, padding=1, stride=2),
+            nn.SiLU(),
+            conv_nd(dims, 16, 16, 3, padding=1),
+            nn.SiLU(),
+            conv_nd(dims, 16, 32, 3, padding=1, stride=2),
+            nn.SiLU(),
+            conv_nd(dims, 32, 32, 3, padding=1),
+            nn.SiLU(),
+            conv_nd(dims, 32, 96, 3, padding=1, stride=2),
+            nn.SiLU(),
+            conv_nd(dims, 96, 96, 3, padding=1),
+            nn.SiLU(),
+            conv_nd(dims, 96, 256, 3, padding=1, stride=2),
+            nn.SiLU(),
+        )
+
+        self.position_block = TimestepEmbedSequential(
+            conv_nd(dims, position_channels, 8, 3, padding=1),
+            nn.SiLU(),
+            conv_nd(dims, 8, 8, 3, padding=1),
+            nn.SiLU(),
+            conv_nd(dims, 8, 16, 3, padding=1, stride=2),
+            nn.SiLU(),
+            conv_nd(dims, 16, 16, 3, padding=1),
+            nn.SiLU(),
+            conv_nd(dims, 16, 32, 3, padding=1, stride=2),
+            nn.SiLU(),
+            conv_nd(dims, 32, 32, 3, padding=1),
+            nn.SiLU(),
+            conv_nd(dims, 32, 64, 3, padding=1, stride=2),
+            nn.SiLU(),
+        )
+
+        self.fuse_block = zero_module(
+            conv_nd(dims, 256 + 64 + 4, model_channels, 3, padding=1)
+        )
+
+        self._feature_size = model_channels
+        input_block_chans = [model_channels]
+        ch = model_channels
+        ds = 1
+        for level, mult in enumerate(channel_mult):
+            for nr in range(self.num_res_blocks[level]):
+                layers = [
+                    ResBlock(
+                        ch,
+                        time_embed_dim,
+                        dropout,
+                        out_channels=mult * model_channels,
+                        dims=dims,
+                        use_checkpoint=use_checkpoint,
+                        use_scale_shift_norm=use_scale_shift_norm,
+                    )
+                ]
+                ch = mult * model_channels
+                if ds in attention_resolutions:
+                    if num_head_channels == -1:
+                        dim_head = ch // num_heads
+                    else:
+                        num_heads = ch // num_head_channels
+                        dim_head = num_head_channels
+                    if legacy:
+                        # num_heads = 1
+                        dim_head = (
+                            ch // num_heads
+                            if use_spatial_transformer
+                            else num_head_channels
+                        )
+                    if exists(disable_self_attentions):
+                        disabled_sa = disable_self_attentions[level]
+                    else:
+                        disabled_sa = False
+
+                    if (
+                        not exists(num_attention_blocks)
+                        or nr < num_attention_blocks[level]
+                    ):
+                        layers.append(
+                            AttentionBlock(
+                                ch,
+                                use_checkpoint=use_checkpoint,
+                                num_heads=num_heads,
+                                num_head_channels=dim_head,
+                                use_new_attention_order=use_new_attention_order,
+                            )
+                            if not use_spatial_transformer
+                            else SpatialTransformer(
+                                ch,
+                                num_heads,
+                                dim_head,
+                                depth=transformer_depth,
+                                context_dim=context_dim,
+                                disable_self_attn=disabled_sa,
+                                use_linear=use_linear_in_transformer,
+                                use_checkpoint=use_checkpoint,
+                            )
+                        )
+                self.input_blocks.append(TimestepEmbedSequential(*layers))
+                self.zero_convs.append(self.make_zero_conv(ch))
+                self._feature_size += ch
+                input_block_chans.append(ch)
+            if level != len(channel_mult) - 1:
+                out_ch = ch
+                self.input_blocks.append(
+                    TimestepEmbedSequential(
+                        ResBlock(
+                            ch,
+                            time_embed_dim,
+                            dropout,
+                            out_channels=out_ch,
+                            dims=dims,
+                            use_checkpoint=use_checkpoint,
+                            use_scale_shift_norm=use_scale_shift_norm,
+                            down=True,
+                        )
+                        if resblock_updown
+                        else Downsample(
+                            ch, conv_resample, dims=dims, out_channels=out_ch
+                        )
+                    )
+                )
+                ch = out_ch
+                input_block_chans.append(ch)
+                self.zero_convs.append(self.make_zero_conv(ch))
+                ds *= 2
+                self._feature_size += ch
+
+        if num_head_channels == -1:
+            dim_head = ch // num_heads
+        else:
+            num_heads = ch // num_head_channels
+            dim_head = num_head_channels
+        if legacy:
+            # num_heads = 1
+            dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+        self.middle_block = TimestepEmbedSequential(
+            ResBlock(
+                ch,
+                time_embed_dim,
+                dropout,
+                dims=dims,
+                use_checkpoint=use_checkpoint,
+                use_scale_shift_norm=use_scale_shift_norm,
+            ),
+            AttentionBlock(
+                ch,
+                use_checkpoint=use_checkpoint,
+                num_heads=num_heads,
+                num_head_channels=dim_head,
+                use_new_attention_order=use_new_attention_order,
+            )
+            if not use_spatial_transformer
+            else SpatialTransformer(  # always uses a self-attn
+                ch,
+                num_heads,
+                dim_head,
+                depth=transformer_depth,
+                context_dim=context_dim,
+                disable_self_attn=disable_middle_self_attn,
+                use_linear=use_linear_in_transformer,
+                use_checkpoint=use_checkpoint,
+            ),
+            ResBlock(
+                ch,
+                time_embed_dim,
+                dropout,
+                dims=dims,
+                use_checkpoint=use_checkpoint,
+                use_scale_shift_norm=use_scale_shift_norm,
+            ),
+        )
+        self.middle_block_out = self.make_zero_conv(ch)
+        self._feature_size += ch
+
+    def make_zero_conv(self, channels):
+        return TimestepEmbedSequential(
+            zero_module(conv_nd(self.dims, channels, channels, 1, padding=0))
+        )
+
+    def forward(self, x, hint, text_info, timesteps, context, **kwargs):
+        t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
+        if self.use_fp16:
+            t_emb = t_emb.half()
+        emb = self.time_embed(t_emb)
+
+        # guided_hint from text_info
+        B, C, H, W = x.shape
+        glyphs = torch.cat(text_info["glyphs"], dim=1).sum(dim=1, keepdim=True)
+        positions = torch.cat(text_info["positions"], dim=1).sum(dim=1, keepdim=True)
+        enc_glyph = self.glyph_block(glyphs, emb, context)
+        enc_pos = self.position_block(positions, emb, context)
+        guided_hint = self.fuse_block(
+            torch.cat([enc_glyph, enc_pos, text_info["masked_x"]], dim=1)
+        )
+
+        outs = []
+
+        h = x.type(self.dtype)
+        for module, zero_conv in zip(self.input_blocks, self.zero_convs):
+            if guided_hint is not None:
+                h = module(h, emb, context)
+                h += guided_hint
+                guided_hint = None
+            else:
+                h = module(h, emb, context)
+            outs.append(zero_conv(h, emb, context))
+
+        h = self.middle_block(h, emb, context)
+        outs.append(self.middle_block_out(h, emb, context))
+
+        return outs
+
+
+class ControlLDM(LatentDiffusion):
+    def __init__(
+        self,
+        control_stage_config,
+        control_key,
+        glyph_key,
+        position_key,
+        only_mid_control,
+        loss_alpha=0,
+        loss_beta=0,
+        with_step_weight=False,
+        use_vae_upsample=False,
+        latin_weight=1.0,
+        embedding_manager_config=None,
+        *args,
+        **kwargs,
+    ):
+        self.use_fp16 = kwargs.pop("use_fp16", False)
+        super().__init__(*args, **kwargs)
+        self.control_model = instantiate_from_config(control_stage_config)
+        self.control_key = control_key
+        self.glyph_key = glyph_key
+        self.position_key = position_key
+        self.only_mid_control = only_mid_control
+        self.control_scales = [1.0] * 13
+        self.loss_alpha = loss_alpha
+        self.loss_beta = loss_beta
+        self.with_step_weight = with_step_weight
+        self.use_vae_upsample = use_vae_upsample
+        self.latin_weight = latin_weight
+
+        if (
+            embedding_manager_config is not None
+            and embedding_manager_config.params.valid
+        ):
+            self.embedding_manager = self.instantiate_embedding_manager(
+                embedding_manager_config, self.cond_stage_model
+            )
+            for param in self.embedding_manager.embedding_parameters():
+                param.requires_grad = True
+        else:
+            self.embedding_manager = None
+        if self.loss_alpha > 0 or self.loss_beta > 0 or self.embedding_manager:
+            if embedding_manager_config.params.emb_type == "ocr":
+                self.text_predictor = create_predictor().eval()
+                args = edict()
+                args.rec_image_shape = "3, 48, 320"
+                args.rec_batch_num = 6
+                args.rec_char_dict_path = str(
+                    CURRENT_DIR.parent / "ocr_recog" / "ppocr_keys_v1.txt"
+                )
+                args.use_fp16 = self.use_fp16
+                self.cn_recognizer = TextRecognizer(args, self.text_predictor)
+                for param in self.text_predictor.parameters():
+                    param.requires_grad = False
+                if self.embedding_manager:
+                    self.embedding_manager.recog = self.cn_recognizer
+
+    @torch.no_grad()
+    def get_input(self, batch, k, bs=None, *args, **kwargs):
+        if self.embedding_manager is None:  # fill in full caption
+            self.fill_caption(batch)
+        x, c, mx = super().get_input(
+            batch, self.first_stage_key, mask_k="masked_img", *args, **kwargs
+        )
+        control = batch[
+            self.control_key
+        ]  # for log_images and loss_alpha, not real control
+        if bs is not None:
+            control = control[:bs]
+        control = control.to(self.device)
+        control = einops.rearrange(control, "b h w c -> b c h w")
+        control = control.to(memory_format=torch.contiguous_format).float()
+
+        inv_mask = batch["inv_mask"]
+        if bs is not None:
+            inv_mask = inv_mask[:bs]
+        inv_mask = inv_mask.to(self.device)
+        inv_mask = einops.rearrange(inv_mask, "b h w c -> b c h w")
+        inv_mask = inv_mask.to(memory_format=torch.contiguous_format).float()
+
+        glyphs = batch[self.glyph_key]
+        gly_line = batch["gly_line"]
+        positions = batch[self.position_key]
+        n_lines = batch["n_lines"]
+        language = batch["language"]
+        texts = batch["texts"]
+        assert len(glyphs) == len(positions)
+        for i in range(len(glyphs)):
+            if bs is not None:
+                glyphs[i] = glyphs[i][:bs]
+                gly_line[i] = gly_line[i][:bs]
+                positions[i] = positions[i][:bs]
+                n_lines = n_lines[:bs]
+            glyphs[i] = glyphs[i].to(self.device)
+            gly_line[i] = gly_line[i].to(self.device)
+            positions[i] = positions[i].to(self.device)
+            glyphs[i] = einops.rearrange(glyphs[i], "b h w c -> b c h w")
+            gly_line[i] = einops.rearrange(gly_line[i], "b h w c -> b c h w")
+            positions[i] = einops.rearrange(positions[i], "b h w c -> b c h w")
+            glyphs[i] = glyphs[i].to(memory_format=torch.contiguous_format).float()
+            gly_line[i] = gly_line[i].to(memory_format=torch.contiguous_format).float()
+            positions[i] = (
+                positions[i].to(memory_format=torch.contiguous_format).float()
+            )
+        info = {}
+        info["glyphs"] = glyphs
+        info["positions"] = positions
+        info["n_lines"] = n_lines
+        info["language"] = language
+        info["texts"] = texts
+        info["img"] = batch["img"]  # nhwc, (-1,1)
+        info["masked_x"] = mx
+        info["gly_line"] = gly_line
+        info["inv_mask"] = inv_mask
+        return x, dict(c_crossattn=[c], c_concat=[control], text_info=info)
+
+    def apply_model(self, x_noisy, t, cond, *args, **kwargs):
+        assert isinstance(cond, dict)
+        diffusion_model = self.model.diffusion_model
+        _cond = torch.cat(cond["c_crossattn"], 1)
+        _hint = torch.cat(cond["c_concat"], 1)
+        if self.use_fp16:
+            x_noisy = x_noisy.half()
+        control = self.control_model(
+            x=x_noisy,
+            timesteps=t,
+            context=_cond,
+            hint=_hint,
+            text_info=cond["text_info"],
+        )
+        control = [c * scale for c, scale in zip(control, self.control_scales)]
+        eps = diffusion_model(
+            x=x_noisy,
+            timesteps=t,
+            context=_cond,
+            control=control,
+            only_mid_control=self.only_mid_control,
+        )
+
+        return eps
+
+    def instantiate_embedding_manager(self, config, embedder):
+        model = instantiate_from_config(config, embedder=embedder)
+        return model
+
+    @torch.no_grad()
+    def get_unconditional_conditioning(self, N):
+        return self.get_learned_conditioning(
+            dict(c_crossattn=[[""] * N], text_info=None)
+        )
+
+    def get_learned_conditioning(self, c):
+        if self.cond_stage_forward is None:
+            if hasattr(self.cond_stage_model, "encode") and callable(
+                self.cond_stage_model.encode
+            ):
+                if self.embedding_manager is not None and c["text_info"] is not None:
+                    self.embedding_manager.encode_text(c["text_info"])
+                if isinstance(c, dict):
+                    cond_txt = c["c_crossattn"][0]
+                else:
+                    cond_txt = c
+                if self.embedding_manager is not None:
+                    cond_txt = self.cond_stage_model.encode(
+                        cond_txt, embedding_manager=self.embedding_manager
+                    )
+                else:
+                    cond_txt = self.cond_stage_model.encode(cond_txt)
+                if isinstance(c, dict):
+                    c["c_crossattn"][0] = cond_txt
+                else:
+                    c = cond_txt
+                if isinstance(c, DiagonalGaussianDistribution):
+                    c = c.mode()
+            else:
+                c = self.cond_stage_model(c)
+        else:
+            assert hasattr(self.cond_stage_model, self.cond_stage_forward)
+            c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
+        return c
+
+    def fill_caption(self, batch, place_holder="*"):
+        bs = len(batch["n_lines"])
+        cond_list = copy.deepcopy(batch[self.cond_stage_key])
+        for i in range(bs):
+            n_lines = batch["n_lines"][i]
+            if n_lines == 0:
+                continue
+            cur_cap = cond_list[i]
+            for j in range(n_lines):
+                r_txt = batch["texts"][j][i]
+                cur_cap = cur_cap.replace(place_holder, f'"{r_txt}"', 1)
+            cond_list[i] = cur_cap
+        batch[self.cond_stage_key] = cond_list
+
+    @torch.no_grad()
+    def log_images(
+        self,
+        batch,
+        N=4,
+        n_row=2,
+        sample=False,
+        ddim_steps=50,
+        ddim_eta=0.0,
+        return_keys=None,
+        quantize_denoised=True,
+        inpaint=True,
+        plot_denoise_rows=False,
+        plot_progressive_rows=True,
+        plot_diffusion_rows=False,
+        unconditional_guidance_scale=9.0,
+        unconditional_guidance_label=None,
+        use_ema_scope=True,
+        **kwargs,
+    ):
+        use_ddim = ddim_steps is not None
+
+        log = dict()
+        z, c = self.get_input(batch, self.first_stage_key, bs=N)
+        if self.cond_stage_trainable:
+            with torch.no_grad():
+                c = self.get_learned_conditioning(c)
+        c_crossattn = c["c_crossattn"][0][:N]
+        c_cat = c["c_concat"][0][:N]
+        text_info = c["text_info"]
+        text_info["glyphs"] = [i[:N] for i in text_info["glyphs"]]
+        text_info["gly_line"] = [i[:N] for i in text_info["gly_line"]]
+        text_info["positions"] = [i[:N] for i in text_info["positions"]]
+        text_info["n_lines"] = text_info["n_lines"][:N]
+        text_info["masked_x"] = text_info["masked_x"][:N]
+        text_info["img"] = text_info["img"][:N]
+
+        N = min(z.shape[0], N)
+        n_row = min(z.shape[0], n_row)
+        log["reconstruction"] = self.decode_first_stage(z)
+        log["masked_image"] = self.decode_first_stage(text_info["masked_x"])
+        log["control"] = c_cat * 2.0 - 1.0
+        log["img"] = text_info["img"].permute(0, 3, 1, 2)  # log source image if needed
+        # get glyph
+        glyph_bs = torch.stack(text_info["glyphs"])
+        glyph_bs = torch.sum(glyph_bs, dim=0) * 2.0 - 1.0
+        log["glyph"] = torch.nn.functional.interpolate(
+            glyph_bs,
+            size=(512, 512),
+            mode="bilinear",
+            align_corners=True,
+        )
+        # fill caption
+        if not self.embedding_manager:
+            self.fill_caption(batch)
+        captions = batch[self.cond_stage_key]
+        log["conditioning"] = log_txt_as_img((512, 512), captions, size=16)
+
+        if plot_diffusion_rows:
+            # get diffusion row
+            diffusion_row = list()
+            z_start = z[:n_row]
+            for t in range(self.num_timesteps):
+                if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
+                    t = repeat(torch.tensor([t]), "1 -> b", b=n_row)
+                    t = t.to(self.device).long()
+                    noise = torch.randn_like(z_start)
+                    z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
+                    diffusion_row.append(self.decode_first_stage(z_noisy))
+
+            diffusion_row = torch.stack(diffusion_row)  # n_log_step, n_row, C, H, W
+            diffusion_grid = rearrange(diffusion_row, "n b c h w -> b n c h w")
+            diffusion_grid = rearrange(diffusion_grid, "b n c h w -> (b n) c h w")
+            diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
+            log["diffusion_row"] = diffusion_grid
+
+        if sample:
+            # get denoise row
+            samples, z_denoise_row = self.sample_log(
+                cond={"c_concat": [c_cat], "c_crossattn": [c], "text_info": text_info},
+                batch_size=N,
+                ddim=use_ddim,
+                ddim_steps=ddim_steps,
+                eta=ddim_eta,
+            )
+            x_samples = self.decode_first_stage(samples)
+            log["samples"] = x_samples
+            if plot_denoise_rows:
+                denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
+                log["denoise_row"] = denoise_grid
+
+        if unconditional_guidance_scale > 1.0:
+            uc_cross = self.get_unconditional_conditioning(N)
+            uc_cat = c_cat  # torch.zeros_like(c_cat)
+            uc_full = {
+                "c_concat": [uc_cat],
+                "c_crossattn": [uc_cross["c_crossattn"][0]],
+                "text_info": text_info,
+            }
+            samples_cfg, tmps = self.sample_log(
+                cond={
+                    "c_concat": [c_cat],
+                    "c_crossattn": [c_crossattn],
+                    "text_info": text_info,
+                },
+                batch_size=N,
+                ddim=use_ddim,
+                ddim_steps=ddim_steps,
+                eta=ddim_eta,
+                unconditional_guidance_scale=unconditional_guidance_scale,
+                unconditional_conditioning=uc_full,
+            )
+            x_samples_cfg = self.decode_first_stage(samples_cfg)
+            log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
+            pred_x0 = False  # wether log pred_x0
+            if pred_x0:
+                for idx in range(len(tmps["pred_x0"])):
+                    pred_x0 = self.decode_first_stage(tmps["pred_x0"][idx])
+                    log[f"pred_x0_{tmps['index'][idx]}"] = pred_x0
+
+        return log
+
+    @torch.no_grad()
+    def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):
+        ddim_sampler = DDIMSampler(self)
+        b, c, h, w = cond["c_concat"][0].shape
+        shape = (self.channels, h // 8, w // 8)
+        samples, intermediates = ddim_sampler.sample(
+            ddim_steps, batch_size, shape, cond, verbose=False, log_every_t=5, **kwargs
+        )
+        return samples, intermediates
+
+    def configure_optimizers(self):
+        lr = self.learning_rate
+        params = list(self.control_model.parameters())
+        if self.embedding_manager:
+            params += list(self.embedding_manager.embedding_parameters())
+        if not self.sd_locked:
+            # params += list(self.model.diffusion_model.input_blocks.parameters())
+            # params += list(self.model.diffusion_model.middle_block.parameters())
+            params += list(self.model.diffusion_model.output_blocks.parameters())
+            params += list(self.model.diffusion_model.out.parameters())
+        if self.unlockKV:
+            nCount = 0
+            for name, param in self.model.diffusion_model.named_parameters():
+                if "attn2.to_k" in name or "attn2.to_v" in name:
+                    params += [param]
+                    nCount += 1
+            print(
+                f"Cross attention is unlocked, and {nCount} Wk or Wv are added to potimizers!!!"
+            )
+
+        opt = torch.optim.AdamW(params, lr=lr)
+        return opt
+
+    def low_vram_shift(self, is_diffusing):
+        if is_diffusing:
+            self.model = self.model.cuda()
+            self.control_model = self.control_model.cuda()
+            self.first_stage_model = self.first_stage_model.cpu()
+            self.cond_stage_model = self.cond_stage_model.cpu()
+        else:
+            self.model = self.model.cpu()
+            self.control_model = self.control_model.cpu()
+            self.first_stage_model = self.first_stage_model.cuda()
+            self.cond_stage_model = self.cond_stage_model.cuda()

+ 486 - 0
sorawm/iopaint/model/anytext/cldm/ddim_hacked.py

@@ -0,0 +1,486 @@
+"""SAMPLING ONLY."""
+
+import numpy as np
+import torch
+from tqdm import tqdm
+
+from sorawm.iopaint.model.anytext.ldm.modules.diffusionmodules.util import (
+    extract_into_tensor,
+    make_ddim_sampling_parameters,
+    make_ddim_timesteps,
+    noise_like,
+)
+
+
+class DDIMSampler(object):
+    def __init__(self, model, device, schedule="linear", **kwargs):
+        super().__init__()
+        self.device = device
+        self.model = model
+        self.ddpm_num_timesteps = model.num_timesteps
+        self.schedule = schedule
+
+    def register_buffer(self, name, attr):
+        if type(attr) == torch.Tensor:
+            if attr.device != torch.device(self.device):
+                attr = attr.to(torch.device(self.device))
+        setattr(self, name, attr)
+
+    def make_schedule(
+        self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True
+    ):
+        self.ddim_timesteps = make_ddim_timesteps(
+            ddim_discr_method=ddim_discretize,
+            num_ddim_timesteps=ddim_num_steps,
+            num_ddpm_timesteps=self.ddpm_num_timesteps,
+            verbose=verbose,
+        )
+        alphas_cumprod = self.model.alphas_cumprod
+        assert (
+            alphas_cumprod.shape[0] == self.ddpm_num_timesteps
+        ), "alphas have to be defined for each timestep"
+        to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.device)
+
+        self.register_buffer("betas", to_torch(self.model.betas))
+        self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
+        self.register_buffer(
+            "alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev)
+        )
+
+        # calculations for diffusion q(x_t | x_{t-1}) and others
+        self.register_buffer(
+            "sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu()))
+        )
+        self.register_buffer(
+            "sqrt_one_minus_alphas_cumprod",
+            to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
+        )
+        self.register_buffer(
+            "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu()))
+        )
+        self.register_buffer(
+            "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))
+        )
+        self.register_buffer(
+            "sqrt_recipm1_alphas_cumprod",
+            to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
+        )
+
+        # ddim sampling parameters
+        ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
+            alphacums=alphas_cumprod.cpu(),
+            ddim_timesteps=self.ddim_timesteps,
+            eta=ddim_eta,
+            verbose=verbose,
+        )
+        self.register_buffer("ddim_sigmas", ddim_sigmas)
+        self.register_buffer("ddim_alphas", ddim_alphas)
+        self.register_buffer("ddim_alphas_prev", ddim_alphas_prev)
+        self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas))
+        sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
+            (1 - self.alphas_cumprod_prev)
+            / (1 - self.alphas_cumprod)
+            * (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
+        )
+        self.register_buffer(
+            "ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps
+        )
+
+    @torch.no_grad()
+    def sample(
+        self,
+        S,
+        batch_size,
+        shape,
+        conditioning=None,
+        callback=None,
+        normals_sequence=None,
+        img_callback=None,
+        quantize_x0=False,
+        eta=0.0,
+        mask=None,
+        x0=None,
+        temperature=1.0,
+        noise_dropout=0.0,
+        score_corrector=None,
+        corrector_kwargs=None,
+        verbose=True,
+        x_T=None,
+        log_every_t=100,
+        unconditional_guidance_scale=1.0,
+        unconditional_conditioning=None,  # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
+        dynamic_threshold=None,
+        ucg_schedule=None,
+        **kwargs,
+    ):
+        if conditioning is not None:
+            if isinstance(conditioning, dict):
+                ctmp = conditioning[list(conditioning.keys())[0]]
+                while isinstance(ctmp, list):
+                    ctmp = ctmp[0]
+                cbs = ctmp.shape[0]
+                if cbs != batch_size:
+                    print(
+                        f"Warning: Got {cbs} conditionings but batch-size is {batch_size}"
+                    )
+
+            elif isinstance(conditioning, list):
+                for ctmp in conditioning:
+                    if ctmp.shape[0] != batch_size:
+                        print(
+                            f"Warning: Got {cbs} conditionings but batch-size is {batch_size}"
+                        )
+
+            else:
+                if conditioning.shape[0] != batch_size:
+                    print(
+                        f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}"
+                    )
+
+        self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
+        # sampling
+        C, H, W = shape
+        size = (batch_size, C, H, W)
+        print(f"Data shape for DDIM sampling is {size}, eta {eta}")
+
+        samples, intermediates = self.ddim_sampling(
+            conditioning,
+            size,
+            callback=callback,
+            img_callback=img_callback,
+            quantize_denoised=quantize_x0,
+            mask=mask,
+            x0=x0,
+            ddim_use_original_steps=False,
+            noise_dropout=noise_dropout,
+            temperature=temperature,
+            score_corrector=score_corrector,
+            corrector_kwargs=corrector_kwargs,
+            x_T=x_T,
+            log_every_t=log_every_t,
+            unconditional_guidance_scale=unconditional_guidance_scale,
+            unconditional_conditioning=unconditional_conditioning,
+            dynamic_threshold=dynamic_threshold,
+            ucg_schedule=ucg_schedule,
+        )
+        return samples, intermediates
+
+    @torch.no_grad()
+    def ddim_sampling(
+        self,
+        cond,
+        shape,
+        x_T=None,
+        ddim_use_original_steps=False,
+        callback=None,
+        timesteps=None,
+        quantize_denoised=False,
+        mask=None,
+        x0=None,
+        img_callback=None,
+        log_every_t=100,
+        temperature=1.0,
+        noise_dropout=0.0,
+        score_corrector=None,
+        corrector_kwargs=None,
+        unconditional_guidance_scale=1.0,
+        unconditional_conditioning=None,
+        dynamic_threshold=None,
+        ucg_schedule=None,
+    ):
+        device = self.model.betas.device
+        b = shape[0]
+        if x_T is None:
+            img = torch.randn(shape, device=device)
+        else:
+            img = x_T
+
+        if timesteps is None:
+            timesteps = (
+                self.ddpm_num_timesteps
+                if ddim_use_original_steps
+                else self.ddim_timesteps
+            )
+        elif timesteps is not None and not ddim_use_original_steps:
+            subset_end = (
+                int(
+                    min(timesteps / self.ddim_timesteps.shape[0], 1)
+                    * self.ddim_timesteps.shape[0]
+                )
+                - 1
+            )
+            timesteps = self.ddim_timesteps[:subset_end]
+
+        intermediates = {"x_inter": [img], "pred_x0": [img]}
+        time_range = (
+            reversed(range(0, timesteps))
+            if ddim_use_original_steps
+            else np.flip(timesteps)
+        )
+        total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
+        print(f"Running DDIM Sampling with {total_steps} timesteps")
+
+        iterator = tqdm(time_range, desc="DDIM Sampler", total=total_steps)
+
+        for i, step in enumerate(iterator):
+            index = total_steps - i - 1
+            ts = torch.full((b,), step, device=device, dtype=torch.long)
+
+            if mask is not None:
+                assert x0 is not None
+                img_orig = self.model.q_sample(
+                    x0, ts
+                )  # TODO: deterministic forward pass?
+                img = img_orig * mask + (1.0 - mask) * img
+
+            if ucg_schedule is not None:
+                assert len(ucg_schedule) == len(time_range)
+                unconditional_guidance_scale = ucg_schedule[i]
+
+            outs = self.p_sample_ddim(
+                img,
+                cond,
+                ts,
+                index=index,
+                use_original_steps=ddim_use_original_steps,
+                quantize_denoised=quantize_denoised,
+                temperature=temperature,
+                noise_dropout=noise_dropout,
+                score_corrector=score_corrector,
+                corrector_kwargs=corrector_kwargs,
+                unconditional_guidance_scale=unconditional_guidance_scale,
+                unconditional_conditioning=unconditional_conditioning,
+                dynamic_threshold=dynamic_threshold,
+            )
+            img, pred_x0 = outs
+            if callback:
+                callback(None, i, None, None)
+            if img_callback:
+                img_callback(pred_x0, i)
+
+            if index % log_every_t == 0 or index == total_steps - 1:
+                intermediates["x_inter"].append(img)
+                intermediates["pred_x0"].append(pred_x0)
+
+        return img, intermediates
+
+    @torch.no_grad()
+    def p_sample_ddim(
+        self,
+        x,
+        c,
+        t,
+        index,
+        repeat_noise=False,
+        use_original_steps=False,
+        quantize_denoised=False,
+        temperature=1.0,
+        noise_dropout=0.0,
+        score_corrector=None,
+        corrector_kwargs=None,
+        unconditional_guidance_scale=1.0,
+        unconditional_conditioning=None,
+        dynamic_threshold=None,
+    ):
+        b, *_, device = *x.shape, x.device
+
+        if unconditional_conditioning is None or unconditional_guidance_scale == 1.0:
+            model_output = self.model.apply_model(x, t, c)
+        else:
+            model_t = self.model.apply_model(x, t, c)
+            model_uncond = self.model.apply_model(x, t, unconditional_conditioning)
+            model_output = model_uncond + unconditional_guidance_scale * (
+                model_t - model_uncond
+            )
+
+        if self.model.parameterization == "v":
+            e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
+        else:
+            e_t = model_output
+
+        if score_corrector is not None:
+            assert self.model.parameterization == "eps", "not implemented"
+            e_t = score_corrector.modify_score(
+                self.model, e_t, x, t, c, **corrector_kwargs
+            )
+
+        alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
+        alphas_prev = (
+            self.model.alphas_cumprod_prev
+            if use_original_steps
+            else self.ddim_alphas_prev
+        )
+        sqrt_one_minus_alphas = (
+            self.model.sqrt_one_minus_alphas_cumprod
+            if use_original_steps
+            else self.ddim_sqrt_one_minus_alphas
+        )
+        sigmas = (
+            self.model.ddim_sigmas_for_original_num_steps
+            if use_original_steps
+            else self.ddim_sigmas
+        )
+        # select parameters corresponding to the currently considered timestep
+        a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
+        a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
+        sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
+        sqrt_one_minus_at = torch.full(
+            (b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device
+        )
+
+        # current prediction for x_0
+        if self.model.parameterization != "v":
+            pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
+        else:
+            pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
+
+        if quantize_denoised:
+            pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
+
+        if dynamic_threshold is not None:
+            raise NotImplementedError()
+
+        # direction pointing to x_t
+        dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t
+        noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
+        if noise_dropout > 0.0:
+            noise = torch.nn.functional.dropout(noise, p=noise_dropout)
+        x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
+        return x_prev, pred_x0
+
+    @torch.no_grad()
+    def encode(
+        self,
+        x0,
+        c,
+        t_enc,
+        use_original_steps=False,
+        return_intermediates=None,
+        unconditional_guidance_scale=1.0,
+        unconditional_conditioning=None,
+        callback=None,
+    ):
+        timesteps = (
+            np.arange(self.ddpm_num_timesteps)
+            if use_original_steps
+            else self.ddim_timesteps
+        )
+        num_reference_steps = timesteps.shape[0]
+
+        assert t_enc <= num_reference_steps
+        num_steps = t_enc
+
+        if use_original_steps:
+            alphas_next = self.alphas_cumprod[:num_steps]
+            alphas = self.alphas_cumprod_prev[:num_steps]
+        else:
+            alphas_next = self.ddim_alphas[:num_steps]
+            alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
+
+        x_next = x0
+        intermediates = []
+        inter_steps = []
+        for i in tqdm(range(num_steps), desc="Encoding Image"):
+            t = torch.full(
+                (x0.shape[0],), timesteps[i], device=self.model.device, dtype=torch.long
+            )
+            if unconditional_guidance_scale == 1.0:
+                noise_pred = self.model.apply_model(x_next, t, c)
+            else:
+                assert unconditional_conditioning is not None
+                e_t_uncond, noise_pred = torch.chunk(
+                    self.model.apply_model(
+                        torch.cat((x_next, x_next)),
+                        torch.cat((t, t)),
+                        torch.cat((unconditional_conditioning, c)),
+                    ),
+                    2,
+                )
+                noise_pred = e_t_uncond + unconditional_guidance_scale * (
+                    noise_pred - e_t_uncond
+                )
+
+            xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
+            weighted_noise_pred = (
+                alphas_next[i].sqrt()
+                * ((1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt())
+                * noise_pred
+            )
+            x_next = xt_weighted + weighted_noise_pred
+            if (
+                return_intermediates
+                and i % (num_steps // return_intermediates) == 0
+                and i < num_steps - 1
+            ):
+                intermediates.append(x_next)
+                inter_steps.append(i)
+            elif return_intermediates and i >= num_steps - 2:
+                intermediates.append(x_next)
+                inter_steps.append(i)
+            if callback:
+                callback(i)
+
+        out = {"x_encoded": x_next, "intermediate_steps": inter_steps}
+        if return_intermediates:
+            out.update({"intermediates": intermediates})
+        return x_next, out
+
+    @torch.no_grad()
+    def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
+        # fast, but does not allow for exact reconstruction
+        # t serves as an index to gather the correct alphas
+        if use_original_steps:
+            sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
+            sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
+        else:
+            sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
+            sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
+
+        if noise is None:
+            noise = torch.randn_like(x0)
+        return (
+            extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0
+            + extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise
+        )
+
+    @torch.no_grad()
+    def decode(
+        self,
+        x_latent,
+        cond,
+        t_start,
+        unconditional_guidance_scale=1.0,
+        unconditional_conditioning=None,
+        use_original_steps=False,
+        callback=None,
+    ):
+        timesteps = (
+            np.arange(self.ddpm_num_timesteps)
+            if use_original_steps
+            else self.ddim_timesteps
+        )
+        timesteps = timesteps[:t_start]
+
+        time_range = np.flip(timesteps)
+        total_steps = timesteps.shape[0]
+        print(f"Running DDIM Sampling with {total_steps} timesteps")
+
+        iterator = tqdm(time_range, desc="Decoding image", total=total_steps)
+        x_dec = x_latent
+        for i, step in enumerate(iterator):
+            index = total_steps - i - 1
+            ts = torch.full(
+                (x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long
+            )
+            x_dec, _ = self.p_sample_ddim(
+                x_dec,
+                cond,
+                ts,
+                index=index,
+                use_original_steps=use_original_steps,
+                unconditional_guidance_scale=unconditional_guidance_scale,
+                unconditional_conditioning=unconditional_conditioning,
+            )
+            if callback:
+                callback(i)
+        return x_dec

+ 185 - 0
sorawm/iopaint/model/anytext/cldm/embedding_manager.py

@@ -0,0 +1,185 @@
+"""
+Copyright (c) Alibaba, Inc. and its affiliates.
+"""
+from functools import partial
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from sorawm.iopaint.model.anytext.ldm.modules.diffusionmodules.util import (
+    conv_nd,
+    linear,
+)
+
+
+def get_clip_token_for_string(tokenizer, string):
+    batch_encoding = tokenizer(
+        string,
+        truncation=True,
+        max_length=77,
+        return_length=True,
+        return_overflowing_tokens=False,
+        padding="max_length",
+        return_tensors="pt",
+    )
+    tokens = batch_encoding["input_ids"]
+    assert (
+        torch.count_nonzero(tokens - 49407) == 2
+    ), f"String '{string}' maps to more than a single token. Please use another string"
+    return tokens[0, 1]
+
+
+def get_bert_token_for_string(tokenizer, string):
+    token = tokenizer(string)
+    assert (
+        torch.count_nonzero(token) == 3
+    ), f"String '{string}' maps to more than a single token. Please use another string"
+    token = token[0, 1]
+    return token
+
+
+def get_clip_vision_emb(encoder, processor, img):
+    _img = img.repeat(1, 3, 1, 1) * 255
+    inputs = processor(images=_img, return_tensors="pt")
+    inputs["pixel_values"] = inputs["pixel_values"].to(img.device)
+    outputs = encoder(**inputs)
+    emb = outputs.image_embeds
+    return emb
+
+
+def get_recog_emb(encoder, img_list):
+    _img_list = [(img.repeat(1, 3, 1, 1) * 255)[0] for img in img_list]
+    encoder.predictor.eval()
+    _, preds_neck = encoder.pred_imglist(_img_list, show_debug=False)
+    return preds_neck
+
+
+def pad_H(x):
+    _, _, H, W = x.shape
+    p_top = (W - H) // 2
+    p_bot = W - H - p_top
+    return F.pad(x, (0, 0, p_top, p_bot))
+
+
+class EncodeNet(nn.Module):
+    def __init__(self, in_channels, out_channels):
+        super(EncodeNet, self).__init__()
+        chan = 16
+        n_layer = 4  # downsample
+
+        self.conv1 = conv_nd(2, in_channels, chan, 3, padding=1)
+        self.conv_list = nn.ModuleList([])
+        _c = chan
+        for i in range(n_layer):
+            self.conv_list.append(conv_nd(2, _c, _c * 2, 3, padding=1, stride=2))
+            _c *= 2
+        self.conv2 = conv_nd(2, _c, out_channels, 3, padding=1)
+        self.avgpool = nn.AdaptiveAvgPool2d(1)
+        self.act = nn.SiLU()
+
+    def forward(self, x):
+        x = self.act(self.conv1(x))
+        for layer in self.conv_list:
+            x = self.act(layer(x))
+        x = self.act(self.conv2(x))
+        x = self.avgpool(x)
+        x = x.view(x.size(0), -1)
+        return x
+
+
+class EmbeddingManager(nn.Module):
+    def __init__(
+        self,
+        embedder,
+        valid=True,
+        glyph_channels=20,
+        position_channels=1,
+        placeholder_string="*",
+        add_pos=False,
+        emb_type="ocr",
+        **kwargs,
+    ):
+        super().__init__()
+        if hasattr(embedder, "tokenizer"):  # using Stable Diffusion's CLIP encoder
+            get_token_for_string = partial(
+                get_clip_token_for_string, embedder.tokenizer
+            )
+            token_dim = 768
+            if hasattr(embedder, "vit"):
+                assert emb_type == "vit"
+                self.get_vision_emb = partial(
+                    get_clip_vision_emb, embedder.vit, embedder.processor
+                )
+            self.get_recog_emb = None
+        else:  # using LDM's BERT encoder
+            get_token_for_string = partial(get_bert_token_for_string, embedder.tknz_fn)
+            token_dim = 1280
+        self.token_dim = token_dim
+        self.emb_type = emb_type
+
+        self.add_pos = add_pos
+        if add_pos:
+            self.position_encoder = EncodeNet(position_channels, token_dim)
+        if emb_type == "ocr":
+            self.proj = linear(40 * 64, token_dim)
+        if emb_type == "conv":
+            self.glyph_encoder = EncodeNet(glyph_channels, token_dim)
+
+        self.placeholder_token = get_token_for_string(placeholder_string)
+
+    def encode_text(self, text_info):
+        if self.get_recog_emb is None and self.emb_type == "ocr":
+            self.get_recog_emb = partial(get_recog_emb, self.recog)
+
+        gline_list = []
+        pos_list = []
+        for i in range(len(text_info["n_lines"])):  # sample index in a batch
+            n_lines = text_info["n_lines"][i]
+            for j in range(n_lines):  # line
+                gline_list += [text_info["gly_line"][j][i : i + 1]]
+                if self.add_pos:
+                    pos_list += [text_info["positions"][j][i : i + 1]]
+
+        if len(gline_list) > 0:
+            if self.emb_type == "ocr":
+                recog_emb = self.get_recog_emb(gline_list)
+                enc_glyph = self.proj(recog_emb.reshape(recog_emb.shape[0], -1))
+            elif self.emb_type == "vit":
+                enc_glyph = self.get_vision_emb(pad_H(torch.cat(gline_list, dim=0)))
+            elif self.emb_type == "conv":
+                enc_glyph = self.glyph_encoder(pad_H(torch.cat(gline_list, dim=0)))
+            if self.add_pos:
+                enc_pos = self.position_encoder(torch.cat(gline_list, dim=0))
+                enc_glyph = enc_glyph + enc_pos
+
+        self.text_embs_all = []
+        n_idx = 0
+        for i in range(len(text_info["n_lines"])):  # sample index in a batch
+            n_lines = text_info["n_lines"][i]
+            text_embs = []
+            for j in range(n_lines):  # line
+                text_embs += [enc_glyph[n_idx : n_idx + 1]]
+                n_idx += 1
+            self.text_embs_all += [text_embs]
+
+    def forward(
+        self,
+        tokenized_text,
+        embedded_text,
+    ):
+        b, device = tokenized_text.shape[0], tokenized_text.device
+        for i in range(b):
+            idx = tokenized_text[i] == self.placeholder_token.to(device)
+            if sum(idx) > 0:
+                if i >= len(self.text_embs_all):
+                    print("truncation for log images...")
+                    break
+                text_emb = torch.cat(self.text_embs_all[i], dim=0)
+                if sum(idx) != len(text_emb):
+                    print("truncation for long caption...")
+                embedded_text[i][idx] = text_emb[: sum(idx)]
+        return embedded_text
+
+    def embedding_parameters(self):
+        return self.parameters()

+ 128 - 0
sorawm/iopaint/model/anytext/cldm/hack.py

@@ -0,0 +1,128 @@
+import einops
+import torch
+from transformers import logging
+
+import sorawm.iopaint.model.anytext.ldm.modules.attention
+import sorawm.iopaint.model.anytext.ldm.modules.encoders.modules
+from sorawm.iopaint.model.anytext.ldm.modules.attention import default
+
+
+def disable_verbosity():
+    logging.set_verbosity_error()
+    print("logging improved.")
+    return
+
+
+def enable_sliced_attention():
+    sorawm.iopaint.model.anytext.ldm.modules.attention.CrossAttention.forward = (
+        _hacked_sliced_attentin_forward
+    )
+    print("Enabled sliced_attention.")
+    return
+
+
+def hack_everything(clip_skip=0):
+    disable_verbosity()
+    sorawm.iopaint.model.anytext.ldm.modules.encoders.modules.FrozenCLIPEmbedder.forward = (
+        _hacked_clip_forward
+    )
+    sorawm.iopaint.model.anytext.ldm.modules.encoders.modules.FrozenCLIPEmbedder.clip_skip = (
+        clip_skip
+    )
+    print("Enabled clip hacks.")
+    return
+
+
+# Written by Lvmin
+def _hacked_clip_forward(self, text):
+    PAD = self.tokenizer.pad_token_id
+    EOS = self.tokenizer.eos_token_id
+    BOS = self.tokenizer.bos_token_id
+
+    def tokenize(t):
+        return self.tokenizer(t, truncation=False, add_special_tokens=False)[
+            "input_ids"
+        ]
+
+    def transformer_encode(t):
+        if self.clip_skip > 1:
+            rt = self.transformer(input_ids=t, output_hidden_states=True)
+            return self.transformer.text_model.final_layer_norm(
+                rt.hidden_states[-self.clip_skip]
+            )
+        else:
+            return self.transformer(
+                input_ids=t, output_hidden_states=False
+            ).last_hidden_state
+
+    def split(x):
+        return x[75 * 0 : 75 * 1], x[75 * 1 : 75 * 2], x[75 * 2 : 75 * 3]
+
+    def pad(x, p, i):
+        return x[:i] if len(x) >= i else x + [p] * (i - len(x))
+
+    raw_tokens_list = tokenize(text)
+    tokens_list = []
+
+    for raw_tokens in raw_tokens_list:
+        raw_tokens_123 = split(raw_tokens)
+        raw_tokens_123 = [
+            [BOS] + raw_tokens_i + [EOS] for raw_tokens_i in raw_tokens_123
+        ]
+        raw_tokens_123 = [pad(raw_tokens_i, PAD, 77) for raw_tokens_i in raw_tokens_123]
+        tokens_list.append(raw_tokens_123)
+
+    tokens_list = torch.IntTensor(tokens_list).to(self.device)
+
+    feed = einops.rearrange(tokens_list, "b f i -> (b f) i")
+    y = transformer_encode(feed)
+    z = einops.rearrange(y, "(b f) i c -> b (f i) c", f=3)
+
+    return z
+
+
+# Stolen from https://github.com/basujindal/stable-diffusion/blob/main/optimizedSD/splitAttention.py
+def _hacked_sliced_attentin_forward(self, x, context=None, mask=None):
+    h = self.heads
+
+    q = self.to_q(x)
+    context = default(context, x)
+    k = self.to_k(context)
+    v = self.to_v(context)
+    del context, x
+
+    q, k, v = map(
+        lambda t: einops.rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)
+    )
+
+    limit = k.shape[0]
+    att_step = 1
+    q_chunks = list(torch.tensor_split(q, limit // att_step, dim=0))
+    k_chunks = list(torch.tensor_split(k, limit // att_step, dim=0))
+    v_chunks = list(torch.tensor_split(v, limit // att_step, dim=0))
+
+    q_chunks.reverse()
+    k_chunks.reverse()
+    v_chunks.reverse()
+    sim = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
+    del k, q, v
+    for i in range(0, limit, att_step):
+        q_buffer = q_chunks.pop()
+        k_buffer = k_chunks.pop()
+        v_buffer = v_chunks.pop()
+        sim_buffer = (
+            torch.einsum("b i d, b j d -> b i j", q_buffer, k_buffer) * self.scale
+        )
+
+        del k_buffer, q_buffer
+        # attention, what we cannot get enough of, by chunks
+
+        sim_buffer = sim_buffer.softmax(dim=-1)
+
+        sim_buffer = torch.einsum("b i j, b j d -> b i d", sim_buffer, v_buffer)
+        del v_buffer
+        sim[i : i + att_step, :, :] = sim_buffer
+
+        del sim_buffer
+    sim = einops.rearrange(sim, "(b h) n d -> b n (h d)", h=h)
+    return self.to_out(sim)

+ 41 - 0
sorawm/iopaint/model/anytext/cldm/model.py

@@ -0,0 +1,41 @@
+import os
+
+import torch
+from omegaconf import OmegaConf
+
+from sorawm.iopaint.model.anytext.ldm.util import instantiate_from_config
+
+
+def get_state_dict(d):
+    return d.get("state_dict", d)
+
+
+def load_state_dict(ckpt_path, location="cpu"):
+    _, extension = os.path.splitext(ckpt_path)
+    if extension.lower() == ".safetensors":
+        import safetensors.torch
+
+        state_dict = safetensors.torch.load_file(ckpt_path, device=location)
+    else:
+        state_dict = get_state_dict(
+            torch.load(ckpt_path, map_location=torch.device(location))
+        )
+    state_dict = get_state_dict(state_dict)
+    print(f"Loaded state_dict from [{ckpt_path}]")
+    return state_dict
+
+
+def create_model(config_path, device, cond_stage_path=None, use_fp16=False):
+    config = OmegaConf.load(config_path)
+    # if cond_stage_path:
+    #     config.model.params.cond_stage_config.params.version = (
+    #         cond_stage_path  # use pre-downloaded ckpts, in case blocked
+    #     )
+    config.model.params.cond_stage_config.params.device = str(device)
+    if use_fp16:
+        config.model.params.use_fp16 = True
+        config.model.params.control_stage_config.params.use_fp16 = True
+        config.model.params.unet_config.params.use_fp16 = True
+    model = instantiate_from_config(config.model).cpu()
+    print(f"Loaded model config from [{config_path}]")
+    return model

+ 302 - 0
sorawm/iopaint/model/anytext/cldm/recognizer.py

@@ -0,0 +1,302 @@
+"""
+Copyright (c) Alibaba, Inc. and its affiliates.
+"""
+import math
+import os
+import time
+import traceback
+
+import cv2
+import numpy as np
+import torch
+import torch.nn.functional as F
+from easydict import EasyDict as edict
+
+from sorawm.iopaint.model.anytext.ocr_recog.RecModel import RecModel
+
+
+def min_bounding_rect(img):
+    ret, thresh = cv2.threshold(img, 127, 255, 0)
+    contours, hierarchy = cv2.findContours(
+        thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
+    )
+    if len(contours) == 0:
+        print("Bad contours, using fake bbox...")
+        return np.array([[0, 0], [100, 0], [100, 100], [0, 100]])
+    max_contour = max(contours, key=cv2.contourArea)
+    rect = cv2.minAreaRect(max_contour)
+    box = cv2.boxPoints(rect)
+    box = np.int0(box)
+    # sort
+    x_sorted = sorted(box, key=lambda x: x[0])
+    left = x_sorted[:2]
+    right = x_sorted[2:]
+    left = sorted(left, key=lambda x: x[1])
+    (tl, bl) = left
+    right = sorted(right, key=lambda x: x[1])
+    (tr, br) = right
+    if tl[1] > bl[1]:
+        (tl, bl) = (bl, tl)
+    if tr[1] > br[1]:
+        (tr, br) = (br, tr)
+    return np.array([tl, tr, br, bl])
+
+
+def create_predictor(model_dir=None, model_lang="ch", is_onnx=False):
+    model_file_path = model_dir
+    if model_file_path is not None and not os.path.exists(model_file_path):
+        raise ValueError("not find model file path {}".format(model_file_path))
+
+    if is_onnx:
+        import onnxruntime as ort
+
+        sess = ort.InferenceSession(
+            model_file_path, providers=["CPUExecutionProvider"]
+        )  # 'TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider'
+        return sess
+    else:
+        if model_lang == "ch":
+            n_class = 6625
+        elif model_lang == "en":
+            n_class = 97
+        else:
+            raise ValueError(f"Unsupported OCR recog model_lang: {model_lang}")
+        rec_config = edict(
+            in_channels=3,
+            backbone=edict(
+                type="MobileNetV1Enhance",
+                scale=0.5,
+                last_conv_stride=[1, 2],
+                last_pool_type="avg",
+            ),
+            neck=edict(
+                type="SequenceEncoder",
+                encoder_type="svtr",
+                dims=64,
+                depth=2,
+                hidden_dims=120,
+                use_guide=True,
+            ),
+            head=edict(
+                type="CTCHead",
+                fc_decay=0.00001,
+                out_channels=n_class,
+                return_feats=True,
+            ),
+        )
+
+        rec_model = RecModel(rec_config)
+        if model_file_path is not None:
+            rec_model.load_state_dict(torch.load(model_file_path, map_location="cpu"))
+            rec_model.eval()
+        return rec_model.eval()
+
+
+def _check_image_file(path):
+    img_end = {"jpg", "bmp", "png", "jpeg", "rgb", "tif", "tiff"}
+    return any([path.lower().endswith(e) for e in img_end])
+
+
+def get_image_file_list(img_file):
+    imgs_lists = []
+    if img_file is None or not os.path.exists(img_file):
+        raise Exception("not found any img file in {}".format(img_file))
+    if os.path.isfile(img_file) and _check_image_file(img_file):
+        imgs_lists.append(img_file)
+    elif os.path.isdir(img_file):
+        for single_file in os.listdir(img_file):
+            file_path = os.path.join(img_file, single_file)
+            if os.path.isfile(file_path) and _check_image_file(file_path):
+                imgs_lists.append(file_path)
+    if len(imgs_lists) == 0:
+        raise Exception("not found any img file in {}".format(img_file))
+    imgs_lists = sorted(imgs_lists)
+    return imgs_lists
+
+
+class TextRecognizer(object):
+    def __init__(self, args, predictor):
+        self.rec_image_shape = [int(v) for v in args.rec_image_shape.split(",")]
+        self.rec_batch_num = args.rec_batch_num
+        self.predictor = predictor
+        self.chars = self.get_char_dict(args.rec_char_dict_path)
+        self.char2id = {x: i for i, x in enumerate(self.chars)}
+        self.is_onnx = not isinstance(self.predictor, torch.nn.Module)
+        self.use_fp16 = args.use_fp16
+
+    # img: CHW
+    def resize_norm_img(self, img, max_wh_ratio):
+        imgC, imgH, imgW = self.rec_image_shape
+        assert imgC == img.shape[0]
+        imgW = int((imgH * max_wh_ratio))
+
+        h, w = img.shape[1:]
+        ratio = w / float(h)
+        if math.ceil(imgH * ratio) > imgW:
+            resized_w = imgW
+        else:
+            resized_w = int(math.ceil(imgH * ratio))
+        resized_image = torch.nn.functional.interpolate(
+            img.unsqueeze(0),
+            size=(imgH, resized_w),
+            mode="bilinear",
+            align_corners=True,
+        )
+        resized_image /= 255.0
+        resized_image -= 0.5
+        resized_image /= 0.5
+        padding_im = torch.zeros((imgC, imgH, imgW), dtype=torch.float32).to(img.device)
+        padding_im[:, :, 0:resized_w] = resized_image[0]
+        return padding_im
+
+    # img_list: list of tensors with shape chw 0-255
+    def pred_imglist(self, img_list, show_debug=False, is_ori=False):
+        img_num = len(img_list)
+        assert img_num > 0
+        # Calculate the aspect ratio of all text bars
+        width_list = []
+        for img in img_list:
+            width_list.append(img.shape[2] / float(img.shape[1]))
+        # Sorting can speed up the recognition process
+        indices = torch.from_numpy(np.argsort(np.array(width_list)))
+        batch_num = self.rec_batch_num
+        preds_all = [None] * img_num
+        preds_neck_all = [None] * img_num
+        for beg_img_no in range(0, img_num, batch_num):
+            end_img_no = min(img_num, beg_img_no + batch_num)
+            norm_img_batch = []
+
+            imgC, imgH, imgW = self.rec_image_shape[:3]
+            max_wh_ratio = imgW / imgH
+            for ino in range(beg_img_no, end_img_no):
+                h, w = img_list[indices[ino]].shape[1:]
+                if h > w * 1.2:
+                    img = img_list[indices[ino]]
+                    img = torch.transpose(img, 1, 2).flip(dims=[1])
+                    img_list[indices[ino]] = img
+                    h, w = img.shape[1:]
+                # wh_ratio = w * 1.0 / h
+                # max_wh_ratio = max(max_wh_ratio, wh_ratio)  # comment to not use different ratio
+            for ino in range(beg_img_no, end_img_no):
+                norm_img = self.resize_norm_img(img_list[indices[ino]], max_wh_ratio)
+                if self.use_fp16:
+                    norm_img = norm_img.half()
+                norm_img = norm_img.unsqueeze(0)
+                norm_img_batch.append(norm_img)
+            norm_img_batch = torch.cat(norm_img_batch, dim=0)
+            if show_debug:
+                for i in range(len(norm_img_batch)):
+                    _img = norm_img_batch[i].permute(1, 2, 0).detach().cpu().numpy()
+                    _img = (_img + 0.5) * 255
+                    _img = _img[:, :, ::-1]
+                    file_name = f"{indices[beg_img_no + i]}"
+                    file_name = file_name + "_ori" if is_ori else file_name
+                    cv2.imwrite(file_name + ".jpg", _img)
+            if self.is_onnx:
+                input_dict = {}
+                input_dict[self.predictor.get_inputs()[0].name] = (
+                    norm_img_batch.detach().cpu().numpy()
+                )
+                outputs = self.predictor.run(None, input_dict)
+                preds = {}
+                preds["ctc"] = torch.from_numpy(outputs[0])
+                preds["ctc_neck"] = [torch.zeros(1)] * img_num
+            else:
+                preds = self.predictor(norm_img_batch)
+            for rno in range(preds["ctc"].shape[0]):
+                preds_all[indices[beg_img_no + rno]] = preds["ctc"][rno]
+                preds_neck_all[indices[beg_img_no + rno]] = preds["ctc_neck"][rno]
+
+        return torch.stack(preds_all, dim=0), torch.stack(preds_neck_all, dim=0)
+
+    def get_char_dict(self, character_dict_path):
+        character_str = []
+        with open(character_dict_path, "rb") as fin:
+            lines = fin.readlines()
+            for line in lines:
+                line = line.decode("utf-8").strip("\n").strip("\r\n")
+                character_str.append(line)
+        dict_character = list(character_str)
+        dict_character = ["sos"] + dict_character + [" "]  # eos is space
+        return dict_character
+
+    def get_text(self, order):
+        char_list = [self.chars[text_id] for text_id in order]
+        return "".join(char_list)
+
+    def decode(self, mat):
+        text_index = mat.detach().cpu().numpy().argmax(axis=1)
+        ignored_tokens = [0]
+        selection = np.ones(len(text_index), dtype=bool)
+        selection[1:] = text_index[1:] != text_index[:-1]
+        for ignored_token in ignored_tokens:
+            selection &= text_index != ignored_token
+        return text_index[selection], np.where(selection)[0]
+
+    def get_ctcloss(self, preds, gt_text, weight):
+        if not isinstance(weight, torch.Tensor):
+            weight = torch.tensor(weight).to(preds.device)
+        ctc_loss = torch.nn.CTCLoss(reduction="none")
+        log_probs = preds.log_softmax(dim=2).permute(1, 0, 2)  # NTC-->TNC
+        targets = []
+        target_lengths = []
+        for t in gt_text:
+            targets += [self.char2id.get(i, len(self.chars) - 1) for i in t]
+            target_lengths += [len(t)]
+        targets = torch.tensor(targets).to(preds.device)
+        target_lengths = torch.tensor(target_lengths).to(preds.device)
+        input_lengths = torch.tensor([log_probs.shape[0]] * (log_probs.shape[1])).to(
+            preds.device
+        )
+        loss = ctc_loss(log_probs, targets, input_lengths, target_lengths)
+        loss = loss / input_lengths * weight
+        return loss
+
+
+def main():
+    rec_model_dir = "./ocr_weights/ppv3_rec.pth"
+    predictor = create_predictor(rec_model_dir)
+    args = edict()
+    args.rec_image_shape = "3, 48, 320"
+    args.rec_char_dict_path = "./ocr_weights/ppocr_keys_v1.txt"
+    args.rec_batch_num = 6
+    text_recognizer = TextRecognizer(args, predictor)
+    image_dir = "./test_imgs_cn"
+    gt_text = ["韩国小馆"] * 14
+
+    image_file_list = get_image_file_list(image_dir)
+    valid_image_file_list = []
+    img_list = []
+
+    for image_file in image_file_list:
+        img = cv2.imread(image_file)
+        if img is None:
+            print("error in loading image:{}".format(image_file))
+            continue
+        valid_image_file_list.append(image_file)
+        img_list.append(torch.from_numpy(img).permute(2, 0, 1).float())
+    try:
+        tic = time.time()
+        times = []
+        for i in range(10):
+            preds, _ = text_recognizer.pred_imglist(img_list)  # get text
+            preds_all = preds.softmax(dim=2)
+            times += [(time.time() - tic) * 1000.0]
+            tic = time.time()
+        print(times)
+        print(np.mean(times[1:]) / len(preds_all))
+        weight = np.ones(len(gt_text))
+        loss = text_recognizer.get_ctcloss(preds, gt_text, weight)
+        for i in range(len(valid_image_file_list)):
+            pred = preds_all[i]
+            order, idx = text_recognizer.decode(pred)
+            text = text_recognizer.get_text(order)
+            print(
+                f'{valid_image_file_list[i]}: pred/gt="{text}"/"{gt_text[i]}", loss={loss[i]:.2f}'
+            )
+    except Exception as E:
+        print(traceback.format_exc(), E)
+
+
+if __name__ == "__main__":
+    main()

+ 0 - 0
sorawm/iopaint/model/anytext/ldm/__init__.py


BIN
sorawm/iopaint/model/anytext/ldm/__pycache__/__init__.cpython-312.pyc


BIN
sorawm/iopaint/model/anytext/ldm/__pycache__/util.cpython-312.pyc


+ 0 - 0
sorawm/iopaint/model/anytext/ldm/models/__init__.py


+ 275 - 0
sorawm/iopaint/model/anytext/ldm/models/autoencoder.py

@@ -0,0 +1,275 @@
+from contextlib import contextmanager
+
+import torch
+import torch.nn.functional as F
+
+from sorawm.iopaint.model.anytext.ldm.modules.diffusionmodules.model import (
+    Decoder,
+    Encoder,
+)
+from sorawm.iopaint.model.anytext.ldm.modules.distributions.distributions import (
+    DiagonalGaussianDistribution,
+)
+from sorawm.iopaint.model.anytext.ldm.modules.ema import LitEma
+from sorawm.iopaint.model.anytext.ldm.util import instantiate_from_config
+
+
+class AutoencoderKL(torch.nn.Module):
+    def __init__(
+        self,
+        ddconfig,
+        lossconfig,
+        embed_dim,
+        ckpt_path=None,
+        ignore_keys=[],
+        image_key="image",
+        colorize_nlabels=None,
+        monitor=None,
+        ema_decay=None,
+        learn_logvar=False,
+    ):
+        super().__init__()
+        self.learn_logvar = learn_logvar
+        self.image_key = image_key
+        self.encoder = Encoder(**ddconfig)
+        self.decoder = Decoder(**ddconfig)
+        self.loss = instantiate_from_config(lossconfig)
+        assert ddconfig["double_z"]
+        self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1)
+        self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
+        self.embed_dim = embed_dim
+        if colorize_nlabels is not None:
+            assert type(colorize_nlabels) == int
+            self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
+        if monitor is not None:
+            self.monitor = monitor
+
+        self.use_ema = ema_decay is not None
+        if self.use_ema:
+            self.ema_decay = ema_decay
+            assert 0.0 < ema_decay < 1.0
+            self.model_ema = LitEma(self, decay=ema_decay)
+            print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
+
+        if ckpt_path is not None:
+            self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
+
+    def init_from_ckpt(self, path, ignore_keys=list()):
+        sd = torch.load(path, map_location="cpu")["state_dict"]
+        keys = list(sd.keys())
+        for k in keys:
+            for ik in ignore_keys:
+                if k.startswith(ik):
+                    print("Deleting key {} from state_dict.".format(k))
+                    del sd[k]
+        self.load_state_dict(sd, strict=False)
+        print(f"Restored from {path}")
+
+    @contextmanager
+    def ema_scope(self, context=None):
+        if self.use_ema:
+            self.model_ema.store(self.parameters())
+            self.model_ema.copy_to(self)
+            if context is not None:
+                print(f"{context}: Switched to EMA weights")
+        try:
+            yield None
+        finally:
+            if self.use_ema:
+                self.model_ema.restore(self.parameters())
+                if context is not None:
+                    print(f"{context}: Restored training weights")
+
+    def on_train_batch_end(self, *args, **kwargs):
+        if self.use_ema:
+            self.model_ema(self)
+
+    def encode(self, x):
+        h = self.encoder(x)
+        moments = self.quant_conv(h)
+        posterior = DiagonalGaussianDistribution(moments)
+        return posterior
+
+    def decode(self, z):
+        z = self.post_quant_conv(z)
+        dec = self.decoder(z)
+        return dec
+
+    def forward(self, input, sample_posterior=True):
+        posterior = self.encode(input)
+        if sample_posterior:
+            z = posterior.sample()
+        else:
+            z = posterior.mode()
+        dec = self.decode(z)
+        return dec, posterior
+
+    def get_input(self, batch, k):
+        x = batch[k]
+        if len(x.shape) == 3:
+            x = x[..., None]
+        x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
+        return x
+
+    def training_step(self, batch, batch_idx, optimizer_idx):
+        inputs = self.get_input(batch, self.image_key)
+        reconstructions, posterior = self(inputs)
+
+        if optimizer_idx == 0:
+            # train encoder+decoder+logvar
+            aeloss, log_dict_ae = self.loss(
+                inputs,
+                reconstructions,
+                posterior,
+                optimizer_idx,
+                self.global_step,
+                last_layer=self.get_last_layer(),
+                split="train",
+            )
+            self.log(
+                "aeloss",
+                aeloss,
+                prog_bar=True,
+                logger=True,
+                on_step=True,
+                on_epoch=True,
+            )
+            self.log_dict(
+                log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False
+            )
+            return aeloss
+
+        if optimizer_idx == 1:
+            # train the discriminator
+            discloss, log_dict_disc = self.loss(
+                inputs,
+                reconstructions,
+                posterior,
+                optimizer_idx,
+                self.global_step,
+                last_layer=self.get_last_layer(),
+                split="train",
+            )
+
+            self.log(
+                "discloss",
+                discloss,
+                prog_bar=True,
+                logger=True,
+                on_step=True,
+                on_epoch=True,
+            )
+            self.log_dict(
+                log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False
+            )
+            return discloss
+
+    def validation_step(self, batch, batch_idx):
+        log_dict = self._validation_step(batch, batch_idx)
+        with self.ema_scope():
+            log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
+        return log_dict
+
+    def _validation_step(self, batch, batch_idx, postfix=""):
+        inputs = self.get_input(batch, self.image_key)
+        reconstructions, posterior = self(inputs)
+        aeloss, log_dict_ae = self.loss(
+            inputs,
+            reconstructions,
+            posterior,
+            0,
+            self.global_step,
+            last_layer=self.get_last_layer(),
+            split="val" + postfix,
+        )
+
+        discloss, log_dict_disc = self.loss(
+            inputs,
+            reconstructions,
+            posterior,
+            1,
+            self.global_step,
+            last_layer=self.get_last_layer(),
+            split="val" + postfix,
+        )
+
+        self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"])
+        self.log_dict(log_dict_ae)
+        self.log_dict(log_dict_disc)
+        return self.log_dict
+
+    def configure_optimizers(self):
+        lr = self.learning_rate
+        ae_params_list = (
+            list(self.encoder.parameters())
+            + list(self.decoder.parameters())
+            + list(self.quant_conv.parameters())
+            + list(self.post_quant_conv.parameters())
+        )
+        if self.learn_logvar:
+            print(f"{self.__class__.__name__}: Learning logvar")
+            ae_params_list.append(self.loss.logvar)
+        opt_ae = torch.optim.Adam(ae_params_list, lr=lr, betas=(0.5, 0.9))
+        opt_disc = torch.optim.Adam(
+            self.loss.discriminator.parameters(), lr=lr, betas=(0.5, 0.9)
+        )
+        return [opt_ae, opt_disc], []
+
+    def get_last_layer(self):
+        return self.decoder.conv_out.weight
+
+    @torch.no_grad()
+    def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs):
+        log = dict()
+        x = self.get_input(batch, self.image_key)
+        x = x.to(self.device)
+        if not only_inputs:
+            xrec, posterior = self(x)
+            if x.shape[1] > 3:
+                # colorize with random projection
+                assert xrec.shape[1] > 3
+                x = self.to_rgb(x)
+                xrec = self.to_rgb(xrec)
+            log["samples"] = self.decode(torch.randn_like(posterior.sample()))
+            log["reconstructions"] = xrec
+            if log_ema or self.use_ema:
+                with self.ema_scope():
+                    xrec_ema, posterior_ema = self(x)
+                    if x.shape[1] > 3:
+                        # colorize with random projection
+                        assert xrec_ema.shape[1] > 3
+                        xrec_ema = self.to_rgb(xrec_ema)
+                    log["samples_ema"] = self.decode(
+                        torch.randn_like(posterior_ema.sample())
+                    )
+                    log["reconstructions_ema"] = xrec_ema
+        log["inputs"] = x
+        return log
+
+    def to_rgb(self, x):
+        assert self.image_key == "segmentation"
+        if not hasattr(self, "colorize"):
+            self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
+        x = F.conv2d(x, weight=self.colorize)
+        x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0
+        return x
+
+
+class IdentityFirstStage(torch.nn.Module):
+    def __init__(self, *args, vq_interface=False, **kwargs):
+        self.vq_interface = vq_interface
+        super().__init__()
+
+    def encode(self, x, *args, **kwargs):
+        return x
+
+    def decode(self, x, *args, **kwargs):
+        return x
+
+    def quantize(self, x, *args, **kwargs):
+        if self.vq_interface:
+            return x, None, [None, None, None]
+        return x
+
+    def forward(self, x, *args, **kwargs):
+        return x

+ 0 - 0
sorawm/iopaint/model/anytext/ldm/models/diffusion/__init__.py


+ 525 - 0
sorawm/iopaint/model/anytext/ldm/models/diffusion/ddim.py

@@ -0,0 +1,525 @@
+"""SAMPLING ONLY."""
+
+import numpy as np
+import torch
+from tqdm import tqdm
+
+from sorawm.iopaint.model.anytext.ldm.modules.diffusionmodules.util import (
+    extract_into_tensor,
+    make_ddim_sampling_parameters,
+    make_ddim_timesteps,
+    noise_like,
+)
+
+
+class DDIMSampler(object):
+    def __init__(self, model, schedule="linear", **kwargs):
+        super().__init__()
+        self.model = model
+        self.ddpm_num_timesteps = model.num_timesteps
+        self.schedule = schedule
+
+    def register_buffer(self, name, attr):
+        if type(attr) == torch.Tensor:
+            if attr.device != torch.device("cuda"):
+                attr = attr.to(torch.device("cuda"))
+        setattr(self, name, attr)
+
+    def make_schedule(
+        self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True
+    ):
+        self.ddim_timesteps = make_ddim_timesteps(
+            ddim_discr_method=ddim_discretize,
+            num_ddim_timesteps=ddim_num_steps,
+            num_ddpm_timesteps=self.ddpm_num_timesteps,
+            verbose=verbose,
+        )
+        alphas_cumprod = self.model.alphas_cumprod
+        assert (
+            alphas_cumprod.shape[0] == self.ddpm_num_timesteps
+        ), "alphas have to be defined for each timestep"
+        to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
+
+        self.register_buffer("betas", to_torch(self.model.betas))
+        self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
+        self.register_buffer(
+            "alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev)
+        )
+
+        # calculations for diffusion q(x_t | x_{t-1}) and others
+        self.register_buffer(
+            "sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu()))
+        )
+        self.register_buffer(
+            "sqrt_one_minus_alphas_cumprod",
+            to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
+        )
+        self.register_buffer(
+            "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu()))
+        )
+        self.register_buffer(
+            "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))
+        )
+        self.register_buffer(
+            "sqrt_recipm1_alphas_cumprod",
+            to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
+        )
+
+        # ddim sampling parameters
+        ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
+            alphacums=alphas_cumprod.cpu(),
+            ddim_timesteps=self.ddim_timesteps,
+            eta=ddim_eta,
+            verbose=verbose,
+        )
+        self.register_buffer("ddim_sigmas", ddim_sigmas)
+        self.register_buffer("ddim_alphas", ddim_alphas)
+        self.register_buffer("ddim_alphas_prev", ddim_alphas_prev)
+        self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas))
+        sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
+            (1 - self.alphas_cumprod_prev)
+            / (1 - self.alphas_cumprod)
+            * (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
+        )
+        self.register_buffer(
+            "ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps
+        )
+
+    @torch.no_grad()
+    def sample(
+        self,
+        S,
+        batch_size,
+        shape,
+        conditioning=None,
+        callback=None,
+        normals_sequence=None,
+        img_callback=None,
+        quantize_x0=False,
+        eta=0.0,
+        mask=None,
+        x0=None,
+        temperature=1.0,
+        noise_dropout=0.0,
+        score_corrector=None,
+        corrector_kwargs=None,
+        verbose=True,
+        x_T=None,
+        log_every_t=100,
+        unconditional_guidance_scale=1.0,
+        unconditional_conditioning=None,  # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
+        dynamic_threshold=None,
+        ucg_schedule=None,
+        **kwargs,
+    ):
+        if conditioning is not None:
+            if isinstance(conditioning, dict):
+                ctmp = conditioning[list(conditioning.keys())[0]]
+                while isinstance(ctmp, list):
+                    ctmp = ctmp[0]
+                cbs = ctmp.shape[0]
+                # cbs = len(ctmp[0])
+                if cbs != batch_size:
+                    print(
+                        f"Warning: Got {cbs} conditionings but batch-size is {batch_size}"
+                    )
+
+            elif isinstance(conditioning, list):
+                for ctmp in conditioning:
+                    if ctmp.shape[0] != batch_size:
+                        print(
+                            f"Warning: Got {cbs} conditionings but batch-size is {batch_size}"
+                        )
+
+            else:
+                if conditioning.shape[0] != batch_size:
+                    print(
+                        f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}"
+                    )
+
+        self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
+        # sampling
+        C, H, W = shape
+        size = (batch_size, C, H, W)
+        print(f"Data shape for DDIM sampling is {size}, eta {eta}")
+
+        samples, intermediates = self.ddim_sampling(
+            conditioning,
+            size,
+            callback=callback,
+            img_callback=img_callback,
+            quantize_denoised=quantize_x0,
+            mask=mask,
+            x0=x0,
+            ddim_use_original_steps=False,
+            noise_dropout=noise_dropout,
+            temperature=temperature,
+            score_corrector=score_corrector,
+            corrector_kwargs=corrector_kwargs,
+            x_T=x_T,
+            log_every_t=log_every_t,
+            unconditional_guidance_scale=unconditional_guidance_scale,
+            unconditional_conditioning=unconditional_conditioning,
+            dynamic_threshold=dynamic_threshold,
+            ucg_schedule=ucg_schedule,
+        )
+        return samples, intermediates
+
+    @torch.no_grad()
+    def ddim_sampling(
+        self,
+        cond,
+        shape,
+        x_T=None,
+        ddim_use_original_steps=False,
+        callback=None,
+        timesteps=None,
+        quantize_denoised=False,
+        mask=None,
+        x0=None,
+        img_callback=None,
+        log_every_t=100,
+        temperature=1.0,
+        noise_dropout=0.0,
+        score_corrector=None,
+        corrector_kwargs=None,
+        unconditional_guidance_scale=1.0,
+        unconditional_conditioning=None,
+        dynamic_threshold=None,
+        ucg_schedule=None,
+    ):
+        device = self.model.betas.device
+        b = shape[0]
+        if x_T is None:
+            img = torch.randn(shape, device=device)
+        else:
+            img = x_T
+
+        if timesteps is None:
+            timesteps = (
+                self.ddpm_num_timesteps
+                if ddim_use_original_steps
+                else self.ddim_timesteps
+            )
+        elif timesteps is not None and not ddim_use_original_steps:
+            subset_end = (
+                int(
+                    min(timesteps / self.ddim_timesteps.shape[0], 1)
+                    * self.ddim_timesteps.shape[0]
+                )
+                - 1
+            )
+            timesteps = self.ddim_timesteps[:subset_end]
+
+        intermediates = {"x_inter": [img], "pred_x0": [img], "index": [10000]}
+        time_range = (
+            reversed(range(0, timesteps))
+            if ddim_use_original_steps
+            else np.flip(timesteps)
+        )
+        total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
+        print(f"Running DDIM Sampling with {total_steps} timesteps")
+
+        iterator = tqdm(time_range, desc="DDIM Sampler", total=total_steps)
+
+        for i, step in enumerate(iterator):
+            index = total_steps - i - 1
+            ts = torch.full((b,), step, device=device, dtype=torch.long)
+
+            if mask is not None:
+                assert x0 is not None
+                img_orig = self.model.q_sample(
+                    x0, ts
+                )  # TODO: deterministic forward pass?
+                img = img_orig * mask + (1.0 - mask) * img
+
+            if ucg_schedule is not None:
+                assert len(ucg_schedule) == len(time_range)
+                unconditional_guidance_scale = ucg_schedule[i]
+
+            outs = self.p_sample_ddim(
+                img,
+                cond,
+                ts,
+                index=index,
+                use_original_steps=ddim_use_original_steps,
+                quantize_denoised=quantize_denoised,
+                temperature=temperature,
+                noise_dropout=noise_dropout,
+                score_corrector=score_corrector,
+                corrector_kwargs=corrector_kwargs,
+                unconditional_guidance_scale=unconditional_guidance_scale,
+                unconditional_conditioning=unconditional_conditioning,
+                dynamic_threshold=dynamic_threshold,
+            )
+            img, pred_x0 = outs
+            if callback:
+                callback(i)
+            if img_callback:
+                img_callback(pred_x0, i)
+
+            if index % log_every_t == 0 or index == total_steps - 1:
+                intermediates["x_inter"].append(img)
+                intermediates["pred_x0"].append(pred_x0)
+                intermediates["index"].append(index)
+
+        return img, intermediates
+
+    @torch.no_grad()
+    def p_sample_ddim(
+        self,
+        x,
+        c,
+        t,
+        index,
+        repeat_noise=False,
+        use_original_steps=False,
+        quantize_denoised=False,
+        temperature=1.0,
+        noise_dropout=0.0,
+        score_corrector=None,
+        corrector_kwargs=None,
+        unconditional_guidance_scale=1.0,
+        unconditional_conditioning=None,
+        dynamic_threshold=None,
+    ):
+        b, *_, device = *x.shape, x.device
+
+        if unconditional_conditioning is None or unconditional_guidance_scale == 1.0:
+            model_output = self.model.apply_model(x, t, c)
+        else:
+            x_in = torch.cat([x] * 2)
+            t_in = torch.cat([t] * 2)
+            if isinstance(c, dict):
+                assert isinstance(unconditional_conditioning, dict)
+                c_in = dict()
+                for k in c:
+                    if isinstance(c[k], list):
+                        c_in[k] = [
+                            torch.cat([unconditional_conditioning[k][i], c[k][i]])
+                            for i in range(len(c[k]))
+                        ]
+                    elif isinstance(c[k], dict):
+                        c_in[k] = dict()
+                        for key in c[k]:
+                            if isinstance(c[k][key], list):
+                                if not isinstance(c[k][key][0], torch.Tensor):
+                                    continue
+                                c_in[k][key] = [
+                                    torch.cat(
+                                        [
+                                            unconditional_conditioning[k][key][i],
+                                            c[k][key][i],
+                                        ]
+                                    )
+                                    for i in range(len(c[k][key]))
+                                ]
+                            else:
+                                c_in[k][key] = torch.cat(
+                                    [unconditional_conditioning[k][key], c[k][key]]
+                                )
+
+                    else:
+                        c_in[k] = torch.cat([unconditional_conditioning[k], c[k]])
+            elif isinstance(c, list):
+                c_in = list()
+                assert isinstance(unconditional_conditioning, list)
+                for i in range(len(c)):
+                    c_in.append(torch.cat([unconditional_conditioning[i], c[i]]))
+            else:
+                c_in = torch.cat([unconditional_conditioning, c])
+            model_uncond, model_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
+            model_output = model_uncond + unconditional_guidance_scale * (
+                model_t - model_uncond
+            )
+
+        if self.model.parameterization == "v":
+            e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
+        else:
+            e_t = model_output
+
+        if score_corrector is not None:
+            assert self.model.parameterization == "eps", "not implemented"
+            e_t = score_corrector.modify_score(
+                self.model, e_t, x, t, c, **corrector_kwargs
+            )
+
+        alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
+        alphas_prev = (
+            self.model.alphas_cumprod_prev
+            if use_original_steps
+            else self.ddim_alphas_prev
+        )
+        sqrt_one_minus_alphas = (
+            self.model.sqrt_one_minus_alphas_cumprod
+            if use_original_steps
+            else self.ddim_sqrt_one_minus_alphas
+        )
+        sigmas = (
+            self.model.ddim_sigmas_for_original_num_steps
+            if use_original_steps
+            else self.ddim_sigmas
+        )
+        # select parameters corresponding to the currently considered timestep
+        a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
+        a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
+        sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
+        sqrt_one_minus_at = torch.full(
+            (b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device
+        )
+
+        # current prediction for x_0
+        if self.model.parameterization != "v":
+            pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
+        else:
+            pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
+
+        if quantize_denoised:
+            pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
+
+        if dynamic_threshold is not None:
+            raise NotImplementedError()
+
+        # direction pointing to x_t
+        dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t
+        noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
+        if noise_dropout > 0.0:
+            noise = torch.nn.functional.dropout(noise, p=noise_dropout)
+        x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
+        return x_prev, pred_x0
+
+    @torch.no_grad()
+    def encode(
+        self,
+        x0,
+        c,
+        t_enc,
+        use_original_steps=False,
+        return_intermediates=None,
+        unconditional_guidance_scale=1.0,
+        unconditional_conditioning=None,
+        callback=None,
+    ):
+        num_reference_steps = (
+            self.ddpm_num_timesteps
+            if use_original_steps
+            else self.ddim_timesteps.shape[0]
+        )
+
+        assert t_enc <= num_reference_steps
+        num_steps = t_enc
+
+        if use_original_steps:
+            alphas_next = self.alphas_cumprod[:num_steps]
+            alphas = self.alphas_cumprod_prev[:num_steps]
+        else:
+            alphas_next = self.ddim_alphas[:num_steps]
+            alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
+
+        x_next = x0
+        intermediates = []
+        inter_steps = []
+        for i in tqdm(range(num_steps), desc="Encoding Image"):
+            t = torch.full(
+                (x0.shape[0],), i, device=self.model.device, dtype=torch.long
+            )
+            if unconditional_guidance_scale == 1.0:
+                noise_pred = self.model.apply_model(x_next, t, c)
+            else:
+                assert unconditional_conditioning is not None
+                e_t_uncond, noise_pred = torch.chunk(
+                    self.model.apply_model(
+                        torch.cat((x_next, x_next)),
+                        torch.cat((t, t)),
+                        torch.cat((unconditional_conditioning, c)),
+                    ),
+                    2,
+                )
+                noise_pred = e_t_uncond + unconditional_guidance_scale * (
+                    noise_pred - e_t_uncond
+                )
+
+            xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
+            weighted_noise_pred = (
+                alphas_next[i].sqrt()
+                * ((1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt())
+                * noise_pred
+            )
+            x_next = xt_weighted + weighted_noise_pred
+            if (
+                return_intermediates
+                and i % (num_steps // return_intermediates) == 0
+                and i < num_steps - 1
+            ):
+                intermediates.append(x_next)
+                inter_steps.append(i)
+            elif return_intermediates and i >= num_steps - 2:
+                intermediates.append(x_next)
+                inter_steps.append(i)
+            if callback:
+                callback(i)
+
+        out = {"x_encoded": x_next, "intermediate_steps": inter_steps}
+        if return_intermediates:
+            out.update({"intermediates": intermediates})
+        return x_next, out
+
+    @torch.no_grad()
+    def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
+        # fast, but does not allow for exact reconstruction
+        # t serves as an index to gather the correct alphas
+        if use_original_steps:
+            sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
+            sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
+        else:
+            sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
+            sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
+
+        if noise is None:
+            noise = torch.randn_like(x0)
+        return (
+            extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0
+            + extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise
+        )
+
+    @torch.no_grad()
+    def decode(
+        self,
+        x_latent,
+        cond,
+        t_start,
+        unconditional_guidance_scale=1.0,
+        unconditional_conditioning=None,
+        use_original_steps=False,
+        callback=None,
+    ):
+        timesteps = (
+            np.arange(self.ddpm_num_timesteps)
+            if use_original_steps
+            else self.ddim_timesteps
+        )
+        timesteps = timesteps[:t_start]
+
+        time_range = np.flip(timesteps)
+        total_steps = timesteps.shape[0]
+        print(f"Running DDIM Sampling with {total_steps} timesteps")
+
+        iterator = tqdm(time_range, desc="Decoding image", total=total_steps)
+        x_dec = x_latent
+        for i, step in enumerate(iterator):
+            index = total_steps - i - 1
+            ts = torch.full(
+                (x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long
+            )
+            x_dec, _ = self.p_sample_ddim(
+                x_dec,
+                cond,
+                ts,
+                index=index,
+                use_original_steps=use_original_steps,
+                unconditional_guidance_scale=unconditional_guidance_scale,
+                unconditional_conditioning=unconditional_conditioning,
+            )
+            if callback:
+                callback(i)
+        return x_dec

+ 2386 - 0
sorawm/iopaint/model/anytext/ldm/models/diffusion/ddpm.py

@@ -0,0 +1,2386 @@
+"""
+Part of the implementation is borrowed and modified from ControlNet, publicly available at https://github.com/lllyasviel/ControlNet/blob/main/ldm/models/diffusion/ddpm.py
+"""
+
+import itertools
+from contextlib import contextmanager, nullcontext
+from functools import partial
+
+import cv2
+import numpy as np
+import torch
+import torch.nn as nn
+from einops import rearrange, repeat
+from omegaconf import ListConfig
+from torch.optim.lr_scheduler import LambdaLR
+from torchvision.utils import make_grid
+from tqdm import tqdm
+
+from sorawm.iopaint.model.anytext.ldm.models.autoencoder import (
+    AutoencoderKL,
+    IdentityFirstStage,
+)
+from sorawm.iopaint.model.anytext.ldm.models.diffusion.ddim import DDIMSampler
+from sorawm.iopaint.model.anytext.ldm.modules.diffusionmodules.util import (
+    extract_into_tensor,
+    make_beta_schedule,
+    noise_like,
+)
+from sorawm.iopaint.model.anytext.ldm.modules.distributions.distributions import (
+    DiagonalGaussianDistribution,
+    normal_kl,
+)
+from sorawm.iopaint.model.anytext.ldm.modules.ema import LitEma
+from sorawm.iopaint.model.anytext.ldm.util import (
+    count_params,
+    default,
+    exists,
+    instantiate_from_config,
+    isimage,
+    ismap,
+    log_txt_as_img,
+    mean_flat,
+)
+
+__conditioning_keys__ = {"concat": "c_concat", "crossattn": "c_crossattn", "adm": "y"}
+
+PRINT_DEBUG = False
+
+
+def print_grad(grad):
+    # print('Gradient:', grad)
+    # print(grad.shape)
+    a = grad.max()
+    b = grad.min()
+    # print(f'mean={grad.mean():.4f}, max={a:.4f}, min={b:.4f}')
+    s = 255.0 / (a - b)
+    c = 255 * (-b / (a - b))
+    grad = grad * s + c
+    # print(f'mean={grad.mean():.4f}, max={grad.max():.4f}, min={grad.min():.4f}')
+    img = grad[0].permute(1, 2, 0).detach().cpu().numpy()
+    if img.shape[0] == 512:
+        cv2.imwrite("grad-img.jpg", img)
+    elif img.shape[0] == 64:
+        cv2.imwrite("grad-latent.jpg", img)
+
+
+def disabled_train(self, mode=True):
+    """Overwrite model.train with this function to make sure train/eval mode
+    does not change anymore."""
+    return self
+
+
+def uniform_on_device(r1, r2, shape, device):
+    return (r1 - r2) * torch.rand(*shape, device=device) + r2
+
+
+class DDPM(torch.nn.Module):
+    # classic DDPM with Gaussian diffusion, in image space
+    def __init__(
+        self,
+        unet_config,
+        timesteps=1000,
+        beta_schedule="linear",
+        loss_type="l2",
+        ckpt_path=None,
+        ignore_keys=[],
+        load_only_unet=False,
+        monitor="val/loss",
+        use_ema=True,
+        first_stage_key="image",
+        image_size=256,
+        channels=3,
+        log_every_t=100,
+        clip_denoised=True,
+        linear_start=1e-4,
+        linear_end=2e-2,
+        cosine_s=8e-3,
+        given_betas=None,
+        original_elbo_weight=0.0,
+        v_posterior=0.0,  # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
+        l_simple_weight=1.0,
+        conditioning_key=None,
+        parameterization="eps",  # all assuming fixed variance schedules
+        scheduler_config=None,
+        use_positional_encodings=False,
+        learn_logvar=False,
+        logvar_init=0.0,
+        make_it_fit=False,
+        ucg_training=None,
+        reset_ema=False,
+        reset_num_ema_updates=False,
+    ):
+        super().__init__()
+        assert parameterization in [
+            "eps",
+            "x0",
+            "v",
+        ], 'currently only supporting "eps" and "x0" and "v"'
+        self.parameterization = parameterization
+        print(
+            f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode"
+        )
+        self.cond_stage_model = None
+        self.clip_denoised = clip_denoised
+        self.log_every_t = log_every_t
+        self.first_stage_key = first_stage_key
+        self.image_size = image_size  # try conv?
+        self.channels = channels
+        self.use_positional_encodings = use_positional_encodings
+        self.model = DiffusionWrapper(unet_config, conditioning_key)
+        count_params(self.model, verbose=True)
+        self.use_ema = use_ema
+        if self.use_ema:
+            self.model_ema = LitEma(self.model)
+            print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
+
+        self.use_scheduler = scheduler_config is not None
+        if self.use_scheduler:
+            self.scheduler_config = scheduler_config
+
+        self.v_posterior = v_posterior
+        self.original_elbo_weight = original_elbo_weight
+        self.l_simple_weight = l_simple_weight
+
+        if monitor is not None:
+            self.monitor = monitor
+        self.make_it_fit = make_it_fit
+        if reset_ema:
+            assert exists(ckpt_path)
+        if ckpt_path is not None:
+            self.init_from_ckpt(
+                ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet
+            )
+            if reset_ema:
+                assert self.use_ema
+                print(
+                    f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint."
+                )
+                self.model_ema = LitEma(self.model)
+        if reset_num_ema_updates:
+            print(
+                " +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ "
+            )
+            assert self.use_ema
+            self.model_ema.reset_num_updates()
+
+        self.register_schedule(
+            given_betas=given_betas,
+            beta_schedule=beta_schedule,
+            timesteps=timesteps,
+            linear_start=linear_start,
+            linear_end=linear_end,
+            cosine_s=cosine_s,
+        )
+
+        self.loss_type = loss_type
+
+        self.learn_logvar = learn_logvar
+        logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))
+        if self.learn_logvar:
+            self.logvar = nn.Parameter(self.logvar, requires_grad=True)
+        else:
+            self.register_buffer("logvar", logvar)
+
+        self.ucg_training = ucg_training or dict()
+        if self.ucg_training:
+            self.ucg_prng = np.random.RandomState()
+
+    def register_schedule(
+        self,
+        given_betas=None,
+        beta_schedule="linear",
+        timesteps=1000,
+        linear_start=1e-4,
+        linear_end=2e-2,
+        cosine_s=8e-3,
+    ):
+        if exists(given_betas):
+            betas = given_betas
+        else:
+            betas = make_beta_schedule(
+                beta_schedule,
+                timesteps,
+                linear_start=linear_start,
+                linear_end=linear_end,
+                cosine_s=cosine_s,
+            )
+        alphas = 1.0 - betas
+        alphas_cumprod = np.cumprod(alphas, axis=0)
+        # np.save('1.npy', alphas_cumprod)
+        alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
+
+        (timesteps,) = betas.shape
+        self.num_timesteps = int(timesteps)
+        self.linear_start = linear_start
+        self.linear_end = linear_end
+        assert (
+            alphas_cumprod.shape[0] == self.num_timesteps
+        ), "alphas have to be defined for each timestep"
+
+        to_torch = partial(torch.tensor, dtype=torch.float32)
+
+        self.register_buffer("betas", to_torch(betas))
+        self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
+        self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev))
+
+        # calculations for diffusion q(x_t | x_{t-1}) and others
+        self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod)))
+        self.register_buffer(
+            "sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod))
+        )
+        self.register_buffer(
+            "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod))
+        )
+        self.register_buffer(
+            "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod))
+        )
+        self.register_buffer(
+            "sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1))
+        )
+
+        # calculations for posterior q(x_{t-1} | x_t, x_0)
+        posterior_variance = (1 - self.v_posterior) * betas * (
+            1.0 - alphas_cumprod_prev
+        ) / (1.0 - alphas_cumprod) + self.v_posterior * betas
+        # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
+        self.register_buffer("posterior_variance", to_torch(posterior_variance))
+        # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
+        self.register_buffer(
+            "posterior_log_variance_clipped",
+            to_torch(np.log(np.maximum(posterior_variance, 1e-20))),
+        )
+        self.register_buffer(
+            "posterior_mean_coef1",
+            to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)),
+        )
+        self.register_buffer(
+            "posterior_mean_coef2",
+            to_torch(
+                (1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod)
+            ),
+        )
+
+        if self.parameterization == "eps":
+            lvlb_weights = self.betas**2 / (
+                2
+                * self.posterior_variance
+                * to_torch(alphas)
+                * (1 - self.alphas_cumprod)
+            )
+        elif self.parameterization == "x0":
+            lvlb_weights = (
+                0.5
+                * np.sqrt(torch.Tensor(alphas_cumprod))
+                / (2.0 * 1 - torch.Tensor(alphas_cumprod))
+            )
+        elif self.parameterization == "v":
+            lvlb_weights = torch.ones_like(
+                self.betas**2
+                / (
+                    2
+                    * self.posterior_variance
+                    * to_torch(alphas)
+                    * (1 - self.alphas_cumprod)
+                )
+            )
+        else:
+            raise NotImplementedError("mu not supported")
+        lvlb_weights[0] = lvlb_weights[1]
+        self.register_buffer("lvlb_weights", lvlb_weights, persistent=False)
+        assert not torch.isnan(self.lvlb_weights).all()
+
+    @contextmanager
+    def ema_scope(self, context=None):
+        if self.use_ema:
+            self.model_ema.store(self.model.parameters())
+            self.model_ema.copy_to(self.model)
+            if context is not None:
+                print(f"{context}: Switched to EMA weights")
+        try:
+            yield None
+        finally:
+            if self.use_ema:
+                self.model_ema.restore(self.model.parameters())
+                if context is not None:
+                    print(f"{context}: Restored training weights")
+
+    @torch.no_grad()
+    def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
+        sd = torch.load(path, map_location="cpu")
+        if "state_dict" in list(sd.keys()):
+            sd = sd["state_dict"]
+        keys = list(sd.keys())
+        for k in keys:
+            for ik in ignore_keys:
+                if k.startswith(ik):
+                    print("Deleting key {} from state_dict.".format(k))
+                    del sd[k]
+        if self.make_it_fit:
+            n_params = len(
+                [
+                    name
+                    for name, _ in itertools.chain(
+                        self.named_parameters(), self.named_buffers()
+                    )
+                ]
+            )
+            for name, param in tqdm(
+                itertools.chain(self.named_parameters(), self.named_buffers()),
+                desc="Fitting old weights to new weights",
+                total=n_params,
+            ):
+                if not name in sd:
+                    continue
+                old_shape = sd[name].shape
+                new_shape = param.shape
+                assert len(old_shape) == len(new_shape)
+                if len(new_shape) > 2:
+                    # we only modify first two axes
+                    assert new_shape[2:] == old_shape[2:]
+                # assumes first axis corresponds to output dim
+                if not new_shape == old_shape:
+                    new_param = param.clone()
+                    old_param = sd[name]
+                    if len(new_shape) == 1:
+                        for i in range(new_param.shape[0]):
+                            new_param[i] = old_param[i % old_shape[0]]
+                    elif len(new_shape) >= 2:
+                        for i in range(new_param.shape[0]):
+                            for j in range(new_param.shape[1]):
+                                new_param[i, j] = old_param[
+                                    i % old_shape[0], j % old_shape[1]
+                                ]
+
+                        n_used_old = torch.ones(old_shape[1])
+                        for j in range(new_param.shape[1]):
+                            n_used_old[j % old_shape[1]] += 1
+                        n_used_new = torch.zeros(new_shape[1])
+                        for j in range(new_param.shape[1]):
+                            n_used_new[j] = n_used_old[j % old_shape[1]]
+
+                        n_used_new = n_used_new[None, :]
+                        while len(n_used_new.shape) < len(new_shape):
+                            n_used_new = n_used_new.unsqueeze(-1)
+                        new_param /= n_used_new
+
+                    sd[name] = new_param
+
+        missing, unexpected = (
+            self.load_state_dict(sd, strict=False)
+            if not only_model
+            else self.model.load_state_dict(sd, strict=False)
+        )
+        print(
+            f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
+        )
+        if len(missing) > 0:
+            print(f"Missing Keys:\n {missing}")
+        if len(unexpected) > 0:
+            print(f"\nUnexpected Keys:\n {unexpected}")
+
+    def q_mean_variance(self, x_start, t):
+        """
+        Get the distribution q(x_t | x_0).
+        :param x_start: the [N x C x ...] tensor of noiseless inputs.
+        :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
+        :return: A tuple (mean, variance, log_variance), all of x_start's shape.
+        """
+        mean = extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+        variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
+        log_variance = extract_into_tensor(
+            self.log_one_minus_alphas_cumprod, t, x_start.shape
+        )
+        return mean, variance, log_variance
+
+    def predict_start_from_noise(self, x_t, t, noise):
+        return (
+            extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
+            - extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
+            * noise
+        )
+
+    def predict_start_from_z_and_v(self, x_t, t, v):
+        # self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
+        # self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
+        return (
+            extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t
+            - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
+        )
+
+    def predict_eps_from_z_and_v(self, x_t, t, v):
+        return (
+            extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v
+            + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape)
+            * x_t
+        )
+
+    def q_posterior(self, x_start, x_t, t):
+        posterior_mean = (
+            extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
+            + extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
+        )
+        posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)
+        posterior_log_variance_clipped = extract_into_tensor(
+            self.posterior_log_variance_clipped, t, x_t.shape
+        )
+        return posterior_mean, posterior_variance, posterior_log_variance_clipped
+
+    def p_mean_variance(self, x, t, clip_denoised: bool):
+        model_out = self.model(x, t)
+        if self.parameterization == "eps":
+            x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
+        elif self.parameterization == "x0":
+            x_recon = model_out
+        if clip_denoised:
+            x_recon.clamp_(-1.0, 1.0)
+
+        model_mean, posterior_variance, posterior_log_variance = self.q_posterior(
+            x_start=x_recon, x_t=x, t=t
+        )
+        return model_mean, posterior_variance, posterior_log_variance
+
+    @torch.no_grad()
+    def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
+        b, *_, device = *x.shape, x.device
+        model_mean, _, model_log_variance = self.p_mean_variance(
+            x=x, t=t, clip_denoised=clip_denoised
+        )
+        noise = noise_like(x.shape, device, repeat_noise)
+        # no noise when t == 0
+        nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
+        return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
+
+    @torch.no_grad()
+    def p_sample_loop(self, shape, return_intermediates=False):
+        device = self.betas.device
+        b = shape[0]
+        img = torch.randn(shape, device=device)
+        intermediates = [img]
+        for i in tqdm(
+            reversed(range(0, self.num_timesteps)),
+            desc="Sampling t",
+            total=self.num_timesteps,
+        ):
+            img = self.p_sample(
+                img,
+                torch.full((b,), i, device=device, dtype=torch.long),
+                clip_denoised=self.clip_denoised,
+            )
+            if i % self.log_every_t == 0 or i == self.num_timesteps - 1:
+                intermediates.append(img)
+        if return_intermediates:
+            return img, intermediates
+        return img
+
+    @torch.no_grad()
+    def sample(self, batch_size=16, return_intermediates=False):
+        image_size = self.image_size
+        channels = self.channels
+        return self.p_sample_loop(
+            (batch_size, channels, image_size, image_size),
+            return_intermediates=return_intermediates,
+        )
+
+    def q_sample(self, x_start, t, noise=None):
+        noise = default(noise, lambda: torch.randn_like(x_start))
+        return (
+            extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+            + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
+            * noise
+        )
+
+    def get_v(self, x, noise, t):
+        return (
+            extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise
+            - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x
+        )
+
+    def get_loss(self, pred, target, mean=True):
+        if self.loss_type == "l1":
+            loss = (target - pred).abs()
+            if mean:
+                loss = loss.mean()
+        elif self.loss_type == "l2":
+            if mean:
+                loss = torch.nn.functional.mse_loss(target, pred)
+            else:
+                loss = torch.nn.functional.mse_loss(target, pred, reduction="none")
+        else:
+            raise NotImplementedError("unknown loss type '{loss_type}'")
+
+        return loss
+
+    def p_losses(self, x_start, t, noise=None):
+        noise = default(noise, lambda: torch.randn_like(x_start))
+        x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
+        model_out = self.model(x_noisy, t)
+
+        loss_dict = {}
+        if self.parameterization == "eps":
+            target = noise
+        elif self.parameterization == "x0":
+            target = x_start
+        elif self.parameterization == "v":
+            target = self.get_v(x_start, noise, t)
+        else:
+            raise NotImplementedError(
+                f"Parameterization {self.parameterization} not yet supported"
+            )
+
+        loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3])
+
+        log_prefix = "train" if self.training else "val"
+
+        loss_dict.update({f"{log_prefix}/loss_simple": loss.mean()})
+        loss_simple = loss.mean() * self.l_simple_weight
+
+        loss_vlb = (self.lvlb_weights[t] * loss).mean()
+        loss_dict.update({f"{log_prefix}/loss_vlb": loss_vlb})
+
+        loss = loss_simple + self.original_elbo_weight * loss_vlb
+
+        loss_dict.update({f"{log_prefix}/loss": loss})
+
+        return loss, loss_dict
+
+    def forward(self, x, *args, **kwargs):
+        # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size
+        # assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
+        t = torch.randint(
+            0, self.num_timesteps, (x.shape[0],), device=self.device
+        ).long()
+        return self.p_losses(x, t, *args, **kwargs)
+
+    def get_input(self, batch, k):
+        x = batch[k]
+        if len(x.shape) == 3:
+            x = x[..., None]
+        x = rearrange(x, "b h w c -> b c h w")
+        x = x.to(memory_format=torch.contiguous_format).float()
+        return x
+
+    def shared_step(self, batch):
+        x = self.get_input(batch, self.first_stage_key)
+        loss, loss_dict = self(x)
+        return loss, loss_dict
+
+    def training_step(self, batch, batch_idx):
+        for k in self.ucg_training:
+            p = self.ucg_training[k]["p"]
+            val = self.ucg_training[k]["val"]
+            if val is None:
+                val = ""
+            for i in range(len(batch[k])):
+                if self.ucg_prng.choice(2, p=[1 - p, p]):
+                    batch[k][i] = val
+
+        loss, loss_dict = self.shared_step(batch)
+
+        self.log_dict(
+            loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=True
+        )
+
+        self.log(
+            "global_step",
+            self.global_step,
+            prog_bar=True,
+            logger=True,
+            on_step=True,
+            on_epoch=False,
+        )
+
+        if self.use_scheduler:
+            lr = self.optimizers().param_groups[0]["lr"]
+            self.log(
+                "lr_abs", lr, prog_bar=True, logger=True, on_step=True, on_epoch=False
+            )
+
+        return loss
+
+    @torch.no_grad()
+    def validation_step(self, batch, batch_idx):
+        _, loss_dict_no_ema = self.shared_step(batch)
+        with self.ema_scope():
+            _, loss_dict_ema = self.shared_step(batch)
+            loss_dict_ema = {key + "_ema": loss_dict_ema[key] for key in loss_dict_ema}
+        self.log_dict(
+            loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True
+        )
+        self.log_dict(
+            loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True
+        )
+
+    def on_train_batch_end(self, *args, **kwargs):
+        if self.use_ema:
+            self.model_ema(self.model)
+
+    def _get_rows_from_list(self, samples):
+        n_imgs_per_row = len(samples)
+        denoise_grid = rearrange(samples, "n b c h w -> b n c h w")
+        denoise_grid = rearrange(denoise_grid, "b n c h w -> (b n) c h w")
+        denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
+        return denoise_grid
+
+    @torch.no_grad()
+    def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
+        log = dict()
+        x = self.get_input(batch, self.first_stage_key)
+        N = min(x.shape[0], N)
+        n_row = min(x.shape[0], n_row)
+        x = x.to(self.device)[:N]
+        log["inputs"] = x
+
+        # get diffusion row
+        diffusion_row = list()
+        x_start = x[:n_row]
+
+        for t in range(self.num_timesteps):
+            if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
+                t = repeat(torch.tensor([t]), "1 -> b", b=n_row)
+                t = t.to(self.device).long()
+                noise = torch.randn_like(x_start)
+                x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
+                diffusion_row.append(x_noisy)
+
+        log["diffusion_row"] = self._get_rows_from_list(diffusion_row)
+
+        if sample:
+            # get denoise row
+            with self.ema_scope("Plotting"):
+                samples, denoise_row = self.sample(
+                    batch_size=N, return_intermediates=True
+                )
+
+            log["samples"] = samples
+            log["denoise_row"] = self._get_rows_from_list(denoise_row)
+
+        if return_keys:
+            if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
+                return log
+            else:
+                return {key: log[key] for key in return_keys}
+        return log
+
+    def configure_optimizers(self):
+        lr = self.learning_rate
+        params = list(self.model.parameters())
+        if self.learn_logvar:
+            params = params + [self.logvar]
+        opt = torch.optim.AdamW(params, lr=lr)
+        return opt
+
+
+class LatentDiffusion(DDPM):
+    """main class"""
+
+    def __init__(
+        self,
+        first_stage_config,
+        cond_stage_config,
+        num_timesteps_cond=None,
+        cond_stage_key="image",
+        cond_stage_trainable=False,
+        concat_mode=True,
+        cond_stage_forward=None,
+        conditioning_key=None,
+        scale_factor=1.0,
+        scale_by_std=False,
+        force_null_conditioning=False,
+        *args,
+        **kwargs,
+    ):
+        self.force_null_conditioning = force_null_conditioning
+        self.num_timesteps_cond = default(num_timesteps_cond, 1)
+        self.scale_by_std = scale_by_std
+        assert self.num_timesteps_cond <= kwargs["timesteps"]
+        # for backwards compatibility after implementation of DiffusionWrapper
+        if conditioning_key is None:
+            conditioning_key = "concat" if concat_mode else "crossattn"
+        if (
+            cond_stage_config == "__is_unconditional__"
+            and not self.force_null_conditioning
+        ):
+            conditioning_key = None
+        ckpt_path = kwargs.pop("ckpt_path", None)
+        reset_ema = kwargs.pop("reset_ema", False)
+        reset_num_ema_updates = kwargs.pop("reset_num_ema_updates", False)
+        ignore_keys = kwargs.pop("ignore_keys", [])
+        super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
+        self.concat_mode = concat_mode
+        self.cond_stage_trainable = cond_stage_trainable
+        self.cond_stage_key = cond_stage_key
+        try:
+            self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
+        except:
+            self.num_downs = 0
+        if not scale_by_std:
+            self.scale_factor = scale_factor
+        else:
+            self.register_buffer("scale_factor", torch.tensor(scale_factor))
+        self.instantiate_first_stage(first_stage_config)
+        self.instantiate_cond_stage(cond_stage_config)
+        self.cond_stage_forward = cond_stage_forward
+        self.clip_denoised = False
+        self.bbox_tokenizer = None
+
+        self.restarted_from_ckpt = False
+        if ckpt_path is not None:
+            self.init_from_ckpt(ckpt_path, ignore_keys)
+            self.restarted_from_ckpt = True
+            if reset_ema:
+                assert self.use_ema
+                print(
+                    f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint."
+                )
+                self.model_ema = LitEma(self.model)
+        if reset_num_ema_updates:
+            print(
+                " +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ "
+            )
+            assert self.use_ema
+            self.model_ema.reset_num_updates()
+
+    def make_cond_schedule(
+        self,
+    ):
+        self.cond_ids = torch.full(
+            size=(self.num_timesteps,),
+            fill_value=self.num_timesteps - 1,
+            dtype=torch.long,
+        )
+        ids = torch.round(
+            torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)
+        ).long()
+        self.cond_ids[: self.num_timesteps_cond] = ids
+
+    @torch.no_grad()
+    def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
+        # only for very first batch
+        if (
+            self.scale_by_std
+            and self.current_epoch == 0
+            and self.global_step == 0
+            and batch_idx == 0
+            and not self.restarted_from_ckpt
+        ):
+            assert (
+                self.scale_factor == 1.0
+            ), "rather not use custom rescaling and std-rescaling simultaneously"
+            # set rescale weight to 1./std of encodings
+            print("### USING STD-RESCALING ###")
+            x = super().get_input(batch, self.first_stage_key)
+            x = x.to(self.device)
+            encoder_posterior = self.encode_first_stage(x)
+            z = self.get_first_stage_encoding(encoder_posterior).detach()
+            del self.scale_factor
+            self.register_buffer("scale_factor", 1.0 / z.flatten().std())
+            print(f"setting self.scale_factor to {self.scale_factor}")
+            print("### USING STD-RESCALING ###")
+
+    def register_schedule(
+        self,
+        given_betas=None,
+        beta_schedule="linear",
+        timesteps=1000,
+        linear_start=1e-4,
+        linear_end=2e-2,
+        cosine_s=8e-3,
+    ):
+        super().register_schedule(
+            given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s
+        )
+
+        self.shorten_cond_schedule = self.num_timesteps_cond > 1
+        if self.shorten_cond_schedule:
+            self.make_cond_schedule()
+
+    def instantiate_first_stage(self, config):
+        model = instantiate_from_config(config)
+        self.first_stage_model = model.eval()
+        self.first_stage_model.train = disabled_train
+        for param in self.first_stage_model.parameters():
+            param.requires_grad = False
+
+    def instantiate_cond_stage(self, config):
+        if not self.cond_stage_trainable:
+            if config == "__is_first_stage__":
+                print("Using first stage also as cond stage.")
+                self.cond_stage_model = self.first_stage_model
+            elif config == "__is_unconditional__":
+                print(f"Training {self.__class__.__name__} as an unconditional model.")
+                self.cond_stage_model = None
+                # self.be_unconditional = True
+            else:
+                model = instantiate_from_config(config)
+                self.cond_stage_model = model.eval()
+                self.cond_stage_model.train = disabled_train
+                for param in self.cond_stage_model.parameters():
+                    param.requires_grad = False
+        else:
+            assert config != "__is_first_stage__"
+            assert config != "__is_unconditional__"
+            model = instantiate_from_config(config)
+            self.cond_stage_model = model
+
+    def _get_denoise_row_from_list(
+        self, samples, desc="", force_no_decoder_quantization=False
+    ):
+        denoise_row = []
+        for zd in tqdm(samples, desc=desc):
+            denoise_row.append(
+                self.decode_first_stage(
+                    zd.to(self.device), force_not_quantize=force_no_decoder_quantization
+                )
+            )
+        n_imgs_per_row = len(denoise_row)
+        denoise_row = torch.stack(denoise_row)  # n_log_step, n_row, C, H, W
+        denoise_grid = rearrange(denoise_row, "n b c h w -> b n c h w")
+        denoise_grid = rearrange(denoise_grid, "b n c h w -> (b n) c h w")
+        denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
+        return denoise_grid
+
+    def get_first_stage_encoding(self, encoder_posterior):
+        if isinstance(encoder_posterior, DiagonalGaussianDistribution):
+            z = encoder_posterior.sample()
+        elif isinstance(encoder_posterior, torch.Tensor):
+            z = encoder_posterior
+        else:
+            raise NotImplementedError(
+                f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented"
+            )
+        return self.scale_factor * z
+
+    def get_learned_conditioning(self, c):
+        if self.cond_stage_forward is None:
+            if hasattr(self.cond_stage_model, "encode") and callable(
+                self.cond_stage_model.encode
+            ):
+                c = self.cond_stage_model.encode(c)
+                if isinstance(c, DiagonalGaussianDistribution):
+                    c = c.mode()
+            else:
+                c = self.cond_stage_model(c)
+        else:
+            assert hasattr(self.cond_stage_model, self.cond_stage_forward)
+            c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
+        return c
+
+    def meshgrid(self, h, w):
+        y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1)
+        x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1)
+
+        arr = torch.cat([y, x], dim=-1)
+        return arr
+
+    def delta_border(self, h, w):
+        """
+        :param h: height
+        :param w: width
+        :return: normalized distance to image border,
+         wtith min distance = 0 at border and max dist = 0.5 at image center
+        """
+        lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2)
+        arr = self.meshgrid(h, w) / lower_right_corner
+        dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0]
+        dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0]
+        edge_dist = torch.min(
+            torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1
+        )[0]
+        return edge_dist
+
+    def get_weighting(self, h, w, Ly, Lx, device):
+        weighting = self.delta_border(h, w)
+        weighting = torch.clip(
+            weighting,
+            self.split_input_params["clip_min_weight"],
+            self.split_input_params["clip_max_weight"],
+        )
+        weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device)
+
+        if self.split_input_params["tie_braker"]:
+            L_weighting = self.delta_border(Ly, Lx)
+            L_weighting = torch.clip(
+                L_weighting,
+                self.split_input_params["clip_min_tie_weight"],
+                self.split_input_params["clip_max_tie_weight"],
+            )
+
+            L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device)
+            weighting = weighting * L_weighting
+        return weighting
+
+    def get_fold_unfold(
+        self, x, kernel_size, stride, uf=1, df=1
+    ):  # todo load once not every time, shorten code
+        """
+        :param x: img of size (bs, c, h, w)
+        :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])
+        """
+        bs, nc, h, w = x.shape
+
+        # number of crops in image
+        Ly = (h - kernel_size[0]) // stride[0] + 1
+        Lx = (w - kernel_size[1]) // stride[1] + 1
+
+        if uf == 1 and df == 1:
+            fold_params = dict(
+                kernel_size=kernel_size, dilation=1, padding=0, stride=stride
+            )
+            unfold = torch.nn.Unfold(**fold_params)
+
+            fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params)
+
+            weighting = self.get_weighting(
+                kernel_size[0], kernel_size[1], Ly, Lx, x.device
+            ).to(x.dtype)
+            normalization = fold(weighting).view(1, 1, h, w)  # normalizes the overlap
+            weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx))
+
+        elif uf > 1 and df == 1:
+            fold_params = dict(
+                kernel_size=kernel_size, dilation=1, padding=0, stride=stride
+            )
+            unfold = torch.nn.Unfold(**fold_params)
+
+            fold_params2 = dict(
+                kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf),
+                dilation=1,
+                padding=0,
+                stride=(stride[0] * uf, stride[1] * uf),
+            )
+            fold = torch.nn.Fold(
+                output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2
+            )
+
+            weighting = self.get_weighting(
+                kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device
+            ).to(x.dtype)
+            normalization = fold(weighting).view(
+                1, 1, h * uf, w * uf
+            )  # normalizes the overlap
+            weighting = weighting.view(
+                (1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx)
+            )
+
+        elif df > 1 and uf == 1:
+            fold_params = dict(
+                kernel_size=kernel_size, dilation=1, padding=0, stride=stride
+            )
+            unfold = torch.nn.Unfold(**fold_params)
+
+            fold_params2 = dict(
+                kernel_size=(kernel_size[0] // df, kernel_size[0] // df),
+                dilation=1,
+                padding=0,
+                stride=(stride[0] // df, stride[1] // df),
+            )
+            fold = torch.nn.Fold(
+                output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2
+            )
+
+            weighting = self.get_weighting(
+                kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device
+            ).to(x.dtype)
+            normalization = fold(weighting).view(
+                1, 1, h // df, w // df
+            )  # normalizes the overlap
+            weighting = weighting.view(
+                (1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx)
+            )
+
+        else:
+            raise NotImplementedError
+
+        return fold, unfold, normalization, weighting
+
+    @torch.no_grad()
+    def get_input(
+        self,
+        batch,
+        k,
+        return_first_stage_outputs=False,
+        force_c_encode=False,
+        cond_key=None,
+        return_original_cond=False,
+        bs=None,
+        return_x=False,
+        mask_k=None,
+    ):
+        x = super().get_input(batch, k)
+        if bs is not None:
+            x = x[:bs]
+        x = x.to(self.device)
+        encoder_posterior = self.encode_first_stage(x)
+        z = self.get_first_stage_encoding(encoder_posterior).detach()
+
+        if mask_k is not None:
+            mx = super().get_input(batch, mask_k)
+            if bs is not None:
+                mx = mx[:bs]
+            mx = mx.to(self.device)
+            encoder_posterior = self.encode_first_stage(mx)
+            mx = self.get_first_stage_encoding(encoder_posterior).detach()
+
+        if self.model.conditioning_key is not None and not self.force_null_conditioning:
+            if cond_key is None:
+                cond_key = self.cond_stage_key
+            if cond_key != self.first_stage_key:
+                if cond_key in ["caption", "coordinates_bbox", "txt"]:
+                    xc = batch[cond_key]
+                elif cond_key in ["class_label", "cls"]:
+                    xc = batch
+                else:
+                    xc = super().get_input(batch, cond_key).to(self.device)
+            else:
+                xc = x
+            if not self.cond_stage_trainable or force_c_encode:
+                if isinstance(xc, dict) or isinstance(xc, list):
+                    c = self.get_learned_conditioning(xc)
+                else:
+                    c = self.get_learned_conditioning(xc.to(self.device))
+            else:
+                c = xc
+            if bs is not None:
+                c = c[:bs]
+
+            if self.use_positional_encodings:
+                pos_x, pos_y = self.compute_latent_shifts(batch)
+                ckey = __conditioning_keys__[self.model.conditioning_key]
+                c = {ckey: c, "pos_x": pos_x, "pos_y": pos_y}
+
+        else:
+            c = None
+            xc = None
+            if self.use_positional_encodings:
+                pos_x, pos_y = self.compute_latent_shifts(batch)
+                c = {"pos_x": pos_x, "pos_y": pos_y}
+        out = [z, c]
+        if return_first_stage_outputs:
+            xrec = self.decode_first_stage(z)
+            out.extend([x, xrec])
+        if return_x:
+            out.extend([x])
+        if return_original_cond:
+            out.append(xc)
+        if mask_k:
+            out.append(mx)
+        return out
+
+    @torch.no_grad()
+    def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
+        if predict_cids:
+            if z.dim() == 4:
+                z = torch.argmax(z.exp(), dim=1).long()
+            z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
+            z = rearrange(z, "b h w c -> b c h w").contiguous()
+
+        z = 1.0 / self.scale_factor * z
+        return self.first_stage_model.decode(z)
+
+    def decode_first_stage_grad(self, z, predict_cids=False, force_not_quantize=False):
+        if predict_cids:
+            if z.dim() == 4:
+                z = torch.argmax(z.exp(), dim=1).long()
+            z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
+            z = rearrange(z, "b h w c -> b c h w").contiguous()
+
+        z = 1.0 / self.scale_factor * z
+        return self.first_stage_model.decode(z)
+
+    @torch.no_grad()
+    def encode_first_stage(self, x):
+        return self.first_stage_model.encode(x)
+
+    def shared_step(self, batch, **kwargs):
+        x, c = self.get_input(batch, self.first_stage_key)
+        loss = self(x, c)
+        return loss
+
+    def forward(self, x, c, *args, **kwargs):
+        t = torch.randint(
+            0, self.num_timesteps, (x.shape[0],), device=self.device
+        ).long()
+        # t = torch.randint(500, 501, (x.shape[0],), device=self.device).long()
+        if self.model.conditioning_key is not None:
+            assert c is not None
+            if self.cond_stage_trainable:
+                c = self.get_learned_conditioning(c)
+            if self.shorten_cond_schedule:  # TODO: drop this option
+                tc = self.cond_ids[t].to(self.device)
+                c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
+        return self.p_losses(x, c, t, *args, **kwargs)
+
+    def apply_model(self, x_noisy, t, cond, return_ids=False):
+        if isinstance(cond, dict):
+            # hybrid case, cond is expected to be a dict
+            pass
+        else:
+            if not isinstance(cond, list):
+                cond = [cond]
+            key = (
+                "c_concat" if self.model.conditioning_key == "concat" else "c_crossattn"
+            )
+            cond = {key: cond}
+
+        x_recon = self.model(x_noisy, t, **cond)
+
+        if isinstance(x_recon, tuple) and not return_ids:
+            return x_recon[0]
+        else:
+            return x_recon
+
+    def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
+        return (
+            extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
+            - pred_xstart
+        ) / extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
+
+    def _prior_bpd(self, x_start):
+        """
+        Get the prior KL term for the variational lower-bound, measured in
+        bits-per-dim.
+        This term can't be optimized, as it only depends on the encoder.
+        :param x_start: the [N x C x ...] tensor of inputs.
+        :return: a batch of [N] KL values (in bits), one per batch element.
+        """
+        batch_size = x_start.shape[0]
+        t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
+        qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
+        kl_prior = normal_kl(
+            mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0
+        )
+        return mean_flat(kl_prior) / np.log(2.0)
+
+    def p_mean_variance(
+        self,
+        x,
+        c,
+        t,
+        clip_denoised: bool,
+        return_codebook_ids=False,
+        quantize_denoised=False,
+        return_x0=False,
+        score_corrector=None,
+        corrector_kwargs=None,
+    ):
+        t_in = t
+        model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids)
+
+        if score_corrector is not None:
+            assert self.parameterization == "eps"
+            model_out = score_corrector.modify_score(
+                self, model_out, x, t, c, **corrector_kwargs
+            )
+
+        if return_codebook_ids:
+            model_out, logits = model_out
+
+        if self.parameterization == "eps":
+            x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
+        elif self.parameterization == "x0":
+            x_recon = model_out
+        else:
+            raise NotImplementedError()
+
+        if clip_denoised:
+            x_recon.clamp_(-1.0, 1.0)
+        if quantize_denoised:
+            x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon)
+        model_mean, posterior_variance, posterior_log_variance = self.q_posterior(
+            x_start=x_recon, x_t=x, t=t
+        )
+        if return_codebook_ids:
+            return model_mean, posterior_variance, posterior_log_variance, logits
+        elif return_x0:
+            return model_mean, posterior_variance, posterior_log_variance, x_recon
+        else:
+            return model_mean, posterior_variance, posterior_log_variance
+
+    @torch.no_grad()
+    def p_sample(
+        self,
+        x,
+        c,
+        t,
+        clip_denoised=False,
+        repeat_noise=False,
+        return_codebook_ids=False,
+        quantize_denoised=False,
+        return_x0=False,
+        temperature=1.0,
+        noise_dropout=0.0,
+        score_corrector=None,
+        corrector_kwargs=None,
+    ):
+        b, *_, device = *x.shape, x.device
+        outputs = self.p_mean_variance(
+            x=x,
+            c=c,
+            t=t,
+            clip_denoised=clip_denoised,
+            return_codebook_ids=return_codebook_ids,
+            quantize_denoised=quantize_denoised,
+            return_x0=return_x0,
+            score_corrector=score_corrector,
+            corrector_kwargs=corrector_kwargs,
+        )
+        if return_codebook_ids:
+            raise DeprecationWarning("Support dropped.")
+            model_mean, _, model_log_variance, logits = outputs
+        elif return_x0:
+            model_mean, _, model_log_variance, x0 = outputs
+        else:
+            model_mean, _, model_log_variance = outputs
+
+        noise = noise_like(x.shape, device, repeat_noise) * temperature
+        if noise_dropout > 0.0:
+            noise = torch.nn.functional.dropout(noise, p=noise_dropout)
+        # no noise when t == 0
+        nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
+
+        if return_codebook_ids:
+            return (
+                model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise,
+                logits.argmax(dim=1),
+            )
+        if return_x0:
+            return (
+                model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise,
+                x0,
+            )
+        else:
+            return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
+
+    @torch.no_grad()
+    def progressive_denoising(
+        self,
+        cond,
+        shape,
+        verbose=True,
+        callback=None,
+        quantize_denoised=False,
+        img_callback=None,
+        mask=None,
+        x0=None,
+        temperature=1.0,
+        noise_dropout=0.0,
+        score_corrector=None,
+        corrector_kwargs=None,
+        batch_size=None,
+        x_T=None,
+        start_T=None,
+        log_every_t=None,
+    ):
+        if not log_every_t:
+            log_every_t = self.log_every_t
+        timesteps = self.num_timesteps
+        if batch_size is not None:
+            b = batch_size if batch_size is not None else shape[0]
+            shape = [batch_size] + list(shape)
+        else:
+            b = batch_size = shape[0]
+        if x_T is None:
+            img = torch.randn(shape, device=self.device)
+        else:
+            img = x_T
+        intermediates = []
+        if cond is not None:
+            if isinstance(cond, dict):
+                cond = {
+                    key: cond[key][:batch_size]
+                    if not isinstance(cond[key], list)
+                    else list(map(lambda x: x[:batch_size], cond[key]))
+                    for key in cond
+                }
+            else:
+                cond = (
+                    [c[:batch_size] for c in cond]
+                    if isinstance(cond, list)
+                    else cond[:batch_size]
+                )
+
+        if start_T is not None:
+            timesteps = min(timesteps, start_T)
+        iterator = (
+            tqdm(
+                reversed(range(0, timesteps)),
+                desc="Progressive Generation",
+                total=timesteps,
+            )
+            if verbose
+            else reversed(range(0, timesteps))
+        )
+        if type(temperature) == float:
+            temperature = [temperature] * timesteps
+
+        for i in iterator:
+            ts = torch.full((b,), i, device=self.device, dtype=torch.long)
+            if self.shorten_cond_schedule:
+                assert self.model.conditioning_key != "hybrid"
+                tc = self.cond_ids[ts].to(cond.device)
+                cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
+
+            img, x0_partial = self.p_sample(
+                img,
+                cond,
+                ts,
+                clip_denoised=self.clip_denoised,
+                quantize_denoised=quantize_denoised,
+                return_x0=True,
+                temperature=temperature[i],
+                noise_dropout=noise_dropout,
+                score_corrector=score_corrector,
+                corrector_kwargs=corrector_kwargs,
+            )
+            if mask is not None:
+                assert x0 is not None
+                img_orig = self.q_sample(x0, ts)
+                img = img_orig * mask + (1.0 - mask) * img
+
+            if i % log_every_t == 0 or i == timesteps - 1:
+                intermediates.append(x0_partial)
+            if callback:
+                callback(i)
+            if img_callback:
+                img_callback(img, i)
+        return img, intermediates
+
+    @torch.no_grad()
+    def p_sample_loop(
+        self,
+        cond,
+        shape,
+        return_intermediates=False,
+        x_T=None,
+        verbose=True,
+        callback=None,
+        timesteps=None,
+        quantize_denoised=False,
+        mask=None,
+        x0=None,
+        img_callback=None,
+        start_T=None,
+        log_every_t=None,
+    ):
+        if not log_every_t:
+            log_every_t = self.log_every_t
+        device = self.betas.device
+        b = shape[0]
+        if x_T is None:
+            img = torch.randn(shape, device=device)
+        else:
+            img = x_T
+
+        intermediates = [img]
+        if timesteps is None:
+            timesteps = self.num_timesteps
+
+        if start_T is not None:
+            timesteps = min(timesteps, start_T)
+        iterator = (
+            tqdm(reversed(range(0, timesteps)), desc="Sampling t", total=timesteps)
+            if verbose
+            else reversed(range(0, timesteps))
+        )
+
+        if mask is not None:
+            assert x0 is not None
+            assert x0.shape[2:3] == mask.shape[2:3]  # spatial size has to match
+
+        for i in iterator:
+            ts = torch.full((b,), i, device=device, dtype=torch.long)
+            if self.shorten_cond_schedule:
+                assert self.model.conditioning_key != "hybrid"
+                tc = self.cond_ids[ts].to(cond.device)
+                cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
+
+            img = self.p_sample(
+                img,
+                cond,
+                ts,
+                clip_denoised=self.clip_denoised,
+                quantize_denoised=quantize_denoised,
+            )
+            if mask is not None:
+                img_orig = self.q_sample(x0, ts)
+                img = img_orig * mask + (1.0 - mask) * img
+
+            if i % log_every_t == 0 or i == timesteps - 1:
+                intermediates.append(img)
+            if callback:
+                callback(i)
+            if img_callback:
+                img_callback(img, i)
+
+        if return_intermediates:
+            return img, intermediates
+        return img
+
+    @torch.no_grad()
+    def sample(
+        self,
+        cond,
+        batch_size=16,
+        return_intermediates=False,
+        x_T=None,
+        verbose=True,
+        timesteps=None,
+        quantize_denoised=False,
+        mask=None,
+        x0=None,
+        shape=None,
+        **kwargs,
+    ):
+        if shape is None:
+            shape = (batch_size, self.channels, self.image_size, self.image_size)
+        if cond is not None:
+            if isinstance(cond, dict):
+                cond = {
+                    key: cond[key][:batch_size]
+                    if not isinstance(cond[key], list)
+                    else list(map(lambda x: x[:batch_size], cond[key]))
+                    for key in cond
+                }
+            else:
+                cond = (
+                    [c[:batch_size] for c in cond]
+                    if isinstance(cond, list)
+                    else cond[:batch_size]
+                )
+        return self.p_sample_loop(
+            cond,
+            shape,
+            return_intermediates=return_intermediates,
+            x_T=x_T,
+            verbose=verbose,
+            timesteps=timesteps,
+            quantize_denoised=quantize_denoised,
+            mask=mask,
+            x0=x0,
+        )
+
+    @torch.no_grad()
+    def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):
+        if ddim:
+            ddim_sampler = DDIMSampler(self)
+            shape = (self.channels, self.image_size, self.image_size)
+            samples, intermediates = ddim_sampler.sample(
+                ddim_steps, batch_size, shape, cond, verbose=False, **kwargs
+            )
+
+        else:
+            samples, intermediates = self.sample(
+                cond=cond, batch_size=batch_size, return_intermediates=True, **kwargs
+            )
+
+        return samples, intermediates
+
+    @torch.no_grad()
+    def get_unconditional_conditioning(self, batch_size, null_label=None):
+        if null_label is not None:
+            xc = null_label
+            if isinstance(xc, ListConfig):
+                xc = list(xc)
+            if isinstance(xc, dict) or isinstance(xc, list):
+                c = self.get_learned_conditioning(xc)
+            else:
+                if hasattr(xc, "to"):
+                    xc = xc.to(self.device)
+                c = self.get_learned_conditioning(xc)
+        else:
+            if self.cond_stage_key in ["class_label", "cls"]:
+                xc = self.cond_stage_model.get_unconditional_conditioning(
+                    batch_size, device=self.device
+                )
+                return self.get_learned_conditioning(xc)
+            else:
+                raise NotImplementedError("todo")
+        if isinstance(c, list):  # in case the encoder gives us a list
+            for i in range(len(c)):
+                c[i] = repeat(c[i], "1 ... -> b ...", b=batch_size).to(self.device)
+        else:
+            c = repeat(c, "1 ... -> b ...", b=batch_size).to(self.device)
+        return c
+
+    @torch.no_grad()
+    def log_images(
+        self,
+        batch,
+        N=8,
+        n_row=4,
+        sample=True,
+        ddim_steps=50,
+        ddim_eta=0.0,
+        return_keys=None,
+        quantize_denoised=True,
+        inpaint=True,
+        plot_denoise_rows=False,
+        plot_progressive_rows=True,
+        plot_diffusion_rows=True,
+        unconditional_guidance_scale=1.0,
+        unconditional_guidance_label=None,
+        use_ema_scope=True,
+        **kwargs,
+    ):
+        ema_scope = self.ema_scope if use_ema_scope else nullcontext
+        use_ddim = ddim_steps is not None
+
+        log = dict()
+        z, c, x, xrec, xc = self.get_input(
+            batch,
+            self.first_stage_key,
+            return_first_stage_outputs=True,
+            force_c_encode=True,
+            return_original_cond=True,
+            bs=N,
+        )
+        N = min(x.shape[0], N)
+        n_row = min(x.shape[0], n_row)
+        log["inputs"] = x
+        log["reconstruction"] = xrec
+        if self.model.conditioning_key is not None:
+            if hasattr(self.cond_stage_model, "decode"):
+                xc = self.cond_stage_model.decode(c)
+                log["conditioning"] = xc
+            elif self.cond_stage_key in ["caption", "txt"]:
+                xc = log_txt_as_img(
+                    (x.shape[2], x.shape[3]),
+                    batch[self.cond_stage_key],
+                    size=x.shape[2] // 25,
+                )
+                log["conditioning"] = xc
+            elif self.cond_stage_key in ["class_label", "cls"]:
+                try:
+                    xc = log_txt_as_img(
+                        (x.shape[2], x.shape[3]),
+                        batch["human_label"],
+                        size=x.shape[2] // 25,
+                    )
+                    log["conditioning"] = xc
+                except KeyError:
+                    # probably no "human_label" in batch
+                    pass
+            elif isimage(xc):
+                log["conditioning"] = xc
+            if ismap(xc):
+                log["original_conditioning"] = self.to_rgb(xc)
+
+        if plot_diffusion_rows:
+            # get diffusion row
+            diffusion_row = list()
+            z_start = z[:n_row]
+            for t in range(self.num_timesteps):
+                if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
+                    t = repeat(torch.tensor([t]), "1 -> b", b=n_row)
+                    t = t.to(self.device).long()
+                    noise = torch.randn_like(z_start)
+                    z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
+                    diffusion_row.append(self.decode_first_stage(z_noisy))
+
+            diffusion_row = torch.stack(diffusion_row)  # n_log_step, n_row, C, H, W
+            diffusion_grid = rearrange(diffusion_row, "n b c h w -> b n c h w")
+            diffusion_grid = rearrange(diffusion_grid, "b n c h w -> (b n) c h w")
+            diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
+            log["diffusion_row"] = diffusion_grid
+
+        if sample:
+            # get denoise row
+            with ema_scope("Sampling"):
+                samples, z_denoise_row = self.sample_log(
+                    cond=c,
+                    batch_size=N,
+                    ddim=use_ddim,
+                    ddim_steps=ddim_steps,
+                    eta=ddim_eta,
+                )
+                # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
+            x_samples = self.decode_first_stage(samples)
+            log["samples"] = x_samples
+            if plot_denoise_rows:
+                denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
+                log["denoise_row"] = denoise_grid
+
+            if (
+                quantize_denoised
+                and not isinstance(self.first_stage_model, AutoencoderKL)
+                and not isinstance(self.first_stage_model, IdentityFirstStage)
+            ):
+                # also display when quantizing x0 while sampling
+                with ema_scope("Plotting Quantized Denoised"):
+                    samples, z_denoise_row = self.sample_log(
+                        cond=c,
+                        batch_size=N,
+                        ddim=use_ddim,
+                        ddim_steps=ddim_steps,
+                        eta=ddim_eta,
+                        quantize_denoised=True,
+                    )
+                    # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,
+                    #                                      quantize_denoised=True)
+                x_samples = self.decode_first_stage(samples.to(self.device))
+                log["samples_x0_quantized"] = x_samples
+
+        if unconditional_guidance_scale > 1.0:
+            uc = self.get_unconditional_conditioning(N, unconditional_guidance_label)
+            if self.model.conditioning_key == "crossattn-adm":
+                uc = {"c_crossattn": [uc], "c_adm": c["c_adm"]}
+            with ema_scope("Sampling with classifier-free guidance"):
+                samples_cfg, _ = self.sample_log(
+                    cond=c,
+                    batch_size=N,
+                    ddim=use_ddim,
+                    ddim_steps=ddim_steps,
+                    eta=ddim_eta,
+                    unconditional_guidance_scale=unconditional_guidance_scale,
+                    unconditional_conditioning=uc,
+                )
+                x_samples_cfg = self.decode_first_stage(samples_cfg)
+                log[
+                    f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"
+                ] = x_samples_cfg
+
+        if inpaint:
+            # make a simple center square
+            b, h, w = z.shape[0], z.shape[2], z.shape[3]
+            mask = torch.ones(N, h, w).to(self.device)
+            # zeros will be filled in
+            mask[:, h // 4 : 3 * h // 4, w // 4 : 3 * w // 4] = 0.0
+            mask = mask[:, None, ...]
+            with ema_scope("Plotting Inpaint"):
+                samples, _ = self.sample_log(
+                    cond=c,
+                    batch_size=N,
+                    ddim=use_ddim,
+                    eta=ddim_eta,
+                    ddim_steps=ddim_steps,
+                    x0=z[:N],
+                    mask=mask,
+                )
+            x_samples = self.decode_first_stage(samples.to(self.device))
+            log["samples_inpainting"] = x_samples
+            log["mask"] = mask
+
+            # outpaint
+            mask = 1.0 - mask
+            with ema_scope("Plotting Outpaint"):
+                samples, _ = self.sample_log(
+                    cond=c,
+                    batch_size=N,
+                    ddim=use_ddim,
+                    eta=ddim_eta,
+                    ddim_steps=ddim_steps,
+                    x0=z[:N],
+                    mask=mask,
+                )
+            x_samples = self.decode_first_stage(samples.to(self.device))
+            log["samples_outpainting"] = x_samples
+
+        if plot_progressive_rows:
+            with ema_scope("Plotting Progressives"):
+                img, progressives = self.progressive_denoising(
+                    c,
+                    shape=(self.channels, self.image_size, self.image_size),
+                    batch_size=N,
+                )
+            prog_row = self._get_denoise_row_from_list(
+                progressives, desc="Progressive Generation"
+            )
+            log["progressive_row"] = prog_row
+
+        if return_keys:
+            if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
+                return log
+            else:
+                return {key: log[key] for key in return_keys}
+        return log
+
+    def configure_optimizers(self):
+        lr = self.learning_rate
+        params = list(self.model.parameters())
+        if self.cond_stage_trainable:
+            print(f"{self.__class__.__name__}: Also optimizing conditioner params!")
+            params = params + list(self.cond_stage_model.parameters())
+        if self.learn_logvar:
+            print("Diffusion model optimizing logvar")
+            params.append(self.logvar)
+        opt = torch.optim.AdamW(params, lr=lr)
+        if self.use_scheduler:
+            assert "target" in self.scheduler_config
+            scheduler = instantiate_from_config(self.scheduler_config)
+
+            print("Setting up LambdaLR scheduler...")
+            scheduler = [
+                {
+                    "scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule),
+                    "interval": "step",
+                    "frequency": 1,
+                }
+            ]
+            return [opt], scheduler
+        return opt
+
+    @torch.no_grad()
+    def to_rgb(self, x):
+        x = x.float()
+        if not hasattr(self, "colorize"):
+            self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x)
+        x = nn.functional.conv2d(x, weight=self.colorize)
+        x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0
+        return x
+
+
+class DiffusionWrapper(torch.nn.Module):
+    def __init__(self, diff_model_config, conditioning_key):
+        super().__init__()
+        self.sequential_cross_attn = diff_model_config.pop(
+            "sequential_crossattn", False
+        )
+        self.diffusion_model = instantiate_from_config(diff_model_config)
+        self.conditioning_key = conditioning_key
+        assert self.conditioning_key in [
+            None,
+            "concat",
+            "crossattn",
+            "hybrid",
+            "adm",
+            "hybrid-adm",
+            "crossattn-adm",
+        ]
+
+    def forward(
+        self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None
+    ):
+        if self.conditioning_key is None:
+            out = self.diffusion_model(x, t)
+        elif self.conditioning_key == "concat":
+            xc = torch.cat([x] + c_concat, dim=1)
+            out = self.diffusion_model(xc, t)
+        elif self.conditioning_key == "crossattn":
+            if not self.sequential_cross_attn:
+                cc = torch.cat(c_crossattn, 1)
+            else:
+                cc = c_crossattn
+            out = self.diffusion_model(x, t, context=cc)
+        elif self.conditioning_key == "hybrid":
+            xc = torch.cat([x] + c_concat, dim=1)
+            cc = torch.cat(c_crossattn, 1)
+            out = self.diffusion_model(xc, t, context=cc)
+        elif self.conditioning_key == "hybrid-adm":
+            assert c_adm is not None
+            xc = torch.cat([x] + c_concat, dim=1)
+            cc = torch.cat(c_crossattn, 1)
+            out = self.diffusion_model(xc, t, context=cc, y=c_adm)
+        elif self.conditioning_key == "crossattn-adm":
+            assert c_adm is not None
+            cc = torch.cat(c_crossattn, 1)
+            out = self.diffusion_model(x, t, context=cc, y=c_adm)
+        elif self.conditioning_key == "adm":
+            cc = c_crossattn[0]
+            out = self.diffusion_model(x, t, y=cc)
+        else:
+            raise NotImplementedError()
+
+        return out
+
+
+class LatentUpscaleDiffusion(LatentDiffusion):
+    def __init__(
+        self,
+        *args,
+        low_scale_config,
+        low_scale_key="LR",
+        noise_level_key=None,
+        **kwargs,
+    ):
+        super().__init__(*args, **kwargs)
+        # assumes that neither the cond_stage nor the low_scale_model contain trainable params
+        assert not self.cond_stage_trainable
+        self.instantiate_low_stage(low_scale_config)
+        self.low_scale_key = low_scale_key
+        self.noise_level_key = noise_level_key
+
+    def instantiate_low_stage(self, config):
+        model = instantiate_from_config(config)
+        self.low_scale_model = model.eval()
+        self.low_scale_model.train = disabled_train
+        for param in self.low_scale_model.parameters():
+            param.requires_grad = False
+
+    @torch.no_grad()
+    def get_input(self, batch, k, cond_key=None, bs=None, log_mode=False):
+        if not log_mode:
+            z, c = super().get_input(batch, k, force_c_encode=True, bs=bs)
+        else:
+            z, c, x, xrec, xc = super().get_input(
+                batch,
+                self.first_stage_key,
+                return_first_stage_outputs=True,
+                force_c_encode=True,
+                return_original_cond=True,
+                bs=bs,
+            )
+        x_low = batch[self.low_scale_key][:bs]
+        x_low = rearrange(x_low, "b h w c -> b c h w")
+        x_low = x_low.to(memory_format=torch.contiguous_format).float()
+        zx, noise_level = self.low_scale_model(x_low)
+        if self.noise_level_key is not None:
+            # get noise level from batch instead, e.g. when extracting a custom noise level for bsr
+            raise NotImplementedError("TODO")
+
+        all_conds = {"c_concat": [zx], "c_crossattn": [c], "c_adm": noise_level}
+        if log_mode:
+            # TODO: maybe disable if too expensive
+            x_low_rec = self.low_scale_model.decode(zx)
+            return z, all_conds, x, xrec, xc, x_low, x_low_rec, noise_level
+        return z, all_conds
+
+    @torch.no_grad()
+    def log_images(
+        self,
+        batch,
+        N=8,
+        n_row=4,
+        sample=True,
+        ddim_steps=200,
+        ddim_eta=1.0,
+        return_keys=None,
+        plot_denoise_rows=False,
+        plot_progressive_rows=True,
+        plot_diffusion_rows=True,
+        unconditional_guidance_scale=1.0,
+        unconditional_guidance_label=None,
+        use_ema_scope=True,
+        **kwargs,
+    ):
+        ema_scope = self.ema_scope if use_ema_scope else nullcontext
+        use_ddim = ddim_steps is not None
+
+        log = dict()
+        z, c, x, xrec, xc, x_low, x_low_rec, noise_level = self.get_input(
+            batch, self.first_stage_key, bs=N, log_mode=True
+        )
+        N = min(x.shape[0], N)
+        n_row = min(x.shape[0], n_row)
+        log["inputs"] = x
+        log["reconstruction"] = xrec
+        log["x_lr"] = x_low
+        log[
+            f"x_lr_rec_@noise_levels{'-'.join(map(lambda x: str(x), list(noise_level.cpu().numpy())))}"
+        ] = x_low_rec
+        if self.model.conditioning_key is not None:
+            if hasattr(self.cond_stage_model, "decode"):
+                xc = self.cond_stage_model.decode(c)
+                log["conditioning"] = xc
+            elif self.cond_stage_key in ["caption", "txt"]:
+                xc = log_txt_as_img(
+                    (x.shape[2], x.shape[3]),
+                    batch[self.cond_stage_key],
+                    size=x.shape[2] // 25,
+                )
+                log["conditioning"] = xc
+            elif self.cond_stage_key in ["class_label", "cls"]:
+                xc = log_txt_as_img(
+                    (x.shape[2], x.shape[3]),
+                    batch["human_label"],
+                    size=x.shape[2] // 25,
+                )
+                log["conditioning"] = xc
+            elif isimage(xc):
+                log["conditioning"] = xc
+            if ismap(xc):
+                log["original_conditioning"] = self.to_rgb(xc)
+
+        if plot_diffusion_rows:
+            # get diffusion row
+            diffusion_row = list()
+            z_start = z[:n_row]
+            for t in range(self.num_timesteps):
+                if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
+                    t = repeat(torch.tensor([t]), "1 -> b", b=n_row)
+                    t = t.to(self.device).long()
+                    noise = torch.randn_like(z_start)
+                    z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
+                    diffusion_row.append(self.decode_first_stage(z_noisy))
+
+            diffusion_row = torch.stack(diffusion_row)  # n_log_step, n_row, C, H, W
+            diffusion_grid = rearrange(diffusion_row, "n b c h w -> b n c h w")
+            diffusion_grid = rearrange(diffusion_grid, "b n c h w -> (b n) c h w")
+            diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
+            log["diffusion_row"] = diffusion_grid
+
+        if sample:
+            # get denoise row
+            with ema_scope("Sampling"):
+                samples, z_denoise_row = self.sample_log(
+                    cond=c,
+                    batch_size=N,
+                    ddim=use_ddim,
+                    ddim_steps=ddim_steps,
+                    eta=ddim_eta,
+                )
+                # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
+            x_samples = self.decode_first_stage(samples)
+            log["samples"] = x_samples
+            if plot_denoise_rows:
+                denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
+                log["denoise_row"] = denoise_grid
+
+        if unconditional_guidance_scale > 1.0:
+            uc_tmp = self.get_unconditional_conditioning(
+                N, unconditional_guidance_label
+            )
+            # TODO explore better "unconditional" choices for the other keys
+            # maybe guide away from empty text label and highest noise level and maximally degraded zx?
+            uc = dict()
+            for k in c:
+                if k == "c_crossattn":
+                    assert isinstance(c[k], list) and len(c[k]) == 1
+                    uc[k] = [uc_tmp]
+                elif k == "c_adm":  # todo: only run with text-based guidance?
+                    assert isinstance(c[k], torch.Tensor)
+                    # uc[k] = torch.ones_like(c[k]) * self.low_scale_model.max_noise_level
+                    uc[k] = c[k]
+                elif isinstance(c[k], list):
+                    uc[k] = [c[k][i] for i in range(len(c[k]))]
+                else:
+                    uc[k] = c[k]
+
+            with ema_scope("Sampling with classifier-free guidance"):
+                samples_cfg, _ = self.sample_log(
+                    cond=c,
+                    batch_size=N,
+                    ddim=use_ddim,
+                    ddim_steps=ddim_steps,
+                    eta=ddim_eta,
+                    unconditional_guidance_scale=unconditional_guidance_scale,
+                    unconditional_conditioning=uc,
+                )
+                x_samples_cfg = self.decode_first_stage(samples_cfg)
+                log[
+                    f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"
+                ] = x_samples_cfg
+
+        if plot_progressive_rows:
+            with ema_scope("Plotting Progressives"):
+                img, progressives = self.progressive_denoising(
+                    c,
+                    shape=(self.channels, self.image_size, self.image_size),
+                    batch_size=N,
+                )
+            prog_row = self._get_denoise_row_from_list(
+                progressives, desc="Progressive Generation"
+            )
+            log["progressive_row"] = prog_row
+
+        return log
+
+
+class LatentFinetuneDiffusion(LatentDiffusion):
+    """
+    Basis for different finetunas, such as inpainting or depth2image
+    To disable finetuning mode, set finetune_keys to None
+    """
+
+    def __init__(
+        self,
+        concat_keys: tuple,
+        finetune_keys=(
+            "model.diffusion_model.input_blocks.0.0.weight",
+            "model_ema.diffusion_modelinput_blocks00weight",
+        ),
+        keep_finetune_dims=4,
+        # if model was trained without concat mode before and we would like to keep these channels
+        c_concat_log_start=None,  # to log reconstruction of c_concat codes
+        c_concat_log_end=None,
+        *args,
+        **kwargs,
+    ):
+        ckpt_path = kwargs.pop("ckpt_path", None)
+        ignore_keys = kwargs.pop("ignore_keys", list())
+        super().__init__(*args, **kwargs)
+        self.finetune_keys = finetune_keys
+        self.concat_keys = concat_keys
+        self.keep_dims = keep_finetune_dims
+        self.c_concat_log_start = c_concat_log_start
+        self.c_concat_log_end = c_concat_log_end
+        if exists(self.finetune_keys):
+            assert exists(ckpt_path), "can only finetune from a given checkpoint"
+        if exists(ckpt_path):
+            self.init_from_ckpt(ckpt_path, ignore_keys)
+
+    def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
+        sd = torch.load(path, map_location="cpu")
+        if "state_dict" in list(sd.keys()):
+            sd = sd["state_dict"]
+        keys = list(sd.keys())
+        for k in keys:
+            for ik in ignore_keys:
+                if k.startswith(ik):
+                    print("Deleting key {} from state_dict.".format(k))
+                    del sd[k]
+
+            # make it explicit, finetune by including extra input channels
+            if exists(self.finetune_keys) and k in self.finetune_keys:
+                new_entry = None
+                for name, param in self.named_parameters():
+                    if name in self.finetune_keys:
+                        print(
+                            f"modifying key '{name}' and keeping its original {self.keep_dims} (channels) dimensions only"
+                        )
+                        new_entry = torch.zeros_like(param)  # zero init
+                assert exists(new_entry), "did not find matching parameter to modify"
+                new_entry[:, : self.keep_dims, ...] = sd[k]
+                sd[k] = new_entry
+
+        missing, unexpected = (
+            self.load_state_dict(sd, strict=False)
+            if not only_model
+            else self.model.load_state_dict(sd, strict=False)
+        )
+        print(
+            f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
+        )
+        if len(missing) > 0:
+            print(f"Missing Keys: {missing}")
+        if len(unexpected) > 0:
+            print(f"Unexpected Keys: {unexpected}")
+
+    @torch.no_grad()
+    def log_images(
+        self,
+        batch,
+        N=8,
+        n_row=4,
+        sample=True,
+        ddim_steps=200,
+        ddim_eta=1.0,
+        return_keys=None,
+        quantize_denoised=True,
+        inpaint=True,
+        plot_denoise_rows=False,
+        plot_progressive_rows=True,
+        plot_diffusion_rows=True,
+        unconditional_guidance_scale=1.0,
+        unconditional_guidance_label=None,
+        use_ema_scope=True,
+        **kwargs,
+    ):
+        ema_scope = self.ema_scope if use_ema_scope else nullcontext
+        use_ddim = ddim_steps is not None
+
+        log = dict()
+        z, c, x, xrec, xc = self.get_input(
+            batch, self.first_stage_key, bs=N, return_first_stage_outputs=True
+        )
+        c_cat, c = c["c_concat"][0], c["c_crossattn"][0]
+        N = min(x.shape[0], N)
+        n_row = min(x.shape[0], n_row)
+        log["inputs"] = x
+        log["reconstruction"] = xrec
+        if self.model.conditioning_key is not None:
+            if hasattr(self.cond_stage_model, "decode"):
+                xc = self.cond_stage_model.decode(c)
+                log["conditioning"] = xc
+            elif self.cond_stage_key in ["caption", "txt"]:
+                xc = log_txt_as_img(
+                    (x.shape[2], x.shape[3]),
+                    batch[self.cond_stage_key],
+                    size=x.shape[2] // 25,
+                )
+                log["conditioning"] = xc
+            elif self.cond_stage_key in ["class_label", "cls"]:
+                xc = log_txt_as_img(
+                    (x.shape[2], x.shape[3]),
+                    batch["human_label"],
+                    size=x.shape[2] // 25,
+                )
+                log["conditioning"] = xc
+            elif isimage(xc):
+                log["conditioning"] = xc
+            if ismap(xc):
+                log["original_conditioning"] = self.to_rgb(xc)
+
+        if not (self.c_concat_log_start is None and self.c_concat_log_end is None):
+            log["c_concat_decoded"] = self.decode_first_stage(
+                c_cat[:, self.c_concat_log_start : self.c_concat_log_end]
+            )
+
+        if plot_diffusion_rows:
+            # get diffusion row
+            diffusion_row = list()
+            z_start = z[:n_row]
+            for t in range(self.num_timesteps):
+                if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
+                    t = repeat(torch.tensor([t]), "1 -> b", b=n_row)
+                    t = t.to(self.device).long()
+                    noise = torch.randn_like(z_start)
+                    z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
+                    diffusion_row.append(self.decode_first_stage(z_noisy))
+
+            diffusion_row = torch.stack(diffusion_row)  # n_log_step, n_row, C, H, W
+            diffusion_grid = rearrange(diffusion_row, "n b c h w -> b n c h w")
+            diffusion_grid = rearrange(diffusion_grid, "b n c h w -> (b n) c h w")
+            diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
+            log["diffusion_row"] = diffusion_grid
+
+        if sample:
+            # get denoise row
+            with ema_scope("Sampling"):
+                samples, z_denoise_row = self.sample_log(
+                    cond={"c_concat": [c_cat], "c_crossattn": [c]},
+                    batch_size=N,
+                    ddim=use_ddim,
+                    ddim_steps=ddim_steps,
+                    eta=ddim_eta,
+                )
+                # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
+            x_samples = self.decode_first_stage(samples)
+            log["samples"] = x_samples
+            if plot_denoise_rows:
+                denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
+                log["denoise_row"] = denoise_grid
+
+        if unconditional_guidance_scale > 1.0:
+            uc_cross = self.get_unconditional_conditioning(
+                N, unconditional_guidance_label
+            )
+            uc_cat = c_cat
+            uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]}
+            with ema_scope("Sampling with classifier-free guidance"):
+                samples_cfg, _ = self.sample_log(
+                    cond={"c_concat": [c_cat], "c_crossattn": [c]},
+                    batch_size=N,
+                    ddim=use_ddim,
+                    ddim_steps=ddim_steps,
+                    eta=ddim_eta,
+                    unconditional_guidance_scale=unconditional_guidance_scale,
+                    unconditional_conditioning=uc_full,
+                )
+                x_samples_cfg = self.decode_first_stage(samples_cfg)
+                log[
+                    f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"
+                ] = x_samples_cfg
+
+        return log
+
+
+class LatentInpaintDiffusion(LatentFinetuneDiffusion):
+    """
+    can either run as pure inpainting model (only concat mode) or with mixed conditionings,
+    e.g. mask as concat and text via cross-attn.
+    To disable finetuning mode, set finetune_keys to None
+    """
+
+    def __init__(
+        self,
+        concat_keys=("mask", "masked_image"),
+        masked_image_key="masked_image",
+        *args,
+        **kwargs,
+    ):
+        super().__init__(concat_keys, *args, **kwargs)
+        self.masked_image_key = masked_image_key
+        assert self.masked_image_key in concat_keys
+
+    @torch.no_grad()
+    def get_input(
+        self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False
+    ):
+        # note: restricted to non-trainable encoders currently
+        assert (
+            not self.cond_stage_trainable
+        ), "trainable cond stages not yet supported for inpainting"
+        z, c, x, xrec, xc = super().get_input(
+            batch,
+            self.first_stage_key,
+            return_first_stage_outputs=True,
+            force_c_encode=True,
+            return_original_cond=True,
+            bs=bs,
+        )
+
+        assert exists(self.concat_keys)
+        c_cat = list()
+        for ck in self.concat_keys:
+            cc = (
+                rearrange(batch[ck], "b h w c -> b c h w")
+                .to(memory_format=torch.contiguous_format)
+                .float()
+            )
+            if bs is not None:
+                cc = cc[:bs]
+                cc = cc.to(self.device)
+            bchw = z.shape
+            if ck != self.masked_image_key:
+                cc = torch.nn.functional.interpolate(cc, size=bchw[-2:])
+            else:
+                cc = self.get_first_stage_encoding(self.encode_first_stage(cc))
+            c_cat.append(cc)
+        c_cat = torch.cat(c_cat, dim=1)
+        all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
+        if return_first_stage_outputs:
+            return z, all_conds, x, xrec, xc
+        return z, all_conds
+
+    @torch.no_grad()
+    def log_images(self, *args, **kwargs):
+        log = super(LatentInpaintDiffusion, self).log_images(*args, **kwargs)
+        log["masked_image"] = (
+            rearrange(args[0]["masked_image"], "b h w c -> b c h w")
+            .to(memory_format=torch.contiguous_format)
+            .float()
+        )
+        return log
+
+
+class LatentDepth2ImageDiffusion(LatentFinetuneDiffusion):
+    """
+    condition on monocular depth estimation
+    """
+
+    def __init__(self, depth_stage_config, concat_keys=("midas_in",), *args, **kwargs):
+        super().__init__(concat_keys=concat_keys, *args, **kwargs)
+        self.depth_model = instantiate_from_config(depth_stage_config)
+        self.depth_stage_key = concat_keys[0]
+
+    @torch.no_grad()
+    def get_input(
+        self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False
+    ):
+        # note: restricted to non-trainable encoders currently
+        assert (
+            not self.cond_stage_trainable
+        ), "trainable cond stages not yet supported for depth2img"
+        z, c, x, xrec, xc = super().get_input(
+            batch,
+            self.first_stage_key,
+            return_first_stage_outputs=True,
+            force_c_encode=True,
+            return_original_cond=True,
+            bs=bs,
+        )
+
+        assert exists(self.concat_keys)
+        assert len(self.concat_keys) == 1
+        c_cat = list()
+        for ck in self.concat_keys:
+            cc = batch[ck]
+            if bs is not None:
+                cc = cc[:bs]
+                cc = cc.to(self.device)
+            cc = self.depth_model(cc)
+            cc = torch.nn.functional.interpolate(
+                cc,
+                size=z.shape[2:],
+                mode="bicubic",
+                align_corners=False,
+            )
+
+            depth_min, depth_max = (
+                torch.amin(cc, dim=[1, 2, 3], keepdim=True),
+                torch.amax(cc, dim=[1, 2, 3], keepdim=True),
+            )
+            cc = 2.0 * (cc - depth_min) / (depth_max - depth_min + 0.001) - 1.0
+            c_cat.append(cc)
+        c_cat = torch.cat(c_cat, dim=1)
+        all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
+        if return_first_stage_outputs:
+            return z, all_conds, x, xrec, xc
+        return z, all_conds
+
+    @torch.no_grad()
+    def log_images(self, *args, **kwargs):
+        log = super().log_images(*args, **kwargs)
+        depth = self.depth_model(args[0][self.depth_stage_key])
+        depth_min, depth_max = (
+            torch.amin(depth, dim=[1, 2, 3], keepdim=True),
+            torch.amax(depth, dim=[1, 2, 3], keepdim=True),
+        )
+        log["depth"] = 2.0 * (depth - depth_min) / (depth_max - depth_min) - 1.0
+        return log
+
+
+class LatentUpscaleFinetuneDiffusion(LatentFinetuneDiffusion):
+    """
+    condition on low-res image (and optionally on some spatial noise augmentation)
+    """
+
+    def __init__(
+        self,
+        concat_keys=("lr",),
+        reshuffle_patch_size=None,
+        low_scale_config=None,
+        low_scale_key=None,
+        *args,
+        **kwargs,
+    ):
+        super().__init__(concat_keys=concat_keys, *args, **kwargs)
+        self.reshuffle_patch_size = reshuffle_patch_size
+        self.low_scale_model = None
+        if low_scale_config is not None:
+            print("Initializing a low-scale model")
+            assert exists(low_scale_key)
+            self.instantiate_low_stage(low_scale_config)
+            self.low_scale_key = low_scale_key
+
+    def instantiate_low_stage(self, config):
+        model = instantiate_from_config(config)
+        self.low_scale_model = model.eval()
+        self.low_scale_model.train = disabled_train
+        for param in self.low_scale_model.parameters():
+            param.requires_grad = False
+
+    @torch.no_grad()
+    def get_input(
+        self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False
+    ):
+        # note: restricted to non-trainable encoders currently
+        assert (
+            not self.cond_stage_trainable
+        ), "trainable cond stages not yet supported for upscaling-ft"
+        z, c, x, xrec, xc = super().get_input(
+            batch,
+            self.first_stage_key,
+            return_first_stage_outputs=True,
+            force_c_encode=True,
+            return_original_cond=True,
+            bs=bs,
+        )
+
+        assert exists(self.concat_keys)
+        assert len(self.concat_keys) == 1
+        # optionally make spatial noise_level here
+        c_cat = list()
+        noise_level = None
+        for ck in self.concat_keys:
+            cc = batch[ck]
+            cc = rearrange(cc, "b h w c -> b c h w")
+            if exists(self.reshuffle_patch_size):
+                assert isinstance(self.reshuffle_patch_size, int)
+                cc = rearrange(
+                    cc,
+                    "b c (p1 h) (p2 w) -> b (p1 p2 c) h w",
+                    p1=self.reshuffle_patch_size,
+                    p2=self.reshuffle_patch_size,
+                )
+            if bs is not None:
+                cc = cc[:bs]
+                cc = cc.to(self.device)
+            if exists(self.low_scale_model) and ck == self.low_scale_key:
+                cc, noise_level = self.low_scale_model(cc)
+            c_cat.append(cc)
+        c_cat = torch.cat(c_cat, dim=1)
+        if exists(noise_level):
+            all_conds = {"c_concat": [c_cat], "c_crossattn": [c], "c_adm": noise_level}
+        else:
+            all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
+        if return_first_stage_outputs:
+            return z, all_conds, x, xrec, xc
+        return z, all_conds
+
+    @torch.no_grad()
+    def log_images(self, *args, **kwargs):
+        log = super().log_images(*args, **kwargs)
+        log["lr"] = rearrange(args[0]["lr"], "b h w c -> b c h w")
+        return log

+ 1 - 0
sorawm/iopaint/model/anytext/ldm/models/diffusion/dpm_solver/__init__.py

@@ -0,0 +1 @@
+from .sampler import DPMSolverSampler

+ 1464 - 0
sorawm/iopaint/model/anytext/ldm/models/diffusion/dpm_solver/dpm_solver.py

@@ -0,0 +1,1464 @@
+import math
+
+import torch
+import torch.nn.functional as F
+from tqdm import tqdm
+
+
+class NoiseScheduleVP:
+    def __init__(
+        self,
+        schedule="discrete",
+        betas=None,
+        alphas_cumprod=None,
+        continuous_beta_0=0.1,
+        continuous_beta_1=20.0,
+    ):
+        """Create a wrapper class for the forward SDE (VP type).
+        ***
+        Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t.
+                We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images.
+        ***
+        The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ).
+        We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper).
+        Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have:
+            log_alpha_t = self.marginal_log_mean_coeff(t)
+            sigma_t = self.marginal_std(t)
+            lambda_t = self.marginal_lambda(t)
+        Moreover, as lambda(t) is an invertible function, we also support its inverse function:
+            t = self.inverse_lambda(lambda_t)
+        ===============================================================
+        We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]).
+        1. For discrete-time DPMs:
+            For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by:
+                t_i = (i + 1) / N
+            e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1.
+            We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3.
+            Args:
+                betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details)
+                alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details)
+            Note that we always have alphas_cumprod = cumprod(betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`.
+            **Important**:  Please pay special attention for the args for `alphas_cumprod`:
+                The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that
+                    q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ).
+                Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have
+                    alpha_{t_n} = \sqrt{\hat{alpha_n}},
+                and
+                    log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}).
+        2. For continuous-time DPMs:
+            We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise
+            schedule are the default settings in DDPM and improved-DDPM:
+            Args:
+                beta_min: A `float` number. The smallest beta for the linear schedule.
+                beta_max: A `float` number. The largest beta for the linear schedule.
+                cosine_s: A `float` number. The hyperparameter in the cosine schedule.
+                cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule.
+                T: A `float` number. The ending time of the forward process.
+        ===============================================================
+        Args:
+            schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs,
+                    'linear' or 'cosine' for continuous-time DPMs.
+        Returns:
+            A wrapper object of the forward SDE (VP type).
+
+        ===============================================================
+        Example:
+        # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1):
+        >>> ns = NoiseScheduleVP('discrete', betas=betas)
+        # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1):
+        >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
+        # For continuous-time DPMs (VPSDE), linear schedule:
+        >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.)
+        """
+
+        if schedule not in ["discrete", "linear", "cosine"]:
+            raise ValueError(
+                "Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(
+                    schedule
+                )
+            )
+
+        self.schedule = schedule
+        if schedule == "discrete":
+            if betas is not None:
+                log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0)
+            else:
+                assert alphas_cumprod is not None
+                log_alphas = 0.5 * torch.log(alphas_cumprod)
+            self.total_N = len(log_alphas)
+            self.T = 1.0
+            self.t_array = torch.linspace(0.0, 1.0, self.total_N + 1)[1:].reshape(
+                (1, -1)
+            )
+            self.log_alpha_array = log_alphas.reshape(
+                (
+                    1,
+                    -1,
+                )
+            )
+        else:
+            self.total_N = 1000
+            self.beta_0 = continuous_beta_0
+            self.beta_1 = continuous_beta_1
+            self.cosine_s = 0.008
+            self.cosine_beta_max = 999.0
+            self.cosine_t_max = (
+                math.atan(self.cosine_beta_max * (1.0 + self.cosine_s) / math.pi)
+                * 2.0
+                * (1.0 + self.cosine_s)
+                / math.pi
+                - self.cosine_s
+            )
+            self.cosine_log_alpha_0 = math.log(
+                math.cos(self.cosine_s / (1.0 + self.cosine_s) * math.pi / 2.0)
+            )
+            self.schedule = schedule
+            if schedule == "cosine":
+                # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T.
+                # Note that T = 0.9946 may be not the optimal setting. However, we find it works well.
+                self.T = 0.9946
+            else:
+                self.T = 1.0
+
+    def marginal_log_mean_coeff(self, t):
+        """
+        Compute log(alpha_t) of a given continuous-time label t in [0, T].
+        """
+        if self.schedule == "discrete":
+            return interpolate_fn(
+                t.reshape((-1, 1)),
+                self.t_array.to(t.device),
+                self.log_alpha_array.to(t.device),
+            ).reshape((-1))
+        elif self.schedule == "linear":
+            return -0.25 * t**2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
+        elif self.schedule == "cosine":
+            log_alpha_fn = lambda s: torch.log(
+                torch.cos((s + self.cosine_s) / (1.0 + self.cosine_s) * math.pi / 2.0)
+            )
+            log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0
+            return log_alpha_t
+
+    def marginal_alpha(self, t):
+        """
+        Compute alpha_t of a given continuous-time label t in [0, T].
+        """
+        return torch.exp(self.marginal_log_mean_coeff(t))
+
+    def marginal_std(self, t):
+        """
+        Compute sigma_t of a given continuous-time label t in [0, T].
+        """
+        return torch.sqrt(1.0 - torch.exp(2.0 * self.marginal_log_mean_coeff(t)))
+
+    def marginal_lambda(self, t):
+        """
+        Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
+        """
+        log_mean_coeff = self.marginal_log_mean_coeff(t)
+        log_std = 0.5 * torch.log(1.0 - torch.exp(2.0 * log_mean_coeff))
+        return log_mean_coeff - log_std
+
+    def inverse_lambda(self, lamb):
+        """
+        Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
+        """
+        if self.schedule == "linear":
+            tmp = (
+                2.0
+                * (self.beta_1 - self.beta_0)
+                * torch.logaddexp(-2.0 * lamb, torch.zeros((1,)).to(lamb))
+            )
+            Delta = self.beta_0**2 + tmp
+            return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
+        elif self.schedule == "discrete":
+            log_alpha = -0.5 * torch.logaddexp(
+                torch.zeros((1,)).to(lamb.device), -2.0 * lamb
+            )
+            t = interpolate_fn(
+                log_alpha.reshape((-1, 1)),
+                torch.flip(self.log_alpha_array.to(lamb.device), [1]),
+                torch.flip(self.t_array.to(lamb.device), [1]),
+            )
+            return t.reshape((-1,))
+        else:
+            log_alpha = -0.5 * torch.logaddexp(-2.0 * lamb, torch.zeros((1,)).to(lamb))
+            t_fn = (
+                lambda log_alpha_t: torch.arccos(
+                    torch.exp(log_alpha_t + self.cosine_log_alpha_0)
+                )
+                * 2.0
+                * (1.0 + self.cosine_s)
+                / math.pi
+                - self.cosine_s
+            )
+            t = t_fn(log_alpha)
+            return t
+
+
+def model_wrapper(
+    model,
+    noise_schedule,
+    model_type="noise",
+    model_kwargs={},
+    guidance_type="uncond",
+    condition=None,
+    unconditional_condition=None,
+    guidance_scale=1.0,
+    classifier_fn=None,
+    classifier_kwargs={},
+):
+    """Create a wrapper function for the noise prediction model.
+    DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to
+    firstly wrap the model function to a noise prediction model that accepts the continuous time as the input.
+    We support four types of the diffusion model by setting `model_type`:
+        1. "noise": noise prediction model. (Trained by predicting noise).
+        2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0).
+        3. "v": velocity prediction model. (Trained by predicting the velocity).
+            The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2].
+            [1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models."
+                arXiv preprint arXiv:2202.00512 (2022).
+            [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models."
+                arXiv preprint arXiv:2210.02303 (2022).
+
+        4. "score": marginal score function. (Trained by denoising score matching).
+            Note that the score function and the noise prediction model follows a simple relationship:
+            ```
+                noise(x_t, t) = -sigma_t * score(x_t, t)
+            ```
+    We support three types of guided sampling by DPMs by setting `guidance_type`:
+        1. "uncond": unconditional sampling by DPMs.
+            The input `model` has the following format:
+            ``
+                model(x, t_input, **model_kwargs) -> noise | x_start | v | score
+            ``
+        2. "classifier": classifier guidance sampling [3] by DPMs and another classifier.
+            The input `model` has the following format:
+            ``
+                model(x, t_input, **model_kwargs) -> noise | x_start | v | score
+            ``
+            The input `classifier_fn` has the following format:
+            ``
+                classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond)
+            ``
+            [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis,"
+                in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794.
+        3. "classifier-free": classifier-free guidance sampling by conditional DPMs.
+            The input `model` has the following format:
+            ``
+                model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score
+            ``
+            And if cond == `unconditional_condition`, the model output is the unconditional DPM output.
+            [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance."
+                arXiv preprint arXiv:2207.12598 (2022).
+
+    The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999)
+    or continuous-time labels (i.e. epsilon to T).
+    We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise:
+    ``
+        def model_fn(x, t_continuous) -> noise:
+            t_input = get_model_input_time(t_continuous)
+            return noise_pred(model, x, t_input, **model_kwargs)
+    ``
+    where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver.
+    ===============================================================
+    Args:
+        model: A diffusion model with the corresponding format described above.
+        noise_schedule: A noise schedule object, such as NoiseScheduleVP.
+        model_type: A `str`. The parameterization type of the diffusion model.
+                    "noise" or "x_start" or "v" or "score".
+        model_kwargs: A `dict`. A dict for the other inputs of the model function.
+        guidance_type: A `str`. The type of the guidance for sampling.
+                    "uncond" or "classifier" or "classifier-free".
+        condition: A pytorch tensor. The condition for the guided sampling.
+                    Only used for "classifier" or "classifier-free" guidance type.
+        unconditional_condition: A pytorch tensor. The condition for the unconditional sampling.
+                    Only used for "classifier-free" guidance type.
+        guidance_scale: A `float`. The scale for the guided sampling.
+        classifier_fn: A classifier function. Only used for the classifier guidance.
+        classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function.
+    Returns:
+        A noise prediction model that accepts the noised data and the continuous time as the inputs.
+    """
+
+    def get_model_input_time(t_continuous):
+        """
+        Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
+        For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N].
+        For continuous-time DPMs, we just use `t_continuous`.
+        """
+        if noise_schedule.schedule == "discrete":
+            return (t_continuous - 1.0 / noise_schedule.total_N) * 1000.0
+        else:
+            return t_continuous
+
+    def noise_pred_fn(x, t_continuous, cond=None):
+        if t_continuous.reshape((-1,)).shape[0] == 1:
+            t_continuous = t_continuous.expand((x.shape[0]))
+        t_input = get_model_input_time(t_continuous)
+        if cond is None:
+            output = model(x, t_input, **model_kwargs)
+        else:
+            output = model(x, t_input, cond, **model_kwargs)
+        if model_type == "noise":
+            return output
+        elif model_type == "x_start":
+            alpha_t, sigma_t = (
+                noise_schedule.marginal_alpha(t_continuous),
+                noise_schedule.marginal_std(t_continuous),
+            )
+            dims = x.dim()
+            return (x - expand_dims(alpha_t, dims) * output) / expand_dims(
+                sigma_t, dims
+            )
+        elif model_type == "v":
+            alpha_t, sigma_t = (
+                noise_schedule.marginal_alpha(t_continuous),
+                noise_schedule.marginal_std(t_continuous),
+            )
+            dims = x.dim()
+            return expand_dims(alpha_t, dims) * output + expand_dims(sigma_t, dims) * x
+        elif model_type == "score":
+            sigma_t = noise_schedule.marginal_std(t_continuous)
+            dims = x.dim()
+            return -expand_dims(sigma_t, dims) * output
+
+    def cond_grad_fn(x, t_input):
+        """
+        Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t).
+        """
+        with torch.enable_grad():
+            x_in = x.detach().requires_grad_(True)
+            log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs)
+            return torch.autograd.grad(log_prob.sum(), x_in)[0]
+
+    def model_fn(x, t_continuous):
+        """
+        The noise predicition model function that is used for DPM-Solver.
+        """
+        if t_continuous.reshape((-1,)).shape[0] == 1:
+            t_continuous = t_continuous.expand((x.shape[0]))
+        if guidance_type == "uncond":
+            return noise_pred_fn(x, t_continuous)
+        elif guidance_type == "classifier":
+            assert classifier_fn is not None
+            t_input = get_model_input_time(t_continuous)
+            cond_grad = cond_grad_fn(x, t_input)
+            sigma_t = noise_schedule.marginal_std(t_continuous)
+            noise = noise_pred_fn(x, t_continuous)
+            return (
+                noise
+                - guidance_scale
+                * expand_dims(sigma_t, dims=cond_grad.dim())
+                * cond_grad
+            )
+        elif guidance_type == "classifier-free":
+            if guidance_scale == 1.0 or unconditional_condition is None:
+                return noise_pred_fn(x, t_continuous, cond=condition)
+            else:
+                x_in = torch.cat([x] * 2)
+                t_in = torch.cat([t_continuous] * 2)
+                c_in = torch.cat([unconditional_condition, condition])
+                noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
+                return noise_uncond + guidance_scale * (noise - noise_uncond)
+
+    assert model_type in ["noise", "x_start", "v"]
+    assert guidance_type in ["uncond", "classifier", "classifier-free"]
+    return model_fn
+
+
+class DPM_Solver:
+    def __init__(
+        self,
+        model_fn,
+        noise_schedule,
+        predict_x0=False,
+        thresholding=False,
+        max_val=1.0,
+    ):
+        """Construct a DPM-Solver.
+        We support both the noise prediction model ("predicting epsilon") and the data prediction model ("predicting x0").
+        If `predict_x0` is False, we use the solver for the noise prediction model (DPM-Solver).
+        If `predict_x0` is True, we use the solver for the data prediction model (DPM-Solver++).
+            In such case, we further support the "dynamic thresholding" in [1] when `thresholding` is True.
+            The "dynamic thresholding" can greatly improve the sample quality for pixel-space DPMs with large guidance scales.
+        Args:
+            model_fn: A noise prediction model function which accepts the continuous-time input (t in [epsilon, T]):
+                ``
+                def model_fn(x, t_continuous):
+                    return noise
+                ``
+            noise_schedule: A noise schedule object, such as NoiseScheduleVP.
+            predict_x0: A `bool`. If true, use the data prediction model; else, use the noise prediction model.
+            thresholding: A `bool`. Valid when `predict_x0` is True. Whether to use the "dynamic thresholding" in [1].
+            max_val: A `float`. Valid when both `predict_x0` and `thresholding` are True. The max value for thresholding.
+
+        [1] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour, Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al. Photorealistic text-to-image diffusion models with deep language understanding. arXiv preprint arXiv:2205.11487, 2022b.
+        """
+        self.model = model_fn
+        self.noise_schedule = noise_schedule
+        self.predict_x0 = predict_x0
+        self.thresholding = thresholding
+        self.max_val = max_val
+
+    def noise_prediction_fn(self, x, t):
+        """
+        Return the noise prediction model.
+        """
+        return self.model(x, t)
+
+    def data_prediction_fn(self, x, t):
+        """
+        Return the data prediction model (with thresholding).
+        """
+        noise = self.noise_prediction_fn(x, t)
+        dims = x.dim()
+        alpha_t, sigma_t = (
+            self.noise_schedule.marginal_alpha(t),
+            self.noise_schedule.marginal_std(t),
+        )
+        x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims)
+        if self.thresholding:
+            p = 0.995  # A hyperparameter in the paper of "Imagen" [1].
+            s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
+            s = expand_dims(
+                torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims
+            )
+            x0 = torch.clamp(x0, -s, s) / s
+        return x0
+
+    def model_fn(self, x, t):
+        """
+        Convert the model to the noise prediction model or the data prediction model.
+        """
+        if self.predict_x0:
+            return self.data_prediction_fn(x, t)
+        else:
+            return self.noise_prediction_fn(x, t)
+
+    def get_time_steps(self, skip_type, t_T, t_0, N, device):
+        """Compute the intermediate time steps for sampling.
+        Args:
+            skip_type: A `str`. The type for the spacing of the time steps. We support three types:
+                - 'logSNR': uniform logSNR for the time steps.
+                - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
+                - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
+            t_T: A `float`. The starting time of the sampling (default is T).
+            t_0: A `float`. The ending time of the sampling (default is epsilon).
+            N: A `int`. The total number of the spacing of the time steps.
+            device: A torch device.
+        Returns:
+            A pytorch tensor of the time steps, with the shape (N + 1,).
+        """
+        if skip_type == "logSNR":
+            lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
+            lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
+            logSNR_steps = torch.linspace(
+                lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1
+            ).to(device)
+            return self.noise_schedule.inverse_lambda(logSNR_steps)
+        elif skip_type == "time_uniform":
+            return torch.linspace(t_T, t_0, N + 1).to(device)
+        elif skip_type == "time_quadratic":
+            t_order = 2
+            t = (
+                torch.linspace(t_T ** (1.0 / t_order), t_0 ** (1.0 / t_order), N + 1)
+                .pow(t_order)
+                .to(device)
+            )
+            return t
+        else:
+            raise ValueError(
+                "Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(
+                    skip_type
+                )
+            )
+
+    def get_orders_and_timesteps_for_singlestep_solver(
+        self, steps, order, skip_type, t_T, t_0, device
+    ):
+        """
+        Get the order of each step for sampling by the singlestep DPM-Solver.
+        We combine both DPM-Solver-1,2,3 to use all the function evaluations, which is named as "DPM-Solver-fast".
+        Given a fixed number of function evaluations by `steps`, the sampling procedure by DPM-Solver-fast is:
+            - If order == 1:
+                We take `steps` of DPM-Solver-1 (i.e. DDIM).
+            - If order == 2:
+                - Denote K = (steps // 2). We take K or (K + 1) intermediate time steps for sampling.
+                - If steps % 2 == 0, we use K steps of DPM-Solver-2.
+                - If steps % 2 == 1, we use K steps of DPM-Solver-2 and 1 step of DPM-Solver-1.
+            - If order == 3:
+                - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
+                - If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1.
+                - If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1.
+                - If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2.
+        ============================================
+        Args:
+            order: A `int`. The max order for the solver (2 or 3).
+            steps: A `int`. The total number of function evaluations (NFE).
+            skip_type: A `str`. The type for the spacing of the time steps. We support three types:
+                - 'logSNR': uniform logSNR for the time steps.
+                - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
+                - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
+            t_T: A `float`. The starting time of the sampling (default is T).
+            t_0: A `float`. The ending time of the sampling (default is epsilon).
+            device: A torch device.
+        Returns:
+            orders: A list of the solver order of each step.
+        """
+        if order == 3:
+            K = steps // 3 + 1
+            if steps % 3 == 0:
+                orders = [
+                    3,
+                ] * (
+                    K - 2
+                ) + [2, 1]
+            elif steps % 3 == 1:
+                orders = [
+                    3,
+                ] * (
+                    K - 1
+                ) + [1]
+            else:
+                orders = [
+                    3,
+                ] * (
+                    K - 1
+                ) + [2]
+        elif order == 2:
+            if steps % 2 == 0:
+                K = steps // 2
+                orders = [
+                    2,
+                ] * K
+            else:
+                K = steps // 2 + 1
+                orders = [
+                    2,
+                ] * (
+                    K - 1
+                ) + [1]
+        elif order == 1:
+            K = 1
+            orders = [
+                1,
+            ] * steps
+        else:
+            raise ValueError("'order' must be '1' or '2' or '3'.")
+        if skip_type == "logSNR":
+            # To reproduce the results in DPM-Solver paper
+            timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device)
+        else:
+            timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[
+                torch.cumsum(
+                    torch.tensor(
+                        [
+                            0,
+                        ]
+                        + orders
+                    )
+                ).to(device)
+            ]
+        return timesteps_outer, orders
+
+    def denoise_to_zero_fn(self, x, s):
+        """
+        Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
+        """
+        return self.data_prediction_fn(x, s)
+
+    def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=False):
+        """
+        DPM-Solver-1 (equivalent to DDIM) from time `s` to time `t`.
+        Args:
+            x: A pytorch tensor. The initial value at time `s`.
+            s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
+            t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+            model_s: A pytorch tensor. The model function evaluated at time `s`.
+                If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
+            return_intermediate: A `bool`. If true, also return the model value at time `s`.
+        Returns:
+            x_t: A pytorch tensor. The approximated solution at time `t`.
+        """
+        ns = self.noise_schedule
+        dims = x.dim()
+        lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
+        h = lambda_t - lambda_s
+        log_alpha_s, log_alpha_t = (
+            ns.marginal_log_mean_coeff(s),
+            ns.marginal_log_mean_coeff(t),
+        )
+        sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t)
+        alpha_t = torch.exp(log_alpha_t)
+
+        if self.predict_x0:
+            phi_1 = torch.expm1(-h)
+            if model_s is None:
+                model_s = self.model_fn(x, s)
+            x_t = (
+                expand_dims(sigma_t / sigma_s, dims) * x
+                - expand_dims(alpha_t * phi_1, dims) * model_s
+            )
+            if return_intermediate:
+                return x_t, {"model_s": model_s}
+            else:
+                return x_t
+        else:
+            phi_1 = torch.expm1(h)
+            if model_s is None:
+                model_s = self.model_fn(x, s)
+            x_t = (
+                expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
+                - expand_dims(sigma_t * phi_1, dims) * model_s
+            )
+            if return_intermediate:
+                return x_t, {"model_s": model_s}
+            else:
+                return x_t
+
+    def singlestep_dpm_solver_second_update(
+        self,
+        x,
+        s,
+        t,
+        r1=0.5,
+        model_s=None,
+        return_intermediate=False,
+        solver_type="dpm_solver",
+    ):
+        """
+        Singlestep solver DPM-Solver-2 from time `s` to time `t`.
+        Args:
+            x: A pytorch tensor. The initial value at time `s`.
+            s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
+            t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+            r1: A `float`. The hyperparameter of the second-order solver.
+            model_s: A pytorch tensor. The model function evaluated at time `s`.
+                If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
+            return_intermediate: A `bool`. If true, also return the model value at time `s` and `s1` (the intermediate time).
+            solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+                The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+        Returns:
+            x_t: A pytorch tensor. The approximated solution at time `t`.
+        """
+        if solver_type not in ["dpm_solver", "taylor"]:
+            raise ValueError(
+                "'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(
+                    solver_type
+                )
+            )
+        if r1 is None:
+            r1 = 0.5
+        ns = self.noise_schedule
+        dims = x.dim()
+        lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
+        h = lambda_t - lambda_s
+        lambda_s1 = lambda_s + r1 * h
+        s1 = ns.inverse_lambda(lambda_s1)
+        log_alpha_s, log_alpha_s1, log_alpha_t = (
+            ns.marginal_log_mean_coeff(s),
+            ns.marginal_log_mean_coeff(s1),
+            ns.marginal_log_mean_coeff(t),
+        )
+        sigma_s, sigma_s1, sigma_t = (
+            ns.marginal_std(s),
+            ns.marginal_std(s1),
+            ns.marginal_std(t),
+        )
+        alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t)
+
+        if self.predict_x0:
+            phi_11 = torch.expm1(-r1 * h)
+            phi_1 = torch.expm1(-h)
+
+            if model_s is None:
+                model_s = self.model_fn(x, s)
+            x_s1 = (
+                expand_dims(sigma_s1 / sigma_s, dims) * x
+                - expand_dims(alpha_s1 * phi_11, dims) * model_s
+            )
+            model_s1 = self.model_fn(x_s1, s1)
+            if solver_type == "dpm_solver":
+                x_t = (
+                    expand_dims(sigma_t / sigma_s, dims) * x
+                    - expand_dims(alpha_t * phi_1, dims) * model_s
+                    - (0.5 / r1)
+                    * expand_dims(alpha_t * phi_1, dims)
+                    * (model_s1 - model_s)
+                )
+            elif solver_type == "taylor":
+                x_t = (
+                    expand_dims(sigma_t / sigma_s, dims) * x
+                    - expand_dims(alpha_t * phi_1, dims) * model_s
+                    + (1.0 / r1)
+                    * expand_dims(alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0), dims)
+                    * (model_s1 - model_s)
+                )
+        else:
+            phi_11 = torch.expm1(r1 * h)
+            phi_1 = torch.expm1(h)
+
+            if model_s is None:
+                model_s = self.model_fn(x, s)
+            x_s1 = (
+                expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x
+                - expand_dims(sigma_s1 * phi_11, dims) * model_s
+            )
+            model_s1 = self.model_fn(x_s1, s1)
+            if solver_type == "dpm_solver":
+                x_t = (
+                    expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
+                    - expand_dims(sigma_t * phi_1, dims) * model_s
+                    - (0.5 / r1)
+                    * expand_dims(sigma_t * phi_1, dims)
+                    * (model_s1 - model_s)
+                )
+            elif solver_type == "taylor":
+                x_t = (
+                    expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
+                    - expand_dims(sigma_t * phi_1, dims) * model_s
+                    - (1.0 / r1)
+                    * expand_dims(sigma_t * ((torch.exp(h) - 1.0) / h - 1.0), dims)
+                    * (model_s1 - model_s)
+                )
+        if return_intermediate:
+            return x_t, {"model_s": model_s, "model_s1": model_s1}
+        else:
+            return x_t
+
+    def singlestep_dpm_solver_third_update(
+        self,
+        x,
+        s,
+        t,
+        r1=1.0 / 3.0,
+        r2=2.0 / 3.0,
+        model_s=None,
+        model_s1=None,
+        return_intermediate=False,
+        solver_type="dpm_solver",
+    ):
+        """
+        Singlestep solver DPM-Solver-3 from time `s` to time `t`.
+        Args:
+            x: A pytorch tensor. The initial value at time `s`.
+            s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
+            t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+            r1: A `float`. The hyperparameter of the third-order solver.
+            r2: A `float`. The hyperparameter of the third-order solver.
+            model_s: A pytorch tensor. The model function evaluated at time `s`.
+                If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
+            model_s1: A pytorch tensor. The model function evaluated at time `s1` (the intermediate time given by `r1`).
+                If `model_s1` is None, we evaluate the model at `s1`; otherwise we directly use it.
+            return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
+            solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+                The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+        Returns:
+            x_t: A pytorch tensor. The approximated solution at time `t`.
+        """
+        if solver_type not in ["dpm_solver", "taylor"]:
+            raise ValueError(
+                "'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(
+                    solver_type
+                )
+            )
+        if r1 is None:
+            r1 = 1.0 / 3.0
+        if r2 is None:
+            r2 = 2.0 / 3.0
+        ns = self.noise_schedule
+        dims = x.dim()
+        lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
+        h = lambda_t - lambda_s
+        lambda_s1 = lambda_s + r1 * h
+        lambda_s2 = lambda_s + r2 * h
+        s1 = ns.inverse_lambda(lambda_s1)
+        s2 = ns.inverse_lambda(lambda_s2)
+        log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = (
+            ns.marginal_log_mean_coeff(s),
+            ns.marginal_log_mean_coeff(s1),
+            ns.marginal_log_mean_coeff(s2),
+            ns.marginal_log_mean_coeff(t),
+        )
+        sigma_s, sigma_s1, sigma_s2, sigma_t = (
+            ns.marginal_std(s),
+            ns.marginal_std(s1),
+            ns.marginal_std(s2),
+            ns.marginal_std(t),
+        )
+        alpha_s1, alpha_s2, alpha_t = (
+            torch.exp(log_alpha_s1),
+            torch.exp(log_alpha_s2),
+            torch.exp(log_alpha_t),
+        )
+
+        if self.predict_x0:
+            phi_11 = torch.expm1(-r1 * h)
+            phi_12 = torch.expm1(-r2 * h)
+            phi_1 = torch.expm1(-h)
+            phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1.0
+            phi_2 = phi_1 / h + 1.0
+            phi_3 = phi_2 / h - 0.5
+
+            if model_s is None:
+                model_s = self.model_fn(x, s)
+            if model_s1 is None:
+                x_s1 = (
+                    expand_dims(sigma_s1 / sigma_s, dims) * x
+                    - expand_dims(alpha_s1 * phi_11, dims) * model_s
+                )
+                model_s1 = self.model_fn(x_s1, s1)
+            x_s2 = (
+                expand_dims(sigma_s2 / sigma_s, dims) * x
+                - expand_dims(alpha_s2 * phi_12, dims) * model_s
+                + r2 / r1 * expand_dims(alpha_s2 * phi_22, dims) * (model_s1 - model_s)
+            )
+            model_s2 = self.model_fn(x_s2, s2)
+            if solver_type == "dpm_solver":
+                x_t = (
+                    expand_dims(sigma_t / sigma_s, dims) * x
+                    - expand_dims(alpha_t * phi_1, dims) * model_s
+                    + (1.0 / r2)
+                    * expand_dims(alpha_t * phi_2, dims)
+                    * (model_s2 - model_s)
+                )
+            elif solver_type == "taylor":
+                D1_0 = (1.0 / r1) * (model_s1 - model_s)
+                D1_1 = (1.0 / r2) * (model_s2 - model_s)
+                D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
+                D2 = 2.0 * (D1_1 - D1_0) / (r2 - r1)
+                x_t = (
+                    expand_dims(sigma_t / sigma_s, dims) * x
+                    - expand_dims(alpha_t * phi_1, dims) * model_s
+                    + expand_dims(alpha_t * phi_2, dims) * D1
+                    - expand_dims(alpha_t * phi_3, dims) * D2
+                )
+        else:
+            phi_11 = torch.expm1(r1 * h)
+            phi_12 = torch.expm1(r2 * h)
+            phi_1 = torch.expm1(h)
+            phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1.0
+            phi_2 = phi_1 / h - 1.0
+            phi_3 = phi_2 / h - 0.5
+
+            if model_s is None:
+                model_s = self.model_fn(x, s)
+            if model_s1 is None:
+                x_s1 = (
+                    expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x
+                    - expand_dims(sigma_s1 * phi_11, dims) * model_s
+                )
+                model_s1 = self.model_fn(x_s1, s1)
+            x_s2 = (
+                expand_dims(torch.exp(log_alpha_s2 - log_alpha_s), dims) * x
+                - expand_dims(sigma_s2 * phi_12, dims) * model_s
+                - r2 / r1 * expand_dims(sigma_s2 * phi_22, dims) * (model_s1 - model_s)
+            )
+            model_s2 = self.model_fn(x_s2, s2)
+            if solver_type == "dpm_solver":
+                x_t = (
+                    expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
+                    - expand_dims(sigma_t * phi_1, dims) * model_s
+                    - (1.0 / r2)
+                    * expand_dims(sigma_t * phi_2, dims)
+                    * (model_s2 - model_s)
+                )
+            elif solver_type == "taylor":
+                D1_0 = (1.0 / r1) * (model_s1 - model_s)
+                D1_1 = (1.0 / r2) * (model_s2 - model_s)
+                D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
+                D2 = 2.0 * (D1_1 - D1_0) / (r2 - r1)
+                x_t = (
+                    expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
+                    - expand_dims(sigma_t * phi_1, dims) * model_s
+                    - expand_dims(sigma_t * phi_2, dims) * D1
+                    - expand_dims(sigma_t * phi_3, dims) * D2
+                )
+
+        if return_intermediate:
+            return x_t, {"model_s": model_s, "model_s1": model_s1, "model_s2": model_s2}
+        else:
+            return x_t
+
+    def multistep_dpm_solver_second_update(
+        self, x, model_prev_list, t_prev_list, t, solver_type="dpm_solver"
+    ):
+        """
+        Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`.
+        Args:
+            x: A pytorch tensor. The initial value at time `s`.
+            model_prev_list: A list of pytorch tensor. The previous computed model values.
+            t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
+            t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+            solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+                The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+        Returns:
+            x_t: A pytorch tensor. The approximated solution at time `t`.
+        """
+        if solver_type not in ["dpm_solver", "taylor"]:
+            raise ValueError(
+                "'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(
+                    solver_type
+                )
+            )
+        ns = self.noise_schedule
+        dims = x.dim()
+        model_prev_1, model_prev_0 = model_prev_list
+        t_prev_1, t_prev_0 = t_prev_list
+        lambda_prev_1, lambda_prev_0, lambda_t = (
+            ns.marginal_lambda(t_prev_1),
+            ns.marginal_lambda(t_prev_0),
+            ns.marginal_lambda(t),
+        )
+        log_alpha_prev_0, log_alpha_t = (
+            ns.marginal_log_mean_coeff(t_prev_0),
+            ns.marginal_log_mean_coeff(t),
+        )
+        sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
+        alpha_t = torch.exp(log_alpha_t)
+
+        h_0 = lambda_prev_0 - lambda_prev_1
+        h = lambda_t - lambda_prev_0
+        r0 = h_0 / h
+        D1_0 = expand_dims(1.0 / r0, dims) * (model_prev_0 - model_prev_1)
+        if self.predict_x0:
+            if solver_type == "dpm_solver":
+                x_t = (
+                    expand_dims(sigma_t / sigma_prev_0, dims) * x
+                    - expand_dims(alpha_t * (torch.exp(-h) - 1.0), dims) * model_prev_0
+                    - 0.5 * expand_dims(alpha_t * (torch.exp(-h) - 1.0), dims) * D1_0
+                )
+            elif solver_type == "taylor":
+                x_t = (
+                    expand_dims(sigma_t / sigma_prev_0, dims) * x
+                    - expand_dims(alpha_t * (torch.exp(-h) - 1.0), dims) * model_prev_0
+                    + expand_dims(alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0), dims)
+                    * D1_0
+                )
+        else:
+            if solver_type == "dpm_solver":
+                x_t = (
+                    expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
+                    - expand_dims(sigma_t * (torch.exp(h) - 1.0), dims) * model_prev_0
+                    - 0.5 * expand_dims(sigma_t * (torch.exp(h) - 1.0), dims) * D1_0
+                )
+            elif solver_type == "taylor":
+                x_t = (
+                    expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
+                    - expand_dims(sigma_t * (torch.exp(h) - 1.0), dims) * model_prev_0
+                    - expand_dims(sigma_t * ((torch.exp(h) - 1.0) / h - 1.0), dims)
+                    * D1_0
+                )
+        return x_t
+
+    def multistep_dpm_solver_third_update(
+        self, x, model_prev_list, t_prev_list, t, solver_type="dpm_solver"
+    ):
+        """
+        Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`.
+        Args:
+            x: A pytorch tensor. The initial value at time `s`.
+            model_prev_list: A list of pytorch tensor. The previous computed model values.
+            t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
+            t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+            solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+                The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+        Returns:
+            x_t: A pytorch tensor. The approximated solution at time `t`.
+        """
+        ns = self.noise_schedule
+        dims = x.dim()
+        model_prev_2, model_prev_1, model_prev_0 = model_prev_list
+        t_prev_2, t_prev_1, t_prev_0 = t_prev_list
+        lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = (
+            ns.marginal_lambda(t_prev_2),
+            ns.marginal_lambda(t_prev_1),
+            ns.marginal_lambda(t_prev_0),
+            ns.marginal_lambda(t),
+        )
+        log_alpha_prev_0, log_alpha_t = (
+            ns.marginal_log_mean_coeff(t_prev_0),
+            ns.marginal_log_mean_coeff(t),
+        )
+        sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
+        alpha_t = torch.exp(log_alpha_t)
+
+        h_1 = lambda_prev_1 - lambda_prev_2
+        h_0 = lambda_prev_0 - lambda_prev_1
+        h = lambda_t - lambda_prev_0
+        r0, r1 = h_0 / h, h_1 / h
+        D1_0 = expand_dims(1.0 / r0, dims) * (model_prev_0 - model_prev_1)
+        D1_1 = expand_dims(1.0 / r1, dims) * (model_prev_1 - model_prev_2)
+        D1 = D1_0 + expand_dims(r0 / (r0 + r1), dims) * (D1_0 - D1_1)
+        D2 = expand_dims(1.0 / (r0 + r1), dims) * (D1_0 - D1_1)
+        if self.predict_x0:
+            x_t = (
+                expand_dims(sigma_t / sigma_prev_0, dims) * x
+                - expand_dims(alpha_t * (torch.exp(-h) - 1.0), dims) * model_prev_0
+                + expand_dims(alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0), dims) * D1
+                - expand_dims(
+                    alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5), dims
+                )
+                * D2
+            )
+        else:
+            x_t = (
+                expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
+                - expand_dims(sigma_t * (torch.exp(h) - 1.0), dims) * model_prev_0
+                - expand_dims(sigma_t * ((torch.exp(h) - 1.0) / h - 1.0), dims) * D1
+                - expand_dims(sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5), dims)
+                * D2
+            )
+        return x_t
+
+    def singlestep_dpm_solver_update(
+        self,
+        x,
+        s,
+        t,
+        order,
+        return_intermediate=False,
+        solver_type="dpm_solver",
+        r1=None,
+        r2=None,
+    ):
+        """
+        Singlestep DPM-Solver with the order `order` from time `s` to time `t`.
+        Args:
+            x: A pytorch tensor. The initial value at time `s`.
+            s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
+            t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+            order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
+            return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
+            solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+                The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+            r1: A `float`. The hyperparameter of the second-order or third-order solver.
+            r2: A `float`. The hyperparameter of the third-order solver.
+        Returns:
+            x_t: A pytorch tensor. The approximated solution at time `t`.
+        """
+        if order == 1:
+            return self.dpm_solver_first_update(
+                x, s, t, return_intermediate=return_intermediate
+            )
+        elif order == 2:
+            return self.singlestep_dpm_solver_second_update(
+                x,
+                s,
+                t,
+                return_intermediate=return_intermediate,
+                solver_type=solver_type,
+                r1=r1,
+            )
+        elif order == 3:
+            return self.singlestep_dpm_solver_third_update(
+                x,
+                s,
+                t,
+                return_intermediate=return_intermediate,
+                solver_type=solver_type,
+                r1=r1,
+                r2=r2,
+            )
+        else:
+            raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
+
+    def multistep_dpm_solver_update(
+        self, x, model_prev_list, t_prev_list, t, order, solver_type="dpm_solver"
+    ):
+        """
+        Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`.
+        Args:
+            x: A pytorch tensor. The initial value at time `s`.
+            model_prev_list: A list of pytorch tensor. The previous computed model values.
+            t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
+            t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+            order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
+            solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+                The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+        Returns:
+            x_t: A pytorch tensor. The approximated solution at time `t`.
+        """
+        if order == 1:
+            return self.dpm_solver_first_update(
+                x, t_prev_list[-1], t, model_s=model_prev_list[-1]
+            )
+        elif order == 2:
+            return self.multistep_dpm_solver_second_update(
+                x, model_prev_list, t_prev_list, t, solver_type=solver_type
+            )
+        elif order == 3:
+            return self.multistep_dpm_solver_third_update(
+                x, model_prev_list, t_prev_list, t, solver_type=solver_type
+            )
+        else:
+            raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
+
+    def dpm_solver_adaptive(
+        self,
+        x,
+        order,
+        t_T,
+        t_0,
+        h_init=0.05,
+        atol=0.0078,
+        rtol=0.05,
+        theta=0.9,
+        t_err=1e-5,
+        solver_type="dpm_solver",
+    ):
+        """
+        The adaptive step size solver based on singlestep DPM-Solver.
+        Args:
+            x: A pytorch tensor. The initial value at time `t_T`.
+            order: A `int`. The (higher) order of the solver. We only support order == 2 or 3.
+            t_T: A `float`. The starting time of the sampling (default is T).
+            t_0: A `float`. The ending time of the sampling (default is epsilon).
+            h_init: A `float`. The initial step size (for logSNR).
+            atol: A `float`. The absolute tolerance of the solver. For image data, the default setting is 0.0078, followed [1].
+            rtol: A `float`. The relative tolerance of the solver. The default setting is 0.05.
+            theta: A `float`. The safety hyperparameter for adapting the step size. The default setting is 0.9, followed [1].
+            t_err: A `float`. The tolerance for the time. We solve the diffusion ODE until the absolute error between the
+                current time and `t_0` is less than `t_err`. The default setting is 1e-5.
+            solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+                The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+        Returns:
+            x_0: A pytorch tensor. The approximated solution at time `t_0`.
+        [1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based models," arXiv preprint arXiv:2105.14080, 2021.
+        """
+        ns = self.noise_schedule
+        s = t_T * torch.ones((x.shape[0],)).to(x)
+        lambda_s = ns.marginal_lambda(s)
+        lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x))
+        h = h_init * torch.ones_like(s).to(x)
+        x_prev = x
+        nfe = 0
+        if order == 2:
+            r1 = 0.5
+            lower_update = lambda x, s, t: self.dpm_solver_first_update(
+                x, s, t, return_intermediate=True
+            )
+            higher_update = (
+                lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update(
+                    x, s, t, r1=r1, solver_type=solver_type, **kwargs
+                )
+            )
+        elif order == 3:
+            r1, r2 = 1.0 / 3.0, 2.0 / 3.0
+            lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update(
+                x, s, t, r1=r1, return_intermediate=True, solver_type=solver_type
+            )
+            higher_update = (
+                lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update(
+                    x, s, t, r1=r1, r2=r2, solver_type=solver_type, **kwargs
+                )
+            )
+        else:
+            raise ValueError(
+                "For adaptive step size solver, order must be 2 or 3, got {}".format(
+                    order
+                )
+            )
+        while torch.abs((s - t_0)).mean() > t_err:
+            t = ns.inverse_lambda(lambda_s + h)
+            x_lower, lower_noise_kwargs = lower_update(x, s, t)
+            x_higher = higher_update(x, s, t, **lower_noise_kwargs)
+            delta = torch.max(
+                torch.ones_like(x).to(x) * atol,
+                rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev)),
+            )
+            norm_fn = lambda v: torch.sqrt(
+                torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True)
+            )
+            E = norm_fn((x_higher - x_lower) / delta).max()
+            if torch.all(E <= 1.0):
+                x = x_higher
+                s = t
+                x_prev = x_lower
+                lambda_s = ns.marginal_lambda(s)
+            h = torch.min(
+                theta * h * torch.float_power(E, -1.0 / order).float(),
+                lambda_0 - lambda_s,
+            )
+            nfe += order
+        print("adaptive solver nfe", nfe)
+        return x
+
+    def sample(
+        self,
+        x,
+        steps=20,
+        t_start=None,
+        t_end=None,
+        order=3,
+        skip_type="time_uniform",
+        method="singlestep",
+        lower_order_final=True,
+        denoise_to_zero=False,
+        solver_type="dpm_solver",
+        atol=0.0078,
+        rtol=0.05,
+    ):
+        """
+        Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`.
+        =====================================================
+        We support the following algorithms for both noise prediction model and data prediction model:
+            - 'singlestep':
+                Singlestep DPM-Solver (i.e. "DPM-Solver-fast" in the paper), which combines different orders of singlestep DPM-Solver.
+                We combine all the singlestep solvers with order <= `order` to use up all the function evaluations (steps).
+                The total number of function evaluations (NFE) == `steps`.
+                Given a fixed NFE == `steps`, the sampling procedure is:
+                    - If `order` == 1:
+                        - Denote K = steps. We use K steps of DPM-Solver-1 (i.e. DDIM).
+                    - If `order` == 2:
+                        - Denote K = (steps // 2) + (steps % 2). We take K intermediate time steps for sampling.
+                        - If steps % 2 == 0, we use K steps of singlestep DPM-Solver-2.
+                        - If steps % 2 == 1, we use (K - 1) steps of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
+                    - If `order` == 3:
+                        - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
+                        - If steps % 3 == 0, we use (K - 2) steps of singlestep DPM-Solver-3, and 1 step of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
+                        - If steps % 3 == 1, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of DPM-Solver-1.
+                        - If steps % 3 == 2, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of singlestep DPM-Solver-2.
+            - 'multistep':
+                Multistep DPM-Solver with the order of `order`. The total number of function evaluations (NFE) == `steps`.
+                We initialize the first `order` values by lower order multistep solvers.
+                Given a fixed NFE == `steps`, the sampling procedure is:
+                    Denote K = steps.
+                    - If `order` == 1:
+                        - We use K steps of DPM-Solver-1 (i.e. DDIM).
+                    - If `order` == 2:
+                        - We firstly use 1 step of DPM-Solver-1, then use (K - 1) step of multistep DPM-Solver-2.
+                    - If `order` == 3:
+                        - We firstly use 1 step of DPM-Solver-1, then 1 step of multistep DPM-Solver-2, then (K - 2) step of multistep DPM-Solver-3.
+            - 'singlestep_fixed':
+                Fixed order singlestep DPM-Solver (i.e. DPM-Solver-1 or singlestep DPM-Solver-2 or singlestep DPM-Solver-3).
+                We use singlestep DPM-Solver-`order` for `order`=1 or 2 or 3, with total [`steps` // `order`] * `order` NFE.
+            - 'adaptive':
+                Adaptive step size DPM-Solver (i.e. "DPM-Solver-12" and "DPM-Solver-23" in the paper).
+                We ignore `steps` and use adaptive step size DPM-Solver with a higher order of `order`.
+                You can adjust the absolute tolerance `atol` and the relative tolerance `rtol` to balance the computatation costs
+                (NFE) and the sample quality.
+                    - If `order` == 2, we use DPM-Solver-12 which combines DPM-Solver-1 and singlestep DPM-Solver-2.
+                    - If `order` == 3, we use DPM-Solver-23 which combines singlestep DPM-Solver-2 and singlestep DPM-Solver-3.
+        =====================================================
+        Some advices for choosing the algorithm:
+            - For **unconditional sampling** or **guided sampling with small guidance scale** by DPMs:
+                Use singlestep DPM-Solver ("DPM-Solver-fast" in the paper) with `order = 3`.
+                e.g.
+                    >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=False)
+                    >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3,
+                            skip_type='time_uniform', method='singlestep')
+            - For **guided sampling with large guidance scale** by DPMs:
+                Use multistep DPM-Solver with `predict_x0 = True` and `order = 2`.
+                e.g.
+                    >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True)
+                    >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=2,
+                            skip_type='time_uniform', method='multistep')
+        We support three types of `skip_type`:
+            - 'logSNR': uniform logSNR for the time steps. **Recommended for low-resolutional images**
+            - 'time_uniform': uniform time for the time steps. **Recommended for high-resolutional images**.
+            - 'time_quadratic': quadratic time for the time steps.
+        =====================================================
+        Args:
+            x: A pytorch tensor. The initial value at time `t_start`
+                e.g. if `t_start` == T, then `x` is a sample from the standard normal distribution.
+            steps: A `int`. The total number of function evaluations (NFE).
+            t_start: A `float`. The starting time of the sampling.
+                If `T` is None, we use self.noise_schedule.T (default is 1.0).
+            t_end: A `float`. The ending time of the sampling.
+                If `t_end` is None, we use 1. / self.noise_schedule.total_N.
+                e.g. if total_N == 1000, we have `t_end` == 1e-3.
+                For discrete-time DPMs:
+                    - We recommend `t_end` == 1. / self.noise_schedule.total_N.
+                For continuous-time DPMs:
+                    - We recommend `t_end` == 1e-3 when `steps` <= 15; and `t_end` == 1e-4 when `steps` > 15.
+            order: A `int`. The order of DPM-Solver.
+            skip_type: A `str`. The type for the spacing of the time steps. 'time_uniform' or 'logSNR' or 'time_quadratic'.
+            method: A `str`. The method for sampling. 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'.
+            denoise_to_zero: A `bool`. Whether to denoise to time 0 at the final step.
+                Default is `False`. If `denoise_to_zero` is `True`, the total NFE is (`steps` + 1).
+                This trick is firstly proposed by DDPM (https://arxiv.org/abs/2006.11239) and
+                score_sde (https://arxiv.org/abs/2011.13456). Such trick can improve the FID
+                for diffusion models sampling by diffusion SDEs for low-resolutional images
+                (such as CIFAR-10). However, we observed that such trick does not matter for
+                high-resolutional images. As it needs an additional NFE, we do not recommend
+                it for high-resolutional images.
+            lower_order_final: A `bool`. Whether to use lower order solvers at the final steps.
+                Only valid for `method=multistep` and `steps < 15`. We empirically find that
+                this trick is a key to stabilizing the sampling by DPM-Solver with very few steps
+                (especially for steps <= 10). So we recommend to set it to be `True`.
+            solver_type: A `str`. The taylor expansion type for the solver. `dpm_solver` or `taylor`. We recommend `dpm_solver`.
+            atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
+            rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
+        Returns:
+            x_end: A pytorch tensor. The approximated solution at time `t_end`.
+        """
+        t_0 = 1.0 / self.noise_schedule.total_N if t_end is None else t_end
+        t_T = self.noise_schedule.T if t_start is None else t_start
+        device = x.device
+        if method == "adaptive":
+            with torch.no_grad():
+                x = self.dpm_solver_adaptive(
+                    x,
+                    order=order,
+                    t_T=t_T,
+                    t_0=t_0,
+                    atol=atol,
+                    rtol=rtol,
+                    solver_type=solver_type,
+                )
+        elif method == "multistep":
+            assert steps >= order
+            timesteps = self.get_time_steps(
+                skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device
+            )
+            assert timesteps.shape[0] - 1 == steps
+            with torch.no_grad():
+                vec_t = timesteps[0].expand((x.shape[0]))
+                model_prev_list = [self.model_fn(x, vec_t)]
+                t_prev_list = [vec_t]
+                # Init the first `order` values by lower order multistep DPM-Solver.
+                for init_order in tqdm(range(1, order), desc="DPM init order"):
+                    vec_t = timesteps[init_order].expand(x.shape[0])
+                    x = self.multistep_dpm_solver_update(
+                        x,
+                        model_prev_list,
+                        t_prev_list,
+                        vec_t,
+                        init_order,
+                        solver_type=solver_type,
+                    )
+                    model_prev_list.append(self.model_fn(x, vec_t))
+                    t_prev_list.append(vec_t)
+                # Compute the remaining values by `order`-th order multistep DPM-Solver.
+                for step in tqdm(range(order, steps + 1), desc="DPM multistep"):
+                    vec_t = timesteps[step].expand(x.shape[0])
+                    if lower_order_final and steps < 15:
+                        step_order = min(order, steps + 1 - step)
+                    else:
+                        step_order = order
+                    x = self.multistep_dpm_solver_update(
+                        x,
+                        model_prev_list,
+                        t_prev_list,
+                        vec_t,
+                        step_order,
+                        solver_type=solver_type,
+                    )
+                    for i in range(order - 1):
+                        t_prev_list[i] = t_prev_list[i + 1]
+                        model_prev_list[i] = model_prev_list[i + 1]
+                    t_prev_list[-1] = vec_t
+                    # We do not need to evaluate the final model value.
+                    if step < steps:
+                        model_prev_list[-1] = self.model_fn(x, vec_t)
+        elif method in ["singlestep", "singlestep_fixed"]:
+            if method == "singlestep":
+                (
+                    timesteps_outer,
+                    orders,
+                ) = self.get_orders_and_timesteps_for_singlestep_solver(
+                    steps=steps,
+                    order=order,
+                    skip_type=skip_type,
+                    t_T=t_T,
+                    t_0=t_0,
+                    device=device,
+                )
+            elif method == "singlestep_fixed":
+                K = steps // order
+                orders = [
+                    order,
+                ] * K
+                timesteps_outer = self.get_time_steps(
+                    skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device
+                )
+            for i, order in enumerate(orders):
+                t_T_inner, t_0_inner = timesteps_outer[i], timesteps_outer[i + 1]
+                timesteps_inner = self.get_time_steps(
+                    skip_type=skip_type,
+                    t_T=t_T_inner.item(),
+                    t_0=t_0_inner.item(),
+                    N=order,
+                    device=device,
+                )
+                lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner)
+                vec_s, vec_t = t_T_inner.tile(x.shape[0]), t_0_inner.tile(x.shape[0])
+                h = lambda_inner[-1] - lambda_inner[0]
+                r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h
+                r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h
+                x = self.singlestep_dpm_solver_update(
+                    x, vec_s, vec_t, order, solver_type=solver_type, r1=r1, r2=r2
+                )
+        if denoise_to_zero:
+            x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0)
+        return x
+
+
+#############################################################
+# other utility functions
+#############################################################
+
+
+def interpolate_fn(x, xp, yp):
+    """
+    A piecewise linear function y = f(x), using xp and yp as keypoints.
+    We implement f(x) in a differentiable way (i.e. applicable for autograd).
+    The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.)
+    Args:
+        x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver).
+        xp: PyTorch tensor with shape [C, K], where K is the number of keypoints.
+        yp: PyTorch tensor with shape [C, K].
+    Returns:
+        The function values f(x), with shape [N, C].
+    """
+    N, K = x.shape[0], xp.shape[1]
+    all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2)
+    sorted_all_x, x_indices = torch.sort(all_x, dim=2)
+    x_idx = torch.argmin(x_indices, dim=2)
+    cand_start_idx = x_idx - 1
+    start_idx = torch.where(
+        torch.eq(x_idx, 0),
+        torch.tensor(1, device=x.device),
+        torch.where(
+            torch.eq(x_idx, K),
+            torch.tensor(K - 2, device=x.device),
+            cand_start_idx,
+        ),
+    )
+    end_idx = torch.where(
+        torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1
+    )
+    start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
+    end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
+    start_idx2 = torch.where(
+        torch.eq(x_idx, 0),
+        torch.tensor(0, device=x.device),
+        torch.where(
+            torch.eq(x_idx, K),
+            torch.tensor(K - 2, device=x.device),
+            cand_start_idx,
+        ),
+    )
+    y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
+    start_y = torch.gather(
+        y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)
+    ).squeeze(2)
+    end_y = torch.gather(
+        y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)
+    ).squeeze(2)
+    cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
+    return cand
+
+
+def expand_dims(v, dims):
+    """
+    Expand the tensor `v` to the dim `dims`.
+    Args:
+        `v`: a PyTorch tensor with shape [N].
+        `dim`: a `int`.
+    Returns:
+        a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
+    """
+    return v[(...,) + (None,) * (dims - 1)]

+ 95 - 0
sorawm/iopaint/model/anytext/ldm/models/diffusion/dpm_solver/sampler.py

@@ -0,0 +1,95 @@
+"""SAMPLING ONLY."""
+import torch
+
+from .dpm_solver import DPM_Solver, NoiseScheduleVP, model_wrapper
+
+MODEL_TYPES = {"eps": "noise", "v": "v"}
+
+
+class DPMSolverSampler(object):
+    def __init__(self, model, **kwargs):
+        super().__init__()
+        self.model = model
+        to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device)
+        self.register_buffer("alphas_cumprod", to_torch(model.alphas_cumprod))
+
+    def register_buffer(self, name, attr):
+        if type(attr) == torch.Tensor:
+            if attr.device != torch.device("cuda"):
+                attr = attr.to(torch.device("cuda"))
+        setattr(self, name, attr)
+
+    @torch.no_grad()
+    def sample(
+        self,
+        S,
+        batch_size,
+        shape,
+        conditioning=None,
+        callback=None,
+        normals_sequence=None,
+        img_callback=None,
+        quantize_x0=False,
+        eta=0.0,
+        mask=None,
+        x0=None,
+        temperature=1.0,
+        noise_dropout=0.0,
+        score_corrector=None,
+        corrector_kwargs=None,
+        verbose=True,
+        x_T=None,
+        log_every_t=100,
+        unconditional_guidance_scale=1.0,
+        unconditional_conditioning=None,
+        # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
+        **kwargs,
+    ):
+        if conditioning is not None:
+            if isinstance(conditioning, dict):
+                cbs = conditioning[list(conditioning.keys())[0]].shape[0]
+                if cbs != batch_size:
+                    print(
+                        f"Warning: Got {cbs} conditionings but batch-size is {batch_size}"
+                    )
+            else:
+                if conditioning.shape[0] != batch_size:
+                    print(
+                        f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}"
+                    )
+
+        # sampling
+        C, H, W = shape
+        size = (batch_size, C, H, W)
+
+        print(f"Data shape for DPM-Solver sampling is {size}, sampling steps {S}")
+
+        device = self.model.betas.device
+        if x_T is None:
+            img = torch.randn(size, device=device)
+        else:
+            img = x_T
+
+        ns = NoiseScheduleVP("discrete", alphas_cumprod=self.alphas_cumprod)
+
+        model_fn = model_wrapper(
+            lambda x, t, c: self.model.apply_model(x, t, c),
+            ns,
+            model_type=MODEL_TYPES[self.model.parameterization],
+            guidance_type="classifier-free",
+            condition=conditioning,
+            unconditional_condition=unconditional_conditioning,
+            guidance_scale=unconditional_guidance_scale,
+        )
+
+        dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False)
+        x = dpm_solver.sample(
+            img,
+            steps=S,
+            skip_type="time_uniform",
+            method="multistep",
+            order=2,
+            lower_order_final=True,
+        )
+
+        return x.to(device), None

Alguns ficheiros não foram mostrados porque muitos ficheiros mudaram neste diff