RioShiina commited on
Commit
c2bcd10
·
verified ·
1 Parent(s): e9c4f1a

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +4 -35
  2. .gitignore +26 -0
  3. CODEOWNERS +24 -0
  4. CONTRIBUTING.md +41 -0
  5. LICENSE +674 -0
  6. README.md +9 -9
  7. alembic.ini +84 -0
  8. alembic_db/README.md +4 -0
  9. alembic_db/env.py +64 -0
  10. alembic_db/script.py.mako +28 -0
  11. api_server/__init__.py +0 -0
  12. api_server/routes/__init__.py +0 -0
  13. api_server/routes/internal/README.md +3 -0
  14. api_server/routes/internal/__init__.py +0 -0
  15. api_server/routes/internal/internal_routes.py +73 -0
  16. api_server/services/__init__.py +0 -0
  17. api_server/services/terminal_service.py +60 -0
  18. api_server/utils/file_operations.py +42 -0
  19. app.py +553 -621
  20. app/__init__.py +0 -0
  21. app/app_settings.py +65 -0
  22. app/custom_node_manager.py +145 -0
  23. app/database/db.py +112 -0
  24. app/database/models.py +14 -0
  25. app/frontend_management.py +361 -0
  26. app/logger.py +98 -0
  27. app/model_manager.py +195 -0
  28. app/user_manager.py +438 -0
  29. comfy/checkpoint_pickle.py +13 -0
  30. comfy/cldm/cldm.py +433 -0
  31. comfy/cldm/control_types.py +10 -0
  32. comfy/cldm/dit_embedder.py +120 -0
  33. comfy/cldm/mmdit.py +81 -0
  34. comfy/cli_args.py +237 -0
  35. comfy/clip_config_bigg.json +23 -0
  36. comfy/clip_model.py +244 -0
  37. comfy/clip_vision.py +148 -0
  38. comfy/clip_vision_config_g.json +18 -0
  39. comfy/clip_vision_config_h.json +18 -0
  40. comfy/clip_vision_config_vitl.json +18 -0
  41. comfy/clip_vision_config_vitl_336.json +18 -0
  42. comfy/clip_vision_config_vitl_336_llava.json +19 -0
  43. comfy/clip_vision_siglip_384.json +13 -0
  44. comfy/clip_vision_siglip_512.json +13 -0
  45. comfy/comfy_types/README.md +43 -0
  46. comfy/comfy_types/__init__.py +46 -0
  47. comfy/comfy_types/examples/example_nodes.py +28 -0
  48. comfy/comfy_types/examples/input_options.png +0 -0
  49. comfy/comfy_types/examples/input_types.png +0 -0
  50. comfy/comfy_types/examples/required_hint.png +0 -0
