Spaces:
Paused
Paused
Upload 430 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +3 -0
- FAQ.md +10 -0
- LICENSE +21 -0
- README.md +473 -8
- assets/VAE_test1.jpg +3 -0
- assets/glif.svg +40 -0
- assets/lora_ease_ui.png +3 -0
- build_and_push_docker +29 -0
- build_and_push_docker_dev +21 -0
- config/examples/extract.example.yml +75 -0
- config/examples/generate.example.yaml +60 -0
- config/examples/mod_lora_scale.yaml +48 -0
- config/examples/modal/modal_train_lora_flux_24gb.yaml +96 -0
- config/examples/modal/modal_train_lora_flux_schnell_24gb.yaml +98 -0
- config/examples/train_flex_redux.yaml +112 -0
- config/examples/train_full_fine_tune_flex.yaml +107 -0
- config/examples/train_full_fine_tune_lumina.yaml +99 -0
- config/examples/train_lora_chroma_24gb.yaml +104 -0
- config/examples/train_lora_flex2_24gb.yaml +165 -0
- config/examples/train_lora_flex_24gb.yaml +101 -0
- config/examples/train_lora_flux_24gb.yaml +96 -0
- config/examples/train_lora_flux_kontext_24gb.yaml +106 -0
- config/examples/train_lora_flux_schnell_24gb.yaml +98 -0
- config/examples/train_lora_hidream_48.yaml +112 -0
- config/examples/train_lora_lumina.yaml +96 -0
- config/examples/train_lora_omnigen2_24gb.yaml +94 -0
- config/examples/train_lora_sd35_large_24gb.yaml +97 -0
- config/examples/train_lora_wan21_14b_24gb.yaml +101 -0
- config/examples/train_lora_wan21_1b_24gb.yaml +90 -0
- config/examples/train_slider.example.yml +230 -0
- docker-compose.yml +25 -0
- docker/Dockerfile +83 -0
- docker/start.sh +70 -0
- extensions/example/ExampleMergeModels.py +129 -0
- extensions/example/__init__.py +25 -0
- extensions/example/config/config.example.yaml +48 -0
- extensions_built_in/.DS_Store +0 -0
- extensions_built_in/advanced_generator/Img2ImgGenerator.py +256 -0
- extensions_built_in/advanced_generator/PureLoraGenerator.py +102 -0
- extensions_built_in/advanced_generator/ReferenceGenerator.py +212 -0
- extensions_built_in/advanced_generator/__init__.py +59 -0
- extensions_built_in/advanced_generator/config/train.example.yaml +91 -0
- extensions_built_in/concept_replacer/ConceptReplacer.py +151 -0
- extensions_built_in/concept_replacer/__init__.py +26 -0
- extensions_built_in/concept_replacer/config/train.example.yaml +91 -0
- extensions_built_in/dataset_tools/DatasetTools.py +20 -0
- extensions_built_in/dataset_tools/SuperTagger.py +196 -0
- extensions_built_in/dataset_tools/SyncFromCollection.py +131 -0
- extensions_built_in/dataset_tools/__init__.py +43 -0
- extensions_built_in/dataset_tools/tools/caption.py +53 -0
.gitattributes
CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
assets/lora_ease_ui.png filter=lfs diff=lfs merge=lfs -text
|
37 |
+
assets/VAE_test1.jpg filter=lfs diff=lfs merge=lfs -text
|
38 |
+
toolkit/timestep_weighing/flex_timestep_weights_plot.png filter=lfs diff=lfs merge=lfs -text
|
FAQ.md
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# FAQ
|
2 |
+
|
3 |
+
WIP. Will continue to add things as they are needed.
|
4 |
+
|
5 |
+
## FLUX.1 Training
|
6 |
+
|
7 |
+
#### How much VRAM is required to train a lora on FLUX.1?
|
8 |
+
|
9 |
+
24GB minimum is required.
|
10 |
+
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2024 Ostris, LLC
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
CHANGED
@@ -1,11 +1,476 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
---
|
10 |
|
11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# AI Toolkit by Ostris
|
2 |
+
|
3 |
+
AI Toolkit is an all in one training suite for diffusion models. I try to support all the latest models on consumer grade hardware. Image and video models. It can be run as a GUI or CLI. It is designed to be easy to use but still have every feature imaginable.
|
4 |
+
|
5 |
+
## Support My Work
|
6 |
+
|
7 |
+
If you enjoy my projects or use them commercially, please consider sponsoring me. Every bit helps! 💖
|
8 |
+
|
9 |
+
[Sponsor on GitHub](https://github.com/orgs/ostris) | [Support on Patreon](https://www.patreon.com/ostris) | [Donate on PayPal](https://www.paypal.com/donate/?hosted_button_id=9GEFUKC8T9R9W)
|
10 |
+
|
11 |
+
### Current Sponsors
|
12 |
+
|
13 |
+
All of these people / organizations are the ones who selflessly make this project possible. Thank you!!
|
14 |
+
|
15 |
+
_Last updated: 2025-08-08 17:01 UTC_
|
16 |
+
|
17 |
+
<p align="center">
|
18 |
+
<a href="https://x.com/NuxZoe" target="_blank" rel="noopener noreferrer"><img src="https://pbs.twimg.com/profile_images/1919488160125616128/QAZXTMEj_400x400.png" alt="a16z" width="200" height="200" style="border-radius:8px;margin:5px;display: inline-block;"></a>
|
19 |
+
<a href="https://github.com/replicate" target="_blank" rel="noopener noreferrer"><img src="https://avatars.githubusercontent.com/u/60410876?v=4" alt="Replicate" width="200" height="200" style="border-radius:8px;margin:5px;display: inline-block;"></a>
|
20 |
+
<a href="https://github.com/huggingface" target="_blank" rel="noopener noreferrer"><img src="https://avatars.githubusercontent.com/u/25720743?v=4" alt="Hugging Face" width="200" height="200" style="border-radius:8px;margin:5px;display: inline-block;"></a>
|
21 |
+
<a href="https://github.com/josephrocca" target="_blank" rel="noopener noreferrer"><img src="https://avatars.githubusercontent.com/u/1167575?u=92d92921b4cb5c8c7e225663fed53c4b41897736&v=4" alt="josephrocca" width="200" height="200" style="border-radius:8px;margin:5px;display: inline-block;"></a>
|
22 |
+
<img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/162524101/81a72689c3754ac5b9e38612ce5ce914/eyJ3IjoyMDB9/1.png?token-hash=JHRjAxd2XxV1aXIUijj-l65pfTnLoefYSvgNPAsw2lI%3D" alt="Prasanth Veerina" width="200" height="200" style="border-radius:8px;margin:5px;display: inline-block;">
|
23 |
+
<a href="https://github.com/weights-ai" target="_blank" rel="noopener noreferrer"><img src="https://avatars.githubusercontent.com/u/185568492?v=4" alt="Weights" width="200" height="200" style="border-radius:8px;margin:5px;display: inline-block;"></a>
|
24 |
+
</p>
|
25 |
+
<hr style="width:100%;border:none;height:2px;background:#ddd;margin:30px 0;">
|
26 |
+
<p align="center">
|
27 |
+
<img src="https://c8.patreon.com/4/200/93304/J" alt="Joseph Rocca" width="150" height="150" style="border-radius:8px;margin:5px;display: inline-block;">
|
28 |
+
<img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/161471720/dd330b4036d44a5985ed5985c12a5def/eyJ3IjoyMDB9/1.jpeg?token-hash=k1f4Vv7TevzYa9tqlzAjsogYmkZs8nrXQohPCDGJGkc%3D" alt="Vladimir Sotnikov" width="150" height="150" style="border-radius:8px;margin:5px;display: inline-block;">
|
29 |
+
<img src="https://c8.patreon.com/4/200/33158543/C" alt="clement Delangue" width="150" height="150" style="border-radius:8px;margin:5px;display: inline-block;">
|
30 |
+
<img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/8654302/b0f5ebedc62a47c4b56222693e1254e9/eyJ3IjoyMDB9/2.jpeg?token-hash=suI7_QjKUgWpdPuJPaIkElkTrXfItHlL8ZHLPT-w_d4%3D" alt="Misch Strotz" width="150" height="150" style="border-radius:8px;margin:5px;display: inline-block;">
|
31 |
+
<img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/120239481/49b1ce70d3d24704b8ec34de24ec8f55/eyJ3IjoyMDB9/1.jpeg?token-hash=o0y1JqSXqtGvVXnxb06HMXjQXs6OII9yMMx5WyyUqT4%3D" alt="nitish PNR" width="150" height="150" style="border-radius:8px;margin:5px;display: inline-block;">
|
32 |
+
</p>
|
33 |
+
<hr style="width:100%;border:none;height:2px;background:#ddd;margin:30px 0;">
|
34 |
+
<p align="center">
|
35 |
+
<img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/2298192/1228b69bd7d7481baf3103315183250d/eyJ3IjoyMDB9/1.jpg?token-hash=opN1e4r4Nnvqbtr8R9HI8eyf9m5F50CiHDOdHzb4UcA%3D" alt="Mohamed Oumoumad" width="100" height="100" style="border-radius:8px;margin:5px;display: inline-block;">
|
36 |
+
<img src="https://c8.patreon.com/4/200/548524/S" alt="Steve Hanff" width="100" height="100" style="border-radius:8px;margin:5px;display: inline-block;">
|
37 |
+
<img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/152118848/3b15a43d71714552b5ed1c9f84e66adf/eyJ3IjoyMDB9/1.png?token-hash=MKf3sWHz0MFPm_OAFjdsNvxoBfN5B5l54mn1ORdlRy8%3D" alt="Kristjan Retter" width="100" height="100" style="border-radius:8px;margin:5px;display: inline-block;">
|
38 |
+
<img src="https://c8.patreon.com/4/200/83319230/M" alt="Miguel Lara" width="100" height="100" style="border-radius:8px;margin:5px;display: inline-block;">
|
39 |
+
<img src="https://c8.patreon.com/4/200/8449560/P" alt="Patron" width="100" height="100" style="border-radius:8px;margin:5px;display: inline-block;">
|
40 |
+
<a href="https://x.com/NuxZoe" target="_blank" rel="noopener noreferrer"><img src="https://pbs.twimg.com/profile_images/1916482710069014528/RDLnPRSg_400x400.jpg" alt="tungsten" width="100" height="100" style="border-radius:8px;margin:5px;display: inline-block;"></a>
|
41 |
+
<img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/169502989/220069e79ce745b29237e94c22a729df/eyJ3IjoyMDB9/1.png?token-hash=E8E2JOqx66k2zMtYUw8Gy57dw-gVqA6OPpdCmWFFSFw%3D" alt="Timothy Bielec" width="100" height="100" style="border-radius:8px;margin:5px;display: inline-block;">
|
42 |
+
<img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/34200989/58ae95ebda0640c8b7a91b4fa31357aa/eyJ3IjoyMDB9/1.jpeg?token-hash=4mVDM1kCYGauYa33zLG14_g0oj9_UjDK_-Qp4zk42GE%3D" alt="Noah Miller" width="100" height="100" style="border-radius:8px;margin:5px;display: inline-block;">
|
43 |
+
<img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/27288932/6c35d2d961ee4e14a7a368c990791315/eyJ3IjoyMDB9/1.jpeg?token-hash=TGIto_PGEG2NEKNyqwzEnRStOkhrjb3QlMhHA3raKJY%3D" alt="David Garrido" width="100" height="100" style="border-radius:8px;margin:5px;display: inline-block;">
|
44 |
+
<a href="https://x.com/RalFingerLP" target="_blank" rel="noopener noreferrer"><img src="https://pbs.twimg.com/profile_images/919595465041162241/ZU7X3T5k_400x400.jpg" alt="RalFinger" width="100" height="100" style="border-radius:8px;margin:5px;display: inline-block;"></a>
|
45 |
+
</p>
|
46 |
+
<hr style="width:100%;border:none;height:2px;background:#ddd;margin:30px 0;">
|
47 |
+
<p align="center">
|
48 |
+
<a href="http://www.ir-ltd.net" target="_blank" rel="noopener noreferrer"><img src="https://pbs.twimg.com/profile_images/1602579392198283264/6Tm2GYus_400x400.jpg" alt="IR-Entertainment Ltd" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;"></a>
|
49 |
+
<img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/9547341/bb35d9a222fd460e862e960ba3eacbaf/eyJ3IjoyMDB9/1.jpeg?token-hash=Q2XGDvkCbiONeWNxBCTeTMOcuwTjOaJ8Z-CAf5xq3Hs%3D" alt="Travis Harrington" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
|
50 |
+
<img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/98811435/3a3632d1795b4c2b9f8f0270f2f6a650/eyJ3IjoyMDB9/1.jpeg?token-hash=657rzuJ0bZavMRZW3XZ-xQGqm3Vk6FkMZgFJVMCOPdk%3D" alt="EmmanuelMr18" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
|
51 |
+
<img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/81275465/1e4148fe9c47452b838949d02dd9a70f/eyJ3IjoyMDB9/1.jpeg?token-hash=YAX1ucxybpCIujUCXfdwzUQkttIn3c7pfi59uaFPSwM%3D" alt="Aaron Amortegui" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
|
52 |
+
<img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/155963250/6f8fd7075c3b4247bfeb054ba49172d6/eyJ3IjoyMDB9/1.png?token-hash=z81EHmdU2cqSrwa9vJmZTV3h0LG-z9Qakhxq34FrYT4%3D" alt="Un Defined" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
|
53 |
+
<img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/45562978/0de33cf52ec642ae8a2f612cddec4ca6/eyJ3IjoyMDB9/1.jpeg?token-hash=aD4debMD5ZQjqTII6s4zYSgVK2-bdQt9p3eipi0bENs%3D" alt="Jack English" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
|
54 |
+
<img src="https://c8.patreon.com/4/200/27791680/J" alt="Jean-Tristan Marin" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
|
55 |
+
<img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/570742/4ceb33453a5a4745b430a216aba9280f/eyJ3IjoyMDB9/1.jpg?token-hash=nPcJ2zj3sloND9jvbnbYnob2vMXRnXdRuujthqDLWlU%3D" alt="Al H" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
|
56 |
+
<img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/82763/f99cc484361d4b9d94fe4f0814ada303/eyJ3IjoyMDB9/1.jpeg?token-hash=A3JWlBNL0b24FFWb-FCRDAyhs-OAxg-zrhfBXP_axuU%3D" alt="Doron Adler" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
|
57 |
+
<img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/103077711/bb215761cc004e80bd9cec7d4bcd636d/eyJ3IjoyMDB9/2.jpeg?token-hash=3U8kdZSUpnmeYIDVK4zK9TTXFpnAud_zOwBRXx18018%3D" alt="John Dopamine" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
|
58 |
+
<img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/99036356/7ae9c4d80e604e739b68cca12ee2ed01/eyJ3IjoyMDB9/3.png?token-hash=ZhsBMoTOZjJ-Y6h5NOmU5MT-vDb2fjK46JDlpEehkVQ%3D" alt="Noctre" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
|
59 |
+
<img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/141098579/1a9f0a1249d447a7a0df718a57343912/eyJ3IjoyMDB9/2.png?token-hash=_n-AQmPgY0FP9zCGTIEsr5ka4Y7YuaMkt3qL26ZqGg8%3D" alt="The Local Lab" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
|
60 |
+
<img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/93348210/5c650f32a0bc481d80900d2674528777/eyJ3IjoyMDB9/1.jpeg?token-hash=0jiknRw3jXqYWW6En8bNfuHgVDj4LI_rL7lSS4-_xlo%3D" alt="Armin Behjati" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
|
61 |
+
<img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/134129880/680c7e14cd1a4d1a9face921fb010f88/eyJ3IjoyMDB9/1.png?token-hash=5fqqHE6DCTbt7gDQL7VRcWkV71jF7FvWcLhpYl5aMXA%3D" alt="Bharat Prabhakar" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
|
62 |
+
<img src="https://c8.patreon.com/4/200/70218846/C" alt="Cosmosis" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
|
63 |
+
<img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/30931983/54ab4e4ceab946e79a6418d205f9ed51/eyJ3IjoyMDB9/1.png?token-hash=j2phDrgd6IWuqKqNIDbq9fR2B3fMF-GUCQSdETS1w5Y%3D" alt="HestoySeghuro ." width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
|
64 |
+
<img src="https://c8.patreon.com/4/200/4105384/J" alt="Jack Blakely" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
|
65 |
+
<img src="https://c8.patreon.com/4/200/4541423/S" alt="Sören " width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
|
66 |
+
<a href="https://www.youtube.com/@happyme7055" target="_blank" rel="noopener noreferrer"><img src="https://yt3.googleusercontent.com/ytc/AIdro_mFqhIRk99SoEWY2gvSvVp6u1SkCGMkRqYQ1OlBBeoOVp8=s160-c-k-c0x00ffffff-no-rj" alt="Marcus Rass" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;"></a>
|
67 |
+
<img src="https://c8.patreon.com/4/200/53077895/M" alt="Marc" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
|
68 |
+
<img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/157407541/bb9d80cffdab4334ad78366060561520/eyJ3IjoyMDB9/2.png?token-hash=WYz-U_9zabhHstOT5UIa5jBaoFwrwwqyWxWEzIR2m_c%3D" alt="Tokio Studio srl IT10640050968" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
|
69 |
+
<img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/44568304/a9d83a0e786b41b4bdada150f7c9271c/eyJ3IjoyMDB9/1.jpeg?token-hash=FtxnwrSrknQUQKvDRv2rqPceX2EF23eLq4pNQYM_fmw%3D" alt="Albert Bukoski" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
|
70 |
+
<img src="https://c8.patreon.com/4/200/5048649/B" alt="Ben Ward" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
|
71 |
+
<img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/111904990/08b1cf65be6a4de091c9b73b693b3468/eyJ3IjoyMDB9/1.png?token-hash=_Odz6RD3CxtubEHbUxYujcjw6zAajbo3w8TRz249VBA%3D" alt="Brian Smith" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
|
72 |
+
<img src="https://c8.patreon.com/4/200/494309/J" alt="Julian Tsependa" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
|
73 |
+
<img src="https://c8.patreon.com/4/200/5602036/K" alt="Kelevra" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
|
74 |
+
<img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/159203973/36c817f941ac4fa18103a4b8c0cb9cae/eyJ3IjoyMDB9/1.png?token-hash=zkt72HW3EoiIEAn3LSk9gJPBsXfuTVcc4rRBS3CeR8w%3D" alt="Marko jak" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
|
75 |
+
<img src="https://c8.patreon.com/4/200/24653779/R" alt="RayHell" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
|
76 |
+
<img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/76566911/6485eaf5ec6249a7b524ee0b979372f0/eyJ3IjoyMDB9/1.jpeg?token-hash=mwCSkTelDBaengG32NkN0lVl5mRjB-cwo6-a47wnOsU%3D" alt="the biitz" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
|
77 |
+
<img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/32633822/1ab5612efe80417cbebfe91e871fc052/eyJ3IjoyMDB9/1.png?token-hash=pOS_IU3b3RL5-iL96A3Xqoj2bQ-dDo4RUkBylcMED_s%3D" alt="Zack Abrams" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
|
78 |
+
<img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/97985240/3d1d0e6905d045aba713e8132cab4a30/eyJ3IjoyMDB9/1.png?token-hash=fRavvbO_yqWKA_OsJb5DzjfKZ1Yt-TG-ihMoeVBvlcM%3D" alt="עומר מכלוף" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
|
79 |
+
<a href="https://github.com/julien-blanchon" target="_blank" rel="noopener noreferrer"><img src="https://avatars.githubusercontent.com/u/11278197?v=4" alt="Blanchon" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;"></a>
|
80 |
+
<img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/11198131/e696d9647feb4318bcf16243c2425805/eyJ3IjoyMDB9/1.jpeg?token-hash=c2c2p1SaiX86iXAigvGRvzm4jDHvIFCg298A49nIfUM%3D" alt="Nicholas Agranoff" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
|
81 |
+
<img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/785333/bdb9ede5765d42e5a2021a86eebf0d8f/eyJ3IjoyMDB9/2.jpg?token-hash=l_rajMhxTm6wFFPn7YdoKBxeUqhdRXKdy6_8SGCuNsE%3D" alt="Sapjes " width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
|
82 |
+
<img src="https://c8.patreon.com/4/200/2446176/S" alt="Scott VanKirk" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
|
83 |
+
<img src="https://c8.patreon.com/4/200/83034/W" alt="william tatum" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
|
84 |
+
<img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/138787189/2b5662dcb638466282ac758e3ac651b4/eyJ3IjoyMDB9/1.png?token-hash=zwj7MScO18vhDxhKt6s5q4gdeNJM3xCLuhSt8zlqlZs%3D" alt="Антон Антонио" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
|
85 |
+
<img src="https://c8.patreon.com/4/200/30530914/T" alt="Techer " width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
|
86 |
+
<img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/25209707/36ae876d662d4d85aaf162b6d67d31e7/eyJ3IjoyMDB9/1.png?token-hash=Zows_A6uqlY5jClhfr4Y3QfMnDKVkS3mbxNHUDkVejo%3D" alt="fjioq8" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
|
87 |
+
<img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/46680573/ee3d99c04a674dd5a8e1ecfb926db6a2/eyJ3IjoyMDB9/1.jpeg?token-hash=cgD4EXyfZMPnXIrcqWQ5jGqzRUfqjPafb9yWfZUPB4Q%3D" alt="Neil Murray" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
|
88 |
+
<img src="https://ostris.com/wp-content/uploads/2025/08/supporter_default.jpg" alt="Joakim Sällström" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
|
89 |
+
<img src="https://c8.patreon.com/4/200/63510241/A" alt="Andrew Park" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
|
90 |
+
<a href="https://github.com/Spikhalskiy" target="_blank" rel="noopener noreferrer"><img src="https://avatars.githubusercontent.com/u/532108?u=2464983638afea8caf4cd9f0e4a7bc3e6a63bb0a&v=4" alt="Dmitry Spikhalsky" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;"></a>
|
91 |
+
<img src="https://c8.patreon.com/4/200/88567307/E" alt="el Chavo" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
|
92 |
+
<img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/117569999/55f75c57f95343e58402529cec852b26/eyJ3IjoyMDB9/1.jpeg?token-hash=squblHZH4-eMs3gI46Uqu1oTOK9sQ-0gcsFdZcB9xQg%3D" alt="James Thompson" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
|
93 |
+
<img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/66157709/6fe70df085e24464995a1a9293a53760/eyJ3IjoyMDB9/1.jpeg?token-hash=eqe0wvg6JfbRUGMKpL_x3YPI5Ppf18aUUJe2EzADU-g%3D" alt="Joey Santana" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
|
94 |
+
<img src="https://ostris.com/wp-content/uploads/2025/08/supporter_default.jpg" alt="Heikki Rinkinen" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
|
95 |
+
<img src="https://c8.patreon.com/4/200/6175608/B" alt="Bobbie " width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
|
96 |
+
<a href="https://github.com/Slartibart23" target="_blank" rel="noopener noreferrer"><img src="https://avatars.githubusercontent.com/u/133593860?u=31217adb2522fb295805824ffa7e14e8f0fca6fa&v=4" alt="Slarti" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;"></a>
|
97 |
+
<img src="https://ostris.com/wp-content/uploads/2025/08/supporter_default.jpg" alt="Tommy Falkowski" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
|
98 |
+
<img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/28533016/e8f6044ccfa7483f87eeaa01c894a773/eyJ3IjoyMDB9/2.png?token-hash=ak-h3JWB50hyenCavcs32AAPw6nNhmH2nBFKpdk5hvM%3D" alt="William Tatum" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
|
99 |
+
<img src="https://ostris.com/wp-content/uploads/2025/08/supporter_default.jpg" alt="Karol Stępień" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
|
100 |
+
<img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/156564939/17dbfd45c59d4cf29853d710cb0c5d6f/eyJ3IjoyMDB9/1.png?token-hash=e6wXA_S8cgJeEDI9eJK934eB0TiM8mxJm9zW_VH0gDU%3D" alt="Hans Untch" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
|
101 |
+
<img src="https://c8.patreon.com/4/200/59408413/B" alt="ByteC" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
|
102 |
+
<img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/3712451/432e22a355494ec0a1ea1927ff8d452e/eyJ3IjoyMDB9/7.jpeg?token-hash=OpQ9SAfVQ4Un9dSYlGTHuApZo5GlJ797Mo0DtVtMOSc%3D" alt="David Shorey" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
|
103 |
+
<img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/53634141/c1441f6c605344bbaef885d4272977bb/eyJ3IjoyMDB9/1.JPG?token-hash=Aizd6AxQhY3n6TBE5AwCVeSwEBbjALxQmu6xqc08qBo%3D" alt="Jana Spacelight" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
|
104 |
+
<img src="https://c8.patreon.com/4/200/11180426/J" alt="jarrett towe" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
|
105 |
+
<img src="https://c8.patreon.com/4/200/21828017/J" alt="Jim" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
|
106 |
+
<img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/63232055/2300b4ab370341b5b476902c9b8218ee/eyJ3IjoyMDB9/1.png?token-hash=R9Nb4O0aLBRwxT1cGHUMThlvf6A2MD5SO88lpZBdH7M%3D" alt="Marek P" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
|
107 |
+
<img src="https://c8.patreon.com/4/200/9944625/P" alt="Pomoe " width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
|
108 |
+
<img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/25047900/423e4cb73aba457f8f9c6e5582eddaeb/eyJ3IjoyMDB9/1.jpeg?token-hash=81RvQXBbT66usxqtyWum9Ul4oBn3qHK1cM71IvthC-U%3D" alt="Ruairi Robinson" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
|
109 |
+
<img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/178476551/0b9e83efcd234df5a6bea30d59e6c1cd/eyJ3IjoyMDB9/1.png?token-hash=3XoYMrMxk-K6GelM22mE-FwkjFulX9hpIL7QI3wO2jI%3D" alt="Timmy" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
|
110 |
+
<img src="https://c8.patreon.com/4/200/10876902/T" alt="Tyssel" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
|
111 |
+
<img src="https://ostris.com/wp-content/uploads/2025/08/supporter_default.jpg" alt="Juan Franco" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
|
112 |
+
</p>
|
113 |
+
|
114 |
---
|
115 |
+
|
116 |
+
|
117 |
+
|
118 |
+
|
119 |
+
## Installation
|
120 |
+
|
121 |
+
Requirements:
|
122 |
+
- python >3.10
|
123 |
+
- Nvidia GPU with enough ram to do what you need
|
124 |
+
- python venv
|
125 |
+
- git
|
126 |
+
|
127 |
+
|
128 |
+
Linux:
|
129 |
+
```bash
|
130 |
+
git clone https://github.com/ostris/ai-toolkit.git
|
131 |
+
cd ai-toolkit
|
132 |
+
python3 -m venv venv
|
133 |
+
source venv/bin/activate
|
134 |
+
# install torch first
|
135 |
+
pip3 install --no-cache-dir torch==2.7.0 torchvision==0.22.0 torchaudio==2.7.0 --index-url https://download.pytorch.org/whl/cu126
|
136 |
+
pip3 install -r requirements.txt
|
137 |
+
```
|
138 |
+
|
139 |
+
Windows:
|
140 |
+
|
141 |
+
If you are having issues with Windows. I recommend using the easy install script at [https://github.com/Tavris1/AI-Toolkit-Easy-Install](https://github.com/Tavris1/AI-Toolkit-Easy-Install)
|
142 |
+
|
143 |
+
```bash
|
144 |
+
git clone https://github.com/ostris/ai-toolkit.git
|
145 |
+
cd ai-toolkit
|
146 |
+
python -m venv venv
|
147 |
+
.\venv\Scripts\activate
|
148 |
+
pip install --no-cache-dir torch==2.7.0 torchvision==0.22.0 torchaudio==2.7.0 --index-url https://download.pytorch.org/whl/cu126
|
149 |
+
pip install -r requirements.txt
|
150 |
+
```
|
151 |
+
|
152 |
+
|
153 |
+
# AI Toolkit UI
|
154 |
+
|
155 |
+
<img src="https://ostris.com/wp-content/uploads/2025/02/toolkit-ui.jpg" alt="AI Toolkit UI" width="100%">
|
156 |
+
|
157 |
+
The AI Toolkit UI is a web interface for the AI Toolkit. It allows you to easily start, stop, and monitor jobs. It also allows you to easily train models with a few clicks. It also allows you to set a token for the UI to prevent unauthorized access so it is mostly safe to run on an exposed server.
|
158 |
+
|
159 |
+
## Running the UI
|
160 |
+
|
161 |
+
Requirements:
|
162 |
+
- Node.js > 18
|
163 |
+
|
164 |
+
The UI does not need to be kept running for the jobs to run. It is only needed to start/stop/monitor jobs. The commands below
|
165 |
+
will install / update the UI and it's dependencies and start the UI.
|
166 |
+
|
167 |
+
```bash
|
168 |
+
cd ui
|
169 |
+
npm run build_and_start
|
170 |
+
```
|
171 |
+
|
172 |
+
You can now access the UI at `http://localhost:8675` or `http://<your-ip>:8675` if you are running it on a server.
|
173 |
+
|
174 |
+
## Securing the UI
|
175 |
+
|
176 |
+
If you are hosting the UI on a cloud provider or any network that is not secure, I highly recommend securing it with an auth token.
|
177 |
+
You can do this by setting the environment variable `AI_TOOLKIT_AUTH` to super secure password. This token will be required to access
|
178 |
+
the UI. You can set this when starting the UI like so:
|
179 |
+
|
180 |
+
```bash
|
181 |
+
# Linux
|
182 |
+
AI_TOOLKIT_AUTH=super_secure_password npm run build_and_start
|
183 |
+
|
184 |
+
# Windows
|
185 |
+
set AI_TOOLKIT_AUTH=super_secure_password && npm run build_and_start
|
186 |
+
|
187 |
+
# Windows Powershell
|
188 |
+
$env:AI_TOOLKIT_AUTH="super_secure_password"; npm run build_and_start
|
189 |
+
```
|
190 |
+
|
191 |
+
|
192 |
+
## FLUX.1 Training
|
193 |
+
|
194 |
+
### Tutorial
|
195 |
+
|
196 |
+
To get started quickly, check out [@araminta_k](https://x.com/araminta_k) tutorial on [Finetuning Flux Dev on a 3090](https://www.youtube.com/watch?v=HzGW_Kyermg) with 24GB VRAM.
|
197 |
+
|
198 |
+
|
199 |
+
### Requirements
|
200 |
+
You currently need a GPU with **at least 24GB of VRAM** to train FLUX.1. If you are using it as your GPU to control
|
201 |
+
your monitors, you probably need to set the flag `low_vram: true` in the config file under `model:`. This will quantize
|
202 |
+
the model on CPU and should allow it to train with monitors attached. Users have gotten it to work on Windows with WSL,
|
203 |
+
but there are some reports of a bug when running on windows natively.
|
204 |
+
I have only tested on linux for now. This is still extremely experimental
|
205 |
+
and a lot of quantizing and tricks had to happen to get it to fit on 24GB at all.
|
206 |
+
|
207 |
+
### FLUX.1-dev
|
208 |
+
|
209 |
+
FLUX.1-dev has a non-commercial license. Which means anything you train will inherit the
|
210 |
+
non-commercial license. It is also a gated model, so you need to accept the license on HF before using it.
|
211 |
+
Otherwise, this will fail. Here are the required steps to setup a license.
|
212 |
+
|
213 |
+
1. Sign into HF and accept the model access here [black-forest-labs/FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev)
|
214 |
+
2. Make a file named `.env` in the root on this folder
|
215 |
+
3. [Get a READ key from huggingface](https://huggingface.co/settings/tokens/new?) and add it to the `.env` file like so `HF_TOKEN=your_key_here`
|
216 |
+
|
217 |
+
### FLUX.1-schnell
|
218 |
+
|
219 |
+
FLUX.1-schnell is Apache 2.0. Anything trained on it can be licensed however you want and it does not require a HF_TOKEN to train.
|
220 |
+
However, it does require a special adapter to train with it, [ostris/FLUX.1-schnell-training-adapter](https://huggingface.co/ostris/FLUX.1-schnell-training-adapter).
|
221 |
+
It is also highly experimental. For best overall quality, training on FLUX.1-dev is recommended.
|
222 |
+
|
223 |
+
To use it, You just need to add the assistant to the `model` section of your config file like so:
|
224 |
+
|
225 |
+
```yaml
|
226 |
+
model:
|
227 |
+
name_or_path: "black-forest-labs/FLUX.1-schnell"
|
228 |
+
assistant_lora_path: "ostris/FLUX.1-schnell-training-adapter"
|
229 |
+
is_flux: true
|
230 |
+
quantize: true
|
231 |
+
```
|
232 |
+
|
233 |
+
You also need to adjust your sample steps since schnell does not require as many
|
234 |
+
|
235 |
+
```yaml
|
236 |
+
sample:
|
237 |
+
guidance_scale: 1 # schnell does not do guidance
|
238 |
+
sample_steps: 4 # 1 - 4 works well
|
239 |
+
```
|
240 |
+
|
241 |
+
### Training
|
242 |
+
1. Copy the example config file located at `config/examples/train_lora_flux_24gb.yaml` (`config/examples/train_lora_flux_schnell_24gb.yaml` for schnell) to the `config` folder and rename it to `whatever_you_want.yml`
|
243 |
+
2. Edit the file following the comments in the file
|
244 |
+
3. Run the file like so `python run.py config/whatever_you_want.yml`
|
245 |
+
|
246 |
+
A folder with the name and the training folder from the config file will be created when you start. It will have all
|
247 |
+
checkpoints and images in it. You can stop the training at any time using ctrl+c and when you resume, it will pick back up
|
248 |
+
from the last checkpoint.
|
249 |
+
|
250 |
+
IMPORTANT. If you press crtl+c while it is saving, it will likely corrupt that checkpoint. So wait until it is done saving
|
251 |
+
|
252 |
+
### Need help?
|
253 |
+
|
254 |
+
Please do not open a bug report unless it is a bug in the code. You are welcome to [Join my Discord](https://discord.gg/VXmU2f5WEU)
|
255 |
+
and ask for help there. However, please refrain from PMing me directly with general question or support. Ask in the discord
|
256 |
+
and I will answer when I can.
|
257 |
+
|
258 |
+
## Gradio UI
|
259 |
+
|
260 |
+
To get started training locally with a with a custom UI, once you followed the steps above and `ai-toolkit` is installed:
|
261 |
+
|
262 |
+
```bash
|
263 |
+
cd ai-toolkit #in case you are not yet in the ai-toolkit folder
|
264 |
+
huggingface-cli login #provide a `write` token to publish your LoRA at the end
|
265 |
+
python flux_train_ui.py
|
266 |
+
```
|
267 |
+
|
268 |
+
You will instantiate a UI that will let you upload your images, caption them, train and publish your LoRA
|
269 |
+

|
270 |
+
|
271 |
+
|
272 |
+
## Training in RunPod
|
273 |
+
Example RunPod template: **runpod/pytorch:2.2.0-py3.10-cuda12.1.1-devel-ubuntu22.04**
|
274 |
+
> You need a minimum of 24GB VRAM, pick a GPU by your preference.
|
275 |
+
|
276 |
+
#### Example config ($0.5/hr):
|
277 |
+
- 1x A40 (48 GB VRAM)
|
278 |
+
- 19 vCPU 100 GB RAM
|
279 |
+
|
280 |
+
#### Custom overrides (you need some storage to clone FLUX.1, store datasets, store trained models and samples):
|
281 |
+
- ~120 GB Disk
|
282 |
+
- ~120 GB Pod Volume
|
283 |
+
- Start Jupyter Notebook
|
284 |
+
|
285 |
+
### 1. Setup
|
286 |
+
```
|
287 |
+
git clone https://github.com/ostris/ai-toolkit.git
|
288 |
+
cd ai-toolkit
|
289 |
+
git submodule update --init --recursive
|
290 |
+
python -m venv venv
|
291 |
+
source venv/bin/activate
|
292 |
+
pip install torch
|
293 |
+
pip install -r requirements.txt
|
294 |
+
pip install --upgrade accelerate transformers diffusers huggingface_hub #Optional, run it if you run into issues
|
295 |
+
```
|
296 |
+
### 2. Upload your dataset
|
297 |
+
- Create a new folder in the root, name it `dataset` or whatever you like.
|
298 |
+
- Drag and drop your .jpg, .jpeg, or .png images and .txt files inside the newly created dataset folder.
|
299 |
+
|
300 |
+
### 3. Login into Hugging Face with an Access Token
|
301 |
+
- Get a READ token from [here](https://huggingface.co/settings/tokens) and request access to Flux.1-dev model from [here](https://huggingface.co/black-forest-labs/FLUX.1-dev).
|
302 |
+
- Run ```huggingface-cli login``` and paste your token.
|
303 |
+
|
304 |
+
### 4. Training
|
305 |
+
- Copy an example config file located at ```config/examples``` to the config folder and rename it to ```whatever_you_want.yml```.
|
306 |
+
- Edit the config following the comments in the file.
|
307 |
+
- Change ```folder_path: "/path/to/images/folder"``` to your dataset path like ```folder_path: "/workspace/ai-toolkit/your-dataset"```.
|
308 |
+
- Run the file: ```python run.py config/whatever_you_want.yml```.
|
309 |
+
|
310 |
+
### Screenshot from RunPod
|
311 |
+
<img width="1728" alt="RunPod Training Screenshot" src="https://github.com/user-attachments/assets/53a1b8ef-92fa-4481-81a7-bde45a14a7b5">
|
312 |
+
|
313 |
+
## Training in Modal
|
314 |
+
|
315 |
+
### 1. Setup
|
316 |
+
#### ai-toolkit:
|
317 |
+
```
|
318 |
+
git clone https://github.com/ostris/ai-toolkit.git
|
319 |
+
cd ai-toolkit
|
320 |
+
git submodule update --init --recursive
|
321 |
+
python -m venv venv
|
322 |
+
source venv/bin/activate
|
323 |
+
pip install torch
|
324 |
+
pip install -r requirements.txt
|
325 |
+
pip install --upgrade accelerate transformers diffusers huggingface_hub #Optional, run it if you run into issues
|
326 |
+
```
|
327 |
+
#### Modal:
|
328 |
+
- Run `pip install modal` to install the modal Python package.
|
329 |
+
- Run `modal setup` to authenticate (if this doesn’t work, try `python -m modal setup`).
|
330 |
+
|
331 |
+
#### Hugging Face:
|
332 |
+
- Get a READ token from [here](https://huggingface.co/settings/tokens) and request access to Flux.1-dev model from [here](https://huggingface.co/black-forest-labs/FLUX.1-dev).
|
333 |
+
- Run `huggingface-cli login` and paste your token.
|
334 |
+
|
335 |
+
### 2. Upload your dataset
|
336 |
+
- Drag and drop your dataset folder containing the .jpg, .jpeg, or .png images and .txt files in `ai-toolkit`.
|
337 |
+
|
338 |
+
### 3. Configs
|
339 |
+
- Copy an example config file located at ```config/examples/modal``` to the `config` folder and rename it to ```whatever_you_want.yml```.
|
340 |
+
- Edit the config following the comments in the file, **<ins>be careful and follow the example `/root/ai-toolkit` paths</ins>**.
|
341 |
+
|
342 |
+
### 4. Edit run_modal.py
|
343 |
+
- Set your entire local `ai-toolkit` path at `code_mount = modal.Mount.from_local_dir` like:
|
344 |
+
|
345 |
+
```
|
346 |
+
code_mount = modal.Mount.from_local_dir("/Users/username/ai-toolkit", remote_path="/root/ai-toolkit")
|
347 |
+
```
|
348 |
+
- Choose a `GPU` and `Timeout` in `@app.function` _(default is A100 40GB and 2 hour timeout)_.
|
349 |
+
|
350 |
+
### 5. Training
|
351 |
+
- Run the config file in your terminal: `modal run run_modal.py --config-file-list-str=/root/ai-toolkit/config/whatever_you_want.yml`.
|
352 |
+
- You can monitor your training in your local terminal, or on [modal.com](https://modal.com/).
|
353 |
+
- Models, samples and optimizer will be stored in `Storage > flux-lora-models`.
|
354 |
+
|
355 |
+
### 6. Saving the model
|
356 |
+
- Check contents of the volume by running `modal volume ls flux-lora-models`.
|
357 |
+
- Download the content by running `modal volume get flux-lora-models your-model-name`.
|
358 |
+
- Example: `modal volume get flux-lora-models my_first_flux_lora_v1`.
|
359 |
+
|
360 |
+
### Screenshot from Modal
|
361 |
+
|
362 |
+
<img width="1728" alt="Modal Traning Screenshot" src="https://github.com/user-attachments/assets/7497eb38-0090-49d6-8ad9-9c8ea7b5388b">
|
363 |
+
|
364 |
---
|
365 |
|
366 |
+
## Dataset Preparation
|
367 |
+
|
368 |
+
Datasets generally need to be a folder containing images and associated text files. Currently, the only supported
|
369 |
+
formats are jpg, jpeg, and png. Webp currently has issues. The text files should be named the same as the images
|
370 |
+
but with a `.txt` extension. For example `image2.jpg` and `image2.txt`. The text file should contain only the caption.
|
371 |
+
You can add the word `[trigger]` in the caption file and if you have `trigger_word` in your config, it will be automatically
|
372 |
+
replaced.
|
373 |
+
|
374 |
+
Images are never upscaled but they are downscaled and placed in buckets for batching. **You do not need to crop/resize your images**.
|
375 |
+
The loader will automatically resize them and can handle varying aspect ratios.
|
376 |
+
|
377 |
+
|
378 |
+
## Training Specific Layers
|
379 |
+
|
380 |
+
To train specific layers with LoRA, you can use the `only_if_contains` network kwargs. For instance, if you want to train only the 2 layers
|
381 |
+
used by The Last Ben, [mentioned in this post](https://x.com/__TheBen/status/1829554120270987740), you can adjust your
|
382 |
+
network kwargs like so:
|
383 |
+
|
384 |
+
```yaml
|
385 |
+
network:
|
386 |
+
type: "lora"
|
387 |
+
linear: 128
|
388 |
+
linear_alpha: 128
|
389 |
+
network_kwargs:
|
390 |
+
only_if_contains:
|
391 |
+
- "transformer.single_transformer_blocks.7.proj_out"
|
392 |
+
- "transformer.single_transformer_blocks.20.proj_out"
|
393 |
+
```
|
394 |
+
|
395 |
+
The naming conventions of the layers are in diffusers format, so checking the state dict of a model will reveal
|
396 |
+
the suffix of the name of the layers you want to train. You can also use this method to only train specific groups of weights.
|
397 |
+
For instance to only train the `single_transformer` for FLUX.1, you can use the following:
|
398 |
+
|
399 |
+
```yaml
|
400 |
+
network:
|
401 |
+
type: "lora"
|
402 |
+
linear: 128
|
403 |
+
linear_alpha: 128
|
404 |
+
network_kwargs:
|
405 |
+
only_if_contains:
|
406 |
+
- "transformer.single_transformer_blocks."
|
407 |
+
```
|
408 |
+
|
409 |
+
You can also exclude layers by their names by using `ignore_if_contains` network kwarg. So to exclude all the single transformer blocks,
|
410 |
+
|
411 |
+
|
412 |
+
```yaml
|
413 |
+
network:
|
414 |
+
type: "lora"
|
415 |
+
linear: 128
|
416 |
+
linear_alpha: 128
|
417 |
+
network_kwargs:
|
418 |
+
ignore_if_contains:
|
419 |
+
- "transformer.single_transformer_blocks."
|
420 |
+
```
|
421 |
+
|
422 |
+
`ignore_if_contains` takes priority over `only_if_contains`. So if a weight is covered by both,
|
423 |
+
if will be ignored.
|
424 |
+
|
425 |
+
## LoKr Training
|
426 |
+
|
427 |
+
To learn more about LoKr, read more about it at [KohakuBlueleaf/LyCORIS](https://github.com/KohakuBlueleaf/LyCORIS/blob/main/docs/Guidelines.md). To train a LoKr model, you can adjust the network type in the config file like so:
|
428 |
+
|
429 |
+
```yaml
|
430 |
+
network:
|
431 |
+
type: "lokr"
|
432 |
+
lokr_full_rank: true
|
433 |
+
lokr_factor: 8
|
434 |
+
```
|
435 |
+
|
436 |
+
Everything else should work the same including layer targeting.
|
437 |
+
|
438 |
+
|
439 |
+
## Updates
|
440 |
+
|
441 |
+
Only larger updates are listed here. There are usually smaller daily updated that are omitted.
|
442 |
+
|
443 |
+
### Jul 17, 2025
|
444 |
+
- Make it easy to add control images to the samples in the ui
|
445 |
+
|
446 |
+
### Jul 11, 2025
|
447 |
+
- Added better video config settings to the UI for video models.
|
448 |
+
- Added Wan I2V training to the UI
|
449 |
+
|
450 |
+
### June 29, 2025
|
451 |
+
- Fixed issue where Kontext forced sizes on sampling
|
452 |
+
|
453 |
+
### June 26, 2025
|
454 |
+
- Added support for FLUX.1 Kontext training
|
455 |
+
- added support for instruction dataset training
|
456 |
+
|
457 |
+
### June 25, 2025
|
458 |
+
- Added support for OmniGen2 training
|
459 |
+
-
|
460 |
+
### June 17, 2025
|
461 |
+
- Performance optimizations for batch preparation
|
462 |
+
- Added some docs via a popup for items in the simple ui explaining what settings do. Still a WIP
|
463 |
+
|
464 |
+
### June 16, 2025
|
465 |
+
- Hide control images in the UI when viewing datasets
|
466 |
+
- WIP on mean flow loss
|
467 |
+
|
468 |
+
### June 12, 2025
|
469 |
+
- Fixed issue that resulted in blank captions in the dataloader
|
470 |
+
|
471 |
+
### June 10, 2025
|
472 |
+
- Decided to keep track up updates in the readme
|
473 |
+
- Added support for SDXL in the UI
|
474 |
+
- Added support for SD 1.5 in the UI
|
475 |
+
- Fixed UI Wan 2.1 14b name bug
|
476 |
+
- Added support for for conv training in the UI for models that support it
|
assets/VAE_test1.jpg
ADDED
![]() |
Git LFS Details
|
assets/glif.svg
ADDED
|
assets/lora_ease_ui.png
ADDED
![]() |
Git LFS Details
|
build_and_push_docker
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
|
3 |
+
# Extract version from version.py
|
4 |
+
if [ -f "version.py" ]; then
|
5 |
+
VERSION=$(python3 -c "from version import VERSION; print(VERSION)")
|
6 |
+
echo "Building version: $VERSION"
|
7 |
+
else
|
8 |
+
echo "Error: version.py not found. Please create a version.py file with VERSION defined."
|
9 |
+
exit 1
|
10 |
+
fi
|
11 |
+
|
12 |
+
echo "Docker builds from the repo, not this dir. Make sure changes are pushed to the repo."
|
13 |
+
echo "Building version: $VERSION and latest"
|
14 |
+
# wait 2 seconds
|
15 |
+
sleep 2
|
16 |
+
|
17 |
+
# Build the image with cache busting
|
18 |
+
docker build --build-arg CACHEBUST=$(date +%s) -t aitoolkit:$VERSION -f docker/Dockerfile .
|
19 |
+
|
20 |
+
# Tag with version and latest
|
21 |
+
docker tag aitoolkit:$VERSION ostris/aitoolkit:$VERSION
|
22 |
+
docker tag aitoolkit:$VERSION ostris/aitoolkit:latest
|
23 |
+
|
24 |
+
# Push both tags
|
25 |
+
echo "Pushing images to Docker Hub..."
|
26 |
+
docker push ostris/aitoolkit:$VERSION
|
27 |
+
docker push ostris/aitoolkit:latest
|
28 |
+
|
29 |
+
echo "Successfully built and pushed ostris/aitoolkit:$VERSION and ostris/aitoolkit:latest"
|
build_and_push_docker_dev
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
|
3 |
+
VERSION=dev
|
4 |
+
GIT_COMMIT=dev
|
5 |
+
|
6 |
+
echo "Docker builds from the repo, not this dir. Make sure changes are pushed to the repo."
|
7 |
+
echo "Building version: $VERSION"
|
8 |
+
# wait 2 seconds
|
9 |
+
sleep 2
|
10 |
+
|
11 |
+
# Build the image with cache busting
|
12 |
+
docker build --build-arg CACHEBUST=$(date +%s) -t aitoolkit:$VERSION -f docker/Dockerfile .
|
13 |
+
|
14 |
+
# Tag with version and latest
|
15 |
+
docker tag aitoolkit:$VERSION ostris/aitoolkit:$VERSION
|
16 |
+
|
17 |
+
# Push both tags
|
18 |
+
echo "Pushing images to Docker Hub..."
|
19 |
+
docker push ostris/aitoolkit:$VERSION
|
20 |
+
|
21 |
+
echo "Successfully built and pushed ostris/aitoolkit:$VERSION"
|
config/examples/extract.example.yml
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
# this is in yaml format. You can use json if you prefer
|
3 |
+
# I like both but yaml is easier to read and write
|
4 |
+
# plus it has comments which is nice for documentation
|
5 |
+
job: extract # tells the runner what to do
|
6 |
+
config:
|
7 |
+
# the name will be used to create a folder in the output folder
|
8 |
+
# it will also replace any [name] token in the rest of this config
|
9 |
+
name: name_of_your_model
|
10 |
+
# can be hugging face model, a .ckpt, or a .safetensors
|
11 |
+
base_model: "/path/to/base/model.safetensors"
|
12 |
+
# can be hugging face model, a .ckpt, or a .safetensors
|
13 |
+
extract_model: "/path/to/model/to/extract/trained.safetensors"
|
14 |
+
# we will create folder here with name above so. This will create /path/to/output/folder/name_of_your_model
|
15 |
+
output_folder: "/path/to/output/folder"
|
16 |
+
is_v2: false
|
17 |
+
dtype: fp16 # saved dtype
|
18 |
+
device: cpu # cpu, cuda:0, etc
|
19 |
+
|
20 |
+
# processes can be chained like this to run multiple in a row
|
21 |
+
# they must all use same models above, but great for testing different
|
22 |
+
# sizes and typed of extractions. It is much faster as we already have the models loaded
|
23 |
+
process:
|
24 |
+
# process 1
|
25 |
+
- type: locon # locon or lora (locon is lycoris)
|
26 |
+
filename: "[name]_64_32.safetensors" # will be put in output folder
|
27 |
+
dtype: fp16
|
28 |
+
mode: fixed
|
29 |
+
linear: 64
|
30 |
+
conv: 32
|
31 |
+
|
32 |
+
# process 2
|
33 |
+
- type: locon
|
34 |
+
output_path: "/absolute/path/for/this/output.safetensors" # can be absolute
|
35 |
+
mode: ratio
|
36 |
+
linear: 0.2
|
37 |
+
conv: 0.2
|
38 |
+
|
39 |
+
# process 3
|
40 |
+
- type: locon
|
41 |
+
filename: "[name]_ratio_02.safetensors"
|
42 |
+
mode: quantile
|
43 |
+
linear: 0.5
|
44 |
+
conv: 0.5
|
45 |
+
|
46 |
+
# process 4
|
47 |
+
- type: lora # traditional lora extraction (lierla) with linear layers only
|
48 |
+
filename: "[name]_4.safetensors"
|
49 |
+
mode: fixed # fixed, ratio, quantile supported for lora as well
|
50 |
+
linear: 4 # lora dim or rank
|
51 |
+
# no conv for lora
|
52 |
+
|
53 |
+
# process 5
|
54 |
+
- type: lora
|
55 |
+
filename: "[name]_q05.safetensors"
|
56 |
+
mode: quantile
|
57 |
+
linear: 0.5
|
58 |
+
|
59 |
+
# you can put any information you want here, and it will be saved in the model
|
60 |
+
# the below is an example. I recommend doing trigger words at a minimum
|
61 |
+
# in the metadata. The software will include this plus some other information
|
62 |
+
meta:
|
63 |
+
name: "[name]" # [name] gets replaced with the name above
|
64 |
+
description: A short description of your model
|
65 |
+
trigger_words:
|
66 |
+
- put
|
67 |
+
- trigger
|
68 |
+
- words
|
69 |
+
- here
|
70 |
+
version: '0.1'
|
71 |
+
creator:
|
72 |
+
name: Your Name
|
73 |
+
email: [email protected]
|
74 |
+
website: https://yourwebsite.com
|
75 |
+
any: All meta data above is arbitrary, it can be whatever you want.
|
config/examples/generate.example.yaml
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
|
3 |
+
job: generate # tells the runner what to do
|
4 |
+
config:
|
5 |
+
name: "generate" # this is not really used anywhere currently but required by runner
|
6 |
+
process:
|
7 |
+
# process 1
|
8 |
+
- type: to_folder # process images to a folder
|
9 |
+
output_folder: "output/gen"
|
10 |
+
device: cuda:0 # cpu, cuda:0, etc
|
11 |
+
generate:
|
12 |
+
# these are your defaults you can override most of them with flags
|
13 |
+
sampler: "ddpm" # ignored for now, will add later though ddpm is used regardless for now
|
14 |
+
width: 1024
|
15 |
+
height: 1024
|
16 |
+
neg: "cartoon, fake, drawing, illustration, cgi, animated, anime"
|
17 |
+
seed: -1 # -1 is random
|
18 |
+
guidance_scale: 7
|
19 |
+
sample_steps: 20
|
20 |
+
ext: ".png" # .png, .jpg, .jpeg, .webp
|
21 |
+
|
22 |
+
# here ate the flags you can use for prompts. Always start with
|
23 |
+
# your prompt first then add these flags after. You can use as many
|
24 |
+
# like
|
25 |
+
# photo of a baseball --n painting, ugly --w 1024 --h 1024 --seed 42 --cfg 7 --steps 20
|
26 |
+
# we will try to support all sd-scripts flags where we can
|
27 |
+
|
28 |
+
# FROM SD-SCRIPTS
|
29 |
+
# --n Treat everything until the next option as a negative prompt.
|
30 |
+
# --w Specify the width of the generated image.
|
31 |
+
# --h Specify the height of the generated image.
|
32 |
+
# --d Specify the seed for the generated image.
|
33 |
+
# --l Specify the CFG scale for the generated image.
|
34 |
+
# --s Specify the number of steps during generation.
|
35 |
+
|
36 |
+
# OURS and some QOL additions
|
37 |
+
# --p2 Prompt for the second text encoder (SDXL only)
|
38 |
+
# --n2 Negative prompt for the second text encoder (SDXL only)
|
39 |
+
# --gr Specify the guidance rescale for the generated image (SDXL only)
|
40 |
+
# --seed Specify the seed for the generated image same as --d
|
41 |
+
# --cfg Specify the CFG scale for the generated image same as --l
|
42 |
+
# --steps Specify the number of steps during generation same as --s
|
43 |
+
|
44 |
+
prompt_file: false # if true a txt file will be created next to images with prompt strings used
|
45 |
+
# prompts can also be a path to a text file with one prompt per line
|
46 |
+
# prompts: "/path/to/prompts.txt"
|
47 |
+
prompts:
|
48 |
+
- "photo of batman"
|
49 |
+
- "photo of superman"
|
50 |
+
- "photo of spiderman"
|
51 |
+
- "photo of a superhero --n batman superman spiderman"
|
52 |
+
|
53 |
+
model:
|
54 |
+
# huggingface name, relative prom project path, or absolute path to .safetensors or .ckpt
|
55 |
+
# name_or_path: "runwayml/stable-diffusion-v1-5"
|
56 |
+
name_or_path: "/mnt/Models/stable-diffusion/models/stable-diffusion/Ostris/Ostris_Real_v1.safetensors"
|
57 |
+
is_v2: false # for v2 models
|
58 |
+
is_v_pred: false # for v-prediction models (most v2 models)
|
59 |
+
is_xl: false # for SDXL models
|
60 |
+
dtype: bf16
|
config/examples/mod_lora_scale.yaml
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
job: mod
|
3 |
+
config:
|
4 |
+
name: name_of_your_model_v1
|
5 |
+
process:
|
6 |
+
- type: rescale_lora
|
7 |
+
# path to your current lora model
|
8 |
+
input_path: "/path/to/lora/lora.safetensors"
|
9 |
+
# output path for your new lora model, can be the same as input_path to replace
|
10 |
+
output_path: "/path/to/lora/output_lora_v1.safetensors"
|
11 |
+
# replaces meta with the meta below (plus minimum meta fields)
|
12 |
+
# if false, we will leave the meta alone except for updating hashes (sd-script hashes)
|
13 |
+
replace_meta: true
|
14 |
+
# how to adjust, we can scale the up_down weights or the alpha
|
15 |
+
# up_down is the default and probably the best, they will both net the same outputs
|
16 |
+
# would only affect rare NaN cases and maybe merging with old merge tools
|
17 |
+
scale_target: 'up_down'
|
18 |
+
# precision to save, fp16 is the default and standard
|
19 |
+
save_dtype: fp16
|
20 |
+
# current_weight is the ideal weight you use as a multiplier when using the lora
|
21 |
+
# IE in automatic1111 <lora:my_lora:6.0> the 6.0 is the current_weight
|
22 |
+
# you can do negatives here too if you want to flip the lora
|
23 |
+
current_weight: 6.0
|
24 |
+
# target_weight is the ideal weight you use as a multiplier when using the lora
|
25 |
+
# instead of the one above. IE in automatic1111 instead of using <lora:my_lora:6.0>
|
26 |
+
# we want to use <lora:my_lora:1.0> so 1.0 is the target_weight
|
27 |
+
target_weight: 1.0
|
28 |
+
|
29 |
+
# base model for the lora
|
30 |
+
# this is just used to add meta so automatic111 knows which model it is for
|
31 |
+
# assume v1.5 if these are not set
|
32 |
+
is_xl: false
|
33 |
+
is_v2: false
|
34 |
+
meta:
|
35 |
+
# this is only used if you set replace_meta to true above
|
36 |
+
name: "[name]" # [name] gets replaced with the name above
|
37 |
+
description: A short description of your lora
|
38 |
+
trigger_words:
|
39 |
+
- put
|
40 |
+
- trigger
|
41 |
+
- words
|
42 |
+
- here
|
43 |
+
version: '0.1'
|
44 |
+
creator:
|
45 |
+
name: Your Name
|
46 |
+
email: [email protected]
|
47 |
+
website: https://yourwebsite.com
|
48 |
+
any: All meta data above is arbitrary, it can be whatever you want.
|
config/examples/modal/modal_train_lora_flux_24gb.yaml
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
job: extension
|
3 |
+
config:
|
4 |
+
# this name will be the folder and filename name
|
5 |
+
name: "my_first_flux_lora_v1"
|
6 |
+
process:
|
7 |
+
- type: 'sd_trainer'
|
8 |
+
# root folder to save training sessions/samples/weights
|
9 |
+
training_folder: "/root/ai-toolkit/modal_output" # must match MOUNT_DIR from run_modal.py
|
10 |
+
# uncomment to see performance stats in the terminal every N steps
|
11 |
+
# performance_log_every: 1000
|
12 |
+
device: cuda:0
|
13 |
+
# if a trigger word is specified, it will be added to captions of training data if it does not already exist
|
14 |
+
# alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
|
15 |
+
# trigger_word: "p3r5on"
|
16 |
+
network:
|
17 |
+
type: "lora"
|
18 |
+
linear: 16
|
19 |
+
linear_alpha: 16
|
20 |
+
save:
|
21 |
+
dtype: float16 # precision to save
|
22 |
+
save_every: 250 # save every this many steps
|
23 |
+
max_step_saves_to_keep: 4 # how many intermittent saves to keep
|
24 |
+
datasets:
|
25 |
+
# datasets are a folder of images. captions need to be txt files with the same name as the image
|
26 |
+
# for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
|
27 |
+
# images will automatically be resized and bucketed into the resolution specified
|
28 |
+
# on windows, escape back slashes with another backslash so
|
29 |
+
# "C:\\path\\to\\images\\folder"
|
30 |
+
# your dataset must be placed in /ai-toolkit and /root is for modal to find the dir:
|
31 |
+
- folder_path: "/root/ai-toolkit/your-dataset"
|
32 |
+
caption_ext: "txt"
|
33 |
+
caption_dropout_rate: 0.05 # will drop out the caption 5% of time
|
34 |
+
shuffle_tokens: false # shuffle caption order, split by commas
|
35 |
+
cache_latents_to_disk: true # leave this true unless you know what you're doing
|
36 |
+
resolution: [ 512, 768, 1024 ] # flux enjoys multiple resolutions
|
37 |
+
train:
|
38 |
+
batch_size: 1
|
39 |
+
steps: 2000 # total number of steps to train 500 - 4000 is a good range
|
40 |
+
gradient_accumulation_steps: 1
|
41 |
+
train_unet: true
|
42 |
+
train_text_encoder: false # probably won't work with flux
|
43 |
+
gradient_checkpointing: true # need the on unless you have a ton of vram
|
44 |
+
noise_scheduler: "flowmatch" # for training only
|
45 |
+
optimizer: "adamw8bit"
|
46 |
+
lr: 1e-4
|
47 |
+
# uncomment this to skip the pre training sample
|
48 |
+
# skip_first_sample: true
|
49 |
+
# uncomment to completely disable sampling
|
50 |
+
# disable_sampling: true
|
51 |
+
# uncomment to use new vell curved weighting. Experimental but may produce better results
|
52 |
+
# linear_timesteps: true
|
53 |
+
|
54 |
+
# ema will smooth out learning, but could slow it down. Recommended to leave on.
|
55 |
+
ema_config:
|
56 |
+
use_ema: true
|
57 |
+
ema_decay: 0.99
|
58 |
+
|
59 |
+
# will probably need this if gpu supports it for flux, other dtypes may not work correctly
|
60 |
+
dtype: bf16
|
61 |
+
model:
|
62 |
+
# huggingface model name or path
|
63 |
+
# if you get an error, or get stuck while downloading,
|
64 |
+
# check https://github.com/ostris/ai-toolkit/issues/84, download the model locally and
|
65 |
+
# place it like "/root/ai-toolkit/FLUX.1-dev"
|
66 |
+
name_or_path: "black-forest-labs/FLUX.1-dev"
|
67 |
+
is_flux: true
|
68 |
+
quantize: true # run 8bit mixed precision
|
69 |
+
# low_vram: true # uncomment this if the GPU is connected to your monitors. It will use less vram to quantize, but is slower.
|
70 |
+
sample:
|
71 |
+
sampler: "flowmatch" # must match train.noise_scheduler
|
72 |
+
sample_every: 250 # sample every this many steps
|
73 |
+
width: 1024
|
74 |
+
height: 1024
|
75 |
+
prompts:
|
76 |
+
# you can add [trigger] to the prompts here and it will be replaced with the trigger word
|
77 |
+
# - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
|
78 |
+
- "woman with red hair, playing chess at the park, bomb going off in the background"
|
79 |
+
- "a woman holding a coffee cup, in a beanie, sitting at a cafe"
|
80 |
+
- "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
|
81 |
+
- "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
|
82 |
+
- "a bear building a log cabin in the snow covered mountains"
|
83 |
+
- "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
|
84 |
+
- "hipster man with a beard, building a chair, in a wood shop"
|
85 |
+
- "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop"
|
86 |
+
- "a man holding a sign that says, 'this is a sign'"
|
87 |
+
- "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
|
88 |
+
neg: "" # not used on flux
|
89 |
+
seed: 42
|
90 |
+
walk_seed: true
|
91 |
+
guidance_scale: 4
|
92 |
+
sample_steps: 20
|
93 |
+
# you can add any additional meta info here. [name] is replaced with config name at top
|
94 |
+
meta:
|
95 |
+
name: "[name]"
|
96 |
+
version: '1.0'
|
config/examples/modal/modal_train_lora_flux_schnell_24gb.yaml
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
job: extension
|
3 |
+
config:
|
4 |
+
# this name will be the folder and filename name
|
5 |
+
name: "my_first_flux_lora_v1"
|
6 |
+
process:
|
7 |
+
- type: 'sd_trainer'
|
8 |
+
# root folder to save training sessions/samples/weights
|
9 |
+
training_folder: "/root/ai-toolkit/modal_output" # must match MOUNT_DIR from run_modal.py
|
10 |
+
# uncomment to see performance stats in the terminal every N steps
|
11 |
+
# performance_log_every: 1000
|
12 |
+
device: cuda:0
|
13 |
+
# if a trigger word is specified, it will be added to captions of training data if it does not already exist
|
14 |
+
# alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
|
15 |
+
# trigger_word: "p3r5on"
|
16 |
+
network:
|
17 |
+
type: "lora"
|
18 |
+
linear: 16
|
19 |
+
linear_alpha: 16
|
20 |
+
save:
|
21 |
+
dtype: float16 # precision to save
|
22 |
+
save_every: 250 # save every this many steps
|
23 |
+
max_step_saves_to_keep: 4 # how many intermittent saves to keep
|
24 |
+
datasets:
|
25 |
+
# datasets are a folder of images. captions need to be txt files with the same name as the image
|
26 |
+
# for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
|
27 |
+
# images will automatically be resized and bucketed into the resolution specified
|
28 |
+
# on windows, escape back slashes with another backslash so
|
29 |
+
# "C:\\path\\to\\images\\folder"
|
30 |
+
# your dataset must be placed in /ai-toolkit and /root is for modal to find the dir:
|
31 |
+
- folder_path: "/root/ai-toolkit/your-dataset"
|
32 |
+
caption_ext: "txt"
|
33 |
+
caption_dropout_rate: 0.05 # will drop out the caption 5% of time
|
34 |
+
shuffle_tokens: false # shuffle caption order, split by commas
|
35 |
+
cache_latents_to_disk: true # leave this true unless you know what you're doing
|
36 |
+
resolution: [ 512, 768, 1024 ] # flux enjoys multiple resolutions
|
37 |
+
train:
|
38 |
+
batch_size: 1
|
39 |
+
steps: 2000 # total number of steps to train 500 - 4000 is a good range
|
40 |
+
gradient_accumulation_steps: 1
|
41 |
+
train_unet: true
|
42 |
+
train_text_encoder: false # probably won't work with flux
|
43 |
+
gradient_checkpointing: true # need the on unless you have a ton of vram
|
44 |
+
noise_scheduler: "flowmatch" # for training only
|
45 |
+
optimizer: "adamw8bit"
|
46 |
+
lr: 1e-4
|
47 |
+
# uncomment this to skip the pre training sample
|
48 |
+
# skip_first_sample: true
|
49 |
+
# uncomment to completely disable sampling
|
50 |
+
# disable_sampling: true
|
51 |
+
# uncomment to use new vell curved weighting. Experimental but may produce better results
|
52 |
+
# linear_timesteps: true
|
53 |
+
|
54 |
+
# ema will smooth out learning, but could slow it down. Recommended to leave on.
|
55 |
+
ema_config:
|
56 |
+
use_ema: true
|
57 |
+
ema_decay: 0.99
|
58 |
+
|
59 |
+
# will probably need this if gpu supports it for flux, other dtypes may not work correctly
|
60 |
+
dtype: bf16
|
61 |
+
model:
|
62 |
+
# huggingface model name or path
|
63 |
+
# if you get an error, or get stuck while downloading,
|
64 |
+
# check https://github.com/ostris/ai-toolkit/issues/84, download the models locally and
|
65 |
+
# place them like "/root/ai-toolkit/FLUX.1-schnell" and "/root/ai-toolkit/FLUX.1-schnell-training-adapter"
|
66 |
+
name_or_path: "black-forest-labs/FLUX.1-schnell"
|
67 |
+
assistant_lora_path: "ostris/FLUX.1-schnell-training-adapter" # Required for flux schnell training
|
68 |
+
is_flux: true
|
69 |
+
quantize: true # run 8bit mixed precision
|
70 |
+
# low_vram is painfully slow to fuse in the adapter avoid it unless absolutely necessary
|
71 |
+
# low_vram: true # uncomment this if the GPU is connected to your monitors. It will use less vram to quantize, but is slower.
|
72 |
+
sample:
|
73 |
+
sampler: "flowmatch" # must match train.noise_scheduler
|
74 |
+
sample_every: 250 # sample every this many steps
|
75 |
+
width: 1024
|
76 |
+
height: 1024
|
77 |
+
prompts:
|
78 |
+
# you can add [trigger] to the prompts here and it will be replaced with the trigger word
|
79 |
+
# - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
|
80 |
+
- "woman with red hair, playing chess at the park, bomb going off in the background"
|
81 |
+
- "a woman holding a coffee cup, in a beanie, sitting at a cafe"
|
82 |
+
- "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
|
83 |
+
- "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
|
84 |
+
- "a bear building a log cabin in the snow covered mountains"
|
85 |
+
- "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
|
86 |
+
- "hipster man with a beard, building a chair, in a wood shop"
|
87 |
+
- "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop"
|
88 |
+
- "a man holding a sign that says, 'this is a sign'"
|
89 |
+
- "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
|
90 |
+
neg: "" # not used on flux
|
91 |
+
seed: 42
|
92 |
+
walk_seed: true
|
93 |
+
guidance_scale: 1 # schnell does not do guidance
|
94 |
+
sample_steps: 4 # 1 - 4 works well
|
95 |
+
# you can add any additional meta info here. [name] is replaced with config name at top
|
96 |
+
meta:
|
97 |
+
name: "[name]"
|
98 |
+
version: '1.0'
|
config/examples/train_flex_redux.yaml
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
job: extension
|
3 |
+
config:
|
4 |
+
# this name will be the folder and filename name
|
5 |
+
name: "my_first_flex_redux_finetune_v1"
|
6 |
+
process:
|
7 |
+
- type: 'sd_trainer'
|
8 |
+
# root folder to save training sessions/samples/weights
|
9 |
+
training_folder: "output"
|
10 |
+
# uncomment to see performance stats in the terminal every N steps
|
11 |
+
# performance_log_every: 1000
|
12 |
+
device: cuda:0
|
13 |
+
adapter:
|
14 |
+
type: "redux"
|
15 |
+
# you can finetune an existing adapter or start from scratch. Set to null to start from scratch
|
16 |
+
name_or_path: '/local/path/to/redux_adapter_to_finetune.safetensors'
|
17 |
+
# name_or_path: null
|
18 |
+
# image_encoder_path: 'google/siglip-so400m-patch14-384' # Flux.1 redux adapter
|
19 |
+
image_encoder_path: 'google/siglip2-so400m-patch16-512' # Flex.1 512 redux adapter
|
20 |
+
# image_encoder_arch: 'siglip' # for Flux.1
|
21 |
+
image_encoder_arch: 'siglip2'
|
22 |
+
# You need a control input for each sample. Best to do squares for both images
|
23 |
+
test_img_path:
|
24 |
+
- "/path/to/x_01.jpg"
|
25 |
+
- "/path/to/x_02.jpg"
|
26 |
+
- "/path/to/x_03.jpg"
|
27 |
+
- "/path/to/x_04.jpg"
|
28 |
+
- "/path/to/x_05.jpg"
|
29 |
+
- "/path/to/x_06.jpg"
|
30 |
+
- "/path/to/x_07.jpg"
|
31 |
+
- "/path/to/x_08.jpg"
|
32 |
+
- "/path/to/x_09.jpg"
|
33 |
+
- "/path/to/x_10.jpg"
|
34 |
+
clip_layer: 'last_hidden_state'
|
35 |
+
train: true
|
36 |
+
save:
|
37 |
+
dtype: bf16 # precision to save
|
38 |
+
save_every: 250 # save every this many steps
|
39 |
+
max_step_saves_to_keep: 4
|
40 |
+
datasets:
|
41 |
+
# datasets are a folder of images. captions need to be txt files with the same name as the image
|
42 |
+
# for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
|
43 |
+
# images will automatically be resized and bucketed into the resolution specified
|
44 |
+
# on windows, escape back slashes with another backslash so
|
45 |
+
# "C:\\path\\to\\images\\folder"
|
46 |
+
- folder_path: "/path/to/images/folder"
|
47 |
+
# clip_image_path is directory containting your control images. They must have filename as their train image. (extension does not matter)
|
48 |
+
# for normal redux, we are just recreating the same image, so you can use the same folder path above
|
49 |
+
clip_image_path: "/path/to/control/images/folder"
|
50 |
+
caption_ext: "txt"
|
51 |
+
caption_dropout_rate: 0.05 # will drop out the caption 5% of time
|
52 |
+
resolution: [ 512, 768, 1024 ] # flex enjoys multiple resolutions
|
53 |
+
train:
|
54 |
+
# this is what I used for the 24GB card, but feel free to adjust
|
55 |
+
# total batch size is 6 here
|
56 |
+
batch_size: 3
|
57 |
+
gradient_accumulation: 2
|
58 |
+
|
59 |
+
# captions are not needed for this training, we cache a blank proompt and rely on the vision encoder
|
60 |
+
unload_text_encoder: true
|
61 |
+
|
62 |
+
loss_type: "mse"
|
63 |
+
train_unet: true
|
64 |
+
train_text_encoder: false
|
65 |
+
steps: 4000000 # I set this very high and stop when I like the results
|
66 |
+
content_or_style: balanced # content, style, balanced
|
67 |
+
gradient_checkpointing: true
|
68 |
+
noise_scheduler: "flowmatch" # or "ddpm", "lms", "euler_a"
|
69 |
+
timestep_type: "flux_shift"
|
70 |
+
optimizer: "adamw8bit"
|
71 |
+
lr: 1e-4
|
72 |
+
|
73 |
+
# this is for Flex.1, comment this out for FLUX.1-dev
|
74 |
+
bypass_guidance_embedding: true
|
75 |
+
|
76 |
+
dtype: bf16
|
77 |
+
ema_config:
|
78 |
+
use_ema: true
|
79 |
+
ema_decay: 0.99
|
80 |
+
model:
|
81 |
+
name_or_path: "ostris/Flex.1-alpha"
|
82 |
+
is_flux: true
|
83 |
+
quantize: true
|
84 |
+
text_encoder_bits: 8
|
85 |
+
sample:
|
86 |
+
sampler: "flowmatch" # must match train.noise_scheduler
|
87 |
+
sample_every: 250 # sample every this many steps
|
88 |
+
width: 1024
|
89 |
+
height: 1024
|
90 |
+
# I leave half blank to test prompt and unprompted
|
91 |
+
prompts:
|
92 |
+
- "woman with red hair, playing chess at the park, bomb going off in the background"
|
93 |
+
- "a woman holding a coffee cup, in a beanie, sitting at a cafe"
|
94 |
+
- "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
|
95 |
+
- "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
|
96 |
+
- "a bear building a log cabin in the snow covered mountains"
|
97 |
+
- ""
|
98 |
+
- ""
|
99 |
+
- ""
|
100 |
+
- ""
|
101 |
+
- ""
|
102 |
+
neg: ""
|
103 |
+
seed: 42
|
104 |
+
walk_seed: true
|
105 |
+
guidance_scale: 4
|
106 |
+
sample_steps: 25
|
107 |
+
network_multiplier: 1.0
|
108 |
+
|
109 |
+
# you can add any additional meta info here. [name] is replaced with config name at top
|
110 |
+
meta:
|
111 |
+
name: "[name]"
|
112 |
+
version: '1.0'
|
config/examples/train_full_fine_tune_flex.yaml
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
# This configuration requires 48GB of VRAM or more to operate
|
3 |
+
job: extension
|
4 |
+
config:
|
5 |
+
# this name will be the folder and filename name
|
6 |
+
name: "my_first_flex_finetune_v1"
|
7 |
+
process:
|
8 |
+
- type: 'sd_trainer'
|
9 |
+
# root folder to save training sessions/samples/weights
|
10 |
+
training_folder: "output"
|
11 |
+
# uncomment to see performance stats in the terminal every N steps
|
12 |
+
# performance_log_every: 1000
|
13 |
+
device: cuda:0
|
14 |
+
# if a trigger word is specified, it will be added to captions of training data if it does not already exist
|
15 |
+
# alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
|
16 |
+
# trigger_word: "p3r5on"
|
17 |
+
save:
|
18 |
+
dtype: bf16 # precision to save
|
19 |
+
save_every: 250 # save every this many steps
|
20 |
+
max_step_saves_to_keep: 2 # how many intermittent saves to keep
|
21 |
+
save_format: 'diffusers' # 'diffusers'
|
22 |
+
datasets:
|
23 |
+
# datasets are a folder of images. captions need to be txt files with the same name as the image
|
24 |
+
# for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
|
25 |
+
# images will automatically be resized and bucketed into the resolution specified
|
26 |
+
# on windows, escape back slashes with another backslash so
|
27 |
+
# "C:\\path\\to\\images\\folder"
|
28 |
+
- folder_path: "/path/to/images/folder"
|
29 |
+
caption_ext: "txt"
|
30 |
+
caption_dropout_rate: 0.05 # will drop out the caption 5% of time
|
31 |
+
shuffle_tokens: false # shuffle caption order, split by commas
|
32 |
+
# cache_latents_to_disk: true # leave this true unless you know what you're doing
|
33 |
+
resolution: [ 512, 768, 1024 ] # flex enjoys multiple resolutions
|
34 |
+
train:
|
35 |
+
batch_size: 1
|
36 |
+
# IMPORTANT! For Flex, you must bypass the guidance embedder during training
|
37 |
+
bypass_guidance_embedding: true
|
38 |
+
|
39 |
+
# can be 'sigmoid', 'linear', or 'lognorm_blend'
|
40 |
+
timestep_type: 'sigmoid'
|
41 |
+
|
42 |
+
steps: 2000 # total number of steps to train 500 - 4000 is a good range
|
43 |
+
gradient_accumulation: 1
|
44 |
+
train_unet: true
|
45 |
+
train_text_encoder: false # probably won't work with flex
|
46 |
+
gradient_checkpointing: true # need the on unless you have a ton of vram
|
47 |
+
noise_scheduler: "flowmatch" # for training only
|
48 |
+
optimizer: "adafactor"
|
49 |
+
lr: 3e-5
|
50 |
+
|
51 |
+
# Paramiter swapping can reduce vram requirements. Set factor from 1.0 to 0.0.
|
52 |
+
# 0.1 is 10% of paramiters active at easc step. Only works with adafactor
|
53 |
+
|
54 |
+
# do_paramiter_swapping: true
|
55 |
+
# paramiter_swapping_factor: 0.9
|
56 |
+
|
57 |
+
# uncomment this to skip the pre training sample
|
58 |
+
# skip_first_sample: true
|
59 |
+
# uncomment to completely disable sampling
|
60 |
+
# disable_sampling: true
|
61 |
+
|
62 |
+
# ema will smooth out learning, but could slow it down. Recommended to leave on if you have the vram
|
63 |
+
ema_config:
|
64 |
+
use_ema: true
|
65 |
+
ema_decay: 0.99
|
66 |
+
|
67 |
+
# will probably need this if gpu supports it for flex, other dtypes may not work correctly
|
68 |
+
dtype: bf16
|
69 |
+
model:
|
70 |
+
# huggingface model name or path
|
71 |
+
name_or_path: "ostris/Flex.1-alpha"
|
72 |
+
is_flux: true # flex is flux architecture
|
73 |
+
# full finetuning quantized models is a crapshoot and results in subpar outputs
|
74 |
+
# quantize: true
|
75 |
+
# you can quantize just the T5 text encoder here to save vram
|
76 |
+
quantize_te: true
|
77 |
+
# only train the transformer blocks
|
78 |
+
only_if_contains:
|
79 |
+
- "transformer.transformer_blocks."
|
80 |
+
- "transformer.single_transformer_blocks."
|
81 |
+
sample:
|
82 |
+
sampler: "flowmatch" # must match train.noise_scheduler
|
83 |
+
sample_every: 250 # sample every this many steps
|
84 |
+
width: 1024
|
85 |
+
height: 1024
|
86 |
+
prompts:
|
87 |
+
# you can add [trigger] to the prompts here and it will be replaced with the trigger word
|
88 |
+
# - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
|
89 |
+
- "woman with red hair, playing chess at the park, bomb going off in the background"
|
90 |
+
- "a woman holding a coffee cup, in a beanie, sitting at a cafe"
|
91 |
+
- "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
|
92 |
+
- "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
|
93 |
+
- "a bear building a log cabin in the snow covered mountains"
|
94 |
+
- "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
|
95 |
+
- "hipster man with a beard, building a chair, in a wood shop"
|
96 |
+
- "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop"
|
97 |
+
- "a man holding a sign that says, 'this is a sign'"
|
98 |
+
- "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
|
99 |
+
neg: "" # not used on flex
|
100 |
+
seed: 42
|
101 |
+
walk_seed: true
|
102 |
+
guidance_scale: 4
|
103 |
+
sample_steps: 25
|
104 |
+
# you can add any additional meta info here. [name] is replaced with config name at top
|
105 |
+
meta:
|
106 |
+
name: "[name]"
|
107 |
+
version: '1.0'
|
config/examples/train_full_fine_tune_lumina.yaml
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
# This configuration requires 24GB of VRAM or more to operate
|
3 |
+
job: extension
|
4 |
+
config:
|
5 |
+
# this name will be the folder and filename name
|
6 |
+
name: "my_first_lumina_finetune_v1"
|
7 |
+
process:
|
8 |
+
- type: 'sd_trainer'
|
9 |
+
# root folder to save training sessions/samples/weights
|
10 |
+
training_folder: "output"
|
11 |
+
# uncomment to see performance stats in the terminal every N steps
|
12 |
+
# performance_log_every: 1000
|
13 |
+
device: cuda:0
|
14 |
+
# if a trigger word is specified, it will be added to captions of training data if it does not already exist
|
15 |
+
# alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
|
16 |
+
# trigger_word: "p3r5on"
|
17 |
+
save:
|
18 |
+
dtype: bf16 # precision to save
|
19 |
+
save_every: 250 # save every this many steps
|
20 |
+
max_step_saves_to_keep: 2 # how many intermittent saves to keep
|
21 |
+
save_format: 'diffusers' # 'diffusers'
|
22 |
+
datasets:
|
23 |
+
# datasets are a folder of images. captions need to be txt files with the same name as the image
|
24 |
+
# for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
|
25 |
+
# images will automatically be resized and bucketed into the resolution specified
|
26 |
+
# on windows, escape back slashes with another backslash so
|
27 |
+
# "C:\\path\\to\\images\\folder"
|
28 |
+
- folder_path: "/path/to/images/folder"
|
29 |
+
caption_ext: "txt"
|
30 |
+
caption_dropout_rate: 0.05 # will drop out the caption 5% of time
|
31 |
+
shuffle_tokens: false # shuffle caption order, split by commas
|
32 |
+
# cache_latents_to_disk: true # leave this true unless you know what you're doing
|
33 |
+
resolution: [ 512, 768, 1024 ] # lumina2 enjoys multiple resolutions
|
34 |
+
train:
|
35 |
+
batch_size: 1
|
36 |
+
|
37 |
+
# can be 'sigmoid', 'linear', or 'lumina2_shift'
|
38 |
+
timestep_type: 'lumina2_shift'
|
39 |
+
|
40 |
+
steps: 2000 # total number of steps to train 500 - 4000 is a good range
|
41 |
+
gradient_accumulation: 1
|
42 |
+
train_unet: true
|
43 |
+
train_text_encoder: false # probably won't work with lumina2
|
44 |
+
gradient_checkpointing: true # need the on unless you have a ton of vram
|
45 |
+
noise_scheduler: "flowmatch" # for training only
|
46 |
+
optimizer: "adafactor"
|
47 |
+
lr: 3e-5
|
48 |
+
|
49 |
+
# Paramiter swapping can reduce vram requirements. Set factor from 1.0 to 0.0.
|
50 |
+
# 0.1 is 10% of paramiters active at easc step. Only works with adafactor
|
51 |
+
|
52 |
+
# do_paramiter_swapping: true
|
53 |
+
# paramiter_swapping_factor: 0.9
|
54 |
+
|
55 |
+
# uncomment this to skip the pre training sample
|
56 |
+
# skip_first_sample: true
|
57 |
+
# uncomment to completely disable sampling
|
58 |
+
# disable_sampling: true
|
59 |
+
|
60 |
+
# ema will smooth out learning, but could slow it down. Recommended to leave on if you have the vram
|
61 |
+
# ema_config:
|
62 |
+
# use_ema: true
|
63 |
+
# ema_decay: 0.99
|
64 |
+
|
65 |
+
# will probably need this if gpu supports it for lumina2, other dtypes may not work correctly
|
66 |
+
dtype: bf16
|
67 |
+
model:
|
68 |
+
# huggingface model name or path
|
69 |
+
name_or_path: "Alpha-VLLM/Lumina-Image-2.0"
|
70 |
+
is_lumina2: true # lumina2 architecture
|
71 |
+
# you can quantize just the Gemma2 text encoder here to save vram
|
72 |
+
quantize_te: true
|
73 |
+
sample:
|
74 |
+
sampler: "flowmatch" # must match train.noise_scheduler
|
75 |
+
sample_every: 250 # sample every this many steps
|
76 |
+
width: 1024
|
77 |
+
height: 1024
|
78 |
+
prompts:
|
79 |
+
# you can add [trigger] to the prompts here and it will be replaced with the trigger word
|
80 |
+
# - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
|
81 |
+
- "woman with red hair, playing chess at the park, bomb going off in the background"
|
82 |
+
- "a woman holding a coffee cup, in a beanie, sitting at a cafe"
|
83 |
+
- "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
|
84 |
+
- "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
|
85 |
+
- "a bear building a log cabin in the snow covered mountains"
|
86 |
+
- "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
|
87 |
+
- "hipster man with a beard, building a chair, in a wood shop"
|
88 |
+
- "photo of a cat that is half black and half orange tabby, split down the middle. The cat has on a blue tophat. They are holding a martini glass with a pink ball of yarn in it with green knitting needles sticking out, in one paw. In the other paw, they are holding a DVD case for a movie titled, \"This is a test\" that has a golden robot on it. In the background is a busy night club with a giant mushroom man dancing with a bear."
|
89 |
+
- "a man holding a sign that says, 'this is a sign'"
|
90 |
+
- "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
|
91 |
+
neg: ""
|
92 |
+
seed: 42
|
93 |
+
walk_seed: true
|
94 |
+
guidance_scale: 4.0
|
95 |
+
sample_steps: 25
|
96 |
+
# you can add any additional meta info here. [name] is replaced with config name at top
|
97 |
+
meta:
|
98 |
+
name: "[name]"
|
99 |
+
version: '1.0'
|
config/examples/train_lora_chroma_24gb.yaml
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
job: extension
|
3 |
+
config:
|
4 |
+
# this name will be the folder and filename name
|
5 |
+
name: "my_first_chroma_lora_v1"
|
6 |
+
process:
|
7 |
+
- type: 'sd_trainer'
|
8 |
+
# root folder to save training sessions/samples/weights
|
9 |
+
training_folder: "output"
|
10 |
+
# uncomment to see performance stats in the terminal every N steps
|
11 |
+
# performance_log_every: 1000
|
12 |
+
device: cuda:0
|
13 |
+
# if a trigger word is specified, it will be added to captions of training data if it does not already exist
|
14 |
+
# alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
|
15 |
+
# trigger_word: "p3r5on"
|
16 |
+
network:
|
17 |
+
type: "lora"
|
18 |
+
linear: 16
|
19 |
+
linear_alpha: 16
|
20 |
+
save:
|
21 |
+
dtype: float16 # precision to save
|
22 |
+
save_every: 250 # save every this many steps
|
23 |
+
max_step_saves_to_keep: 4 # how many intermittent saves to keep
|
24 |
+
push_to_hub: false #change this to True to push your trained model to Hugging Face.
|
25 |
+
# You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
|
26 |
+
# hf_repo_id: your-username/your-model-slug
|
27 |
+
# hf_private: true #whether the repo is private or public
|
28 |
+
datasets:
|
29 |
+
# datasets are a folder of images. captions need to be txt files with the same name as the image
|
30 |
+
# for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
|
31 |
+
# images will automatically be resized and bucketed into the resolution specified
|
32 |
+
# on windows, escape back slashes with another backslash so
|
33 |
+
# "C:\\path\\to\\images\\folder"
|
34 |
+
- folder_path: "/path/to/images/folder"
|
35 |
+
caption_ext: "txt"
|
36 |
+
caption_dropout_rate: 0.05 # will drop out the caption 5% of time
|
37 |
+
shuffle_tokens: false # shuffle caption order, split by commas
|
38 |
+
cache_latents_to_disk: true # leave this true unless you know what you're doing
|
39 |
+
resolution: [ 512, 768, 1024 ] # chroma enjoys multiple resolutions
|
40 |
+
train:
|
41 |
+
batch_size: 1
|
42 |
+
steps: 2000 # total number of steps to train 500 - 4000 is a good range
|
43 |
+
gradient_accumulation: 1
|
44 |
+
train_unet: true
|
45 |
+
train_text_encoder: false # probably won't work with chroma
|
46 |
+
gradient_checkpointing: true # need the on unless you have a ton of vram
|
47 |
+
noise_scheduler: "flowmatch" # for training only
|
48 |
+
optimizer: "adamw8bit"
|
49 |
+
lr: 1e-4
|
50 |
+
# uncomment this to skip the pre training sample
|
51 |
+
# skip_first_sample: true
|
52 |
+
# uncomment to completely disable sampling
|
53 |
+
# disable_sampling: true
|
54 |
+
# uncomment to use new vell curved weighting. Experimental but may produce better results
|
55 |
+
# linear_timesteps: true
|
56 |
+
|
57 |
+
# ema will smooth out learning, but could slow it down. Recommended to leave on.
|
58 |
+
ema_config:
|
59 |
+
use_ema: true
|
60 |
+
ema_decay: 0.99
|
61 |
+
|
62 |
+
# will probably need this if gpu supports it for chroma, other dtypes may not work correctly
|
63 |
+
dtype: bf16
|
64 |
+
model:
|
65 |
+
# Download the whichever model you prefer from the Chroma repo
|
66 |
+
# https://huggingface.co/lodestones/Chroma/tree/main
|
67 |
+
# point to it here.
|
68 |
+
# name_or_path: "/path/to/chroma/chroma-unlocked-vVERSION.safetensors"
|
69 |
+
|
70 |
+
# using lodestones/Chroma will automatically use the latest version
|
71 |
+
name_or_path: "lodestones/Chroma"
|
72 |
+
|
73 |
+
# # You can also select a version of Chroma like so
|
74 |
+
# name_or_path: "lodestones/Chroma/v28"
|
75 |
+
|
76 |
+
arch: "chroma"
|
77 |
+
quantize: true # run 8bit mixed precision
|
78 |
+
sample:
|
79 |
+
sampler: "flowmatch" # must match train.noise_scheduler
|
80 |
+
sample_every: 250 # sample every this many steps
|
81 |
+
width: 1024
|
82 |
+
height: 1024
|
83 |
+
prompts:
|
84 |
+
# you can add [trigger] to the prompts here and it will be replaced with the trigger word
|
85 |
+
# - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
|
86 |
+
- "woman with red hair, playing chess at the park, bomb going off in the background"
|
87 |
+
- "a woman holding a coffee cup, in a beanie, sitting at a cafe"
|
88 |
+
- "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
|
89 |
+
- "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
|
90 |
+
- "a bear building a log cabin in the snow covered mountains"
|
91 |
+
- "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
|
92 |
+
- "hipster man with a beard, building a chair, in a wood shop"
|
93 |
+
- "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop"
|
94 |
+
- "a man holding a sign that says, 'this is a sign'"
|
95 |
+
- "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
|
96 |
+
neg: "" # negative prompt, optional
|
97 |
+
seed: 42
|
98 |
+
walk_seed: true
|
99 |
+
guidance_scale: 4
|
100 |
+
sample_steps: 25
|
101 |
+
# you can add any additional meta info here. [name] is replaced with config name at top
|
102 |
+
meta:
|
103 |
+
name: "[name]"
|
104 |
+
version: '1.0'
|
config/examples/train_lora_flex2_24gb.yaml
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Note, Flex2 is a highly experimental WIP model. Finetuning a model with built in controls and inpainting has not
|
2 |
+
# been done before, so you will be experimenting with me on how to do it. This is my recommended setup, but this is highly
|
3 |
+
# subject to change as we learn more about how Flex2 works.
|
4 |
+
|
5 |
+
---
|
6 |
+
job: extension
|
7 |
+
config:
|
8 |
+
# this name will be the folder and filename name
|
9 |
+
name: "my_first_flex2_lora_v1"
|
10 |
+
process:
|
11 |
+
- type: 'sd_trainer'
|
12 |
+
# root folder to save training sessions/samples/weights
|
13 |
+
training_folder: "output"
|
14 |
+
# uncomment to see performance stats in the terminal every N steps
|
15 |
+
# performance_log_every: 1000
|
16 |
+
device: cuda:0
|
17 |
+
# if a trigger word is specified, it will be added to captions of training data if it does not already exist
|
18 |
+
# alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
|
19 |
+
# trigger_word: "p3r5on"
|
20 |
+
network:
|
21 |
+
type: "lora"
|
22 |
+
linear: 32
|
23 |
+
linear_alpha: 32
|
24 |
+
save:
|
25 |
+
dtype: float16 # precision to save
|
26 |
+
save_every: 250 # save every this many steps
|
27 |
+
max_step_saves_to_keep: 4 # how many intermittent saves to keep
|
28 |
+
push_to_hub: false #change this to True to push your trained model to Hugging Face.
|
29 |
+
# You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
|
30 |
+
# hf_repo_id: your-username/your-model-slug
|
31 |
+
# hf_private: true #whether the repo is private or public
|
32 |
+
datasets:
|
33 |
+
# datasets are a folder of images. captions need to be txt files with the same name as the image
|
34 |
+
# for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
|
35 |
+
# images will automatically be resized and bucketed into the resolution specified
|
36 |
+
# on windows, escape back slashes with another backslash so
|
37 |
+
# "C:\\path\\to\\images\\folder"
|
38 |
+
- folder_path: "/path/to/images/folder"
|
39 |
+
# Flex2 is trained with controls and inpainting. If you want the model to truely understand how the
|
40 |
+
# controls function with your dataset, it is a good idea to keep doing controls during training.
|
41 |
+
# this will automatically generate the controls for you before training. The current script is not
|
42 |
+
# fully optimized so this could be rather slow for large datasets, but it caches them to disk so it
|
43 |
+
# only needs to be done once. If you want to skip this step, you can set the controls to [] and it will
|
44 |
+
controls:
|
45 |
+
- "depth"
|
46 |
+
- "line"
|
47 |
+
- "pose"
|
48 |
+
- "inpaint"
|
49 |
+
|
50 |
+
# you can make custom inpainting images as well. These images must be webp or png format with an alpha.
|
51 |
+
# just erase the part of the image you want to inpaint and save it as a webp or png. Again, erase your
|
52 |
+
# train target. So the person if training a person. The automatic controls above with inpaint will
|
53 |
+
# just run a background remover mask and erase the foreground, which works well for subjects.
|
54 |
+
|
55 |
+
# inpaint_path: "/my/impaint/images"
|
56 |
+
|
57 |
+
# you can also specify existing control image pairs. It can handle multiple groups and will randomly
|
58 |
+
# select one for each step.
|
59 |
+
|
60 |
+
# control_path:
|
61 |
+
# - "/my/custom/control/images"
|
62 |
+
# - "/my/custom/control/images2"
|
63 |
+
|
64 |
+
caption_ext: "txt"
|
65 |
+
caption_dropout_rate: 0.05 # will drop out the caption 5% of time
|
66 |
+
resolution: [ 512, 768, 1024 ] # flex2 enjoys multiple resolutions
|
67 |
+
train:
|
68 |
+
batch_size: 1
|
69 |
+
# IMPORTANT! For Flex2, you must bypass the guidance embedder during training
|
70 |
+
bypass_guidance_embedding: true
|
71 |
+
|
72 |
+
steps: 3000 # total number of steps to train 500 - 4000 is a good range
|
73 |
+
gradient_accumulation: 1
|
74 |
+
train_unet: true
|
75 |
+
train_text_encoder: false # probably won't work with flex2
|
76 |
+
gradient_checkpointing: true # need the on unless you have a ton of vram
|
77 |
+
noise_scheduler: "flowmatch" # for training only
|
78 |
+
# shift works well for training fast and learning composition and style.
|
79 |
+
# for just subject, you may want to change this to sigmoid
|
80 |
+
timestep_type: 'shift' # 'linear', 'sigmoid', 'shift'
|
81 |
+
optimizer: "adamw8bit"
|
82 |
+
lr: 1e-4
|
83 |
+
|
84 |
+
optimizer_params:
|
85 |
+
weight_decay: 1e-5
|
86 |
+
# uncomment this to skip the pre training sample
|
87 |
+
# skip_first_sample: true
|
88 |
+
# uncomment to completely disable sampling
|
89 |
+
# disable_sampling: true
|
90 |
+
# uncomment to use new vell curved weighting. Experimental but may produce better results
|
91 |
+
# linear_timesteps: true
|
92 |
+
|
93 |
+
# ema will smooth out learning, but could slow it down. Defaults off
|
94 |
+
ema_config:
|
95 |
+
use_ema: false
|
96 |
+
ema_decay: 0.99
|
97 |
+
|
98 |
+
# will probably need this if gpu supports it for flex, other dtypes may not work correctly
|
99 |
+
dtype: bf16
|
100 |
+
model:
|
101 |
+
# huggingface model name or path
|
102 |
+
name_or_path: "ostris/Flex.2-preview"
|
103 |
+
arch: "flex2"
|
104 |
+
quantize: true # run 8bit mixed precision
|
105 |
+
quantize_te: true
|
106 |
+
|
107 |
+
# you can pass special training infor for controls to the model here
|
108 |
+
# percentages are decimal based so 0.0 is 0% and 1.0 is 100% of the time.
|
109 |
+
model_kwargs:
|
110 |
+
# inverts the inpainting mask, good to learn outpainting as well, recommended 0.0 for characters
|
111 |
+
invert_inpaint_mask_chance: 0.5
|
112 |
+
# this will do a normal t2i training step without inpaint when dropped out. REcommended if you want
|
113 |
+
# your lora to be able to inference with and without inpainting.
|
114 |
+
inpaint_dropout: 0.5
|
115 |
+
# randomly drops out the control image. Dropout recvommended if your want it to work without controls as well.
|
116 |
+
control_dropout: 0.5
|
117 |
+
# does a random inpaint blob. Usually a good idea to keep. Without it, the model will learn to always 100%
|
118 |
+
# fill the inpaint area with your subject. This is not always a good thing.
|
119 |
+
inpaint_random_chance: 0.5
|
120 |
+
# generates random inpaint blobs if you did not provide an inpaint image for your dataset. Inpaint breaks down fast
|
121 |
+
# if you are not training with it. Controls are a little more robust and can be left out,
|
122 |
+
# but when in doubt, always leave this on
|
123 |
+
do_random_inpainting: false
|
124 |
+
# does random blurring of the inpaint mask. Helps prevent weird edge artifacts for real workd inpainting. Leave on.
|
125 |
+
random_blur_mask: true
|
126 |
+
# applies a small amount of random dialition and restriction to the inpaint mask. Helps with edge artifacts.
|
127 |
+
# Leave on.
|
128 |
+
random_dialate_mask: true
|
129 |
+
sample:
|
130 |
+
sampler: "flowmatch" # must match train.noise_scheduler
|
131 |
+
sample_every: 250 # sample every this many steps
|
132 |
+
width: 1024
|
133 |
+
height: 1024
|
134 |
+
prompts:
|
135 |
+
# you can add [trigger] to the prompts here and it will be replaced with the trigger word
|
136 |
+
# - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
|
137 |
+
|
138 |
+
# you can use a single inpaint or single control image on your samples.
|
139 |
+
# for controls, the ctrl_idx is 1, the images can be any name and image format.
|
140 |
+
# use either a pose/line/depth image or whatever you are training with. An example is
|
141 |
+
# - "photo of [trigger] --ctrl_idx 1 --ctrl_img /path/to/control/image.jpg"
|
142 |
+
|
143 |
+
# for an inpainting image, it must be png/webp. Erase the part of the image you want to inpaint
|
144 |
+
# IMPORTANT! the inpaint images must be ctrl_idx 0 and have .inpaint.{ext} in the name for this to work right.
|
145 |
+
# - "photo of [trigger] --ctrl_idx 0 --ctrl_img /path/to/inpaint/image.inpaint.png"
|
146 |
+
|
147 |
+
- "woman with red hair, playing chess at the park, bomb going off in the background"
|
148 |
+
- "a woman holding a coffee cup, in a beanie, sitting at a cafe"
|
149 |
+
- "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
|
150 |
+
- "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
|
151 |
+
- "a bear building a log cabin in the snow covered mountains"
|
152 |
+
- "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
|
153 |
+
- "hipster man with a beard, building a chair, in a wood shop"
|
154 |
+
- "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop"
|
155 |
+
- "a man holding a sign that says, 'this is a sign'"
|
156 |
+
- "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
|
157 |
+
neg: "" # not used on flex2
|
158 |
+
seed: 42
|
159 |
+
walk_seed: true
|
160 |
+
guidance_scale: 4
|
161 |
+
sample_steps: 25
|
162 |
+
# you can add any additional meta info here. [name] is replaced with config name at top
|
163 |
+
meta:
|
164 |
+
name: "[name]"
|
165 |
+
version: '1.0'
|
config/examples/train_lora_flex_24gb.yaml
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
job: extension
|
3 |
+
config:
|
4 |
+
# this name will be the folder and filename name
|
5 |
+
name: "my_first_flex_lora_v1"
|
6 |
+
process:
|
7 |
+
- type: 'sd_trainer'
|
8 |
+
# root folder to save training sessions/samples/weights
|
9 |
+
training_folder: "output"
|
10 |
+
# uncomment to see performance stats in the terminal every N steps
|
11 |
+
# performance_log_every: 1000
|
12 |
+
device: cuda:0
|
13 |
+
# if a trigger word is specified, it will be added to captions of training data if it does not already exist
|
14 |
+
# alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
|
15 |
+
# trigger_word: "p3r5on"
|
16 |
+
network:
|
17 |
+
type: "lora"
|
18 |
+
linear: 16
|
19 |
+
linear_alpha: 16
|
20 |
+
save:
|
21 |
+
dtype: float16 # precision to save
|
22 |
+
save_every: 250 # save every this many steps
|
23 |
+
max_step_saves_to_keep: 4 # how many intermittent saves to keep
|
24 |
+
push_to_hub: false #change this to True to push your trained model to Hugging Face.
|
25 |
+
# You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
|
26 |
+
# hf_repo_id: your-username/your-model-slug
|
27 |
+
# hf_private: true #whether the repo is private or public
|
28 |
+
datasets:
|
29 |
+
# datasets are a folder of images. captions need to be txt files with the same name as the image
|
30 |
+
# for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
|
31 |
+
# images will automatically be resized and bucketed into the resolution specified
|
32 |
+
# on windows, escape back slashes with another backslash so
|
33 |
+
# "C:\\path\\to\\images\\folder"
|
34 |
+
- folder_path: "/path/to/images/folder"
|
35 |
+
caption_ext: "txt"
|
36 |
+
caption_dropout_rate: 0.05 # will drop out the caption 5% of time
|
37 |
+
shuffle_tokens: false # shuffle caption order, split by commas
|
38 |
+
cache_latents_to_disk: true # leave this true unless you know what you're doing
|
39 |
+
resolution: [ 512, 768, 1024 ] # flex enjoys multiple resolutions
|
40 |
+
train:
|
41 |
+
batch_size: 1
|
42 |
+
# IMPORTANT! For Flex, you must bypass the guidance embedder during training
|
43 |
+
bypass_guidance_embedding: true
|
44 |
+
|
45 |
+
steps: 2000 # total number of steps to train 500 - 4000 is a good range
|
46 |
+
gradient_accumulation: 1
|
47 |
+
train_unet: true
|
48 |
+
train_text_encoder: false # probably won't work with flex
|
49 |
+
gradient_checkpointing: true # need the on unless you have a ton of vram
|
50 |
+
noise_scheduler: "flowmatch" # for training only
|
51 |
+
optimizer: "adamw8bit"
|
52 |
+
lr: 1e-4
|
53 |
+
# uncomment this to skip the pre training sample
|
54 |
+
# skip_first_sample: true
|
55 |
+
# uncomment to completely disable sampling
|
56 |
+
# disable_sampling: true
|
57 |
+
# uncomment to use new vell curved weighting. Experimental but may produce better results
|
58 |
+
# linear_timesteps: true
|
59 |
+
|
60 |
+
# ema will smooth out learning, but could slow it down. Recommended to leave on.
|
61 |
+
ema_config:
|
62 |
+
use_ema: true
|
63 |
+
ema_decay: 0.99
|
64 |
+
|
65 |
+
# will probably need this if gpu supports it for flex, other dtypes may not work correctly
|
66 |
+
dtype: bf16
|
67 |
+
model:
|
68 |
+
# huggingface model name or path
|
69 |
+
name_or_path: "ostris/Flex.1-alpha"
|
70 |
+
is_flux: true
|
71 |
+
quantize: true # run 8bit mixed precision
|
72 |
+
quantize_kwargs:
|
73 |
+
exclude:
|
74 |
+
- "*time_text_embed*" # exclude the time text embedder from quantization
|
75 |
+
sample:
|
76 |
+
sampler: "flowmatch" # must match train.noise_scheduler
|
77 |
+
sample_every: 250 # sample every this many steps
|
78 |
+
width: 1024
|
79 |
+
height: 1024
|
80 |
+
prompts:
|
81 |
+
# you can add [trigger] to the prompts here and it will be replaced with the trigger word
|
82 |
+
# - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
|
83 |
+
- "woman with red hair, playing chess at the park, bomb going off in the background"
|
84 |
+
- "a woman holding a coffee cup, in a beanie, sitting at a cafe"
|
85 |
+
- "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
|
86 |
+
- "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
|
87 |
+
- "a bear building a log cabin in the snow covered mountains"
|
88 |
+
- "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
|
89 |
+
- "hipster man with a beard, building a chair, in a wood shop"
|
90 |
+
- "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop"
|
91 |
+
- "a man holding a sign that says, 'this is a sign'"
|
92 |
+
- "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
|
93 |
+
neg: "" # not used on flex
|
94 |
+
seed: 42
|
95 |
+
walk_seed: true
|
96 |
+
guidance_scale: 4
|
97 |
+
sample_steps: 25
|
98 |
+
# you can add any additional meta info here. [name] is replaced with config name at top
|
99 |
+
meta:
|
100 |
+
name: "[name]"
|
101 |
+
version: '1.0'
|
config/examples/train_lora_flux_24gb.yaml
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
job: extension
|
3 |
+
config:
|
4 |
+
# this name will be the folder and filename name
|
5 |
+
name: "my_first_flux_lora_v1"
|
6 |
+
process:
|
7 |
+
- type: 'sd_trainer'
|
8 |
+
# root folder to save training sessions/samples/weights
|
9 |
+
training_folder: "output"
|
10 |
+
# uncomment to see performance stats in the terminal every N steps
|
11 |
+
# performance_log_every: 1000
|
12 |
+
device: cuda:0
|
13 |
+
# if a trigger word is specified, it will be added to captions of training data if it does not already exist
|
14 |
+
# alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
|
15 |
+
# trigger_word: "p3r5on"
|
16 |
+
network:
|
17 |
+
type: "lora"
|
18 |
+
linear: 16
|
19 |
+
linear_alpha: 16
|
20 |
+
save:
|
21 |
+
dtype: float16 # precision to save
|
22 |
+
save_every: 250 # save every this many steps
|
23 |
+
max_step_saves_to_keep: 4 # how many intermittent saves to keep
|
24 |
+
push_to_hub: false #change this to True to push your trained model to Hugging Face.
|
25 |
+
# You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
|
26 |
+
# hf_repo_id: your-username/your-model-slug
|
27 |
+
# hf_private: true #whether the repo is private or public
|
28 |
+
datasets:
|
29 |
+
# datasets are a folder of images. captions need to be txt files with the same name as the image
|
30 |
+
# for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
|
31 |
+
# images will automatically be resized and bucketed into the resolution specified
|
32 |
+
# on windows, escape back slashes with another backslash so
|
33 |
+
# "C:\\path\\to\\images\\folder"
|
34 |
+
- folder_path: "/path/to/images/folder"
|
35 |
+
caption_ext: "txt"
|
36 |
+
caption_dropout_rate: 0.05 # will drop out the caption 5% of time
|
37 |
+
shuffle_tokens: false # shuffle caption order, split by commas
|
38 |
+
cache_latents_to_disk: true # leave this true unless you know what you're doing
|
39 |
+
resolution: [ 512, 768, 1024 ] # flux enjoys multiple resolutions
|
40 |
+
train:
|
41 |
+
batch_size: 1
|
42 |
+
steps: 2000 # total number of steps to train 500 - 4000 is a good range
|
43 |
+
gradient_accumulation_steps: 1
|
44 |
+
train_unet: true
|
45 |
+
train_text_encoder: false # probably won't work with flux
|
46 |
+
gradient_checkpointing: true # need the on unless you have a ton of vram
|
47 |
+
noise_scheduler: "flowmatch" # for training only
|
48 |
+
optimizer: "adamw8bit"
|
49 |
+
lr: 1e-4
|
50 |
+
# uncomment this to skip the pre training sample
|
51 |
+
# skip_first_sample: true
|
52 |
+
# uncomment to completely disable sampling
|
53 |
+
# disable_sampling: true
|
54 |
+
# uncomment to use new vell curved weighting. Experimental but may produce better results
|
55 |
+
# linear_timesteps: true
|
56 |
+
|
57 |
+
# ema will smooth out learning, but could slow it down. Recommended to leave on.
|
58 |
+
ema_config:
|
59 |
+
use_ema: true
|
60 |
+
ema_decay: 0.99
|
61 |
+
|
62 |
+
# will probably need this if gpu supports it for flux, other dtypes may not work correctly
|
63 |
+
dtype: bf16
|
64 |
+
model:
|
65 |
+
# huggingface model name or path
|
66 |
+
name_or_path: "black-forest-labs/FLUX.1-dev"
|
67 |
+
is_flux: true
|
68 |
+
quantize: true # run 8bit mixed precision
|
69 |
+
# low_vram: true # uncomment this if the GPU is connected to your monitors. It will use less vram to quantize, but is slower.
|
70 |
+
sample:
|
71 |
+
sampler: "flowmatch" # must match train.noise_scheduler
|
72 |
+
sample_every: 250 # sample every this many steps
|
73 |
+
width: 1024
|
74 |
+
height: 1024
|
75 |
+
prompts:
|
76 |
+
# you can add [trigger] to the prompts here and it will be replaced with the trigger word
|
77 |
+
# - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
|
78 |
+
- "woman with red hair, playing chess at the park, bomb going off in the background"
|
79 |
+
- "a woman holding a coffee cup, in a beanie, sitting at a cafe"
|
80 |
+
- "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
|
81 |
+
- "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
|
82 |
+
- "a bear building a log cabin in the snow covered mountains"
|
83 |
+
- "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
|
84 |
+
- "hipster man with a beard, building a chair, in a wood shop"
|
85 |
+
- "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop"
|
86 |
+
- "a man holding a sign that says, 'this is a sign'"
|
87 |
+
- "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
|
88 |
+
neg: "" # not used on flux
|
89 |
+
seed: 42
|
90 |
+
walk_seed: true
|
91 |
+
guidance_scale: 4
|
92 |
+
sample_steps: 20
|
93 |
+
# you can add any additional meta info here. [name] is replaced with config name at top
|
94 |
+
meta:
|
95 |
+
name: "[name]"
|
96 |
+
version: '1.0'
|
config/examples/train_lora_flux_kontext_24gb.yaml
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
job: extension
|
3 |
+
config:
|
4 |
+
# this name will be the folder and filename name
|
5 |
+
name: "my_first_flux_kontext_lora_v1"
|
6 |
+
process:
|
7 |
+
- type: 'sd_trainer'
|
8 |
+
# root folder to save training sessions/samples/weights
|
9 |
+
training_folder: "output"
|
10 |
+
# uncomment to see performance stats in the terminal every N steps
|
11 |
+
# performance_log_every: 1000
|
12 |
+
device: cuda:0
|
13 |
+
# if a trigger word is specified, it will be added to captions of training data if it does not already exist
|
14 |
+
# alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
|
15 |
+
# trigger_word: "p3r5on"
|
16 |
+
network:
|
17 |
+
type: "lora"
|
18 |
+
linear: 16
|
19 |
+
linear_alpha: 16
|
20 |
+
save:
|
21 |
+
dtype: float16 # precision to save
|
22 |
+
save_every: 250 # save every this many steps
|
23 |
+
max_step_saves_to_keep: 4 # how many intermittent saves to keep
|
24 |
+
push_to_hub: false #change this to True to push your trained model to Hugging Face.
|
25 |
+
# You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
|
26 |
+
# hf_repo_id: your-username/your-model-slug
|
27 |
+
# hf_private: true #whether the repo is private or public
|
28 |
+
datasets:
|
29 |
+
# datasets are a folder of images. captions need to be txt files with the same name as the image
|
30 |
+
# for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
|
31 |
+
# images will automatically be resized and bucketed into the resolution specified
|
32 |
+
# on windows, escape back slashes with another backslash so
|
33 |
+
# "C:\\path\\to\\images\\folder"
|
34 |
+
- folder_path: "/path/to/images/folder"
|
35 |
+
# control path is the input images for kontext for a paired dataset. These are the source images you want to change.
|
36 |
+
# You can comment this out and only use normal images if you don't have a paired dataset.
|
37 |
+
# Control images need to match the filenames on the folder path but in
|
38 |
+
# a different folder. These do not need captions.
|
39 |
+
control_path: "/path/to/control/folder"
|
40 |
+
caption_ext: "txt"
|
41 |
+
caption_dropout_rate: 0.05 # will drop out the caption 5% of time
|
42 |
+
shuffle_tokens: false # shuffle caption order, split by commas
|
43 |
+
cache_latents_to_disk: true # leave this true unless you know what you're doing
|
44 |
+
# Kontext runs images in at 2x the latent size. It may OOM at 1024 resolution with 24GB vram.
|
45 |
+
resolution: [ 512, 768 ] # flux enjoys multiple resolutions
|
46 |
+
# resolution: [ 512, 768, 1024 ]
|
47 |
+
train:
|
48 |
+
batch_size: 1
|
49 |
+
steps: 3000 # total number of steps to train 500 - 4000 is a good range
|
50 |
+
gradient_accumulation_steps: 1
|
51 |
+
train_unet: true
|
52 |
+
train_text_encoder: false # probably won't work with flux
|
53 |
+
gradient_checkpointing: true # need the on unless you have a ton of vram
|
54 |
+
noise_scheduler: "flowmatch" # for training only
|
55 |
+
optimizer: "adamw8bit"
|
56 |
+
lr: 1e-4
|
57 |
+
timestep_type: "weighted" # sigmoid, linear, or weighted.
|
58 |
+
# uncomment this to skip the pre training sample
|
59 |
+
# skip_first_sample: true
|
60 |
+
# uncomment to completely disable sampling
|
61 |
+
# disable_sampling: true
|
62 |
+
|
63 |
+
# ema will smooth out learning, but could slow it down.
|
64 |
+
|
65 |
+
# ema_config:
|
66 |
+
# use_ema: true
|
67 |
+
# ema_decay: 0.99
|
68 |
+
|
69 |
+
# will probably need this if gpu supports it for flux, other dtypes may not work correctly
|
70 |
+
dtype: bf16
|
71 |
+
model:
|
72 |
+
# huggingface model name or path. This model is gated.
|
73 |
+
# visit https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev to accept the terms and conditions
|
74 |
+
# and then you can use this model.
|
75 |
+
name_or_path: "black-forest-labs/FLUX.1-Kontext-dev"
|
76 |
+
arch: "flux_kontext"
|
77 |
+
quantize: true # run 8bit mixed precision
|
78 |
+
# low_vram: true # uncomment this if the GPU is connected to your monitors. It will use less vram to quantize, but is slower.
|
79 |
+
sample:
|
80 |
+
sampler: "flowmatch" # must match train.noise_scheduler
|
81 |
+
sample_every: 250 # sample every this many steps
|
82 |
+
width: 1024
|
83 |
+
height: 1024
|
84 |
+
prompts:
|
85 |
+
# you can add [trigger] to the prompts here and it will be replaced with the trigger word
|
86 |
+
# the --ctrl_img path is the one loaded to apply the kontext editing to
|
87 |
+
# - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
|
88 |
+
- "make the person smile --ctrl_img /path/to/control/folder/person1.jpg"
|
89 |
+
- "give the person an afro --ctrl_img /path/to/control/folder/person1.jpg"
|
90 |
+
- "turn this image into a cartoon --ctrl_img /path/to/control/folder/person1.jpg"
|
91 |
+
- "put this person in an action film --ctrl_img /path/to/control/folder/person1.jpg"
|
92 |
+
- "make this person a rapper in a rap music video --ctrl_img /path/to/control/folder/person1.jpg"
|
93 |
+
- "make the person smile --ctrl_img /path/to/control/folder/person1.jpg"
|
94 |
+
- "give the person an afro --ctrl_img /path/to/control/folder/person1.jpg"
|
95 |
+
- "turn this image into a cartoon --ctrl_img /path/to/control/folder/person1.jpg"
|
96 |
+
- "put this person in an action film --ctrl_img /path/to/control/folder/person1.jpg"
|
97 |
+
- "make this person a rapper in a rap music video --ctrl_img /path/to/control/folder/person1.jpg"
|
98 |
+
neg: "" # not used on flux
|
99 |
+
seed: 42
|
100 |
+
walk_seed: true
|
101 |
+
guidance_scale: 4
|
102 |
+
sample_steps: 20
|
103 |
+
# you can add any additional meta info here. [name] is replaced with config name at top
|
104 |
+
meta:
|
105 |
+
name: "[name]"
|
106 |
+
version: '1.0'
|
config/examples/train_lora_flux_schnell_24gb.yaml
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
job: extension
|
3 |
+
config:
|
4 |
+
# this name will be the folder and filename name
|
5 |
+
name: "my_first_flux_lora_v1"
|
6 |
+
process:
|
7 |
+
- type: 'sd_trainer'
|
8 |
+
# root folder to save training sessions/samples/weights
|
9 |
+
training_folder: "output"
|
10 |
+
# uncomment to see performance stats in the terminal every N steps
|
11 |
+
# performance_log_every: 1000
|
12 |
+
device: cuda:0
|
13 |
+
# if a trigger word is specified, it will be added to captions of training data if it does not already exist
|
14 |
+
# alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
|
15 |
+
# trigger_word: "p3r5on"
|
16 |
+
network:
|
17 |
+
type: "lora"
|
18 |
+
linear: 16
|
19 |
+
linear_alpha: 16
|
20 |
+
save:
|
21 |
+
dtype: float16 # precision to save
|
22 |
+
save_every: 250 # save every this many steps
|
23 |
+
max_step_saves_to_keep: 4 # how many intermittent saves to keep
|
24 |
+
push_to_hub: false #change this to True to push your trained model to Hugging Face.
|
25 |
+
# You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
|
26 |
+
# hf_repo_id: your-username/your-model-slug
|
27 |
+
# hf_private: true #whether the repo is private or public
|
28 |
+
datasets:
|
29 |
+
# datasets are a folder of images. captions need to be txt files with the same name as the image
|
30 |
+
# for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
|
31 |
+
# images will automatically be resized and bucketed into the resolution specified
|
32 |
+
# on windows, escape back slashes with another backslash so
|
33 |
+
# "C:\\path\\to\\images\\folder"
|
34 |
+
- folder_path: "/path/to/images/folder"
|
35 |
+
caption_ext: "txt"
|
36 |
+
caption_dropout_rate: 0.05 # will drop out the caption 5% of time
|
37 |
+
shuffle_tokens: false # shuffle caption order, split by commas
|
38 |
+
cache_latents_to_disk: true # leave this true unless you know what you're doing
|
39 |
+
resolution: [ 512, 768, 1024 ] # flux enjoys multiple resolutions
|
40 |
+
train:
|
41 |
+
batch_size: 1
|
42 |
+
steps: 2000 # total number of steps to train 500 - 4000 is a good range
|
43 |
+
gradient_accumulation_steps: 1
|
44 |
+
train_unet: true
|
45 |
+
train_text_encoder: false # probably won't work with flux
|
46 |
+
gradient_checkpointing: true # need the on unless you have a ton of vram
|
47 |
+
noise_scheduler: "flowmatch" # for training only
|
48 |
+
optimizer: "adamw8bit"
|
49 |
+
lr: 1e-4
|
50 |
+
# uncomment this to skip the pre training sample
|
51 |
+
# skip_first_sample: true
|
52 |
+
# uncomment to completely disable sampling
|
53 |
+
# disable_sampling: true
|
54 |
+
# uncomment to use new bell curved weighting. Experimental but may produce better results
|
55 |
+
# linear_timesteps: true
|
56 |
+
|
57 |
+
# ema will smooth out learning, but could slow it down. Recommended to leave on.
|
58 |
+
ema_config:
|
59 |
+
use_ema: true
|
60 |
+
ema_decay: 0.99
|
61 |
+
|
62 |
+
# will probably need this if gpu supports it for flux, other dtypes may not work correctly
|
63 |
+
dtype: bf16
|
64 |
+
model:
|
65 |
+
# huggingface model name or path
|
66 |
+
name_or_path: "black-forest-labs/FLUX.1-schnell"
|
67 |
+
assistant_lora_path: "ostris/FLUX.1-schnell-training-adapter" # Required for flux schnell training
|
68 |
+
is_flux: true
|
69 |
+
quantize: true # run 8bit mixed precision
|
70 |
+
# low_vram is painfully slow to fuse in the adapter avoid it unless absolutely necessary
|
71 |
+
# low_vram: true # uncomment this if the GPU is connected to your monitors. It will use less vram to quantize, but is slower.
|
72 |
+
sample:
|
73 |
+
sampler: "flowmatch" # must match train.noise_scheduler
|
74 |
+
sample_every: 250 # sample every this many steps
|
75 |
+
width: 1024
|
76 |
+
height: 1024
|
77 |
+
prompts:
|
78 |
+
# you can add [trigger] to the prompts here and it will be replaced with the trigger word
|
79 |
+
# - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
|
80 |
+
- "woman with red hair, playing chess at the park, bomb going off in the background"
|
81 |
+
- "a woman holding a coffee cup, in a beanie, sitting at a cafe"
|
82 |
+
- "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
|
83 |
+
- "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
|
84 |
+
- "a bear building a log cabin in the snow covered mountains"
|
85 |
+
- "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
|
86 |
+
- "hipster man with a beard, building a chair, in a wood shop"
|
87 |
+
- "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop"
|
88 |
+
- "a man holding a sign that says, 'this is a sign'"
|
89 |
+
- "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
|
90 |
+
neg: "" # not used on flux
|
91 |
+
seed: 42
|
92 |
+
walk_seed: true
|
93 |
+
guidance_scale: 1 # schnell does not do guidance
|
94 |
+
sample_steps: 4 # 1 - 4 works well
|
95 |
+
# you can add any additional meta info here. [name] is replaced with config name at top
|
96 |
+
meta:
|
97 |
+
name: "[name]"
|
98 |
+
version: '1.0'
|
config/examples/train_lora_hidream_48.yaml
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# HiDream training is still highly experimental. The settings here will take ~35.2GB of vram to train.
|
2 |
+
# It is not possible to train on a single 24GB card yet, but I am working on it. If you have more VRAM
|
3 |
+
# I highly recommend first disabling quantization on the model itself if you can. You can leave the TEs quantized.
|
4 |
+
# HiDream has a mixture of experts that may take special training considerations that I do not
|
5 |
+
# have implemented properly. The current implementation seems to work well for LoRA training, but
|
6 |
+
# may not be effective for longer training runs. The implementation could change in future updates
|
7 |
+
# so your results may vary when this happens.
|
8 |
+
|
9 |
+
---
|
10 |
+
job: extension
|
11 |
+
config:
|
12 |
+
# this name will be the folder and filename name
|
13 |
+
name: "my_first_hidream_lora_v1"
|
14 |
+
process:
|
15 |
+
- type: 'sd_trainer'
|
16 |
+
# root folder to save training sessions/samples/weights
|
17 |
+
training_folder: "output"
|
18 |
+
# uncomment to see performance stats in the terminal every N steps
|
19 |
+
# performance_log_every: 1000
|
20 |
+
device: cuda:0
|
21 |
+
# if a trigger word is specified, it will be added to captions of training data if it does not already exist
|
22 |
+
# alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
|
23 |
+
# trigger_word: "p3r5on"
|
24 |
+
network:
|
25 |
+
type: "lora"
|
26 |
+
linear: 32
|
27 |
+
linear_alpha: 32
|
28 |
+
network_kwargs:
|
29 |
+
# it is probably best to ignore the mixture of experts since only 2 are active each block. It works activating it, but I wouldnt.
|
30 |
+
# proper training of it is not fully implemented
|
31 |
+
ignore_if_contains:
|
32 |
+
- "ff_i.experts"
|
33 |
+
- "ff_i.gate"
|
34 |
+
save:
|
35 |
+
dtype: bfloat16 # precision to save
|
36 |
+
save_every: 250 # save every this many steps
|
37 |
+
max_step_saves_to_keep: 4 # how many intermittent saves to keep
|
38 |
+
datasets:
|
39 |
+
# datasets are a folder of images. captions need to be txt files with the same name as the image
|
40 |
+
# for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
|
41 |
+
# images will automatically be resized and bucketed into the resolution specified
|
42 |
+
# on windows, escape back slashes with another backslash so
|
43 |
+
# "C:\\path\\to\\images\\folder"
|
44 |
+
- folder_path: "/path/to/images/folder"
|
45 |
+
caption_ext: "txt"
|
46 |
+
caption_dropout_rate: 0.05 # will drop out the caption 5% of time
|
47 |
+
resolution: [ 512, 768, 1024 ] # hidream enjoys multiple resolutions
|
48 |
+
train:
|
49 |
+
batch_size: 1
|
50 |
+
steps: 3000 # total number of steps to train 500 - 4000 is a good range
|
51 |
+
gradient_accumulation_steps: 1
|
52 |
+
train_unet: true
|
53 |
+
train_text_encoder: false # wont work with hidream
|
54 |
+
gradient_checkpointing: true # need the on unless you have a ton of vram
|
55 |
+
noise_scheduler: "flowmatch" # for training only
|
56 |
+
timestep_type: shift # sigmoid, shift, linear
|
57 |
+
optimizer: "adamw8bit"
|
58 |
+
lr: 2e-4
|
59 |
+
# uncomment this to skip the pre training sample
|
60 |
+
# skip_first_sample: true
|
61 |
+
# uncomment to completely disable sampling
|
62 |
+
# disable_sampling: true
|
63 |
+
# uncomment to use new vell curved weighting. Experimental but may produce better results
|
64 |
+
# linear_timesteps: true
|
65 |
+
|
66 |
+
# ema will smooth out learning, but could slow it down. Defaults off
|
67 |
+
ema_config:
|
68 |
+
use_ema: false
|
69 |
+
ema_decay: 0.99
|
70 |
+
|
71 |
+
# will probably need this if gpu supports it for hidream, other dtypes may not work correctly
|
72 |
+
dtype: bf16
|
73 |
+
model:
|
74 |
+
# the transformer will get grabbed from this hf repo
|
75 |
+
# warning ONLY train on Full. The dev and fast models are distilled and will break
|
76 |
+
name_or_path: "HiDream-ai/HiDream-I1-Full"
|
77 |
+
# the extras will be grabbed from this hf repo. (text encoder, vae)
|
78 |
+
extras_name_or_path: "HiDream-ai/HiDream-I1-Full"
|
79 |
+
arch: "hidream"
|
80 |
+
# both need to be quantized to train on 48GB currently
|
81 |
+
quantize: true
|
82 |
+
quantize_te: true
|
83 |
+
model_kwargs:
|
84 |
+
# llama is a gated model, It defaults to unsloth version, but you can set the llama path here
|
85 |
+
llama_model_path: "unsloth/Meta-Llama-3.1-8B-Instruct"
|
86 |
+
sample:
|
87 |
+
sampler: "flowmatch" # must match train.noise_scheduler
|
88 |
+
sample_every: 250 # sample every this many steps
|
89 |
+
width: 1024
|
90 |
+
height: 1024
|
91 |
+
prompts:
|
92 |
+
# you can add [trigger] to the prompts here and it will be replaced with the trigger word
|
93 |
+
# - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
|
94 |
+
- "woman with red hair, playing chess at the park, bomb going off in the background"
|
95 |
+
- "a woman holding a coffee cup, in a beanie, sitting at a cafe"
|
96 |
+
- "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
|
97 |
+
- "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
|
98 |
+
- "a bear building a log cabin in the snow covered mountains"
|
99 |
+
- "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
|
100 |
+
- "hipster man with a beard, building a chair, in a wood shop"
|
101 |
+
- "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop"
|
102 |
+
- "a man holding a sign that says, 'this is a sign'"
|
103 |
+
- "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
|
104 |
+
neg: ""
|
105 |
+
seed: 42
|
106 |
+
walk_seed: true
|
107 |
+
guidance_scale: 4
|
108 |
+
sample_steps: 25
|
109 |
+
# you can add any additional meta info here. [name] is replaced with config name at top
|
110 |
+
meta:
|
111 |
+
name: "[name]"
|
112 |
+
version: '1.0'
|
config/examples/train_lora_lumina.yaml
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
# This configuration requires 20GB of VRAM or more to operate
|
3 |
+
job: extension
|
4 |
+
config:
|
5 |
+
# this name will be the folder and filename name
|
6 |
+
name: "my_first_lumina_lora_v1"
|
7 |
+
process:
|
8 |
+
- type: 'sd_trainer'
|
9 |
+
# root folder to save training sessions/samples/weights
|
10 |
+
training_folder: "output"
|
11 |
+
# uncomment to see performance stats in the terminal every N steps
|
12 |
+
# performance_log_every: 1000
|
13 |
+
device: cuda:0
|
14 |
+
# if a trigger word is specified, it will be added to captions of training data if it does not already exist
|
15 |
+
# alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
|
16 |
+
# trigger_word: "p3r5on"
|
17 |
+
network:
|
18 |
+
type: "lora"
|
19 |
+
linear: 16
|
20 |
+
linear_alpha: 16
|
21 |
+
save:
|
22 |
+
dtype: bf16 # precision to save
|
23 |
+
save_every: 250 # save every this many steps
|
24 |
+
max_step_saves_to_keep: 2 # how many intermittent saves to keep
|
25 |
+
save_format: 'diffusers' # 'diffusers'
|
26 |
+
datasets:
|
27 |
+
# datasets are a folder of images. captions need to be txt files with the same name as the image
|
28 |
+
# for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
|
29 |
+
# images will automatically be resized and bucketed into the resolution specified
|
30 |
+
# on windows, escape back slashes with another backslash so
|
31 |
+
# "C:\\path\\to\\images\\folder"
|
32 |
+
- folder_path: "/path/to/images/folder"
|
33 |
+
caption_ext: "txt"
|
34 |
+
caption_dropout_rate: 0.05 # will drop out the caption 5% of time
|
35 |
+
shuffle_tokens: false # shuffle caption order, split by commas
|
36 |
+
# cache_latents_to_disk: true # leave this true unless you know what you're doing
|
37 |
+
resolution: [ 512, 768, 1024 ] # lumina2 enjoys multiple resolutions
|
38 |
+
train:
|
39 |
+
batch_size: 1
|
40 |
+
|
41 |
+
# can be 'sigmoid', 'linear', or 'lumina2_shift'
|
42 |
+
timestep_type: 'lumina2_shift'
|
43 |
+
|
44 |
+
steps: 2000 # total number of steps to train 500 - 4000 is a good range
|
45 |
+
gradient_accumulation: 1
|
46 |
+
train_unet: true
|
47 |
+
train_text_encoder: false # probably won't work with lumina2
|
48 |
+
gradient_checkpointing: true # need the on unless you have a ton of vram
|
49 |
+
noise_scheduler: "flowmatch" # for training only
|
50 |
+
optimizer: "adamw8bit"
|
51 |
+
lr: 1e-4
|
52 |
+
# uncomment this to skip the pre training sample
|
53 |
+
# skip_first_sample: true
|
54 |
+
# uncomment to completely disable sampling
|
55 |
+
# disable_sampling: true
|
56 |
+
|
57 |
+
# ema will smooth out learning, but could slow it down. Recommended to leave on if you have the vram
|
58 |
+
ema_config:
|
59 |
+
use_ema: true
|
60 |
+
ema_decay: 0.99
|
61 |
+
|
62 |
+
# will probably need this if gpu supports it for lumina2, other dtypes may not work correctly
|
63 |
+
dtype: bf16
|
64 |
+
model:
|
65 |
+
# huggingface model name or path
|
66 |
+
name_or_path: "Alpha-VLLM/Lumina-Image-2.0"
|
67 |
+
is_lumina2: true # lumina2 architecture
|
68 |
+
# you can quantize just the Gemma2 text encoder here to save vram
|
69 |
+
quantize_te: true
|
70 |
+
sample:
|
71 |
+
sampler: "flowmatch" # must match train.noise_scheduler
|
72 |
+
sample_every: 250 # sample every this many steps
|
73 |
+
width: 1024
|
74 |
+
height: 1024
|
75 |
+
prompts:
|
76 |
+
# you can add [trigger] to the prompts here and it will be replaced with the trigger word
|
77 |
+
# - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
|
78 |
+
- "woman with red hair, playing chess at the park, bomb going off in the background"
|
79 |
+
- "a woman holding a coffee cup, in a beanie, sitting at a cafe"
|
80 |
+
- "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
|
81 |
+
- "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
|
82 |
+
- "a bear building a log cabin in the snow covered mountains"
|
83 |
+
- "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
|
84 |
+
- "hipster man with a beard, building a chair, in a wood shop"
|
85 |
+
- "photo of a cat that is half black and half orange tabby, split down the middle. The cat has on a blue tophat. They are holding a martini glass with a pink ball of yarn in it with green knitting needles sticking out, in one paw. In the other paw, they are holding a DVD case for a movie titled, \"This is a test\" that has a golden robot on it. In the background is a busy night club with a giant mushroom man dancing with a bear."
|
86 |
+
- "a man holding a sign that says, 'this is a sign'"
|
87 |
+
- "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
|
88 |
+
neg: ""
|
89 |
+
seed: 42
|
90 |
+
walk_seed: true
|
91 |
+
guidance_scale: 4.0
|
92 |
+
sample_steps: 25
|
93 |
+
# you can add any additional meta info here. [name] is replaced with config name at top
|
94 |
+
meta:
|
95 |
+
name: "[name]"
|
96 |
+
version: '1.0'
|
config/examples/train_lora_omnigen2_24gb.yaml
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
job: extension
|
3 |
+
config:
|
4 |
+
# this name will be the folder and filename name
|
5 |
+
name: "my_first_omnigen2_lora_v1"
|
6 |
+
process:
|
7 |
+
- type: 'sd_trainer'
|
8 |
+
# root folder to save training sessions/samples/weights
|
9 |
+
training_folder: "output"
|
10 |
+
# uncomment to see performance stats in the terminal every N steps
|
11 |
+
# performance_log_every: 1000
|
12 |
+
device: cuda:0
|
13 |
+
# if a trigger word is specified, it will be added to captions of training data if it does not already exist
|
14 |
+
# alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
|
15 |
+
# trigger_word: "p3r5on"
|
16 |
+
network:
|
17 |
+
type: "lora"
|
18 |
+
linear: 16
|
19 |
+
linear_alpha: 16
|
20 |
+
save:
|
21 |
+
dtype: float16 # precision to save
|
22 |
+
save_every: 250 # save every this many steps
|
23 |
+
max_step_saves_to_keep: 4 # how many intermittent saves to keep
|
24 |
+
push_to_hub: false #change this to True to push your trained model to Hugging Face.
|
25 |
+
# You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
|
26 |
+
# hf_repo_id: your-username/your-model-slug
|
27 |
+
# hf_private: true #whether the repo is private or public
|
28 |
+
datasets:
|
29 |
+
# datasets are a folder of images. captions need to be txt files with the same name as the image
|
30 |
+
# for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
|
31 |
+
# images will automatically be resized and bucketed into the resolution specified
|
32 |
+
# on windows, escape back slashes with another backslash so
|
33 |
+
# "C:\\path\\to\\images\\folder"
|
34 |
+
- folder_path: "/path/to/images/folder"
|
35 |
+
caption_ext: "txt"
|
36 |
+
caption_dropout_rate: 0.05 # will drop out the caption 5% of time
|
37 |
+
shuffle_tokens: false # shuffle caption order, split by commas
|
38 |
+
cache_latents_to_disk: true # leave this true unless you know what you're doing
|
39 |
+
resolution: [ 512, 768, 1024 ] # omnigen2 should work with multiple resolutions
|
40 |
+
train:
|
41 |
+
batch_size: 1
|
42 |
+
steps: 3000 # total number of steps to train 500 - 4000 is a good range
|
43 |
+
gradient_accumulation: 1
|
44 |
+
train_unet: true
|
45 |
+
train_text_encoder: false # probably won't work with omnigen2
|
46 |
+
gradient_checkpointing: true # need the on unless you have a ton of vram
|
47 |
+
noise_scheduler: "flowmatch" # for training only
|
48 |
+
optimizer: "adamw8bit"
|
49 |
+
lr: 1e-4
|
50 |
+
timestep_type: 'sigmoid' # sigmoid, linear, shift
|
51 |
+
# uncomment this to skip the pre training sample
|
52 |
+
# skip_first_sample: true
|
53 |
+
# uncomment to completely disable sampling
|
54 |
+
# disable_sampling: true
|
55 |
+
|
56 |
+
# ema will smooth out learning, but could slow it down.
|
57 |
+
# ema_config:
|
58 |
+
# use_ema: true
|
59 |
+
# ema_decay: 0.99
|
60 |
+
|
61 |
+
# will probably need this if gpu supports it for omnigen2, other dtypes may not work correctly
|
62 |
+
dtype: bf16
|
63 |
+
model:
|
64 |
+
name_or_path: "OmniGen2/OmniGen2
|
65 |
+
arch: "omnigen2"
|
66 |
+
quantize_te: true # quantize_only te
|
67 |
+
# quantize: true # quantize transformer
|
68 |
+
sample:
|
69 |
+
sampler: "flowmatch" # must match train.noise_scheduler
|
70 |
+
sample_every: 250 # sample every this many steps
|
71 |
+
width: 1024
|
72 |
+
height: 1024
|
73 |
+
prompts:
|
74 |
+
# you can add [trigger] to the prompts here and it will be replaced with the trigger word
|
75 |
+
# - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
|
76 |
+
- "woman with red hair, playing chess at the park, bomb going off in the background"
|
77 |
+
- "a woman holding a coffee cup, in a beanie, sitting at a cafe"
|
78 |
+
- "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
|
79 |
+
- "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
|
80 |
+
- "a bear building a log cabin in the snow covered mountains"
|
81 |
+
- "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
|
82 |
+
- "hipster man with a beard, building a chair, in a wood shop"
|
83 |
+
- "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop"
|
84 |
+
- "a man holding a sign that says, 'this is a sign'"
|
85 |
+
- "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
|
86 |
+
neg: "" # negative prompt, optional
|
87 |
+
seed: 42
|
88 |
+
walk_seed: true
|
89 |
+
guidance_scale: 4
|
90 |
+
sample_steps: 25
|
91 |
+
# you can add any additional meta info here. [name] is replaced with config name at top
|
92 |
+
meta:
|
93 |
+
name: "[name]"
|
94 |
+
version: '1.0'
|
config/examples/train_lora_sd35_large_24gb.yaml
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
# NOTE!! THIS IS CURRENTLY EXPERIMENTAL AND UNDER DEVELOPMENT. SOME THINGS WILL CHANGE
|
3 |
+
job: extension
|
4 |
+
config:
|
5 |
+
# this name will be the folder and filename name
|
6 |
+
name: "my_first_sd3l_lora_v1"
|
7 |
+
process:
|
8 |
+
- type: 'sd_trainer'
|
9 |
+
# root folder to save training sessions/samples/weights
|
10 |
+
training_folder: "output"
|
11 |
+
# uncomment to see performance stats in the terminal every N steps
|
12 |
+
# performance_log_every: 1000
|
13 |
+
device: cuda:0
|
14 |
+
# if a trigger word is specified, it will be added to captions of training data if it does not already exist
|
15 |
+
# alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
|
16 |
+
# trigger_word: "p3r5on"
|
17 |
+
network:
|
18 |
+
type: "lora"
|
19 |
+
linear: 16
|
20 |
+
linear_alpha: 16
|
21 |
+
save:
|
22 |
+
dtype: float16 # precision to save
|
23 |
+
save_every: 250 # save every this many steps
|
24 |
+
max_step_saves_to_keep: 4 # how many intermittent saves to keep
|
25 |
+
push_to_hub: false #change this to True to push your trained model to Hugging Face.
|
26 |
+
# You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
|
27 |
+
# hf_repo_id: your-username/your-model-slug
|
28 |
+
# hf_private: true #whether the repo is private or public
|
29 |
+
datasets:
|
30 |
+
# datasets are a folder of images. captions need to be txt files with the same name as the image
|
31 |
+
# for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
|
32 |
+
# images will automatically be resized and bucketed into the resolution specified
|
33 |
+
# on windows, escape back slashes with another backslash so
|
34 |
+
# "C:\\path\\to\\images\\folder"
|
35 |
+
- folder_path: "/path/to/images/folder"
|
36 |
+
caption_ext: "txt"
|
37 |
+
caption_dropout_rate: 0.05 # will drop out the caption 5% of time
|
38 |
+
shuffle_tokens: false # shuffle caption order, split by commas
|
39 |
+
cache_latents_to_disk: true # leave this true unless you know what you're doing
|
40 |
+
resolution: [ 1024 ]
|
41 |
+
train:
|
42 |
+
batch_size: 1
|
43 |
+
steps: 2000 # total number of steps to train 500 - 4000 is a good range
|
44 |
+
gradient_accumulation_steps: 1
|
45 |
+
train_unet: true
|
46 |
+
train_text_encoder: false # May not fully work with SD3 yet
|
47 |
+
gradient_checkpointing: true # need the on unless you have a ton of vram
|
48 |
+
noise_scheduler: "flowmatch"
|
49 |
+
timestep_type: "linear" # linear or sigmoid
|
50 |
+
optimizer: "adamw8bit"
|
51 |
+
lr: 1e-4
|
52 |
+
# uncomment this to skip the pre training sample
|
53 |
+
# skip_first_sample: true
|
54 |
+
# uncomment to completely disable sampling
|
55 |
+
# disable_sampling: true
|
56 |
+
# uncomment to use new vell curved weighting. Experimental but may produce better results
|
57 |
+
# linear_timesteps: true
|
58 |
+
|
59 |
+
# ema will smooth out learning, but could slow it down. Recommended to leave on.
|
60 |
+
ema_config:
|
61 |
+
use_ema: true
|
62 |
+
ema_decay: 0.99
|
63 |
+
|
64 |
+
# will probably need this if gpu supports it for sd3, other dtypes may not work correctly
|
65 |
+
dtype: bf16
|
66 |
+
model:
|
67 |
+
# huggingface model name or path
|
68 |
+
name_or_path: "stabilityai/stable-diffusion-3.5-large"
|
69 |
+
is_v3: true
|
70 |
+
quantize: true # run 8bit mixed precision
|
71 |
+
sample:
|
72 |
+
sampler: "flowmatch" # must match train.noise_scheduler
|
73 |
+
sample_every: 250 # sample every this many steps
|
74 |
+
width: 1024
|
75 |
+
height: 1024
|
76 |
+
prompts:
|
77 |
+
# you can add [trigger] to the prompts here and it will be replaced with the trigger word
|
78 |
+
# - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
|
79 |
+
- "woman with red hair, playing chess at the park, bomb going off in the background"
|
80 |
+
- "a woman holding a coffee cup, in a beanie, sitting at a cafe"
|
81 |
+
- "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
|
82 |
+
- "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
|
83 |
+
- "a bear building a log cabin in the snow covered mountains"
|
84 |
+
- "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
|
85 |
+
- "hipster man with a beard, building a chair, in a wood shop"
|
86 |
+
- "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop"
|
87 |
+
- "a man holding a sign that says, 'this is a sign'"
|
88 |
+
- "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
|
89 |
+
neg: ""
|
90 |
+
seed: 42
|
91 |
+
walk_seed: true
|
92 |
+
guidance_scale: 4
|
93 |
+
sample_steps: 25
|
94 |
+
# you can add any additional meta info here. [name] is replaced with config name at top
|
95 |
+
meta:
|
96 |
+
name: "[name]"
|
97 |
+
version: '1.0'
|
config/examples/train_lora_wan21_14b_24gb.yaml
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# IMPORTANT: The Wan2.1 14B model is huge. This config should work on 24GB GPUs. It cannot
|
2 |
+
# support keeping the text encoder on GPU while training with 24GB, so it is only good
|
3 |
+
# for training on a single prompt, for example a person with a trigger word.
|
4 |
+
# to train on captions, you need more vran for now.
|
5 |
+
---
|
6 |
+
job: extension
|
7 |
+
config:
|
8 |
+
# this name will be the folder and filename name
|
9 |
+
name: "my_first_wan21_14b_lora_v1"
|
10 |
+
process:
|
11 |
+
- type: 'sd_trainer'
|
12 |
+
# root folder to save training sessions/samples/weights
|
13 |
+
training_folder: "output"
|
14 |
+
# uncomment to see performance stats in the terminal every N steps
|
15 |
+
# performance_log_every: 1000
|
16 |
+
device: cuda:0
|
17 |
+
# if a trigger word is specified, it will be added to captions of training data if it does not already exist
|
18 |
+
# alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
|
19 |
+
# this is probably needed for 24GB cards when offloading TE to CPU
|
20 |
+
trigger_word: "p3r5on"
|
21 |
+
network:
|
22 |
+
type: "lora"
|
23 |
+
linear: 32
|
24 |
+
linear_alpha: 32
|
25 |
+
save:
|
26 |
+
dtype: float16 # precision to save
|
27 |
+
save_every: 250 # save every this many steps
|
28 |
+
max_step_saves_to_keep: 4 # how many intermittent saves to keep
|
29 |
+
push_to_hub: false #change this to True to push your trained model to Hugging Face.
|
30 |
+
# You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
|
31 |
+
# hf_repo_id: your-username/your-model-slug
|
32 |
+
# hf_private: true #whether the repo is private or public
|
33 |
+
datasets:
|
34 |
+
# datasets are a folder of images. captions need to be txt files with the same name as the image
|
35 |
+
# for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
|
36 |
+
# images will automatically be resized and bucketed into the resolution specified
|
37 |
+
# on windows, escape back slashes with another backslash so
|
38 |
+
# "C:\\path\\to\\images\\folder"
|
39 |
+
# AI-Toolkit does not currently support video datasets, we will train on 1 frame at a time
|
40 |
+
# it works well for characters, but not as well for "actions"
|
41 |
+
- folder_path: "/path/to/images/folder"
|
42 |
+
caption_ext: "txt"
|
43 |
+
caption_dropout_rate: 0.05 # will drop out the caption 5% of time
|
44 |
+
shuffle_tokens: false # shuffle caption order, split by commas
|
45 |
+
cache_latents_to_disk: true # leave this true unless you know what you're doing
|
46 |
+
resolution: [ 632 ] # will be around 480p
|
47 |
+
train:
|
48 |
+
batch_size: 1
|
49 |
+
steps: 2000 # total number of steps to train 500 - 4000 is a good range
|
50 |
+
gradient_accumulation: 1
|
51 |
+
train_unet: true
|
52 |
+
train_text_encoder: false # probably won't work with wan
|
53 |
+
gradient_checkpointing: true # need the on unless you have a ton of vram
|
54 |
+
noise_scheduler: "flowmatch" # for training only
|
55 |
+
timestep_type: 'sigmoid'
|
56 |
+
optimizer: "adamw8bit"
|
57 |
+
lr: 1e-4
|
58 |
+
optimizer_params:
|
59 |
+
weight_decay: 1e-4
|
60 |
+
# uncomment this to skip the pre training sample
|
61 |
+
# skip_first_sample: true
|
62 |
+
# uncomment to completely disable sampling
|
63 |
+
# disable_sampling: true
|
64 |
+
# ema will smooth out learning, but could slow it down. Recommended to leave on.
|
65 |
+
ema_config:
|
66 |
+
use_ema: true
|
67 |
+
ema_decay: 0.99
|
68 |
+
dtype: bf16
|
69 |
+
# required for 24GB cards
|
70 |
+
# this will encode your trigger word and use those embeddings for every image in the dataset
|
71 |
+
unload_text_encoder: true
|
72 |
+
model:
|
73 |
+
# huggingface model name or path
|
74 |
+
name_or_path: "Wan-AI/Wan2.1-T2V-14B-Diffusers"
|
75 |
+
arch: 'wan21'
|
76 |
+
# these settings will save as much vram as possible
|
77 |
+
quantize: true
|
78 |
+
quantize_te: true
|
79 |
+
low_vram: true
|
80 |
+
sample:
|
81 |
+
sampler: "flowmatch"
|
82 |
+
sample_every: 250 # sample every this many steps
|
83 |
+
width: 832
|
84 |
+
height: 480
|
85 |
+
num_frames: 40
|
86 |
+
fps: 15
|
87 |
+
# samples take a long time. so use them sparingly
|
88 |
+
# samples will be animated webp files, if you don't see them animated, open in a browser.
|
89 |
+
prompts:
|
90 |
+
# you can add [trigger] to the prompts here and it will be replaced with the trigger word
|
91 |
+
# - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
|
92 |
+
- "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
|
93 |
+
neg: ""
|
94 |
+
seed: 42
|
95 |
+
walk_seed: true
|
96 |
+
guidance_scale: 5
|
97 |
+
sample_steps: 30
|
98 |
+
# you can add any additional meta info here. [name] is replaced with config name at top
|
99 |
+
meta:
|
100 |
+
name: "[name]"
|
101 |
+
version: '1.0'
|
config/examples/train_lora_wan21_1b_24gb.yaml
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
job: extension
|
3 |
+
config:
|
4 |
+
# this name will be the folder and filename name
|
5 |
+
name: "my_first_wan21_1b_lora_v1"
|
6 |
+
process:
|
7 |
+
- type: 'sd_trainer'
|
8 |
+
# root folder to save training sessions/samples/weights
|
9 |
+
training_folder: "output"
|
10 |
+
# uncomment to see performance stats in the terminal every N steps
|
11 |
+
# performance_log_every: 1000
|
12 |
+
device: cuda:0
|
13 |
+
# if a trigger word is specified, it will be added to captions of training data if it does not already exist
|
14 |
+
# alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
|
15 |
+
# trigger_word: "p3r5on"
|
16 |
+
network:
|
17 |
+
type: "lora"
|
18 |
+
linear: 32
|
19 |
+
linear_alpha: 32
|
20 |
+
save:
|
21 |
+
dtype: float16 # precision to save
|
22 |
+
save_every: 250 # save every this many steps
|
23 |
+
max_step_saves_to_keep: 4 # how many intermittent saves to keep
|
24 |
+
push_to_hub: false #change this to True to push your trained model to Hugging Face.
|
25 |
+
# You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
|
26 |
+
# hf_repo_id: your-username/your-model-slug
|
27 |
+
# hf_private: true #whether the repo is private or public
|
28 |
+
datasets:
|
29 |
+
# datasets are a folder of images. captions need to be txt files with the same name as the image
|
30 |
+
# for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
|
31 |
+
# images will automatically be resized and bucketed into the resolution specified
|
32 |
+
# on windows, escape back slashes with another backslash so
|
33 |
+
# "C:\\path\\to\\images\\folder"
|
34 |
+
# AI-Toolkit does not currently support video datasets, we will train on 1 frame at a time
|
35 |
+
# it works well for characters, but not as well for "actions"
|
36 |
+
- folder_path: "/path/to/images/folder"
|
37 |
+
caption_ext: "txt"
|
38 |
+
caption_dropout_rate: 0.05 # will drop out the caption 5% of time
|
39 |
+
shuffle_tokens: false # shuffle caption order, split by commas
|
40 |
+
cache_latents_to_disk: true # leave this true unless you know what you're doing
|
41 |
+
resolution: [ 632 ] # will be around 480p
|
42 |
+
train:
|
43 |
+
batch_size: 1
|
44 |
+
steps: 2000 # total number of steps to train 500 - 4000 is a good range
|
45 |
+
gradient_accumulation: 1
|
46 |
+
train_unet: true
|
47 |
+
train_text_encoder: false # probably won't work with wan
|
48 |
+
gradient_checkpointing: true # need the on unless you have a ton of vram
|
49 |
+
noise_scheduler: "flowmatch" # for training only
|
50 |
+
timestep_type: 'sigmoid'
|
51 |
+
optimizer: "adamw8bit"
|
52 |
+
lr: 1e-4
|
53 |
+
optimizer_params:
|
54 |
+
weight_decay: 1e-4
|
55 |
+
# uncomment this to skip the pre training sample
|
56 |
+
# skip_first_sample: true
|
57 |
+
# uncomment to completely disable sampling
|
58 |
+
# disable_sampling: true
|
59 |
+
# ema will smooth out learning, but could slow it down. Recommended to leave on.
|
60 |
+
ema_config:
|
61 |
+
use_ema: true
|
62 |
+
ema_decay: 0.99
|
63 |
+
dtype: bf16
|
64 |
+
model:
|
65 |
+
# huggingface model name or path
|
66 |
+
name_or_path: "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
|
67 |
+
arch: 'wan21'
|
68 |
+
quantize_te: true # saves vram
|
69 |
+
sample:
|
70 |
+
sampler: "flowmatch"
|
71 |
+
sample_every: 250 # sample every this many steps
|
72 |
+
width: 832
|
73 |
+
height: 480
|
74 |
+
num_frames: 40
|
75 |
+
fps: 15
|
76 |
+
# samples take a long time. so use them sparingly
|
77 |
+
# samples will be animated webp files, if you don't see them animated, open in a browser.
|
78 |
+
prompts:
|
79 |
+
# you can add [trigger] to the prompts here and it will be replaced with the trigger word
|
80 |
+
# - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
|
81 |
+
- "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
|
82 |
+
neg: ""
|
83 |
+
seed: 42
|
84 |
+
walk_seed: true
|
85 |
+
guidance_scale: 5
|
86 |
+
sample_steps: 30
|
87 |
+
# you can add any additional meta info here. [name] is replaced with config name at top
|
88 |
+
meta:
|
89 |
+
name: "[name]"
|
90 |
+
version: '1.0'
|
config/examples/train_slider.example.yml
ADDED
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
# This is in yaml format. You can use json if you prefer
|
3 |
+
# I like both but yaml is easier to write
|
4 |
+
# Plus it has comments which is nice for documentation
|
5 |
+
# This is the config I use on my sliders, It is solid and tested
|
6 |
+
job: train
|
7 |
+
config:
|
8 |
+
# the name will be used to create a folder in the output folder
|
9 |
+
# it will also replace any [name] token in the rest of this config
|
10 |
+
name: detail_slider_v1
|
11 |
+
# folder will be created with name above in folder below
|
12 |
+
# it can be relative to the project root or absolute
|
13 |
+
training_folder: "output/LoRA"
|
14 |
+
device: cuda:0 # cpu, cuda:0, etc
|
15 |
+
# for tensorboard logging, we will make a subfolder for this job
|
16 |
+
log_dir: "output/.tensorboard"
|
17 |
+
# you can stack processes for other jobs, It is not tested with sliders though
|
18 |
+
# just use one for now
|
19 |
+
process:
|
20 |
+
- type: slider # tells runner to run the slider process
|
21 |
+
# network is the LoRA network for a slider, I recommend to leave this be
|
22 |
+
network:
|
23 |
+
# network type lierla is traditional LoRA that works everywhere, only linear layers
|
24 |
+
type: "lierla"
|
25 |
+
# rank / dim of the network. Bigger is not always better. Especially for sliders. 8 is good
|
26 |
+
linear: 8
|
27 |
+
linear_alpha: 4 # Do about half of rank
|
28 |
+
# training config
|
29 |
+
train:
|
30 |
+
# this is also used in sampling. Stick with ddpm unless you know what you are doing
|
31 |
+
noise_scheduler: "ddpm" # or "ddpm", "lms", "euler_a"
|
32 |
+
# how many steps to train. More is not always better. I rarely go over 1000
|
33 |
+
steps: 500
|
34 |
+
# I have had good results with 4e-4 to 1e-4 at 500 steps
|
35 |
+
lr: 2e-4
|
36 |
+
# enables gradient checkpoint, saves vram, leave it on
|
37 |
+
gradient_checkpointing: true
|
38 |
+
# train the unet. I recommend leaving this true
|
39 |
+
train_unet: true
|
40 |
+
# train the text encoder. I don't recommend this unless you have a special use case
|
41 |
+
# for sliders we are adjusting representation of the concept (unet),
|
42 |
+
# not the description of it (text encoder)
|
43 |
+
train_text_encoder: false
|
44 |
+
# same as from sd-scripts, not fully tested but should speed up training
|
45 |
+
min_snr_gamma: 5.0
|
46 |
+
# just leave unless you know what you are doing
|
47 |
+
# also supports "dadaptation" but set lr to 1 if you use that,
|
48 |
+
# but it learns too fast and I don't recommend it
|
49 |
+
optimizer: "adamw"
|
50 |
+
# only constant for now
|
51 |
+
lr_scheduler: "constant"
|
52 |
+
# we randomly denoise random num of steps form 1 to this number
|
53 |
+
# while training. Just leave it
|
54 |
+
max_denoising_steps: 40
|
55 |
+
# works great at 1. I do 1 even with my 4090.
|
56 |
+
# higher may not work right with newer single batch stacking code anyway
|
57 |
+
batch_size: 1
|
58 |
+
# bf16 works best if your GPU supports it (modern)
|
59 |
+
dtype: bf16 # fp32, bf16, fp16
|
60 |
+
# if you have it, use it. It is faster and better
|
61 |
+
# torch 2.0 doesnt need xformers anymore, only use if you have lower version
|
62 |
+
# xformers: true
|
63 |
+
# I don't recommend using unless you are trying to make a darker lora. Then do 0.1 MAX
|
64 |
+
# although, the way we train sliders is comparative, so it probably won't work anyway
|
65 |
+
noise_offset: 0.0
|
66 |
+
# noise_offset: 0.0357 # SDXL was trained with offset of 0.0357. So use that when training on SDXL
|
67 |
+
|
68 |
+
# the model to train the LoRA network on
|
69 |
+
model:
|
70 |
+
# huggingface name, relative prom project path, or absolute path to .safetensors or .ckpt
|
71 |
+
name_or_path: "runwayml/stable-diffusion-v1-5"
|
72 |
+
is_v2: false # for v2 models
|
73 |
+
is_v_pred: false # for v-prediction models (most v2 models)
|
74 |
+
# has some issues with the dual text encoder and the way we train sliders
|
75 |
+
# it works bit weights need to probably be higher to see it.
|
76 |
+
is_xl: false # for SDXL models
|
77 |
+
|
78 |
+
# saving config
|
79 |
+
save:
|
80 |
+
dtype: float16 # precision to save. I recommend float16
|
81 |
+
save_every: 50 # save every this many steps
|
82 |
+
# this will remove step counts more than this number
|
83 |
+
# allows you to save more often in case of a crash without filling up your drive
|
84 |
+
max_step_saves_to_keep: 2
|
85 |
+
|
86 |
+
# sampling config
|
87 |
+
sample:
|
88 |
+
# must match train.noise_scheduler, this is not used here
|
89 |
+
# but may be in future and in other processes
|
90 |
+
sampler: "ddpm"
|
91 |
+
# sample every this many steps
|
92 |
+
sample_every: 20
|
93 |
+
# image size
|
94 |
+
width: 512
|
95 |
+
height: 512
|
96 |
+
# prompts to use for sampling. Do as many as you want, but it slows down training
|
97 |
+
# pick ones that will best represent the concept you are trying to adjust
|
98 |
+
# allows some flags after the prompt
|
99 |
+
# --m [number] # network multiplier. LoRA weight. -3 for the negative slide, 3 for the positive
|
100 |
+
# slide are good tests. will inherit sample.network_multiplier if not set
|
101 |
+
# --n [string] # negative prompt, will inherit sample.neg if not set
|
102 |
+
# Only 75 tokens allowed currently
|
103 |
+
# I like to do a wide positive and negative spread so I can see a good range and stop
|
104 |
+
# early if the network is braking down
|
105 |
+
prompts:
|
106 |
+
- "a woman in a coffee shop, black hat, blonde hair, blue jacket --m -5"
|
107 |
+
- "a woman in a coffee shop, black hat, blonde hair, blue jacket --m -3"
|
108 |
+
- "a woman in a coffee shop, black hat, blonde hair, blue jacket --m 3"
|
109 |
+
- "a woman in a coffee shop, black hat, blonde hair, blue jacket --m 5"
|
110 |
+
- "a golden retriever sitting on a leather couch, --m -5"
|
111 |
+
- "a golden retriever sitting on a leather couch --m -3"
|
112 |
+
- "a golden retriever sitting on a leather couch --m 3"
|
113 |
+
- "a golden retriever sitting on a leather couch --m 5"
|
114 |
+
- "a man with a beard and red flannel shirt, wearing vr goggles, walking into traffic --m -5"
|
115 |
+
- "a man with a beard and red flannel shirt, wearing vr goggles, walking into traffic --m -3"
|
116 |
+
- "a man with a beard and red flannel shirt, wearing vr goggles, walking into traffic --m 3"
|
117 |
+
- "a man with a beard and red flannel shirt, wearing vr goggles, walking into traffic --m 5"
|
118 |
+
# negative prompt used on all prompts above as default if they don't have one
|
119 |
+
neg: "cartoon, fake, drawing, illustration, cgi, animated, anime, monochrome"
|
120 |
+
# seed for sampling. 42 is the answer for everything
|
121 |
+
seed: 42
|
122 |
+
# walks the seed so s1 is 42, s2 is 43, s3 is 44, etc
|
123 |
+
# will start over on next sample_every so s1 is always seed
|
124 |
+
# works well if you use same prompt but want different results
|
125 |
+
walk_seed: false
|
126 |
+
# cfg scale (4 to 10 is good)
|
127 |
+
guidance_scale: 7
|
128 |
+
# sampler steps (20 to 30 is good)
|
129 |
+
sample_steps: 20
|
130 |
+
# default network multiplier for all prompts
|
131 |
+
# since we are training a slider, I recommend overriding this with --m [number]
|
132 |
+
# in the prompts above to get both sides of the slider
|
133 |
+
network_multiplier: 1.0
|
134 |
+
|
135 |
+
# logging information
|
136 |
+
logging:
|
137 |
+
log_every: 10 # log every this many steps
|
138 |
+
use_wandb: false # not supported yet
|
139 |
+
verbose: false # probably done need unless you are debugging
|
140 |
+
|
141 |
+
# slider training config, best for last
|
142 |
+
slider:
|
143 |
+
# resolutions to train on. [ width, height ]. This is less important for sliders
|
144 |
+
# as we are not teaching the model anything it doesn't already know
|
145 |
+
# but must be a size it understands [ 512, 512 ] for sd_v1.5 and [ 768, 768 ] for sd_v2.1
|
146 |
+
# and [ 1024, 1024 ] for sd_xl
|
147 |
+
# you can do as many as you want here
|
148 |
+
resolutions:
|
149 |
+
- [ 512, 512 ]
|
150 |
+
# - [ 512, 768 ]
|
151 |
+
# - [ 768, 768 ]
|
152 |
+
# slider training uses 4 combined steps for a single round. This will do it in one gradient
|
153 |
+
# step. It is highly optimized and shouldn't take anymore vram than doing without it,
|
154 |
+
# since we break down batches for gradient accumulation now. so just leave it on.
|
155 |
+
batch_full_slide: true
|
156 |
+
# These are the concepts to train on. You can do as many as you want here,
|
157 |
+
# but they can conflict outweigh each other. Other than experimenting, I recommend
|
158 |
+
# just doing one for good results
|
159 |
+
targets:
|
160 |
+
# target_class is the base concept we are adjusting the representation of
|
161 |
+
# for example, if we are adjusting the representation of a person, we would use "person"
|
162 |
+
# if we are adjusting the representation of a cat, we would use "cat" It is not
|
163 |
+
# a keyword necessarily but what the model understands the concept to represent.
|
164 |
+
# "person" will affect men, women, children, etc but will not affect cats, dogs, etc
|
165 |
+
# it is the models base general understanding of the concept and everything it represents
|
166 |
+
# you can leave it blank to affect everything. In this example, we are adjusting
|
167 |
+
# detail, so we will leave it blank to affect everything
|
168 |
+
- target_class: ""
|
169 |
+
# positive is the prompt for the positive side of the slider.
|
170 |
+
# It is the concept that will be excited and amplified in the model when we slide the slider
|
171 |
+
# to the positive side and forgotten / inverted when we slide
|
172 |
+
# the slider to the negative side. It is generally best to include the target_class in
|
173 |
+
# the prompt. You want it to be the extreme of what you want to train on. For example,
|
174 |
+
# if you want to train on fat people, you would use "an extremely fat, morbidly obese person"
|
175 |
+
# as the prompt. Not just "fat person"
|
176 |
+
# max 75 tokens for now
|
177 |
+
positive: "high detail, 8k, intricate, detailed, high resolution, high res, high quality"
|
178 |
+
# negative is the prompt for the negative side of the slider and works the same as positive
|
179 |
+
# it does not necessarily work the same as a negative prompt when generating images
|
180 |
+
# these need to be polar opposites.
|
181 |
+
# max 76 tokens for now
|
182 |
+
negative: "blurry, boring, fuzzy, low detail, low resolution, low res, low quality"
|
183 |
+
# the loss for this target is multiplied by this number.
|
184 |
+
# if you are doing more than one target it may be good to set less important ones
|
185 |
+
# to a lower number like 0.1 so they don't outweigh the primary target
|
186 |
+
weight: 1.0
|
187 |
+
# shuffle the prompts split by the comma. We will run every combination randomly
|
188 |
+
# this will make the LoRA more robust. You probably want this on unless prompt order
|
189 |
+
# is important for some reason
|
190 |
+
shuffle: true
|
191 |
+
|
192 |
+
|
193 |
+
# anchors are prompts that we will try to hold on to while training the slider
|
194 |
+
# these are NOT necessary and can prevent the slider from converging if not done right
|
195 |
+
# leave them off if you are having issues, but they can help lock the network
|
196 |
+
# on certain concepts to help prevent catastrophic forgetting
|
197 |
+
# you want these to generate an image that is not your target_class, but close to it
|
198 |
+
# is fine as long as it does not directly overlap it.
|
199 |
+
# For example, if you are training on a person smiling,
|
200 |
+
# you could use "a person with a face mask" as an anchor. It is a person, the image is the same
|
201 |
+
# regardless if they are smiling or not, however, the closer the concept is to the target_class
|
202 |
+
# the less the multiplier needs to be. Keep multipliers less than 1.0 for anchors usually
|
203 |
+
# for close concepts, you want to be closer to 0.1 or 0.2
|
204 |
+
# these will slow down training. I am leaving them off for the demo
|
205 |
+
|
206 |
+
# anchors:
|
207 |
+
# - prompt: "a woman"
|
208 |
+
# neg_prompt: "animal"
|
209 |
+
# # the multiplier applied to the LoRA when this is run.
|
210 |
+
# # higher will give it more weight but also help keep the lora from collapsing
|
211 |
+
# multiplier: 1.0
|
212 |
+
# - prompt: "a man"
|
213 |
+
# neg_prompt: "animal"
|
214 |
+
# multiplier: 1.0
|
215 |
+
# - prompt: "a person"
|
216 |
+
# neg_prompt: "animal"
|
217 |
+
# multiplier: 1.0
|
218 |
+
|
219 |
+
# You can put any information you want here, and it will be saved in the model.
|
220 |
+
# The below is an example, but you can put your grocery list in it if you want.
|
221 |
+
# It is saved in the model so be aware of that. The software will include this
|
222 |
+
# plus some other information for you automatically
|
223 |
+
meta:
|
224 |
+
# [name] gets replaced with the name above
|
225 |
+
name: "[name]"
|
226 |
+
# version: '1.0'
|
227 |
+
# creator:
|
228 |
+
# name: Your Name
|
229 |
+
# email: [email protected]
|
230 |
+
# website: https://your.website
|
docker-compose.yml
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
version: "3.8"
|
2 |
+
|
3 |
+
services:
|
4 |
+
ai-toolkit:
|
5 |
+
image: ostris/aitoolkit:latest
|
6 |
+
restart: unless-stopped
|
7 |
+
ports:
|
8 |
+
- "8675:8675"
|
9 |
+
volumes:
|
10 |
+
- ~/.cache/huggingface/hub:/root/.cache/huggingface/hub
|
11 |
+
- ./aitk_db.db:/app/ai-toolkit/aitk_db.db
|
12 |
+
- ./datasets:/app/ai-toolkit/datasets
|
13 |
+
- ./output:/app/ai-toolkit/output
|
14 |
+
- ./config:/app/ai-toolkit/config
|
15 |
+
environment:
|
16 |
+
- AI_TOOLKIT_AUTH=${AI_TOOLKIT_AUTH:-password}
|
17 |
+
- NODE_ENV=production
|
18 |
+
- TZ=UTC
|
19 |
+
deploy:
|
20 |
+
resources:
|
21 |
+
reservations:
|
22 |
+
devices:
|
23 |
+
- driver: nvidia
|
24 |
+
count: all
|
25 |
+
capabilities: [gpu]
|
docker/Dockerfile
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM nvidia/cuda:12.8.1-devel-ubuntu22.04
|
2 |
+
|
3 |
+
LABEL authors="jaret"
|
4 |
+
|
5 |
+
# Set noninteractive to avoid timezone prompts
|
6 |
+
ENV DEBIAN_FRONTEND=noninteractive
|
7 |
+
|
8 |
+
# ref https://en.wikipedia.org/wiki/CUDA
|
9 |
+
ENV TORCH_CUDA_ARCH_LIST="8.0 8.6 8.9 9.0 10.0 12.0"
|
10 |
+
|
11 |
+
# Install dependencies
|
12 |
+
RUN apt-get update && apt-get install --no-install-recommends -y \
|
13 |
+
git \
|
14 |
+
curl \
|
15 |
+
build-essential \
|
16 |
+
cmake \
|
17 |
+
wget \
|
18 |
+
python3.10 \
|
19 |
+
python3-pip \
|
20 |
+
python3-dev \
|
21 |
+
python3-setuptools \
|
22 |
+
python3-wheel \
|
23 |
+
python3-venv \
|
24 |
+
ffmpeg \
|
25 |
+
tmux \
|
26 |
+
htop \
|
27 |
+
nvtop \
|
28 |
+
python3-opencv \
|
29 |
+
openssh-client \
|
30 |
+
openssh-server \
|
31 |
+
openssl \
|
32 |
+
rsync \
|
33 |
+
unzip \
|
34 |
+
&& apt-get clean \
|
35 |
+
&& rm -rf /var/lib/apt/lists/*
|
36 |
+
|
37 |
+
# Install nodejs
|
38 |
+
WORKDIR /tmp
|
39 |
+
RUN curl -sL https://deb.nodesource.com/setup_23.x -o nodesource_setup.sh && \
|
40 |
+
bash nodesource_setup.sh && \
|
41 |
+
apt-get update && \
|
42 |
+
apt-get install -y nodejs && \
|
43 |
+
apt-get clean && \
|
44 |
+
rm -rf /var/lib/apt/lists/*
|
45 |
+
|
46 |
+
WORKDIR /app
|
47 |
+
|
48 |
+
# Set aliases for python and pip
|
49 |
+
RUN ln -s /usr/bin/python3 /usr/bin/python
|
50 |
+
|
51 |
+
# install pytorch before cache bust to avoid redownloading pytorch
|
52 |
+
RUN pip install --pre --no-cache-dir torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128
|
53 |
+
|
54 |
+
# Fix cache busting by moving CACHEBUST to right before git clone
|
55 |
+
ARG CACHEBUST=1234
|
56 |
+
ARG GIT_COMMIT=main
|
57 |
+
RUN echo "Cache bust: ${CACHEBUST}" && \
|
58 |
+
git clone https://github.com/ostris/ai-toolkit.git && \
|
59 |
+
cd ai-toolkit && \
|
60 |
+
git checkout ${GIT_COMMIT}
|
61 |
+
|
62 |
+
WORKDIR /app/ai-toolkit
|
63 |
+
|
64 |
+
# Install Python dependencies
|
65 |
+
RUN pip install --no-cache-dir -r requirements.txt && \
|
66 |
+
pip install --pre --no-cache-dir torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128 --force && \
|
67 |
+
pip install setuptools==69.5.1 --no-cache-dir
|
68 |
+
|
69 |
+
# Build UI
|
70 |
+
WORKDIR /app/ai-toolkit/ui
|
71 |
+
RUN npm install && \
|
72 |
+
npm run build && \
|
73 |
+
npm run update_db
|
74 |
+
|
75 |
+
# Expose port (assuming the application runs on port 3000)
|
76 |
+
EXPOSE 8675
|
77 |
+
|
78 |
+
WORKDIR /
|
79 |
+
|
80 |
+
COPY docker/start.sh /start.sh
|
81 |
+
RUN chmod +x /start.sh
|
82 |
+
|
83 |
+
CMD ["/start.sh"]
|
docker/start.sh
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
set -e # Exit the script if any statement returns a non-true return value
|
3 |
+
|
4 |
+
# ref https://github.com/runpod/containers/blob/main/container-template/start.sh
|
5 |
+
|
6 |
+
# ---------------------------------------------------------------------------- #
|
7 |
+
# Function Definitions #
|
8 |
+
# ---------------------------------------------------------------------------- #
|
9 |
+
|
10 |
+
|
11 |
+
# Setup ssh
|
12 |
+
setup_ssh() {
|
13 |
+
if [[ $PUBLIC_KEY ]]; then
|
14 |
+
echo "Setting up SSH..."
|
15 |
+
mkdir -p ~/.ssh
|
16 |
+
echo "$PUBLIC_KEY" >> ~/.ssh/authorized_keys
|
17 |
+
chmod 700 -R ~/.ssh
|
18 |
+
|
19 |
+
if [ ! -f /etc/ssh/ssh_host_rsa_key ]; then
|
20 |
+
ssh-keygen -t rsa -f /etc/ssh/ssh_host_rsa_key -q -N ''
|
21 |
+
echo "RSA key fingerprint:"
|
22 |
+
ssh-keygen -lf /etc/ssh/ssh_host_rsa_key.pub
|
23 |
+
fi
|
24 |
+
|
25 |
+
if [ ! -f /etc/ssh/ssh_host_dsa_key ]; then
|
26 |
+
ssh-keygen -t dsa -f /etc/ssh/ssh_host_dsa_key -q -N ''
|
27 |
+
echo "DSA key fingerprint:"
|
28 |
+
ssh-keygen -lf /etc/ssh/ssh_host_dsa_key.pub
|
29 |
+
fi
|
30 |
+
|
31 |
+
if [ ! -f /etc/ssh/ssh_host_ecdsa_key ]; then
|
32 |
+
ssh-keygen -t ecdsa -f /etc/ssh/ssh_host_ecdsa_key -q -N ''
|
33 |
+
echo "ECDSA key fingerprint:"
|
34 |
+
ssh-keygen -lf /etc/ssh/ssh_host_ecdsa_key.pub
|
35 |
+
fi
|
36 |
+
|
37 |
+
if [ ! -f /etc/ssh/ssh_host_ed25519_key ]; then
|
38 |
+
ssh-keygen -t ed25519 -f /etc/ssh/ssh_host_ed25519_key -q -N ''
|
39 |
+
echo "ED25519 key fingerprint:"
|
40 |
+
ssh-keygen -lf /etc/ssh/ssh_host_ed25519_key.pub
|
41 |
+
fi
|
42 |
+
|
43 |
+
service ssh start
|
44 |
+
|
45 |
+
echo "SSH host keys:"
|
46 |
+
for key in /etc/ssh/*.pub; do
|
47 |
+
echo "Key: $key"
|
48 |
+
ssh-keygen -lf $key
|
49 |
+
done
|
50 |
+
fi
|
51 |
+
}
|
52 |
+
|
53 |
+
# Export env vars
|
54 |
+
export_env_vars() {
|
55 |
+
echo "Exporting environment variables..."
|
56 |
+
printenv | grep -E '^RUNPOD_|^PATH=|^_=' | awk -F = '{ print "export " $1 "=\"" $2 "\"" }' >> /etc/rp_environment
|
57 |
+
echo 'source /etc/rp_environment' >> ~/.bashrc
|
58 |
+
}
|
59 |
+
|
60 |
+
# ---------------------------------------------------------------------------- #
|
61 |
+
# Main Program #
|
62 |
+
# ---------------------------------------------------------------------------- #
|
63 |
+
|
64 |
+
|
65 |
+
echo "Pod Started"
|
66 |
+
|
67 |
+
setup_ssh
|
68 |
+
export_env_vars
|
69 |
+
echo "Starting AI Toolkit UI..."
|
70 |
+
cd /app/ai-toolkit/ui && npm run start
|
extensions/example/ExampleMergeModels.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import gc
|
3 |
+
from collections import OrderedDict
|
4 |
+
from typing import TYPE_CHECKING
|
5 |
+
from jobs.process import BaseExtensionProcess
|
6 |
+
from toolkit.config_modules import ModelConfig
|
7 |
+
from toolkit.stable_diffusion_model import StableDiffusion
|
8 |
+
from toolkit.train_tools import get_torch_dtype
|
9 |
+
from tqdm import tqdm
|
10 |
+
|
11 |
+
# Type check imports. Prevents circular imports
|
12 |
+
if TYPE_CHECKING:
|
13 |
+
from jobs import ExtensionJob
|
14 |
+
|
15 |
+
|
16 |
+
# extend standard config classes to add weight
|
17 |
+
class ModelInputConfig(ModelConfig):
|
18 |
+
def __init__(self, **kwargs):
|
19 |
+
super().__init__(**kwargs)
|
20 |
+
self.weight = kwargs.get('weight', 1.0)
|
21 |
+
# overwrite default dtype unless user specifies otherwise
|
22 |
+
# float 32 will give up better precision on the merging functions
|
23 |
+
self.dtype: str = kwargs.get('dtype', 'float32')
|
24 |
+
|
25 |
+
|
26 |
+
def flush():
|
27 |
+
torch.cuda.empty_cache()
|
28 |
+
gc.collect()
|
29 |
+
|
30 |
+
|
31 |
+
# this is our main class process
|
32 |
+
class ExampleMergeModels(BaseExtensionProcess):
|
33 |
+
def __init__(
|
34 |
+
self,
|
35 |
+
process_id: int,
|
36 |
+
job: 'ExtensionJob',
|
37 |
+
config: OrderedDict
|
38 |
+
):
|
39 |
+
super().__init__(process_id, job, config)
|
40 |
+
# this is the setup process, do not do process intensive stuff here, just variable setup and
|
41 |
+
# checking requirements. This is called before the run() function
|
42 |
+
# no loading models or anything like that, it is just for setting up the process
|
43 |
+
# all of your process intensive stuff should be done in the run() function
|
44 |
+
# config will have everything from the process item in the config file
|
45 |
+
|
46 |
+
# convince methods exist on BaseProcess to get config values
|
47 |
+
# if required is set to true and the value is not found it will throw an error
|
48 |
+
# you can pass a default value to get_conf() as well if it was not in the config file
|
49 |
+
# as well as a type to cast the value to
|
50 |
+
self.save_path = self.get_conf('save_path', required=True)
|
51 |
+
self.save_dtype = self.get_conf('save_dtype', default='float16', as_type=get_torch_dtype)
|
52 |
+
self.device = self.get_conf('device', default='cpu', as_type=torch.device)
|
53 |
+
|
54 |
+
# build models to merge list
|
55 |
+
models_to_merge = self.get_conf('models_to_merge', required=True, as_type=list)
|
56 |
+
# build list of ModelInputConfig objects. I find it is a good idea to make a class for each config
|
57 |
+
# this way you can add methods to it and it is easier to read and code. There are a lot of
|
58 |
+
# inbuilt config classes located in toolkit.config_modules as well
|
59 |
+
self.models_to_merge = [ModelInputConfig(**model) for model in models_to_merge]
|
60 |
+
# setup is complete. Don't load anything else here, just setup variables and stuff
|
61 |
+
|
62 |
+
# this is the entire run process be sure to call super().run() first
|
63 |
+
def run(self):
|
64 |
+
# always call first
|
65 |
+
super().run()
|
66 |
+
print(f"Running process: {self.__class__.__name__}")
|
67 |
+
|
68 |
+
# let's adjust our weights first to normalize them so the total is 1.0
|
69 |
+
total_weight = sum([model.weight for model in self.models_to_merge])
|
70 |
+
weight_adjust = 1.0 / total_weight
|
71 |
+
for model in self.models_to_merge:
|
72 |
+
model.weight *= weight_adjust
|
73 |
+
|
74 |
+
output_model: StableDiffusion = None
|
75 |
+
# let's do the merge, it is a good idea to use tqdm to show progress
|
76 |
+
for model_config in tqdm(self.models_to_merge, desc="Merging models"):
|
77 |
+
# setup model class with our helper class
|
78 |
+
sd_model = StableDiffusion(
|
79 |
+
device=self.device,
|
80 |
+
model_config=model_config,
|
81 |
+
dtype="float32"
|
82 |
+
)
|
83 |
+
# load the model
|
84 |
+
sd_model.load_model()
|
85 |
+
|
86 |
+
# adjust the weight of the text encoder
|
87 |
+
if isinstance(sd_model.text_encoder, list):
|
88 |
+
# sdxl model
|
89 |
+
for text_encoder in sd_model.text_encoder:
|
90 |
+
for key, value in text_encoder.state_dict().items():
|
91 |
+
value *= model_config.weight
|
92 |
+
else:
|
93 |
+
# normal model
|
94 |
+
for key, value in sd_model.text_encoder.state_dict().items():
|
95 |
+
value *= model_config.weight
|
96 |
+
# adjust the weights of the unet
|
97 |
+
for key, value in sd_model.unet.state_dict().items():
|
98 |
+
value *= model_config.weight
|
99 |
+
|
100 |
+
if output_model is None:
|
101 |
+
# use this one as the base
|
102 |
+
output_model = sd_model
|
103 |
+
else:
|
104 |
+
# merge the models
|
105 |
+
# text encoder
|
106 |
+
if isinstance(output_model.text_encoder, list):
|
107 |
+
# sdxl model
|
108 |
+
for i, text_encoder in enumerate(output_model.text_encoder):
|
109 |
+
for key, value in text_encoder.state_dict().items():
|
110 |
+
value += sd_model.text_encoder[i].state_dict()[key]
|
111 |
+
else:
|
112 |
+
# normal model
|
113 |
+
for key, value in output_model.text_encoder.state_dict().items():
|
114 |
+
value += sd_model.text_encoder.state_dict()[key]
|
115 |
+
# unet
|
116 |
+
for key, value in output_model.unet.state_dict().items():
|
117 |
+
value += sd_model.unet.state_dict()[key]
|
118 |
+
|
119 |
+
# remove the model to free memory
|
120 |
+
del sd_model
|
121 |
+
flush()
|
122 |
+
|
123 |
+
# merge loop is done, let's save the model
|
124 |
+
print(f"Saving merged model to {self.save_path}")
|
125 |
+
output_model.save(self.save_path, meta=self.meta, save_dtype=self.save_dtype)
|
126 |
+
print(f"Saved merged model to {self.save_path}")
|
127 |
+
# do cleanup here
|
128 |
+
del output_model
|
129 |
+
flush()
|
extensions/example/__init__.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This is an example extension for custom training. It is great for experimenting with new ideas.
|
2 |
+
from toolkit.extension import Extension
|
3 |
+
|
4 |
+
|
5 |
+
# We make a subclass of Extension
|
6 |
+
class ExampleMergeExtension(Extension):
|
7 |
+
# uid must be unique, it is how the extension is identified
|
8 |
+
uid = "example_merge_extension"
|
9 |
+
|
10 |
+
# name is the name of the extension for printing
|
11 |
+
name = "Example Merge Extension"
|
12 |
+
|
13 |
+
# This is where your process class is loaded
|
14 |
+
# keep your imports in here so they don't slow down the rest of the program
|
15 |
+
@classmethod
|
16 |
+
def get_process(cls):
|
17 |
+
# import your process class here so it is only loaded when needed and return it
|
18 |
+
from .ExampleMergeModels import ExampleMergeModels
|
19 |
+
return ExampleMergeModels
|
20 |
+
|
21 |
+
|
22 |
+
AI_TOOLKIT_EXTENSIONS = [
|
23 |
+
# you can put a list of extensions here
|
24 |
+
ExampleMergeExtension
|
25 |
+
]
|
extensions/example/config/config.example.yaml
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
# Always include at least one example config file to show how to use your extension.
|
3 |
+
# use plenty of comments so users know how to use it and what everything does
|
4 |
+
|
5 |
+
# all extensions will use this job name
|
6 |
+
job: extension
|
7 |
+
config:
|
8 |
+
name: 'my_awesome_merge'
|
9 |
+
process:
|
10 |
+
# Put your example processes here. This will be passed
|
11 |
+
# to your extension process in the config argument.
|
12 |
+
# the type MUST match your extension uid
|
13 |
+
- type: "example_merge_extension"
|
14 |
+
# save path for the merged model
|
15 |
+
save_path: "output/merge/[name].safetensors"
|
16 |
+
# save type
|
17 |
+
dtype: fp16
|
18 |
+
# device to run it on
|
19 |
+
device: cuda:0
|
20 |
+
# input models can only be SD1.x and SD2.x models for this example (currently)
|
21 |
+
models_to_merge:
|
22 |
+
# weights are relative, total weights will be normalized
|
23 |
+
# for example. If you have 2 models with weight 1.0, they will
|
24 |
+
# both be weighted 0.5. If you have 1 model with weight 1.0 and
|
25 |
+
# another with weight 2.0, the first will be weighted 1/3 and the
|
26 |
+
# second will be weighted 2/3
|
27 |
+
- name_or_path: "input/model1.safetensors"
|
28 |
+
weight: 1.0
|
29 |
+
- name_or_path: "input/model2.safetensors"
|
30 |
+
weight: 1.0
|
31 |
+
- name_or_path: "input/model3.safetensors"
|
32 |
+
weight: 0.3
|
33 |
+
- name_or_path: "input/model4.safetensors"
|
34 |
+
weight: 1.0
|
35 |
+
|
36 |
+
|
37 |
+
# you can put any information you want here, and it will be saved in the model
|
38 |
+
# the below is an example. I recommend doing trigger words at a minimum
|
39 |
+
# in the metadata. The software will include this plus some other information
|
40 |
+
meta:
|
41 |
+
name: "[name]" # [name] gets replaced with the name above
|
42 |
+
description: A short description of your model
|
43 |
+
version: '0.1'
|
44 |
+
creator:
|
45 |
+
name: Your Name
|
46 |
+
email: [email protected]
|
47 |
+
website: https://yourwebsite.com
|
48 |
+
any: All meta data above is arbitrary, it can be whatever you want.
|
extensions_built_in/.DS_Store
ADDED
Binary file (8.2 kB). View file
|
|
extensions_built_in/advanced_generator/Img2ImgGenerator.py
ADDED
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import os
|
3 |
+
import random
|
4 |
+
from collections import OrderedDict
|
5 |
+
from typing import List
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
from PIL import Image
|
9 |
+
from diffusers import T2IAdapter
|
10 |
+
from diffusers.utils.torch_utils import randn_tensor
|
11 |
+
from torch.utils.data import DataLoader
|
12 |
+
from diffusers import StableDiffusionXLImg2ImgPipeline, PixArtSigmaPipeline
|
13 |
+
from tqdm import tqdm
|
14 |
+
|
15 |
+
from toolkit.config_modules import ModelConfig, GenerateImageConfig, preprocess_dataset_raw_config, DatasetConfig
|
16 |
+
from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO
|
17 |
+
from toolkit.sampler import get_sampler
|
18 |
+
from toolkit.stable_diffusion_model import StableDiffusion
|
19 |
+
import gc
|
20 |
+
import torch
|
21 |
+
from jobs.process import BaseExtensionProcess
|
22 |
+
from toolkit.data_loader import get_dataloader_from_datasets
|
23 |
+
from toolkit.train_tools import get_torch_dtype
|
24 |
+
from controlnet_aux.midas import MidasDetector
|
25 |
+
from diffusers.utils import load_image
|
26 |
+
from torchvision.transforms import ToTensor
|
27 |
+
|
28 |
+
|
29 |
+
def flush():
|
30 |
+
torch.cuda.empty_cache()
|
31 |
+
gc.collect()
|
32 |
+
|
33 |
+
|
34 |
+
|
35 |
+
|
36 |
+
|
37 |
+
class GenerateConfig:
|
38 |
+
|
39 |
+
def __init__(self, **kwargs):
|
40 |
+
self.prompts: List[str]
|
41 |
+
self.sampler = kwargs.get('sampler', 'ddpm')
|
42 |
+
self.neg = kwargs.get('neg', '')
|
43 |
+
self.seed = kwargs.get('seed', -1)
|
44 |
+
self.walk_seed = kwargs.get('walk_seed', False)
|
45 |
+
self.guidance_scale = kwargs.get('guidance_scale', 7)
|
46 |
+
self.sample_steps = kwargs.get('sample_steps', 20)
|
47 |
+
self.guidance_rescale = kwargs.get('guidance_rescale', 0.0)
|
48 |
+
self.ext = kwargs.get('ext', 'png')
|
49 |
+
self.denoise_strength = kwargs.get('denoise_strength', 0.5)
|
50 |
+
self.trigger_word = kwargs.get('trigger_word', None)
|
51 |
+
|
52 |
+
|
53 |
+
class Img2ImgGenerator(BaseExtensionProcess):
|
54 |
+
|
55 |
+
def __init__(self, process_id: int, job, config: OrderedDict):
|
56 |
+
super().__init__(process_id, job, config)
|
57 |
+
self.output_folder = self.get_conf('output_folder', required=True)
|
58 |
+
self.copy_inputs_to = self.get_conf('copy_inputs_to', None)
|
59 |
+
self.device = self.get_conf('device', 'cuda')
|
60 |
+
self.model_config = ModelConfig(**self.get_conf('model', required=True))
|
61 |
+
self.generate_config = GenerateConfig(**self.get_conf('generate', required=True))
|
62 |
+
self.is_latents_cached = True
|
63 |
+
raw_datasets = self.get_conf('datasets', None)
|
64 |
+
if raw_datasets is not None and len(raw_datasets) > 0:
|
65 |
+
raw_datasets = preprocess_dataset_raw_config(raw_datasets)
|
66 |
+
self.datasets = None
|
67 |
+
self.datasets_reg = None
|
68 |
+
self.dtype = self.get_conf('dtype', 'float16')
|
69 |
+
self.torch_dtype = get_torch_dtype(self.dtype)
|
70 |
+
self.params = []
|
71 |
+
if raw_datasets is not None and len(raw_datasets) > 0:
|
72 |
+
for raw_dataset in raw_datasets:
|
73 |
+
dataset = DatasetConfig(**raw_dataset)
|
74 |
+
is_caching = dataset.cache_latents or dataset.cache_latents_to_disk
|
75 |
+
if not is_caching:
|
76 |
+
self.is_latents_cached = False
|
77 |
+
if dataset.is_reg:
|
78 |
+
if self.datasets_reg is None:
|
79 |
+
self.datasets_reg = []
|
80 |
+
self.datasets_reg.append(dataset)
|
81 |
+
else:
|
82 |
+
if self.datasets is None:
|
83 |
+
self.datasets = []
|
84 |
+
self.datasets.append(dataset)
|
85 |
+
|
86 |
+
self.progress_bar = None
|
87 |
+
self.sd = StableDiffusion(
|
88 |
+
device=self.device,
|
89 |
+
model_config=self.model_config,
|
90 |
+
dtype=self.dtype,
|
91 |
+
)
|
92 |
+
print(f"Using device {self.device}")
|
93 |
+
self.data_loader: DataLoader = None
|
94 |
+
self.adapter: T2IAdapter = None
|
95 |
+
|
96 |
+
def to_pil(self, img):
|
97 |
+
# image comes in -1 to 1. convert to a PIL RGB image
|
98 |
+
img = (img + 1) / 2
|
99 |
+
img = img.clamp(0, 1)
|
100 |
+
img = img[0].permute(1, 2, 0).cpu().numpy()
|
101 |
+
img = (img * 255).astype(np.uint8)
|
102 |
+
image = Image.fromarray(img)
|
103 |
+
return image
|
104 |
+
|
105 |
+
def run(self):
|
106 |
+
with torch.no_grad():
|
107 |
+
super().run()
|
108 |
+
print("Loading model...")
|
109 |
+
self.sd.load_model()
|
110 |
+
device = torch.device(self.device)
|
111 |
+
|
112 |
+
if self.model_config.is_xl:
|
113 |
+
pipe = StableDiffusionXLImg2ImgPipeline(
|
114 |
+
vae=self.sd.vae,
|
115 |
+
unet=self.sd.unet,
|
116 |
+
text_encoder=self.sd.text_encoder[0],
|
117 |
+
text_encoder_2=self.sd.text_encoder[1],
|
118 |
+
tokenizer=self.sd.tokenizer[0],
|
119 |
+
tokenizer_2=self.sd.tokenizer[1],
|
120 |
+
scheduler=get_sampler(self.generate_config.sampler),
|
121 |
+
).to(device, dtype=self.torch_dtype)
|
122 |
+
elif self.model_config.is_pixart:
|
123 |
+
pipe = self.sd.pipeline.to(device, dtype=self.torch_dtype)
|
124 |
+
else:
|
125 |
+
raise NotImplementedError("Only XL models are supported")
|
126 |
+
pipe.set_progress_bar_config(disable=True)
|
127 |
+
|
128 |
+
# pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
|
129 |
+
# midas_depth = torch.compile(midas_depth, mode="reduce-overhead", fullgraph=True)
|
130 |
+
|
131 |
+
self.data_loader = get_dataloader_from_datasets(self.datasets, 1, self.sd)
|
132 |
+
|
133 |
+
num_batches = len(self.data_loader)
|
134 |
+
pbar = tqdm(total=num_batches, desc="Generating images")
|
135 |
+
seed = self.generate_config.seed
|
136 |
+
# load images from datasets, use tqdm
|
137 |
+
for i, batch in enumerate(self.data_loader):
|
138 |
+
batch: DataLoaderBatchDTO = batch
|
139 |
+
|
140 |
+
gen_seed = seed if seed > 0 else random.randint(0, 2 ** 32 - 1)
|
141 |
+
generator = torch.manual_seed(gen_seed)
|
142 |
+
|
143 |
+
file_item: FileItemDTO = batch.file_items[0]
|
144 |
+
img_path = file_item.path
|
145 |
+
img_filename = os.path.basename(img_path)
|
146 |
+
img_filename_no_ext = os.path.splitext(img_filename)[0]
|
147 |
+
img_filename = img_filename_no_ext + '.' + self.generate_config.ext
|
148 |
+
output_path = os.path.join(self.output_folder, img_filename)
|
149 |
+
output_caption_path = os.path.join(self.output_folder, img_filename_no_ext + '.txt')
|
150 |
+
|
151 |
+
if self.copy_inputs_to is not None:
|
152 |
+
output_inputs_path = os.path.join(self.copy_inputs_to, img_filename)
|
153 |
+
output_inputs_caption_path = os.path.join(self.copy_inputs_to, img_filename_no_ext + '.txt')
|
154 |
+
else:
|
155 |
+
output_inputs_path = None
|
156 |
+
output_inputs_caption_path = None
|
157 |
+
|
158 |
+
caption = batch.get_caption_list()[0]
|
159 |
+
if self.generate_config.trigger_word is not None:
|
160 |
+
caption = caption.replace('[trigger]', self.generate_config.trigger_word)
|
161 |
+
|
162 |
+
img: torch.Tensor = batch.tensor.clone()
|
163 |
+
image = self.to_pil(img)
|
164 |
+
|
165 |
+
# image.save(output_depth_path)
|
166 |
+
if self.model_config.is_pixart:
|
167 |
+
pipe: PixArtSigmaPipeline = pipe
|
168 |
+
|
169 |
+
# Encode the full image once
|
170 |
+
encoded_image = pipe.vae.encode(
|
171 |
+
pipe.image_processor.preprocess(image).to(device=pipe.device, dtype=pipe.dtype))
|
172 |
+
if hasattr(encoded_image, "latent_dist"):
|
173 |
+
latents = encoded_image.latent_dist.sample(generator)
|
174 |
+
elif hasattr(encoded_image, "latents"):
|
175 |
+
latents = encoded_image.latents
|
176 |
+
else:
|
177 |
+
raise AttributeError("Could not access latents of provided encoder_output")
|
178 |
+
latents = pipe.vae.config.scaling_factor * latents
|
179 |
+
|
180 |
+
# latents = self.sd.encode_images(img)
|
181 |
+
|
182 |
+
# self.sd.noise_scheduler.set_timesteps(self.generate_config.sample_steps)
|
183 |
+
# start_step = math.floor(self.generate_config.sample_steps * self.generate_config.denoise_strength)
|
184 |
+
# timestep = self.sd.noise_scheduler.timesteps[start_step].unsqueeze(0)
|
185 |
+
# timestep = timestep.to(device, dtype=torch.int32)
|
186 |
+
# latent = latent.to(device, dtype=self.torch_dtype)
|
187 |
+
# noise = torch.randn_like(latent, device=device, dtype=self.torch_dtype)
|
188 |
+
# latent = self.sd.add_noise(latent, noise, timestep)
|
189 |
+
# timesteps_to_use = self.sd.noise_scheduler.timesteps[start_step + 1:]
|
190 |
+
batch_size = 1
|
191 |
+
num_images_per_prompt = 1
|
192 |
+
|
193 |
+
shape = (batch_size, pipe.transformer.config.in_channels, image.height // pipe.vae_scale_factor,
|
194 |
+
image.width // pipe.vae_scale_factor)
|
195 |
+
noise = randn_tensor(shape, generator=generator, device=pipe.device, dtype=pipe.dtype)
|
196 |
+
|
197 |
+
# noise = torch.randn_like(latents, device=device, dtype=self.torch_dtype)
|
198 |
+
num_inference_steps = self.generate_config.sample_steps
|
199 |
+
strength = self.generate_config.denoise_strength
|
200 |
+
# Get timesteps
|
201 |
+
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
|
202 |
+
t_start = max(num_inference_steps - init_timestep, 0)
|
203 |
+
pipe.scheduler.set_timesteps(num_inference_steps, device="cpu")
|
204 |
+
timesteps = pipe.scheduler.timesteps[t_start:]
|
205 |
+
timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
206 |
+
latents = pipe.scheduler.add_noise(latents, noise, timestep)
|
207 |
+
|
208 |
+
gen_images = pipe.__call__(
|
209 |
+
prompt=caption,
|
210 |
+
negative_prompt=self.generate_config.neg,
|
211 |
+
latents=latents,
|
212 |
+
timesteps=timesteps,
|
213 |
+
width=image.width,
|
214 |
+
height=image.height,
|
215 |
+
num_inference_steps=num_inference_steps,
|
216 |
+
num_images_per_prompt=num_images_per_prompt,
|
217 |
+
guidance_scale=self.generate_config.guidance_scale,
|
218 |
+
# strength=self.generate_config.denoise_strength,
|
219 |
+
use_resolution_binning=False,
|
220 |
+
output_type="np"
|
221 |
+
).images[0]
|
222 |
+
gen_images = (gen_images * 255).clip(0, 255).astype(np.uint8)
|
223 |
+
gen_images = Image.fromarray(gen_images)
|
224 |
+
else:
|
225 |
+
pipe: StableDiffusionXLImg2ImgPipeline = pipe
|
226 |
+
|
227 |
+
gen_images = pipe.__call__(
|
228 |
+
prompt=caption,
|
229 |
+
negative_prompt=self.generate_config.neg,
|
230 |
+
image=image,
|
231 |
+
num_inference_steps=self.generate_config.sample_steps,
|
232 |
+
guidance_scale=self.generate_config.guidance_scale,
|
233 |
+
strength=self.generate_config.denoise_strength,
|
234 |
+
).images[0]
|
235 |
+
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
236 |
+
gen_images.save(output_path)
|
237 |
+
|
238 |
+
# save caption
|
239 |
+
with open(output_caption_path, 'w') as f:
|
240 |
+
f.write(caption)
|
241 |
+
|
242 |
+
if output_inputs_path is not None:
|
243 |
+
os.makedirs(os.path.dirname(output_inputs_path), exist_ok=True)
|
244 |
+
image.save(output_inputs_path)
|
245 |
+
with open(output_inputs_caption_path, 'w') as f:
|
246 |
+
f.write(caption)
|
247 |
+
|
248 |
+
pbar.update(1)
|
249 |
+
batch.cleanup()
|
250 |
+
|
251 |
+
pbar.close()
|
252 |
+
print("Done generating images")
|
253 |
+
# cleanup
|
254 |
+
del self.sd
|
255 |
+
gc.collect()
|
256 |
+
torch.cuda.empty_cache()
|
extensions_built_in/advanced_generator/PureLoraGenerator.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from collections import OrderedDict
|
3 |
+
|
4 |
+
from toolkit.config_modules import ModelConfig, GenerateImageConfig, SampleConfig, LoRMConfig
|
5 |
+
from toolkit.lorm import ExtractMode, convert_diffusers_unet_to_lorm
|
6 |
+
from toolkit.sd_device_states_presets import get_train_sd_device_state_preset
|
7 |
+
from toolkit.stable_diffusion_model import StableDiffusion
|
8 |
+
import gc
|
9 |
+
import torch
|
10 |
+
from jobs.process import BaseExtensionProcess
|
11 |
+
from toolkit.train_tools import get_torch_dtype
|
12 |
+
|
13 |
+
|
14 |
+
def flush():
|
15 |
+
torch.cuda.empty_cache()
|
16 |
+
gc.collect()
|
17 |
+
|
18 |
+
|
19 |
+
class PureLoraGenerator(BaseExtensionProcess):
|
20 |
+
|
21 |
+
def __init__(self, process_id: int, job, config: OrderedDict):
|
22 |
+
super().__init__(process_id, job, config)
|
23 |
+
self.output_folder = self.get_conf('output_folder', required=True)
|
24 |
+
self.device = self.get_conf('device', 'cuda')
|
25 |
+
self.device_torch = torch.device(self.device)
|
26 |
+
self.model_config = ModelConfig(**self.get_conf('model', required=True))
|
27 |
+
self.generate_config = SampleConfig(**self.get_conf('sample', required=True))
|
28 |
+
self.dtype = self.get_conf('dtype', 'float16')
|
29 |
+
self.torch_dtype = get_torch_dtype(self.dtype)
|
30 |
+
lorm_config = self.get_conf('lorm', None)
|
31 |
+
self.lorm_config = LoRMConfig(**lorm_config) if lorm_config is not None else None
|
32 |
+
|
33 |
+
self.device_state_preset = get_train_sd_device_state_preset(
|
34 |
+
device=torch.device(self.device),
|
35 |
+
)
|
36 |
+
|
37 |
+
self.progress_bar = None
|
38 |
+
self.sd = StableDiffusion(
|
39 |
+
device=self.device,
|
40 |
+
model_config=self.model_config,
|
41 |
+
dtype=self.dtype,
|
42 |
+
)
|
43 |
+
|
44 |
+
def run(self):
|
45 |
+
super().run()
|
46 |
+
print("Loading model...")
|
47 |
+
with torch.no_grad():
|
48 |
+
self.sd.load_model()
|
49 |
+
self.sd.unet.eval()
|
50 |
+
self.sd.unet.to(self.device_torch)
|
51 |
+
if isinstance(self.sd.text_encoder, list):
|
52 |
+
for te in self.sd.text_encoder:
|
53 |
+
te.eval()
|
54 |
+
te.to(self.device_torch)
|
55 |
+
else:
|
56 |
+
self.sd.text_encoder.eval()
|
57 |
+
self.sd.to(self.device_torch)
|
58 |
+
|
59 |
+
print(f"Converting to LoRM UNet")
|
60 |
+
# replace the unet with LoRMUnet
|
61 |
+
convert_diffusers_unet_to_lorm(
|
62 |
+
self.sd.unet,
|
63 |
+
config=self.lorm_config,
|
64 |
+
)
|
65 |
+
|
66 |
+
sample_folder = os.path.join(self.output_folder)
|
67 |
+
gen_img_config_list = []
|
68 |
+
|
69 |
+
sample_config = self.generate_config
|
70 |
+
start_seed = sample_config.seed
|
71 |
+
current_seed = start_seed
|
72 |
+
for i in range(len(sample_config.prompts)):
|
73 |
+
if sample_config.walk_seed:
|
74 |
+
current_seed = start_seed + i
|
75 |
+
|
76 |
+
filename = f"[time]_[count].{self.generate_config.ext}"
|
77 |
+
output_path = os.path.join(sample_folder, filename)
|
78 |
+
prompt = sample_config.prompts[i]
|
79 |
+
extra_args = {}
|
80 |
+
gen_img_config_list.append(GenerateImageConfig(
|
81 |
+
prompt=prompt, # it will autoparse the prompt
|
82 |
+
width=sample_config.width,
|
83 |
+
height=sample_config.height,
|
84 |
+
negative_prompt=sample_config.neg,
|
85 |
+
seed=current_seed,
|
86 |
+
guidance_scale=sample_config.guidance_scale,
|
87 |
+
guidance_rescale=sample_config.guidance_rescale,
|
88 |
+
num_inference_steps=sample_config.sample_steps,
|
89 |
+
network_multiplier=sample_config.network_multiplier,
|
90 |
+
output_path=output_path,
|
91 |
+
output_ext=sample_config.ext,
|
92 |
+
adapter_conditioning_scale=sample_config.adapter_conditioning_scale,
|
93 |
+
**extra_args
|
94 |
+
))
|
95 |
+
|
96 |
+
# send to be generated
|
97 |
+
self.sd.generate_images(gen_img_config_list, sampler=sample_config.sampler)
|
98 |
+
print("Done generating images")
|
99 |
+
# cleanup
|
100 |
+
del self.sd
|
101 |
+
gc.collect()
|
102 |
+
torch.cuda.empty_cache()
|
extensions_built_in/advanced_generator/ReferenceGenerator.py
ADDED
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
from collections import OrderedDict
|
4 |
+
from typing import List
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
from PIL import Image
|
8 |
+
from diffusers import T2IAdapter
|
9 |
+
from torch.utils.data import DataLoader
|
10 |
+
from diffusers import StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline
|
11 |
+
from tqdm import tqdm
|
12 |
+
|
13 |
+
from toolkit.config_modules import ModelConfig, GenerateImageConfig, preprocess_dataset_raw_config, DatasetConfig
|
14 |
+
from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO
|
15 |
+
from toolkit.sampler import get_sampler
|
16 |
+
from toolkit.stable_diffusion_model import StableDiffusion
|
17 |
+
import gc
|
18 |
+
import torch
|
19 |
+
from jobs.process import BaseExtensionProcess
|
20 |
+
from toolkit.data_loader import get_dataloader_from_datasets
|
21 |
+
from toolkit.train_tools import get_torch_dtype
|
22 |
+
from controlnet_aux.midas import MidasDetector
|
23 |
+
from diffusers.utils import load_image
|
24 |
+
|
25 |
+
|
26 |
+
def flush():
|
27 |
+
torch.cuda.empty_cache()
|
28 |
+
gc.collect()
|
29 |
+
|
30 |
+
|
31 |
+
class GenerateConfig:
|
32 |
+
|
33 |
+
def __init__(self, **kwargs):
|
34 |
+
self.prompts: List[str]
|
35 |
+
self.sampler = kwargs.get('sampler', 'ddpm')
|
36 |
+
self.neg = kwargs.get('neg', '')
|
37 |
+
self.seed = kwargs.get('seed', -1)
|
38 |
+
self.walk_seed = kwargs.get('walk_seed', False)
|
39 |
+
self.t2i_adapter_path = kwargs.get('t2i_adapter_path', None)
|
40 |
+
self.guidance_scale = kwargs.get('guidance_scale', 7)
|
41 |
+
self.sample_steps = kwargs.get('sample_steps', 20)
|
42 |
+
self.prompt_2 = kwargs.get('prompt_2', None)
|
43 |
+
self.neg_2 = kwargs.get('neg_2', None)
|
44 |
+
self.prompts = kwargs.get('prompts', None)
|
45 |
+
self.guidance_rescale = kwargs.get('guidance_rescale', 0.0)
|
46 |
+
self.ext = kwargs.get('ext', 'png')
|
47 |
+
self.adapter_conditioning_scale = kwargs.get('adapter_conditioning_scale', 1.0)
|
48 |
+
if kwargs.get('shuffle', False):
|
49 |
+
# shuffle the prompts
|
50 |
+
random.shuffle(self.prompts)
|
51 |
+
|
52 |
+
|
53 |
+
class ReferenceGenerator(BaseExtensionProcess):
|
54 |
+
|
55 |
+
def __init__(self, process_id: int, job, config: OrderedDict):
|
56 |
+
super().__init__(process_id, job, config)
|
57 |
+
self.output_folder = self.get_conf('output_folder', required=True)
|
58 |
+
self.device = self.get_conf('device', 'cuda')
|
59 |
+
self.model_config = ModelConfig(**self.get_conf('model', required=True))
|
60 |
+
self.generate_config = GenerateConfig(**self.get_conf('generate', required=True))
|
61 |
+
self.is_latents_cached = True
|
62 |
+
raw_datasets = self.get_conf('datasets', None)
|
63 |
+
if raw_datasets is not None and len(raw_datasets) > 0:
|
64 |
+
raw_datasets = preprocess_dataset_raw_config(raw_datasets)
|
65 |
+
self.datasets = None
|
66 |
+
self.datasets_reg = None
|
67 |
+
self.dtype = self.get_conf('dtype', 'float16')
|
68 |
+
self.torch_dtype = get_torch_dtype(self.dtype)
|
69 |
+
self.params = []
|
70 |
+
if raw_datasets is not None and len(raw_datasets) > 0:
|
71 |
+
for raw_dataset in raw_datasets:
|
72 |
+
dataset = DatasetConfig(**raw_dataset)
|
73 |
+
is_caching = dataset.cache_latents or dataset.cache_latents_to_disk
|
74 |
+
if not is_caching:
|
75 |
+
self.is_latents_cached = False
|
76 |
+
if dataset.is_reg:
|
77 |
+
if self.datasets_reg is None:
|
78 |
+
self.datasets_reg = []
|
79 |
+
self.datasets_reg.append(dataset)
|
80 |
+
else:
|
81 |
+
if self.datasets is None:
|
82 |
+
self.datasets = []
|
83 |
+
self.datasets.append(dataset)
|
84 |
+
|
85 |
+
self.progress_bar = None
|
86 |
+
self.sd = StableDiffusion(
|
87 |
+
device=self.device,
|
88 |
+
model_config=self.model_config,
|
89 |
+
dtype=self.dtype,
|
90 |
+
)
|
91 |
+
print(f"Using device {self.device}")
|
92 |
+
self.data_loader: DataLoader = None
|
93 |
+
self.adapter: T2IAdapter = None
|
94 |
+
|
95 |
+
def run(self):
|
96 |
+
super().run()
|
97 |
+
print("Loading model...")
|
98 |
+
self.sd.load_model()
|
99 |
+
device = torch.device(self.device)
|
100 |
+
|
101 |
+
if self.generate_config.t2i_adapter_path is not None:
|
102 |
+
self.adapter = T2IAdapter.from_pretrained(
|
103 |
+
self.generate_config.t2i_adapter_path,
|
104 |
+
torch_dtype=self.torch_dtype,
|
105 |
+
varient="fp16"
|
106 |
+
).to(device)
|
107 |
+
|
108 |
+
midas_depth = MidasDetector.from_pretrained(
|
109 |
+
"valhalla/t2iadapter-aux-models", filename="dpt_large_384.pt", model_type="dpt_large"
|
110 |
+
).to(device)
|
111 |
+
|
112 |
+
if self.model_config.is_xl:
|
113 |
+
pipe = StableDiffusionXLAdapterPipeline(
|
114 |
+
vae=self.sd.vae,
|
115 |
+
unet=self.sd.unet,
|
116 |
+
text_encoder=self.sd.text_encoder[0],
|
117 |
+
text_encoder_2=self.sd.text_encoder[1],
|
118 |
+
tokenizer=self.sd.tokenizer[0],
|
119 |
+
tokenizer_2=self.sd.tokenizer[1],
|
120 |
+
scheduler=get_sampler(self.generate_config.sampler),
|
121 |
+
adapter=self.adapter,
|
122 |
+
).to(device, dtype=self.torch_dtype)
|
123 |
+
else:
|
124 |
+
pipe = StableDiffusionAdapterPipeline(
|
125 |
+
vae=self.sd.vae,
|
126 |
+
unet=self.sd.unet,
|
127 |
+
text_encoder=self.sd.text_encoder,
|
128 |
+
tokenizer=self.sd.tokenizer,
|
129 |
+
scheduler=get_sampler(self.generate_config.sampler),
|
130 |
+
safety_checker=None,
|
131 |
+
feature_extractor=None,
|
132 |
+
requires_safety_checker=False,
|
133 |
+
adapter=self.adapter,
|
134 |
+
).to(device, dtype=self.torch_dtype)
|
135 |
+
pipe.set_progress_bar_config(disable=True)
|
136 |
+
|
137 |
+
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
|
138 |
+
# midas_depth = torch.compile(midas_depth, mode="reduce-overhead", fullgraph=True)
|
139 |
+
|
140 |
+
self.data_loader = get_dataloader_from_datasets(self.datasets, 1, self.sd)
|
141 |
+
|
142 |
+
num_batches = len(self.data_loader)
|
143 |
+
pbar = tqdm(total=num_batches, desc="Generating images")
|
144 |
+
seed = self.generate_config.seed
|
145 |
+
# load images from datasets, use tqdm
|
146 |
+
for i, batch in enumerate(self.data_loader):
|
147 |
+
batch: DataLoaderBatchDTO = batch
|
148 |
+
|
149 |
+
file_item: FileItemDTO = batch.file_items[0]
|
150 |
+
img_path = file_item.path
|
151 |
+
img_filename = os.path.basename(img_path)
|
152 |
+
img_filename_no_ext = os.path.splitext(img_filename)[0]
|
153 |
+
output_path = os.path.join(self.output_folder, img_filename)
|
154 |
+
output_caption_path = os.path.join(self.output_folder, img_filename_no_ext + '.txt')
|
155 |
+
output_depth_path = os.path.join(self.output_folder, img_filename_no_ext + '.depth.png')
|
156 |
+
|
157 |
+
caption = batch.get_caption_list()[0]
|
158 |
+
|
159 |
+
img: torch.Tensor = batch.tensor.clone()
|
160 |
+
# image comes in -1 to 1. convert to a PIL RGB image
|
161 |
+
img = (img + 1) / 2
|
162 |
+
img = img.clamp(0, 1)
|
163 |
+
img = img[0].permute(1, 2, 0).cpu().numpy()
|
164 |
+
img = (img * 255).astype(np.uint8)
|
165 |
+
image = Image.fromarray(img)
|
166 |
+
|
167 |
+
width, height = image.size
|
168 |
+
min_res = min(width, height)
|
169 |
+
|
170 |
+
if self.generate_config.walk_seed:
|
171 |
+
seed = seed + 1
|
172 |
+
|
173 |
+
if self.generate_config.seed == -1:
|
174 |
+
# random
|
175 |
+
seed = random.randint(0, 1000000)
|
176 |
+
|
177 |
+
torch.manual_seed(seed)
|
178 |
+
torch.cuda.manual_seed(seed)
|
179 |
+
|
180 |
+
# generate depth map
|
181 |
+
image = midas_depth(
|
182 |
+
image,
|
183 |
+
detect_resolution=min_res, # do 512 ?
|
184 |
+
image_resolution=min_res
|
185 |
+
)
|
186 |
+
|
187 |
+
# image.save(output_depth_path)
|
188 |
+
|
189 |
+
gen_images = pipe(
|
190 |
+
prompt=caption,
|
191 |
+
negative_prompt=self.generate_config.neg,
|
192 |
+
image=image,
|
193 |
+
num_inference_steps=self.generate_config.sample_steps,
|
194 |
+
adapter_conditioning_scale=self.generate_config.adapter_conditioning_scale,
|
195 |
+
guidance_scale=self.generate_config.guidance_scale,
|
196 |
+
).images[0]
|
197 |
+
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
198 |
+
gen_images.save(output_path)
|
199 |
+
|
200 |
+
# save caption
|
201 |
+
with open(output_caption_path, 'w') as f:
|
202 |
+
f.write(caption)
|
203 |
+
|
204 |
+
pbar.update(1)
|
205 |
+
batch.cleanup()
|
206 |
+
|
207 |
+
pbar.close()
|
208 |
+
print("Done generating images")
|
209 |
+
# cleanup
|
210 |
+
del self.sd
|
211 |
+
gc.collect()
|
212 |
+
torch.cuda.empty_cache()
|
extensions_built_in/advanced_generator/__init__.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This is an example extension for custom training. It is great for experimenting with new ideas.
|
2 |
+
from toolkit.extension import Extension
|
3 |
+
|
4 |
+
|
5 |
+
# This is for generic training (LoRA, Dreambooth, FineTuning)
|
6 |
+
class AdvancedReferenceGeneratorExtension(Extension):
|
7 |
+
# uid must be unique, it is how the extension is identified
|
8 |
+
uid = "reference_generator"
|
9 |
+
|
10 |
+
# name is the name of the extension for printing
|
11 |
+
name = "Reference Generator"
|
12 |
+
|
13 |
+
# This is where your process class is loaded
|
14 |
+
# keep your imports in here so they don't slow down the rest of the program
|
15 |
+
@classmethod
|
16 |
+
def get_process(cls):
|
17 |
+
# import your process class here so it is only loaded when needed and return it
|
18 |
+
from .ReferenceGenerator import ReferenceGenerator
|
19 |
+
return ReferenceGenerator
|
20 |
+
|
21 |
+
|
22 |
+
# This is for generic training (LoRA, Dreambooth, FineTuning)
|
23 |
+
class PureLoraGenerator(Extension):
|
24 |
+
# uid must be unique, it is how the extension is identified
|
25 |
+
uid = "pure_lora_generator"
|
26 |
+
|
27 |
+
# name is the name of the extension for printing
|
28 |
+
name = "Pure LoRA Generator"
|
29 |
+
|
30 |
+
# This is where your process class is loaded
|
31 |
+
# keep your imports in here so they don't slow down the rest of the program
|
32 |
+
@classmethod
|
33 |
+
def get_process(cls):
|
34 |
+
# import your process class here so it is only loaded when needed and return it
|
35 |
+
from .PureLoraGenerator import PureLoraGenerator
|
36 |
+
return PureLoraGenerator
|
37 |
+
|
38 |
+
|
39 |
+
# This is for generic training (LoRA, Dreambooth, FineTuning)
|
40 |
+
class Img2ImgGeneratorExtension(Extension):
|
41 |
+
# uid must be unique, it is how the extension is identified
|
42 |
+
uid = "batch_img2img"
|
43 |
+
|
44 |
+
# name is the name of the extension for printing
|
45 |
+
name = "Img2ImgGeneratorExtension"
|
46 |
+
|
47 |
+
# This is where your process class is loaded
|
48 |
+
# keep your imports in here so they don't slow down the rest of the program
|
49 |
+
@classmethod
|
50 |
+
def get_process(cls):
|
51 |
+
# import your process class here so it is only loaded when needed and return it
|
52 |
+
from .Img2ImgGenerator import Img2ImgGenerator
|
53 |
+
return Img2ImgGenerator
|
54 |
+
|
55 |
+
|
56 |
+
AI_TOOLKIT_EXTENSIONS = [
|
57 |
+
# you can put a list of extensions here
|
58 |
+
AdvancedReferenceGeneratorExtension, PureLoraGenerator, Img2ImgGeneratorExtension
|
59 |
+
]
|
extensions_built_in/advanced_generator/config/train.example.yaml
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
job: extension
|
3 |
+
config:
|
4 |
+
name: test_v1
|
5 |
+
process:
|
6 |
+
- type: 'textual_inversion_trainer'
|
7 |
+
training_folder: "out/TI"
|
8 |
+
device: cuda:0
|
9 |
+
# for tensorboard logging
|
10 |
+
log_dir: "out/.tensorboard"
|
11 |
+
embedding:
|
12 |
+
trigger: "your_trigger_here"
|
13 |
+
tokens: 12
|
14 |
+
init_words: "man with short brown hair"
|
15 |
+
save_format: "safetensors" # 'safetensors' or 'pt'
|
16 |
+
save:
|
17 |
+
dtype: float16 # precision to save
|
18 |
+
save_every: 100 # save every this many steps
|
19 |
+
max_step_saves_to_keep: 5 # only affects step counts
|
20 |
+
datasets:
|
21 |
+
- folder_path: "/path/to/dataset"
|
22 |
+
caption_ext: "txt"
|
23 |
+
default_caption: "[trigger]"
|
24 |
+
buckets: true
|
25 |
+
resolution: 512
|
26 |
+
train:
|
27 |
+
noise_scheduler: "ddpm" # or "ddpm", "lms", "euler_a"
|
28 |
+
steps: 3000
|
29 |
+
weight_jitter: 0.0
|
30 |
+
lr: 5e-5
|
31 |
+
train_unet: false
|
32 |
+
gradient_checkpointing: true
|
33 |
+
train_text_encoder: false
|
34 |
+
optimizer: "adamw"
|
35 |
+
# optimizer: "prodigy"
|
36 |
+
optimizer_params:
|
37 |
+
weight_decay: 1e-2
|
38 |
+
lr_scheduler: "constant"
|
39 |
+
max_denoising_steps: 1000
|
40 |
+
batch_size: 4
|
41 |
+
dtype: bf16
|
42 |
+
xformers: true
|
43 |
+
min_snr_gamma: 5.0
|
44 |
+
# skip_first_sample: true
|
45 |
+
noise_offset: 0.0 # not needed for this
|
46 |
+
model:
|
47 |
+
# objective reality v2
|
48 |
+
name_or_path: "https://civitai.com/models/128453?modelVersionId=142465"
|
49 |
+
is_v2: false # for v2 models
|
50 |
+
is_xl: false # for SDXL models
|
51 |
+
is_v_pred: false # for v-prediction models (most v2 models)
|
52 |
+
sample:
|
53 |
+
sampler: "ddpm" # must match train.noise_scheduler
|
54 |
+
sample_every: 100 # sample every this many steps
|
55 |
+
width: 512
|
56 |
+
height: 512
|
57 |
+
prompts:
|
58 |
+
- "photo of [trigger] laughing"
|
59 |
+
- "photo of [trigger] smiling"
|
60 |
+
- "[trigger] close up"
|
61 |
+
- "dark scene [trigger] frozen"
|
62 |
+
- "[trigger] nighttime"
|
63 |
+
- "a painting of [trigger]"
|
64 |
+
- "a drawing of [trigger]"
|
65 |
+
- "a cartoon of [trigger]"
|
66 |
+
- "[trigger] pixar style"
|
67 |
+
- "[trigger] costume"
|
68 |
+
neg: ""
|
69 |
+
seed: 42
|
70 |
+
walk_seed: false
|
71 |
+
guidance_scale: 7
|
72 |
+
sample_steps: 20
|
73 |
+
network_multiplier: 1.0
|
74 |
+
|
75 |
+
logging:
|
76 |
+
log_every: 10 # log every this many steps
|
77 |
+
use_wandb: false # not supported yet
|
78 |
+
verbose: false
|
79 |
+
|
80 |
+
# You can put any information you want here, and it will be saved in the model.
|
81 |
+
# The below is an example, but you can put your grocery list in it if you want.
|
82 |
+
# It is saved in the model so be aware of that. The software will include this
|
83 |
+
# plus some other information for you automatically
|
84 |
+
meta:
|
85 |
+
# [name] gets replaced with the name above
|
86 |
+
name: "[name]"
|
87 |
+
# version: '1.0'
|
88 |
+
# creator:
|
89 |
+
# name: Your Name
|
90 |
+
# email: [email protected]
|
91 |
+
# website: https://your.website
|
extensions_built_in/concept_replacer/ConceptReplacer.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
from collections import OrderedDict
|
3 |
+
from torch.utils.data import DataLoader
|
4 |
+
from toolkit.prompt_utils import concat_prompt_embeds, split_prompt_embeds
|
5 |
+
from toolkit.stable_diffusion_model import StableDiffusion, BlankNetwork
|
6 |
+
from toolkit.train_tools import get_torch_dtype, apply_snr_weight
|
7 |
+
import gc
|
8 |
+
import torch
|
9 |
+
from jobs.process import BaseSDTrainProcess
|
10 |
+
|
11 |
+
|
12 |
+
def flush():
|
13 |
+
torch.cuda.empty_cache()
|
14 |
+
gc.collect()
|
15 |
+
|
16 |
+
|
17 |
+
class ConceptReplacementConfig:
|
18 |
+
def __init__(self, **kwargs):
|
19 |
+
self.concept: str = kwargs.get('concept', '')
|
20 |
+
self.replacement: str = kwargs.get('replacement', '')
|
21 |
+
|
22 |
+
|
23 |
+
class ConceptReplacer(BaseSDTrainProcess):
|
24 |
+
|
25 |
+
def __init__(self, process_id: int, job, config: OrderedDict, **kwargs):
|
26 |
+
super().__init__(process_id, job, config, **kwargs)
|
27 |
+
replacement_list = self.config.get('replacements', [])
|
28 |
+
self.replacement_list = [ConceptReplacementConfig(**x) for x in replacement_list]
|
29 |
+
|
30 |
+
def before_model_load(self):
|
31 |
+
pass
|
32 |
+
|
33 |
+
def hook_before_train_loop(self):
|
34 |
+
self.sd.vae.eval()
|
35 |
+
self.sd.vae.to(self.device_torch)
|
36 |
+
|
37 |
+
# textual inversion
|
38 |
+
if self.embedding is not None:
|
39 |
+
# set text encoder to train. Not sure if this is necessary but diffusers example did it
|
40 |
+
self.sd.text_encoder.train()
|
41 |
+
|
42 |
+
def hook_train_loop(self, batch):
|
43 |
+
with torch.no_grad():
|
44 |
+
dtype = get_torch_dtype(self.train_config.dtype)
|
45 |
+
noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch)
|
46 |
+
network_weight_list = batch.get_network_weight_list()
|
47 |
+
|
48 |
+
# have a blank network so we can wrap it in a context and set multipliers without checking every time
|
49 |
+
if self.network is not None:
|
50 |
+
network = self.network
|
51 |
+
else:
|
52 |
+
network = BlankNetwork()
|
53 |
+
|
54 |
+
batch_replacement_list = []
|
55 |
+
# get a random replacement for each prompt
|
56 |
+
for prompt in conditioned_prompts:
|
57 |
+
replacement = random.choice(self.replacement_list)
|
58 |
+
batch_replacement_list.append(replacement)
|
59 |
+
|
60 |
+
# build out prompts
|
61 |
+
concept_prompts = []
|
62 |
+
replacement_prompts = []
|
63 |
+
for idx, replacement in enumerate(batch_replacement_list):
|
64 |
+
prompt = conditioned_prompts[idx]
|
65 |
+
|
66 |
+
# insert shuffled concept at beginning and end of prompt
|
67 |
+
shuffled_concept = [x.strip() for x in replacement.concept.split(',')]
|
68 |
+
random.shuffle(shuffled_concept)
|
69 |
+
shuffled_concept = ', '.join(shuffled_concept)
|
70 |
+
concept_prompts.append(f"{shuffled_concept}, {prompt}, {shuffled_concept}")
|
71 |
+
|
72 |
+
# insert replacement at beginning and end of prompt
|
73 |
+
shuffled_replacement = [x.strip() for x in replacement.replacement.split(',')]
|
74 |
+
random.shuffle(shuffled_replacement)
|
75 |
+
shuffled_replacement = ', '.join(shuffled_replacement)
|
76 |
+
replacement_prompts.append(f"{shuffled_replacement}, {prompt}, {shuffled_replacement}")
|
77 |
+
|
78 |
+
# predict the replacement without network
|
79 |
+
conditional_embeds = self.sd.encode_prompt(replacement_prompts).to(self.device_torch, dtype=dtype)
|
80 |
+
|
81 |
+
replacement_pred = self.sd.predict_noise(
|
82 |
+
latents=noisy_latents.to(self.device_torch, dtype=dtype),
|
83 |
+
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype),
|
84 |
+
timestep=timesteps,
|
85 |
+
guidance_scale=1.0,
|
86 |
+
)
|
87 |
+
|
88 |
+
del conditional_embeds
|
89 |
+
replacement_pred = replacement_pred.detach()
|
90 |
+
|
91 |
+
self.optimizer.zero_grad()
|
92 |
+
flush()
|
93 |
+
|
94 |
+
# text encoding
|
95 |
+
grad_on_text_encoder = False
|
96 |
+
if self.train_config.train_text_encoder:
|
97 |
+
grad_on_text_encoder = True
|
98 |
+
|
99 |
+
if self.embedding:
|
100 |
+
grad_on_text_encoder = True
|
101 |
+
|
102 |
+
# set the weights
|
103 |
+
network.multiplier = network_weight_list
|
104 |
+
|
105 |
+
# activate network if it exits
|
106 |
+
with network:
|
107 |
+
with torch.set_grad_enabled(grad_on_text_encoder):
|
108 |
+
# embed the prompts
|
109 |
+
conditional_embeds = self.sd.encode_prompt(concept_prompts).to(self.device_torch, dtype=dtype)
|
110 |
+
if not grad_on_text_encoder:
|
111 |
+
# detach the embeddings
|
112 |
+
conditional_embeds = conditional_embeds.detach()
|
113 |
+
self.optimizer.zero_grad()
|
114 |
+
flush()
|
115 |
+
|
116 |
+
noise_pred = self.sd.predict_noise(
|
117 |
+
latents=noisy_latents.to(self.device_torch, dtype=dtype),
|
118 |
+
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype),
|
119 |
+
timestep=timesteps,
|
120 |
+
guidance_scale=1.0,
|
121 |
+
)
|
122 |
+
|
123 |
+
loss = torch.nn.functional.mse_loss(noise_pred.float(), replacement_pred.float(), reduction="none")
|
124 |
+
loss = loss.mean([1, 2, 3])
|
125 |
+
|
126 |
+
if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001:
|
127 |
+
# add min_snr_gamma
|
128 |
+
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma)
|
129 |
+
|
130 |
+
loss = loss.mean()
|
131 |
+
|
132 |
+
# back propagate loss to free ram
|
133 |
+
loss.backward()
|
134 |
+
flush()
|
135 |
+
|
136 |
+
# apply gradients
|
137 |
+
self.optimizer.step()
|
138 |
+
self.optimizer.zero_grad()
|
139 |
+
self.lr_scheduler.step()
|
140 |
+
|
141 |
+
if self.embedding is not None:
|
142 |
+
# Let's make sure we don't update any embedding weights besides the newly added token
|
143 |
+
self.embedding.restore_embeddings()
|
144 |
+
|
145 |
+
loss_dict = OrderedDict(
|
146 |
+
{'loss': loss.item()}
|
147 |
+
)
|
148 |
+
# reset network multiplier
|
149 |
+
network.multiplier = 1.0
|
150 |
+
|
151 |
+
return loss_dict
|
extensions_built_in/concept_replacer/__init__.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This is an example extension for custom training. It is great for experimenting with new ideas.
|
2 |
+
from toolkit.extension import Extension
|
3 |
+
|
4 |
+
|
5 |
+
# This is for generic training (LoRA, Dreambooth, FineTuning)
|
6 |
+
class ConceptReplacerExtension(Extension):
|
7 |
+
# uid must be unique, it is how the extension is identified
|
8 |
+
uid = "concept_replacer"
|
9 |
+
|
10 |
+
# name is the name of the extension for printing
|
11 |
+
name = "Concept Replacer"
|
12 |
+
|
13 |
+
# This is where your process class is loaded
|
14 |
+
# keep your imports in here so they don't slow down the rest of the program
|
15 |
+
@classmethod
|
16 |
+
def get_process(cls):
|
17 |
+
# import your process class here so it is only loaded when needed and return it
|
18 |
+
from .ConceptReplacer import ConceptReplacer
|
19 |
+
return ConceptReplacer
|
20 |
+
|
21 |
+
|
22 |
+
|
23 |
+
AI_TOOLKIT_EXTENSIONS = [
|
24 |
+
# you can put a list of extensions here
|
25 |
+
ConceptReplacerExtension,
|
26 |
+
]
|
extensions_built_in/concept_replacer/config/train.example.yaml
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
job: extension
|
3 |
+
config:
|
4 |
+
name: test_v1
|
5 |
+
process:
|
6 |
+
- type: 'textual_inversion_trainer'
|
7 |
+
training_folder: "out/TI"
|
8 |
+
device: cuda:0
|
9 |
+
# for tensorboard logging
|
10 |
+
log_dir: "out/.tensorboard"
|
11 |
+
embedding:
|
12 |
+
trigger: "your_trigger_here"
|
13 |
+
tokens: 12
|
14 |
+
init_words: "man with short brown hair"
|
15 |
+
save_format: "safetensors" # 'safetensors' or 'pt'
|
16 |
+
save:
|
17 |
+
dtype: float16 # precision to save
|
18 |
+
save_every: 100 # save every this many steps
|
19 |
+
max_step_saves_to_keep: 5 # only affects step counts
|
20 |
+
datasets:
|
21 |
+
- folder_path: "/path/to/dataset"
|
22 |
+
caption_ext: "txt"
|
23 |
+
default_caption: "[trigger]"
|
24 |
+
buckets: true
|
25 |
+
resolution: 512
|
26 |
+
train:
|
27 |
+
noise_scheduler: "ddpm" # or "ddpm", "lms", "euler_a"
|
28 |
+
steps: 3000
|
29 |
+
weight_jitter: 0.0
|
30 |
+
lr: 5e-5
|
31 |
+
train_unet: false
|
32 |
+
gradient_checkpointing: true
|
33 |
+
train_text_encoder: false
|
34 |
+
optimizer: "adamw"
|
35 |
+
# optimizer: "prodigy"
|
36 |
+
optimizer_params:
|
37 |
+
weight_decay: 1e-2
|
38 |
+
lr_scheduler: "constant"
|
39 |
+
max_denoising_steps: 1000
|
40 |
+
batch_size: 4
|
41 |
+
dtype: bf16
|
42 |
+
xformers: true
|
43 |
+
min_snr_gamma: 5.0
|
44 |
+
# skip_first_sample: true
|
45 |
+
noise_offset: 0.0 # not needed for this
|
46 |
+
model:
|
47 |
+
# objective reality v2
|
48 |
+
name_or_path: "https://civitai.com/models/128453?modelVersionId=142465"
|
49 |
+
is_v2: false # for v2 models
|
50 |
+
is_xl: false # for SDXL models
|
51 |
+
is_v_pred: false # for v-prediction models (most v2 models)
|
52 |
+
sample:
|
53 |
+
sampler: "ddpm" # must match train.noise_scheduler
|
54 |
+
sample_every: 100 # sample every this many steps
|
55 |
+
width: 512
|
56 |
+
height: 512
|
57 |
+
prompts:
|
58 |
+
- "photo of [trigger] laughing"
|
59 |
+
- "photo of [trigger] smiling"
|
60 |
+
- "[trigger] close up"
|
61 |
+
- "dark scene [trigger] frozen"
|
62 |
+
- "[trigger] nighttime"
|
63 |
+
- "a painting of [trigger]"
|
64 |
+
- "a drawing of [trigger]"
|
65 |
+
- "a cartoon of [trigger]"
|
66 |
+
- "[trigger] pixar style"
|
67 |
+
- "[trigger] costume"
|
68 |
+
neg: ""
|
69 |
+
seed: 42
|
70 |
+
walk_seed: false
|
71 |
+
guidance_scale: 7
|
72 |
+
sample_steps: 20
|
73 |
+
network_multiplier: 1.0
|
74 |
+
|
75 |
+
logging:
|
76 |
+
log_every: 10 # log every this many steps
|
77 |
+
use_wandb: false # not supported yet
|
78 |
+
verbose: false
|
79 |
+
|
80 |
+
# You can put any information you want here, and it will be saved in the model.
|
81 |
+
# The below is an example, but you can put your grocery list in it if you want.
|
82 |
+
# It is saved in the model so be aware of that. The software will include this
|
83 |
+
# plus some other information for you automatically
|
84 |
+
meta:
|
85 |
+
# [name] gets replaced with the name above
|
86 |
+
name: "[name]"
|
87 |
+
# version: '1.0'
|
88 |
+
# creator:
|
89 |
+
# name: Your Name
|
90 |
+
# email: [email protected]
|
91 |
+
# website: https://your.website
|
extensions_built_in/dataset_tools/DatasetTools.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import OrderedDict
|
2 |
+
import gc
|
3 |
+
import torch
|
4 |
+
from jobs.process import BaseExtensionProcess
|
5 |
+
|
6 |
+
|
7 |
+
def flush():
|
8 |
+
torch.cuda.empty_cache()
|
9 |
+
gc.collect()
|
10 |
+
|
11 |
+
|
12 |
+
class DatasetTools(BaseExtensionProcess):
|
13 |
+
|
14 |
+
def __init__(self, process_id: int, job, config: OrderedDict):
|
15 |
+
super().__init__(process_id, job, config)
|
16 |
+
|
17 |
+
def run(self):
|
18 |
+
super().run()
|
19 |
+
|
20 |
+
raise NotImplementedError("This extension is not yet implemented")
|
extensions_built_in/dataset_tools/SuperTagger.py
ADDED
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
from collections import OrderedDict
|
5 |
+
import gc
|
6 |
+
import traceback
|
7 |
+
import torch
|
8 |
+
from PIL import Image, ImageOps
|
9 |
+
from tqdm import tqdm
|
10 |
+
|
11 |
+
from .tools.dataset_tools_config_modules import RAW_DIR, TRAIN_DIR, Step, ImgInfo
|
12 |
+
from .tools.fuyu_utils import FuyuImageProcessor
|
13 |
+
from .tools.image_tools import load_image, ImageProcessor, resize_to_max
|
14 |
+
from .tools.llava_utils import LLaVAImageProcessor
|
15 |
+
from .tools.caption import default_long_prompt, default_short_prompt, default_replacements
|
16 |
+
from jobs.process import BaseExtensionProcess
|
17 |
+
from .tools.sync_tools import get_img_paths
|
18 |
+
|
19 |
+
img_ext = ['.jpg', '.jpeg', '.png', '.webp']
|
20 |
+
|
21 |
+
|
22 |
+
def flush():
|
23 |
+
torch.cuda.empty_cache()
|
24 |
+
gc.collect()
|
25 |
+
|
26 |
+
|
27 |
+
VERSION = 2
|
28 |
+
|
29 |
+
|
30 |
+
class SuperTagger(BaseExtensionProcess):
|
31 |
+
|
32 |
+
def __init__(self, process_id: int, job, config: OrderedDict):
|
33 |
+
super().__init__(process_id, job, config)
|
34 |
+
parent_dir = config.get('parent_dir', None)
|
35 |
+
self.dataset_paths: list[str] = config.get('dataset_paths', [])
|
36 |
+
self.device = config.get('device', 'cuda')
|
37 |
+
self.steps: list[Step] = config.get('steps', [])
|
38 |
+
self.caption_method = config.get('caption_method', 'llava:default')
|
39 |
+
self.caption_prompt = config.get('caption_prompt', default_long_prompt)
|
40 |
+
self.caption_short_prompt = config.get('caption_short_prompt', default_short_prompt)
|
41 |
+
self.force_reprocess_img = config.get('force_reprocess_img', False)
|
42 |
+
self.caption_replacements = config.get('caption_replacements', default_replacements)
|
43 |
+
self.caption_short_replacements = config.get('caption_short_replacements', default_replacements)
|
44 |
+
self.master_dataset_dict = OrderedDict()
|
45 |
+
self.dataset_master_config_file = config.get('dataset_master_config_file', None)
|
46 |
+
if parent_dir is not None and len(self.dataset_paths) == 0:
|
47 |
+
# find all folders in the patent_dataset_path
|
48 |
+
self.dataset_paths = [
|
49 |
+
os.path.join(parent_dir, folder)
|
50 |
+
for folder in os.listdir(parent_dir)
|
51 |
+
if os.path.isdir(os.path.join(parent_dir, folder))
|
52 |
+
]
|
53 |
+
else:
|
54 |
+
# make sure they exist
|
55 |
+
for dataset_path in self.dataset_paths:
|
56 |
+
if not os.path.exists(dataset_path):
|
57 |
+
raise ValueError(f"Dataset path does not exist: {dataset_path}")
|
58 |
+
|
59 |
+
print(f"Found {len(self.dataset_paths)} dataset paths")
|
60 |
+
|
61 |
+
self.image_processor: ImageProcessor = self.get_image_processor()
|
62 |
+
|
63 |
+
def get_image_processor(self):
|
64 |
+
if self.caption_method.startswith('llava'):
|
65 |
+
return LLaVAImageProcessor(device=self.device)
|
66 |
+
elif self.caption_method.startswith('fuyu'):
|
67 |
+
return FuyuImageProcessor(device=self.device)
|
68 |
+
else:
|
69 |
+
raise ValueError(f"Unknown caption method: {self.caption_method}")
|
70 |
+
|
71 |
+
def process_image(self, img_path: str):
|
72 |
+
root_img_dir = os.path.dirname(os.path.dirname(img_path))
|
73 |
+
filename = os.path.basename(img_path)
|
74 |
+
filename_no_ext = os.path.splitext(filename)[0]
|
75 |
+
train_dir = os.path.join(root_img_dir, TRAIN_DIR)
|
76 |
+
train_img_path = os.path.join(train_dir, filename)
|
77 |
+
json_path = os.path.join(train_dir, f"{filename_no_ext}.json")
|
78 |
+
|
79 |
+
# check if json exists, if it does load it as image info
|
80 |
+
if os.path.exists(json_path):
|
81 |
+
with open(json_path, 'r') as f:
|
82 |
+
img_info = ImgInfo(**json.load(f))
|
83 |
+
else:
|
84 |
+
img_info = ImgInfo()
|
85 |
+
|
86 |
+
# always send steps first in case other processes need them
|
87 |
+
img_info.add_steps(copy.deepcopy(self.steps))
|
88 |
+
img_info.set_version(VERSION)
|
89 |
+
img_info.set_caption_method(self.caption_method)
|
90 |
+
|
91 |
+
image: Image = None
|
92 |
+
caption_image: Image = None
|
93 |
+
|
94 |
+
did_update_image = False
|
95 |
+
|
96 |
+
# trigger reprocess of steps
|
97 |
+
if self.force_reprocess_img:
|
98 |
+
img_info.trigger_image_reprocess()
|
99 |
+
|
100 |
+
# set the image as updated if it does not exist on disk
|
101 |
+
if not os.path.exists(train_img_path):
|
102 |
+
did_update_image = True
|
103 |
+
image = load_image(img_path)
|
104 |
+
if img_info.force_image_process:
|
105 |
+
did_update_image = True
|
106 |
+
image = load_image(img_path)
|
107 |
+
|
108 |
+
# go through the needed steps
|
109 |
+
for step in copy.deepcopy(img_info.state.steps_to_complete):
|
110 |
+
if step == 'caption':
|
111 |
+
# load image
|
112 |
+
if image is None:
|
113 |
+
image = load_image(img_path)
|
114 |
+
if caption_image is None:
|
115 |
+
caption_image = resize_to_max(image, 1024, 1024)
|
116 |
+
|
117 |
+
if not self.image_processor.is_loaded:
|
118 |
+
print('Loading Model. Takes a while, especially the first time')
|
119 |
+
self.image_processor.load_model()
|
120 |
+
|
121 |
+
img_info.caption = self.image_processor.generate_caption(
|
122 |
+
image=caption_image,
|
123 |
+
prompt=self.caption_prompt,
|
124 |
+
replacements=self.caption_replacements
|
125 |
+
)
|
126 |
+
img_info.mark_step_complete(step)
|
127 |
+
elif step == 'caption_short':
|
128 |
+
# load image
|
129 |
+
if image is None:
|
130 |
+
image = load_image(img_path)
|
131 |
+
|
132 |
+
if caption_image is None:
|
133 |
+
caption_image = resize_to_max(image, 1024, 1024)
|
134 |
+
|
135 |
+
if not self.image_processor.is_loaded:
|
136 |
+
print('Loading Model. Takes a while, especially the first time')
|
137 |
+
self.image_processor.load_model()
|
138 |
+
img_info.caption_short = self.image_processor.generate_caption(
|
139 |
+
image=caption_image,
|
140 |
+
prompt=self.caption_short_prompt,
|
141 |
+
replacements=self.caption_short_replacements
|
142 |
+
)
|
143 |
+
img_info.mark_step_complete(step)
|
144 |
+
elif step == 'contrast_stretch':
|
145 |
+
# load image
|
146 |
+
if image is None:
|
147 |
+
image = load_image(img_path)
|
148 |
+
image = ImageOps.autocontrast(image, cutoff=(0.1, 0), preserve_tone=True)
|
149 |
+
did_update_image = True
|
150 |
+
img_info.mark_step_complete(step)
|
151 |
+
else:
|
152 |
+
raise ValueError(f"Unknown step: {step}")
|
153 |
+
|
154 |
+
os.makedirs(os.path.dirname(train_img_path), exist_ok=True)
|
155 |
+
if did_update_image:
|
156 |
+
image.save(train_img_path)
|
157 |
+
|
158 |
+
if img_info.is_dirty:
|
159 |
+
with open(json_path, 'w') as f:
|
160 |
+
json.dump(img_info.to_dict(), f, indent=4)
|
161 |
+
|
162 |
+
if self.dataset_master_config_file:
|
163 |
+
# add to master dict
|
164 |
+
self.master_dataset_dict[train_img_path] = img_info.to_dict()
|
165 |
+
|
166 |
+
def run(self):
|
167 |
+
super().run()
|
168 |
+
imgs_to_process = []
|
169 |
+
# find all images
|
170 |
+
for dataset_path in self.dataset_paths:
|
171 |
+
raw_dir = os.path.join(dataset_path, RAW_DIR)
|
172 |
+
raw_image_paths = get_img_paths(raw_dir)
|
173 |
+
for raw_image_path in raw_image_paths:
|
174 |
+
imgs_to_process.append(raw_image_path)
|
175 |
+
|
176 |
+
if len(imgs_to_process) == 0:
|
177 |
+
print(f"No images to process")
|
178 |
+
else:
|
179 |
+
print(f"Found {len(imgs_to_process)} to process")
|
180 |
+
|
181 |
+
for img_path in tqdm(imgs_to_process, desc="Processing images"):
|
182 |
+
try:
|
183 |
+
self.process_image(img_path)
|
184 |
+
except Exception:
|
185 |
+
# print full stack trace
|
186 |
+
print(traceback.format_exc())
|
187 |
+
continue
|
188 |
+
# self.process_image(img_path)
|
189 |
+
|
190 |
+
if self.dataset_master_config_file is not None:
|
191 |
+
# save it as json
|
192 |
+
with open(self.dataset_master_config_file, 'w') as f:
|
193 |
+
json.dump(self.master_dataset_dict, f, indent=4)
|
194 |
+
|
195 |
+
del self.image_processor
|
196 |
+
flush()
|
extensions_built_in/dataset_tools/SyncFromCollection.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import shutil
|
3 |
+
from collections import OrderedDict
|
4 |
+
import gc
|
5 |
+
from typing import List
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from tqdm import tqdm
|
9 |
+
|
10 |
+
from .tools.dataset_tools_config_modules import DatasetSyncCollectionConfig, RAW_DIR, NEW_DIR
|
11 |
+
from .tools.sync_tools import get_unsplash_images, get_pexels_images, get_local_image_file_names, download_image, \
|
12 |
+
get_img_paths
|
13 |
+
from jobs.process import BaseExtensionProcess
|
14 |
+
|
15 |
+
|
16 |
+
def flush():
|
17 |
+
torch.cuda.empty_cache()
|
18 |
+
gc.collect()
|
19 |
+
|
20 |
+
|
21 |
+
class SyncFromCollection(BaseExtensionProcess):
|
22 |
+
|
23 |
+
def __init__(self, process_id: int, job, config: OrderedDict):
|
24 |
+
super().__init__(process_id, job, config)
|
25 |
+
|
26 |
+
self.min_width = config.get('min_width', 1024)
|
27 |
+
self.min_height = config.get('min_height', 1024)
|
28 |
+
|
29 |
+
# add our min_width and min_height to each dataset config if they don't exist
|
30 |
+
for dataset_config in config.get('dataset_sync', []):
|
31 |
+
if 'min_width' not in dataset_config:
|
32 |
+
dataset_config['min_width'] = self.min_width
|
33 |
+
if 'min_height' not in dataset_config:
|
34 |
+
dataset_config['min_height'] = self.min_height
|
35 |
+
|
36 |
+
self.dataset_configs: List[DatasetSyncCollectionConfig] = [
|
37 |
+
DatasetSyncCollectionConfig(**dataset_config)
|
38 |
+
for dataset_config in config.get('dataset_sync', [])
|
39 |
+
]
|
40 |
+
print(f"Found {len(self.dataset_configs)} dataset configs")
|
41 |
+
|
42 |
+
def move_new_images(self, root_dir: str):
|
43 |
+
raw_dir = os.path.join(root_dir, RAW_DIR)
|
44 |
+
new_dir = os.path.join(root_dir, NEW_DIR)
|
45 |
+
new_images = get_img_paths(new_dir)
|
46 |
+
|
47 |
+
for img_path in new_images:
|
48 |
+
# move to raw
|
49 |
+
new_path = os.path.join(raw_dir, os.path.basename(img_path))
|
50 |
+
shutil.move(img_path, new_path)
|
51 |
+
|
52 |
+
# remove new dir
|
53 |
+
shutil.rmtree(new_dir)
|
54 |
+
|
55 |
+
def sync_dataset(self, config: DatasetSyncCollectionConfig):
|
56 |
+
if config.host == 'unsplash':
|
57 |
+
get_images = get_unsplash_images
|
58 |
+
elif config.host == 'pexels':
|
59 |
+
get_images = get_pexels_images
|
60 |
+
else:
|
61 |
+
raise ValueError(f"Unknown host: {config.host}")
|
62 |
+
|
63 |
+
results = {
|
64 |
+
'num_downloaded': 0,
|
65 |
+
'num_skipped': 0,
|
66 |
+
'bad': 0,
|
67 |
+
'total': 0,
|
68 |
+
}
|
69 |
+
|
70 |
+
photos = get_images(config)
|
71 |
+
raw_dir = os.path.join(config.directory, RAW_DIR)
|
72 |
+
new_dir = os.path.join(config.directory, NEW_DIR)
|
73 |
+
raw_images = get_local_image_file_names(raw_dir)
|
74 |
+
new_images = get_local_image_file_names(new_dir)
|
75 |
+
|
76 |
+
for photo in tqdm(photos, desc=f"{config.host}-{config.collection_id}"):
|
77 |
+
try:
|
78 |
+
if photo.filename not in raw_images and photo.filename not in new_images:
|
79 |
+
download_image(photo, new_dir, min_width=self.min_width, min_height=self.min_height)
|
80 |
+
results['num_downloaded'] += 1
|
81 |
+
else:
|
82 |
+
results['num_skipped'] += 1
|
83 |
+
except Exception as e:
|
84 |
+
print(f" - BAD({photo.id}): {e}")
|
85 |
+
results['bad'] += 1
|
86 |
+
continue
|
87 |
+
results['total'] += 1
|
88 |
+
|
89 |
+
return results
|
90 |
+
|
91 |
+
def print_results(self, results):
|
92 |
+
print(
|
93 |
+
f" - new:{results['num_downloaded']}, old:{results['num_skipped']}, bad:{results['bad']} total:{results['total']}")
|
94 |
+
|
95 |
+
def run(self):
|
96 |
+
super().run()
|
97 |
+
print(f"Syncing {len(self.dataset_configs)} datasets")
|
98 |
+
all_results = None
|
99 |
+
failed_datasets = []
|
100 |
+
for dataset_config in tqdm(self.dataset_configs, desc="Syncing datasets", leave=True):
|
101 |
+
try:
|
102 |
+
results = self.sync_dataset(dataset_config)
|
103 |
+
if all_results is None:
|
104 |
+
all_results = {**results}
|
105 |
+
else:
|
106 |
+
for key, value in results.items():
|
107 |
+
all_results[key] += value
|
108 |
+
|
109 |
+
self.print_results(results)
|
110 |
+
except Exception as e:
|
111 |
+
print(f" - FAILED: {e}")
|
112 |
+
if 'response' in e.__dict__:
|
113 |
+
error = f"{e.response.status_code}: {e.response.text}"
|
114 |
+
print(f" - {error}")
|
115 |
+
failed_datasets.append({'dataset': dataset_config, 'error': error})
|
116 |
+
else:
|
117 |
+
failed_datasets.append({'dataset': dataset_config, 'error': str(e)})
|
118 |
+
continue
|
119 |
+
|
120 |
+
print("Moving new images to raw")
|
121 |
+
for dataset_config in self.dataset_configs:
|
122 |
+
self.move_new_images(dataset_config.directory)
|
123 |
+
|
124 |
+
print("Done syncing datasets")
|
125 |
+
self.print_results(all_results)
|
126 |
+
|
127 |
+
if len(failed_datasets) > 0:
|
128 |
+
print(f"Failed to sync {len(failed_datasets)} datasets")
|
129 |
+
for failed in failed_datasets:
|
130 |
+
print(f" - {failed['dataset'].host}-{failed['dataset'].collection_id}")
|
131 |
+
print(f" - ERR: {failed['error']}")
|
extensions_built_in/dataset_tools/__init__.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from toolkit.extension import Extension
|
2 |
+
|
3 |
+
|
4 |
+
class DatasetToolsExtension(Extension):
|
5 |
+
uid = "dataset_tools"
|
6 |
+
|
7 |
+
# name is the name of the extension for printing
|
8 |
+
name = "Dataset Tools"
|
9 |
+
|
10 |
+
# This is where your process class is loaded
|
11 |
+
# keep your imports in here so they don't slow down the rest of the program
|
12 |
+
@classmethod
|
13 |
+
def get_process(cls):
|
14 |
+
# import your process class here so it is only loaded when needed and return it
|
15 |
+
from .DatasetTools import DatasetTools
|
16 |
+
return DatasetTools
|
17 |
+
|
18 |
+
|
19 |
+
class SyncFromCollectionExtension(Extension):
|
20 |
+
uid = "sync_from_collection"
|
21 |
+
name = "Sync from Collection"
|
22 |
+
|
23 |
+
@classmethod
|
24 |
+
def get_process(cls):
|
25 |
+
# import your process class here so it is only loaded when needed and return it
|
26 |
+
from .SyncFromCollection import SyncFromCollection
|
27 |
+
return SyncFromCollection
|
28 |
+
|
29 |
+
|
30 |
+
class SuperTaggerExtension(Extension):
|
31 |
+
uid = "super_tagger"
|
32 |
+
name = "Super Tagger"
|
33 |
+
|
34 |
+
@classmethod
|
35 |
+
def get_process(cls):
|
36 |
+
# import your process class here so it is only loaded when needed and return it
|
37 |
+
from .SuperTagger import SuperTagger
|
38 |
+
return SuperTagger
|
39 |
+
|
40 |
+
|
41 |
+
AI_TOOLKIT_EXTENSIONS = [
|
42 |
+
SyncFromCollectionExtension, DatasetToolsExtension, SuperTaggerExtension
|
43 |
+
]
|
extensions_built_in/dataset_tools/tools/caption.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
caption_manipulation_steps = ['caption', 'caption_short']
|
3 |
+
|
4 |
+
default_long_prompt = 'caption this image. describe every single thing in the image in detail. Do not include any unnecessary words in your description for the sake of good grammar. I want many short statements that serve the single purpose of giving the most thorough description if items as possible in the smallest, comma separated way possible. be sure to describe people\'s moods, clothing, the environment, lighting, colors, and everything.'
|
5 |
+
default_short_prompt = 'caption this image in less than ten words'
|
6 |
+
|
7 |
+
default_replacements = [
|
8 |
+
("the image features", ""),
|
9 |
+
("the image shows", ""),
|
10 |
+
("the image depicts", ""),
|
11 |
+
("the image is", ""),
|
12 |
+
("in this image", ""),
|
13 |
+
("in the image", ""),
|
14 |
+
]
|
15 |
+
|
16 |
+
|
17 |
+
def clean_caption(cap, replacements=None):
|
18 |
+
if replacements is None:
|
19 |
+
replacements = default_replacements
|
20 |
+
|
21 |
+
# remove any newlines
|
22 |
+
cap = cap.replace("\n", ", ")
|
23 |
+
cap = cap.replace("\r", ", ")
|
24 |
+
cap = cap.replace(".", ",")
|
25 |
+
cap = cap.replace("\"", "")
|
26 |
+
|
27 |
+
# remove unicode characters
|
28 |
+
cap = cap.encode('ascii', 'ignore').decode('ascii')
|
29 |
+
|
30 |
+
# make lowercase
|
31 |
+
cap = cap.lower()
|
32 |
+
# remove any extra spaces
|
33 |
+
cap = " ".join(cap.split())
|
34 |
+
|
35 |
+
for replacement in replacements:
|
36 |
+
if replacement[0].startswith('*'):
|
37 |
+
# we are removing all text if it starts with this and the rest matches
|
38 |
+
search_text = replacement[0][1:]
|
39 |
+
if cap.startswith(search_text):
|
40 |
+
cap = ""
|
41 |
+
else:
|
42 |
+
cap = cap.replace(replacement[0].lower(), replacement[1].lower())
|
43 |
+
|
44 |
+
cap_list = cap.split(",")
|
45 |
+
# trim whitespace
|
46 |
+
cap_list = [c.strip() for c in cap_list]
|
47 |
+
# remove empty strings
|
48 |
+
cap_list = [c for c in cap_list if c != ""]
|
49 |
+
# remove duplicates
|
50 |
+
cap_list = list(dict.fromkeys(cap_list))
|
51 |
+
# join back together
|
52 |
+
cap = ", ".join(cap_list)
|
53 |
+
return cap
|