.gitattributes CHANGED
@@ -1,35 +1,4 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ /web/assets/** linguist-generated
2
+ /web/** linguist-vendored
3
+ comfy_api_nodes/apis/__init__.py linguist-generated
4
+ comfy/text_encoders/t5_pile_tokenizer/tokenizer.model filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.gitignore ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.py[cod]
3
+ /output/
4
+ /input/
5
+ !/input/example.png
6
+ /models/
7
+ /temp/
8
+ /custom_nodes/
9
+ !custom_nodes/example_node.py.example
10
+ extra_model_paths.yaml
11
+ /.vs
12
+ .vscode/
13
+ .idea/
14
+ venv/
15
+ .venv/
16
+ /web/extensions/*
17
+ !/web/extensions/logging.js.example
18
+ !/web/extensions/core/
19
+ /tests-ui/data/object_info.json
20
+ /user/
21
+ *.log
22
+ web_custom_versions/
23
+ .DS_Store
24
+ openapi.yaml
25
+ filtered-openapi.yaml
26
+ uv.lock
CODEOWNERS ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Admins
2
+ * @comfyanonymous
3
+
4
+ # Note: Github teams syntax cannot be used here as the repo is not owned by Comfy-Org.
5
+ # Inlined the team members for now.
6
+
7
+ # Maintainers
8
+ *.md @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
9
+ /tests/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
10
+ /tests-unit/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
11
+ /notebooks/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
12
+ /script_examples/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
13
+ /.github/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
14
+ /requirements.txt @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
15
+ /pyproject.toml @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
16
+
17
+ # Python web server
18
+ /api_server/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne
19
+ /app/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne
20
+ /utils/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne
21
+
22
+ # Node developers
23
+ /comfy_extras/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne
24
+ /comfy/comfy_types/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne
CONTRIBUTING.md ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Contributing to ComfyUI
2
+
3
+ Welcome, and thank you for your interest in contributing to ComfyUI!
4
+
5
+ There are several ways in which you can contribute, beyond writing code. The goal of this document is to provide a high-level overview of how you can get involved.
6
+
7
+ ## Asking Questions
8
+
9
+ Have a question? Instead of opening an issue, please ask on [Discord](https://comfy.org/discord) or [Matrix](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) channels. Our team and the community will help you.
10
+
11
+ ## Providing Feedback
12
+
13
+ Your comments and feedback are welcome, and the development team is available via a handful of different channels.
14
+
15
+ See the `#bug-report`, `#feature-request` and `#feedback` channels on Discord.
16
+
17
+ ## Reporting Issues
18
+
19
+ Have you identified a reproducible problem in ComfyUI? Do you have a feature request? We want to hear about it! Here's how you can report your issue as effectively as possible.
20
+
21
+
22
+ ### Look For an Existing Issue
23
+
24
+ Before you create a new issue, please do a search in [open issues](https://github.com/comfyanonymous/ComfyUI/issues) to see if the issue or feature request has already been filed.
25
+
26
+ If you find your issue already exists, make relevant comments and add your [reaction](https://github.com/blog/2119-add-reactions-to-pull-requests-issues-and-comments). Use a reaction in place of a "+1" comment:
27
+
28
+ * 👍 - upvote
29
+ * 👎 - downvote
30
+
31
+ If you cannot find an existing issue that describes your bug or feature, create a new issue. We have an issue template in place to organize new issues.
32
+
33
+
34
+ ### Creating Pull Requests
35
+
36
+ * Please refer to the article on [creating pull requests](https://github.com/comfyanonymous/ComfyUI/wiki/How-to-Contribute-Code) and contributing to this project.
37
+
38
+
39
+ ## Thank You
40
+
41
+ Your contributions to open source, large or small, make great projects like this possible. Thank you for taking the time to contribute.
LICENSE ADDED
@@ -0,0 +1,674 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ GNU GENERAL PUBLIC LICENSE
2
+ Version 3, 29 June 2007
3
+
4
+ Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
5
+ Everyone is permitted to copy and distribute verbatim copies
6
+ of this license document, but changing it is not allowed.
7
+
8
+ Preamble
9
+
10
+ The GNU General Public License is a free, copyleft license for
11
+ software and other kinds of works.
12
+
13
+ The licenses for most software and other practical works are designed
14
+ to take away your freedom to share and change the works. By contrast,
15
+ the GNU General Public License is intended to guarantee your freedom to
16
+ share and change all versions of a program--to make sure it remains free
17
+ software for all its users. We, the Free Software Foundation, use the
18
+ GNU General Public License for most of our software; it applies also to
19
+ any other work released this way by its authors. You can apply it to
20
+ your programs, too.
21
+
22
+ When we speak of free software, we are referring to freedom, not
23
+ price. Our General Public Licenses are designed to make sure that you
24
+ have the freedom to distribute copies of free software (and charge for
25
+ them if you wish), that you receive source code or can get it if you
26
+ want it, that you can change the software or use pieces of it in new
27
+ free programs, and that you know you can do these things.
28
+
29
+ To protect your rights, we need to prevent others from denying you
30
+ these rights or asking you to surrender the rights. Therefore, you have
31
+ certain responsibilities if you distribute copies of the software, or if
32
+ you modify it: responsibilities to respect the freedom of others.
33
+
34
+ For example, if you distribute copies of such a program, whether
35
+ gratis or for a fee, you must pass on to the recipients the same
36
+ freedoms that you received. You must make sure that they, too, receive
37
+ or can get the source code. And you must show them these terms so they
38
+ know their rights.
39
+
40
+ Developers that use the GNU GPL protect your rights with two steps:
41
+ (1) assert copyright on the software, and (2) offer you this License
42
+ giving you legal permission to copy, distribute and/or modify it.
43
+
44
+ For the developers' and authors' protection, the GPL clearly explains
45
+ that there is no warranty for this free software. For both users' and
46
+ authors' sake, the GPL requires that modified versions be marked as
47
+ changed, so that their problems will not be attributed erroneously to
48
+ authors of previous versions.
49
+
50
+ Some devices are designed to deny users access to install or run
51
+ modified versions of the software inside them, although the manufacturer
52
+ can do so. This is fundamentally incompatible with the aim of
53
+ protecting users' freedom to change the software. The systematic
54
+ pattern of such abuse occurs in the area of products for individuals to
55
+ use, which is precisely where it is most unacceptable. Therefore, we
56
+ have designed this version of the GPL to prohibit the practice for those
57
+ products. If such problems arise substantially in other domains, we
58
+ stand ready to extend this provision to those domains in future versions
59
+ of the GPL, as needed to protect the freedom of users.
60
+
61
+ Finally, every program is threatened constantly by software patents.
62
+ States should not allow patents to restrict development and use of
63
+ software on general-purpose computers, but in those that do, we wish to
64
+ avoid the special danger that patents applied to a free program could
65
+ make it effectively proprietary. To prevent this, the GPL assures that
66
+ patents cannot be used to render the program non-free.
67
+
68
+ The precise terms and conditions for copying, distribution and
69
+ modification follow.
70
+
71
+ TERMS AND CONDITIONS
72
+
73
+ 0. Definitions.
74
+
75
+ "This License" refers to version 3 of the GNU General Public License.
76
+
77
+ "Copyright" also means copyright-like laws that apply to other kinds of
78
+ works, such as semiconductor masks.
79
+
80
+ "The Program" refers to any copyrightable work licensed under this
81
+ License. Each licensee is addressed as "you". "Licensees" and
82
+ "recipients" may be individuals or organizations.
83
+
84
+ To "modify" a work means to copy from or adapt all or part of the work
85
+ in a fashion requiring copyright permission, other than the making of an
86
+ exact copy. The resulting work is called a "modified version" of the
87
+ earlier work or a work "based on" the earlier work.
88
+
89
+ A "covered work" means either the unmodified Program or a work based
90
+ on the Program.
91
+
92
+ To "propagate" a work means to do anything with it that, without
93
+ permission, would make you directly or secondarily liable for
94
+ infringement under applicable copyright law, except executing it on a
95
+ computer or modifying a private copy. Propagation includes copying,
96
+ distribution (with or without modification), making available to the
97
+ public, and in some countries other activities as well.
98
+
99
+ To "convey" a work means any kind of propagation that enables other
100
+ parties to make or receive copies. Mere interaction with a user through
101
+ a computer network, with no transfer of a copy, is not conveying.
102
+
103
+ An interactive user interface displays "Appropriate Legal Notices"
104
+ to the extent that it includes a convenient and prominently visible
105
+ feature that (1) displays an appropriate copyright notice, and (2)
106
+ tells the user that there is no warranty for the work (except to the
107
+ extent that warranties are provided), that licensees may convey the
108
+ work under this License, and how to view a copy of this License. If
109
+ the interface presents a list of user commands or options, such as a
110
+ menu, a prominent item in the list meets this criterion.
111
+
112
+ 1. Source Code.
113
+
114
+ The "source code" for a work means the preferred form of the work
115
+ for making modifications to it. "Object code" means any non-source
116
+ form of a work.
117
+
118
+ A "Standard Interface" means an interface that either is an official
119
+ standard defined by a recognized standards body, or, in the case of
120
+ interfaces specified for a particular programming language, one that
121
+ is widely used among developers working in that language.
122
+
123
+ The "System Libraries" of an executable work include anything, other
124
+ than the work as a whole, that (a) is included in the normal form of
125
+ packaging a Major Component, but which is not part of that Major
126
+ Component, and (b) serves only to enable use of the work with that
127
+ Major Component, or to implement a Standard Interface for which an
128
+ implementation is available to the public in source code form. A
129
+ "Major Component", in this context, means a major essential component
130
+ (kernel, window system, and so on) of the specific operating system
131
+ (if any) on which the executable work runs, or a compiler used to
132
+ produce the work, or an object code interpreter used to run it.
133
+
134
+ The "Corresponding Source" for a work in object code form means all
135
+ the source code needed to generate, install, and (for an executable
136
+ work) run the object code and to modify the work, including scripts to
137
+ control those activities. However, it does not include the work's
138
+ System Libraries, or general-purpose tools or generally available free
139
+ programs which are used unmodified in performing those activities but
140
+ which are not part of the work. For example, Corresponding Source
141
+ includes interface definition files associated with source files for
142
+ the work, and the source code for shared libraries and dynamically
143
+ linked subprograms that the work is specifically designed to require,
144
+ such as by intimate data communication or control flow between those
145
+ subprograms and other parts of the work.
146
+
147
+ The Corresponding Source need not include anything that users
148
+ can regenerate automatically from other parts of the Corresponding
149
+ Source.
150
+
151
+ The Corresponding Source for a work in source code form is that
152
+ same work.
153
+
154
+ 2. Basic Permissions.
155
+
156
+ All rights granted under this License are granted for the term of
157
+ copyright on the Program, and are irrevocable provided the stated
158
+ conditions are met. This License explicitly affirms your unlimited
159
+ permission to run the unmodified Program. The output from running a
160
+ covered work is covered by this License only if the output, given its
161
+ content, constitutes a covered work. This License acknowledges your
162
+ rights of fair use or other equivalent, as provided by copyright law.
163
+
164
+ You may make, run and propagate covered works that you do not
165
+ convey, without conditions so long as your license otherwise remains
166
+ in force. You may convey covered works to others for the sole purpose
167
+ of having them make modifications exclusively for you, or provide you
168
+ with facilities for running those works, provided that you comply with
169
+ the terms of this License in conveying all material for which you do
170
+ not control copyright. Those thus making or running the covered works
171
+ for you must do so exclusively on your behalf, under your direction
172
+ and control, on terms that prohibit them from making any copies of
173
+ your copyrighted material outside their relationship with you.
174
+
175
+ Conveying under any other circumstances is permitted solely under
176
+ the conditions stated below. Sublicensing is not allowed; section 10
177
+ makes it unnecessary.
178
+
179
+ 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
180
+
181
+ No covered work shall be deemed part of an effective technological
182
+ measure under any applicable law fulfilling obligations under article
183
+ 11 of the WIPO copyright treaty adopted on 20 December 1996, or
184
+ similar laws prohibiting or restricting circumvention of such
185
+ measures.
186
+
187
+ When you convey a covered work, you waive any legal power to forbid
188
+ circumvention of technological measures to the extent such circumvention
189
+ is effected by exercising rights under this License with respect to
190
+ the covered work, and you disclaim any intention to limit operation or
191
+ modification of the work as a means of enforcing, against the work's
192
+ users, your or third parties' legal rights to forbid circumvention of
193
+ technological measures.
194
+
195
+ 4. Conveying Verbatim Copies.
196
+
197
+ You may convey verbatim copies of the Program's source code as you
198
+ receive it, in any medium, provided that you conspicuously and
199
+ appropriately publish on each copy an appropriate copyright notice;
200
+ keep intact all notices stating that this License and any
201
+ non-permissive terms added in accord with section 7 apply to the code;
202
+ keep intact all notices of the absence of any warranty; and give all
203
+ recipients a copy of this License along with the Program.
204
+
205
+ You may charge any price or no price for each copy that you convey,
206
+ and you may offer support or warranty protection for a fee.
207
+
208
+ 5. Conveying Modified Source Versions.
209
+
210
+ You may convey a work based on the Program, or the modifications to
211
+ produce it from the Program, in the form of source code under the
212
+ terms of section 4, provided that you also meet all of these conditions:
213
+
214
+ a) The work must carry prominent notices stating that you modified
215
+ it, and giving a relevant date.
216
+
217
+ b) The work must carry prominent notices stating that it is
218
+ released under this License and any conditions added under section
219
+ 7. This requirement modifies the requirement in section 4 to
220
+ "keep intact all notices".
221
+
222
+ c) You must license the entire work, as a whole, under this
223
+ License to anyone who comes into possession of a copy. This
224
+ License will therefore apply, along with any applicable section 7
225
+ additional terms, to the whole of the work, and all its parts,
226
+ regardless of how they are packaged. This License gives no
227
+ permission to license the work in any other way, but it does not
228
+ invalidate such permission if you have separately received it.
229
+
230
+ d) If the work has interactive user interfaces, each must display
231
+ Appropriate Legal Notices; however, if the Program has interactive
232
+ interfaces that do not display Appropriate Legal Notices, your
233
+ work need not make them do so.
234
+
235
+ A compilation of a covered work with other separate and independent
236
+ works, which are not by their nature extensions of the covered work,
237
+ and which are not combined with it such as to form a larger program,
238
+ in or on a volume of a storage or distribution medium, is called an
239
+ "aggregate" if the compilation and its resulting copyright are not
240
+ used to limit the access or legal rights of the compilation's users
241
+ beyond what the individual works permit. Inclusion of a covered work
242
+ in an aggregate does not cause this License to apply to the other
243
+ parts of the aggregate.
244
+
245
+ 6. Conveying Non-Source Forms.
246
+
247
+ You may convey a covered work in object code form under the terms
248
+ of sections 4 and 5, provided that you also convey the
249
+ machine-readable Corresponding Source under the terms of this License,
250
+ in one of these ways:
251
+
252
+ a) Convey the object code in, or embodied in, a physical product
253
+ (including a physical distribution medium), accompanied by the
254
+ Corresponding Source fixed on a durable physical medium
255
+ customarily used for software interchange.
256
+
257
+ b) Convey the object code in, or embodied in, a physical product
258
+ (including a physical distribution medium), accompanied by a
259
+ written offer, valid for at least three years and valid for as
260
+ long as you offer spare parts or customer support for that product
261
+ model, to give anyone who possesses the object code either (1) a
262
+ copy of the Corresponding Source for all the software in the
263
+ product that is covered by this License, on a durable physical
264
+ medium customarily used for software interchange, for a price no
265
+ more than your reasonable cost of physically performing this
266
+ conveying of source, or (2) access to copy the
267
+ Corresponding Source from a network server at no charge.
268
+
269
+ c) Convey individual copies of the object code with a copy of the
270
+ written offer to provide the Corresponding Source. This
271
+ alternative is allowed only occasionally and noncommercially, and
272
+ only if you received the object code with such an offer, in accord
273
+ with subsection 6b.
274
+
275
+ d) Convey the object code by offering access from a designated
276
+ place (gratis or for a charge), and offer equivalent access to the
277
+ Corresponding Source in the same way through the same place at no
278
+ further charge. You need not require recipients to copy the
279
+ Corresponding Source along with the object code. If the place to
280
+ copy the object code is a network server, the Corresponding Source
281
+ may be on a different server (operated by you or a third party)
282
+ that supports equivalent copying facilities, provided you maintain
283
+ clear directions next to the object code saying where to find the
284
+ Corresponding Source. Regardless of what server hosts the
285
+ Corresponding Source, you remain obligated to ensure that it is
286
+ available for as long as needed to satisfy these requirements.
287
+
288
+ e) Convey the object code using peer-to-peer transmission, provided
289
+ you inform other peers where the object code and Corresponding
290
+ Source of the work are being offered to the general public at no
291
+ charge under subsection 6d.
292
+
293
+ A separable portion of the object code, whose source code is excluded
294
+ from the Corresponding Source as a System Library, need not be
295
+ included in conveying the object code work.
296
+
297
+ A "User Product" is either (1) a "consumer product", which means any
298
+ tangible personal property which is normally used for personal, family,
299
+ or household purposes, or (2) anything designed or sold for incorporation
300
+ into a dwelling. In determining whether a product is a consumer product,
301
+ doubtful cases shall be resolved in favor of coverage. For a particular
302
+ product received by a particular user, "normally used" refers to a
303
+ typical or common use of that class of product, regardless of the status
304
+ of the particular user or of the way in which the particular user
305
+ actually uses, or expects or is expected to use, the product. A product
306
+ is a consumer product regardless of whether the product has substantial
307
+ commercial, industrial or non-consumer uses, unless such uses represent
308
+ the only significant mode of use of the product.
309
+
310
+ "Installation Information" for a User Product means any methods,
311
+ procedures, authorization keys, or other information required to install
312
+ and execute modified versions of a covered work in that User Product from
313
+ a modified version of its Corresponding Source. The information must
314
+ suffice to ensure that the continued functioning of the modified object
315
+ code is in no case prevented or interfered with solely because
316
+ modification has been made.
317
+
318
+ If you convey an object code work under this section in, or with, or
319
+ specifically for use in, a User Product, and the conveying occurs as
320
+ part of a transaction in which the right of possession and use of the
321
+ User Product is transferred to the recipient in perpetuity or for a
322
+ fixed term (regardless of how the transaction is characterized), the
323
+ Corresponding Source conveyed under this section must be accompanied
324
+ by the Installation Information. But this requirement does not apply
325
+ if neither you nor any third party retains the ability to install
326
+ modified object code on the User Product (for example, the work has
327
+ been installed in ROM).
328
+
329
+ The requirement to provide Installation Information does not include a
330
+ requirement to continue to provide support service, warranty, or updates
331
+ for a work that has been modified or installed by the recipient, or for
332
+ the User Product in which it has been modified or installed. Access to a
333
+ network may be denied when the modification itself materially and
334
+ adversely affects the operation of the network or violates the rules and
335
+ protocols for communication across the network.
336
+
337
+ Corresponding Source conveyed, and Installation Information provided,
338
+ in accord with this section must be in a format that is publicly
339
+ documented (and with an implementation available to the public in
340
+ source code form), and must require no special password or key for
341
+ unpacking, reading or copying.
342
+
343
+ 7. Additional Terms.
344
+
345
+ "Additional permissions" are terms that supplement the terms of this
346
+ License by making exceptions from one or more of its conditions.
347
+ Additional permissions that are applicable to the entire Program shall
348
+ be treated as though they were included in this License, to the extent
349
+ that they are valid under applicable law. If additional permissions
350
+ apply only to part of the Program, that part may be used separately
351
+ under those permissions, but the entire Program remains governed by
352
+ this License without regard to the additional permissions.
353
+
354
+ When you convey a copy of a covered work, you may at your option
355
+ remove any additional permissions from that copy, or from any part of
356
+ it. (Additional permissions may be written to require their own
357
+ removal in certain cases when you modify the work.) You may place
358
+ additional permissions on material, added by you to a covered work,
359
+ for which you have or can give appropriate copyright permission.
360
+
361
+ Notwithstanding any other provision of this License, for material you
362
+ add to a covered work, you may (if authorized by the copyright holders of
363
+ that material) supplement the terms of this License with terms:
364
+
365
+ a) Disclaiming warranty or limiting liability differently from the
366
+ terms of sections 15 and 16 of this License; or
367
+
368
+ b) Requiring preservation of specified reasonable legal notices or
369
+ author attributions in that material or in the Appropriate Legal
370
+ Notices displayed by works containing it; or
371
+
372
+ c) Prohibiting misrepresentation of the origin of that material, or
373
+ requiring that modified versions of such material be marked in
374
+ reasonable ways as different from the original version; or
375
+
376
+ d) Limiting the use for publicity purposes of names of licensors or
377
+ authors of the material; or
378
+
379
+ e) Declining to grant rights under trademark law for use of some
380
+ trade names, trademarks, or service marks; or
381
+
382
+ f) Requiring indemnification of licensors and authors of that
383
+ material by anyone who conveys the material (or modified versions of
384
+ it) with contractual assumptions of liability to the recipient, for
385
+ any liability that these contractual assumptions directly impose on
386
+ those licensors and authors.
387
+
388
+ All other non-permissive additional terms are considered "further
389
+ restrictions" within the meaning of section 10. If the Program as you
390
+ received it, or any part of it, contains a notice stating that it is
391
+ governed by this License along with a term that is a further
392
+ restriction, you may remove that term. If a license document contains
393
+ a further restriction but permits relicensing or conveying under this
394
+ License, you may add to a covered work material governed by the terms
395
+ of that license document, provided that the further restriction does
396
+ not survive such relicensing or conveying.
397
+
398
+ If you add terms to a covered work in accord with this section, you
399
+ must place, in the relevant source files, a statement of the
400
+ additional terms that apply to those files, or a notice indicating
401
+ where to find the applicable terms.
402
+
403
+ Additional terms, permissive or non-permissive, may be stated in the
404
+ form of a separately written license, or stated as exceptions;
405
+ the above requirements apply either way.
406
+
407
+ 8. Termination.
408
+
409
+ You may not propagate or modify a covered work except as expressly
410
+ provided under this License. Any attempt otherwise to propagate or
411
+ modify it is void, and will automatically terminate your rights under
412
+ this License (including any patent licenses granted under the third
413
+ paragraph of section 11).
414
+
415
+ However, if you cease all violation of this License, then your
416
+ license from a particular copyright holder is reinstated (a)
417
+ provisionally, unless and until the copyright holder explicitly and
418
+ finally terminates your license, and (b) permanently, if the copyright
419
+ holder fails to notify you of the violation by some reasonable means
420
+ prior to 60 days after the cessation.
421
+
422
+ Moreover, your license from a particular copyright holder is
423
+ reinstated permanently if the copyright holder notifies you of the
424
+ violation by some reasonable means, this is the first time you have
425
+ received notice of violation of this License (for any work) from that
426
+ copyright holder, and you cure the violation prior to 30 days after
427
+ your receipt of the notice.
428
+
429
+ Termination of your rights under this section does not terminate the
430
+ licenses of parties who have received copies or rights from you under
431
+ this License. If your rights have been terminated and not permanently
432
+ reinstated, you do not qualify to receive new licenses for the same
433
+ material under section 10.
434
+
435
+ 9. Acceptance Not Required for Having Copies.
436
+
437
+ You are not required to accept this License in order to receive or
438
+ run a copy of the Program. Ancillary propagation of a covered work
439
+ occurring solely as a consequence of using peer-to-peer transmission
440
+ to receive a copy likewise does not require acceptance. However,
441
+ nothing other than this License grants you permission to propagate or
442
+ modify any covered work. These actions infringe copyright if you do
443
+ not accept this License. Therefore, by modifying or propagating a
444
+ covered work, you indicate your acceptance of this License to do so.
445
+
446
+ 10. Automatic Licensing of Downstream Recipients.
447
+
448
+ Each time you convey a covered work, the recipient automatically
449
+ receives a license from the original licensors, to run, modify and
450
+ propagate that work, subject to this License. You are not responsible
451
+ for enforcing compliance by third parties with this License.
452
+
453
+ An "entity transaction" is a transaction transferring control of an
454
+ organization, or substantially all assets of one, or subdividing an
455
+ organization, or merging organizations. If propagation of a covered
456
+ work results from an entity transaction, each party to that
457
+ transaction who receives a copy of the work also receives whatever
458
+ licenses to the work the party's predecessor in interest had or could
459
+ give under the previous paragraph, plus a right to possession of the
460
+ Corresponding Source of the work from the predecessor in interest, if
461
+ the predecessor has it or can get it with reasonable efforts.
462
+
463
+ You may not impose any further restrictions on the exercise of the
464
+ rights granted or affirmed under this License. For example, you may
465
+ not impose a license fee, royalty, or other charge for exercise of
466
+ rights granted under this License, and you may not initiate litigation
467
+ (including a cross-claim or counterclaim in a lawsuit) alleging that
468
+ any patent claim is infringed by making, using, selling, offering for
469
+ sale, or importing the Program or any portion of it.
470
+
471
+ 11. Patents.
472
+
473
+ A "contributor" is a copyright holder who authorizes use under this
474
+ License of the Program or a work on which the Program is based. The
475
+ work thus licensed is called the contributor's "contributor version".
476
+
477
+ A contributor's "essential patent claims" are all patent claims
478
+ owned or controlled by the contributor, whether already acquired or
479
+ hereafter acquired, that would be infringed by some manner, permitted
480
+ by this License, of making, using, or selling its contributor version,
481
+ but do not include claims that would be infringed only as a
482
+ consequence of further modification of the contributor version. For
483
+ purposes of this definition, "control" includes the right to grant
484
+ patent sublicenses in a manner consistent with the requirements of
485
+ this License.
486
+
487
+ Each contributor grants you a non-exclusive, worldwide, royalty-free
488
+ patent license under the contributor's essential patent claims, to
489
+ make, use, sell, offer for sale, import and otherwise run, modify and
490
+ propagate the contents of its contributor version.
491
+
492
+ In the following three paragraphs, a "patent license" is any express
493
+ agreement or commitment, however denominated, not to enforce a patent
494
+ (such as an express permission to practice a patent or covenant not to
495
+ sue for patent infringement). To "grant" such a patent license to a
496
+ party means to make such an agreement or commitment not to enforce a
497
+ patent against the party.
498
+
499
+ If you convey a covered work, knowingly relying on a patent license,
500
+ and the Corresponding Source of the work is not available for anyone
501
+ to copy, free of charge and under the terms of this License, through a
502
+ publicly available network server or other readily accessible means,
503
+ then you must either (1) cause the Corresponding Source to be so
504
+ available, or (2) arrange to deprive yourself of the benefit of the
505
+ patent license for this particular work, or (3) arrange, in a manner
506
+ consistent with the requirements of this License, to extend the patent
507
+ license to downstream recipients. "Knowingly relying" means you have
508
+ actual knowledge that, but for the patent license, your conveying the
509
+ covered work in a country, or your recipient's use of the covered work
510
+ in a country, would infringe one or more identifiable patents in that
511
+ country that you have reason to believe are valid.
512
+
513
+ If, pursuant to or in connection with a single transaction or
514
+ arrangement, you convey, or propagate by procuring conveyance of, a
515
+ covered work, and grant a patent license to some of the parties
516
+ receiving the covered work authorizing them to use, propagate, modify
517
+ or convey a specific copy of the covered work, then the patent license
518
+ you grant is automatically extended to all recipients of the covered
519
+ work and works based on it.
520
+
521
+ A patent license is "discriminatory" if it does not include within
522
+ the scope of its coverage, prohibits the exercise of, or is
523
+ conditioned on the non-exercise of one or more of the rights that are
524
+ specifically granted under this License. You may not convey a covered
525
+ work if you are a party to an arrangement with a third party that is
526
+ in the business of distributing software, under which you make payment
527
+ to the third party based on the extent of your activity of conveying
528
+ the work, and under which the third party grants, to any of the
529
+ parties who would receive the covered work from you, a discriminatory
530
+ patent license (a) in connection with copies of the covered work
531
+ conveyed by you (or copies made from those copies), or (b) primarily
532
+ for and in connection with specific products or compilations that
533
+ contain the covered work, unless you entered into that arrangement,
534
+ or that patent license was granted, prior to 28 March 2007.
535
+
536
+ Nothing in this License shall be construed as excluding or limiting
537
+ any implied license or other defenses to infringement that may
538
+ otherwise be available to you under applicable patent law.
539
+
540
+ 12. No Surrender of Others' Freedom.
541
+
542
+ If conditions are imposed on you (whether by court order, agreement or
543
+ otherwise) that contradict the conditions of this License, they do not
544
+ excuse you from the conditions of this License. If you cannot convey a
545
+ covered work so as to satisfy simultaneously your obligations under this
546
+ License and any other pertinent obligations, then as a consequence you may
547
+ not convey it at all. For example, if you agree to terms that obligate you
548
+ to collect a royalty for further conveying from those to whom you convey
549
+ the Program, the only way you could satisfy both those terms and this
550
+ License would be to refrain entirely from conveying the Program.
551
+
552
+ 13. Use with the GNU Affero General Public License.
553
+
554
+ Notwithstanding any other provision of this License, you have
555
+ permission to link or combine any covered work with a work licensed
556
+ under version 3 of the GNU Affero General Public License into a single
557
+ combined work, and to convey the resulting work. The terms of this
558
+ License will continue to apply to the part which is the covered work,
559
+ but the special requirements of the GNU Affero General Public License,
560
+ section 13, concerning interaction through a network will apply to the
561
+ combination as such.
562
+
563
+ 14. Revised Versions of this License.
564
+
565
+ The Free Software Foundation may publish revised and/or new versions of
566
+ the GNU General Public License from time to time. Such new versions will
567
+ be similar in spirit to the present version, but may differ in detail to
568
+ address new problems or concerns.
569
+
570
+ Each version is given a distinguishing version number. If the
571
+ Program specifies that a certain numbered version of the GNU General
572
+ Public License "or any later version" applies to it, you have the
573
+ option of following the terms and conditions either of that numbered
574
+ version or of any later version published by the Free Software
575
+ Foundation. If the Program does not specify a version number of the
576
+ GNU General Public License, you may choose any version ever published
577
+ by the Free Software Foundation.
578
+
579
+ If the Program specifies that a proxy can decide which future
580
+ versions of the GNU General Public License can be used, that proxy's
581
+ public statement of acceptance of a version permanently authorizes you
582
+ to choose that version for the Program.
583
+
584
+ Later license versions may give you additional or different
585
+ permissions. However, no additional obligations are imposed on any
586
+ author or copyright holder as a result of your choosing to follow a
587
+ later version.
588
+
589
+ 15. Disclaimer of Warranty.
590
+
591
+ THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
592
+ APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
593
+ HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
594
+ OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
595
+ THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
596
+ PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
597
+ IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
598
+ ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
599
+
600
+ 16. Limitation of Liability.
601
+
602
+ IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
603
+ WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
604
+ THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
605
+ GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
606
+ USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
607
+ DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
608
+ PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
609
+ EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
610
+ SUCH DAMAGES.
611
+
612
+ 17. Interpretation of Sections 15 and 16.
613
+
614
+ If the disclaimer of warranty and limitation of liability provided
615
+ above cannot be given local legal effect according to their terms,
616
+ reviewing courts shall apply local law that most closely approximates
617
+ an absolute waiver of all civil liability in connection with the
618
+ Program, unless a warranty or assumption of liability accompanies a
619
+ copy of the Program in return for a fee.
620
+
621
+ END OF TERMS AND CONDITIONS
622
+
623
+ How to Apply These Terms to Your New Programs
624
+
625
+ If you develop a new program, and you want it to be of the greatest
626
+ possible use to the public, the best way to achieve this is to make it
627
+ free software which everyone can redistribute and change under these terms.
628
+
629
+ To do so, attach the following notices to the program. It is safest
630
+ to attach them to the start of each source file to most effectively
631
+ state the exclusion of warranty; and each file should have at least
632
+ the "copyright" line and a pointer to where the full notice is found.
633
+
634
+ <one line to give the program's name and a brief idea of what it does.>
635
+ Copyright (C) <year> <name of author>
636
+
637
+ This program is free software: you can redistribute it and/or modify
638
+ it under the terms of the GNU General Public License as published by
639
+ the Free Software Foundation, either version 3 of the License, or
640
+ (at your option) any later version.
641
+
642
+ This program is distributed in the hope that it will be useful,
643
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
644
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
645
+ GNU General Public License for more details.
646
+
647
+ You should have received a copy of the GNU General Public License
648
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
649
+
650
+ Also add information on how to contact you by electronic and paper mail.
651
+
652
+ If the program does terminal interaction, make it output a short
653
+ notice like this when it starts in an interactive mode:
654
+
655
+ <program> Copyright (C) <year> <name of author>
656
+ This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
657
+ This is free software, and you are welcome to redistribute it
658
+ under certain conditions; type `show c' for details.
659
+
660
+ The hypothetical commands `show w' and `show c' should show the appropriate
661
+ parts of the General Public License. Of course, your program's commands
662
+ might be different; for a GUI interface, you would use an "about box".
663
+
664
+ You should also get your employer (if you work as a programmer) or school,
665
+ if any, to sign a "copyright disclaimer" for the program, if necessary.
666
+ For more information on this, and how to apply and follow the GNU GPL, see
667
+ <https://www.gnu.org/licenses/>.
668
+
669
+ The GNU General Public License does not permit incorporating your program
670
+ into proprietary programs. If your program is a subroutine library, you
671
+ may consider it more useful to permit linking proprietary applications with
672
+ the library. If this is what you want to do, use the GNU Lesser General
673
+ Public License instead of this License. But first, please read
674
+ <https://www.gnu.org/licenses/why-not-lgpl.html>.
README.md CHANGED
@@ -1,9 +1,9 @@
1
- ---
2
- title: Animated SDXL T2I with LoRAs
3
- emoji: 🖼
4
- colorFrom: purple
5
- colorTo: red
6
- sdk: gradio
7
- app_file: app.py
8
- pinned: true
9
- ---
 
1
+ ---
2
+ title: Animated T2I with LoRAs
3
+ emoji: 🖼
4
+ colorFrom: purple
5
+ colorTo: red
6
+ sdk: gradio
7
+ app_file: app.py
8
+ pinned: true
9
+ ---
alembic.ini ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # A generic, single database configuration.
2
+
3
+ [alembic]
4
+ # path to migration scripts
5
+ # Use forward slashes (/) also on windows to provide an os agnostic path
6
+ script_location = alembic_db
7
+
8
+ # template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
9
+ # Uncomment the line below if you want the files to be prepended with date and time
10
+ # see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file
11
+ # for all available tokens
12
+ # file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
13
+
14
+ # sys.path path, will be prepended to sys.path if present.
15
+ # defaults to the current working directory.
16
+ prepend_sys_path = .
17
+
18
+ # timezone to use when rendering the date within the migration file
19
+ # as well as the filename.
20
+ # If specified, requires the python>=3.9 or backports.zoneinfo library and tzdata library.
21
+ # Any required deps can installed by adding `alembic[tz]` to the pip requirements
22
+ # string value is passed to ZoneInfo()
23
+ # leave blank for localtime
24
+ # timezone =
25
+
26
+ # max length of characters to apply to the "slug" field
27
+ # truncate_slug_length = 40
28
+
29
+ # set to 'true' to run the environment during
30
+ # the 'revision' command, regardless of autogenerate
31
+ # revision_environment = false
32
+
33
+ # set to 'true' to allow .pyc and .pyo files without
34
+ # a source .py file to be detected as revisions in the
35
+ # versions/ directory
36
+ # sourceless = false
37
+
38
+ # version location specification; This defaults
39
+ # to alembic_db/versions. When using multiple version
40
+ # directories, initial revisions must be specified with --version-path.
41
+ # The path separator used here should be the separator specified by "version_path_separator" below.
42
+ # version_locations = %(here)s/bar:%(here)s/bat:alembic_db/versions
43
+
44
+ # version path separator; As mentioned above, this is the character used to split
45
+ # version_locations. The default within new alembic.ini files is "os", which uses os.pathsep.
46
+ # If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas.
47
+ # Valid values for version_path_separator are:
48
+ #
49
+ # version_path_separator = :
50
+ # version_path_separator = ;
51
+ # version_path_separator = space
52
+ # version_path_separator = newline
53
+ #
54
+ # Use os.pathsep. Default configuration used for new projects.
55
+ version_path_separator = os
56
+
57
+ # set to 'true' to search source files recursively
58
+ # in each "version_locations" directory
59
+ # new in Alembic version 1.10
60
+ # recursive_version_locations = false
61
+
62
+ # the output encoding used when revision files
63
+ # are written from script.py.mako
64
+ # output_encoding = utf-8
65
+
66
+ sqlalchemy.url = sqlite:///user/comfyui.db
67
+
68
+
69
+ [post_write_hooks]
70
+ # post_write_hooks defines scripts or Python functions that are run
71
+ # on newly generated revision scripts. See the documentation for further
72
+ # detail and examples
73
+
74
+ # format using "black" - use the console_scripts runner, against the "black" entrypoint
75
+ # hooks = black
76
+ # black.type = console_scripts
77
+ # black.entrypoint = black
78
+ # black.options = -l 79 REVISION_SCRIPT_FILENAME
79
+
80
+ # lint with attempts to fix using "ruff" - use the exec runner, execute a binary
81
+ # hooks = ruff
82
+ # ruff.type = exec
83
+ # ruff.executable = %(here)s/.venv/bin/ruff
84
+ # ruff.options = check --fix REVISION_SCRIPT_FILENAME
alembic_db/README.md ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ ## Generate new revision
2
+
3
+ 1. Update models in `/app/database/models.py`
4
+ 2. Run `alembic revision --autogenerate -m "{your message}"`
alembic_db/env.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sqlalchemy import engine_from_config
2
+ from sqlalchemy import pool
3
+
4
+ from alembic import context
5
+
6
+ # this is the Alembic Config object, which provides
7
+ # access to the values within the .ini file in use.
8
+ config = context.config
9
+
10
+
11
+ from app.database.models import Base
12
+ target_metadata = Base.metadata
13
+
14
+ # other values from the config, defined by the needs of env.py,
15
+ # can be acquired:
16
+ # my_important_option = config.get_main_option("my_important_option")
17
+ # ... etc.
18
+
19
+
20
+ def run_migrations_offline() -> None:
21
+ """Run migrations in 'offline' mode.
22
+ This configures the context with just a URL
23
+ and not an Engine, though an Engine is acceptable
24
+ here as well. By skipping the Engine creation
25
+ we don't even need a DBAPI to be available.
26
+ Calls to context.execute() here emit the given string to the
27
+ script output.
28
+ """
29
+ url = config.get_main_option("sqlalchemy.url")
30
+ context.configure(
31
+ url=url,
32
+ target_metadata=target_metadata,
33
+ literal_binds=True,
34
+ dialect_opts={"paramstyle": "named"},
35
+ )
36
+
37
+ with context.begin_transaction():
38
+ context.run_migrations()
39
+
40
+
41
+ def run_migrations_online() -> None:
42
+ """Run migrations in 'online' mode.
43
+ In this scenario we need to create an Engine
44
+ and associate a connection with the context.
45
+ """
46
+ connectable = engine_from_config(
47
+ config.get_section(config.config_ini_section, {}),
48
+ prefix="sqlalchemy.",
49
+ poolclass=pool.NullPool,
50
+ )
51
+
52
+ with connectable.connect() as connection:
53
+ context.configure(
54
+ connection=connection, target_metadata=target_metadata
55
+ )
56
+
57
+ with context.begin_transaction():
58
+ context.run_migrations()
59
+
60
+
61
+ if context.is_offline_mode():
62
+ run_migrations_offline()
63
+ else:
64
+ run_migrations_online()
alembic_db/script.py.mako ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """${message}
2
+
3
+ Revision ID: ${up_revision}
4
+ Revises: ${down_revision | comma,n}
5
+ Create Date: ${create_date}
6
+
7
+ """
8
+ from typing import Sequence, Union
9
+
10
+ from alembic import op
11
+ import sqlalchemy as sa
12
+ ${imports if imports else ""}
13
+
14
+ # revision identifiers, used by Alembic.
15
+ revision: str = ${repr(up_revision)}
16
+ down_revision: Union[str, None] = ${repr(down_revision)}
17
+ branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
18
+ depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
19
+
20
+
21
+ def upgrade() -> None:
22
+ """Upgrade schema."""
23
+ ${upgrades if upgrades else "pass"}
24
+
25
+
26
+ def downgrade() -> None:
27
+ """Downgrade schema."""
28
+ ${downgrades if downgrades else "pass"}
api_server/__init__.py ADDED
File without changes
api_server/routes/__init__.py ADDED
File without changes
api_server/routes/internal/README.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # ComfyUI Internal Routes
2
+
3
+ All routes under the `/internal` path are designated for **internal use by ComfyUI only**. These routes are not intended for use by external applications may change at any time without notice.
api_server/routes/internal/__init__.py ADDED
File without changes
api_server/routes/internal/internal_routes.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from aiohttp import web
2
+ from typing import Optional
3
+ from folder_paths import folder_names_and_paths, get_directory_by_type
4
+ from api_server.services.terminal_service import TerminalService
5
+ import app.logger
6
+ import os
7
+
8
+ class InternalRoutes:
9
+ '''
10
+ The top level web router for internal routes: /internal/*
11
+ The endpoints here should NOT be depended upon. It is for ComfyUI frontend use only.
12
+ Check README.md for more information.
13
+ '''
14
+
15
+ def __init__(self, prompt_server):
16
+ self.routes: web.RouteTableDef = web.RouteTableDef()
17
+ self._app: Optional[web.Application] = None
18
+ self.prompt_server = prompt_server
19
+ self.terminal_service = TerminalService(prompt_server)
20
+
21
+ def setup_routes(self):
22
+ @self.routes.get('/logs')
23
+ async def get_logs(request):
24
+ return web.json_response("".join([(l["t"] + " - " + l["m"]) for l in app.logger.get_logs()]))
25
+
26
+ @self.routes.get('/logs/raw')
27
+ async def get_raw_logs(request):
28
+ self.terminal_service.update_size()
29
+ return web.json_response({
30
+ "entries": list(app.logger.get_logs()),
31
+ "size": {"cols": self.terminal_service.cols, "rows": self.terminal_service.rows}
32
+ })
33
+
34
+ @self.routes.patch('/logs/subscribe')
35
+ async def subscribe_logs(request):
36
+ json_data = await request.json()
37
+ client_id = json_data["clientId"]
38
+ enabled = json_data["enabled"]
39
+ if enabled:
40
+ self.terminal_service.subscribe(client_id)
41
+ else:
42
+ self.terminal_service.unsubscribe(client_id)
43
+
44
+ return web.Response(status=200)
45
+
46
+
47
+ @self.routes.get('/folder_paths')
48
+ async def get_folder_paths(request):
49
+ response = {}
50
+ for key in folder_names_and_paths:
51
+ response[key] = folder_names_and_paths[key][0]
52
+ return web.json_response(response)
53
+
54
+ @self.routes.get('/files/{directory_type}')
55
+ async def get_files(request: web.Request) -> web.Response:
56
+ directory_type = request.match_info['directory_type']
57
+ if directory_type not in ("output", "input", "temp"):
58
+ return web.json_response({"error": "Invalid directory type"}, status=400)
59
+
60
+ directory = get_directory_by_type(directory_type)
61
+ sorted_files = sorted(
62
+ (entry for entry in os.scandir(directory) if entry.is_file()),
63
+ key=lambda entry: -entry.stat().st_mtime
64
+ )
65
+ return web.json_response([entry.name for entry in sorted_files], status=200)
66
+
67
+
68
+ def get_app(self):
69
+ if self._app is None:
70
+ self._app = web.Application()
71
+ self.setup_routes()
72
+ self._app.add_routes(self.routes)
73
+ return self._app
api_server/services/__init__.py ADDED
File without changes
api_server/services/terminal_service.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from app.logger import on_flush
2
+ import os
3
+ import shutil
4
+
5
+
6
+ class TerminalService:
7
+ def __init__(self, server):
8
+ self.server = server
9
+ self.cols = None
10
+ self.rows = None
11
+ self.subscriptions = set()
12
+ on_flush(self.send_messages)
13
+
14
+ def get_terminal_size(self):
15
+ try:
16
+ size = os.get_terminal_size()
17
+ return (size.columns, size.lines)
18
+ except OSError:
19
+ try:
20
+ size = shutil.get_terminal_size()
21
+ return (size.columns, size.lines)
22
+ except OSError:
23
+ return (80, 24) # fallback to 80x24
24
+
25
+ def update_size(self):
26
+ columns, lines = self.get_terminal_size()
27
+ changed = False
28
+
29
+ if columns != self.cols:
30
+ self.cols = columns
31
+ changed = True
32
+
33
+ if lines != self.rows:
34
+ self.rows = lines
35
+ changed = True
36
+
37
+ if changed:
38
+ return {"cols": self.cols, "rows": self.rows}
39
+
40
+ return None
41
+
42
+ def subscribe(self, client_id):
43
+ self.subscriptions.add(client_id)
44
+
45
+ def unsubscribe(self, client_id):
46
+ self.subscriptions.discard(client_id)
47
+
48
+ def send_messages(self, entries):
49
+ if not len(entries) or not len(self.subscriptions):
50
+ return
51
+
52
+ new_size = self.update_size()
53
+
54
+ for client_id in self.subscriptions.copy(): # prevent: Set changed size during iteration
55
+ if client_id not in self.server.sockets:
56
+ # Automatically unsub if the socket has disconnected
57
+ self.unsubscribe(client_id)
58
+ continue
59
+
60
+ self.server.send_sync("logs", {"entries": entries, "size": new_size}, client_id)
api_server/utils/file_operations.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List, Union, TypedDict, Literal
3
+ from typing_extensions import TypeGuard
4
+ class FileInfo(TypedDict):
5
+ name: str
6
+ path: str
7
+ type: Literal["file"]
8
+ size: int
9
+
10
+ class DirectoryInfo(TypedDict):
11
+ name: str
12
+ path: str
13
+ type: Literal["directory"]
14
+
15
+ FileSystemItem = Union[FileInfo, DirectoryInfo]
16
+
17
+ def is_file_info(item: FileSystemItem) -> TypeGuard[FileInfo]:
18
+ return item["type"] == "file"
19
+
20
+ class FileSystemOperations:
21
+ @staticmethod
22
+ def walk_directory(directory: str) -> List[FileSystemItem]:
23
+ file_list: List[FileSystemItem] = []
24
+ for root, dirs, files in os.walk(directory):
25
+ for name in files:
26
+ file_path = os.path.join(root, name)
27
+ relative_path = os.path.relpath(file_path, directory)
28
+ file_list.append({
29
+ "name": name,
30
+ "path": relative_path,
31
+ "type": "file",
32
+ "size": os.path.getsize(file_path)
33
+ })
34
+ for name in dirs:
35
+ dir_path = os.path.join(root, name)
36
+ relative_path = os.path.relpath(dir_path, directory)
37
+ file_list.append({
38
+ "name": name,
39
+ "path": relative_path,
40
+ "type": "directory"
41
+ })
42
+ return file_list
app.py CHANGED
@@ -1,621 +1,553 @@
1
- import spaces
2
- import gradio as gr
3
- import numpy as np
4
- import PIL.Image
5
- from PIL import Image, PngImagePlugin
6
- import random
7
- from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, DPMSolverMultistepScheduler, DDIMScheduler, UniPCMultistepScheduler, HeunDiscreteScheduler, LMSDiscreteScheduler
8
- import torch
9
- from compel import Compel, ReturnedEmbeddingsType
10
- import requests
11
- import os
12
- import re
13
- import gc
14
- import hashlib
15
- from huggingface_hub import hf_hub_download, snapshot_download
16
- import time
17
-
18
- # This dummy function is required to pass the Hugging Face Spaces startup check for GPU apps.
19
- @spaces.GPU(duration=60)
20
- def dummy_gpu_for_startup():
21
- print("Dummy function for startup check executed. This is normal.")
22
- return "Startup check passed."
23
-
24
- # --- Constants ---
25
- MAX_LORAS = 5
26
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27
- MAX_SEED = np.iinfo(np.int64).max
28
- MAX_IMAGE_SIZE = 1216
29
- SAMPLER_MAP = {
30
- "Euler a": EulerAncestralDiscreteScheduler,
31
- "Euler": EulerDiscreteScheduler,
32
- "DPM++ 2M Karras": DPMSolverMultistepScheduler,
33
- "DDIM": DDIMScheduler,
34
- "UniPC": UniPCMultistepScheduler,
35
- "Heun": HeunDiscreteScheduler,
36
- "LMS": LMSDiscreteScheduler,
37
- }
38
- SCHEDULE_TYPE_MAP = ["Default", "Karras", "Uniform", "SGM Uniform"]
39
- LORA_SOURCE_CHOICES = ["Civitai", "TensorArt", "Custom URL", "File"]
40
- DEFAULT_SCHEDULE_TYPE = "Default"
41
- DEFAULT_SAMPLER = "Euler a"
42
- DEFAULT_NEGATIVE_PROMPT = "monochrome, (low quality, worst quality:1.2), 3d, watermark, signature, ugly, poorly drawn,"
43
- DOWNLOAD_DIR = "/tmp/loras"
44
- os.makedirs(DOWNLOAD_DIR, exist_ok=True)
45
-
46
- # --- Model Lists ---
47
- MODEL_LIST = [
48
- "dhead/wai-nsfw-illustrious-sdxl-v140-sdxl",
49
- "Laxhar/noobai-XL-Vpred-1.0",
50
- "John6666/hassaku-xl-illustrious-v30-sdxl",
51
- "RedRayz/hikari_noob_v-pred_1.2.2",
52
- "bluepen5805/noob_v_pencil-XL",
53
- "Laxhar/noobai-XL-1.1"
54
- ]
55
-
56
- # --- Model Display Name Mapping ---
57
- MODEL_DISPLAY_NAME_MAP = {
58
- "dhead/wai-nsfw-illustrious-sdxl-v140-sdxl": "WAI0731/wai-nsfw-illustrious-sdxl-v140-sdxl",
59
- "Laxhar/noobai-XL-Vpred-1.0": "Laxhar/noobai-XL-Vpred-1.0",
60
- "John6666/hassaku-xl-illustrious-v30-sdxl": "Ikena/hassaku-xl-illustrious-v30-sdxl",
61
- "RedRayz/hikari_noob_v-pred_1.2.2": "RedRayz/hikari_noob_v-pred_1.2.2",
62
- "bluepen5805/noob_v_pencil-XL": "bluepen5805/noob_v_pencil-XL",
63
- "Laxhar/noobai-XL-1.1": "Laxhar/noobai-XL-1.1"
64
- }
65
- DISPLAY_NAME_TO_BACKEND_MAP = {v: k for k, v in MODEL_DISPLAY_NAME_MAP.items()}
66
-
67
- # --- List of V-Prediction Models ---
68
- V_PREDICTION_MODELS = [
69
- "Laxhar/noobai-XL-Vpred-1.0",
70
- "RedRayz/hikari_noob_v-pred_1.2.2",
71
- "bluepen5805/noob_v_pencil-XL"
72
- ]
73
-
74
- # --- Dictionary for single-file models now stores the filename ---
75
- SINGLE_FILE_MODELS = {
76
- "bluepen5805/noob_v_pencil-XL": "noob_v_pencil-XL-v3.0.0.safetensors"
77
- }
78
-
79
- # --- Model Hash to Name Mapping ---
80
- HASH_TO_MODEL_MAP = {
81
- "bdb59bac77": "dhead/wai-nsfw-illustrious-sdxl-v140-sdxl",
82
- "ea349eeae8": "Laxhar/noobai-XL-Vpred-1.0",
83
- "b4fb5f829a": "John6666/hassaku-xl-illustrious-v30-sdxl",
84
- "6681e8e4b1": "Laxhar/noobai-XL-1.1",
85
- "90b7911a78": "bluepen5805/noob_v_pencil-XL",
86
- "874170688a": "RedRayz/hikari_noob_v-pred_1.2.2"
87
- }
88
- MODEL_TO_HASH_MAP = {v: k for k, v in HASH_TO_MODEL_MAP.items()}
89
-
90
-
91
- def download_all_base_models_on_startup():
92
- """Downloads all base models listed in MODEL_LIST when the app starts."""
93
- print("--- Starting pre-download of all base models ---")
94
- for model_name in MODEL_LIST:
95
- try:
96
- print(f"Downloading: {model_name}...")
97
- start_time = time.time()
98
- if model_name in SINGLE_FILE_MODELS:
99
- filename = SINGLE_FILE_MODELS[model_name]
100
- hf_hub_download(repo_id=model_name, filename=filename)
101
- else:
102
- snapshot_download(repo_id=model_name, ignore_patterns=["*.onnx", "*.flax"])
103
- end_time = time.time()
104
- print(f"✅ Successfully downloaded {model_name} in {end_time - start_time:.2f} seconds.")
105
- except Exception as e:
106
- print(f"❌ Failed to download {model_name}: {e}")
107
- finally:
108
- gc.collect()
109
- if torch.cuda.is_available():
110
- torch.cuda.empty_cache()
111
- print("--- Finished pre-downloading all base models ---")
112
-
113
- def get_civitai_file_info(version_id):
114
- """Gets the file metadata for a model version via the Civitai API."""
115
- api_url = f"https://civitai.com/api/v1/model-versions/{version_id}"
116
- try:
117
- response = requests.get(api_url, timeout=10)
118
- response.raise_for_status()
119
- data = response.json()
120
- for file_data in data.get('files', []):
121
- if file_data.get('type') == 'Model' and file_data['name'].endswith('.safetensors'):
122
- return file_data
123
- if data.get('files'):
124
- return data['files'][0]
125
- return None
126
- except Exception as e:
127
- print(f"Could not get file info from Civitai API: {e}")
128
- return None
129
-
130
- def get_tensorart_file_info(model_id):
131
- """Gets the file metadata for a model via the TensorArt API."""
132
- api_url = f"https://tensor.art/api/v1/models/{model_id}"
133
- try:
134
- response = requests.get(api_url, timeout=10)
135
- response.raise_for_status()
136
- data = response.json()
137
- model_versions = data.get('modelVersions', [])
138
- if not model_versions: return None
139
- for file_data in model_versions[0].get('files', []):
140
- if file_data['name'].endswith('.safetensors'):
141
- return file_data
142
- return model_versions[0]['files'][0] if model_versions[0].get('files') else None
143
- except Exception as e:
144
- print(f"Could not get file info from TensorArt API: {e}")
145
- return None
146
-
147
- def download_file(url, save_path, api_key=None, progress=None, desc=""):
148
- """Downloads a file, skipping if it already exists."""
149
- if os.path.exists(save_path):
150
- return f"File already exists: {os.path.basename(save_path)}"
151
-
152
- headers = {}
153
- if api_key and api_key.strip():
154
- headers['Authorization'] = f'Bearer {api_key}'
155
-
156
- try:
157
- if progress: progress(0, desc=desc)
158
- response = requests.get(url, stream=True, headers=headers, timeout=15)
159
- response.raise_for_status()
160
- total_size = int(response.headers.get('content-length', 0))
161
-
162
- with open(save_path, "wb") as f:
163
- downloaded = 0
164
- for chunk in response.iter_content(chunk_size=8192):
165
- f.write(chunk)
166
- if progress and total_size > 0:
167
- downloaded += len(chunk)
168
- progress(downloaded / total_size, desc=desc)
169
- return f"Successfully downloaded: {os.path.basename(save_path)}"
170
- except Exception as e:
171
- if os.path.exists(save_path): os.remove(save_path)
172
- return f"Download failed for {os.path.basename(save_path)}: {e}"
173
-
174
- def get_lora_path(source, id_or_url, civitai_key, tensorart_key, progress):
175
- """Determines the local path for a LoRA, downloading it if necessary."""
176
- if not id_or_url or not id_or_url.strip(): return None, "No ID/URL provided."
177
-
178
- if source == "Civitai":
179
- version_id = id_or_url.strip()
180
- local_path = os.path.join(DOWNLOAD_DIR, f"civitai_{version_id}.safetensors")
181
- if os.path.exists(local_path): return local_path, "File already exists."
182
- file_info = get_civitai_file_info(version_id)
183
- api_key_to_use = civitai_key
184
- source_name = f"Civitai ID {version_id}"
185
- elif source == "TensorArt":
186
- model_id = id_or_url.strip()
187
- local_path = os.path.join(DOWNLOAD_DIR, f"tensorart_{model_id}.safetensors")
188
- if os.path.exists(local_path): return local_path, "File already exists."
189
- file_info = get_tensorart_file_info(model_id)
190
- api_key_to_use = tensorart_key
191
- source_name = f"TensorArt ID {model_id}"
192
- elif source == "Custom URL":
193
- url = id_or_url.strip()
194
- url_hash = hashlib.md5(url.encode()).hexdigest()
195
- local_path = os.path.join(DOWNLOAD_DIR, f"custom_{url_hash}.safetensors")
196
- if os.path.exists(local_path): return local_path, "File already exists."
197
- file_info = {'downloadUrl': url}
198
- api_key_to_use = None
199
- source_name = f"URL {url[:30]}..."
200
- else:
201
- return None, "Invalid source."
202
-
203
- if not file_info: return None, f"Could not get file info for {source_name}."
204
- download_url = file_info.get('downloadUrl')
205
- if not download_url: return None, f"Could not get download link for {source_name}."
206
-
207
- status = download_file(download_url, local_path, api_key=api_key_to_use, progress=progress, desc=f"Downloading {source_name}")
208
- if "Successfully" in status:
209
- return local_path, status
210
- return None, status
211
-
212
-
213
- def pre_download_loras(civitai_api_key, tensorart_api_key, *lora_data, progress=gr.Progress(track_tqdm=True)):
214
- sources, ids, scales, files = lora_data[0::4], lora_data[1::4], lora_data[2::4], lora_data[3::4]
215
- status_log = []
216
-
217
- active_loras_to_download = [
218
- (src, lora_id) for src, lora_id, scale, f in zip(sources, ids, scales, files)
219
- if src in ["Civitai", "TensorArt", "Custom URL"] and lora_id and lora_id.strip() and f is None
220
- ]
221
-
222
- if not active_loras_to_download:
223
- return "No remote LoRAs specified for pre-downloading."
224
-
225
- for i, (source, lora_id) in enumerate(active_loras_to_download):
226
- progress(i / len(active_loras_to_download), desc=f"Processing {source} ID: {lora_id}")
227
- _, status = get_lora_path(source, lora_id, civitai_api_key, tensorart_api_key, progress)
228
- status_log.append(f"* {source} ID {lora_id}: {status}")
229
-
230
- return "\n".join(status_log)
231
-
232
-
233
- def process_long_prompt(compel_proc, prompt, negative_prompt=""):
234
- """Uses Compel to process prompts that may be too long for the standard tokenizer."""
235
- try:
236
- conditioning, pooled = compel_proc([prompt, negative_prompt])
237
- return conditioning, pooled
238
- except Exception:
239
- return None, None
240
-
241
-
242
- def _infer_logic(base_model_name, prompt, negative_prompt, seed, batch_size, width, height, guidance_scale, num_inference_steps,
243
- sampler, schedule_type,
244
- civitai_api_key, tensorart_api_key,
245
- *lora_data,
246
- progress=gr.Progress(track_tqdm=True)):
247
-
248
- pipe = None
249
- try:
250
- progress(0, desc=f"Loading model: {base_model_name}")
251
-
252
- if base_model_name in SINGLE_FILE_MODELS:
253
- filename = SINGLE_FILE_MODELS[base_model_name]
254
- local_path = hf_hub_download(repo_id=base_model_name, filename=filename)
255
- pipe = StableDiffusionXLPipeline.from_single_file(local_path, torch_dtype=torch.float16, use_safetensors=True)
256
- else:
257
- pipe = StableDiffusionXLPipeline.from_pretrained(base_model_name, torch_dtype=torch.float16, use_safetensors=True)
258
- pipe.to(device)
259
-
260
- batch_size = int(batch_size)
261
- seed = int(seed)
262
- pipe.unload_lora_weights()
263
-
264
- scheduler_class = SAMPLER_MAP.get(sampler, EulerAncestralDiscreteScheduler)
265
- scheduler_config = pipe.scheduler.config
266
-
267
- if base_model_name in V_PREDICTION_MODELS: scheduler_config['prediction_type'] = 'v_prediction'
268
- else: scheduler_config['prediction_type'] = 'epsilon'
269
-
270
- scheduler_kwargs = {}
271
- if schedule_type == "Karras" or (schedule_type == "Default" and sampler == "DPM++ 2M Karras"):
272
- scheduler_kwargs['use_karras_sigmas'] = True
273
- elif schedule_type == "Uniform": scheduler_kwargs['use_karras_sigmas'] = False
274
- elif schedule_type == "SGM Uniform": scheduler_kwargs['algorithm_type'] = 'sgm_uniform'
275
- pipe.scheduler = scheduler_class.from_config(scheduler_config, **scheduler_kwargs)
276
-
277
- compel = Compel(tokenizer=[pipe.tokenizer, pipe.tokenizer_2], text_encoder=[pipe.text_encoder, pipe.text_encoder_2],
278
- returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
279
- requires_pooled=[False, True], truncate_long_prompts=False)
280
-
281
- sources, ids, scales, files = lora_data[0::4], lora_data[1::4], lora_data[2::4], lora_data[3::4]
282
- active_loras, active_lora_names_for_meta = [], []
283
-
284
- for i, (source, lora_id, scale, custom_file) in enumerate(zip(sources, ids, scales, files)):
285
- if scale > 0:
286
- local_lora_path = None
287
- lora_name_for_meta = "Unknown LoRA"
288
-
289
- if custom_file is not None:
290
- local_lora_path = custom_file.name
291
- lora_name_for_meta = f"Custom LoRA ({os.path.basename(local_lora_path)}, Weight: {scale})"
292
- elif lora_id and lora_id.strip():
293
- progress(0.05 + (i * 0.05), desc=f"Handling LoRA {i+1} ({source})")
294
- local_lora_path, _ = get_lora_path(source, lora_id, civitai_api_key, tensorart_api_key, progress)
295
- lora_name_for_meta = f"{source} LoRA (ID: {lora_id}, Weight: {scale})"
296
-
297
- if local_lora_path and os.path.exists(local_lora_path):
298
- adapter_name = f"lora_{i+1}"
299
- pipe.load_lora_weights(local_lora_path, adapter_name=adapter_name)
300
- active_loras.append((adapter_name, scale))
301
- active_lora_names_for_meta.append(lora_name_for_meta)
302
- else:
303
- print(f"Skipping LoRA {i+1} as file could not be found or downloaded.")
304
-
305
- if active_loras:
306
- adapter_names, adapter_weights = zip(*active_loras)
307
- pipe.set_adapters(list(adapter_names), list(adapter_weights))
308
-
309
- conditioning, pooled = process_long_prompt(compel, prompt, negative_prompt)
310
-
311
- pipe_args = {"guidance_scale": guidance_scale, "num_inference_steps": num_inference_steps, "width": width, "height": height}
312
- output_images = []
313
- loras_string = f"LoRAs: [{', '.join(active_lora_names_for_meta)}]" if active_lora_names_for_meta else ""
314
-
315
- for i in range(batch_size):
316
- progress(i / batch_size, desc=f"Generating image {i+1}/{batch_size}")
317
- current_seed = seed if i == 0 and seed != -1 else random.randint(0, MAX_SEED)
318
- generator = torch.Generator(device=device).manual_seed(current_seed)
319
- pipe_args["generator"] = generator
320
-
321
- if conditioning is not None:
322
- image = pipe(prompt_embeds=conditioning[0:1], pooled_prompt_embeds=pooled[0:1], negative_prompt_embeds=conditioning[1:2], negative_pooled_prompt_embeds=pooled[1:2], **pipe_args).images[0]
323
- else:
324
- image = pipe(prompt=prompt, negative_prompt=negative_prompt, **pipe_args).images[0]
325
-
326
- model_hash = MODEL_TO_HASH_MAP.get(base_model_name, "N/A")
327
- params_string = f"{prompt}\nNegative prompt: {negative_prompt}\n"
328
- params_string += f"Steps: {num_inference_steps}, Sampler: {sampler}, Schedule type: {schedule_type}, CFG scale: {guidance_scale}, Seed: {current_seed}, Size: {width}x{height}, Base Model: {base_model_name}, Model hash: {model_hash}, {loras_string}".strip()
329
- image.info = {'parameters': params_string}
330
- output_images.append(image)
331
-
332
- return output_images
333
-
334
- except Exception as e:
335
- print(f"An error occurred during generation: {e}")
336
- error_str = str(e).lower()
337
- if "dora_scale" in error_str and "not compatible in diffusers" in error_str:
338
- raise gr.Error("This LoRA appears to be a DoRA model. Diffusers currently has limited support for this format, which may cause errors.")
339
- raise gr.Error(f"Generation failed: {e}")
340
- finally:
341
- if pipe is not None:
342
- pipe.disable_lora()
343
- del pipe
344
- gc.collect()
345
- if torch.cuda.is_available(): torch.cuda.empty_cache()
346
-
347
- def infer(base_model_display_name, prompt, negative_prompt, seed, batch_size, width, height, guidance_scale, num_inference_steps,
348
- sampler, schedule_type, civitai_api_key, tensorart_api_key, zero_gpu_duration, *lora_data,
349
- progress=gr.Progress(track_tqdm=True)):
350
-
351
- base_model_name = DISPLAY_NAME_TO_BACKEND_MAP.get(base_model_display_name, base_model_display_name)
352
- duration = 60
353
- if zero_gpu_duration and int(zero_gpu_duration) > 0: duration = int(zero_gpu_duration)
354
- print(f"Using ZeroGPU duration: {duration} seconds")
355
-
356
- decorated_infer_logic = spaces.GPU(duration=duration)(_infer_logic)
357
-
358
- return decorated_infer_logic(
359
- base_model_name, prompt, negative_prompt, seed, batch_size, width, height, guidance_scale, num_inference_steps,
360
- sampler, schedule_type, civitai_api_key, tensorart_api_key, *lora_data, progress=progress
361
- )
362
-
363
- def _parse_parameters(params_text):
364
- data = {'lora_ids': [''] * MAX_LORAS, 'lora_scales': [0.0] * MAX_LORAS}
365
- lines = params_text.strip().split('\n')
366
- data['prompt'] = lines[0]
367
- data['negative_prompt'] = lines[1].replace("Negative prompt:", "").strip() if len(lines) > 1 and lines[1].startswith("Negative prompt:") else ""
368
- params_line = lines[2] if len(lines) > 2 else ""
369
-
370
- def find_param(key, default, cast_type=str):
371
- match = re.search(fr"\b{key}: ([^,]+?)(,|$)", params_line)
372
- return cast_type(match.group(1).strip()) if match else default
373
-
374
- data['steps'] = find_param("Steps", 28, int)
375
- data['sampler'] = find_param("Sampler", DEFAULT_SAMPLER, str)
376
- data['schedule_type'] = find_param("Schedule type", DEFAULT_SCHEDULE_TYPE, str)
377
- data['cfg_scale'] = find_param("CFG scale", 7.0, float)
378
- data['seed'] = find_param("Seed", -1, int)
379
- data['base_model'] = find_param("Base Model", MODEL_LIST[0], str)
380
- data['model_hash'] = find_param("Model hash", None, str)
381
-
382
- size_match = re.search(r"Size: (\d+)x(\d+)", params_line); data['width'], data['height'] = (int(size_match.group(1)), int(size_match.group(2))) if size_match else (1024, 1024)
383
- return data
384
-
385
- def get_png_info(image):
386
- if image is None: return "", "", "Please upload an image first."
387
- params = image.info.get('parameters', None)
388
- if not params: return "", "", "No metadata found in the image."
389
- try:
390
- parsed_data = _parse_parameters(params)
391
- lines = params.strip().split('\n')
392
- other_params_text = lines[2] if len(lines) > 2 else ""
393
- other_params_display = "\n".join([p.strip() for p in other_params_text.split(',')])
394
- return parsed_data.get('prompt', ''), parsed_data.get('negative_prompt', ''), other_params_display
395
- except Exception as e:
396
- return "", "", f"Error parsing metadata: {e}\n\nRaw metadata:\n{params}"
397
-
398
- def send_info_to_txt2img(image):
399
- if image is None or not (params := image.info.get('parameters', '')):
400
- num_lora_params = MAX_LORAS * 4
401
- num_other_params = 12
402
- num_api_keys = 2
403
- return [gr.update()] * (num_other_params + num_api_keys + num_lora_params + 1)
404
-
405
- data = _parse_parameters(params)
406
-
407
- model_from_hash = HASH_TO_MODEL_MAP.get(data.get('model_hash'))
408
- backend_base_model = model_from_hash if model_from_hash else data.get('base_model', MODEL_LIST[0])
409
-
410
- final_display_model = MODEL_DISPLAY_NAME_MAP.get(backend_base_model, backend_base_model)
411
- final_sampler = data.get('sampler', DEFAULT_SAMPLER)
412
-
413
- schedule_from_png = data.get('schedule_type', DEFAULT_SCHEDULE_TYPE)
414
- final_schedule_type = schedule_from_png if schedule_from_png in SCHEDULE_TYPE_MAP else DEFAULT_SCHEDULE_TYPE
415
-
416
- updates = [final_display_model, data['prompt'], data['negative_prompt'], data['seed'], gr.update(), gr.update(), data['width'], data['height'],
417
- data['cfg_scale'], data['steps'], final_sampler, final_schedule_type, gr.update(), gr.update()]
418
-
419
- for i in range(MAX_LORAS):
420
- updates.extend([gr.update(), gr.update(), gr.update(), gr.update()])
421
- updates.append(gr.Tabs(selected=0))
422
- return updates
423
-
424
- # --- Execute model download on startup ---
425
- download_all_base_models_on_startup()
426
-
427
-
428
- with gr.Blocks(css="#col-container {margin: 0 auto; max-width: 1024px;}") as demo:
429
- gr.Markdown("# Animated SDXL T2I with LoRAs")
430
- with gr.Tabs(elem_id="tabs_container") as tabs:
431
- with gr.TabItem("txt2img", id=0):
432
- gr.Markdown("<div style='background-color: #282828; color: #a0aec0; padding: 10px; border-radius: 5px; margin-bottom: 15px;'>💡 <b>Tip:</b> Pre-downloading LoRAs before 'Run' can maximize ZeroGPU time.</div>")
433
- with gr.Column(elem_id="col-container"):
434
- with gr.Row():
435
- with gr.Column(scale=3):
436
- default_backend_model = "Laxhar/noobai-XL-Vpred-1.0"
437
- default_display_name = MODEL_DISPLAY_NAME_MAP.get(default_backend_model, default_backend_model)
438
- base_model_name_input = gr.Dropdown(label="Base Model", choices=list(MODEL_DISPLAY_NAME_MAP.values()), value=default_display_name)
439
- with gr.Column(scale=1):
440
- predownload_lora_button = gr.Button("Pre-download LoRAs")
441
- run_button = gr.Button("Run", variant="primary")
442
-
443
- predownload_status = gr.Markdown("")
444
- prompt = gr.Text(label="Prompt", lines=3, placeholder="Enter your prompt")
445
- negative_prompt = gr.Text(label="Negative prompt", lines=3, placeholder="Enter a negative prompt", value=DEFAULT_NEGATIVE_PROMPT)
446
-
447
- with gr.Row():
448
- with gr.Column(scale=2):
449
- with gr.Row():
450
- width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
451
- height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
452
- with gr.Row():
453
- sampler = gr.Dropdown(label="Sampling method", choices=list(SAMPLER_MAP.keys()), value=DEFAULT_SAMPLER)
454
- schedule_type = gr.Dropdown(label="Schedule type", choices=SCHEDULE_TYPE_MAP, value=DEFAULT_SCHEDULE_TYPE)
455
- with gr.Row():
456
- guidance_scale = gr.Slider(label="CFG Scale", minimum=0.0, maximum=20.0, step=0.1, value=7)
457
- num_inference_steps = gr.Slider(label="Sampling steps", minimum=1, maximum=50, step=1, value=28)
458
- with gr.Column(scale=1):
459
- result = gr.Gallery(label="Result", show_label=False, elem_id="result_gallery", columns=2, object_fit="contain", height="auto")
460
-
461
- with gr.Row():
462
- seed = gr.Number(label="Seed (-1 for random)", value=-1, precision=0)
463
- batch_size = gr.Slider(label="Batch size", minimum=1, maximum=8, step=1, value=1)
464
- zero_gpu_duration = gr.Number(label="ZeroGPU Duration (s)", value=None, placeholder="Default: 60s", info="Optional: Leave empty for default (60s), max to 120")
465
-
466
- with gr.Accordion("LoRA Settings", open=False):
467
- gr.Markdown("⚠️ **Responsible Use Notice:** Please avoid excessive, rapid, or automated (scripted) use of the pre-download LoRA feature. Overt misuse may lead to service disruption. Thank you for your cooperation.")
468
-
469
- gr.Markdown("For LoRAs that require login to download, you may need to enter the corresponding API Key.")
470
- with gr.Row():
471
- with gr.Column(scale=1):
472
- gr.Markdown("**Civitai API Key**")
473
- civitai_api_key = gr.Textbox(show_label=False, placeholder="Enter your Civitai API Key here", type="password", container=False)
474
- with gr.Column(scale=1):
475
- gr.Markdown("**TensorArt API Key**")
476
- tensorart_api_key = gr.Textbox(show_label=False, placeholder="Enter your TensorArt API Key here", type="password", container=False)
477
-
478
- gr.Markdown("---")
479
- gr.Markdown("For each LoRA, choose a source, provide an ID/URL, or upload a file.")
480
-
481
- gr.Markdown("""
482
- <div style='background-color: #282828; color: #a0aec0; padding: 10px; border-radius: 5px; margin-top: 10px; margin-bottom: 15px;'>
483
- <b>Input Examples:</b>
484
- <ul>
485
- <li><b>Civitai:</b> Enter the <b>Model Version ID</b>, not the Model ID. Example: <code>133755</code> (Found in the URL, e.g., <code>civitai.com/models/122136?modelVersionId=<b>133755</b></code>)</li>
486
- <li><b>TensorArt:</b> Enter the <b>Model ID</b>. Example: <code>706684852832599558</code> (Found in the URL, e.g., <code>tensor.art/models/<b>706684852832599558</b></code>)</li>
487
- <li><b>Custom URL:</b> Provide a direct download link to a <code>.safetensors</code> file. Example: <code>https://huggingface.co/path/to/your/lora.safetensors</code></li>
488
- <li><b>File:</b> Use the "Upload" button. The source will be set automatically.</li>
489
- </ul>
490
- </div>
491
- """)
492
-
493
- gr.Markdown("""
494
- <div style='background-color: #282828; color: #a0aec0; padding: 10px; border-radius: 5px; margin-bottom: 15px;'>
495
- <b>TODO:</b>
496
- <ul style='margin-bottom: 0;'>
497
- <li>When uploading a local LoRA, the page may not respond, but it is transferring. Please be patient. This issue is pending a fix.</li>
498
- </ul>
499
- </div>
500
- """)
501
-
502
- lora_rows = []
503
- lora_source_inputs, lora_id_inputs, lora_scale_inputs, lora_upload_buttons = [], [], [], []
504
-
505
- for i in range(MAX_LORAS):
506
- with gr.Row(visible=(i == 0)) as row:
507
- with gr.Column(scale=1, min_width=120):
508
- lora_source = gr.Dropdown(label=f"LoRA {i+1} Source", choices=LORA_SOURCE_CHOICES, value="Civitai")
509
- with gr.Column(scale=2, min_width=160):
510
- lora_id = gr.Textbox(label="ID / URL / Uploaded File", placeholder="e.g.: 133755")
511
- with gr.Column(scale=2, min_width=220):
512
- lora_scale = gr.Slider(label="Weight", minimum=0.0, maximum=2.0, step=0.05, value=0.0)
513
- with gr.Column(scale=1, min_width=80):
514
- lora_upload = gr.UploadButton("Upload", file_types=[".safetensors"])
515
-
516
- lora_rows.append(row)
517
- lora_source_inputs.append(lora_source)
518
- lora_id_inputs.append(lora_id)
519
- lora_scale_inputs.append(lora_scale)
520
- lora_upload_buttons.append(lora_upload)
521
-
522
- lora_upload.upload(
523
- fn=lambda f: (os.path.basename(f.name), "File") if f else (gr.update(), gr.update()),
524
- inputs=[lora_upload],
525
- outputs=[lora_id, lora_source]
526
- )
527
-
528
- with gr.Row():
529
- add_lora_button = gr.Button("✚ Add LoRA", variant="secondary")
530
- delete_lora_button = gr.Button("➖ Delete LoRA", variant="secondary", visible=False)
531
-
532
- lora_count_state = gr.State(value=1)
533
- all_lora_components_flat = [item for sublist in zip(lora_source_inputs, lora_id_inputs, lora_scale_inputs, lora_upload_buttons) for item in sublist]
534
-
535
-
536
- with gr.TabItem("PNG Info", id=1):
537
- with gr.Column(elem_id="col-container"):
538
- gr.Markdown("Upload a generated image to view its generation data.")
539
- info_image_input = gr.Image(type="pil", label="Upload Image")
540
- with gr.Row():
541
- info_get_button = gr.Button("Get Info", variant="secondary")
542
- send_to_txt2img_button = gr.Button("Send to txt2img", variant="primary")
543
- gr.Markdown("### Positive Prompt"); info_prompt_output = gr.Textbox(lines=3, interactive=False, show_label=False)
544
- gr.Markdown("### Negative Prompt"); info_neg_prompt_output = gr.Textbox(lines=3, interactive=False, show_label=False)
545
- gr.Markdown("### Other Parameters"); info_params_output = gr.Textbox(lines=5, interactive=False, show_label=False)
546
-
547
- gr.Markdown("<div style='text-align: center; margin-top: 20px;'>Made by <a href='https://civitai.com/user/RioShiina'>RioShiina</a> with ❤️</div>")
548
-
549
- # --- Event Handlers ---
550
- def add_lora_row(current_count):
551
- current_count = int(current_count)
552
- if current_count < MAX_LORAS:
553
- return {
554
- lora_count_state: current_count + 1,
555
- lora_rows[current_count]: gr.update(visible=True),
556
- delete_lora_button: gr.update(visible=True),
557
- add_lora_button: gr.update(visible=False) if (current_count + 1 == MAX_LORAS) else gr.update(visible=True)
558
- }
559
- return {}
560
-
561
- def delete_lora_row(current_count):
562
- current_count = int(current_count)
563
- if current_count > 1:
564
- row_index_to_hide = current_count - 1
565
- return {
566
- lora_count_state: current_count - 1,
567
- lora_rows[row_index_to_hide]: gr.update(visible=False),
568
- lora_id_inputs[row_index_to_hide]: gr.update(value=""),
569
- lora_scale_inputs[row_index_to_hide]: gr.update(value=0.0),
570
- add_lora_button: gr.update(visible=True),
571
- delete_lora_button: gr.update(visible=False) if (current_count - 1 == 1) else gr.update(visible=True)
572
- }
573
- return {}
574
-
575
- def start_lora_predownload():
576
- return "⏳ Downloading... please wait. This may take a moment."
577
-
578
- predownload_lora_button.click(
579
- fn=start_lora_predownload,
580
- inputs=None,
581
- outputs=[predownload_status],
582
- queue=False
583
- ).then(
584
- fn=pre_download_loras,
585
- inputs=[civitai_api_key, tensorart_api_key, *all_lora_components_flat],
586
- outputs=[predownload_status]
587
- )
588
-
589
- add_lora_button.click(
590
- fn=add_lora_row,
591
- inputs=[lora_count_state],
592
- outputs=[lora_count_state, add_lora_button, delete_lora_button, *lora_rows]
593
- )
594
-
595
- delete_lora_button.click(
596
- fn=delete_lora_row,
597
- inputs=[lora_count_state],
598
- outputs=[
599
- lora_count_state,
600
- add_lora_button,
601
- delete_lora_button,
602
- *lora_rows,
603
- *lora_id_inputs,
604
- *lora_scale_inputs
605
- ]
606
- )
607
-
608
- run_button_inputs = [base_model_name_input, prompt, negative_prompt, seed, batch_size, width, height, guidance_scale, num_inference_steps, sampler, schedule_type, civitai_api_key, tensorart_api_key, zero_gpu_duration, *all_lora_components_flat]
609
- run_button.click(fn=infer, inputs=run_button_inputs, outputs=[result])
610
-
611
- info_get_button.click(fn=get_png_info, inputs=[info_image_input], outputs=[info_prompt_output, info_neg_prompt_output, info_params_output])
612
-
613
- txt2img_outputs = [
614
- base_model_name_input, prompt, negative_prompt, seed, batch_size,
615
- zero_gpu_duration, width, height, guidance_scale, num_inference_steps,
616
- sampler, schedule_type, civitai_api_key, tensorart_api_key,
617
- *all_lora_components_flat, tabs
618
- ]
619
- send_to_txt2img_button.click(fn=send_info_to_txt2img, inputs=[info_image_input], outputs=txt2img_outputs)
620
-
621
- demo.queue().launch()
 
1
+ import os
2
+ import random
3
+ import sys
4
+ from typing import Sequence, Mapping, Any, Union
5
+ import torch
6
+ import gradio as gr
7
+ from PIL import Image
8
+ from huggingface_hub import hf_hub_download
9
+ import spaces
10
+ from comfy import model_management # We need to import this early
11
+ import gc
12
+ import requests
13
+ import re
14
+ import hashlib
15
+ import shutil
16
+
17
+ # --- Startup Dummy Function ---
18
+ @spaces.GPU(duration=60)
19
+ def dummy_gpu_for_startup():
20
+ print("Dummy function for startup check executed. This is normal.")
21
+ return "Startup check passed."
22
+
23
+ # --- ComfyUI Backend Setup ---
24
+ def find_path(name: str, path: str = None) -> str:
25
+ if path is None: path = os.getcwd()
26
+ if name in os.listdir(path): return os.path.join(path, name)
27
+ parent_directory = os.path.dirname(path)
28
+ if parent_directory == path: return None
29
+ return find_path(name, parent_directory)
30
+
31
+ def add_comfyui_directory_to_sys_path() -> None:
32
+ comfyui_path = find_path("ComfyUI")
33
+ if comfyui_path and os.path.isdir(comfyui_path):
34
+ sys.path.append(comfyui_path)
35
+ print(f"'{comfyui_path}' added to sys.path")
36
+
37
+ def add_extra_model_paths() -> None:
38
+ try: from main import load_extra_path_config
39
+ except ImportError: from utils.extra_config import load_extra_path_config
40
+ extra_model_paths = find_path("extra_model_paths.yaml")
41
+ if extra_model_paths: load_extra_path_config(extra_model_paths)
42
+ else: print("Could not find extra_model_paths.yaml")
43
+
44
+ add_comfyui_directory_to_sys_path()
45
+ add_extra_model_paths()
46
+
47
+ # Monkey-patch for Sage Attention
48
+ print("Attempting to monkey-patch ComfyUI for Sage Attention...")
49
+ try:
50
+ model_management.sage_attention_enabled = lambda: True
51
+ model_management.pytorch_attention_enabled = lambda: False
52
+ print("Successfully monkey-patched model_management for Sage Attention.")
53
+ except Exception as e:
54
+ print(f"An error occurred during monkey-patching: {e}")
55
+
56
+ # --- Constants & Configuration ---
57
+ CHECKPOINT_DIR = "models/checkpoints"
58
+ LORA_DIR = "models/loras"
59
+ os.makedirs(CHECKPOINT_DIR, exist_ok=True)
60
+ os.makedirs(LORA_DIR, exist_ok=True)
61
+
62
+ # --- Model Definitions with Hashes ---
63
+ # Format: {Display Name: (Repo ID, Filename, Type, Hash)}
64
+ MODEL_MAP_ILLUSTRIOUS = {
65
+ "Laxhar/noobai-XL-Vpred-1.0": ("Laxhar/noobai-XL-Vpred-1.0", "NoobAI-XL-Vpred-v1.0.safetensors", "SDXL", "ea349eeae8"),
66
+ "Laxhar/noobai-XL-1.1": ("Laxhar/noobai-XL-1.1", "NoobAI-XL-v1.1.safetensors", "SDXL", "6681e8e4b1"),
67
+ "WAI0731/wai-nsfw-illustrious-sdxl-v140": ("Ine007/waiNSFWIllustrious_v140", "waiNSFWIllustrious_v140.safetensors", "SDXL", "bdb59bac77"),
68
+ "Ikena/hassaku-xl-illustrious-v30": ("misri/hassakuXLIllustrious_v30", "hassakuXLIllustrious_v30.safetensors", "SDXL", "b4fb5f829a"),
69
+ "bluepen5805/noob_v_pencil-XL": ("bluepen5805/noob_v_pencil-XL", "noob_v_pencil-XL-v3.0.0.safetensors", "SDXL", "90b7911a78"),
70
+ "RedRayz/hikari_noob_v-pred_1.2.2": ("RedRayz/hikari_noob_v-pred_1.2.2", "Hikari_Noob_v-pred_1.2.2.safetensors", "SDXL", "874170688a"),
71
+ }
72
+ MODEL_MAP_ANIMAGINE = {
73
+ "cagliostrolab/animagine-xl-4.0": ("cagliostrolab/animagine-xl-4.0", "animagine-xl-4.0.safetensors", "SDXL", "6327eca98b"),
74
+ "cagliostrolab/animagine-xl-3.1": ("cagliostrolab/animagine-xl-3.1", "animagine-xl-3.1.safetensors", "SDXL", "e3c47aedb0"),
75
+ }
76
+ MODEL_MAP_PONY = {
77
+ "PurpleSmartAI/Pony_Diffusion_V6_XL": ("LyliaEngine/Pony_Diffusion_V6_XL", "ponyDiffusionV6XL_v6StartWithThisOne.safetensors", "SDXL", "67ab2fd8ec"),
78
+ }
79
+ MODEL_MAP_SD15 = {
80
+ "Yuno779/anything-v3": ("ckpt/anything-v3.0", "Anything-V3.0-pruned.safetensors", "SD1.5", "ddd565f806"),
81
+ }
82
+
83
+ # --- Combined Maps for Global Lookup ---
84
+ ALL_MODEL_MAP = {**MODEL_MAP_ILLUSTRIOUS, **MODEL_MAP_ANIMAGINE, **MODEL_MAP_PONY, **MODEL_MAP_SD15}
85
+ MODEL_TYPE_MAP = {k: v[2] for k, v in ALL_MODEL_MAP.items()}
86
+ DISPLAY_NAME_TO_HASH_MAP = {k: v[3] for k, v in ALL_MODEL_MAP.items()}
87
+ HASH_TO_DISPLAY_NAME_MAP = {v[3]: k for k, v in ALL_MODEL_MAP.items()}
88
+
89
+ # --- UI Defaults ---
90
+ DEFAULT_NEGATIVE_PROMPT = "monochrome, (low quality, worst quality:1.2), 3d, watermark, signature, ugly, poorly drawn,"
91
+ MAX_LORAS = 5
92
+ LORA_SOURCE_CHOICES = ["Civitai", "TensorArt", "Custom URL", "File"]
93
+
94
+ def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any:
95
+ try: return obj[index]
96
+ except (KeyError, IndexError):
97
+ try: return obj["result"][index]
98
+ except (KeyError, IndexError): return None
99
+
100
+ def import_custom_nodes() -> None:
101
+ import asyncio, execution, server
102
+ from nodes import init_extra_nodes
103
+ loop = asyncio.new_event_loop()
104
+ asyncio.set_event_loop(loop)
105
+ server_instance = server.PromptServer(loop)
106
+ execution.PromptQueue(server_instance)
107
+ loop.run_until_complete(init_extra_nodes())
108
+
109
+ # --- Import ComfyUI Nodes & Get Choices ---
110
+ from nodes import CheckpointLoaderSimple, EmptyLatentImage, KSampler, VAEDecode, SaveImage, NODE_CLASS_MAPPINGS
111
+ import_custom_nodes()
112
+ CLIPTextEncodeSDXL = NODE_CLASS_MAPPINGS['CLIPTextEncodeSDXL']
113
+ CLIPTextEncode = NODE_CLASS_MAPPINGS['CLIPTextEncode']
114
+ LoraLoader = NODE_CLASS_MAPPINGS['LoraLoader']
115
+ CLIPSetLastLayer = NODE_CLASS_MAPPINGS['CLIPSetLastLayer']
116
+ try:
117
+ SAMPLER_CHOICES = KSampler.INPUT_TYPES()["required"]["sampler_name"][0]
118
+ SCHEDULER_CHOICES = KSampler.INPUT_TYPES()["required"]["scheduler"][0]
119
+ except Exception:
120
+ SAMPLER_CHOICES = ['euler', 'dpmpp_2m_sde_gpu']
121
+ SCHEDULER_CHOICES = ['normal', 'karras']
122
+
123
+ # --- Instantiate Node Objects ---
124
+ checkpointloadersimple = CheckpointLoaderSimple(); cliptextencodesdxl = CLIPTextEncodeSDXL()
125
+ cliptextencode_sd15 = CLIPTextEncode(); emptylatentimage = EmptyLatentImage()
126
+ ksampler = KSampler(); vaedecode = VAEDecode(); saveimage = SaveImage(); loraloader = LoraLoader()
127
+ clipsetlastlayer = CLIPSetLastLayer()
128
+
129
+ # --- LoRA & File Utils ---
130
+ def get_civitai_file_info(version_id):
131
+ api_url = f"https://civitai.com/api/v1/model-versions/{version_id}"
132
+ try:
133
+ response = requests.get(api_url, timeout=10); response.raise_for_status(); data = response.json()
134
+ for file_data in data.get('files', []):
135
+ if file_data.get('type') == 'Model' and file_data['name'].endswith('.safetensors'): return file_data
136
+ if data.get('files'): return data['files'][0]
137
+ except Exception: return None
138
+
139
+ def get_tensorart_file_info(model_id):
140
+ api_url = f"https://tensor.art/api/v1/models/{model_id}"
141
+ try:
142
+ response = requests.get(api_url, timeout=10); response.raise_for_status(); data = response.json()
143
+ model_versions = data.get('modelVersions', [])
144
+ if not model_versions: return None
145
+ for file_data in model_versions[0].get('files', []):
146
+ if file_data['name'].endswith('.safetensors'): return file_data
147
+ return model_versions[0]['files'][0] if model_versions[0].get('files') else None
148
+ except Exception: return None
149
+
150
+ def download_file(url, save_path, api_key=None, progress=None, desc=""):
151
+ if os.path.exists(save_path): return f"File already exists: {os.path.basename(save_path)}"
152
+ headers = {'Authorization': f'Bearer {api_key}'} if api_key and api_key.strip() else {}
153
+ try:
154
+ if progress: progress(0, desc=desc)
155
+ response = requests.get(url, stream=True, headers=headers, timeout=15); response.raise_for_status()
156
+ total_size = int(response.headers.get('content-length', 0))
157
+ with open(save_path, "wb") as f:
158
+ downloaded = 0
159
+ for chunk in response.iter_content(chunk_size=8192):
160
+ f.write(chunk)
161
+ if progress and total_size > 0: downloaded += len(chunk); progress(downloaded / total_size, desc=desc)
162
+ return f"Successfully downloaded: {os.path.basename(save_path)}"
163
+ except Exception as e:
164
+ if os.path.exists(save_path): os.remove(save_path)
165
+ return f"Download failed for {os.path.basename(save_path)}: {e}"
166
+
167
+ def get_lora_path(source, id_or_url, civitai_key, tensorart_key, progress):
168
+ if not id_or_url or not id_or_url.strip(): return None, "No ID/URL provided."
169
+ if source == "Civitai":
170
+ version_id = id_or_url.strip(); local_path = os.path.join(LORA_DIR, f"civitai_{version_id}.safetensors"); file_info, api_key_to_use = get_civitai_file_info(version_id), civitai_key; source_name = f"Civitai ID {version_id}"
171
+ elif source == "TensorArt":
172
+ model_id = id_or_url.strip(); local_path = os.path.join(LORA_DIR, f"tensorart_{model_id}.safetensors"); file_info, api_key_to_use = get_tensorart_file_info(model_id), tensorart_key; source_name = f"TensorArt ID {model_id}"
173
+ elif source == "Custom URL":
174
+ url = id_or_url.strip(); url_hash = hashlib.md5(url.encode()).hexdigest(); local_path = os.path.join(LORA_DIR, f"custom_{url_hash}.safetensors"); file_info, api_key_to_use = {'downloadUrl': url}, None; source_name = f"URL {url[:30]}..."
175
+ else: return None, "Invalid source."
176
+ if os.path.exists(local_path): return local_path, "File already exists."
177
+ if not file_info or not file_info.get('downloadUrl'): return None, f"Could not get download link for {source_name}."
178
+ status = download_file(file_info['downloadUrl'], local_path, api_key_to_use, progress=progress, desc=f"Downloading {source_name}")
179
+ return (local_path, status) if "Successfully" in status else (None, status)
180
+
181
+ def pre_download_loras(civitai_api_key, tensorart_api_key, *lora_data, progress=gr.Progress(track_tqdm=True)):
182
+ sources, ids, _, files = lora_data[0::4], lora_data[1::4], lora_data[2::4], lora_data[3::4]
183
+ active_loras = [(s, i) for s, i, f in zip(sources, ids, files) if s in ["Civitai", "TensorArt", "Custom URL"] and i and i.strip() and f is None]
184
+ if not active_loras: return "No remote LoRAs specified for pre-downloading."
185
+ log = [f"* {s} ID {i}: {get_lora_path(s, i, civitai_api_key, tensorart_api_key, progress)[1]}" for s, i in active_loras]
186
+ return "\n".join(log)
187
+
188
+ # --- Model Management & Core Logic ---
189
+ current_loaded_model_name = None; loaded_checkpoint_tuple = None
190
+ def load_model(model_display_name: str, progress=gr.Progress()):
191
+ global current_loaded_model_name, loaded_checkpoint_tuple
192
+ if model_display_name == current_loaded_model_name and loaded_checkpoint_tuple: return loaded_checkpoint_tuple
193
+ if loaded_checkpoint_tuple: model_management.unload_all_models(); loaded_checkpoint_tuple = None; gc.collect(); torch.cuda.empty_cache()
194
+
195
+ repo_id, filename, _, _ = ALL_MODEL_MAP[model_display_name]
196
+ local_file_path = os.path.join(CHECKPOINT_DIR, filename)
197
+
198
+ if not os.path.exists(local_file_path):
199
+ progress(0, desc=f"Downloading model: {model_display_name}")
200
+ hf_hub_download(repo_id=repo_id, filename=filename, local_dir=CHECKPOINT_DIR, local_dir_use_symlinks=False)
201
+
202
+ progress(0.5, desc=f"Loading '{filename}'")
203
+ MODEL_TUPLE = checkpointloadersimple.load_checkpoint(ckpt_name=filename)
204
+ model_management.load_models_gpu([get_value_at_index(MODEL_TUPLE, 0)])
205
+ current_loaded_model_name = model_display_name; loaded_checkpoint_tuple = MODEL_TUPLE
206
+ progress(1.0, desc="Model loaded"); return loaded_checkpoint_tuple
207
+
208
+ def _generate_image_logic(model_display_name: str, positive_prompt: str, negative_prompt: str,
209
+ seed: int, batch_size: int, width: int, height: int, guidance_scale: float, num_inference_steps: int,
210
+ sampler_name: str, scheduler: str, civitai_api_key: str, tensorart_api_key: str, *lora_data,
211
+ progress=gr.Progress(track_tqdm=True)):
212
+ output_images = []
213
+ is_sd15 = MODEL_TYPE_MAP.get(model_display_name) == "SD1.5"
214
+ clip_skip = 1
215
+ if is_sd15 and len(lora_data) > MAX_LORAS * 4:
216
+ clip_skip = int(lora_data[-1])
217
+ lora_data = lora_data[:-1]
218
+
219
+ with torch.inference_mode():
220
+ model_tuple = load_model(model_display_name, progress)
221
+ model, clip, vae = (get_value_at_index(model_tuple, i) for i in range(3))
222
+
223
+ if is_sd15:
224
+ clip = get_value_at_index(clipsetlastlayer.set_last_layer(clip=clip, stop_at_clip_layer=-clip_skip), 0)
225
+
226
+ active_loras_for_meta = []
227
+ sources, ids, scales, files = lora_data[0::4], lora_data[1::4], lora_data[2::4], lora_data[3::4]
228
+ for i, (source, lora_id, scale, custom_file) in enumerate(zip(sources, ids, scales, files)):
229
+ if scale > 0:
230
+ lora_filename = None
231
+ if custom_file:
232
+ lora_filename = os.path.basename(custom_file.name)
233
+ shutil.copy(custom_file.name, LORA_DIR)
234
+ elif lora_id and lora_id.strip():
235
+ local_path, _ = get_lora_path(source, lora_id, civitai_api_key, tensorart_api_key, progress)
236
+ if local_path: lora_filename = os.path.basename(local_path)
237
+
238
+ if lora_filename:
239
+ lora_tuple = loraloader.load_lora(model=model, clip=clip, lora_name=lora_filename, strength_model=scale, strength_clip=scale)
240
+ model, clip = get_value_at_index(lora_tuple, 0), get_value_at_index(lora_tuple, 1)
241
+ active_loras_for_meta.append(f"{source} {lora_id}:{scale}")
242
+
243
+ loras_string = f"LoRAs: [{', '.join(active_loras_for_meta)}]" if active_loras_for_meta else ""
244
+
245
+ if is_sd15:
246
+ pos_cond = cliptextencode_sd15.encode(text=positive_prompt, clip=clip)
247
+ neg_cond = cliptextencode_sd15.encode(text=negative_prompt, clip=clip)
248
+ else:
249
+ pos_cond = cliptextencodesdxl.encode(width=width, height=height, text_g=positive_prompt, text_l=positive_prompt, clip=clip, target_width=width, target_height=height, crop_w=0, crop_h=0)
250
+ neg_cond = cliptextencodesdxl.encode(width=width, height=height, text_g=negative_prompt, text_l=negative_prompt, clip=clip, target_width=width, target_height=height, crop_w=0, crop_h=0)
251
+
252
+ start_seed = seed if seed != -1 else random.randint(0, 2**64 - 1)
253
+
254
+ latent = emptylatentimage.generate(width=width, height=height, batch_size=batch_size)
255
+
256
+ sampled = ksampler.sample(
257
+ seed=start_seed,
258
+ steps=num_inference_steps,
259
+ cfg=guidance_scale,
260
+ sampler_name=sampler_name,
261
+ scheduler=scheduler,
262
+ denoise=1.0,
263
+ model=model,
264
+ positive=get_value_at_index(pos_cond, 0),
265
+ negative=get_value_at_index(neg_cond, 0),
266
+ latent_image=get_value_at_index(latent, 0)
267
+ )
268
+
269
+ decoded_images_tensor = get_value_at_index(vaedecode.decode(samples=get_value_at_index(sampled, 0), vae=vae), 0)
270
+
271
+ for i in range(decoded_images_tensor.shape[0]):
272
+ img_tensor = decoded_images_tensor[i]
273
+ pil_image = Image.fromarray((img_tensor.cpu().numpy() * 255.0).astype("uint8"))
274
+
275
+ current_seed = start_seed + i
276
+
277
+ model_hash = DISPLAY_NAME_TO_HASH_MAP.get(model_display_name, "N/A")
278
+ params_string = f"{positive_prompt}\nNegative prompt: {negative_prompt}\n"
279
+ params_string += f"Steps: {num_inference_steps}, Sampler: {sampler_name}, Scheduler: {scheduler}, CFG scale: {guidance_scale}, Seed: {current_seed}, Size: {width}x{height}, Base Model: {model_display_name}, Model hash: {model_hash}"
280
+ if is_sd15: params_string += f", Clip skip: {clip_skip}"
281
+ params_string += f", {loras_string}"
282
+ pil_image.info = {'parameters': params_string.strip()}
283
+
284
+ output_images.append(pil_image)
285
+
286
+ return output_images
287
+
288
+ def generate_image_wrapper(*args, **kwargs):
289
+ logic_args_list = list(args[:11])
290
+ zero_gpu_duration = args[11]
291
+ logic_args_list.extend(args[12:])
292
+ duration = 60
293
+ try:
294
+ if zero_gpu_duration and int(zero_gpu_duration) > 0:
295
+ duration = int(zero_gpu_duration)
296
+ except (ValueError, TypeError):
297
+ pass
298
+ return spaces.GPU(duration=duration)(_generate_image_logic)(*logic_args_list, **kwargs)
299
+
300
+
301
+ # --- PNG Info & UI Logic ---
302
+ def _parse_parameters(params_text):
303
+ data = {}; lines = params_text.strip().split('\n'); data['prompt'] = lines[0]
304
+ data['negative_prompt'] = lines[1].replace("Negative prompt:", "").strip() if len(lines) > 1 and lines[1].startswith("Negative prompt:") else ""
305
+ params_line = '\n'.join(lines[2:])
306
+ def find_param(key, default, cast_type=str):
307
+ match = re.search(fr"\b{key}: ([^,]+?)(,|$|\n)", params_line)
308
+ return cast_type(match.group(1).strip()) if match else default
309
+ data['steps'] = find_param("Steps", 28, int); data['sampler'] = find_param("Sampler", SAMPLER_CHOICES[0], str)
310
+ data['scheduler'] = find_param("Scheduler", SCHEDULER_CHOICES[0], str); data['cfg_scale'] = find_param("CFG scale", 7.5, float)
311
+ data['seed'] = find_param("Seed", -1, int); data['clip_skip'] = find_param("Clip skip", 1, int)
312
+ data['base_model'] = find_param("Base Model", list(ALL_MODEL_MAP.keys())[0], str); data['model_hash'] = find_param("Model hash", None, str)
313
+ size_match = re.search(r"Size: (\d+)x(\d+)", params_line)
314
+ data['width'], data['height'] = (int(size_match.group(1)), int(size_match.group(2))) if size_match else (1024, 1024)
315
+ return data
316
+
317
+ def get_png_info(image):
318
+ if not image or not (params := image.info.get('parameters')): return "", "", "No metadata found in the image."
319
+ parsed_data = _parse_parameters(params)
320
+ other_params_text = "\n".join([p.strip() for p in '\n'.join(params.strip().split('\n')[2:]).split(',')])
321
+ return parsed_data.get('prompt', ''), parsed_data.get('negative_prompt', ''), other_params_text
322
+
323
+ def apply_data_to_ui(data, target_tab):
324
+ final_sampler = data.get('sampler') if data.get('sampler') in SAMPLER_CHOICES else SAMPLER_CHOICES[0]
325
+ default_scheduler = 'normal' if 'normal' in SCHEDULER_CHOICES else SCHEDULER_CHOICES[0]
326
+ final_scheduler = data.get('scheduler') if data.get('scheduler') in SCHEDULER_CHOICES else default_scheduler
327
+
328
+ updates = {}
329
+ base_model_name = data.get('base_model')
330
+
331
+ if target_tab == "Illustrious":
332
+ if base_model_name in MODEL_MAP_ILLUSTRIOUS:
333
+ updates.update({base_model_name_input_illustrious: base_model_name})
334
+ updates.update({prompt_illustrious: data['prompt'], negative_prompt_illustrious: data['negative_prompt'], seed_illustrious: data['seed'], width_illustrious: data['width'], height_illustrious: data['height'], guidance_scale_illustrious: data['cfg_scale'], num_inference_steps_illustrious: data['steps'], sampler_illustrious: final_sampler, schedule_type_illustrious: final_scheduler, model_tabs: gr.Tabs(selected=0)})
335
+ elif target_tab == "Animagine":
336
+ if base_model_name in MODEL_MAP_ANIMAGINE:
337
+ updates.update({base_model_name_input_animagine: base_model_name})
338
+ updates.update({prompt_animagine: data['prompt'], negative_prompt_animagine: data['negative_prompt'], seed_animagine: data['seed'], width_animagine: data['width'], height_animagine: data['height'], guidance_scale_animagine: data['cfg_scale'], num_inference_steps_animagine: data['steps'], sampler_animagine: final_sampler, schedule_type_animagine: final_scheduler, model_tabs: gr.Tabs(selected=1)})
339
+ elif target_tab == "Pony":
340
+ if base_model_name in MODEL_MAP_PONY:
341
+ updates.update({base_model_name_input_pony: base_model_name})
342
+ updates.update({prompt_pony: data['prompt'], negative_prompt_pony: data['negative_prompt'], seed_pony: data['seed'], width_pony: data['width'], height_pony: data['height'], guidance_scale_pony: data['cfg_scale'], num_inference_steps_pony: data['steps'], sampler_pony: final_sampler, schedule_type_pony: final_scheduler, model_tabs: gr.Tabs(selected=2)})
343
+ elif target_tab == "SD1.5":
344
+ if base_model_name in MODEL_MAP_SD15:
345
+ updates.update({base_model_name_input_sd15: base_model_name})
346
+ updates.update({prompt_sd15: data['prompt'], negative_prompt_sd15: data['negative_prompt'], seed_sd15: data['seed'], width_sd15: data['width'], height_sd15: data['height'], guidance_scale_sd15: data['cfg_scale'], num_inference_steps_sd15: data['steps'], sampler_sd15: final_sampler, schedule_type_sd15: final_scheduler, clip_skip_sd15: data.get('clip_skip', 1), model_tabs: gr.Tabs(selected=3)})
347
+
348
+ updates[tabs] = gr.Tabs(selected=0)
349
+ return updates
350
+
351
+ def send_info_to_tab(image, target_tab):
352
+ if not image or not image.info.get('parameters', ''): return {comp: gr.update() for comp in all_ui_components}
353
+ data = _parse_parameters(image.info['parameters'])
354
+ return apply_data_to_ui(data, target_tab)
355
+
356
+ def send_info_by_hash(image):
357
+ if not image or not image.info.get('parameters', ''): return {comp: gr.update() for comp in all_ui_components}
358
+ data = _parse_parameters(image.info['parameters'])
359
+ model_hash = data.get('model_hash')
360
+ display_name = HASH_TO_DISPLAY_NAME_MAP.get(model_hash)
361
+
362
+ if not display_name:
363
+ raise gr.Error("Model hash not found in this app's model list. The original model name from the PNG will be used if it exists in the target tab.")
364
+
365
+ if display_name in MODEL_MAP_ILLUSTRIOUS: target_tab = "Illustrious"
366
+ elif display_name in MODEL_MAP_ANIMAGINE: target_tab = "Animagine"
367
+ elif display_name in MODEL_MAP_PONY: target_tab = "Pony"
368
+ elif display_name in MODEL_MAP_SD15: target_tab = "SD1.5"
369
+ else:
370
+ raise gr.Error("Cannot determine the correct tab for this model.")
371
+
372
+ data['base_model'] = display_name
373
+ return apply_data_to_ui(data, target_tab)
374
+
375
+ # --- UI Generation Functions ---
376
+ def create_lora_settings_ui():
377
+ with gr.Accordion("LoRA Settings", open=False):
378
+ gr.Markdown("⚠️ **Responsible Use Notice:** Please avoid excessive, rapid, or automated (scripted) use of the pre-download LoRA feature. Overt misuse may lead to service disruption. Thank you for your cooperation.")
379
+ gr.Markdown("For LoRAs that require login to download, you may need to enter the corresponding API Key.")
380
+ with gr.Row():
381
+ civitai_api_key = gr.Textbox(label="Civitai API Key", placeholder="Enter your Civitai API Key", type="password", scale=1)
382
+ tensorart_api_key = gr.Textbox(label="TensorArt API Key", placeholder="Enter your TensorArt API Key", type="password", scale=1)
383
+ gr.Markdown("---")
384
+ gr.Markdown("For each LoRA, choose a source, provide an ID/URL, or upload a file.")
385
+ gr.Markdown("""
386
+ <div style='background-color: #282828; color: #a0aec0; padding: 10px; border-radius: 5px; margin-top: 10px; margin-bottom: 15px;'>
387
+ <b>Input Examples:</b>
388
+ <ul>
389
+ <li><b>Civitai:</b> Enter the <b>Model Version ID</b>, not the Model ID. Example: <code>133755</code> (Found in the URL, e.g., <code>civitai.com/models/122136?modelVersionId=<b>133755</b></code>)</li>
390
+ <li><b>TensorArt:</b> Enter the <b>Model ID</b>. Example: <code>706684852832599558</code> (Found in the URL, e.g., <code>tensor.art/models/<b>706684852832599558</b></code>)</li>
391
+ <li><b>Custom URL:</b> Provide a direct download link to a <code>.safetensors</code> file. Example: <code>https://huggingface.co/path/to/your/lora.safetensors</code></li>
392
+ <li><b>File:</b> Use the "Upload" button. The source will be set automatically.</li>
393
+ </ul>
394
+ </div>
395
+ """)
396
+ gr.Markdown("""
397
+ <div style='background-color: #282828; color: #a0aec0; padding: 10px; border-radius: 5px; margin-bottom: 15px;'>
398
+ <b>Notice:</b>
399
+ <ul style='margin-bottom: 0;'>
400
+ <li>With Gradio, the page may become unresponsive until a file is fully uploaded. Please be patient and wait for the process to complete.</li>
401
+ </ul>
402
+ </div>
403
+ """)
404
+ lora_rows, sources, ids, scales, uploads = [], [], [], [], []
405
+ for i in range(MAX_LORAS):
406
+ with gr.Row(visible=(i == 0)) as row:
407
+ source = gr.Dropdown(label=f"LoRA {i+1} Source", choices=LORA_SOURCE_CHOICES, value="Civitai", scale=1)
408
+ lora_id = gr.Textbox(label="ID / URL / File", placeholder="e.g.: 133755", scale=2)
409
+ scale = gr.Slider(label="Weight", minimum=0.0, maximum=2.0, step=0.05, value=0.0, scale=2)
410
+ upload = gr.UploadButton("Upload", file_types=[".safetensors"], scale=1)
411
+ lora_rows.append(row); sources.append(source); ids.append(lora_id); scales.append(scale); uploads.append(upload)
412
+ upload.upload(fn=lambda f: (os.path.basename(f.name), "File") if f else (gr.update(), gr.update()), inputs=[upload], outputs=[lora_id, source])
413
+ with gr.Row(): add_button = gr.Button("✚ Add LoRA"); delete_button = gr.Button("➖ Delete LoRA", visible=False)
414
+ count_state = gr.State(value=1)
415
+ all_components = [item for sublist in zip(sources, ids, scales, uploads) for item in sublist]
416
+ return (civitai_api_key, tensorart_api_key, lora_rows, sources, ids, scales, uploads, add_button, delete_button, count_state, all_components)
417
+
418
+ def download_all_models_on_startup():
419
+ """Downloads all base models listed in ALL_MODEL_MAP when the app starts."""
420
+ print("--- Starting pre-download of all base models ---")
421
+ for model_display_name, model_info in ALL_MODEL_MAP.items():
422
+ repo_id, filename, _, _ = model_info
423
+ local_file_path = os.path.join(CHECKPOINT_DIR, filename)
424
+
425
+ if os.path.exists(local_file_path):
426
+ print(f"✅ Model '{filename}' already exists. Skipping download.")
427
+ continue
428
+
429
+ try:
430
+ print(f"Downloading: {model_display_name} ({filename})...")
431
+ hf_hub_download(
432
+ repo_id=repo_id,
433
+ filename=filename,
434
+ local_dir=CHECKPOINT_DIR,
435
+ local_dir_use_symlinks=False
436
+ )
437
+ print(f"✅ Successfully downloaded {filename}.")
438
+ except Exception as e:
439
+ print(f"❌ Failed to download {filename} from {repo_id}: {e}")
440
+ print("--- Finished pre-downloading all base models ---")
441
+
442
+ # --- Execute model download on startup ---
443
+ download_all_models_on_startup()
444
+
445
+ # --- Gradio UI ---
446
+ with gr.Blocks(css="#col-container {margin: 0 auto; max-width: 1024px;}") as demo:
447
+ gr.Markdown("# Animated T2I with LoRAs")
448
+ with gr.Tabs(elem_id="tabs_container") as tabs:
449
+ with gr.TabItem("txt2img", id=0):
450
+ with gr.Tabs() as model_tabs:
451
+ for tab_name, model_map, defaults in [
452
+ ("Illustrious", MODEL_MAP_ILLUSTRIOUS, {'w': 1024, 'h': 1024, 'cs_vis': False, 'cs_val': 1}),
453
+ ("Animagine", MODEL_MAP_ANIMAGINE, {'w': 1024, 'h': 1024, 'cs_vis': False, 'cs_val': 1}),
454
+ ("Pony", MODEL_MAP_PONY, {'w': 1024, 'h': 1024, 'cs_vis': False, 'cs_val': 1}),
455
+ ("SD1.5", MODEL_MAP_SD15, {'w': 512, 'h': 768, 'cs_vis': True, 'cs_val': 1})
456
+ ]:
457
+ with gr.TabItem(tab_name):
458
+ gr.Markdown("💡 **Tip:** Pre-downloading LoRAs before 'Run' can maximize ZeroGPU time.")
459
+ with gr.Column():
460
+ with gr.Row():
461
+ base_model = gr.Dropdown(label="Base Model", choices=list(model_map.keys()), value=list(model_map.keys())[0], scale=3)
462
+ with gr.Column(scale=1): predownload_lora = gr.Button("Pre-download LoRAs"); run = gr.Button("Run", variant="primary")
463
+ predownload_status = gr.Markdown("")
464
+ prompt = gr.Text(label="Prompt", lines=3, placeholder="Enter your prompt")
465
+ neg_prompt = gr.Text(label="Negative prompt", lines=3, value=DEFAULT_NEGATIVE_PROMPT)
466
+ with gr.Row():
467
+ with gr.Column(scale=2):
468
+ with gr.Row(): width = gr.Slider(label="Width", minimum=256, maximum=2048, step=64, value=defaults['w']); height = gr.Slider(label="Height", minimum=256, maximum=2048, step=64, value=defaults['h'])
469
+ with gr.Row():
470
+ sampler = gr.Dropdown(label="Sampling method", choices=SAMPLER_CHOICES, value=SAMPLER_CHOICES[0])
471
+ default_scheduler = 'normal' if 'normal' in SCHEDULER_CHOICES else SCHEDULER_CHOICES[0]
472
+ scheduler = gr.Dropdown(label="Scheduler", choices=SCHEDULER_CHOICES, value=default_scheduler)
473
+ with gr.Row(): cfg = gr.Slider(label="CFG Scale", minimum=0.0, maximum=20.0, step=0.1, value=7.5); steps = gr.Slider(label="Sampling steps", minimum=1, maximum=50, step=1, value=28)
474
+ with gr.Column(scale=1): result = gr.Gallery(label="Result", show_label=False, columns=2, object_fit="contain", height="auto")
475
+ with gr.Row():
476
+ seed = gr.Number(label="Seed (-1 for random)", value=-1, precision=0)
477
+ batch_size = gr.Slider(label="Batch size", minimum=1, maximum=8, step=1, value=1)
478
+ clip_skip = gr.Slider(label="Clip Skip", minimum=1, maximum=2, step=1, value=defaults['cs_val'], visible=defaults['cs_vis'])
479
+ zero_gpu = gr.Number(label="ZeroGPU Duration (s)", value=None, placeholder="Default: 60s", info="Optional: Leave empty for default (60s), max to 120")
480
+ lora_settings = create_lora_settings_ui()
481
+
482
+ # Assign specific variables for event handlers
483
+ if tab_name == "Illustrious":
484
+ base_model_name_input_illustrious, prompt_illustrious, negative_prompt_illustrious, seed_illustrious, batch_size_illustrious, width_illustrious, height_illustrious, guidance_scale_illustrious, num_inference_steps_illustrious, sampler_illustrious, schedule_type_illustrious, zero_gpu_duration_illustrious, result_illustrious = base_model, prompt, neg_prompt, seed, batch_size, width, height, cfg, steps, sampler, scheduler, zero_gpu, result
485
+ civitai_api_key_illustrious, tensorart_api_key_illustrious, lora_rows_illustrious, _, lora_id_inputs_illustrious, lora_scale_inputs_illustrious, _, add_lora_button_illustrious, delete_lora_button_illustrious, lora_count_state_illustrious, all_lora_components_flat_illustrious = lora_settings
486
+ predownload_lora_button_illustrious, run_button_illustrious, predownload_status_illustrious = predownload_lora, run, predownload_status
487
+ elif tab_name == "Animagine":
488
+ base_model_name_input_animagine, prompt_animagine, negative_prompt_animagine, seed_animagine, batch_size_animagine, width_animagine, height_animagine, guidance_scale_animagine, num_inference_steps_animagine, sampler_animagine, schedule_type_animagine, zero_gpu_duration_animagine, result_animagine = base_model, prompt, neg_prompt, seed, batch_size, width, height, cfg, steps, sampler, scheduler, zero_gpu, result
489
+ civitai_api_key_animagine, tensorart_api_key_animagine, lora_rows_animagine, _, lora_id_inputs_animagine, lora_scale_inputs_animagine, _, add_lora_button_animagine, delete_lora_button_animagine, lora_count_state_animagine, all_lora_components_flat_animagine = lora_settings
490
+ predownload_lora_button_animagine, run_button_animagine, predownload_status_animagine = predownload_lora, run, predownload_status
491
+ elif tab_name == "Pony":
492
+ base_model_name_input_pony, prompt_pony, negative_prompt_pony, seed_pony, batch_size_pony, width_pony, height_pony, guidance_scale_pony, num_inference_steps_pony, sampler_pony, schedule_type_pony, zero_gpu_duration_pony, result_pony = base_model, prompt, neg_prompt, seed, batch_size, width, height, cfg, steps, sampler, scheduler, zero_gpu, result
493
+ civitai_api_key_pony, tensorart_api_key_pony, lora_rows_pony, _, lora_id_inputs_pony, lora_scale_inputs_pony, _, add_lora_button_pony, delete_lora_button_pony, lora_count_state_pony, all_lora_components_flat_pony = lora_settings
494
+ predownload_lora_button_pony, run_button_pony, predownload_status_pony = predownload_lora, run, predownload_status
495
+ elif tab_name == "SD1.5":
496
+ base_model_name_input_sd15, prompt_sd15, negative_prompt_sd15, seed_sd15, batch_size_sd15, width_sd15, height_sd15, guidance_scale_sd15, num_inference_steps_sd15, sampler_sd15, schedule_type_sd15, clip_skip_sd15, zero_gpu_duration_sd15, result_sd15 = base_model, prompt, neg_prompt, seed, batch_size, width, height, cfg, steps, sampler, scheduler, clip_skip, zero_gpu, result
497
+ civitai_api_key_sd15, tensorart_api_key_sd15, lora_rows_sd15, _, lora_id_inputs_sd15, lora_scale_inputs_sd15, _, add_lora_button_sd15, delete_lora_button_sd15, lora_count_state_sd15, all_lora_components_flat_sd15 = lora_settings
498
+ predownload_lora_button_sd15, run_button_sd15, predownload_status_sd15 = predownload_lora, run, predownload_status
499
+ with gr.TabItem("PNG Info", id=1):
500
+ with gr.Column():
501
+ info_image_input = gr.Image(type="pil", label="Upload Image", height=512)
502
+ with gr.Row():
503
+ info_get_button = gr.Button("Get Info")
504
+ send_by_hash_button = gr.Button("Send to txt2img by Model Hash", variant="primary")
505
+ with gr.Row():
506
+ send_to_illustrious_button = gr.Button("Send to Illustrious")
507
+ send_to_animagine_button = gr.Button("Send to Animagine")
508
+ send_to_pony_button = gr.Button("Send to Pony")
509
+ send_to_sd15_button = gr.Button("Send to SD1.5")
510
+ gr.Markdown("### Positive Prompt"); info_prompt_output = gr.Textbox(lines=3, interactive=False, show_label=False)
511
+ gr.Markdown("### Negative Prompt"); info_neg_prompt_output = gr.Textbox(lines=3, interactive=False, show_label=False)
512
+ gr.Markdown("### Other Parameters"); info_params_output = gr.Textbox(lines=5, interactive=False, show_label=False)
513
+ gr.Markdown("<div style='text-align: center; margin-top: 20px;'>Made by <a href='https://civitai.com/user/RioShiina'>RioShiina</a> with ❤️</div>")
514
+
515
+ # --- Event Handlers ---
516
+ def create_lora_event_handlers(lora_rows, count_state, add_button, del_button, lora_ids, lora_scales):
517
+ def add_lora_row(c): return {count_state: c+1, lora_rows[c]: gr.update(visible=True), del_button: gr.update(visible=True), add_button: gr.update(visible=c+1 < MAX_LORAS)}
518
+ def del_lora_row(c): c-=1; return {count_state: c, lora_rows[c]: gr.update(visible=False), lora_ids[c]: "", lora_scales[c]: 0.0, add_button: gr.update(visible=True), del_button: gr.update(visible=c > 1)}
519
+ add_button.click(add_lora_row, [count_state], [count_state, add_button, del_button, *lora_rows])
520
+ del_button.click(del_lora_row, [count_state], [count_state, add_button, del_button, *lora_rows, *lora_ids, *lora_scales])
521
+
522
+ create_lora_event_handlers(lora_rows_illustrious, lora_count_state_illustrious, add_lora_button_illustrious, delete_lora_button_illustrious, lora_id_inputs_illustrious, lora_scale_inputs_illustrious)
523
+ predownload_lora_button_illustrious.click(lambda: "⏳ Downloading...", None, [predownload_status_illustrious]).then(pre_download_loras, [civitai_api_key_illustrious, tensorart_api_key_illustrious, *all_lora_components_flat_illustrious], [predownload_status_illustrious])
524
+ run_button_illustrious.click(generate_image_wrapper, [base_model_name_input_illustrious, prompt_illustrious, negative_prompt_illustrious, seed_illustrious, batch_size_illustrious, width_illustrious, height_illustrious, guidance_scale_illustrious, num_inference_steps_illustrious, sampler_illustrious, schedule_type_illustrious, zero_gpu_duration_illustrious, civitai_api_key_illustrious, tensorart_api_key_illustrious, *all_lora_components_flat_illustrious], [result_illustrious])
525
+
526
+ create_lora_event_handlers(lora_rows_animagine, lora_count_state_animagine, add_lora_button_animagine, delete_lora_button_animagine, lora_id_inputs_animagine, lora_scale_inputs_animagine)
527
+ predownload_lora_button_animagine.click(lambda: "⏳ Downloading...", None, [predownload_status_animagine]).then(pre_download_loras, [civitai_api_key_animagine, tensorart_api_key_animagine, *all_lora_components_flat_animagine], [predownload_status_animagine])
528
+ run_button_animagine.click(generate_image_wrapper, [base_model_name_input_animagine, prompt_animagine, negative_prompt_animagine, seed_animagine, batch_size_animagine, width_animagine, height_animagine, guidance_scale_animagine, num_inference_steps_animagine, sampler_animagine, schedule_type_animagine, zero_gpu_duration_animagine, civitai_api_key_animagine, tensorart_api_key_animagine, *all_lora_components_flat_animagine], [result_animagine])
529
+
530
+ create_lora_event_handlers(lora_rows_pony, lora_count_state_pony, add_lora_button_pony, delete_lora_button_pony, lora_id_inputs_pony, lora_scale_inputs_pony)
531
+ predownload_lora_button_pony.click(lambda: "⏳ Downloading...", None, [predownload_status_pony]).then(pre_download_loras, [civitai_api_key_pony, tensorart_api_key_pony, *all_lora_components_flat_pony], [predownload_status_pony])
532
+ run_button_pony.click(generate_image_wrapper, [base_model_name_input_pony, prompt_pony, negative_prompt_pony, seed_pony, batch_size_pony, width_pony, height_pony, guidance_scale_pony, num_inference_steps_pony, sampler_pony, schedule_type_pony, zero_gpu_duration_pony, civitai_api_key_pony, tensorart_api_key_pony, *all_lora_components_flat_pony], [result_pony])
533
+
534
+ create_lora_event_handlers(lora_rows_sd15, lora_count_state_sd15, add_lora_button_sd15, delete_lora_button_sd15, lora_id_inputs_sd15, lora_scale_inputs_sd15)
535
+ predownload_lora_button_sd15.click(lambda: "⏳ Downloading...", None, [predownload_status_sd15]).then(pre_download_loras, [civitai_api_key_sd15, tensorart_api_key_sd15, *all_lora_components_flat_sd15], [predownload_status_sd15])
536
+ run_button_sd15.click(generate_image_wrapper, [base_model_name_input_sd15, prompt_sd15, negative_prompt_sd15, seed_sd15, batch_size_sd15, width_sd15, height_sd15, guidance_scale_sd15, num_inference_steps_sd15, sampler_sd15, schedule_type_sd15, zero_gpu_duration_sd15, civitai_api_key_sd15, tensorart_api_key_sd15, *all_lora_components_flat_sd15, clip_skip_sd15], [result_sd15])
537
+
538
+ info_get_button.click(get_png_info, [info_image_input], [info_prompt_output, info_neg_prompt_output, info_params_output])
539
+ all_ui_components = [
540
+ base_model_name_input_illustrious, prompt_illustrious, negative_prompt_illustrious, seed_illustrious, width_illustrious, height_illustrious, guidance_scale_illustrious, num_inference_steps_illustrious, sampler_illustrious, schedule_type_illustrious,
541
+ base_model_name_input_animagine, prompt_animagine, negative_prompt_animagine, seed_animagine, width_animagine, height_animagine, guidance_scale_animagine, num_inference_steps_animagine, sampler_animagine, schedule_type_animagine,
542
+ base_model_name_input_pony, prompt_pony, negative_prompt_pony, seed_pony, width_pony, height_pony, guidance_scale_pony, num_inference_steps_pony, sampler_pony, schedule_type_pony,
543
+ base_model_name_input_sd15, prompt_sd15, negative_prompt_sd15, seed_sd15, width_sd15, height_sd15, guidance_scale_sd15, num_inference_steps_sd15, sampler_sd15, schedule_type_sd15, clip_skip_sd15,
544
+ tabs, model_tabs
545
+ ]
546
+ send_to_illustrious_button.click(lambda img: send_info_to_tab(img, "Illustrious"), [info_image_input], all_ui_components)
547
+ send_to_animagine_button.click(lambda img: send_info_to_tab(img, "Animagine"), [info_image_input], all_ui_components)
548
+ send_to_pony_button.click(lambda img: send_info_to_tab(img, "Pony"), [info_image_input], all_ui_components)
549
+ send_to_sd15_button.click(lambda img: send_info_to_tab(img, "SD1.5"), [info_image_input], all_ui_components)
550
+ send_by_hash_button.click(send_info_by_hash, [info_image_input], all_ui_components)
551
+
552
+ if __name__ == "__main__":
553
+ demo.queue().launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app/__init__.py ADDED
File without changes
app/app_settings.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ from aiohttp import web
4
+ import logging
5
+
6
+
7
+ class AppSettings():
8
+ def __init__(self, user_manager):
9
+ self.user_manager = user_manager
10
+
11
+ def get_settings(self, request):
12
+ try:
13
+ file = self.user_manager.get_request_user_filepath(
14
+ request,
15
+ "comfy.settings.json"
16
+ )
17
+ except KeyError as e:
18
+ logging.error("User settings not found.")
19
+ raise web.HTTPUnauthorized() from e
20
+ if os.path.isfile(file):
21
+ try:
22
+ with open(file) as f:
23
+ return json.load(f)
24
+ except:
25
+ logging.error(f"The user settings file is corrupted: {file}")
26
+ return {}
27
+ else:
28
+ return {}
29
+
30
+ def save_settings(self, request, settings):
31
+ file = self.user_manager.get_request_user_filepath(
32
+ request, "comfy.settings.json")
33
+ with open(file, "w") as f:
34
+ f.write(json.dumps(settings, indent=4))
35
+
36
+ def add_routes(self, routes):
37
+ @routes.get("/settings")
38
+ async def get_settings(request):
39
+ return web.json_response(self.get_settings(request))
40
+
41
+ @routes.get("/settings/{id}")
42
+ async def get_setting(request):
43
+ value = None
44
+ settings = self.get_settings(request)
45
+ setting_id = request.match_info.get("id", None)
46
+ if setting_id and setting_id in settings:
47
+ value = settings[setting_id]
48
+ return web.json_response(value)
49
+
50
+ @routes.post("/settings")
51
+ async def post_settings(request):
52
+ settings = self.get_settings(request)
53
+ new_settings = await request.json()
54
+ self.save_settings(request, {**settings, **new_settings})
55
+ return web.Response(status=200)
56
+
57
+ @routes.post("/settings/{id}")
58
+ async def post_setting(request):
59
+ setting_id = request.match_info.get("id", None)
60
+ if not setting_id:
61
+ return web.Response(status=400)
62
+ settings = self.get_settings(request)
63
+ settings[setting_id] = await request.json()
64
+ self.save_settings(request, settings)
65
+ return web.Response(status=200)
app/custom_node_manager.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import folder_paths
5
+ import glob
6
+ from aiohttp import web
7
+ import json
8
+ import logging
9
+ from functools import lru_cache
10
+
11
+ from utils.json_util import merge_json_recursive
12
+
13
+
14
+ # Extra locale files to load into main.json
15
+ EXTRA_LOCALE_FILES = [
16
+ "nodeDefs.json",
17
+ "commands.json",
18
+ "settings.json",
19
+ ]
20
+
21
+
22
+ def safe_load_json_file(file_path: str) -> dict:
23
+ if not os.path.exists(file_path):
24
+ return {}
25
+
26
+ try:
27
+ with open(file_path, "r", encoding="utf-8") as f:
28
+ return json.load(f)
29
+ except json.JSONDecodeError:
30
+ logging.error(f"Error loading {file_path}")
31
+ return {}
32
+
33
+
34
+ class CustomNodeManager:
35
+ @lru_cache(maxsize=1)
36
+ def build_translations(self):
37
+ """Load all custom nodes translations during initialization. Translations are
38
+ expected to be loaded from `locales/` folder.
39
+
40
+ The folder structure is expected to be the following:
41
+ - custom_nodes/
42
+ - custom_node_1/
43
+ - locales/
44
+ - en/
45
+ - main.json
46
+ - commands.json
47
+ - settings.json
48
+
49
+ returned translations are expected to be in the following format:
50
+ {
51
+ "en": {
52
+ "nodeDefs": {...},
53
+ "commands": {...},
54
+ "settings": {...},
55
+ ...{other main.json keys}
56
+ }
57
+ }
58
+ """
59
+
60
+ translations = {}
61
+
62
+ for folder in folder_paths.get_folder_paths("custom_nodes"):
63
+ # Sort glob results for deterministic ordering
64
+ for custom_node_dir in sorted(glob.glob(os.path.join(folder, "*/"))):
65
+ locales_dir = os.path.join(custom_node_dir, "locales")
66
+ if not os.path.exists(locales_dir):
67
+ continue
68
+
69
+ for lang_dir in glob.glob(os.path.join(locales_dir, "*/")):
70
+ lang_code = os.path.basename(os.path.dirname(lang_dir))
71
+
72
+ if lang_code not in translations:
73
+ translations[lang_code] = {}
74
+
75
+ # Load main.json
76
+ main_file = os.path.join(lang_dir, "main.json")
77
+ node_translations = safe_load_json_file(main_file)
78
+
79
+ # Load extra locale files
80
+ for extra_file in EXTRA_LOCALE_FILES:
81
+ extra_file_path = os.path.join(lang_dir, extra_file)
82
+ key = extra_file.split(".")[0]
83
+ json_data = safe_load_json_file(extra_file_path)
84
+ if json_data:
85
+ node_translations[key] = json_data
86
+
87
+ if node_translations:
88
+ translations[lang_code] = merge_json_recursive(
89
+ translations[lang_code], node_translations
90
+ )
91
+
92
+ return translations
93
+
94
+ def add_routes(self, routes, webapp, loadedModules):
95
+
96
+ example_workflow_folder_names = ["example_workflows", "example", "examples", "workflow", "workflows"]
97
+
98
+ @routes.get("/workflow_templates")
99
+ async def get_workflow_templates(request):
100
+ """Returns a web response that contains the map of custom_nodes names and their associated workflow templates. The ones without templates are omitted."""
101
+
102
+ files = []
103
+
104
+ for folder in folder_paths.get_folder_paths("custom_nodes"):
105
+ for folder_name in example_workflow_folder_names:
106
+ pattern = os.path.join(folder, f"*/{folder_name}/*.json")
107
+ matched_files = glob.glob(pattern)
108
+ files.extend(matched_files)
109
+
110
+ workflow_templates_dict = (
111
+ {}
112
+ ) # custom_nodes folder name -> example workflow names
113
+ for file in files:
114
+ custom_nodes_name = os.path.basename(
115
+ os.path.dirname(os.path.dirname(file))
116
+ )
117
+ workflow_name = os.path.splitext(os.path.basename(file))[0]
118
+ workflow_templates_dict.setdefault(custom_nodes_name, []).append(
119
+ workflow_name
120
+ )
121
+ return web.json_response(workflow_templates_dict)
122
+
123
+ # Serve workflow templates from custom nodes.
124
+ for module_name, module_dir in loadedModules:
125
+ for folder_name in example_workflow_folder_names:
126
+ workflows_dir = os.path.join(module_dir, folder_name)
127
+
128
+ if os.path.exists(workflows_dir):
129
+ if folder_name != "example_workflows":
130
+ logging.debug(
131
+ "Found example workflow folder '%s' for custom node '%s', consider renaming it to 'example_workflows'",
132
+ folder_name, module_name)
133
+
134
+ webapp.add_routes(
135
+ [
136
+ web.static(
137
+ "/api/workflow_templates/" + module_name, workflows_dir
138
+ )
139
+ ]
140
+ )
141
+
142
+ @routes.get("/i18n")
143
+ async def get_i18n(request):
144
+ """Returns translations from all custom nodes' locales folders."""
145
+ return web.json_response(self.build_translations())
app/database/db.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import shutil
4
+ from app.logger import log_startup_warning
5
+ from utils.install_util import get_missing_requirements_message
6
+ from comfy.cli_args import args
7
+
8
+ _DB_AVAILABLE = False
9
+ Session = None
10
+
11
+
12
+ try:
13
+ from alembic import command
14
+ from alembic.config import Config
15
+ from alembic.runtime.migration import MigrationContext
16
+ from alembic.script import ScriptDirectory
17
+ from sqlalchemy import create_engine
18
+ from sqlalchemy.orm import sessionmaker
19
+
20
+ _DB_AVAILABLE = True
21
+ except ImportError as e:
22
+ log_startup_warning(
23
+ f"""
24
+ ------------------------------------------------------------------------
25
+ Error importing dependencies: {e}
26
+ {get_missing_requirements_message()}
27
+ This error is happening because ComfyUI now uses a local sqlite database.
28
+ ------------------------------------------------------------------------
29
+ """.strip()
30
+ )
31
+
32
+
33
+ def dependencies_available():
34
+ """
35
+ Temporary function to check if the dependencies are available
36
+ """
37
+ return _DB_AVAILABLE
38
+
39
+
40
+ def can_create_session():
41
+ """
42
+ Temporary function to check if the database is available to create a session
43
+ During initial release there may be environmental issues (or missing dependencies) that prevent the database from being created
44
+ """
45
+ return dependencies_available() and Session is not None
46
+
47
+
48
+ def get_alembic_config():
49
+ root_path = os.path.join(os.path.dirname(__file__), "../..")
50
+ config_path = os.path.abspath(os.path.join(root_path, "alembic.ini"))
51
+ scripts_path = os.path.abspath(os.path.join(root_path, "alembic_db"))
52
+
53
+ config = Config(config_path)
54
+ config.set_main_option("script_location", scripts_path)
55
+ config.set_main_option("sqlalchemy.url", args.database_url)
56
+
57
+ return config
58
+
59
+
60
+ def get_db_path():
61
+ url = args.database_url
62
+ if url.startswith("sqlite:///"):
63
+ return url.split("///")[1]
64
+ else:
65
+ raise ValueError(f"Unsupported database URL '{url}'.")
66
+
67
+
68
+ def init_db():
69
+ db_url = args.database_url
70
+ logging.debug(f"Database URL: {db_url}")
71
+ db_path = get_db_path()
72
+ db_exists = os.path.exists(db_path)
73
+
74
+ config = get_alembic_config()
75
+
76
+ # Check if we need to upgrade
77
+ engine = create_engine(db_url)
78
+ conn = engine.connect()
79
+
80
+ context = MigrationContext.configure(conn)
81
+ current_rev = context.get_current_revision()
82
+
83
+ script = ScriptDirectory.from_config(config)
84
+ target_rev = script.get_current_head()
85
+
86
+ if target_rev is None:
87
+ logging.warning("No target revision found.")
88
+ elif current_rev != target_rev:
89
+ # Backup the database pre upgrade
90
+ backup_path = db_path + ".bkp"
91
+ if db_exists:
92
+ shutil.copy(db_path, backup_path)
93
+ else:
94
+ backup_path = None
95
+
96
+ try:
97
+ command.upgrade(config, target_rev)
98
+ logging.info(f"Database upgraded from {current_rev} to {target_rev}")
99
+ except Exception as e:
100
+ if backup_path:
101
+ # Restore the database from backup if upgrade fails
102
+ shutil.copy(backup_path, db_path)
103
+ os.remove(backup_path)
104
+ logging.exception("Error upgrading database: ")
105
+ raise e
106
+
107
+ global Session
108
+ Session = sessionmaker(bind=engine)
109
+
110
+
111
+ def create_session():
112
+ return Session()
app/database/models.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sqlalchemy.orm import declarative_base
2
+
3
+ Base = declarative_base()
4
+
5
+
6
+ def to_dict(obj):
7
+ fields = obj.__table__.columns.keys()
8
+ return {
9
+ field: (val.to_dict() if hasattr(val, "to_dict") else val)
10
+ for field in fields
11
+ if (val := getattr(obj, field))
12
+ }
13
+
14
+ # TODO: Define models here
app/frontend_management.py ADDED
@@ -0,0 +1,361 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ import argparse
3
+ import logging
4
+ import os
5
+ import re
6
+ import sys
7
+ import tempfile
8
+ import zipfile
9
+ import importlib
10
+ from dataclasses import dataclass
11
+ from functools import cached_property
12
+ from pathlib import Path
13
+ from typing import TypedDict, Optional
14
+ from importlib.metadata import version
15
+
16
+ import requests
17
+ from typing_extensions import NotRequired
18
+
19
+ from utils.install_util import get_missing_requirements_message, requirements_path
20
+
21
+ from comfy.cli_args import DEFAULT_VERSION_STRING
22
+ import app.logger
23
+
24
+
25
+ def frontend_install_warning_message():
26
+ return f"""
27
+ {get_missing_requirements_message()}
28
+
29
+ This error is happening because the ComfyUI frontend is no longer shipped as part of the main repo but as a pip package instead.
30
+ """.strip()
31
+
32
+ def parse_version(version: str) -> tuple[int, int, int]:
33
+ return tuple(map(int, version.split(".")))
34
+
35
+ def is_valid_version(version: str) -> bool:
36
+ """Validate if a string is a valid semantic version (X.Y.Z format)."""
37
+ pattern = r"^(\d+)\.(\d+)\.(\d+)$"
38
+ return bool(re.match(pattern, version))
39
+
40
+ def get_installed_frontend_version():
41
+ """Get the currently installed frontend package version."""
42
+ frontend_version_str = version("comfyui-frontend-package")
43
+ return frontend_version_str
44
+
45
+ def get_required_frontend_version():
46
+ """Get the required frontend version from requirements.txt."""
47
+ try:
48
+ with open(requirements_path, "r", encoding="utf-8") as f:
49
+ for line in f:
50
+ line = line.strip()
51
+ if line.startswith("comfyui-frontend-package=="):
52
+ version_str = line.split("==")[-1]
53
+ if not is_valid_version(version_str):
54
+ logging.error(f"Invalid version format in requirements.txt: {version_str}")
55
+ return None
56
+ return version_str
57
+ logging.error("comfyui-frontend-package not found in requirements.txt")
58
+ return None
59
+ except FileNotFoundError:
60
+ logging.error("requirements.txt not found. Cannot determine required frontend version.")
61
+ return None
62
+ except Exception as e:
63
+ logging.error(f"Error reading requirements.txt: {e}")
64
+ return None
65
+
66
+ def check_frontend_version():
67
+ """Check if the frontend version is up to date."""
68
+
69
+ try:
70
+ frontend_version_str = get_installed_frontend_version()
71
+ frontend_version = parse_version(frontend_version_str)
72
+ required_frontend_str = get_required_frontend_version()
73
+ required_frontend = parse_version(required_frontend_str)
74
+ if frontend_version < required_frontend:
75
+ app.logger.log_startup_warning(
76
+ f"""
77
+ ________________________________________________________________________
78
+ WARNING WARNING WARNING WARNING WARNING
79
+
80
+ Installed frontend version {".".join(map(str, frontend_version))} is lower than the recommended version {".".join(map(str, required_frontend))}.
81
+
82
+ {frontend_install_warning_message()}
83
+ ________________________________________________________________________
84
+ """.strip()
85
+ )
86
+ else:
87
+ logging.info("ComfyUI frontend version: {}".format(frontend_version_str))
88
+ except Exception as e:
89
+ logging.error(f"Failed to check frontend version: {e}")
90
+
91
+
92
+ REQUEST_TIMEOUT = 10 # seconds
93
+
94
+
95
+ class Asset(TypedDict):
96
+ url: str
97
+
98
+
99
+ class Release(TypedDict):
100
+ id: int
101
+ tag_name: str
102
+ name: str
103
+ prerelease: bool
104
+ created_at: str
105
+ published_at: str
106
+ body: str
107
+ assets: NotRequired[list[Asset]]
108
+
109
+
110
+ @dataclass
111
+ class FrontEndProvider:
112
+ owner: str
113
+ repo: str
114
+
115
+ @property
116
+ def folder_name(self) -> str:
117
+ return f"{self.owner}_{self.repo}"
118
+
119
+ @property
120
+ def release_url(self) -> str:
121
+ return f"https://api.github.com/repos/{self.owner}/{self.repo}/releases"
122
+
123
+ @cached_property
124
+ def all_releases(self) -> list[Release]:
125
+ releases = []
126
+ api_url = self.release_url
127
+ while api_url:
128
+ response = requests.get(api_url, timeout=REQUEST_TIMEOUT)
129
+ response.raise_for_status() # Raises an HTTPError if the response was an error
130
+ releases.extend(response.json())
131
+ # GitHub uses the Link header to provide pagination links. Check if it exists and update api_url accordingly.
132
+ if "next" in response.links:
133
+ api_url = response.links["next"]["url"]
134
+ else:
135
+ api_url = None
136
+ return releases
137
+
138
+ @cached_property
139
+ def latest_release(self) -> Release:
140
+ latest_release_url = f"{self.release_url}/latest"
141
+ response = requests.get(latest_release_url, timeout=REQUEST_TIMEOUT)
142
+ response.raise_for_status() # Raises an HTTPError if the response was an error
143
+ return response.json()
144
+
145
+ @cached_property
146
+ def latest_prerelease(self) -> Release:
147
+ """Get the latest pre-release version - even if it's older than the latest release"""
148
+ release = [release for release in self.all_releases if release["prerelease"]]
149
+
150
+ if not release:
151
+ raise ValueError("No pre-releases found")
152
+
153
+ # GitHub returns releases in reverse chronological order, so first is latest
154
+ return release[0]
155
+
156
+ def get_release(self, version: str) -> Release:
157
+ if version == "latest":
158
+ return self.latest_release
159
+ elif version == "prerelease":
160
+ return self.latest_prerelease
161
+ else:
162
+ for release in self.all_releases:
163
+ if release["tag_name"] in [version, f"v{version}"]:
164
+ return release
165
+ raise ValueError(f"Version {version} not found in releases")
166
+
167
+
168
+ def download_release_asset_zip(release: Release, destination_path: str) -> None:
169
+ """Download dist.zip from github release."""
170
+ asset_url = None
171
+ for asset in release.get("assets", []):
172
+ if asset["name"] == "dist.zip":
173
+ asset_url = asset["url"]
174
+ break
175
+
176
+ if not asset_url:
177
+ raise ValueError("dist.zip not found in the release assets")
178
+
179
+ # Use a temporary file to download the zip content
180
+ with tempfile.TemporaryFile() as tmp_file:
181
+ headers = {"Accept": "application/octet-stream"}
182
+ response = requests.get(
183
+ asset_url, headers=headers, allow_redirects=True, timeout=REQUEST_TIMEOUT
184
+ )
185
+ response.raise_for_status() # Ensure we got a successful response
186
+
187
+ # Write the content to the temporary file
188
+ tmp_file.write(response.content)
189
+
190
+ # Go back to the beginning of the temporary file
191
+ tmp_file.seek(0)
192
+
193
+ # Extract the zip file content to the destination path
194
+ with zipfile.ZipFile(tmp_file, "r") as zip_ref:
195
+ zip_ref.extractall(destination_path)
196
+
197
+
198
+ class FrontendManager:
199
+ CUSTOM_FRONTENDS_ROOT = str(Path(__file__).parents[1] / "web_custom_versions")
200
+
201
+ @classmethod
202
+ def get_required_frontend_version(cls) -> str:
203
+ """Get the required frontend package version."""
204
+ return get_required_frontend_version()
205
+
206
+ @classmethod
207
+ def default_frontend_path(cls) -> str:
208
+ try:
209
+ import comfyui_frontend_package
210
+
211
+ return str(importlib.resources.files(comfyui_frontend_package) / "static")
212
+ except ImportError:
213
+ logging.error(
214
+ f"""
215
+ ********** ERROR ***********
216
+
217
+ comfyui-frontend-package is not installed.
218
+
219
+ {frontend_install_warning_message()}
220
+
221
+ ********** ERROR ***********
222
+ """.strip()
223
+ )
224
+ sys.exit(-1)
225
+
226
+ @classmethod
227
+ def templates_path(cls) -> str:
228
+ try:
229
+ import comfyui_workflow_templates
230
+
231
+ return str(
232
+ importlib.resources.files(comfyui_workflow_templates) / "templates"
233
+ )
234
+ except ImportError:
235
+ logging.error(
236
+ f"""
237
+ ********** ERROR ***********
238
+
239
+ comfyui-workflow-templates is not installed.
240
+
241
+ {frontend_install_warning_message()}
242
+
243
+ ********** ERROR ***********
244
+ """.strip()
245
+ )
246
+
247
+ @classmethod
248
+ def embedded_docs_path(cls) -> str:
249
+ """Get the path to embedded documentation"""
250
+ try:
251
+ import comfyui_embedded_docs
252
+
253
+ return str(
254
+ importlib.resources.files(comfyui_embedded_docs) / "docs"
255
+ )
256
+ except ImportError:
257
+ logging.info("comfyui-embedded-docs package not found")
258
+ return None
259
+
260
+ @classmethod
261
+ def parse_version_string(cls, value: str) -> tuple[str, str, str]:
262
+ """
263
+ Args:
264
+ value (str): The version string to parse.
265
+
266
+ Returns:
267
+ tuple[str, str]: A tuple containing provider name and version.
268
+
269
+ Raises:
270
+ argparse.ArgumentTypeError: If the version string is invalid.
271
+ """
272
+ VERSION_PATTERN = r"^([a-zA-Z0-9][a-zA-Z0-9-]{0,38})/([a-zA-Z0-9_.-]+)@(v?\d+\.\d+\.\d+[-._a-zA-Z0-9]*|latest|prerelease)$"
273
+ match_result = re.match(VERSION_PATTERN, value)
274
+ if match_result is None:
275
+ raise argparse.ArgumentTypeError(f"Invalid version string: {value}")
276
+
277
+ return match_result.group(1), match_result.group(2), match_result.group(3)
278
+
279
+ @classmethod
280
+ def init_frontend_unsafe(
281
+ cls, version_string: str, provider: Optional[FrontEndProvider] = None
282
+ ) -> str:
283
+ """
284
+ Initializes the frontend for the specified version.
285
+
286
+ Args:
287
+ version_string (str): The version string.
288
+ provider (FrontEndProvider, optional): The provider to use. Defaults to None.
289
+
290
+ Returns:
291
+ str: The path to the initialized frontend.
292
+
293
+ Raises:
294
+ Exception: If there is an error during the initialization process.
295
+ main error source might be request timeout or invalid URL.
296
+ """
297
+ if version_string == DEFAULT_VERSION_STRING:
298
+ check_frontend_version()
299
+ return cls.default_frontend_path()
300
+
301
+ repo_owner, repo_name, version = cls.parse_version_string(version_string)
302
+
303
+ if version.startswith("v"):
304
+ expected_path = str(
305
+ Path(cls.CUSTOM_FRONTENDS_ROOT)
306
+ / f"{repo_owner}_{repo_name}"
307
+ / version.lstrip("v")
308
+ )
309
+ if os.path.exists(expected_path):
310
+ logging.info(
311
+ f"Using existing copy of specific frontend version tag: {repo_owner}/{repo_name}@{version}"
312
+ )
313
+ return expected_path
314
+
315
+ logging.info(
316
+ f"Initializing frontend: {repo_owner}/{repo_name}@{version}, requesting version details from GitHub..."
317
+ )
318
+
319
+ provider = provider or FrontEndProvider(repo_owner, repo_name)
320
+ release = provider.get_release(version)
321
+
322
+ semantic_version = release["tag_name"].lstrip("v")
323
+ web_root = str(
324
+ Path(cls.CUSTOM_FRONTENDS_ROOT) / provider.folder_name / semantic_version
325
+ )
326
+ if not os.path.exists(web_root):
327
+ try:
328
+ os.makedirs(web_root, exist_ok=True)
329
+ logging.info(
330
+ "Downloading frontend(%s) version(%s) to (%s)",
331
+ provider.folder_name,
332
+ semantic_version,
333
+ web_root,
334
+ )
335
+ logging.debug(release)
336
+ download_release_asset_zip(release, destination_path=web_root)
337
+ finally:
338
+ # Clean up the directory if it is empty, i.e. the download failed
339
+ if not os.listdir(web_root):
340
+ os.rmdir(web_root)
341
+
342
+ return web_root
343
+
344
+ @classmethod
345
+ def init_frontend(cls, version_string: str) -> str:
346
+ """
347
+ Initializes the frontend with the specified version string.
348
+
349
+ Args:
350
+ version_string (str): The version string to initialize the frontend with.
351
+
352
+ Returns:
353
+ str: The path of the initialized frontend.
354
+ """
355
+ try:
356
+ return cls.init_frontend_unsafe(version_string)
357
+ except Exception as e:
358
+ logging.error("Failed to initialize frontend: %s", e)
359
+ logging.info("Falling back to the default frontend.")
360
+ check_frontend_version()
361
+ return cls.default_frontend_path()
app/logger.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import deque
2
+ from datetime import datetime
3
+ import io
4
+ import logging
5
+ import sys
6
+ import threading
7
+
8
+ logs = None
9
+ stdout_interceptor = None
10
+ stderr_interceptor = None
11
+
12
+
13
+ class LogInterceptor(io.TextIOWrapper):
14
+ def __init__(self, stream, *args, **kwargs):
15
+ buffer = stream.buffer
16
+ encoding = stream.encoding
17
+ super().__init__(buffer, *args, **kwargs, encoding=encoding, line_buffering=stream.line_buffering)
18
+ self._lock = threading.Lock()
19
+ self._flush_callbacks = []
20
+ self._logs_since_flush = []
21
+
22
+ def write(self, data):
23
+ entry = {"t": datetime.now().isoformat(), "m": data}
24
+ with self._lock:
25
+ self._logs_since_flush.append(entry)
26
+
27
+ # Simple handling for cr to overwrite the last output if it isnt a full line
28
+ # else logs just get full of progress messages
29
+ if isinstance(data, str) and data.startswith("\r") and not logs[-1]["m"].endswith("\n"):
30
+ logs.pop()
31
+ logs.append(entry)
32
+ super().write(data)
33
+
34
+ def flush(self):
35
+ super().flush()
36
+ for cb in self._flush_callbacks:
37
+ cb(self._logs_since_flush)
38
+ self._logs_since_flush = []
39
+
40
+ def on_flush(self, callback):
41
+ self._flush_callbacks.append(callback)
42
+
43
+
44
+ def get_logs():
45
+ return logs
46
+
47
+
48
+ def on_flush(callback):
49
+ if stdout_interceptor is not None:
50
+ stdout_interceptor.on_flush(callback)
51
+ if stderr_interceptor is not None:
52
+ stderr_interceptor.on_flush(callback)
53
+
54
+ def setup_logger(log_level: str = 'INFO', capacity: int = 300, use_stdout: bool = False):
55
+ global logs
56
+ if logs:
57
+ return
58
+
59
+ # Override output streams and log to buffer
60
+ logs = deque(maxlen=capacity)
61
+
62
+ global stdout_interceptor
63
+ global stderr_interceptor
64
+ stdout_interceptor = sys.stdout = LogInterceptor(sys.stdout)
65
+ stderr_interceptor = sys.stderr = LogInterceptor(sys.stderr)
66
+
67
+ # Setup default global logger
68
+ logger = logging.getLogger()
69
+ logger.setLevel(log_level)
70
+
71
+ stream_handler = logging.StreamHandler()
72
+ stream_handler.setFormatter(logging.Formatter("%(message)s"))
73
+
74
+ if use_stdout:
75
+ # Only errors and critical to stderr
76
+ stream_handler.addFilter(lambda record: not record.levelno < logging.ERROR)
77
+
78
+ # Lesser to stdout
79
+ stdout_handler = logging.StreamHandler(sys.stdout)
80
+ stdout_handler.setFormatter(logging.Formatter("%(message)s"))
81
+ stdout_handler.addFilter(lambda record: record.levelno < logging.ERROR)
82
+ logger.addHandler(stdout_handler)
83
+
84
+ logger.addHandler(stream_handler)
85
+
86
+
87
+ STARTUP_WARNINGS = []
88
+
89
+
90
+ def log_startup_warning(msg):
91
+ logging.warning(msg)
92
+ STARTUP_WARNINGS.append(msg)
93
+
94
+
95
+ def print_startup_warnings():
96
+ for s in STARTUP_WARNINGS:
97
+ logging.warning(s)
98
+ STARTUP_WARNINGS.clear()
app/model_manager.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import base64
5
+ import json
6
+ import time
7
+ import logging
8
+ import folder_paths
9
+ import glob
10
+ import comfy.utils
11
+ from aiohttp import web
12
+ from PIL import Image
13
+ from io import BytesIO
14
+ from folder_paths import map_legacy, filter_files_extensions, filter_files_content_types
15
+
16
+
17
+ class ModelFileManager:
18
+ def __init__(self) -> None:
19
+ self.cache: dict[str, tuple[list[dict], dict[str, float], float]] = {}
20
+
21
+ def get_cache(self, key: str, default=None) -> tuple[list[dict], dict[str, float], float] | None:
22
+ return self.cache.get(key, default)
23
+
24
+ def set_cache(self, key: str, value: tuple[list[dict], dict[str, float], float]):
25
+ self.cache[key] = value
26
+
27
+ def clear_cache(self):
28
+ self.cache.clear()
29
+
30
+ def add_routes(self, routes):
31
+ # NOTE: This is an experiment to replace `/models`
32
+ @routes.get("/experiment/models")
33
+ async def get_model_folders(request):
34
+ model_types = list(folder_paths.folder_names_and_paths.keys())
35
+ folder_black_list = ["configs", "custom_nodes"]
36
+ output_folders: list[dict] = []
37
+ for folder in model_types:
38
+ if folder in folder_black_list:
39
+ continue
40
+ output_folders.append({"name": folder, "folders": folder_paths.get_folder_paths(folder)})
41
+ return web.json_response(output_folders)
42
+
43
+ # NOTE: This is an experiment to replace `/models/{folder}`
44
+ @routes.get("/experiment/models/{folder}")
45
+ async def get_all_models(request):
46
+ folder = request.match_info.get("folder", None)
47
+ if not folder in folder_paths.folder_names_and_paths:
48
+ return web.Response(status=404)
49
+ files = self.get_model_file_list(folder)
50
+ return web.json_response(files)
51
+
52
+ @routes.get("/experiment/models/preview/{folder}/{path_index}/{filename:.*}")
53
+ async def get_model_preview(request):
54
+ folder_name = request.match_info.get("folder", None)
55
+ path_index = int(request.match_info.get("path_index", None))
56
+ filename = request.match_info.get("filename", None)
57
+
58
+ if not folder_name in folder_paths.folder_names_and_paths:
59
+ return web.Response(status=404)
60
+
61
+ folders = folder_paths.folder_names_and_paths[folder_name]
62
+ folder = folders[0][path_index]
63
+ full_filename = os.path.join(folder, filename)
64
+
65
+ previews = self.get_model_previews(full_filename)
66
+ default_preview = previews[0] if len(previews) > 0 else None
67
+ if default_preview is None or (isinstance(default_preview, str) and not os.path.isfile(default_preview)):
68
+ return web.Response(status=404)
69
+
70
+ try:
71
+ with Image.open(default_preview) as img:
72
+ img_bytes = BytesIO()
73
+ img.save(img_bytes, format="WEBP")
74
+ img_bytes.seek(0)
75
+ return web.Response(body=img_bytes.getvalue(), content_type="image/webp")
76
+ except:
77
+ return web.Response(status=404)
78
+
79
+ def get_model_file_list(self, folder_name: str):
80
+ folder_name = map_legacy(folder_name)
81
+ folders = folder_paths.folder_names_and_paths[folder_name]
82
+ output_list: list[dict] = []
83
+
84
+ for index, folder in enumerate(folders[0]):
85
+ if not os.path.isdir(folder):
86
+ continue
87
+ out = self.cache_model_file_list_(folder)
88
+ if out is None:
89
+ out = self.recursive_search_models_(folder, index)
90
+ self.set_cache(folder, out)
91
+ output_list.extend(out[0])
92
+
93
+ return output_list
94
+
95
+ def cache_model_file_list_(self, folder: str):
96
+ model_file_list_cache = self.get_cache(folder)
97
+
98
+ if model_file_list_cache is None:
99
+ return None
100
+ if not os.path.isdir(folder):
101
+ return None
102
+ if os.path.getmtime(folder) != model_file_list_cache[1]:
103
+ return None
104
+ for x in model_file_list_cache[1]:
105
+ time_modified = model_file_list_cache[1][x]
106
+ folder = x
107
+ if os.path.getmtime(folder) != time_modified:
108
+ return None
109
+
110
+ return model_file_list_cache
111
+
112
+ def recursive_search_models_(self, directory: str, pathIndex: int) -> tuple[list[str], dict[str, float], float]:
113
+ if not os.path.isdir(directory):
114
+ return [], {}, time.perf_counter()
115
+
116
+ excluded_dir_names = [".git"]
117
+ # TODO use settings
118
+ include_hidden_files = False
119
+
120
+ result: list[str] = []
121
+ dirs: dict[str, float] = {}
122
+
123
+ for dirpath, subdirs, filenames in os.walk(directory, followlinks=True, topdown=True):
124
+ subdirs[:] = [d for d in subdirs if d not in excluded_dir_names]
125
+ if not include_hidden_files:
126
+ subdirs[:] = [d for d in subdirs if not d.startswith(".")]
127
+ filenames = [f for f in filenames if not f.startswith(".")]
128
+
129
+ filenames = filter_files_extensions(filenames, folder_paths.supported_pt_extensions)
130
+
131
+ for file_name in filenames:
132
+ try:
133
+ full_path = os.path.join(dirpath, file_name)
134
+ relative_path = os.path.relpath(full_path, directory)
135
+
136
+ # Get file metadata
137
+ file_info = {
138
+ "name": relative_path,
139
+ "pathIndex": pathIndex,
140
+ "modified": os.path.getmtime(full_path), # Add modification time
141
+ "created": os.path.getctime(full_path), # Add creation time
142
+ "size": os.path.getsize(full_path) # Add file size
143
+ }
144
+ result.append(file_info)
145
+
146
+ except Exception as e:
147
+ logging.warning(f"Warning: Unable to access {file_name}. Error: {e}. Skipping this file.")
148
+ continue
149
+
150
+ for d in subdirs:
151
+ path: str = os.path.join(dirpath, d)
152
+ try:
153
+ dirs[path] = os.path.getmtime(path)
154
+ except FileNotFoundError:
155
+ logging.warning(f"Warning: Unable to access {path}. Skipping this path.")
156
+ continue
157
+
158
+ return result, dirs, time.perf_counter()
159
+
160
+ def get_model_previews(self, filepath: str) -> list[str | BytesIO]:
161
+ dirname = os.path.dirname(filepath)
162
+
163
+ if not os.path.exists(dirname):
164
+ return []
165
+
166
+ basename = os.path.splitext(filepath)[0]
167
+ match_files = glob.glob(f"{basename}.*", recursive=False)
168
+ image_files = filter_files_content_types(match_files, "image")
169
+ safetensors_file = next(filter(lambda x: x.endswith(".safetensors"), match_files), None)
170
+ safetensors_metadata = {}
171
+
172
+ result: list[str | BytesIO] = []
173
+
174
+ for filename in image_files:
175
+ _basename = os.path.splitext(filename)[0]
176
+ if _basename == basename:
177
+ result.append(filename)
178
+ if _basename == f"{basename}.preview":
179
+ result.append(filename)
180
+
181
+ if safetensors_file:
182
+ safetensors_filepath = os.path.join(dirname, safetensors_file)
183
+ header = comfy.utils.safetensors_header(safetensors_filepath, max_size=8*1024*1024)
184
+ if header:
185
+ safetensors_metadata = json.loads(header)
186
+ safetensors_images = safetensors_metadata.get("__metadata__", {}).get("ssmd_cover_images", None)
187
+ if safetensors_images:
188
+ safetensors_images = json.loads(safetensors_images)
189
+ for image in safetensors_images:
190
+ result.append(BytesIO(base64.b64decode(image)))
191
+
192
+ return result
193
+
194
+ def __exit__(self, exc_type, exc_value, traceback):
195
+ self.clear_cache()
app/user_manager.py ADDED
@@ -0,0 +1,438 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ import json
3
+ import os
4
+ import re
5
+ import uuid
6
+ import glob
7
+ import shutil
8
+ import logging
9
+ from aiohttp import web
10
+ from urllib import parse
11
+ from comfy.cli_args import args
12
+ import folder_paths
13
+ from .app_settings import AppSettings
14
+ from typing import TypedDict
15
+
16
+ default_user = "default"
17
+
18
+
19
+ class FileInfo(TypedDict):
20
+ path: str
21
+ size: int
22
+ modified: int
23
+ created: int
24
+
25
+
26
+ def get_file_info(path: str, relative_to: str) -> FileInfo:
27
+ return {
28
+ "path": os.path.relpath(path, relative_to).replace(os.sep, '/'),
29
+ "size": os.path.getsize(path),
30
+ "modified": os.path.getmtime(path),
31
+ "created": os.path.getctime(path)
32
+ }
33
+
34
+
35
+ class UserManager():
36
+ def __init__(self):
37
+ user_directory = folder_paths.get_user_directory()
38
+
39
+ self.settings = AppSettings(self)
40
+ if not os.path.exists(user_directory):
41
+ os.makedirs(user_directory, exist_ok=True)
42
+ if not args.multi_user:
43
+ logging.warning("****** User settings have been changed to be stored on the server instead of browser storage. ******")
44
+ logging.warning("****** For multi-user setups add the --multi-user CLI argument to enable multiple user profiles. ******")
45
+
46
+ if args.multi_user:
47
+ if os.path.isfile(self.get_users_file()):
48
+ with open(self.get_users_file()) as f:
49
+ self.users = json.load(f)
50
+ else:
51
+ self.users = {}
52
+ else:
53
+ self.users = {"default": "default"}
54
+
55
+ def get_users_file(self):
56
+ return os.path.join(folder_paths.get_user_directory(), "users.json")
57
+
58
+ def get_request_user_id(self, request):
59
+ user = "default"
60
+ if args.multi_user and "comfy-user" in request.headers:
61
+ user = request.headers["comfy-user"]
62
+
63
+ if user not in self.users:
64
+ raise KeyError("Unknown user: " + user)
65
+
66
+ return user
67
+
68
+ def get_request_user_filepath(self, request, file, type="userdata", create_dir=True):
69
+ user_directory = folder_paths.get_user_directory()
70
+
71
+ if type == "userdata":
72
+ root_dir = user_directory
73
+ else:
74
+ raise KeyError("Unknown filepath type:" + type)
75
+
76
+ user = self.get_request_user_id(request)
77
+ path = user_root = os.path.abspath(os.path.join(root_dir, user))
78
+
79
+ # prevent leaving /{type}
80
+ if os.path.commonpath((root_dir, user_root)) != root_dir:
81
+ return None
82
+
83
+ if file is not None:
84
+ # Check if filename is url encoded
85
+ if "%" in file:
86
+ file = parse.unquote(file)
87
+
88
+ # prevent leaving /{type}/{user}
89
+ path = os.path.abspath(os.path.join(user_root, file))
90
+ if os.path.commonpath((user_root, path)) != user_root:
91
+ return None
92
+
93
+ parent = os.path.split(path)[0]
94
+
95
+ if create_dir and not os.path.exists(parent):
96
+ os.makedirs(parent, exist_ok=True)
97
+
98
+ return path
99
+
100
+ def add_user(self, name):
101
+ name = name.strip()
102
+ if not name:
103
+ raise ValueError("username not provided")
104
+ user_id = re.sub("[^a-zA-Z0-9-_]+", '-', name)
105
+ user_id = user_id + "_" + str(uuid.uuid4())
106
+
107
+ self.users[user_id] = name
108
+
109
+ with open(self.get_users_file(), "w") as f:
110
+ json.dump(self.users, f)
111
+
112
+ return user_id
113
+
114
+ def add_routes(self, routes):
115
+ self.settings.add_routes(routes)
116
+
117
+ @routes.get("/users")
118
+ async def get_users(request):
119
+ if args.multi_user:
120
+ return web.json_response({"storage": "server", "users": self.users})
121
+ else:
122
+ user_dir = self.get_request_user_filepath(request, None, create_dir=False)
123
+ return web.json_response({
124
+ "storage": "server",
125
+ "migrated": os.path.exists(user_dir)
126
+ })
127
+
128
+ @routes.post("/users")
129
+ async def post_users(request):
130
+ body = await request.json()
131
+ username = body["username"]
132
+ if username in self.users.values():
133
+ return web.json_response({"error": "Duplicate username."}, status=400)
134
+
135
+ user_id = self.add_user(username)
136
+ return web.json_response(user_id)
137
+
138
+ @routes.get("/userdata")
139
+ async def listuserdata(request):
140
+ """
141
+ List user data files in a specified directory.
142
+
143
+ This endpoint allows listing files in a user's data directory, with options for recursion,
144
+ full file information, and path splitting.
145
+
146
+ Query Parameters:
147
+ - dir (required): The directory to list files from.
148
+ - recurse (optional): If "true", recursively list files in subdirectories.
149
+ - full_info (optional): If "true", return detailed file information (path, size, modified time).
150
+ - split (optional): If "true", split file paths into components (only applies when full_info is false).
151
+
152
+ Returns:
153
+ - 400: If 'dir' parameter is missing.
154
+ - 403: If the requested path is not allowed.
155
+ - 404: If the requested directory does not exist.
156
+ - 200: JSON response with the list of files or file information.
157
+
158
+ The response format depends on the query parameters:
159
+ - Default: List of relative file paths.
160
+ - full_info=true: List of dictionaries with file details.
161
+ - split=true (and full_info=false): List of lists, each containing path components.
162
+ """
163
+ directory = request.rel_url.query.get('dir', '')
164
+ if not directory:
165
+ return web.Response(status=400, text="Directory not provided")
166
+
167
+ path = self.get_request_user_filepath(request, directory)
168
+ if not path:
169
+ return web.Response(status=403, text="Invalid directory")
170
+
171
+ if not os.path.exists(path):
172
+ return web.Response(status=404, text="Directory not found")
173
+
174
+ recurse = request.rel_url.query.get('recurse', '').lower() == "true"
175
+ full_info = request.rel_url.query.get('full_info', '').lower() == "true"
176
+ split_path = request.rel_url.query.get('split', '').lower() == "true"
177
+
178
+ # Use different patterns based on whether we're recursing or not
179
+ if recurse:
180
+ pattern = os.path.join(glob.escape(path), '**', '*')
181
+ else:
182
+ pattern = os.path.join(glob.escape(path), '*')
183
+
184
+ def process_full_path(full_path: str) -> FileInfo | str | list[str]:
185
+ if full_info:
186
+ return get_file_info(full_path, path)
187
+
188
+ rel_path = os.path.relpath(full_path, path).replace(os.sep, '/')
189
+ if split_path:
190
+ return [rel_path] + rel_path.split('/')
191
+
192
+ return rel_path
193
+
194
+ results = [
195
+ process_full_path(full_path)
196
+ for full_path in glob.glob(pattern, recursive=recurse)
197
+ if os.path.isfile(full_path)
198
+ ]
199
+
200
+ return web.json_response(results)
201
+
202
+ @routes.get("/v2/userdata")
203
+ async def list_userdata_v2(request):
204
+ """
205
+ List files and directories in a user's data directory.
206
+
207
+ This endpoint provides a structured listing of contents within a specified
208
+ subdirectory of the user's data storage.
209
+
210
+ Query Parameters:
211
+ - path (optional): The relative path within the user's data directory
212
+ to list. Defaults to the root ('').
213
+
214
+ Returns:
215
+ - 400: If the requested path is invalid, outside the user's data directory, or is not a directory.
216
+ - 404: If the requested path does not exist.
217
+ - 403: If the user is invalid.
218
+ - 500: If there is an error reading the directory contents.
219
+ - 200: JSON response containing a list of file and directory objects.
220
+ Each object includes:
221
+ - name: The name of the file or directory.
222
+ - type: 'file' or 'directory'.
223
+ - path: The relative path from the user's data root.
224
+ - size (for files): The size in bytes.
225
+ - modified (for files): The last modified timestamp (Unix epoch).
226
+ """
227
+ requested_rel_path = request.rel_url.query.get('path', '')
228
+
229
+ # URL-decode the path parameter
230
+ try:
231
+ requested_rel_path = parse.unquote(requested_rel_path)
232
+ except Exception as e:
233
+ logging.warning(f"Failed to decode path parameter: {requested_rel_path}, Error: {e}")
234
+ return web.Response(status=400, text="Invalid characters in path parameter")
235
+
236
+
237
+ # Check user validity and get the absolute path for the requested directory
238
+ try:
239
+ base_user_path = self.get_request_user_filepath(request, None, create_dir=False)
240
+
241
+ if requested_rel_path:
242
+ target_abs_path = self.get_request_user_filepath(request, requested_rel_path, create_dir=False)
243
+ else:
244
+ target_abs_path = base_user_path
245
+
246
+ except KeyError as e:
247
+ # Invalid user detected by get_request_user_id inside get_request_user_filepath
248
+ logging.warning(f"Access denied for user: {e}")
249
+ return web.Response(status=403, text="Invalid user specified in request")
250
+
251
+
252
+ if not target_abs_path:
253
+ # Path traversal or other issue detected by get_request_user_filepath
254
+ return web.Response(status=400, text="Invalid path requested")
255
+
256
+ # Handle cases where the user directory or target path doesn't exist
257
+ if not os.path.exists(target_abs_path):
258
+ # Check if it's the base user directory that's missing (new user case)
259
+ if target_abs_path == base_user_path:
260
+ # It's okay if the base user directory doesn't exist yet, return empty list
261
+ return web.json_response([])
262
+ else:
263
+ # A specific subdirectory was requested but doesn't exist
264
+ return web.Response(status=404, text="Requested path not found")
265
+
266
+ if not os.path.isdir(target_abs_path):
267
+ return web.Response(status=400, text="Requested path is not a directory")
268
+
269
+ results = []
270
+ try:
271
+ for root, dirs, files in os.walk(target_abs_path, topdown=True):
272
+ # Process directories
273
+ for dir_name in dirs:
274
+ dir_path = os.path.join(root, dir_name)
275
+ rel_path = os.path.relpath(dir_path, base_user_path).replace(os.sep, '/')
276
+ results.append({
277
+ "name": dir_name,
278
+ "path": rel_path,
279
+ "type": "directory"
280
+ })
281
+
282
+ # Process files
283
+ for file_name in files:
284
+ file_path = os.path.join(root, file_name)
285
+ rel_path = os.path.relpath(file_path, base_user_path).replace(os.sep, '/')
286
+ entry_info = {
287
+ "name": file_name,
288
+ "path": rel_path,
289
+ "type": "file"
290
+ }
291
+ try:
292
+ stats = os.stat(file_path) # Use os.stat for potentially better performance with os.walk
293
+ entry_info["size"] = stats.st_size
294
+ entry_info["modified"] = stats.st_mtime
295
+ except OSError as stat_error:
296
+ logging.warning(f"Could not stat file {file_path}: {stat_error}")
297
+ pass # Include file with available info
298
+ results.append(entry_info)
299
+ except OSError as e:
300
+ logging.error(f"Error listing directory {target_abs_path}: {e}")
301
+ return web.Response(status=500, text="Error reading directory contents")
302
+
303
+ # Sort results alphabetically, directories first then files
304
+ results.sort(key=lambda x: (x['type'] != 'directory', x['name'].lower()))
305
+
306
+ return web.json_response(results)
307
+
308
+ def get_user_data_path(request, check_exists = False, param = "file"):
309
+ file = request.match_info.get(param, None)
310
+ if not file:
311
+ return web.Response(status=400)
312
+
313
+ path = self.get_request_user_filepath(request, file)
314
+ if not path:
315
+ return web.Response(status=403)
316
+
317
+ if check_exists and not os.path.exists(path):
318
+ return web.Response(status=404)
319
+
320
+ return path
321
+
322
+ @routes.get("/userdata/{file}")
323
+ async def getuserdata(request):
324
+ path = get_user_data_path(request, check_exists=True)
325
+ if not isinstance(path, str):
326
+ return path
327
+
328
+ return web.FileResponse(path)
329
+
330
+ @routes.post("/userdata/{file}")
331
+ async def post_userdata(request):
332
+ """
333
+ Upload or update a user data file.
334
+
335
+ This endpoint handles file uploads to a user's data directory, with options for
336
+ controlling overwrite behavior and response format.
337
+
338
+ Query Parameters:
339
+ - overwrite (optional): If "false", prevents overwriting existing files. Defaults to "true".
340
+ - full_info (optional): If "true", returns detailed file information (path, size, modified time).
341
+ If "false", returns only the relative file path.
342
+
343
+ Path Parameters:
344
+ - file: The target file path (URL encoded if necessary).
345
+
346
+ Returns:
347
+ - 400: If 'file' parameter is missing.
348
+ - 403: If the requested path is not allowed.
349
+ - 409: If overwrite=false and the file already exists.
350
+ - 200: JSON response with either:
351
+ - Full file information (if full_info=true)
352
+ - Relative file path (if full_info=false)
353
+
354
+ The request body should contain the raw file content to be written.
355
+ """
356
+ path = get_user_data_path(request)
357
+ if not isinstance(path, str):
358
+ return path
359
+
360
+ overwrite = request.query.get("overwrite", 'true') != "false"
361
+ full_info = request.query.get('full_info', 'false').lower() == "true"
362
+
363
+ if not overwrite and os.path.exists(path):
364
+ return web.Response(status=409, text="File already exists")
365
+
366
+ body = await request.read()
367
+
368
+ with open(path, "wb") as f:
369
+ f.write(body)
370
+
371
+ user_path = self.get_request_user_filepath(request, None)
372
+ if full_info:
373
+ resp = get_file_info(path, user_path)
374
+ else:
375
+ resp = os.path.relpath(path, user_path)
376
+
377
+ return web.json_response(resp)
378
+
379
+ @routes.delete("/userdata/{file}")
380
+ async def delete_userdata(request):
381
+ path = get_user_data_path(request, check_exists=True)
382
+ if not isinstance(path, str):
383
+ return path
384
+
385
+ os.remove(path)
386
+
387
+ return web.Response(status=204)
388
+
389
+ @routes.post("/userdata/{file}/move/{dest}")
390
+ async def move_userdata(request):
391
+ """
392
+ Move or rename a user data file.
393
+
394
+ This endpoint handles moving or renaming files within a user's data directory, with options for
395
+ controlling overwrite behavior and response format.
396
+
397
+ Path Parameters:
398
+ - file: The source file path (URL encoded if necessary)
399
+ - dest: The destination file path (URL encoded if necessary)
400
+
401
+ Query Parameters:
402
+ - overwrite (optional): If "false", prevents overwriting existing files. Defaults to "true".
403
+ - full_info (optional): If "true", returns detailed file information (path, size, modified time).
404
+ If "false", returns only the relative file path.
405
+
406
+ Returns:
407
+ - 400: If either 'file' or 'dest' parameter is missing
408
+ - 403: If either requested path is not allowed
409
+ - 404: If the source file does not exist
410
+ - 409: If overwrite=false and the destination file already exists
411
+ - 200: JSON response with either:
412
+ - Full file information (if full_info=true)
413
+ - Relative file path (if full_info=false)
414
+ """
415
+ source = get_user_data_path(request, check_exists=True)
416
+ if not isinstance(source, str):
417
+ return source
418
+
419
+ dest = get_user_data_path(request, check_exists=False, param="dest")
420
+ if not isinstance(source, str):
421
+ return dest
422
+
423
+ overwrite = request.query.get("overwrite", 'true') != "false"
424
+ full_info = request.query.get('full_info', 'false').lower() == "true"
425
+
426
+ if not overwrite and os.path.exists(dest):
427
+ return web.Response(status=409, text="File already exists")
428
+
429
+ logging.info(f"moving '{source}' -> '{dest}'")
430
+ shutil.move(source, dest)
431
+
432
+ user_path = self.get_request_user_filepath(request, None)
433
+ if full_info:
434
+ resp = get_file_info(dest, user_path)
435
+ else:
436
+ resp = os.path.relpath(dest, user_path)
437
+
438
+ return web.json_response(resp)
comfy/checkpoint_pickle.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+
3
+ load = pickle.load
4
+
5
+ class Empty:
6
+ pass
7
+
8
+ class Unpickler(pickle.Unpickler):
9
+ def find_class(self, module, name):
10
+ #TODO: safe unpickle
11
+ if module.startswith("pytorch_lightning"):
12
+ return Empty
13
+ return super().find_class(module, name)
comfy/cldm/cldm.py ADDED
@@ -0,0 +1,433 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #taken from: https://github.com/lllyasviel/ControlNet
2
+ #and modified
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from ..ldm.modules.diffusionmodules.util import (
8
+ timestep_embedding,
9
+ )
10
+
11
+ from ..ldm.modules.attention import SpatialTransformer
12
+ from ..ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample
13
+ from ..ldm.util import exists
14
+ from .control_types import UNION_CONTROLNET_TYPES
15
+ from collections import OrderedDict
16
+ import comfy.ops
17
+ from comfy.ldm.modules.attention import optimized_attention
18
+
19
+ class OptimizedAttention(nn.Module):
20
+ def __init__(self, c, nhead, dropout=0.0, dtype=None, device=None, operations=None):
21
+ super().__init__()
22
+ self.heads = nhead
23
+ self.c = c
24
+
25
+ self.in_proj = operations.Linear(c, c * 3, bias=True, dtype=dtype, device=device)
26
+ self.out_proj = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
27
+
28
+ def forward(self, x):
29
+ x = self.in_proj(x)
30
+ q, k, v = x.split(self.c, dim=2)
31
+ out = optimized_attention(q, k, v, self.heads)
32
+ return self.out_proj(out)
33
+
34
+ class QuickGELU(nn.Module):
35
+ def forward(self, x: torch.Tensor):
36
+ return x * torch.sigmoid(1.702 * x)
37
+
38
+ class ResBlockUnionControlnet(nn.Module):
39
+ def __init__(self, dim, nhead, dtype=None, device=None, operations=None):
40
+ super().__init__()
41
+ self.attn = OptimizedAttention(dim, nhead, dtype=dtype, device=device, operations=operations)
42
+ self.ln_1 = operations.LayerNorm(dim, dtype=dtype, device=device)
43
+ self.mlp = nn.Sequential(
44
+ OrderedDict([("c_fc", operations.Linear(dim, dim * 4, dtype=dtype, device=device)), ("gelu", QuickGELU()),
45
+ ("c_proj", operations.Linear(dim * 4, dim, dtype=dtype, device=device))]))
46
+ self.ln_2 = operations.LayerNorm(dim, dtype=dtype, device=device)
47
+
48
+ def attention(self, x: torch.Tensor):
49
+ return self.attn(x)
50
+
51
+ def forward(self, x: torch.Tensor):
52
+ x = x + self.attention(self.ln_1(x))
53
+ x = x + self.mlp(self.ln_2(x))
54
+ return x
55
+
56
+ class ControlledUnetModel(UNetModel):
57
+ #implemented in the ldm unet
58
+ pass
59
+
60
+ class ControlNet(nn.Module):
61
+ def __init__(
62
+ self,
63
+ image_size,
64
+ in_channels,
65
+ model_channels,
66
+ hint_channels,
67
+ num_res_blocks,
68
+ dropout=0,
69
+ channel_mult=(1, 2, 4, 8),
70
+ conv_resample=True,
71
+ dims=2,
72
+ num_classes=None,
73
+ use_checkpoint=False,
74
+ dtype=torch.float32,
75
+ num_heads=-1,
76
+ num_head_channels=-1,
77
+ num_heads_upsample=-1,
78
+ use_scale_shift_norm=False,
79
+ resblock_updown=False,
80
+ use_new_attention_order=False,
81
+ use_spatial_transformer=False, # custom transformer support
82
+ transformer_depth=1, # custom transformer support
83
+ context_dim=None, # custom transformer support
84
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
85
+ legacy=True,
86
+ disable_self_attentions=None,
87
+ num_attention_blocks=None,
88
+ disable_middle_self_attn=False,
89
+ use_linear_in_transformer=False,
90
+ adm_in_channels=None,
91
+ transformer_depth_middle=None,
92
+ transformer_depth_output=None,
93
+ attn_precision=None,
94
+ union_controlnet_num_control_type=None,
95
+ device=None,
96
+ operations=comfy.ops.disable_weight_init,
97
+ **kwargs,
98
+ ):
99
+ super().__init__()
100
+ assert use_spatial_transformer == True, "use_spatial_transformer has to be true"
101
+ if use_spatial_transformer:
102
+ assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
103
+
104
+ if context_dim is not None:
105
+ assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
106
+ # from omegaconf.listconfig import ListConfig
107
+ # if type(context_dim) == ListConfig:
108
+ # context_dim = list(context_dim)
109
+
110
+ if num_heads_upsample == -1:
111
+ num_heads_upsample = num_heads
112
+
113
+ if num_heads == -1:
114
+ assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
115
+
116
+ if num_head_channels == -1:
117
+ assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
118
+
119
+ self.dims = dims
120
+ self.image_size = image_size
121
+ self.in_channels = in_channels
122
+ self.model_channels = model_channels
123
+
124
+ if isinstance(num_res_blocks, int):
125
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
126
+ else:
127
+ if len(num_res_blocks) != len(channel_mult):
128
+ raise ValueError("provide num_res_blocks either as an int (globally constant) or "
129
+ "as a list/tuple (per-level) with the same length as channel_mult")
130
+ self.num_res_blocks = num_res_blocks
131
+
132
+ if disable_self_attentions is not None:
133
+ # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
134
+ assert len(disable_self_attentions) == len(channel_mult)
135
+ if num_attention_blocks is not None:
136
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
137
+ assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
138
+
139
+ transformer_depth = transformer_depth[:]
140
+
141
+ self.dropout = dropout
142
+ self.channel_mult = channel_mult
143
+ self.conv_resample = conv_resample
144
+ self.num_classes = num_classes
145
+ self.use_checkpoint = use_checkpoint
146
+ self.dtype = dtype
147
+ self.num_heads = num_heads
148
+ self.num_head_channels = num_head_channels
149
+ self.num_heads_upsample = num_heads_upsample
150
+ self.predict_codebook_ids = n_embed is not None
151
+
152
+ time_embed_dim = model_channels * 4
153
+ self.time_embed = nn.Sequential(
154
+ operations.Linear(model_channels, time_embed_dim, dtype=self.dtype, device=device),
155
+ nn.SiLU(),
156
+ operations.Linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device),
157
+ )
158
+
159
+ if self.num_classes is not None:
160
+ if isinstance(self.num_classes, int):
161
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
162
+ elif self.num_classes == "continuous":
163
+ self.label_emb = nn.Linear(1, time_embed_dim)
164
+ elif self.num_classes == "sequential":
165
+ assert adm_in_channels is not None
166
+ self.label_emb = nn.Sequential(
167
+ nn.Sequential(
168
+ operations.Linear(adm_in_channels, time_embed_dim, dtype=self.dtype, device=device),
169
+ nn.SiLU(),
170
+ operations.Linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device),
171
+ )
172
+ )
173
+ else:
174
+ raise ValueError()
175
+
176
+ self.input_blocks = nn.ModuleList(
177
+ [
178
+ TimestepEmbedSequential(
179
+ operations.conv_nd(dims, in_channels, model_channels, 3, padding=1, dtype=self.dtype, device=device)
180
+ )
181
+ ]
182
+ )
183
+ self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels, operations=operations, dtype=self.dtype, device=device)])
184
+
185
+ self.input_hint_block = TimestepEmbedSequential(
186
+ operations.conv_nd(dims, hint_channels, 16, 3, padding=1, dtype=self.dtype, device=device),
187
+ nn.SiLU(),
188
+ operations.conv_nd(dims, 16, 16, 3, padding=1, dtype=self.dtype, device=device),
189
+ nn.SiLU(),
190
+ operations.conv_nd(dims, 16, 32, 3, padding=1, stride=2, dtype=self.dtype, device=device),
191
+ nn.SiLU(),
192
+ operations.conv_nd(dims, 32, 32, 3, padding=1, dtype=self.dtype, device=device),
193
+ nn.SiLU(),
194
+ operations.conv_nd(dims, 32, 96, 3, padding=1, stride=2, dtype=self.dtype, device=device),
195
+ nn.SiLU(),
196
+ operations.conv_nd(dims, 96, 96, 3, padding=1, dtype=self.dtype, device=device),
197
+ nn.SiLU(),
198
+ operations.conv_nd(dims, 96, 256, 3, padding=1, stride=2, dtype=self.dtype, device=device),
199
+ nn.SiLU(),
200
+ operations.conv_nd(dims, 256, model_channels, 3, padding=1, dtype=self.dtype, device=device)
201
+ )
202
+
203
+ self._feature_size = model_channels
204
+ input_block_chans = [model_channels]
205
+ ch = model_channels
206
+ ds = 1
207
+ for level, mult in enumerate(channel_mult):
208
+ for nr in range(self.num_res_blocks[level]):
209
+ layers = [
210
+ ResBlock(
211
+ ch,
212
+ time_embed_dim,
213
+ dropout,
214
+ out_channels=mult * model_channels,
215
+ dims=dims,
216
+ use_checkpoint=use_checkpoint,
217
+ use_scale_shift_norm=use_scale_shift_norm,
218
+ dtype=self.dtype,
219
+ device=device,
220
+ operations=operations,
221
+ )
222
+ ]
223
+ ch = mult * model_channels
224
+ num_transformers = transformer_depth.pop(0)
225
+ if num_transformers > 0:
226
+ if num_head_channels == -1:
227
+ dim_head = ch // num_heads
228
+ else:
229
+ num_heads = ch // num_head_channels
230
+ dim_head = num_head_channels
231
+ if legacy:
232
+ #num_heads = 1
233
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
234
+ if exists(disable_self_attentions):
235
+ disabled_sa = disable_self_attentions[level]
236
+ else:
237
+ disabled_sa = False
238
+
239
+ if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
240
+ layers.append(
241
+ SpatialTransformer(
242
+ ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim,
243
+ disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
244
+ use_checkpoint=use_checkpoint, attn_precision=attn_precision, dtype=self.dtype, device=device, operations=operations
245
+ )
246
+ )
247
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
248
+ self.zero_convs.append(self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device))
249
+ self._feature_size += ch
250
+ input_block_chans.append(ch)
251
+ if level != len(channel_mult) - 1:
252
+ out_ch = ch
253
+ self.input_blocks.append(
254
+ TimestepEmbedSequential(
255
+ ResBlock(
256
+ ch,
257
+ time_embed_dim,
258
+ dropout,
259
+ out_channels=out_ch,
260
+ dims=dims,
261
+ use_checkpoint=use_checkpoint,
262
+ use_scale_shift_norm=use_scale_shift_norm,
263
+ down=True,
264
+ dtype=self.dtype,
265
+ device=device,
266
+ operations=operations
267
+ )
268
+ if resblock_updown
269
+ else Downsample(
270
+ ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype, device=device, operations=operations
271
+ )
272
+ )
273
+ )
274
+ ch = out_ch
275
+ input_block_chans.append(ch)
276
+ self.zero_convs.append(self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device))
277
+ ds *= 2
278
+ self._feature_size += ch
279
+
280
+ if num_head_channels == -1:
281
+ dim_head = ch // num_heads
282
+ else:
283
+ num_heads = ch // num_head_channels
284
+ dim_head = num_head_channels
285
+ if legacy:
286
+ #num_heads = 1
287
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
288
+ mid_block = [
289
+ ResBlock(
290
+ ch,
291
+ time_embed_dim,
292
+ dropout,
293
+ dims=dims,
294
+ use_checkpoint=use_checkpoint,
295
+ use_scale_shift_norm=use_scale_shift_norm,
296
+ dtype=self.dtype,
297
+ device=device,
298
+ operations=operations
299
+ )]
300
+ if transformer_depth_middle >= 0:
301
+ mid_block += [SpatialTransformer( # always uses a self-attn
302
+ ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
303
+ disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
304
+ use_checkpoint=use_checkpoint, attn_precision=attn_precision, dtype=self.dtype, device=device, operations=operations
305
+ ),
306
+ ResBlock(
307
+ ch,
308
+ time_embed_dim,
309
+ dropout,
310
+ dims=dims,
311
+ use_checkpoint=use_checkpoint,
312
+ use_scale_shift_norm=use_scale_shift_norm,
313
+ dtype=self.dtype,
314
+ device=device,
315
+ operations=operations
316
+ )]
317
+ self.middle_block = TimestepEmbedSequential(*mid_block)
318
+ self.middle_block_out = self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device)
319
+ self._feature_size += ch
320
+
321
+ if union_controlnet_num_control_type is not None:
322
+ self.num_control_type = union_controlnet_num_control_type
323
+ num_trans_channel = 320
324
+ num_trans_head = 8
325
+ num_trans_layer = 1
326
+ num_proj_channel = 320
327
+ # task_scale_factor = num_trans_channel ** 0.5
328
+ self.task_embedding = nn.Parameter(torch.empty(self.num_control_type, num_trans_channel, dtype=self.dtype, device=device))
329
+
330
+ self.transformer_layes = nn.Sequential(*[ResBlockUnionControlnet(num_trans_channel, num_trans_head, dtype=self.dtype, device=device, operations=operations) for _ in range(num_trans_layer)])
331
+ self.spatial_ch_projs = operations.Linear(num_trans_channel, num_proj_channel, dtype=self.dtype, device=device)
332
+ #-----------------------------------------------------------------------------------------------------
333
+
334
+ control_add_embed_dim = 256
335
+ class ControlAddEmbedding(nn.Module):
336
+ def __init__(self, in_dim, out_dim, num_control_type, dtype=None, device=None, operations=None):
337
+ super().__init__()
338
+ self.num_control_type = num_control_type
339
+ self.in_dim = in_dim
340
+ self.linear_1 = operations.Linear(in_dim * num_control_type, out_dim, dtype=dtype, device=device)
341
+ self.linear_2 = operations.Linear(out_dim, out_dim, dtype=dtype, device=device)
342
+ def forward(self, control_type, dtype, device):
343
+ c_type = torch.zeros((self.num_control_type,), device=device)
344
+ c_type[control_type] = 1.0
345
+ c_type = timestep_embedding(c_type.flatten(), self.in_dim, repeat_only=False).to(dtype).reshape((-1, self.num_control_type * self.in_dim))
346
+ return self.linear_2(torch.nn.functional.silu(self.linear_1(c_type)))
347
+
348
+ self.control_add_embedding = ControlAddEmbedding(control_add_embed_dim, time_embed_dim, self.num_control_type, dtype=self.dtype, device=device, operations=operations)
349
+ else:
350
+ self.task_embedding = None
351
+ self.control_add_embedding = None
352
+
353
+ def union_controlnet_merge(self, hint, control_type, emb, context):
354
+ # Equivalent to: https://github.com/xinsir6/ControlNetPlus/tree/main
355
+ inputs = []
356
+ condition_list = []
357
+
358
+ for idx in range(min(1, len(control_type))):
359
+ controlnet_cond = self.input_hint_block(hint[idx], emb, context)
360
+ feat_seq = torch.mean(controlnet_cond, dim=(2, 3))
361
+ if idx < len(control_type):
362
+ feat_seq += self.task_embedding[control_type[idx]].to(dtype=feat_seq.dtype, device=feat_seq.device)
363
+
364
+ inputs.append(feat_seq.unsqueeze(1))
365
+ condition_list.append(controlnet_cond)
366
+
367
+ x = torch.cat(inputs, dim=1)
368
+ x = self.transformer_layes(x)
369
+ controlnet_cond_fuser = None
370
+ for idx in range(len(control_type)):
371
+ alpha = self.spatial_ch_projs(x[:, idx])
372
+ alpha = alpha.unsqueeze(-1).unsqueeze(-1)
373
+ o = condition_list[idx] + alpha
374
+ if controlnet_cond_fuser is None:
375
+ controlnet_cond_fuser = o
376
+ else:
377
+ controlnet_cond_fuser += o
378
+ return controlnet_cond_fuser
379
+
380
+ def make_zero_conv(self, channels, operations=None, dtype=None, device=None):
381
+ return TimestepEmbedSequential(operations.conv_nd(self.dims, channels, channels, 1, padding=0, dtype=dtype, device=device))
382
+
383
+ def forward(self, x, hint, timesteps, context, y=None, **kwargs):
384
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
385
+ emb = self.time_embed(t_emb)
386
+
387
+ guided_hint = None
388
+ if self.control_add_embedding is not None: #Union Controlnet
389
+ control_type = kwargs.get("control_type", [])
390
+
391
+ if any([c >= self.num_control_type for c in control_type]):
392
+ max_type = max(control_type)
393
+ max_type_name = {
394
+ v: k for k, v in UNION_CONTROLNET_TYPES.items()
395
+ }[max_type]
396
+ raise ValueError(
397
+ f"Control type {max_type_name}({max_type}) is out of range for the number of control types" +
398
+ f"({self.num_control_type}) supported.\n" +
399
+ "Please consider using the ProMax ControlNet Union model.\n" +
400
+ "https://huggingface.co/xinsir/controlnet-union-sdxl-1.0/tree/main"
401
+ )
402
+
403
+ emb += self.control_add_embedding(control_type, emb.dtype, emb.device)
404
+ if len(control_type) > 0:
405
+ if len(hint.shape) < 5:
406
+ hint = hint.unsqueeze(dim=0)
407
+ guided_hint = self.union_controlnet_merge(hint, control_type, emb, context)
408
+
409
+ if guided_hint is None:
410
+ guided_hint = self.input_hint_block(hint, emb, context)
411
+
412
+ out_output = []
413
+ out_middle = []
414
+
415
+ if self.num_classes is not None:
416
+ assert y.shape[0] == x.shape[0]
417
+ emb = emb + self.label_emb(y)
418
+
419
+ h = x
420
+ for module, zero_conv in zip(self.input_blocks, self.zero_convs):
421
+ if guided_hint is not None:
422
+ h = module(h, emb, context)
423
+ h += guided_hint
424
+ guided_hint = None
425
+ else:
426
+ h = module(h, emb, context)
427
+ out_output.append(zero_conv(h, emb, context))
428
+
429
+ h = self.middle_block(h, emb, context)
430
+ out_middle.append(self.middle_block_out(h, emb, context))
431
+
432
+ return {"middle": out_middle, "output": out_output}
433
+
comfy/cldm/control_types.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ UNION_CONTROLNET_TYPES = {
2
+ "openpose": 0,
3
+ "depth": 1,
4
+ "hed/pidi/scribble/ted": 2,
5
+ "canny/lineart/anime_lineart/mlsd": 3,
6
+ "normal": 4,
7
+ "segment": 5,
8
+ "tile": 6,
9
+ "repaint": 7,
10
+ }
comfy/cldm/dit_embedder.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import List, Optional, Tuple
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch import Tensor
7
+
8
+ from comfy.ldm.modules.diffusionmodules.mmdit import DismantledBlock, PatchEmbed, VectorEmbedder, TimestepEmbedder, get_2d_sincos_pos_embed_torch
9
+
10
+
11
+ class ControlNetEmbedder(nn.Module):
12
+
13
+ def __init__(
14
+ self,
15
+ img_size: int,
16
+ patch_size: int,
17
+ in_chans: int,
18
+ attention_head_dim: int,
19
+ num_attention_heads: int,
20
+ adm_in_channels: int,
21
+ num_layers: int,
22
+ main_model_double: int,
23
+ double_y_emb: bool,
24
+ device: torch.device,
25
+ dtype: torch.dtype,
26
+ pos_embed_max_size: Optional[int] = None,
27
+ operations = None,
28
+ ):
29
+ super().__init__()
30
+ self.main_model_double = main_model_double
31
+ self.dtype = dtype
32
+ self.hidden_size = num_attention_heads * attention_head_dim
33
+ self.patch_size = patch_size
34
+ self.x_embedder = PatchEmbed(
35
+ img_size=img_size,
36
+ patch_size=patch_size,
37
+ in_chans=in_chans,
38
+ embed_dim=self.hidden_size,
39
+ strict_img_size=pos_embed_max_size is None,
40
+ device=device,
41
+ dtype=dtype,
42
+ operations=operations,
43
+ )
44
+
45
+ self.t_embedder = TimestepEmbedder(self.hidden_size, dtype=dtype, device=device, operations=operations)
46
+
47
+ self.double_y_emb = double_y_emb
48
+ if self.double_y_emb:
49
+ self.orig_y_embedder = VectorEmbedder(
50
+ adm_in_channels, self.hidden_size, dtype, device, operations=operations
51
+ )
52
+ self.y_embedder = VectorEmbedder(
53
+ self.hidden_size, self.hidden_size, dtype, device, operations=operations
54
+ )
55
+ else:
56
+ self.y_embedder = VectorEmbedder(
57
+ adm_in_channels, self.hidden_size, dtype, device, operations=operations
58
+ )
59
+
60
+ self.transformer_blocks = nn.ModuleList(
61
+ DismantledBlock(
62
+ hidden_size=self.hidden_size, num_heads=num_attention_heads, qkv_bias=True,
63
+ dtype=dtype, device=device, operations=operations
64
+ )
65
+ for _ in range(num_layers)
66
+ )
67
+
68
+ # self.use_y_embedder = pooled_projection_dim != self.time_text_embed.text_embedder.linear_1.in_features
69
+ # TODO double check this logic when 8b
70
+ self.use_y_embedder = True
71
+
72
+ self.controlnet_blocks = nn.ModuleList([])
73
+ for _ in range(len(self.transformer_blocks)):
74
+ controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
75
+ self.controlnet_blocks.append(controlnet_block)
76
+
77
+ self.pos_embed_input = PatchEmbed(
78
+ img_size=img_size,
79
+ patch_size=patch_size,
80
+ in_chans=in_chans,
81
+ embed_dim=self.hidden_size,
82
+ strict_img_size=False,
83
+ device=device,
84
+ dtype=dtype,
85
+ operations=operations,
86
+ )
87
+
88
+ def forward(
89
+ self,
90
+ x: torch.Tensor,
91
+ timesteps: torch.Tensor,
92
+ y: Optional[torch.Tensor] = None,
93
+ context: Optional[torch.Tensor] = None,
94
+ hint = None,
95
+ ) -> Tuple[Tensor, List[Tensor]]:
96
+ x_shape = list(x.shape)
97
+ x = self.x_embedder(x)
98
+ if not self.double_y_emb:
99
+ h = (x_shape[-2] + 1) // self.patch_size
100
+ w = (x_shape[-1] + 1) // self.patch_size
101
+ x += get_2d_sincos_pos_embed_torch(self.hidden_size, w, h, device=x.device)
102
+ c = self.t_embedder(timesteps, dtype=x.dtype)
103
+ if y is not None and self.y_embedder is not None:
104
+ if self.double_y_emb:
105
+ y = self.orig_y_embedder(y)
106
+ y = self.y_embedder(y)
107
+ c = c + y
108
+
109
+ x = x + self.pos_embed_input(hint)
110
+
111
+ block_out = ()
112
+
113
+ repeat = math.ceil(self.main_model_double / len(self.transformer_blocks))
114
+ for i in range(len(self.transformer_blocks)):
115
+ out = self.transformer_blocks[i](x, c)
116
+ if not self.double_y_emb:
117
+ x = out
118
+ block_out += (self.controlnet_blocks[i](out),) * repeat
119
+
120
+ return {"output": block_out}
comfy/cldm/mmdit.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Optional
3
+ import comfy.ldm.modules.diffusionmodules.mmdit
4
+
5
+ class ControlNet(comfy.ldm.modules.diffusionmodules.mmdit.MMDiT):
6
+ def __init__(
7
+ self,
8
+ num_blocks = None,
9
+ control_latent_channels = None,
10
+ dtype = None,
11
+ device = None,
12
+ operations = None,
13
+ **kwargs,
14
+ ):
15
+ super().__init__(dtype=dtype, device=device, operations=operations, final_layer=False, num_blocks=num_blocks, **kwargs)
16
+ # controlnet_blocks
17
+ self.controlnet_blocks = torch.nn.ModuleList([])
18
+ for _ in range(len(self.joint_blocks)):
19
+ self.controlnet_blocks.append(operations.Linear(self.hidden_size, self.hidden_size, device=device, dtype=dtype))
20
+
21
+ if control_latent_channels is None:
22
+ control_latent_channels = self.in_channels
23
+
24
+ self.pos_embed_input = comfy.ldm.modules.diffusionmodules.mmdit.PatchEmbed(
25
+ None,
26
+ self.patch_size,
27
+ control_latent_channels,
28
+ self.hidden_size,
29
+ bias=True,
30
+ strict_img_size=False,
31
+ dtype=dtype,
32
+ device=device,
33
+ operations=operations
34
+ )
35
+
36
+ def forward(
37
+ self,
38
+ x: torch.Tensor,
39
+ timesteps: torch.Tensor,
40
+ y: Optional[torch.Tensor] = None,
41
+ context: Optional[torch.Tensor] = None,
42
+ hint = None,
43
+ ) -> torch.Tensor:
44
+
45
+ #weird sd3 controlnet specific stuff
46
+ y = torch.zeros_like(y)
47
+
48
+ if self.context_processor is not None:
49
+ context = self.context_processor(context)
50
+
51
+ hw = x.shape[-2:]
52
+ x = self.x_embedder(x) + self.cropped_pos_embed(hw, device=x.device).to(dtype=x.dtype, device=x.device)
53
+ x += self.pos_embed_input(hint)
54
+
55
+ c = self.t_embedder(timesteps, dtype=x.dtype)
56
+ if y is not None and self.y_embedder is not None:
57
+ y = self.y_embedder(y)
58
+ c = c + y
59
+
60
+ if context is not None:
61
+ context = self.context_embedder(context)
62
+
63
+ output = []
64
+
65
+ blocks = len(self.joint_blocks)
66
+ for i in range(blocks):
67
+ context, x = self.joint_blocks[i](
68
+ context,
69
+ x,
70
+ c=c,
71
+ use_checkpoint=self.use_checkpoint,
72
+ )
73
+
74
+ out = self.controlnet_blocks[i](x)
75
+ count = self.depth // blocks
76
+ if i == blocks - 1:
77
+ count -= 1
78
+ for j in range(count):
79
+ output.append(out)
80
+
81
+ return {"output": output}
comfy/cli_args.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import enum
3
+ import os
4
+ import comfy.options
5
+
6
+
7
+ class EnumAction(argparse.Action):
8
+ """
9
+ Argparse action for handling Enums
10
+ """
11
+ def __init__(self, **kwargs):
12
+ # Pop off the type value
13
+ enum_type = kwargs.pop("type", None)
14
+
15
+ # Ensure an Enum subclass is provided
16
+ if enum_type is None:
17
+ raise ValueError("type must be assigned an Enum when using EnumAction")
18
+ if not issubclass(enum_type, enum.Enum):
19
+ raise TypeError("type must be an Enum when using EnumAction")
20
+
21
+ # Generate choices from the Enum
22
+ choices = tuple(e.value for e in enum_type)
23
+ kwargs.setdefault("choices", choices)
24
+ kwargs.setdefault("metavar", f"[{','.join(list(choices))}]")
25
+
26
+ super(EnumAction, self).__init__(**kwargs)
27
+
28
+ self._enum = enum_type
29
+
30
+ def __call__(self, parser, namespace, values, option_string=None):
31
+ # Convert value back into an Enum
32
+ value = self._enum(values)
33
+ setattr(namespace, self.dest, value)
34
+
35
+
36
+ parser = argparse.ArgumentParser()
37
+
38
+ parser.add_argument("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0,::", help="Specify the IP address to listen on (default: 127.0.0.1). You can give a list of ip addresses by separating them with a comma like: 127.2.2.2,127.3.3.3 If --listen is provided without an argument, it defaults to 0.0.0.0,:: (listens on all ipv4 and ipv6)")
39
+ parser.add_argument("--port", type=int, default=8188, help="Set the listen port.")
40
+ parser.add_argument("--tls-keyfile", type=str, help="Path to TLS (SSL) key file. Enables TLS, makes app accessible at https://... requires --tls-certfile to function")
41
+ parser.add_argument("--tls-certfile", type=str, help="Path to TLS (SSL) certificate file. Enables TLS, makes app accessible at https://... requires --tls-keyfile to function")
42
+ parser.add_argument("--enable-cors-header", type=str, default=None, metavar="ORIGIN", nargs="?", const="*", help="Enable CORS (Cross-Origin Resource Sharing) with optional origin or allow all with default '*'.")
43
+ parser.add_argument("--max-upload-size", type=float, default=100, help="Set the maximum upload size in MB.")
44
+
45
+ parser.add_argument("--base-directory", type=str, default=None, help="Set the ComfyUI base directory for models, custom_nodes, input, output, temp, and user directories.")
46
+ parser.add_argument("--extra-model-paths-config", type=str, default=None, metavar="PATH", nargs='+', action='append', help="Load one or more extra_model_paths.yaml files.")
47
+ parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory. Overrides --base-directory.")
48
+ parser.add_argument("--temp-directory", type=str, default=None, help="Set the ComfyUI temp directory (default is in the ComfyUI directory). Overrides --base-directory.")
49
+ parser.add_argument("--input-directory", type=str, default=None, help="Set the ComfyUI input directory. Overrides --base-directory.")
50
+ parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.")
51
+ parser.add_argument("--disable-auto-launch", action="store_true", help="Disable auto launching the browser.")
52
+ parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use. All other devices will not be visible.")
53
+ parser.add_argument("--default-device", type=int, default=None, metavar="DEFAULT_DEVICE_ID", help="Set the id of the default device, all other devices will stay visible.")
54
+ cm_group = parser.add_mutually_exclusive_group()
55
+ cm_group.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).")
56
+ cm_group.add_argument("--disable-cuda-malloc", action="store_true", help="Disable cudaMallocAsync.")
57
+
58
+
59
+ fp_group = parser.add_mutually_exclusive_group()
60
+ fp_group.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).")
61
+ fp_group.add_argument("--force-fp16", action="store_true", help="Force fp16.")
62
+
63
+ fpunet_group = parser.add_mutually_exclusive_group()
64
+ fpunet_group.add_argument("--fp32-unet", action="store_true", help="Run the diffusion model in fp32.")
65
+ fpunet_group.add_argument("--fp64-unet", action="store_true", help="Run the diffusion model in fp64.")
66
+ fpunet_group.add_argument("--bf16-unet", action="store_true", help="Run the diffusion model in bf16.")
67
+ fpunet_group.add_argument("--fp16-unet", action="store_true", help="Run the diffusion model in fp16")
68
+ fpunet_group.add_argument("--fp8_e4m3fn-unet", action="store_true", help="Store unet weights in fp8_e4m3fn.")
69
+ fpunet_group.add_argument("--fp8_e5m2-unet", action="store_true", help="Store unet weights in fp8_e5m2.")
70
+ fpunet_group.add_argument("--fp8_e8m0fnu-unet", action="store_true", help="Store unet weights in fp8_e8m0fnu.")
71
+
72
+ fpvae_group = parser.add_mutually_exclusive_group()
73
+ fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in fp16, might cause black images.")
74
+ fpvae_group.add_argument("--fp32-vae", action="store_true", help="Run the VAE in full precision fp32.")
75
+ fpvae_group.add_argument("--bf16-vae", action="store_true", help="Run the VAE in bf16.")
76
+
77
+ parser.add_argument("--cpu-vae", action="store_true", help="Run the VAE on the CPU.")
78
+
79
+ fpte_group = parser.add_mutually_exclusive_group()
80
+ fpte_group.add_argument("--fp8_e4m3fn-text-enc", action="store_true", help="Store text encoder weights in fp8 (e4m3fn variant).")
81
+ fpte_group.add_argument("--fp8_e5m2-text-enc", action="store_true", help="Store text encoder weights in fp8 (e5m2 variant).")
82
+ fpte_group.add_argument("--fp16-text-enc", action="store_true", help="Store text encoder weights in fp16.")
83
+ fpte_group.add_argument("--fp32-text-enc", action="store_true", help="Store text encoder weights in fp32.")
84
+ fpte_group.add_argument("--bf16-text-enc", action="store_true", help="Store text encoder weights in bf16.")
85
+
86
+ parser.add_argument("--force-channels-last", action="store_true", help="Force channels last format when inferencing the models.")
87
+
88
+ parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
89
+
90
+ parser.add_argument("--oneapi-device-selector", type=str, default=None, metavar="SELECTOR_STRING", help="Sets the oneAPI device(s) this instance will use.")
91
+ parser.add_argument("--disable-ipex-optimize", action="store_true", help="Disables ipex.optimize default when loading models with Intel's Extension for Pytorch.")
92
+ parser.add_argument("--supports-fp8-compute", action="store_true", help="ComfyUI will act like if the device supports fp8 compute.")
93
+
94
+ class LatentPreviewMethod(enum.Enum):
95
+ NoPreviews = "none"
96
+ Auto = "auto"
97
+ Latent2RGB = "latent2rgb"
98
+ TAESD = "taesd"
99
+
100
+ parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction)
101
+
102
+ parser.add_argument("--preview-size", type=int, default=512, help="Sets the maximum preview size for sampler nodes.")
103
+
104
+ cache_group = parser.add_mutually_exclusive_group()
105
+ cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
106
+ cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
107
+ cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.")
108
+
109
+ attn_group = parser.add_mutually_exclusive_group()
110
+ attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
111
+ attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.")
112
+ attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.")
113
+ attn_group.add_argument("--use-sage-attention", action="store_true", help="Use sage attention.")
114
+ attn_group.add_argument("--use-flash-attention", action="store_true", help="Use FlashAttention.")
115
+
116
+ parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.")
117
+
118
+ upcast = parser.add_mutually_exclusive_group()
119
+ upcast.add_argument("--force-upcast-attention", action="store_true", help="Force enable attention upcasting, please report if it fixes black images.")
120
+ upcast.add_argument("--dont-upcast-attention", action="store_true", help="Disable all upcasting of attention. Should be unnecessary except for debugging.")
121
+
122
+
123
+ vram_group = parser.add_mutually_exclusive_group()
124
+ vram_group.add_argument("--gpu-only", action="store_true", help="Store and run everything (text encoders/CLIP models, etc... on the GPU).")
125
+ vram_group.add_argument("--highvram", action="store_true", help="By default models will be unloaded to CPU memory after being used. This option keeps them in GPU memory.")
126
+ vram_group.add_argument("--normalvram", action="store_true", help="Used to force normal vram use if lowvram gets automatically enabled.")
127
+ vram_group.add_argument("--lowvram", action="store_true", help="Split the unet in parts to use less vram.")
128
+ vram_group.add_argument("--novram", action="store_true", help="When lowvram isn't enough.")
129
+ vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).")
130
+
131
+ parser.add_argument("--reserve-vram", type=float, default=None, help="Set the amount of vram in GB you want to reserve for use by your OS/other software. By default some amount is reserved depending on your OS.")
132
+
133
+ parser.add_argument("--async-offload", action="store_true", help="Use async weight offloading.")
134
+
135
+ parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.")
136
+
137
+ parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.")
138
+ parser.add_argument("--deterministic", action="store_true", help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.")
139
+
140
+ class PerformanceFeature(enum.Enum):
141
+ Fp16Accumulation = "fp16_accumulation"
142
+ Fp8MatrixMultiplication = "fp8_matrix_mult"
143
+ CublasOps = "cublas_ops"
144
+
145
+ parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: fp16_accumulation fp8_matrix_mult cublas_ops")
146
+
147
+ parser.add_argument("--mmap-torch-files", action="store_true", help="Use mmap when loading ckpt/pt files.")
148
+ parser.add_argument("--disable-mmap", action="store_true", help="Don't use mmap when loading safetensors.")
149
+
150
+ parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
151
+ parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
152
+ parser.add_argument("--windows-standalone-build", action="store_true", help="Windows standalone build: Enable convenient things that most people using the standalone windows build will probably enjoy (like auto opening the page on startup).")
153
+
154
+ parser.add_argument("--disable-metadata", action="store_true", help="Disable saving prompt metadata in files.")
155
+ parser.add_argument("--disable-all-custom-nodes", action="store_true", help="Disable loading all custom nodes.")
156
+ parser.add_argument("--whitelist-custom-nodes", type=str, nargs='+', default=[], help="Specify custom node folders to load even when --disable-all-custom-nodes is enabled.")
157
+ parser.add_argument("--disable-api-nodes", action="store_true", help="Disable loading all api nodes.")
158
+
159
+ parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")
160
+
161
+ parser.add_argument("--verbose", default='INFO', const='DEBUG', nargs="?", choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], help='Set the logging level')
162
+ parser.add_argument("--log-stdout", action="store_true", help="Send normal process output to stdout instead of stderr (default).")
163
+
164
+ # The default built-in provider hosted under web/
165
+ DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest"
166
+
167
+ parser.add_argument(
168
+ "--front-end-version",
169
+ type=str,
170
+ default=DEFAULT_VERSION_STRING,
171
+ help="""
172
+ Specifies the version of the frontend to be used. This command needs internet connectivity to query and
173
+ download available frontend implementations from GitHub releases.
174
+
175
+ The version string should be in the format of:
176
+ [repoOwner]/[repoName]@[version]
177
+ where version is one of: "latest" or a valid version number (e.g. "1.0.0")
178
+ """,
179
+ )
180
+
181
+ def is_valid_directory(path: str) -> str:
182
+ """Validate if the given path is a directory, and check permissions."""
183
+ if not os.path.exists(path):
184
+ raise argparse.ArgumentTypeError(f"The path '{path}' does not exist.")
185
+ if not os.path.isdir(path):
186
+ raise argparse.ArgumentTypeError(f"'{path}' is not a directory.")
187
+ if not os.access(path, os.R_OK):
188
+ raise argparse.ArgumentTypeError(f"You do not have read permissions for '{path}'.")
189
+ return path
190
+
191
+ parser.add_argument(
192
+ "--front-end-root",
193
+ type=is_valid_directory,
194
+ default=None,
195
+ help="The local filesystem path to the directory where the frontend is located. Overrides --front-end-version.",
196
+ )
197
+
198
+ parser.add_argument("--user-directory", type=is_valid_directory, default=None, help="Set the ComfyUI user directory with an absolute path. Overrides --base-directory.")
199
+
200
+ parser.add_argument("--enable-compress-response-body", action="store_true", help="Enable compressing response body.")
201
+
202
+ parser.add_argument(
203
+ "--comfy-api-base",
204
+ type=str,
205
+ default="https://api.comfy.org",
206
+ help="Set the base URL for the ComfyUI API. (default: https://api.comfy.org)",
207
+ )
208
+
209
+ database_default_path = os.path.abspath(
210
+ os.path.join(os.path.dirname(__file__), "..", "user", "comfyui.db")
211
+ )
212
+ parser.add_argument("--database-url", type=str, default=f"sqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.")
213
+
214
+ if comfy.options.args_parsing:
215
+ args = parser.parse_args()
216
+ else:
217
+ args = parser.parse_args([])
218
+
219
+ if args.windows_standalone_build:
220
+ args.auto_launch = True
221
+
222
+ if args.disable_auto_launch:
223
+ args.auto_launch = False
224
+
225
+ if args.force_fp16:
226
+ args.fp16_unet = True
227
+
228
+
229
+ # '--fast' is not provided, use an empty set
230
+ if args.fast is None:
231
+ args.fast = set()
232
+ # '--fast' is provided with an empty list, enable all optimizations
233
+ elif args.fast == []:
234
+ args.fast = set(PerformanceFeature)
235
+ # '--fast' is provided with a list of performance features, use that list
236
+ else:
237
+ args.fast = set(args.fast)
comfy/clip_config_bigg.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "CLIPTextModel"
4
+ ],
5
+ "attention_dropout": 0.0,
6
+ "bos_token_id": 0,
7
+ "dropout": 0.0,
8
+ "eos_token_id": 49407,
9
+ "hidden_act": "gelu",
10
+ "hidden_size": 1280,
11
+ "initializer_factor": 1.0,
12
+ "initializer_range": 0.02,
13
+ "intermediate_size": 5120,
14
+ "layer_norm_eps": 1e-05,
15
+ "max_position_embeddings": 77,
16
+ "model_type": "clip_text_model",
17
+ "num_attention_heads": 20,
18
+ "num_hidden_layers": 32,
19
+ "pad_token_id": 1,
20
+ "projection_dim": 1280,
21
+ "torch_dtype": "float32",
22
+ "vocab_size": 49408
23
+ }
comfy/clip_model.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from comfy.ldm.modules.attention import optimized_attention_for_device
3
+ import comfy.ops
4
+
5
+ class CLIPAttention(torch.nn.Module):
6
+ def __init__(self, embed_dim, heads, dtype, device, operations):
7
+ super().__init__()
8
+
9
+ self.heads = heads
10
+ self.q_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
11
+ self.k_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
12
+ self.v_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
13
+
14
+ self.out_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
15
+
16
+ def forward(self, x, mask=None, optimized_attention=None):
17
+ q = self.q_proj(x)
18
+ k = self.k_proj(x)
19
+ v = self.v_proj(x)
20
+
21
+ out = optimized_attention(q, k, v, self.heads, mask)
22
+ return self.out_proj(out)
23
+
24
+ ACTIVATIONS = {"quick_gelu": lambda a: a * torch.sigmoid(1.702 * a),
25
+ "gelu": torch.nn.functional.gelu,
26
+ "gelu_pytorch_tanh": lambda a: torch.nn.functional.gelu(a, approximate="tanh"),
27
+ }
28
+
29
+ class CLIPMLP(torch.nn.Module):
30
+ def __init__(self, embed_dim, intermediate_size, activation, dtype, device, operations):
31
+ super().__init__()
32
+ self.fc1 = operations.Linear(embed_dim, intermediate_size, bias=True, dtype=dtype, device=device)
33
+ self.activation = ACTIVATIONS[activation]
34
+ self.fc2 = operations.Linear(intermediate_size, embed_dim, bias=True, dtype=dtype, device=device)
35
+
36
+ def forward(self, x):
37
+ x = self.fc1(x)
38
+ x = self.activation(x)
39
+ x = self.fc2(x)
40
+ return x
41
+
42
+ class CLIPLayer(torch.nn.Module):
43
+ def __init__(self, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations):
44
+ super().__init__()
45
+ self.layer_norm1 = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
46
+ self.self_attn = CLIPAttention(embed_dim, heads, dtype, device, operations)
47
+ self.layer_norm2 = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
48
+ self.mlp = CLIPMLP(embed_dim, intermediate_size, intermediate_activation, dtype, device, operations)
49
+
50
+ def forward(self, x, mask=None, optimized_attention=None):
51
+ x += self.self_attn(self.layer_norm1(x), mask, optimized_attention)
52
+ x += self.mlp(self.layer_norm2(x))
53
+ return x
54
+
55
+
56
+ class CLIPEncoder(torch.nn.Module):
57
+ def __init__(self, num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations):
58
+ super().__init__()
59
+ self.layers = torch.nn.ModuleList([CLIPLayer(embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations) for i in range(num_layers)])
60
+
61
+ def forward(self, x, mask=None, intermediate_output=None):
62
+ optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True)
63
+
64
+ if intermediate_output is not None:
65
+ if intermediate_output < 0:
66
+ intermediate_output = len(self.layers) + intermediate_output
67
+
68
+ intermediate = None
69
+ for i, l in enumerate(self.layers):
70
+ x = l(x, mask, optimized_attention)
71
+ if i == intermediate_output:
72
+ intermediate = x.clone()
73
+ return x, intermediate
74
+
75
+ class CLIPEmbeddings(torch.nn.Module):
76
+ def __init__(self, embed_dim, vocab_size=49408, num_positions=77, dtype=None, device=None, operations=None):
77
+ super().__init__()
78
+ self.token_embedding = operations.Embedding(vocab_size, embed_dim, dtype=dtype, device=device)
79
+ self.position_embedding = operations.Embedding(num_positions, embed_dim, dtype=dtype, device=device)
80
+
81
+ def forward(self, input_tokens, dtype=torch.float32):
82
+ return self.token_embedding(input_tokens, out_dtype=dtype) + comfy.ops.cast_to(self.position_embedding.weight, dtype=dtype, device=input_tokens.device)
83
+
84
+
85
+ class CLIPTextModel_(torch.nn.Module):
86
+ def __init__(self, config_dict, dtype, device, operations):
87
+ num_layers = config_dict["num_hidden_layers"]
88
+ embed_dim = config_dict["hidden_size"]
89
+ heads = config_dict["num_attention_heads"]
90
+ intermediate_size = config_dict["intermediate_size"]
91
+ intermediate_activation = config_dict["hidden_act"]
92
+ num_positions = config_dict["max_position_embeddings"]
93
+ self.eos_token_id = config_dict["eos_token_id"]
94
+
95
+ super().__init__()
96
+ self.embeddings = CLIPEmbeddings(embed_dim, num_positions=num_positions, dtype=dtype, device=device, operations=operations)
97
+ self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
98
+ self.final_layer_norm = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
99
+
100
+ def forward(self, input_tokens=None, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=torch.float32):
101
+ if embeds is not None:
102
+ x = embeds + comfy.ops.cast_to(self.embeddings.position_embedding.weight, dtype=dtype, device=embeds.device)
103
+ else:
104
+ x = self.embeddings(input_tokens, dtype=dtype)
105
+
106
+ mask = None
107
+ if attention_mask is not None:
108
+ mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
109
+ mask = mask.masked_fill(mask.to(torch.bool), -torch.finfo(x.dtype).max)
110
+
111
+ causal_mask = torch.full((x.shape[1], x.shape[1]), -torch.finfo(x.dtype).max, dtype=x.dtype, device=x.device).triu_(1)
112
+
113
+ if mask is not None:
114
+ mask += causal_mask
115
+ else:
116
+ mask = causal_mask
117
+
118
+ x, i = self.encoder(x, mask=mask, intermediate_output=intermediate_output)
119
+ x = self.final_layer_norm(x)
120
+ if i is not None and final_layer_norm_intermediate:
121
+ i = self.final_layer_norm(i)
122
+
123
+ if num_tokens is not None:
124
+ pooled_output = x[list(range(x.shape[0])), list(map(lambda a: a - 1, num_tokens))]
125
+ else:
126
+ pooled_output = x[torch.arange(x.shape[0], device=x.device), (torch.round(input_tokens).to(dtype=torch.int, device=x.device) == self.eos_token_id).int().argmax(dim=-1),]
127
+ return x, i, pooled_output
128
+
129
+ class CLIPTextModel(torch.nn.Module):
130
+ def __init__(self, config_dict, dtype, device, operations):
131
+ super().__init__()
132
+ self.num_layers = config_dict["num_hidden_layers"]
133
+ self.text_model = CLIPTextModel_(config_dict, dtype, device, operations)
134
+ embed_dim = config_dict["hidden_size"]
135
+ self.text_projection = operations.Linear(embed_dim, embed_dim, bias=False, dtype=dtype, device=device)
136
+ self.dtype = dtype
137
+
138
+ def get_input_embeddings(self):
139
+ return self.text_model.embeddings.token_embedding
140
+
141
+ def set_input_embeddings(self, embeddings):
142
+ self.text_model.embeddings.token_embedding = embeddings
143
+
144
+ def forward(self, *args, **kwargs):
145
+ x = self.text_model(*args, **kwargs)
146
+ out = self.text_projection(x[2])
147
+ return (x[0], x[1], out, x[2])
148
+
149
+
150
+ class CLIPVisionEmbeddings(torch.nn.Module):
151
+ def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, model_type="", dtype=None, device=None, operations=None):
152
+ super().__init__()
153
+
154
+ num_patches = (image_size // patch_size) ** 2
155
+ if model_type == "siglip_vision_model":
156
+ self.class_embedding = None
157
+ patch_bias = True
158
+ else:
159
+ num_patches = num_patches + 1
160
+ self.class_embedding = torch.nn.Parameter(torch.empty(embed_dim, dtype=dtype, device=device))
161
+ patch_bias = False
162
+
163
+ self.patch_embedding = operations.Conv2d(
164
+ in_channels=num_channels,
165
+ out_channels=embed_dim,
166
+ kernel_size=patch_size,
167
+ stride=patch_size,
168
+ bias=patch_bias,
169
+ dtype=dtype,
170
+ device=device
171
+ )
172
+
173
+ self.position_embedding = operations.Embedding(num_patches, embed_dim, dtype=dtype, device=device)
174
+
175
+ def forward(self, pixel_values):
176
+ embeds = self.patch_embedding(pixel_values).flatten(2).transpose(1, 2)
177
+ if self.class_embedding is not None:
178
+ embeds = torch.cat([comfy.ops.cast_to_input(self.class_embedding, embeds).expand(pixel_values.shape[0], 1, -1), embeds], dim=1)
179
+ return embeds + comfy.ops.cast_to_input(self.position_embedding.weight, embeds)
180
+
181
+
182
+ class CLIPVision(torch.nn.Module):
183
+ def __init__(self, config_dict, dtype, device, operations):
184
+ super().__init__()
185
+ num_layers = config_dict["num_hidden_layers"]
186
+ embed_dim = config_dict["hidden_size"]
187
+ heads = config_dict["num_attention_heads"]
188
+ intermediate_size = config_dict["intermediate_size"]
189
+ intermediate_activation = config_dict["hidden_act"]
190
+ model_type = config_dict["model_type"]
191
+
192
+ self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], model_type=model_type, dtype=dtype, device=device, operations=operations)
193
+ if model_type == "siglip_vision_model":
194
+ self.pre_layrnorm = lambda a: a
195
+ self.output_layernorm = True
196
+ else:
197
+ self.pre_layrnorm = operations.LayerNorm(embed_dim)
198
+ self.output_layernorm = False
199
+ self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
200
+ self.post_layernorm = operations.LayerNorm(embed_dim)
201
+
202
+ def forward(self, pixel_values, attention_mask=None, intermediate_output=None):
203
+ x = self.embeddings(pixel_values)
204
+ x = self.pre_layrnorm(x)
205
+ #TODO: attention_mask?
206
+ x, i = self.encoder(x, mask=None, intermediate_output=intermediate_output)
207
+ if self.output_layernorm:
208
+ x = self.post_layernorm(x)
209
+ pooled_output = x
210
+ else:
211
+ pooled_output = self.post_layernorm(x[:, 0, :])
212
+ return x, i, pooled_output
213
+
214
+ class LlavaProjector(torch.nn.Module):
215
+ def __init__(self, in_dim, out_dim, dtype, device, operations):
216
+ super().__init__()
217
+ self.linear_1 = operations.Linear(in_dim, out_dim, bias=True, device=device, dtype=dtype)
218
+ self.linear_2 = operations.Linear(out_dim, out_dim, bias=True, device=device, dtype=dtype)
219
+
220
+ def forward(self, x):
221
+ return self.linear_2(torch.nn.functional.gelu(self.linear_1(x[:, 1:])))
222
+
223
+ class CLIPVisionModelProjection(torch.nn.Module):
224
+ def __init__(self, config_dict, dtype, device, operations):
225
+ super().__init__()
226
+ self.vision_model = CLIPVision(config_dict, dtype, device, operations)
227
+ if "projection_dim" in config_dict:
228
+ self.visual_projection = operations.Linear(config_dict["hidden_size"], config_dict["projection_dim"], bias=False)
229
+ else:
230
+ self.visual_projection = lambda a: a
231
+
232
+ if "llava3" == config_dict.get("projector_type", None):
233
+ self.multi_modal_projector = LlavaProjector(config_dict["hidden_size"], 4096, dtype, device, operations)
234
+ else:
235
+ self.multi_modal_projector = None
236
+
237
+ def forward(self, *args, **kwargs):
238
+ x = self.vision_model(*args, **kwargs)
239
+ out = self.visual_projection(x[2])
240
+ projected = None
241
+ if self.multi_modal_projector is not None:
242
+ projected = self.multi_modal_projector(x[1])
243
+
244
+ return (x[0], x[1], out, projected)
comfy/clip_vision.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .utils import load_torch_file, transformers_convert, state_dict_prefix_replace
2
+ import os
3
+ import torch
4
+ import json
5
+ import logging
6
+
7
+ import comfy.ops
8
+ import comfy.model_patcher
9
+ import comfy.model_management
10
+ import comfy.utils
11
+ import comfy.clip_model
12
+ import comfy.image_encoders.dino2
13
+
14
+ class Output:
15
+ def __getitem__(self, key):
16
+ return getattr(self, key)
17
+ def __setitem__(self, key, item):
18
+ setattr(self, key, item)
19
+
20
+ def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True):
21
+ image = image[:, :, :, :3] if image.shape[3] > 3 else image
22
+ mean = torch.tensor(mean, device=image.device, dtype=image.dtype)
23
+ std = torch.tensor(std, device=image.device, dtype=image.dtype)
24
+ image = image.movedim(-1, 1)
25
+ if not (image.shape[2] == size and image.shape[3] == size):
26
+ if crop:
27
+ scale = (size / min(image.shape[2], image.shape[3]))
28
+ scale_size = (round(scale * image.shape[2]), round(scale * image.shape[3]))
29
+ else:
30
+ scale_size = (size, size)
31
+
32
+ image = torch.nn.functional.interpolate(image, size=scale_size, mode="bicubic", antialias=True)
33
+ h = (image.shape[2] - size)//2
34
+ w = (image.shape[3] - size)//2
35
+ image = image[:,:,h:h+size,w:w+size]
36
+ image = torch.clip((255. * image), 0, 255).round() / 255.0
37
+ return (image - mean.view([3,1,1])) / std.view([3,1,1])
38
+
39
+ IMAGE_ENCODERS = {
40
+ "clip_vision_model": comfy.clip_model.CLIPVisionModelProjection,
41
+ "siglip_vision_model": comfy.clip_model.CLIPVisionModelProjection,
42
+ "dinov2": comfy.image_encoders.dino2.Dinov2Model,
43
+ }
44
+
45
+ class ClipVisionModel():
46
+ def __init__(self, json_config):
47
+ with open(json_config) as f:
48
+ config = json.load(f)
49
+
50
+ self.image_size = config.get("image_size", 224)
51
+ self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073])
52
+ self.image_std = config.get("image_std", [0.26862954, 0.26130258, 0.27577711])
53
+ model_class = IMAGE_ENCODERS.get(config.get("model_type", "clip_vision_model"))
54
+ self.load_device = comfy.model_management.text_encoder_device()
55
+ offload_device = comfy.model_management.text_encoder_offload_device()
56
+ self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
57
+ self.model = model_class(config, self.dtype, offload_device, comfy.ops.manual_cast)
58
+ self.model.eval()
59
+
60
+ self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
61
+
62
+ def load_sd(self, sd):
63
+ return self.model.load_state_dict(sd, strict=False)
64
+
65
+ def get_sd(self):
66
+ return self.model.state_dict()
67
+
68
+ def encode_image(self, image, crop=True):
69
+ comfy.model_management.load_model_gpu(self.patcher)
70
+ pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float()
71
+ out = self.model(pixel_values=pixel_values, intermediate_output=-2)
72
+
73
+ outputs = Output()
74
+ outputs["last_hidden_state"] = out[0].to(comfy.model_management.intermediate_device())
75
+ outputs["image_embeds"] = out[2].to(comfy.model_management.intermediate_device())
76
+ outputs["penultimate_hidden_states"] = out[1].to(comfy.model_management.intermediate_device())
77
+ outputs["mm_projected"] = out[3]
78
+ return outputs
79
+
80
+ def convert_to_transformers(sd, prefix):
81
+ sd_k = sd.keys()
82
+ if "{}transformer.resblocks.0.attn.in_proj_weight".format(prefix) in sd_k:
83
+ keys_to_replace = {
84
+ "{}class_embedding".format(prefix): "vision_model.embeddings.class_embedding",
85
+ "{}conv1.weight".format(prefix): "vision_model.embeddings.patch_embedding.weight",
86
+ "{}positional_embedding".format(prefix): "vision_model.embeddings.position_embedding.weight",
87
+ "{}ln_post.bias".format(prefix): "vision_model.post_layernorm.bias",
88
+ "{}ln_post.weight".format(prefix): "vision_model.post_layernorm.weight",
89
+ "{}ln_pre.bias".format(prefix): "vision_model.pre_layrnorm.bias",
90
+ "{}ln_pre.weight".format(prefix): "vision_model.pre_layrnorm.weight",
91
+ }
92
+
93
+ for x in keys_to_replace:
94
+ if x in sd_k:
95
+ sd[keys_to_replace[x]] = sd.pop(x)
96
+
97
+ if "{}proj".format(prefix) in sd_k:
98
+ sd['visual_projection.weight'] = sd.pop("{}proj".format(prefix)).transpose(0, 1)
99
+
100
+ sd = transformers_convert(sd, prefix, "vision_model.", 48)
101
+ else:
102
+ replace_prefix = {prefix: ""}
103
+ sd = state_dict_prefix_replace(sd, replace_prefix)
104
+ return sd
105
+
106
+ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
107
+ if convert_keys:
108
+ sd = convert_to_transformers(sd, prefix)
109
+ if "vision_model.encoder.layers.47.layer_norm1.weight" in sd:
110
+ json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_g.json")
111
+ elif "vision_model.encoder.layers.30.layer_norm1.weight" in sd:
112
+ json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json")
113
+ elif "vision_model.encoder.layers.22.layer_norm1.weight" in sd:
114
+ embed_shape = sd["vision_model.embeddings.position_embedding.weight"].shape[0]
115
+ if sd["vision_model.encoder.layers.0.layer_norm1.weight"].shape[0] == 1152:
116
+ if embed_shape == 729:
117
+ json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_384.json")
118
+ elif embed_shape == 1024:
119
+ json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_512.json")
120
+ elif embed_shape == 577:
121
+ if "multi_modal_projector.linear_1.bias" in sd:
122
+ json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336_llava.json")
123
+ else:
124
+ json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336.json")
125
+ else:
126
+ json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
127
+ elif "embeddings.patch_embeddings.projection.weight" in sd:
128
+ json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_giant.json")
129
+ else:
130
+ return None
131
+
132
+ clip = ClipVisionModel(json_config)
133
+ m, u = clip.load_sd(sd)
134
+ if len(m) > 0:
135
+ logging.warning("missing clip vision: {}".format(m))
136
+ u = set(u)
137
+ keys = list(sd.keys())
138
+ for k in keys:
139
+ if k not in u:
140
+ sd.pop(k)
141
+ return clip
142
+
143
+ def load(ckpt_path):
144
+ sd = load_torch_file(ckpt_path)
145
+ if "visual.transformer.resblocks.0.attn.in_proj_weight" in sd:
146
+ return load_clipvision_from_sd(sd, prefix="visual.", convert_keys=True)
147
+ else:
148
+ return load_clipvision_from_sd(sd)
comfy/clip_vision_config_g.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_dropout": 0.0,
3
+ "dropout": 0.0,
4
+ "hidden_act": "gelu",
5
+ "hidden_size": 1664,
6
+ "image_size": 224,
7
+ "initializer_factor": 1.0,
8
+ "initializer_range": 0.02,
9
+ "intermediate_size": 8192,
10
+ "layer_norm_eps": 1e-05,
11
+ "model_type": "clip_vision_model",
12
+ "num_attention_heads": 16,
13
+ "num_channels": 3,
14
+ "num_hidden_layers": 48,
15
+ "patch_size": 14,
16
+ "projection_dim": 1280,
17
+ "torch_dtype": "float32"
18
+ }
comfy/clip_vision_config_h.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_dropout": 0.0,
3
+ "dropout": 0.0,
4
+ "hidden_act": "gelu",
5
+ "hidden_size": 1280,
6
+ "image_size": 224,
7
+ "initializer_factor": 1.0,
8
+ "initializer_range": 0.02,
9
+ "intermediate_size": 5120,
10
+ "layer_norm_eps": 1e-05,
11
+ "model_type": "clip_vision_model",
12
+ "num_attention_heads": 16,
13
+ "num_channels": 3,
14
+ "num_hidden_layers": 32,
15
+ "patch_size": 14,
16
+ "projection_dim": 1024,
17
+ "torch_dtype": "float32"
18
+ }
comfy/clip_vision_config_vitl.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_dropout": 0.0,
3
+ "dropout": 0.0,
4
+ "hidden_act": "quick_gelu",
5
+ "hidden_size": 1024,
6
+ "image_size": 224,
7
+ "initializer_factor": 1.0,
8
+ "initializer_range": 0.02,
9
+ "intermediate_size": 4096,
10
+ "layer_norm_eps": 1e-05,
11
+ "model_type": "clip_vision_model",
12
+ "num_attention_heads": 16,
13
+ "num_channels": 3,
14
+ "num_hidden_layers": 24,
15
+ "patch_size": 14,
16
+ "projection_dim": 768,
17
+ "torch_dtype": "float32"
18
+ }
comfy/clip_vision_config_vitl_336.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_dropout": 0.0,
3
+ "dropout": 0.0,
4
+ "hidden_act": "quick_gelu",
5
+ "hidden_size": 1024,
6
+ "image_size": 336,
7
+ "initializer_factor": 1.0,
8
+ "initializer_range": 0.02,
9
+ "intermediate_size": 4096,
10
+ "layer_norm_eps": 1e-5,
11
+ "model_type": "clip_vision_model",
12
+ "num_attention_heads": 16,
13
+ "num_channels": 3,
14
+ "num_hidden_layers": 24,
15
+ "patch_size": 14,
16
+ "projection_dim": 768,
17
+ "torch_dtype": "float32"
18
+ }
comfy/clip_vision_config_vitl_336_llava.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_dropout": 0.0,
3
+ "dropout": 0.0,
4
+ "hidden_act": "quick_gelu",
5
+ "hidden_size": 1024,
6
+ "image_size": 336,
7
+ "initializer_factor": 1.0,
8
+ "initializer_range": 0.02,
9
+ "intermediate_size": 4096,
10
+ "layer_norm_eps": 1e-5,
11
+ "model_type": "clip_vision_model",
12
+ "num_attention_heads": 16,
13
+ "num_channels": 3,
14
+ "num_hidden_layers": 24,
15
+ "patch_size": 14,
16
+ "projection_dim": 768,
17
+ "projector_type": "llava3",
18
+ "torch_dtype": "float32"
19
+ }
comfy/clip_vision_siglip_384.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "num_channels": 3,
3
+ "hidden_act": "gelu_pytorch_tanh",
4
+ "hidden_size": 1152,
5
+ "image_size": 384,
6
+ "intermediate_size": 4304,
7
+ "model_type": "siglip_vision_model",
8
+ "num_attention_heads": 16,
9
+ "num_hidden_layers": 27,
10
+ "patch_size": 14,
11
+ "image_mean": [0.5, 0.5, 0.5],
12
+ "image_std": [0.5, 0.5, 0.5]
13
+ }
comfy/clip_vision_siglip_512.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "num_channels": 3,
3
+ "hidden_act": "gelu_pytorch_tanh",
4
+ "hidden_size": 1152,
5
+ "image_size": 512,
6
+ "intermediate_size": 4304,
7
+ "model_type": "siglip_vision_model",
8
+ "num_attention_heads": 16,
9
+ "num_hidden_layers": 27,
10
+ "patch_size": 16,
11
+ "image_mean": [0.5, 0.5, 0.5],
12
+ "image_std": [0.5, 0.5, 0.5]
13
+ }
comfy/comfy_types/README.md ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Comfy Typing
2
+ ## Type hinting for ComfyUI Node development
3
+
4
+ This module provides type hinting and concrete convenience types for node developers.
5
+ If cloned to the custom_nodes directory of ComfyUI, types can be imported using:
6
+
7
+ ```python
8
+ from comfy.comfy_types import IO, ComfyNodeABC, CheckLazyMixin
9
+
10
+ class ExampleNode(ComfyNodeABC):
11
+ @classmethod
12
+ def INPUT_TYPES(s) -> InputTypeDict:
13
+ return {"required": {}}
14
+ ```
15
+
16
+ Full example is in [examples/example_nodes.py](examples/example_nodes.py).
17
+
18
+ # Types
19
+ A few primary types are documented below. More complete information is available via the docstrings on each type.
20
+
21
+ ## `IO`
22
+
23
+ A string enum of built-in and a few custom data types. Includes the following special types and their requisite plumbing:
24
+
25
+ - `ANY`: `"*"`
26
+ - `NUMBER`: `"FLOAT,INT"`
27
+ - `PRIMITIVE`: `"STRING,FLOAT,INT,BOOLEAN"`
28
+
29
+ ## `ComfyNodeABC`
30
+
31
+ An abstract base class for nodes, offering type-hinting / autocomplete, and somewhat-alright docstrings.
32
+
33
+ ### Type hinting for `INPUT_TYPES`
34
+
35
+ ![INPUT_TYPES auto-completion in Visual Studio Code](examples/input_types.png)
36
+
37
+ ### `INPUT_TYPES` return dict
38
+
39
+ ![INPUT_TYPES return value type hinting in Visual Studio Code](examples/required_hint.png)
40
+
41
+ ### Options for individual inputs
42
+
43
+ ![INPUT_TYPES return value option auto-completion in Visual Studio Code](examples/input_options.png)
comfy/comfy_types/__init__.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Callable, Protocol, TypedDict, Optional, List
3
+ from .node_typing import IO, InputTypeDict, ComfyNodeABC, CheckLazyMixin, FileLocator
4
+
5
+
6
+ class UnetApplyFunction(Protocol):
7
+ """Function signature protocol on comfy.model_base.BaseModel.apply_model"""
8
+
9
+ def __call__(self, x: torch.Tensor, t: torch.Tensor, **kwargs) -> torch.Tensor:
10
+ pass
11
+
12
+
13
+ class UnetApplyConds(TypedDict):
14
+ """Optional conditions for unet apply function."""
15
+
16
+ c_concat: Optional[torch.Tensor]
17
+ c_crossattn: Optional[torch.Tensor]
18
+ control: Optional[torch.Tensor]
19
+ transformer_options: Optional[dict]
20
+
21
+
22
+ class UnetParams(TypedDict):
23
+ # Tensor of shape [B, C, H, W]
24
+ input: torch.Tensor
25
+ # Tensor of shape [B]
26
+ timestep: torch.Tensor
27
+ c: UnetApplyConds
28
+ # List of [0, 1], [0], [1], ...
29
+ # 0 means conditional, 1 means conditional unconditional
30
+ cond_or_uncond: List[int]
31
+
32
+
33
+ UnetWrapperFunction = Callable[[UnetApplyFunction, UnetParams], torch.Tensor]
34
+
35
+
36
+ __all__ = [
37
+ "UnetWrapperFunction",
38
+ UnetApplyConds.__name__,
39
+ UnetParams.__name__,
40
+ UnetApplyFunction.__name__,
41
+ IO.__name__,
42
+ InputTypeDict.__name__,
43
+ ComfyNodeABC.__name__,
44
+ CheckLazyMixin.__name__,
45
+ FileLocator.__name__,
46
+ ]
comfy/comfy_types/examples/example_nodes.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict
2
+ from inspect import cleandoc
3
+
4
+
5
+ class ExampleNode(ComfyNodeABC):
6
+ """An example node that just adds 1 to an input integer.
7
+
8
+ * Requires a modern IDE to provide any benefit (detail: an IDE configured with analysis paths etc).
9
+ * This node is intended as an example for developers only.
10
+ """
11
+
12
+ DESCRIPTION = cleandoc(__doc__)
13
+ CATEGORY = "examples"
14
+
15
+ @classmethod
16
+ def INPUT_TYPES(s) -> InputTypeDict:
17
+ return {
18
+ "required": {
19
+ "input_int": (IO.INT, {"defaultInput": True}),
20
+ }
21
+ }
22
+
23
+ RETURN_TYPES = (IO.INT,)
24
+ RETURN_NAMES = ("input_plus_one",)
25
+ FUNCTION = "execute"
26
+
27
+ def execute(self, input_int: int):
28
+ return (input_int + 1,)
comfy/comfy_types/examples/input_options.png ADDED
comfy/comfy_types/examples/input_types.png ADDED
comfy/comfy_types/examples/required_hint.png ADDED