jbilcke-hf HF Staff commited on
Commit
3cc1e25
·
verified ·
1 Parent(s): 3180dc0

Upload 430 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +3 -0
  2. FAQ.md +10 -0
  3. LICENSE +21 -0
  4. README.md +473 -8
  5. assets/VAE_test1.jpg +3 -0
  6. assets/glif.svg +40 -0
  7. assets/lora_ease_ui.png +3 -0
  8. build_and_push_docker +29 -0
  9. build_and_push_docker_dev +21 -0
  10. config/examples/extract.example.yml +75 -0
  11. config/examples/generate.example.yaml +60 -0
  12. config/examples/mod_lora_scale.yaml +48 -0
  13. config/examples/modal/modal_train_lora_flux_24gb.yaml +96 -0
  14. config/examples/modal/modal_train_lora_flux_schnell_24gb.yaml +98 -0
  15. config/examples/train_flex_redux.yaml +112 -0
  16. config/examples/train_full_fine_tune_flex.yaml +107 -0
  17. config/examples/train_full_fine_tune_lumina.yaml +99 -0
  18. config/examples/train_lora_chroma_24gb.yaml +104 -0
  19. config/examples/train_lora_flex2_24gb.yaml +165 -0
  20. config/examples/train_lora_flex_24gb.yaml +101 -0
  21. config/examples/train_lora_flux_24gb.yaml +96 -0
  22. config/examples/train_lora_flux_kontext_24gb.yaml +106 -0
  23. config/examples/train_lora_flux_schnell_24gb.yaml +98 -0
  24. config/examples/train_lora_hidream_48.yaml +112 -0
  25. config/examples/train_lora_lumina.yaml +96 -0
  26. config/examples/train_lora_omnigen2_24gb.yaml +94 -0
  27. config/examples/train_lora_sd35_large_24gb.yaml +97 -0
  28. config/examples/train_lora_wan21_14b_24gb.yaml +101 -0
  29. config/examples/train_lora_wan21_1b_24gb.yaml +90 -0
  30. config/examples/train_slider.example.yml +230 -0
  31. docker-compose.yml +25 -0
  32. docker/Dockerfile +83 -0
  33. docker/start.sh +70 -0
  34. extensions/example/ExampleMergeModels.py +129 -0
  35. extensions/example/__init__.py +25 -0
  36. extensions/example/config/config.example.yaml +48 -0
  37. extensions_built_in/.DS_Store +0 -0
  38. extensions_built_in/advanced_generator/Img2ImgGenerator.py +256 -0
  39. extensions_built_in/advanced_generator/PureLoraGenerator.py +102 -0
  40. extensions_built_in/advanced_generator/ReferenceGenerator.py +212 -0
  41. extensions_built_in/advanced_generator/__init__.py +59 -0
  42. extensions_built_in/advanced_generator/config/train.example.yaml +91 -0
  43. extensions_built_in/concept_replacer/ConceptReplacer.py +151 -0
  44. extensions_built_in/concept_replacer/__init__.py +26 -0
  45. extensions_built_in/concept_replacer/config/train.example.yaml +91 -0
  46. extensions_built_in/dataset_tools/DatasetTools.py +20 -0
  47. extensions_built_in/dataset_tools/SuperTagger.py +196 -0
  48. extensions_built_in/dataset_tools/SyncFromCollection.py +131 -0
  49. extensions_built_in/dataset_tools/__init__.py +43 -0
  50. extensions_built_in/dataset_tools/tools/caption.py +53 -0
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/lora_ease_ui.png filter=lfs diff=lfs merge=lfs -text
37
+ assets/VAE_test1.jpg filter=lfs diff=lfs merge=lfs -text
38
+ toolkit/timestep_weighing/flex_timestep_weights_plot.png filter=lfs diff=lfs merge=lfs -text
FAQ.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # FAQ
2
+
3
+ WIP. Will continue to add things as they are needed.
4
+
5
+ ## FLUX.1 Training
6
+
7
+ #### How much VRAM is required to train a lora on FLUX.1?
8
+
9
+ 24GB minimum is required.
10
+
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Ostris, LLC
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,11 +1,476 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
- title: Ai Toolkit
3
- emoji: 🚀
4
- colorFrom: red
5
- colorTo: gray
6
- sdk: docker
7
- pinned: false
8
- short_description: Ostris AI Toolkit running as a HF space
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  ---
10
 
11
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AI Toolkit by Ostris
2
+
3
+ AI Toolkit is an all in one training suite for diffusion models. I try to support all the latest models on consumer grade hardware. Image and video models. It can be run as a GUI or CLI. It is designed to be easy to use but still have every feature imaginable.
4
+
5
+ ## Support My Work
6
+
7
+ If you enjoy my projects or use them commercially, please consider sponsoring me. Every bit helps! 💖
8
+
9
+ [Sponsor on GitHub](https://github.com/orgs/ostris) | [Support on Patreon](https://www.patreon.com/ostris) | [Donate on PayPal](https://www.paypal.com/donate/?hosted_button_id=9GEFUKC8T9R9W)
10
+
11
+ ### Current Sponsors
12
+
13
+ All of these people / organizations are the ones who selflessly make this project possible. Thank you!!
14
+
15
+ _Last updated: 2025-08-08 17:01 UTC_
16
+
17
+ <p align="center">
18
+ <a href="https://x.com/NuxZoe" target="_blank" rel="noopener noreferrer"><img src="https://pbs.twimg.com/profile_images/1919488160125616128/QAZXTMEj_400x400.png" alt="a16z" width="200" height="200" style="border-radius:8px;margin:5px;display: inline-block;"></a>
19
+ <a href="https://github.com/replicate" target="_blank" rel="noopener noreferrer"><img src="https://avatars.githubusercontent.com/u/60410876?v=4" alt="Replicate" width="200" height="200" style="border-radius:8px;margin:5px;display: inline-block;"></a>
20
+ <a href="https://github.com/huggingface" target="_blank" rel="noopener noreferrer"><img src="https://avatars.githubusercontent.com/u/25720743?v=4" alt="Hugging Face" width="200" height="200" style="border-radius:8px;margin:5px;display: inline-block;"></a>
21
+ <a href="https://github.com/josephrocca" target="_blank" rel="noopener noreferrer"><img src="https://avatars.githubusercontent.com/u/1167575?u=92d92921b4cb5c8c7e225663fed53c4b41897736&v=4" alt="josephrocca" width="200" height="200" style="border-radius:8px;margin:5px;display: inline-block;"></a>
22
+ <img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/162524101/81a72689c3754ac5b9e38612ce5ce914/eyJ3IjoyMDB9/1.png?token-hash=JHRjAxd2XxV1aXIUijj-l65pfTnLoefYSvgNPAsw2lI%3D" alt="Prasanth Veerina" width="200" height="200" style="border-radius:8px;margin:5px;display: inline-block;">
23
+ <a href="https://github.com/weights-ai" target="_blank" rel="noopener noreferrer"><img src="https://avatars.githubusercontent.com/u/185568492?v=4" alt="Weights" width="200" height="200" style="border-radius:8px;margin:5px;display: inline-block;"></a>
24
+ </p>
25
+ <hr style="width:100%;border:none;height:2px;background:#ddd;margin:30px 0;">
26
+ <p align="center">
27
+ <img src="https://c8.patreon.com/4/200/93304/J" alt="Joseph Rocca" width="150" height="150" style="border-radius:8px;margin:5px;display: inline-block;">
28
+ <img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/161471720/dd330b4036d44a5985ed5985c12a5def/eyJ3IjoyMDB9/1.jpeg?token-hash=k1f4Vv7TevzYa9tqlzAjsogYmkZs8nrXQohPCDGJGkc%3D" alt="Vladimir Sotnikov" width="150" height="150" style="border-radius:8px;margin:5px;display: inline-block;">
29
+ <img src="https://c8.patreon.com/4/200/33158543/C" alt="clement Delangue" width="150" height="150" style="border-radius:8px;margin:5px;display: inline-block;">
30
+ <img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/8654302/b0f5ebedc62a47c4b56222693e1254e9/eyJ3IjoyMDB9/2.jpeg?token-hash=suI7_QjKUgWpdPuJPaIkElkTrXfItHlL8ZHLPT-w_d4%3D" alt="Misch Strotz" width="150" height="150" style="border-radius:8px;margin:5px;display: inline-block;">
31
+ <img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/120239481/49b1ce70d3d24704b8ec34de24ec8f55/eyJ3IjoyMDB9/1.jpeg?token-hash=o0y1JqSXqtGvVXnxb06HMXjQXs6OII9yMMx5WyyUqT4%3D" alt="nitish PNR" width="150" height="150" style="border-radius:8px;margin:5px;display: inline-block;">
32
+ </p>
33
+ <hr style="width:100%;border:none;height:2px;background:#ddd;margin:30px 0;">
34
+ <p align="center">
35
+ <img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/2298192/1228b69bd7d7481baf3103315183250d/eyJ3IjoyMDB9/1.jpg?token-hash=opN1e4r4Nnvqbtr8R9HI8eyf9m5F50CiHDOdHzb4UcA%3D" alt="Mohamed Oumoumad" width="100" height="100" style="border-radius:8px;margin:5px;display: inline-block;">
36
+ <img src="https://c8.patreon.com/4/200/548524/S" alt="Steve Hanff" width="100" height="100" style="border-radius:8px;margin:5px;display: inline-block;">
37
+ <img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/152118848/3b15a43d71714552b5ed1c9f84e66adf/eyJ3IjoyMDB9/1.png?token-hash=MKf3sWHz0MFPm_OAFjdsNvxoBfN5B5l54mn1ORdlRy8%3D" alt="Kristjan Retter" width="100" height="100" style="border-radius:8px;margin:5px;display: inline-block;">
38
+ <img src="https://c8.patreon.com/4/200/83319230/M" alt="Miguel Lara" width="100" height="100" style="border-radius:8px;margin:5px;display: inline-block;">
39
+ <img src="https://c8.patreon.com/4/200/8449560/P" alt="Patron" width="100" height="100" style="border-radius:8px;margin:5px;display: inline-block;">
40
+ <a href="https://x.com/NuxZoe" target="_blank" rel="noopener noreferrer"><img src="https://pbs.twimg.com/profile_images/1916482710069014528/RDLnPRSg_400x400.jpg" alt="tungsten" width="100" height="100" style="border-radius:8px;margin:5px;display: inline-block;"></a>
41
+ <img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/169502989/220069e79ce745b29237e94c22a729df/eyJ3IjoyMDB9/1.png?token-hash=E8E2JOqx66k2zMtYUw8Gy57dw-gVqA6OPpdCmWFFSFw%3D" alt="Timothy Bielec" width="100" height="100" style="border-radius:8px;margin:5px;display: inline-block;">
42
+ <img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/34200989/58ae95ebda0640c8b7a91b4fa31357aa/eyJ3IjoyMDB9/1.jpeg?token-hash=4mVDM1kCYGauYa33zLG14_g0oj9_UjDK_-Qp4zk42GE%3D" alt="Noah Miller" width="100" height="100" style="border-radius:8px;margin:5px;display: inline-block;">
43
+ <img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/27288932/6c35d2d961ee4e14a7a368c990791315/eyJ3IjoyMDB9/1.jpeg?token-hash=TGIto_PGEG2NEKNyqwzEnRStOkhrjb3QlMhHA3raKJY%3D" alt="David Garrido" width="100" height="100" style="border-radius:8px;margin:5px;display: inline-block;">
44
+ <a href="https://x.com/RalFingerLP" target="_blank" rel="noopener noreferrer"><img src="https://pbs.twimg.com/profile_images/919595465041162241/ZU7X3T5k_400x400.jpg" alt="RalFinger" width="100" height="100" style="border-radius:8px;margin:5px;display: inline-block;"></a>
45
+ </p>
46
+ <hr style="width:100%;border:none;height:2px;background:#ddd;margin:30px 0;">
47
+ <p align="center">
48
+ <a href="http://www.ir-ltd.net" target="_blank" rel="noopener noreferrer"><img src="https://pbs.twimg.com/profile_images/1602579392198283264/6Tm2GYus_400x400.jpg" alt="IR-Entertainment Ltd" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;"></a>
49
+ <img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/9547341/bb35d9a222fd460e862e960ba3eacbaf/eyJ3IjoyMDB9/1.jpeg?token-hash=Q2XGDvkCbiONeWNxBCTeTMOcuwTjOaJ8Z-CAf5xq3Hs%3D" alt="Travis Harrington" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
50
+ <img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/98811435/3a3632d1795b4c2b9f8f0270f2f6a650/eyJ3IjoyMDB9/1.jpeg?token-hash=657rzuJ0bZavMRZW3XZ-xQGqm3Vk6FkMZgFJVMCOPdk%3D" alt="EmmanuelMr18" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
51
+ <img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/81275465/1e4148fe9c47452b838949d02dd9a70f/eyJ3IjoyMDB9/1.jpeg?token-hash=YAX1ucxybpCIujUCXfdwzUQkttIn3c7pfi59uaFPSwM%3D" alt="Aaron Amortegui" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
52
+ <img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/155963250/6f8fd7075c3b4247bfeb054ba49172d6/eyJ3IjoyMDB9/1.png?token-hash=z81EHmdU2cqSrwa9vJmZTV3h0LG-z9Qakhxq34FrYT4%3D" alt="Un Defined" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
53
+ <img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/45562978/0de33cf52ec642ae8a2f612cddec4ca6/eyJ3IjoyMDB9/1.jpeg?token-hash=aD4debMD5ZQjqTII6s4zYSgVK2-bdQt9p3eipi0bENs%3D" alt="Jack English" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
54
+ <img src="https://c8.patreon.com/4/200/27791680/J" alt="Jean-Tristan Marin" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
55
+ <img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/570742/4ceb33453a5a4745b430a216aba9280f/eyJ3IjoyMDB9/1.jpg?token-hash=nPcJ2zj3sloND9jvbnbYnob2vMXRnXdRuujthqDLWlU%3D" alt="Al H" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
56
+ <img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/82763/f99cc484361d4b9d94fe4f0814ada303/eyJ3IjoyMDB9/1.jpeg?token-hash=A3JWlBNL0b24FFWb-FCRDAyhs-OAxg-zrhfBXP_axuU%3D" alt="Doron Adler" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
57
+ <img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/103077711/bb215761cc004e80bd9cec7d4bcd636d/eyJ3IjoyMDB9/2.jpeg?token-hash=3U8kdZSUpnmeYIDVK4zK9TTXFpnAud_zOwBRXx18018%3D" alt="John Dopamine" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
58
+ <img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/99036356/7ae9c4d80e604e739b68cca12ee2ed01/eyJ3IjoyMDB9/3.png?token-hash=ZhsBMoTOZjJ-Y6h5NOmU5MT-vDb2fjK46JDlpEehkVQ%3D" alt="Noctre" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
59
+ <img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/141098579/1a9f0a1249d447a7a0df718a57343912/eyJ3IjoyMDB9/2.png?token-hash=_n-AQmPgY0FP9zCGTIEsr5ka4Y7YuaMkt3qL26ZqGg8%3D" alt="The Local Lab" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
60
+ <img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/93348210/5c650f32a0bc481d80900d2674528777/eyJ3IjoyMDB9/1.jpeg?token-hash=0jiknRw3jXqYWW6En8bNfuHgVDj4LI_rL7lSS4-_xlo%3D" alt="Armin Behjati" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
61
+ <img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/134129880/680c7e14cd1a4d1a9face921fb010f88/eyJ3IjoyMDB9/1.png?token-hash=5fqqHE6DCTbt7gDQL7VRcWkV71jF7FvWcLhpYl5aMXA%3D" alt="Bharat Prabhakar" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
62
+ <img src="https://c8.patreon.com/4/200/70218846/C" alt="Cosmosis" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
63
+ <img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/30931983/54ab4e4ceab946e79a6418d205f9ed51/eyJ3IjoyMDB9/1.png?token-hash=j2phDrgd6IWuqKqNIDbq9fR2B3fMF-GUCQSdETS1w5Y%3D" alt="HestoySeghuro ." width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
64
+ <img src="https://c8.patreon.com/4/200/4105384/J" alt="Jack Blakely" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
65
+ <img src="https://c8.patreon.com/4/200/4541423/S" alt="Sören " width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
66
+ <a href="https://www.youtube.com/@happyme7055" target="_blank" rel="noopener noreferrer"><img src="https://yt3.googleusercontent.com/ytc/AIdro_mFqhIRk99SoEWY2gvSvVp6u1SkCGMkRqYQ1OlBBeoOVp8=s160-c-k-c0x00ffffff-no-rj" alt="Marcus Rass" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;"></a>
67
+ <img src="https://c8.patreon.com/4/200/53077895/M" alt="Marc" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
68
+ <img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/157407541/bb9d80cffdab4334ad78366060561520/eyJ3IjoyMDB9/2.png?token-hash=WYz-U_9zabhHstOT5UIa5jBaoFwrwwqyWxWEzIR2m_c%3D" alt="Tokio Studio srl IT10640050968" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
69
+ <img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/44568304/a9d83a0e786b41b4bdada150f7c9271c/eyJ3IjoyMDB9/1.jpeg?token-hash=FtxnwrSrknQUQKvDRv2rqPceX2EF23eLq4pNQYM_fmw%3D" alt="Albert Bukoski" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
70
+ <img src="https://c8.patreon.com/4/200/5048649/B" alt="Ben Ward" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
71
+ <img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/111904990/08b1cf65be6a4de091c9b73b693b3468/eyJ3IjoyMDB9/1.png?token-hash=_Odz6RD3CxtubEHbUxYujcjw6zAajbo3w8TRz249VBA%3D" alt="Brian Smith" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
72
+ <img src="https://c8.patreon.com/4/200/494309/J" alt="Julian Tsependa" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
73
+ <img src="https://c8.patreon.com/4/200/5602036/K" alt="Kelevra" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
74
+ <img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/159203973/36c817f941ac4fa18103a4b8c0cb9cae/eyJ3IjoyMDB9/1.png?token-hash=zkt72HW3EoiIEAn3LSk9gJPBsXfuTVcc4rRBS3CeR8w%3D" alt="Marko jak" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
75
+ <img src="https://c8.patreon.com/4/200/24653779/R" alt="RayHell" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
76
+ <img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/76566911/6485eaf5ec6249a7b524ee0b979372f0/eyJ3IjoyMDB9/1.jpeg?token-hash=mwCSkTelDBaengG32NkN0lVl5mRjB-cwo6-a47wnOsU%3D" alt="the biitz" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
77
+ <img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/32633822/1ab5612efe80417cbebfe91e871fc052/eyJ3IjoyMDB9/1.png?token-hash=pOS_IU3b3RL5-iL96A3Xqoj2bQ-dDo4RUkBylcMED_s%3D" alt="Zack Abrams" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
78
+ <img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/97985240/3d1d0e6905d045aba713e8132cab4a30/eyJ3IjoyMDB9/1.png?token-hash=fRavvbO_yqWKA_OsJb5DzjfKZ1Yt-TG-ihMoeVBvlcM%3D" alt="עומר מכלוף" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
79
+ <a href="https://github.com/julien-blanchon" target="_blank" rel="noopener noreferrer"><img src="https://avatars.githubusercontent.com/u/11278197?v=4" alt="Blanchon" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;"></a>
80
+ <img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/11198131/e696d9647feb4318bcf16243c2425805/eyJ3IjoyMDB9/1.jpeg?token-hash=c2c2p1SaiX86iXAigvGRvzm4jDHvIFCg298A49nIfUM%3D" alt="Nicholas Agranoff" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
81
+ <img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/785333/bdb9ede5765d42e5a2021a86eebf0d8f/eyJ3IjoyMDB9/2.jpg?token-hash=l_rajMhxTm6wFFPn7YdoKBxeUqhdRXKdy6_8SGCuNsE%3D" alt="Sapjes " width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
82
+ <img src="https://c8.patreon.com/4/200/2446176/S" alt="Scott VanKirk" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
83
+ <img src="https://c8.patreon.com/4/200/83034/W" alt="william tatum" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
84
+ <img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/138787189/2b5662dcb638466282ac758e3ac651b4/eyJ3IjoyMDB9/1.png?token-hash=zwj7MScO18vhDxhKt6s5q4gdeNJM3xCLuhSt8zlqlZs%3D" alt="Антон Антонио" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
85
+ <img src="https://c8.patreon.com/4/200/30530914/T" alt="Techer " width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
86
+ <img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/25209707/36ae876d662d4d85aaf162b6d67d31e7/eyJ3IjoyMDB9/1.png?token-hash=Zows_A6uqlY5jClhfr4Y3QfMnDKVkS3mbxNHUDkVejo%3D" alt="fjioq8" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
87
+ <img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/46680573/ee3d99c04a674dd5a8e1ecfb926db6a2/eyJ3IjoyMDB9/1.jpeg?token-hash=cgD4EXyfZMPnXIrcqWQ5jGqzRUfqjPafb9yWfZUPB4Q%3D" alt="Neil Murray" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
88
+ <img src="https://ostris.com/wp-content/uploads/2025/08/supporter_default.jpg" alt="Joakim Sällström" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
89
+ <img src="https://c8.patreon.com/4/200/63510241/A" alt="Andrew Park" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
90
+ <a href="https://github.com/Spikhalskiy" target="_blank" rel="noopener noreferrer"><img src="https://avatars.githubusercontent.com/u/532108?u=2464983638afea8caf4cd9f0e4a7bc3e6a63bb0a&v=4" alt="Dmitry Spikhalsky" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;"></a>
91
+ <img src="https://c8.patreon.com/4/200/88567307/E" alt="el Chavo" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
92
+ <img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/117569999/55f75c57f95343e58402529cec852b26/eyJ3IjoyMDB9/1.jpeg?token-hash=squblHZH4-eMs3gI46Uqu1oTOK9sQ-0gcsFdZcB9xQg%3D" alt="James Thompson" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
93
+ <img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/66157709/6fe70df085e24464995a1a9293a53760/eyJ3IjoyMDB9/1.jpeg?token-hash=eqe0wvg6JfbRUGMKpL_x3YPI5Ppf18aUUJe2EzADU-g%3D" alt="Joey Santana" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
94
+ <img src="https://ostris.com/wp-content/uploads/2025/08/supporter_default.jpg" alt="Heikki Rinkinen" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
95
+ <img src="https://c8.patreon.com/4/200/6175608/B" alt="Bobbie " width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
96
+ <a href="https://github.com/Slartibart23" target="_blank" rel="noopener noreferrer"><img src="https://avatars.githubusercontent.com/u/133593860?u=31217adb2522fb295805824ffa7e14e8f0fca6fa&v=4" alt="Slarti" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;"></a>
97
+ <img src="https://ostris.com/wp-content/uploads/2025/08/supporter_default.jpg" alt="Tommy Falkowski" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
98
+ <img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/28533016/e8f6044ccfa7483f87eeaa01c894a773/eyJ3IjoyMDB9/2.png?token-hash=ak-h3JWB50hyenCavcs32AAPw6nNhmH2nBFKpdk5hvM%3D" alt="William Tatum" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
99
+ <img src="https://ostris.com/wp-content/uploads/2025/08/supporter_default.jpg" alt="Karol Stępień" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
100
+ <img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/156564939/17dbfd45c59d4cf29853d710cb0c5d6f/eyJ3IjoyMDB9/1.png?token-hash=e6wXA_S8cgJeEDI9eJK934eB0TiM8mxJm9zW_VH0gDU%3D" alt="Hans Untch" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
101
+ <img src="https://c8.patreon.com/4/200/59408413/B" alt="ByteC" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
102
+ <img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/3712451/432e22a355494ec0a1ea1927ff8d452e/eyJ3IjoyMDB9/7.jpeg?token-hash=OpQ9SAfVQ4Un9dSYlGTHuApZo5GlJ797Mo0DtVtMOSc%3D" alt="David Shorey" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
103
+ <img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/53634141/c1441f6c605344bbaef885d4272977bb/eyJ3IjoyMDB9/1.JPG?token-hash=Aizd6AxQhY3n6TBE5AwCVeSwEBbjALxQmu6xqc08qBo%3D" alt="Jana Spacelight" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
104
+ <img src="https://c8.patreon.com/4/200/11180426/J" alt="jarrett towe" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
105
+ <img src="https://c8.patreon.com/4/200/21828017/J" alt="Jim" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
106
+ <img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/63232055/2300b4ab370341b5b476902c9b8218ee/eyJ3IjoyMDB9/1.png?token-hash=R9Nb4O0aLBRwxT1cGHUMThlvf6A2MD5SO88lpZBdH7M%3D" alt="Marek P" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
107
+ <img src="https://c8.patreon.com/4/200/9944625/P" alt="Pomoe " width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
108
+ <img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/25047900/423e4cb73aba457f8f9c6e5582eddaeb/eyJ3IjoyMDB9/1.jpeg?token-hash=81RvQXBbT66usxqtyWum9Ul4oBn3qHK1cM71IvthC-U%3D" alt="Ruairi Robinson" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
109
+ <img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/178476551/0b9e83efcd234df5a6bea30d59e6c1cd/eyJ3IjoyMDB9/1.png?token-hash=3XoYMrMxk-K6GelM22mE-FwkjFulX9hpIL7QI3wO2jI%3D" alt="Timmy" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
110
+ <img src="https://c8.patreon.com/4/200/10876902/T" alt="Tyssel" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
111
+ <img src="https://ostris.com/wp-content/uploads/2025/08/supporter_default.jpg" alt="Juan Franco" width="60" height="60" style="border-radius:8px;margin:5px;display: inline-block;">
112
+ </p>
113
+
114
  ---
115
+
116
+
117
+
118
+
119
+ ## Installation
120
+
121
+ Requirements:
122
+ - python >3.10
123
+ - Nvidia GPU with enough ram to do what you need
124
+ - python venv
125
+ - git
126
+
127
+
128
+ Linux:
129
+ ```bash
130
+ git clone https://github.com/ostris/ai-toolkit.git
131
+ cd ai-toolkit
132
+ python3 -m venv venv
133
+ source venv/bin/activate
134
+ # install torch first
135
+ pip3 install --no-cache-dir torch==2.7.0 torchvision==0.22.0 torchaudio==2.7.0 --index-url https://download.pytorch.org/whl/cu126
136
+ pip3 install -r requirements.txt
137
+ ```
138
+
139
+ Windows:
140
+
141
+ If you are having issues with Windows. I recommend using the easy install script at [https://github.com/Tavris1/AI-Toolkit-Easy-Install](https://github.com/Tavris1/AI-Toolkit-Easy-Install)
142
+
143
+ ```bash
144
+ git clone https://github.com/ostris/ai-toolkit.git
145
+ cd ai-toolkit
146
+ python -m venv venv
147
+ .\venv\Scripts\activate
148
+ pip install --no-cache-dir torch==2.7.0 torchvision==0.22.0 torchaudio==2.7.0 --index-url https://download.pytorch.org/whl/cu126
149
+ pip install -r requirements.txt
150
+ ```
151
+
152
+
153
+ # AI Toolkit UI
154
+
155
+ <img src="https://ostris.com/wp-content/uploads/2025/02/toolkit-ui.jpg" alt="AI Toolkit UI" width="100%">
156
+
157
+ The AI Toolkit UI is a web interface for the AI Toolkit. It allows you to easily start, stop, and monitor jobs. It also allows you to easily train models with a few clicks. It also allows you to set a token for the UI to prevent unauthorized access so it is mostly safe to run on an exposed server.
158
+
159
+ ## Running the UI
160
+
161
+ Requirements:
162
+ - Node.js > 18
163
+
164
+ The UI does not need to be kept running for the jobs to run. It is only needed to start/stop/monitor jobs. The commands below
165
+ will install / update the UI and it's dependencies and start the UI.
166
+
167
+ ```bash
168
+ cd ui
169
+ npm run build_and_start
170
+ ```
171
+
172
+ You can now access the UI at `http://localhost:8675` or `http://<your-ip>:8675` if you are running it on a server.
173
+
174
+ ## Securing the UI
175
+
176
+ If you are hosting the UI on a cloud provider or any network that is not secure, I highly recommend securing it with an auth token.
177
+ You can do this by setting the environment variable `AI_TOOLKIT_AUTH` to super secure password. This token will be required to access
178
+ the UI. You can set this when starting the UI like so:
179
+
180
+ ```bash
181
+ # Linux
182
+ AI_TOOLKIT_AUTH=super_secure_password npm run build_and_start
183
+
184
+ # Windows
185
+ set AI_TOOLKIT_AUTH=super_secure_password && npm run build_and_start
186
+
187
+ # Windows Powershell
188
+ $env:AI_TOOLKIT_AUTH="super_secure_password"; npm run build_and_start
189
+ ```
190
+
191
+
192
+ ## FLUX.1 Training
193
+
194
+ ### Tutorial
195
+
196
+ To get started quickly, check out [@araminta_k](https://x.com/araminta_k) tutorial on [Finetuning Flux Dev on a 3090](https://www.youtube.com/watch?v=HzGW_Kyermg) with 24GB VRAM.
197
+
198
+
199
+ ### Requirements
200
+ You currently need a GPU with **at least 24GB of VRAM** to train FLUX.1. If you are using it as your GPU to control
201
+ your monitors, you probably need to set the flag `low_vram: true` in the config file under `model:`. This will quantize
202
+ the model on CPU and should allow it to train with monitors attached. Users have gotten it to work on Windows with WSL,
203
+ but there are some reports of a bug when running on windows natively.
204
+ I have only tested on linux for now. This is still extremely experimental
205
+ and a lot of quantizing and tricks had to happen to get it to fit on 24GB at all.
206
+
207
+ ### FLUX.1-dev
208
+
209
+ FLUX.1-dev has a non-commercial license. Which means anything you train will inherit the
210
+ non-commercial license. It is also a gated model, so you need to accept the license on HF before using it.
211
+ Otherwise, this will fail. Here are the required steps to setup a license.
212
+
213
+ 1. Sign into HF and accept the model access here [black-forest-labs/FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev)
214
+ 2. Make a file named `.env` in the root on this folder
215
+ 3. [Get a READ key from huggingface](https://huggingface.co/settings/tokens/new?) and add it to the `.env` file like so `HF_TOKEN=your_key_here`
216
+
217
+ ### FLUX.1-schnell
218
+
219
+ FLUX.1-schnell is Apache 2.0. Anything trained on it can be licensed however you want and it does not require a HF_TOKEN to train.
220
+ However, it does require a special adapter to train with it, [ostris/FLUX.1-schnell-training-adapter](https://huggingface.co/ostris/FLUX.1-schnell-training-adapter).
221
+ It is also highly experimental. For best overall quality, training on FLUX.1-dev is recommended.
222
+
223
+ To use it, You just need to add the assistant to the `model` section of your config file like so:
224
+
225
+ ```yaml
226
+ model:
227
+ name_or_path: "black-forest-labs/FLUX.1-schnell"
228
+ assistant_lora_path: "ostris/FLUX.1-schnell-training-adapter"
229
+ is_flux: true
230
+ quantize: true
231
+ ```
232
+
233
+ You also need to adjust your sample steps since schnell does not require as many
234
+
235
+ ```yaml
236
+ sample:
237
+ guidance_scale: 1 # schnell does not do guidance
238
+ sample_steps: 4 # 1 - 4 works well
239
+ ```
240
+
241
+ ### Training
242
+ 1. Copy the example config file located at `config/examples/train_lora_flux_24gb.yaml` (`config/examples/train_lora_flux_schnell_24gb.yaml` for schnell) to the `config` folder and rename it to `whatever_you_want.yml`
243
+ 2. Edit the file following the comments in the file
244
+ 3. Run the file like so `python run.py config/whatever_you_want.yml`
245
+
246
+ A folder with the name and the training folder from the config file will be created when you start. It will have all
247
+ checkpoints and images in it. You can stop the training at any time using ctrl+c and when you resume, it will pick back up
248
+ from the last checkpoint.
249
+
250
+ IMPORTANT. If you press crtl+c while it is saving, it will likely corrupt that checkpoint. So wait until it is done saving
251
+
252
+ ### Need help?
253
+
254
+ Please do not open a bug report unless it is a bug in the code. You are welcome to [Join my Discord](https://discord.gg/VXmU2f5WEU)
255
+ and ask for help there. However, please refrain from PMing me directly with general question or support. Ask in the discord
256
+ and I will answer when I can.
257
+
258
+ ## Gradio UI
259
+
260
+ To get started training locally with a with a custom UI, once you followed the steps above and `ai-toolkit` is installed:
261
+
262
+ ```bash
263
+ cd ai-toolkit #in case you are not yet in the ai-toolkit folder
264
+ huggingface-cli login #provide a `write` token to publish your LoRA at the end
265
+ python flux_train_ui.py
266
+ ```
267
+
268
+ You will instantiate a UI that will let you upload your images, caption them, train and publish your LoRA
269
+ ![image](assets/lora_ease_ui.png)
270
+
271
+
272
+ ## Training in RunPod
273
+ Example RunPod template: **runpod/pytorch:2.2.0-py3.10-cuda12.1.1-devel-ubuntu22.04**
274
+ > You need a minimum of 24GB VRAM, pick a GPU by your preference.
275
+
276
+ #### Example config ($0.5/hr):
277
+ - 1x A40 (48 GB VRAM)
278
+ - 19 vCPU 100 GB RAM
279
+
280
+ #### Custom overrides (you need some storage to clone FLUX.1, store datasets, store trained models and samples):
281
+ - ~120 GB Disk
282
+ - ~120 GB Pod Volume
283
+ - Start Jupyter Notebook
284
+
285
+ ### 1. Setup
286
+ ```
287
+ git clone https://github.com/ostris/ai-toolkit.git
288
+ cd ai-toolkit
289
+ git submodule update --init --recursive
290
+ python -m venv venv
291
+ source venv/bin/activate
292
+ pip install torch
293
+ pip install -r requirements.txt
294
+ pip install --upgrade accelerate transformers diffusers huggingface_hub #Optional, run it if you run into issues
295
+ ```
296
+ ### 2. Upload your dataset
297
+ - Create a new folder in the root, name it `dataset` or whatever you like.
298
+ - Drag and drop your .jpg, .jpeg, or .png images and .txt files inside the newly created dataset folder.
299
+
300
+ ### 3. Login into Hugging Face with an Access Token
301
+ - Get a READ token from [here](https://huggingface.co/settings/tokens) and request access to Flux.1-dev model from [here](https://huggingface.co/black-forest-labs/FLUX.1-dev).
302
+ - Run ```huggingface-cli login``` and paste your token.
303
+
304
+ ### 4. Training
305
+ - Copy an example config file located at ```config/examples``` to the config folder and rename it to ```whatever_you_want.yml```.
306
+ - Edit the config following the comments in the file.
307
+ - Change ```folder_path: "/path/to/images/folder"``` to your dataset path like ```folder_path: "/workspace/ai-toolkit/your-dataset"```.
308
+ - Run the file: ```python run.py config/whatever_you_want.yml```.
309
+
310
+ ### Screenshot from RunPod
311
+ <img width="1728" alt="RunPod Training Screenshot" src="https://github.com/user-attachments/assets/53a1b8ef-92fa-4481-81a7-bde45a14a7b5">
312
+
313
+ ## Training in Modal
314
+
315
+ ### 1. Setup
316
+ #### ai-toolkit:
317
+ ```
318
+ git clone https://github.com/ostris/ai-toolkit.git
319
+ cd ai-toolkit
320
+ git submodule update --init --recursive
321
+ python -m venv venv
322
+ source venv/bin/activate
323
+ pip install torch
324
+ pip install -r requirements.txt
325
+ pip install --upgrade accelerate transformers diffusers huggingface_hub #Optional, run it if you run into issues
326
+ ```
327
+ #### Modal:
328
+ - Run `pip install modal` to install the modal Python package.
329
+ - Run `modal setup` to authenticate (if this doesn’t work, try `python -m modal setup`).
330
+
331
+ #### Hugging Face:
332
+ - Get a READ token from [here](https://huggingface.co/settings/tokens) and request access to Flux.1-dev model from [here](https://huggingface.co/black-forest-labs/FLUX.1-dev).
333
+ - Run `huggingface-cli login` and paste your token.
334
+
335
+ ### 2. Upload your dataset
336
+ - Drag and drop your dataset folder containing the .jpg, .jpeg, or .png images and .txt files in `ai-toolkit`.
337
+
338
+ ### 3. Configs
339
+ - Copy an example config file located at ```config/examples/modal``` to the `config` folder and rename it to ```whatever_you_want.yml```.
340
+ - Edit the config following the comments in the file, **<ins>be careful and follow the example `/root/ai-toolkit` paths</ins>**.
341
+
342
+ ### 4. Edit run_modal.py
343
+ - Set your entire local `ai-toolkit` path at `code_mount = modal.Mount.from_local_dir` like:
344
+
345
+ ```
346
+ code_mount = modal.Mount.from_local_dir("/Users/username/ai-toolkit", remote_path="/root/ai-toolkit")
347
+ ```
348
+ - Choose a `GPU` and `Timeout` in `@app.function` _(default is A100 40GB and 2 hour timeout)_.
349
+
350
+ ### 5. Training
351
+ - Run the config file in your terminal: `modal run run_modal.py --config-file-list-str=/root/ai-toolkit/config/whatever_you_want.yml`.
352
+ - You can monitor your training in your local terminal, or on [modal.com](https://modal.com/).
353
+ - Models, samples and optimizer will be stored in `Storage > flux-lora-models`.
354
+
355
+ ### 6. Saving the model
356
+ - Check contents of the volume by running `modal volume ls flux-lora-models`.
357
+ - Download the content by running `modal volume get flux-lora-models your-model-name`.
358
+ - Example: `modal volume get flux-lora-models my_first_flux_lora_v1`.
359
+
360
+ ### Screenshot from Modal
361
+
362
+ <img width="1728" alt="Modal Traning Screenshot" src="https://github.com/user-attachments/assets/7497eb38-0090-49d6-8ad9-9c8ea7b5388b">
363
+
364
  ---
365
 
366
+ ## Dataset Preparation
367
+
368
+ Datasets generally need to be a folder containing images and associated text files. Currently, the only supported
369
+ formats are jpg, jpeg, and png. Webp currently has issues. The text files should be named the same as the images
370
+ but with a `.txt` extension. For example `image2.jpg` and `image2.txt`. The text file should contain only the caption.
371
+ You can add the word `[trigger]` in the caption file and if you have `trigger_word` in your config, it will be automatically
372
+ replaced.
373
+
374
+ Images are never upscaled but they are downscaled and placed in buckets for batching. **You do not need to crop/resize your images**.
375
+ The loader will automatically resize them and can handle varying aspect ratios.
376
+
377
+
378
+ ## Training Specific Layers
379
+
380
+ To train specific layers with LoRA, you can use the `only_if_contains` network kwargs. For instance, if you want to train only the 2 layers
381
+ used by The Last Ben, [mentioned in this post](https://x.com/__TheBen/status/1829554120270987740), you can adjust your
382
+ network kwargs like so:
383
+
384
+ ```yaml
385
+ network:
386
+ type: "lora"
387
+ linear: 128
388
+ linear_alpha: 128
389
+ network_kwargs:
390
+ only_if_contains:
391
+ - "transformer.single_transformer_blocks.7.proj_out"
392
+ - "transformer.single_transformer_blocks.20.proj_out"
393
+ ```
394
+
395
+ The naming conventions of the layers are in diffusers format, so checking the state dict of a model will reveal
396
+ the suffix of the name of the layers you want to train. You can also use this method to only train specific groups of weights.
397
+ For instance to only train the `single_transformer` for FLUX.1, you can use the following:
398
+
399
+ ```yaml
400
+ network:
401
+ type: "lora"
402
+ linear: 128
403
+ linear_alpha: 128
404
+ network_kwargs:
405
+ only_if_contains:
406
+ - "transformer.single_transformer_blocks."
407
+ ```
408
+
409
+ You can also exclude layers by their names by using `ignore_if_contains` network kwarg. So to exclude all the single transformer blocks,
410
+
411
+
412
+ ```yaml
413
+ network:
414
+ type: "lora"
415
+ linear: 128
416
+ linear_alpha: 128
417
+ network_kwargs:
418
+ ignore_if_contains:
419
+ - "transformer.single_transformer_blocks."
420
+ ```
421
+
422
+ `ignore_if_contains` takes priority over `only_if_contains`. So if a weight is covered by both,
423
+ if will be ignored.
424
+
425
+ ## LoKr Training
426
+
427
+ To learn more about LoKr, read more about it at [KohakuBlueleaf/LyCORIS](https://github.com/KohakuBlueleaf/LyCORIS/blob/main/docs/Guidelines.md). To train a LoKr model, you can adjust the network type in the config file like so:
428
+
429
+ ```yaml
430
+ network:
431
+ type: "lokr"
432
+ lokr_full_rank: true
433
+ lokr_factor: 8
434
+ ```
435
+
436
+ Everything else should work the same including layer targeting.
437
+
438
+
439
+ ## Updates
440
+
441
+ Only larger updates are listed here. There are usually smaller daily updated that are omitted.
442
+
443
+ ### Jul 17, 2025
444
+ - Make it easy to add control images to the samples in the ui
445
+
446
+ ### Jul 11, 2025
447
+ - Added better video config settings to the UI for video models.
448
+ - Added Wan I2V training to the UI
449
+
450
+ ### June 29, 2025
451
+ - Fixed issue where Kontext forced sizes on sampling
452
+
453
+ ### June 26, 2025
454
+ - Added support for FLUX.1 Kontext training
455
+ - added support for instruction dataset training
456
+
457
+ ### June 25, 2025
458
+ - Added support for OmniGen2 training
459
+ -
460
+ ### June 17, 2025
461
+ - Performance optimizations for batch preparation
462
+ - Added some docs via a popup for items in the simple ui explaining what settings do. Still a WIP
463
+
464
+ ### June 16, 2025
465
+ - Hide control images in the UI when viewing datasets
466
+ - WIP on mean flow loss
467
+
468
+ ### June 12, 2025
469
+ - Fixed issue that resulted in blank captions in the dataloader
470
+
471
+ ### June 10, 2025
472
+ - Decided to keep track up updates in the readme
473
+ - Added support for SDXL in the UI
474
+ - Added support for SD 1.5 in the UI
475
+ - Fixed UI Wan 2.1 14b name bug
476
+ - Added support for for conv training in the UI for models that support it
assets/VAE_test1.jpg ADDED

Git LFS Details

  • SHA256: 879fcb537d039408d7aada297b7397420132684f0106edacc1205fb5cc839476
  • Pointer size: 132 Bytes
  • Size of remote file: 1.51 MB
assets/glif.svg ADDED
assets/lora_ease_ui.png ADDED

Git LFS Details

  • SHA256: f647b9fe90cc96db2aa84d1cb25a73b60ffcc5394822f99e9dac27d373f89d79
  • Pointer size: 131 Bytes
  • Size of remote file: 349 kB
build_and_push_docker ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ # Extract version from version.py
4
+ if [ -f "version.py" ]; then
5
+ VERSION=$(python3 -c "from version import VERSION; print(VERSION)")
6
+ echo "Building version: $VERSION"
7
+ else
8
+ echo "Error: version.py not found. Please create a version.py file with VERSION defined."
9
+ exit 1
10
+ fi
11
+
12
+ echo "Docker builds from the repo, not this dir. Make sure changes are pushed to the repo."
13
+ echo "Building version: $VERSION and latest"
14
+ # wait 2 seconds
15
+ sleep 2
16
+
17
+ # Build the image with cache busting
18
+ docker build --build-arg CACHEBUST=$(date +%s) -t aitoolkit:$VERSION -f docker/Dockerfile .
19
+
20
+ # Tag with version and latest
21
+ docker tag aitoolkit:$VERSION ostris/aitoolkit:$VERSION
22
+ docker tag aitoolkit:$VERSION ostris/aitoolkit:latest
23
+
24
+ # Push both tags
25
+ echo "Pushing images to Docker Hub..."
26
+ docker push ostris/aitoolkit:$VERSION
27
+ docker push ostris/aitoolkit:latest
28
+
29
+ echo "Successfully built and pushed ostris/aitoolkit:$VERSION and ostris/aitoolkit:latest"
build_and_push_docker_dev ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ VERSION=dev
4
+ GIT_COMMIT=dev
5
+
6
+ echo "Docker builds from the repo, not this dir. Make sure changes are pushed to the repo."
7
+ echo "Building version: $VERSION"
8
+ # wait 2 seconds
9
+ sleep 2
10
+
11
+ # Build the image with cache busting
12
+ docker build --build-arg CACHEBUST=$(date +%s) -t aitoolkit:$VERSION -f docker/Dockerfile .
13
+
14
+ # Tag with version and latest
15
+ docker tag aitoolkit:$VERSION ostris/aitoolkit:$VERSION
16
+
17
+ # Push both tags
18
+ echo "Pushing images to Docker Hub..."
19
+ docker push ostris/aitoolkit:$VERSION
20
+
21
+ echo "Successfully built and pushed ostris/aitoolkit:$VERSION"
config/examples/extract.example.yml ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ # this is in yaml format. You can use json if you prefer
3
+ # I like both but yaml is easier to read and write
4
+ # plus it has comments which is nice for documentation
5
+ job: extract # tells the runner what to do
6
+ config:
7
+ # the name will be used to create a folder in the output folder
8
+ # it will also replace any [name] token in the rest of this config
9
+ name: name_of_your_model
10
+ # can be hugging face model, a .ckpt, or a .safetensors
11
+ base_model: "/path/to/base/model.safetensors"
12
+ # can be hugging face model, a .ckpt, or a .safetensors
13
+ extract_model: "/path/to/model/to/extract/trained.safetensors"
14
+ # we will create folder here with name above so. This will create /path/to/output/folder/name_of_your_model
15
+ output_folder: "/path/to/output/folder"
16
+ is_v2: false
17
+ dtype: fp16 # saved dtype
18
+ device: cpu # cpu, cuda:0, etc
19
+
20
+ # processes can be chained like this to run multiple in a row
21
+ # they must all use same models above, but great for testing different
22
+ # sizes and typed of extractions. It is much faster as we already have the models loaded
23
+ process:
24
+ # process 1
25
+ - type: locon # locon or lora (locon is lycoris)
26
+ filename: "[name]_64_32.safetensors" # will be put in output folder
27
+ dtype: fp16
28
+ mode: fixed
29
+ linear: 64
30
+ conv: 32
31
+
32
+ # process 2
33
+ - type: locon
34
+ output_path: "/absolute/path/for/this/output.safetensors" # can be absolute
35
+ mode: ratio
36
+ linear: 0.2
37
+ conv: 0.2
38
+
39
+ # process 3
40
+ - type: locon
41
+ filename: "[name]_ratio_02.safetensors"
42
+ mode: quantile
43
+ linear: 0.5
44
+ conv: 0.5
45
+
46
+ # process 4
47
+ - type: lora # traditional lora extraction (lierla) with linear layers only
48
+ filename: "[name]_4.safetensors"
49
+ mode: fixed # fixed, ratio, quantile supported for lora as well
50
+ linear: 4 # lora dim or rank
51
+ # no conv for lora
52
+
53
+ # process 5
54
+ - type: lora
55
+ filename: "[name]_q05.safetensors"
56
+ mode: quantile
57
+ linear: 0.5
58
+
59
+ # you can put any information you want here, and it will be saved in the model
60
+ # the below is an example. I recommend doing trigger words at a minimum
61
+ # in the metadata. The software will include this plus some other information
62
+ meta:
63
+ name: "[name]" # [name] gets replaced with the name above
64
+ description: A short description of your model
65
+ trigger_words:
66
+ - put
67
+ - trigger
68
+ - words
69
+ - here
70
+ version: '0.1'
71
+ creator:
72
+ name: Your Name
73
74
+ website: https://yourwebsite.com
75
+ any: All meta data above is arbitrary, it can be whatever you want.
config/examples/generate.example.yaml ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+
3
+ job: generate # tells the runner what to do
4
+ config:
5
+ name: "generate" # this is not really used anywhere currently but required by runner
6
+ process:
7
+ # process 1
8
+ - type: to_folder # process images to a folder
9
+ output_folder: "output/gen"
10
+ device: cuda:0 # cpu, cuda:0, etc
11
+ generate:
12
+ # these are your defaults you can override most of them with flags
13
+ sampler: "ddpm" # ignored for now, will add later though ddpm is used regardless for now
14
+ width: 1024
15
+ height: 1024
16
+ neg: "cartoon, fake, drawing, illustration, cgi, animated, anime"
17
+ seed: -1 # -1 is random
18
+ guidance_scale: 7
19
+ sample_steps: 20
20
+ ext: ".png" # .png, .jpg, .jpeg, .webp
21
+
22
+ # here ate the flags you can use for prompts. Always start with
23
+ # your prompt first then add these flags after. You can use as many
24
+ # like
25
+ # photo of a baseball --n painting, ugly --w 1024 --h 1024 --seed 42 --cfg 7 --steps 20
26
+ # we will try to support all sd-scripts flags where we can
27
+
28
+ # FROM SD-SCRIPTS
29
+ # --n Treat everything until the next option as a negative prompt.
30
+ # --w Specify the width of the generated image.
31
+ # --h Specify the height of the generated image.
32
+ # --d Specify the seed for the generated image.
33
+ # --l Specify the CFG scale for the generated image.
34
+ # --s Specify the number of steps during generation.
35
+
36
+ # OURS and some QOL additions
37
+ # --p2 Prompt for the second text encoder (SDXL only)
38
+ # --n2 Negative prompt for the second text encoder (SDXL only)
39
+ # --gr Specify the guidance rescale for the generated image (SDXL only)
40
+ # --seed Specify the seed for the generated image same as --d
41
+ # --cfg Specify the CFG scale for the generated image same as --l
42
+ # --steps Specify the number of steps during generation same as --s
43
+
44
+ prompt_file: false # if true a txt file will be created next to images with prompt strings used
45
+ # prompts can also be a path to a text file with one prompt per line
46
+ # prompts: "/path/to/prompts.txt"
47
+ prompts:
48
+ - "photo of batman"
49
+ - "photo of superman"
50
+ - "photo of spiderman"
51
+ - "photo of a superhero --n batman superman spiderman"
52
+
53
+ model:
54
+ # huggingface name, relative prom project path, or absolute path to .safetensors or .ckpt
55
+ # name_or_path: "runwayml/stable-diffusion-v1-5"
56
+ name_or_path: "/mnt/Models/stable-diffusion/models/stable-diffusion/Ostris/Ostris_Real_v1.safetensors"
57
+ is_v2: false # for v2 models
58
+ is_v_pred: false # for v-prediction models (most v2 models)
59
+ is_xl: false # for SDXL models
60
+ dtype: bf16
config/examples/mod_lora_scale.yaml ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ job: mod
3
+ config:
4
+ name: name_of_your_model_v1
5
+ process:
6
+ - type: rescale_lora
7
+ # path to your current lora model
8
+ input_path: "/path/to/lora/lora.safetensors"
9
+ # output path for your new lora model, can be the same as input_path to replace
10
+ output_path: "/path/to/lora/output_lora_v1.safetensors"
11
+ # replaces meta with the meta below (plus minimum meta fields)
12
+ # if false, we will leave the meta alone except for updating hashes (sd-script hashes)
13
+ replace_meta: true
14
+ # how to adjust, we can scale the up_down weights or the alpha
15
+ # up_down is the default and probably the best, they will both net the same outputs
16
+ # would only affect rare NaN cases and maybe merging with old merge tools
17
+ scale_target: 'up_down'
18
+ # precision to save, fp16 is the default and standard
19
+ save_dtype: fp16
20
+ # current_weight is the ideal weight you use as a multiplier when using the lora
21
+ # IE in automatic1111 <lora:my_lora:6.0> the 6.0 is the current_weight
22
+ # you can do negatives here too if you want to flip the lora
23
+ current_weight: 6.0
24
+ # target_weight is the ideal weight you use as a multiplier when using the lora
25
+ # instead of the one above. IE in automatic1111 instead of using <lora:my_lora:6.0>
26
+ # we want to use <lora:my_lora:1.0> so 1.0 is the target_weight
27
+ target_weight: 1.0
28
+
29
+ # base model for the lora
30
+ # this is just used to add meta so automatic111 knows which model it is for
31
+ # assume v1.5 if these are not set
32
+ is_xl: false
33
+ is_v2: false
34
+ meta:
35
+ # this is only used if you set replace_meta to true above
36
+ name: "[name]" # [name] gets replaced with the name above
37
+ description: A short description of your lora
38
+ trigger_words:
39
+ - put
40
+ - trigger
41
+ - words
42
+ - here
43
+ version: '0.1'
44
+ creator:
45
+ name: Your Name
46
47
+ website: https://yourwebsite.com
48
+ any: All meta data above is arbitrary, it can be whatever you want.
config/examples/modal/modal_train_lora_flux_24gb.yaml ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ job: extension
3
+ config:
4
+ # this name will be the folder and filename name
5
+ name: "my_first_flux_lora_v1"
6
+ process:
7
+ - type: 'sd_trainer'
8
+ # root folder to save training sessions/samples/weights
9
+ training_folder: "/root/ai-toolkit/modal_output" # must match MOUNT_DIR from run_modal.py
10
+ # uncomment to see performance stats in the terminal every N steps
11
+ # performance_log_every: 1000
12
+ device: cuda:0
13
+ # if a trigger word is specified, it will be added to captions of training data if it does not already exist
14
+ # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
15
+ # trigger_word: "p3r5on"
16
+ network:
17
+ type: "lora"
18
+ linear: 16
19
+ linear_alpha: 16
20
+ save:
21
+ dtype: float16 # precision to save
22
+ save_every: 250 # save every this many steps
23
+ max_step_saves_to_keep: 4 # how many intermittent saves to keep
24
+ datasets:
25
+ # datasets are a folder of images. captions need to be txt files with the same name as the image
26
+ # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
27
+ # images will automatically be resized and bucketed into the resolution specified
28
+ # on windows, escape back slashes with another backslash so
29
+ # "C:\\path\\to\\images\\folder"
30
+ # your dataset must be placed in /ai-toolkit and /root is for modal to find the dir:
31
+ - folder_path: "/root/ai-toolkit/your-dataset"
32
+ caption_ext: "txt"
33
+ caption_dropout_rate: 0.05 # will drop out the caption 5% of time
34
+ shuffle_tokens: false # shuffle caption order, split by commas
35
+ cache_latents_to_disk: true # leave this true unless you know what you're doing
36
+ resolution: [ 512, 768, 1024 ] # flux enjoys multiple resolutions
37
+ train:
38
+ batch_size: 1
39
+ steps: 2000 # total number of steps to train 500 - 4000 is a good range
40
+ gradient_accumulation_steps: 1
41
+ train_unet: true
42
+ train_text_encoder: false # probably won't work with flux
43
+ gradient_checkpointing: true # need the on unless you have a ton of vram
44
+ noise_scheduler: "flowmatch" # for training only
45
+ optimizer: "adamw8bit"
46
+ lr: 1e-4
47
+ # uncomment this to skip the pre training sample
48
+ # skip_first_sample: true
49
+ # uncomment to completely disable sampling
50
+ # disable_sampling: true
51
+ # uncomment to use new vell curved weighting. Experimental but may produce better results
52
+ # linear_timesteps: true
53
+
54
+ # ema will smooth out learning, but could slow it down. Recommended to leave on.
55
+ ema_config:
56
+ use_ema: true
57
+ ema_decay: 0.99
58
+
59
+ # will probably need this if gpu supports it for flux, other dtypes may not work correctly
60
+ dtype: bf16
61
+ model:
62
+ # huggingface model name or path
63
+ # if you get an error, or get stuck while downloading,
64
+ # check https://github.com/ostris/ai-toolkit/issues/84, download the model locally and
65
+ # place it like "/root/ai-toolkit/FLUX.1-dev"
66
+ name_or_path: "black-forest-labs/FLUX.1-dev"
67
+ is_flux: true
68
+ quantize: true # run 8bit mixed precision
69
+ # low_vram: true # uncomment this if the GPU is connected to your monitors. It will use less vram to quantize, but is slower.
70
+ sample:
71
+ sampler: "flowmatch" # must match train.noise_scheduler
72
+ sample_every: 250 # sample every this many steps
73
+ width: 1024
74
+ height: 1024
75
+ prompts:
76
+ # you can add [trigger] to the prompts here and it will be replaced with the trigger word
77
+ # - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
78
+ - "woman with red hair, playing chess at the park, bomb going off in the background"
79
+ - "a woman holding a coffee cup, in a beanie, sitting at a cafe"
80
+ - "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
81
+ - "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
82
+ - "a bear building a log cabin in the snow covered mountains"
83
+ - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
84
+ - "hipster man with a beard, building a chair, in a wood shop"
85
+ - "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop"
86
+ - "a man holding a sign that says, 'this is a sign'"
87
+ - "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
88
+ neg: "" # not used on flux
89
+ seed: 42
90
+ walk_seed: true
91
+ guidance_scale: 4
92
+ sample_steps: 20
93
+ # you can add any additional meta info here. [name] is replaced with config name at top
94
+ meta:
95
+ name: "[name]"
96
+ version: '1.0'
config/examples/modal/modal_train_lora_flux_schnell_24gb.yaml ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ job: extension
3
+ config:
4
+ # this name will be the folder and filename name
5
+ name: "my_first_flux_lora_v1"
6
+ process:
7
+ - type: 'sd_trainer'
8
+ # root folder to save training sessions/samples/weights
9
+ training_folder: "/root/ai-toolkit/modal_output" # must match MOUNT_DIR from run_modal.py
10
+ # uncomment to see performance stats in the terminal every N steps
11
+ # performance_log_every: 1000
12
+ device: cuda:0
13
+ # if a trigger word is specified, it will be added to captions of training data if it does not already exist
14
+ # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
15
+ # trigger_word: "p3r5on"
16
+ network:
17
+ type: "lora"
18
+ linear: 16
19
+ linear_alpha: 16
20
+ save:
21
+ dtype: float16 # precision to save
22
+ save_every: 250 # save every this many steps
23
+ max_step_saves_to_keep: 4 # how many intermittent saves to keep
24
+ datasets:
25
+ # datasets are a folder of images. captions need to be txt files with the same name as the image
26
+ # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
27
+ # images will automatically be resized and bucketed into the resolution specified
28
+ # on windows, escape back slashes with another backslash so
29
+ # "C:\\path\\to\\images\\folder"
30
+ # your dataset must be placed in /ai-toolkit and /root is for modal to find the dir:
31
+ - folder_path: "/root/ai-toolkit/your-dataset"
32
+ caption_ext: "txt"
33
+ caption_dropout_rate: 0.05 # will drop out the caption 5% of time
34
+ shuffle_tokens: false # shuffle caption order, split by commas
35
+ cache_latents_to_disk: true # leave this true unless you know what you're doing
36
+ resolution: [ 512, 768, 1024 ] # flux enjoys multiple resolutions
37
+ train:
38
+ batch_size: 1
39
+ steps: 2000 # total number of steps to train 500 - 4000 is a good range
40
+ gradient_accumulation_steps: 1
41
+ train_unet: true
42
+ train_text_encoder: false # probably won't work with flux
43
+ gradient_checkpointing: true # need the on unless you have a ton of vram
44
+ noise_scheduler: "flowmatch" # for training only
45
+ optimizer: "adamw8bit"
46
+ lr: 1e-4
47
+ # uncomment this to skip the pre training sample
48
+ # skip_first_sample: true
49
+ # uncomment to completely disable sampling
50
+ # disable_sampling: true
51
+ # uncomment to use new vell curved weighting. Experimental but may produce better results
52
+ # linear_timesteps: true
53
+
54
+ # ema will smooth out learning, but could slow it down. Recommended to leave on.
55
+ ema_config:
56
+ use_ema: true
57
+ ema_decay: 0.99
58
+
59
+ # will probably need this if gpu supports it for flux, other dtypes may not work correctly
60
+ dtype: bf16
61
+ model:
62
+ # huggingface model name or path
63
+ # if you get an error, or get stuck while downloading,
64
+ # check https://github.com/ostris/ai-toolkit/issues/84, download the models locally and
65
+ # place them like "/root/ai-toolkit/FLUX.1-schnell" and "/root/ai-toolkit/FLUX.1-schnell-training-adapter"
66
+ name_or_path: "black-forest-labs/FLUX.1-schnell"
67
+ assistant_lora_path: "ostris/FLUX.1-schnell-training-adapter" # Required for flux schnell training
68
+ is_flux: true
69
+ quantize: true # run 8bit mixed precision
70
+ # low_vram is painfully slow to fuse in the adapter avoid it unless absolutely necessary
71
+ # low_vram: true # uncomment this if the GPU is connected to your monitors. It will use less vram to quantize, but is slower.
72
+ sample:
73
+ sampler: "flowmatch" # must match train.noise_scheduler
74
+ sample_every: 250 # sample every this many steps
75
+ width: 1024
76
+ height: 1024
77
+ prompts:
78
+ # you can add [trigger] to the prompts here and it will be replaced with the trigger word
79
+ # - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
80
+ - "woman with red hair, playing chess at the park, bomb going off in the background"
81
+ - "a woman holding a coffee cup, in a beanie, sitting at a cafe"
82
+ - "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
83
+ - "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
84
+ - "a bear building a log cabin in the snow covered mountains"
85
+ - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
86
+ - "hipster man with a beard, building a chair, in a wood shop"
87
+ - "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop"
88
+ - "a man holding a sign that says, 'this is a sign'"
89
+ - "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
90
+ neg: "" # not used on flux
91
+ seed: 42
92
+ walk_seed: true
93
+ guidance_scale: 1 # schnell does not do guidance
94
+ sample_steps: 4 # 1 - 4 works well
95
+ # you can add any additional meta info here. [name] is replaced with config name at top
96
+ meta:
97
+ name: "[name]"
98
+ version: '1.0'
config/examples/train_flex_redux.yaml ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ job: extension
3
+ config:
4
+ # this name will be the folder and filename name
5
+ name: "my_first_flex_redux_finetune_v1"
6
+ process:
7
+ - type: 'sd_trainer'
8
+ # root folder to save training sessions/samples/weights
9
+ training_folder: "output"
10
+ # uncomment to see performance stats in the terminal every N steps
11
+ # performance_log_every: 1000
12
+ device: cuda:0
13
+ adapter:
14
+ type: "redux"
15
+ # you can finetune an existing adapter or start from scratch. Set to null to start from scratch
16
+ name_or_path: '/local/path/to/redux_adapter_to_finetune.safetensors'
17
+ # name_or_path: null
18
+ # image_encoder_path: 'google/siglip-so400m-patch14-384' # Flux.1 redux adapter
19
+ image_encoder_path: 'google/siglip2-so400m-patch16-512' # Flex.1 512 redux adapter
20
+ # image_encoder_arch: 'siglip' # for Flux.1
21
+ image_encoder_arch: 'siglip2'
22
+ # You need a control input for each sample. Best to do squares for both images
23
+ test_img_path:
24
+ - "/path/to/x_01.jpg"
25
+ - "/path/to/x_02.jpg"
26
+ - "/path/to/x_03.jpg"
27
+ - "/path/to/x_04.jpg"
28
+ - "/path/to/x_05.jpg"
29
+ - "/path/to/x_06.jpg"
30
+ - "/path/to/x_07.jpg"
31
+ - "/path/to/x_08.jpg"
32
+ - "/path/to/x_09.jpg"
33
+ - "/path/to/x_10.jpg"
34
+ clip_layer: 'last_hidden_state'
35
+ train: true
36
+ save:
37
+ dtype: bf16 # precision to save
38
+ save_every: 250 # save every this many steps
39
+ max_step_saves_to_keep: 4
40
+ datasets:
41
+ # datasets are a folder of images. captions need to be txt files with the same name as the image
42
+ # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
43
+ # images will automatically be resized and bucketed into the resolution specified
44
+ # on windows, escape back slashes with another backslash so
45
+ # "C:\\path\\to\\images\\folder"
46
+ - folder_path: "/path/to/images/folder"
47
+ # clip_image_path is directory containting your control images. They must have filename as their train image. (extension does not matter)
48
+ # for normal redux, we are just recreating the same image, so you can use the same folder path above
49
+ clip_image_path: "/path/to/control/images/folder"
50
+ caption_ext: "txt"
51
+ caption_dropout_rate: 0.05 # will drop out the caption 5% of time
52
+ resolution: [ 512, 768, 1024 ] # flex enjoys multiple resolutions
53
+ train:
54
+ # this is what I used for the 24GB card, but feel free to adjust
55
+ # total batch size is 6 here
56
+ batch_size: 3
57
+ gradient_accumulation: 2
58
+
59
+ # captions are not needed for this training, we cache a blank proompt and rely on the vision encoder
60
+ unload_text_encoder: true
61
+
62
+ loss_type: "mse"
63
+ train_unet: true
64
+ train_text_encoder: false
65
+ steps: 4000000 # I set this very high and stop when I like the results
66
+ content_or_style: balanced # content, style, balanced
67
+ gradient_checkpointing: true
68
+ noise_scheduler: "flowmatch" # or "ddpm", "lms", "euler_a"
69
+ timestep_type: "flux_shift"
70
+ optimizer: "adamw8bit"
71
+ lr: 1e-4
72
+
73
+ # this is for Flex.1, comment this out for FLUX.1-dev
74
+ bypass_guidance_embedding: true
75
+
76
+ dtype: bf16
77
+ ema_config:
78
+ use_ema: true
79
+ ema_decay: 0.99
80
+ model:
81
+ name_or_path: "ostris/Flex.1-alpha"
82
+ is_flux: true
83
+ quantize: true
84
+ text_encoder_bits: 8
85
+ sample:
86
+ sampler: "flowmatch" # must match train.noise_scheduler
87
+ sample_every: 250 # sample every this many steps
88
+ width: 1024
89
+ height: 1024
90
+ # I leave half blank to test prompt and unprompted
91
+ prompts:
92
+ - "woman with red hair, playing chess at the park, bomb going off in the background"
93
+ - "a woman holding a coffee cup, in a beanie, sitting at a cafe"
94
+ - "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
95
+ - "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
96
+ - "a bear building a log cabin in the snow covered mountains"
97
+ - ""
98
+ - ""
99
+ - ""
100
+ - ""
101
+ - ""
102
+ neg: ""
103
+ seed: 42
104
+ walk_seed: true
105
+ guidance_scale: 4
106
+ sample_steps: 25
107
+ network_multiplier: 1.0
108
+
109
+ # you can add any additional meta info here. [name] is replaced with config name at top
110
+ meta:
111
+ name: "[name]"
112
+ version: '1.0'
config/examples/train_full_fine_tune_flex.yaml ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ # This configuration requires 48GB of VRAM or more to operate
3
+ job: extension
4
+ config:
5
+ # this name will be the folder and filename name
6
+ name: "my_first_flex_finetune_v1"
7
+ process:
8
+ - type: 'sd_trainer'
9
+ # root folder to save training sessions/samples/weights
10
+ training_folder: "output"
11
+ # uncomment to see performance stats in the terminal every N steps
12
+ # performance_log_every: 1000
13
+ device: cuda:0
14
+ # if a trigger word is specified, it will be added to captions of training data if it does not already exist
15
+ # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
16
+ # trigger_word: "p3r5on"
17
+ save:
18
+ dtype: bf16 # precision to save
19
+ save_every: 250 # save every this many steps
20
+ max_step_saves_to_keep: 2 # how many intermittent saves to keep
21
+ save_format: 'diffusers' # 'diffusers'
22
+ datasets:
23
+ # datasets are a folder of images. captions need to be txt files with the same name as the image
24
+ # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
25
+ # images will automatically be resized and bucketed into the resolution specified
26
+ # on windows, escape back slashes with another backslash so
27
+ # "C:\\path\\to\\images\\folder"
28
+ - folder_path: "/path/to/images/folder"
29
+ caption_ext: "txt"
30
+ caption_dropout_rate: 0.05 # will drop out the caption 5% of time
31
+ shuffle_tokens: false # shuffle caption order, split by commas
32
+ # cache_latents_to_disk: true # leave this true unless you know what you're doing
33
+ resolution: [ 512, 768, 1024 ] # flex enjoys multiple resolutions
34
+ train:
35
+ batch_size: 1
36
+ # IMPORTANT! For Flex, you must bypass the guidance embedder during training
37
+ bypass_guidance_embedding: true
38
+
39
+ # can be 'sigmoid', 'linear', or 'lognorm_blend'
40
+ timestep_type: 'sigmoid'
41
+
42
+ steps: 2000 # total number of steps to train 500 - 4000 is a good range
43
+ gradient_accumulation: 1
44
+ train_unet: true
45
+ train_text_encoder: false # probably won't work with flex
46
+ gradient_checkpointing: true # need the on unless you have a ton of vram
47
+ noise_scheduler: "flowmatch" # for training only
48
+ optimizer: "adafactor"
49
+ lr: 3e-5
50
+
51
+ # Paramiter swapping can reduce vram requirements. Set factor from 1.0 to 0.0.
52
+ # 0.1 is 10% of paramiters active at easc step. Only works with adafactor
53
+
54
+ # do_paramiter_swapping: true
55
+ # paramiter_swapping_factor: 0.9
56
+
57
+ # uncomment this to skip the pre training sample
58
+ # skip_first_sample: true
59
+ # uncomment to completely disable sampling
60
+ # disable_sampling: true
61
+
62
+ # ema will smooth out learning, but could slow it down. Recommended to leave on if you have the vram
63
+ ema_config:
64
+ use_ema: true
65
+ ema_decay: 0.99
66
+
67
+ # will probably need this if gpu supports it for flex, other dtypes may not work correctly
68
+ dtype: bf16
69
+ model:
70
+ # huggingface model name or path
71
+ name_or_path: "ostris/Flex.1-alpha"
72
+ is_flux: true # flex is flux architecture
73
+ # full finetuning quantized models is a crapshoot and results in subpar outputs
74
+ # quantize: true
75
+ # you can quantize just the T5 text encoder here to save vram
76
+ quantize_te: true
77
+ # only train the transformer blocks
78
+ only_if_contains:
79
+ - "transformer.transformer_blocks."
80
+ - "transformer.single_transformer_blocks."
81
+ sample:
82
+ sampler: "flowmatch" # must match train.noise_scheduler
83
+ sample_every: 250 # sample every this many steps
84
+ width: 1024
85
+ height: 1024
86
+ prompts:
87
+ # you can add [trigger] to the prompts here and it will be replaced with the trigger word
88
+ # - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
89
+ - "woman with red hair, playing chess at the park, bomb going off in the background"
90
+ - "a woman holding a coffee cup, in a beanie, sitting at a cafe"
91
+ - "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
92
+ - "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
93
+ - "a bear building a log cabin in the snow covered mountains"
94
+ - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
95
+ - "hipster man with a beard, building a chair, in a wood shop"
96
+ - "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop"
97
+ - "a man holding a sign that says, 'this is a sign'"
98
+ - "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
99
+ neg: "" # not used on flex
100
+ seed: 42
101
+ walk_seed: true
102
+ guidance_scale: 4
103
+ sample_steps: 25
104
+ # you can add any additional meta info here. [name] is replaced with config name at top
105
+ meta:
106
+ name: "[name]"
107
+ version: '1.0'
config/examples/train_full_fine_tune_lumina.yaml ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ # This configuration requires 24GB of VRAM or more to operate
3
+ job: extension
4
+ config:
5
+ # this name will be the folder and filename name
6
+ name: "my_first_lumina_finetune_v1"
7
+ process:
8
+ - type: 'sd_trainer'
9
+ # root folder to save training sessions/samples/weights
10
+ training_folder: "output"
11
+ # uncomment to see performance stats in the terminal every N steps
12
+ # performance_log_every: 1000
13
+ device: cuda:0
14
+ # if a trigger word is specified, it will be added to captions of training data if it does not already exist
15
+ # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
16
+ # trigger_word: "p3r5on"
17
+ save:
18
+ dtype: bf16 # precision to save
19
+ save_every: 250 # save every this many steps
20
+ max_step_saves_to_keep: 2 # how many intermittent saves to keep
21
+ save_format: 'diffusers' # 'diffusers'
22
+ datasets:
23
+ # datasets are a folder of images. captions need to be txt files with the same name as the image
24
+ # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
25
+ # images will automatically be resized and bucketed into the resolution specified
26
+ # on windows, escape back slashes with another backslash so
27
+ # "C:\\path\\to\\images\\folder"
28
+ - folder_path: "/path/to/images/folder"
29
+ caption_ext: "txt"
30
+ caption_dropout_rate: 0.05 # will drop out the caption 5% of time
31
+ shuffle_tokens: false # shuffle caption order, split by commas
32
+ # cache_latents_to_disk: true # leave this true unless you know what you're doing
33
+ resolution: [ 512, 768, 1024 ] # lumina2 enjoys multiple resolutions
34
+ train:
35
+ batch_size: 1
36
+
37
+ # can be 'sigmoid', 'linear', or 'lumina2_shift'
38
+ timestep_type: 'lumina2_shift'
39
+
40
+ steps: 2000 # total number of steps to train 500 - 4000 is a good range
41
+ gradient_accumulation: 1
42
+ train_unet: true
43
+ train_text_encoder: false # probably won't work with lumina2
44
+ gradient_checkpointing: true # need the on unless you have a ton of vram
45
+ noise_scheduler: "flowmatch" # for training only
46
+ optimizer: "adafactor"
47
+ lr: 3e-5
48
+
49
+ # Paramiter swapping can reduce vram requirements. Set factor from 1.0 to 0.0.
50
+ # 0.1 is 10% of paramiters active at easc step. Only works with adafactor
51
+
52
+ # do_paramiter_swapping: true
53
+ # paramiter_swapping_factor: 0.9
54
+
55
+ # uncomment this to skip the pre training sample
56
+ # skip_first_sample: true
57
+ # uncomment to completely disable sampling
58
+ # disable_sampling: true
59
+
60
+ # ema will smooth out learning, but could slow it down. Recommended to leave on if you have the vram
61
+ # ema_config:
62
+ # use_ema: true
63
+ # ema_decay: 0.99
64
+
65
+ # will probably need this if gpu supports it for lumina2, other dtypes may not work correctly
66
+ dtype: bf16
67
+ model:
68
+ # huggingface model name or path
69
+ name_or_path: "Alpha-VLLM/Lumina-Image-2.0"
70
+ is_lumina2: true # lumina2 architecture
71
+ # you can quantize just the Gemma2 text encoder here to save vram
72
+ quantize_te: true
73
+ sample:
74
+ sampler: "flowmatch" # must match train.noise_scheduler
75
+ sample_every: 250 # sample every this many steps
76
+ width: 1024
77
+ height: 1024
78
+ prompts:
79
+ # you can add [trigger] to the prompts here and it will be replaced with the trigger word
80
+ # - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
81
+ - "woman with red hair, playing chess at the park, bomb going off in the background"
82
+ - "a woman holding a coffee cup, in a beanie, sitting at a cafe"
83
+ - "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
84
+ - "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
85
+ - "a bear building a log cabin in the snow covered mountains"
86
+ - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
87
+ - "hipster man with a beard, building a chair, in a wood shop"
88
+ - "photo of a cat that is half black and half orange tabby, split down the middle. The cat has on a blue tophat. They are holding a martini glass with a pink ball of yarn in it with green knitting needles sticking out, in one paw. In the other paw, they are holding a DVD case for a movie titled, \"This is a test\" that has a golden robot on it. In the background is a busy night club with a giant mushroom man dancing with a bear."
89
+ - "a man holding a sign that says, 'this is a sign'"
90
+ - "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
91
+ neg: ""
92
+ seed: 42
93
+ walk_seed: true
94
+ guidance_scale: 4.0
95
+ sample_steps: 25
96
+ # you can add any additional meta info here. [name] is replaced with config name at top
97
+ meta:
98
+ name: "[name]"
99
+ version: '1.0'
config/examples/train_lora_chroma_24gb.yaml ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ job: extension
3
+ config:
4
+ # this name will be the folder and filename name
5
+ name: "my_first_chroma_lora_v1"
6
+ process:
7
+ - type: 'sd_trainer'
8
+ # root folder to save training sessions/samples/weights
9
+ training_folder: "output"
10
+ # uncomment to see performance stats in the terminal every N steps
11
+ # performance_log_every: 1000
12
+ device: cuda:0
13
+ # if a trigger word is specified, it will be added to captions of training data if it does not already exist
14
+ # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
15
+ # trigger_word: "p3r5on"
16
+ network:
17
+ type: "lora"
18
+ linear: 16
19
+ linear_alpha: 16
20
+ save:
21
+ dtype: float16 # precision to save
22
+ save_every: 250 # save every this many steps
23
+ max_step_saves_to_keep: 4 # how many intermittent saves to keep
24
+ push_to_hub: false #change this to True to push your trained model to Hugging Face.
25
+ # You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
26
+ # hf_repo_id: your-username/your-model-slug
27
+ # hf_private: true #whether the repo is private or public
28
+ datasets:
29
+ # datasets are a folder of images. captions need to be txt files with the same name as the image
30
+ # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
31
+ # images will automatically be resized and bucketed into the resolution specified
32
+ # on windows, escape back slashes with another backslash so
33
+ # "C:\\path\\to\\images\\folder"
34
+ - folder_path: "/path/to/images/folder"
35
+ caption_ext: "txt"
36
+ caption_dropout_rate: 0.05 # will drop out the caption 5% of time
37
+ shuffle_tokens: false # shuffle caption order, split by commas
38
+ cache_latents_to_disk: true # leave this true unless you know what you're doing
39
+ resolution: [ 512, 768, 1024 ] # chroma enjoys multiple resolutions
40
+ train:
41
+ batch_size: 1
42
+ steps: 2000 # total number of steps to train 500 - 4000 is a good range
43
+ gradient_accumulation: 1
44
+ train_unet: true
45
+ train_text_encoder: false # probably won't work with chroma
46
+ gradient_checkpointing: true # need the on unless you have a ton of vram
47
+ noise_scheduler: "flowmatch" # for training only
48
+ optimizer: "adamw8bit"
49
+ lr: 1e-4
50
+ # uncomment this to skip the pre training sample
51
+ # skip_first_sample: true
52
+ # uncomment to completely disable sampling
53
+ # disable_sampling: true
54
+ # uncomment to use new vell curved weighting. Experimental but may produce better results
55
+ # linear_timesteps: true
56
+
57
+ # ema will smooth out learning, but could slow it down. Recommended to leave on.
58
+ ema_config:
59
+ use_ema: true
60
+ ema_decay: 0.99
61
+
62
+ # will probably need this if gpu supports it for chroma, other dtypes may not work correctly
63
+ dtype: bf16
64
+ model:
65
+ # Download the whichever model you prefer from the Chroma repo
66
+ # https://huggingface.co/lodestones/Chroma/tree/main
67
+ # point to it here.
68
+ # name_or_path: "/path/to/chroma/chroma-unlocked-vVERSION.safetensors"
69
+
70
+ # using lodestones/Chroma will automatically use the latest version
71
+ name_or_path: "lodestones/Chroma"
72
+
73
+ # # You can also select a version of Chroma like so
74
+ # name_or_path: "lodestones/Chroma/v28"
75
+
76
+ arch: "chroma"
77
+ quantize: true # run 8bit mixed precision
78
+ sample:
79
+ sampler: "flowmatch" # must match train.noise_scheduler
80
+ sample_every: 250 # sample every this many steps
81
+ width: 1024
82
+ height: 1024
83
+ prompts:
84
+ # you can add [trigger] to the prompts here and it will be replaced with the trigger word
85
+ # - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
86
+ - "woman with red hair, playing chess at the park, bomb going off in the background"
87
+ - "a woman holding a coffee cup, in a beanie, sitting at a cafe"
88
+ - "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
89
+ - "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
90
+ - "a bear building a log cabin in the snow covered mountains"
91
+ - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
92
+ - "hipster man with a beard, building a chair, in a wood shop"
93
+ - "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop"
94
+ - "a man holding a sign that says, 'this is a sign'"
95
+ - "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
96
+ neg: "" # negative prompt, optional
97
+ seed: 42
98
+ walk_seed: true
99
+ guidance_scale: 4
100
+ sample_steps: 25
101
+ # you can add any additional meta info here. [name] is replaced with config name at top
102
+ meta:
103
+ name: "[name]"
104
+ version: '1.0'
config/examples/train_lora_flex2_24gb.yaml ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Note, Flex2 is a highly experimental WIP model. Finetuning a model with built in controls and inpainting has not
2
+ # been done before, so you will be experimenting with me on how to do it. This is my recommended setup, but this is highly
3
+ # subject to change as we learn more about how Flex2 works.
4
+
5
+ ---
6
+ job: extension
7
+ config:
8
+ # this name will be the folder and filename name
9
+ name: "my_first_flex2_lora_v1"
10
+ process:
11
+ - type: 'sd_trainer'
12
+ # root folder to save training sessions/samples/weights
13
+ training_folder: "output"
14
+ # uncomment to see performance stats in the terminal every N steps
15
+ # performance_log_every: 1000
16
+ device: cuda:0
17
+ # if a trigger word is specified, it will be added to captions of training data if it does not already exist
18
+ # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
19
+ # trigger_word: "p3r5on"
20
+ network:
21
+ type: "lora"
22
+ linear: 32
23
+ linear_alpha: 32
24
+ save:
25
+ dtype: float16 # precision to save
26
+ save_every: 250 # save every this many steps
27
+ max_step_saves_to_keep: 4 # how many intermittent saves to keep
28
+ push_to_hub: false #change this to True to push your trained model to Hugging Face.
29
+ # You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
30
+ # hf_repo_id: your-username/your-model-slug
31
+ # hf_private: true #whether the repo is private or public
32
+ datasets:
33
+ # datasets are a folder of images. captions need to be txt files with the same name as the image
34
+ # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
35
+ # images will automatically be resized and bucketed into the resolution specified
36
+ # on windows, escape back slashes with another backslash so
37
+ # "C:\\path\\to\\images\\folder"
38
+ - folder_path: "/path/to/images/folder"
39
+ # Flex2 is trained with controls and inpainting. If you want the model to truely understand how the
40
+ # controls function with your dataset, it is a good idea to keep doing controls during training.
41
+ # this will automatically generate the controls for you before training. The current script is not
42
+ # fully optimized so this could be rather slow for large datasets, but it caches them to disk so it
43
+ # only needs to be done once. If you want to skip this step, you can set the controls to [] and it will
44
+ controls:
45
+ - "depth"
46
+ - "line"
47
+ - "pose"
48
+ - "inpaint"
49
+
50
+ # you can make custom inpainting images as well. These images must be webp or png format with an alpha.
51
+ # just erase the part of the image you want to inpaint and save it as a webp or png. Again, erase your
52
+ # train target. So the person if training a person. The automatic controls above with inpaint will
53
+ # just run a background remover mask and erase the foreground, which works well for subjects.
54
+
55
+ # inpaint_path: "/my/impaint/images"
56
+
57
+ # you can also specify existing control image pairs. It can handle multiple groups and will randomly
58
+ # select one for each step.
59
+
60
+ # control_path:
61
+ # - "/my/custom/control/images"
62
+ # - "/my/custom/control/images2"
63
+
64
+ caption_ext: "txt"
65
+ caption_dropout_rate: 0.05 # will drop out the caption 5% of time
66
+ resolution: [ 512, 768, 1024 ] # flex2 enjoys multiple resolutions
67
+ train:
68
+ batch_size: 1
69
+ # IMPORTANT! For Flex2, you must bypass the guidance embedder during training
70
+ bypass_guidance_embedding: true
71
+
72
+ steps: 3000 # total number of steps to train 500 - 4000 is a good range
73
+ gradient_accumulation: 1
74
+ train_unet: true
75
+ train_text_encoder: false # probably won't work with flex2
76
+ gradient_checkpointing: true # need the on unless you have a ton of vram
77
+ noise_scheduler: "flowmatch" # for training only
78
+ # shift works well for training fast and learning composition and style.
79
+ # for just subject, you may want to change this to sigmoid
80
+ timestep_type: 'shift' # 'linear', 'sigmoid', 'shift'
81
+ optimizer: "adamw8bit"
82
+ lr: 1e-4
83
+
84
+ optimizer_params:
85
+ weight_decay: 1e-5
86
+ # uncomment this to skip the pre training sample
87
+ # skip_first_sample: true
88
+ # uncomment to completely disable sampling
89
+ # disable_sampling: true
90
+ # uncomment to use new vell curved weighting. Experimental but may produce better results
91
+ # linear_timesteps: true
92
+
93
+ # ema will smooth out learning, but could slow it down. Defaults off
94
+ ema_config:
95
+ use_ema: false
96
+ ema_decay: 0.99
97
+
98
+ # will probably need this if gpu supports it for flex, other dtypes may not work correctly
99
+ dtype: bf16
100
+ model:
101
+ # huggingface model name or path
102
+ name_or_path: "ostris/Flex.2-preview"
103
+ arch: "flex2"
104
+ quantize: true # run 8bit mixed precision
105
+ quantize_te: true
106
+
107
+ # you can pass special training infor for controls to the model here
108
+ # percentages are decimal based so 0.0 is 0% and 1.0 is 100% of the time.
109
+ model_kwargs:
110
+ # inverts the inpainting mask, good to learn outpainting as well, recommended 0.0 for characters
111
+ invert_inpaint_mask_chance: 0.5
112
+ # this will do a normal t2i training step without inpaint when dropped out. REcommended if you want
113
+ # your lora to be able to inference with and without inpainting.
114
+ inpaint_dropout: 0.5
115
+ # randomly drops out the control image. Dropout recvommended if your want it to work without controls as well.
116
+ control_dropout: 0.5
117
+ # does a random inpaint blob. Usually a good idea to keep. Without it, the model will learn to always 100%
118
+ # fill the inpaint area with your subject. This is not always a good thing.
119
+ inpaint_random_chance: 0.5
120
+ # generates random inpaint blobs if you did not provide an inpaint image for your dataset. Inpaint breaks down fast
121
+ # if you are not training with it. Controls are a little more robust and can be left out,
122
+ # but when in doubt, always leave this on
123
+ do_random_inpainting: false
124
+ # does random blurring of the inpaint mask. Helps prevent weird edge artifacts for real workd inpainting. Leave on.
125
+ random_blur_mask: true
126
+ # applies a small amount of random dialition and restriction to the inpaint mask. Helps with edge artifacts.
127
+ # Leave on.
128
+ random_dialate_mask: true
129
+ sample:
130
+ sampler: "flowmatch" # must match train.noise_scheduler
131
+ sample_every: 250 # sample every this many steps
132
+ width: 1024
133
+ height: 1024
134
+ prompts:
135
+ # you can add [trigger] to the prompts here and it will be replaced with the trigger word
136
+ # - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
137
+
138
+ # you can use a single inpaint or single control image on your samples.
139
+ # for controls, the ctrl_idx is 1, the images can be any name and image format.
140
+ # use either a pose/line/depth image or whatever you are training with. An example is
141
+ # - "photo of [trigger] --ctrl_idx 1 --ctrl_img /path/to/control/image.jpg"
142
+
143
+ # for an inpainting image, it must be png/webp. Erase the part of the image you want to inpaint
144
+ # IMPORTANT! the inpaint images must be ctrl_idx 0 and have .inpaint.{ext} in the name for this to work right.
145
+ # - "photo of [trigger] --ctrl_idx 0 --ctrl_img /path/to/inpaint/image.inpaint.png"
146
+
147
+ - "woman with red hair, playing chess at the park, bomb going off in the background"
148
+ - "a woman holding a coffee cup, in a beanie, sitting at a cafe"
149
+ - "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
150
+ - "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
151
+ - "a bear building a log cabin in the snow covered mountains"
152
+ - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
153
+ - "hipster man with a beard, building a chair, in a wood shop"
154
+ - "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop"
155
+ - "a man holding a sign that says, 'this is a sign'"
156
+ - "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
157
+ neg: "" # not used on flex2
158
+ seed: 42
159
+ walk_seed: true
160
+ guidance_scale: 4
161
+ sample_steps: 25
162
+ # you can add any additional meta info here. [name] is replaced with config name at top
163
+ meta:
164
+ name: "[name]"
165
+ version: '1.0'
config/examples/train_lora_flex_24gb.yaml ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ job: extension
3
+ config:
4
+ # this name will be the folder and filename name
5
+ name: "my_first_flex_lora_v1"
6
+ process:
7
+ - type: 'sd_trainer'
8
+ # root folder to save training sessions/samples/weights
9
+ training_folder: "output"
10
+ # uncomment to see performance stats in the terminal every N steps
11
+ # performance_log_every: 1000
12
+ device: cuda:0
13
+ # if a trigger word is specified, it will be added to captions of training data if it does not already exist
14
+ # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
15
+ # trigger_word: "p3r5on"
16
+ network:
17
+ type: "lora"
18
+ linear: 16
19
+ linear_alpha: 16
20
+ save:
21
+ dtype: float16 # precision to save
22
+ save_every: 250 # save every this many steps
23
+ max_step_saves_to_keep: 4 # how many intermittent saves to keep
24
+ push_to_hub: false #change this to True to push your trained model to Hugging Face.
25
+ # You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
26
+ # hf_repo_id: your-username/your-model-slug
27
+ # hf_private: true #whether the repo is private or public
28
+ datasets:
29
+ # datasets are a folder of images. captions need to be txt files with the same name as the image
30
+ # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
31
+ # images will automatically be resized and bucketed into the resolution specified
32
+ # on windows, escape back slashes with another backslash so
33
+ # "C:\\path\\to\\images\\folder"
34
+ - folder_path: "/path/to/images/folder"
35
+ caption_ext: "txt"
36
+ caption_dropout_rate: 0.05 # will drop out the caption 5% of time
37
+ shuffle_tokens: false # shuffle caption order, split by commas
38
+ cache_latents_to_disk: true # leave this true unless you know what you're doing
39
+ resolution: [ 512, 768, 1024 ] # flex enjoys multiple resolutions
40
+ train:
41
+ batch_size: 1
42
+ # IMPORTANT! For Flex, you must bypass the guidance embedder during training
43
+ bypass_guidance_embedding: true
44
+
45
+ steps: 2000 # total number of steps to train 500 - 4000 is a good range
46
+ gradient_accumulation: 1
47
+ train_unet: true
48
+ train_text_encoder: false # probably won't work with flex
49
+ gradient_checkpointing: true # need the on unless you have a ton of vram
50
+ noise_scheduler: "flowmatch" # for training only
51
+ optimizer: "adamw8bit"
52
+ lr: 1e-4
53
+ # uncomment this to skip the pre training sample
54
+ # skip_first_sample: true
55
+ # uncomment to completely disable sampling
56
+ # disable_sampling: true
57
+ # uncomment to use new vell curved weighting. Experimental but may produce better results
58
+ # linear_timesteps: true
59
+
60
+ # ema will smooth out learning, but could slow it down. Recommended to leave on.
61
+ ema_config:
62
+ use_ema: true
63
+ ema_decay: 0.99
64
+
65
+ # will probably need this if gpu supports it for flex, other dtypes may not work correctly
66
+ dtype: bf16
67
+ model:
68
+ # huggingface model name or path
69
+ name_or_path: "ostris/Flex.1-alpha"
70
+ is_flux: true
71
+ quantize: true # run 8bit mixed precision
72
+ quantize_kwargs:
73
+ exclude:
74
+ - "*time_text_embed*" # exclude the time text embedder from quantization
75
+ sample:
76
+ sampler: "flowmatch" # must match train.noise_scheduler
77
+ sample_every: 250 # sample every this many steps
78
+ width: 1024
79
+ height: 1024
80
+ prompts:
81
+ # you can add [trigger] to the prompts here and it will be replaced with the trigger word
82
+ # - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
83
+ - "woman with red hair, playing chess at the park, bomb going off in the background"
84
+ - "a woman holding a coffee cup, in a beanie, sitting at a cafe"
85
+ - "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
86
+ - "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
87
+ - "a bear building a log cabin in the snow covered mountains"
88
+ - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
89
+ - "hipster man with a beard, building a chair, in a wood shop"
90
+ - "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop"
91
+ - "a man holding a sign that says, 'this is a sign'"
92
+ - "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
93
+ neg: "" # not used on flex
94
+ seed: 42
95
+ walk_seed: true
96
+ guidance_scale: 4
97
+ sample_steps: 25
98
+ # you can add any additional meta info here. [name] is replaced with config name at top
99
+ meta:
100
+ name: "[name]"
101
+ version: '1.0'
config/examples/train_lora_flux_24gb.yaml ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ job: extension
3
+ config:
4
+ # this name will be the folder and filename name
5
+ name: "my_first_flux_lora_v1"
6
+ process:
7
+ - type: 'sd_trainer'
8
+ # root folder to save training sessions/samples/weights
9
+ training_folder: "output"
10
+ # uncomment to see performance stats in the terminal every N steps
11
+ # performance_log_every: 1000
12
+ device: cuda:0
13
+ # if a trigger word is specified, it will be added to captions of training data if it does not already exist
14
+ # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
15
+ # trigger_word: "p3r5on"
16
+ network:
17
+ type: "lora"
18
+ linear: 16
19
+ linear_alpha: 16
20
+ save:
21
+ dtype: float16 # precision to save
22
+ save_every: 250 # save every this many steps
23
+ max_step_saves_to_keep: 4 # how many intermittent saves to keep
24
+ push_to_hub: false #change this to True to push your trained model to Hugging Face.
25
+ # You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
26
+ # hf_repo_id: your-username/your-model-slug
27
+ # hf_private: true #whether the repo is private or public
28
+ datasets:
29
+ # datasets are a folder of images. captions need to be txt files with the same name as the image
30
+ # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
31
+ # images will automatically be resized and bucketed into the resolution specified
32
+ # on windows, escape back slashes with another backslash so
33
+ # "C:\\path\\to\\images\\folder"
34
+ - folder_path: "/path/to/images/folder"
35
+ caption_ext: "txt"
36
+ caption_dropout_rate: 0.05 # will drop out the caption 5% of time
37
+ shuffle_tokens: false # shuffle caption order, split by commas
38
+ cache_latents_to_disk: true # leave this true unless you know what you're doing
39
+ resolution: [ 512, 768, 1024 ] # flux enjoys multiple resolutions
40
+ train:
41
+ batch_size: 1
42
+ steps: 2000 # total number of steps to train 500 - 4000 is a good range
43
+ gradient_accumulation_steps: 1
44
+ train_unet: true
45
+ train_text_encoder: false # probably won't work with flux
46
+ gradient_checkpointing: true # need the on unless you have a ton of vram
47
+ noise_scheduler: "flowmatch" # for training only
48
+ optimizer: "adamw8bit"
49
+ lr: 1e-4
50
+ # uncomment this to skip the pre training sample
51
+ # skip_first_sample: true
52
+ # uncomment to completely disable sampling
53
+ # disable_sampling: true
54
+ # uncomment to use new vell curved weighting. Experimental but may produce better results
55
+ # linear_timesteps: true
56
+
57
+ # ema will smooth out learning, but could slow it down. Recommended to leave on.
58
+ ema_config:
59
+ use_ema: true
60
+ ema_decay: 0.99
61
+
62
+ # will probably need this if gpu supports it for flux, other dtypes may not work correctly
63
+ dtype: bf16
64
+ model:
65
+ # huggingface model name or path
66
+ name_or_path: "black-forest-labs/FLUX.1-dev"
67
+ is_flux: true
68
+ quantize: true # run 8bit mixed precision
69
+ # low_vram: true # uncomment this if the GPU is connected to your monitors. It will use less vram to quantize, but is slower.
70
+ sample:
71
+ sampler: "flowmatch" # must match train.noise_scheduler
72
+ sample_every: 250 # sample every this many steps
73
+ width: 1024
74
+ height: 1024
75
+ prompts:
76
+ # you can add [trigger] to the prompts here and it will be replaced with the trigger word
77
+ # - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
78
+ - "woman with red hair, playing chess at the park, bomb going off in the background"
79
+ - "a woman holding a coffee cup, in a beanie, sitting at a cafe"
80
+ - "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
81
+ - "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
82
+ - "a bear building a log cabin in the snow covered mountains"
83
+ - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
84
+ - "hipster man with a beard, building a chair, in a wood shop"
85
+ - "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop"
86
+ - "a man holding a sign that says, 'this is a sign'"
87
+ - "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
88
+ neg: "" # not used on flux
89
+ seed: 42
90
+ walk_seed: true
91
+ guidance_scale: 4
92
+ sample_steps: 20
93
+ # you can add any additional meta info here. [name] is replaced with config name at top
94
+ meta:
95
+ name: "[name]"
96
+ version: '1.0'
config/examples/train_lora_flux_kontext_24gb.yaml ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ job: extension
3
+ config:
4
+ # this name will be the folder and filename name
5
+ name: "my_first_flux_kontext_lora_v1"
6
+ process:
7
+ - type: 'sd_trainer'
8
+ # root folder to save training sessions/samples/weights
9
+ training_folder: "output"
10
+ # uncomment to see performance stats in the terminal every N steps
11
+ # performance_log_every: 1000
12
+ device: cuda:0
13
+ # if a trigger word is specified, it will be added to captions of training data if it does not already exist
14
+ # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
15
+ # trigger_word: "p3r5on"
16
+ network:
17
+ type: "lora"
18
+ linear: 16
19
+ linear_alpha: 16
20
+ save:
21
+ dtype: float16 # precision to save
22
+ save_every: 250 # save every this many steps
23
+ max_step_saves_to_keep: 4 # how many intermittent saves to keep
24
+ push_to_hub: false #change this to True to push your trained model to Hugging Face.
25
+ # You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
26
+ # hf_repo_id: your-username/your-model-slug
27
+ # hf_private: true #whether the repo is private or public
28
+ datasets:
29
+ # datasets are a folder of images. captions need to be txt files with the same name as the image
30
+ # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
31
+ # images will automatically be resized and bucketed into the resolution specified
32
+ # on windows, escape back slashes with another backslash so
33
+ # "C:\\path\\to\\images\\folder"
34
+ - folder_path: "/path/to/images/folder"
35
+ # control path is the input images for kontext for a paired dataset. These are the source images you want to change.
36
+ # You can comment this out and only use normal images if you don't have a paired dataset.
37
+ # Control images need to match the filenames on the folder path but in
38
+ # a different folder. These do not need captions.
39
+ control_path: "/path/to/control/folder"
40
+ caption_ext: "txt"
41
+ caption_dropout_rate: 0.05 # will drop out the caption 5% of time
42
+ shuffle_tokens: false # shuffle caption order, split by commas
43
+ cache_latents_to_disk: true # leave this true unless you know what you're doing
44
+ # Kontext runs images in at 2x the latent size. It may OOM at 1024 resolution with 24GB vram.
45
+ resolution: [ 512, 768 ] # flux enjoys multiple resolutions
46
+ # resolution: [ 512, 768, 1024 ]
47
+ train:
48
+ batch_size: 1
49
+ steps: 3000 # total number of steps to train 500 - 4000 is a good range
50
+ gradient_accumulation_steps: 1
51
+ train_unet: true
52
+ train_text_encoder: false # probably won't work with flux
53
+ gradient_checkpointing: true # need the on unless you have a ton of vram
54
+ noise_scheduler: "flowmatch" # for training only
55
+ optimizer: "adamw8bit"
56
+ lr: 1e-4
57
+ timestep_type: "weighted" # sigmoid, linear, or weighted.
58
+ # uncomment this to skip the pre training sample
59
+ # skip_first_sample: true
60
+ # uncomment to completely disable sampling
61
+ # disable_sampling: true
62
+
63
+ # ema will smooth out learning, but could slow it down.
64
+
65
+ # ema_config:
66
+ # use_ema: true
67
+ # ema_decay: 0.99
68
+
69
+ # will probably need this if gpu supports it for flux, other dtypes may not work correctly
70
+ dtype: bf16
71
+ model:
72
+ # huggingface model name or path. This model is gated.
73
+ # visit https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev to accept the terms and conditions
74
+ # and then you can use this model.
75
+ name_or_path: "black-forest-labs/FLUX.1-Kontext-dev"
76
+ arch: "flux_kontext"
77
+ quantize: true # run 8bit mixed precision
78
+ # low_vram: true # uncomment this if the GPU is connected to your monitors. It will use less vram to quantize, but is slower.
79
+ sample:
80
+ sampler: "flowmatch" # must match train.noise_scheduler
81
+ sample_every: 250 # sample every this many steps
82
+ width: 1024
83
+ height: 1024
84
+ prompts:
85
+ # you can add [trigger] to the prompts here and it will be replaced with the trigger word
86
+ # the --ctrl_img path is the one loaded to apply the kontext editing to
87
+ # - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
88
+ - "make the person smile --ctrl_img /path/to/control/folder/person1.jpg"
89
+ - "give the person an afro --ctrl_img /path/to/control/folder/person1.jpg"
90
+ - "turn this image into a cartoon --ctrl_img /path/to/control/folder/person1.jpg"
91
+ - "put this person in an action film --ctrl_img /path/to/control/folder/person1.jpg"
92
+ - "make this person a rapper in a rap music video --ctrl_img /path/to/control/folder/person1.jpg"
93
+ - "make the person smile --ctrl_img /path/to/control/folder/person1.jpg"
94
+ - "give the person an afro --ctrl_img /path/to/control/folder/person1.jpg"
95
+ - "turn this image into a cartoon --ctrl_img /path/to/control/folder/person1.jpg"
96
+ - "put this person in an action film --ctrl_img /path/to/control/folder/person1.jpg"
97
+ - "make this person a rapper in a rap music video --ctrl_img /path/to/control/folder/person1.jpg"
98
+ neg: "" # not used on flux
99
+ seed: 42
100
+ walk_seed: true
101
+ guidance_scale: 4
102
+ sample_steps: 20
103
+ # you can add any additional meta info here. [name] is replaced with config name at top
104
+ meta:
105
+ name: "[name]"
106
+ version: '1.0'
config/examples/train_lora_flux_schnell_24gb.yaml ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ job: extension
3
+ config:
4
+ # this name will be the folder and filename name
5
+ name: "my_first_flux_lora_v1"
6
+ process:
7
+ - type: 'sd_trainer'
8
+ # root folder to save training sessions/samples/weights
9
+ training_folder: "output"
10
+ # uncomment to see performance stats in the terminal every N steps
11
+ # performance_log_every: 1000
12
+ device: cuda:0
13
+ # if a trigger word is specified, it will be added to captions of training data if it does not already exist
14
+ # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
15
+ # trigger_word: "p3r5on"
16
+ network:
17
+ type: "lora"
18
+ linear: 16
19
+ linear_alpha: 16
20
+ save:
21
+ dtype: float16 # precision to save
22
+ save_every: 250 # save every this many steps
23
+ max_step_saves_to_keep: 4 # how many intermittent saves to keep
24
+ push_to_hub: false #change this to True to push your trained model to Hugging Face.
25
+ # You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
26
+ # hf_repo_id: your-username/your-model-slug
27
+ # hf_private: true #whether the repo is private or public
28
+ datasets:
29
+ # datasets are a folder of images. captions need to be txt files with the same name as the image
30
+ # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
31
+ # images will automatically be resized and bucketed into the resolution specified
32
+ # on windows, escape back slashes with another backslash so
33
+ # "C:\\path\\to\\images\\folder"
34
+ - folder_path: "/path/to/images/folder"
35
+ caption_ext: "txt"
36
+ caption_dropout_rate: 0.05 # will drop out the caption 5% of time
37
+ shuffle_tokens: false # shuffle caption order, split by commas
38
+ cache_latents_to_disk: true # leave this true unless you know what you're doing
39
+ resolution: [ 512, 768, 1024 ] # flux enjoys multiple resolutions
40
+ train:
41
+ batch_size: 1
42
+ steps: 2000 # total number of steps to train 500 - 4000 is a good range
43
+ gradient_accumulation_steps: 1
44
+ train_unet: true
45
+ train_text_encoder: false # probably won't work with flux
46
+ gradient_checkpointing: true # need the on unless you have a ton of vram
47
+ noise_scheduler: "flowmatch" # for training only
48
+ optimizer: "adamw8bit"
49
+ lr: 1e-4
50
+ # uncomment this to skip the pre training sample
51
+ # skip_first_sample: true
52
+ # uncomment to completely disable sampling
53
+ # disable_sampling: true
54
+ # uncomment to use new bell curved weighting. Experimental but may produce better results
55
+ # linear_timesteps: true
56
+
57
+ # ema will smooth out learning, but could slow it down. Recommended to leave on.
58
+ ema_config:
59
+ use_ema: true
60
+ ema_decay: 0.99
61
+
62
+ # will probably need this if gpu supports it for flux, other dtypes may not work correctly
63
+ dtype: bf16
64
+ model:
65
+ # huggingface model name or path
66
+ name_or_path: "black-forest-labs/FLUX.1-schnell"
67
+ assistant_lora_path: "ostris/FLUX.1-schnell-training-adapter" # Required for flux schnell training
68
+ is_flux: true
69
+ quantize: true # run 8bit mixed precision
70
+ # low_vram is painfully slow to fuse in the adapter avoid it unless absolutely necessary
71
+ # low_vram: true # uncomment this if the GPU is connected to your monitors. It will use less vram to quantize, but is slower.
72
+ sample:
73
+ sampler: "flowmatch" # must match train.noise_scheduler
74
+ sample_every: 250 # sample every this many steps
75
+ width: 1024
76
+ height: 1024
77
+ prompts:
78
+ # you can add [trigger] to the prompts here and it will be replaced with the trigger word
79
+ # - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
80
+ - "woman with red hair, playing chess at the park, bomb going off in the background"
81
+ - "a woman holding a coffee cup, in a beanie, sitting at a cafe"
82
+ - "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
83
+ - "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
84
+ - "a bear building a log cabin in the snow covered mountains"
85
+ - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
86
+ - "hipster man with a beard, building a chair, in a wood shop"
87
+ - "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop"
88
+ - "a man holding a sign that says, 'this is a sign'"
89
+ - "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
90
+ neg: "" # not used on flux
91
+ seed: 42
92
+ walk_seed: true
93
+ guidance_scale: 1 # schnell does not do guidance
94
+ sample_steps: 4 # 1 - 4 works well
95
+ # you can add any additional meta info here. [name] is replaced with config name at top
96
+ meta:
97
+ name: "[name]"
98
+ version: '1.0'
config/examples/train_lora_hidream_48.yaml ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # HiDream training is still highly experimental. The settings here will take ~35.2GB of vram to train.
2
+ # It is not possible to train on a single 24GB card yet, but I am working on it. If you have more VRAM
3
+ # I highly recommend first disabling quantization on the model itself if you can. You can leave the TEs quantized.
4
+ # HiDream has a mixture of experts that may take special training considerations that I do not
5
+ # have implemented properly. The current implementation seems to work well for LoRA training, but
6
+ # may not be effective for longer training runs. The implementation could change in future updates
7
+ # so your results may vary when this happens.
8
+
9
+ ---
10
+ job: extension
11
+ config:
12
+ # this name will be the folder and filename name
13
+ name: "my_first_hidream_lora_v1"
14
+ process:
15
+ - type: 'sd_trainer'
16
+ # root folder to save training sessions/samples/weights
17
+ training_folder: "output"
18
+ # uncomment to see performance stats in the terminal every N steps
19
+ # performance_log_every: 1000
20
+ device: cuda:0
21
+ # if a trigger word is specified, it will be added to captions of training data if it does not already exist
22
+ # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
23
+ # trigger_word: "p3r5on"
24
+ network:
25
+ type: "lora"
26
+ linear: 32
27
+ linear_alpha: 32
28
+ network_kwargs:
29
+ # it is probably best to ignore the mixture of experts since only 2 are active each block. It works activating it, but I wouldnt.
30
+ # proper training of it is not fully implemented
31
+ ignore_if_contains:
32
+ - "ff_i.experts"
33
+ - "ff_i.gate"
34
+ save:
35
+ dtype: bfloat16 # precision to save
36
+ save_every: 250 # save every this many steps
37
+ max_step_saves_to_keep: 4 # how many intermittent saves to keep
38
+ datasets:
39
+ # datasets are a folder of images. captions need to be txt files with the same name as the image
40
+ # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
41
+ # images will automatically be resized and bucketed into the resolution specified
42
+ # on windows, escape back slashes with another backslash so
43
+ # "C:\\path\\to\\images\\folder"
44
+ - folder_path: "/path/to/images/folder"
45
+ caption_ext: "txt"
46
+ caption_dropout_rate: 0.05 # will drop out the caption 5% of time
47
+ resolution: [ 512, 768, 1024 ] # hidream enjoys multiple resolutions
48
+ train:
49
+ batch_size: 1
50
+ steps: 3000 # total number of steps to train 500 - 4000 is a good range
51
+ gradient_accumulation_steps: 1
52
+ train_unet: true
53
+ train_text_encoder: false # wont work with hidream
54
+ gradient_checkpointing: true # need the on unless you have a ton of vram
55
+ noise_scheduler: "flowmatch" # for training only
56
+ timestep_type: shift # sigmoid, shift, linear
57
+ optimizer: "adamw8bit"
58
+ lr: 2e-4
59
+ # uncomment this to skip the pre training sample
60
+ # skip_first_sample: true
61
+ # uncomment to completely disable sampling
62
+ # disable_sampling: true
63
+ # uncomment to use new vell curved weighting. Experimental but may produce better results
64
+ # linear_timesteps: true
65
+
66
+ # ema will smooth out learning, but could slow it down. Defaults off
67
+ ema_config:
68
+ use_ema: false
69
+ ema_decay: 0.99
70
+
71
+ # will probably need this if gpu supports it for hidream, other dtypes may not work correctly
72
+ dtype: bf16
73
+ model:
74
+ # the transformer will get grabbed from this hf repo
75
+ # warning ONLY train on Full. The dev and fast models are distilled and will break
76
+ name_or_path: "HiDream-ai/HiDream-I1-Full"
77
+ # the extras will be grabbed from this hf repo. (text encoder, vae)
78
+ extras_name_or_path: "HiDream-ai/HiDream-I1-Full"
79
+ arch: "hidream"
80
+ # both need to be quantized to train on 48GB currently
81
+ quantize: true
82
+ quantize_te: true
83
+ model_kwargs:
84
+ # llama is a gated model, It defaults to unsloth version, but you can set the llama path here
85
+ llama_model_path: "unsloth/Meta-Llama-3.1-8B-Instruct"
86
+ sample:
87
+ sampler: "flowmatch" # must match train.noise_scheduler
88
+ sample_every: 250 # sample every this many steps
89
+ width: 1024
90
+ height: 1024
91
+ prompts:
92
+ # you can add [trigger] to the prompts here and it will be replaced with the trigger word
93
+ # - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
94
+ - "woman with red hair, playing chess at the park, bomb going off in the background"
95
+ - "a woman holding a coffee cup, in a beanie, sitting at a cafe"
96
+ - "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
97
+ - "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
98
+ - "a bear building a log cabin in the snow covered mountains"
99
+ - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
100
+ - "hipster man with a beard, building a chair, in a wood shop"
101
+ - "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop"
102
+ - "a man holding a sign that says, 'this is a sign'"
103
+ - "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
104
+ neg: ""
105
+ seed: 42
106
+ walk_seed: true
107
+ guidance_scale: 4
108
+ sample_steps: 25
109
+ # you can add any additional meta info here. [name] is replaced with config name at top
110
+ meta:
111
+ name: "[name]"
112
+ version: '1.0'
config/examples/train_lora_lumina.yaml ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ # This configuration requires 20GB of VRAM or more to operate
3
+ job: extension
4
+ config:
5
+ # this name will be the folder and filename name
6
+ name: "my_first_lumina_lora_v1"
7
+ process:
8
+ - type: 'sd_trainer'
9
+ # root folder to save training sessions/samples/weights
10
+ training_folder: "output"
11
+ # uncomment to see performance stats in the terminal every N steps
12
+ # performance_log_every: 1000
13
+ device: cuda:0
14
+ # if a trigger word is specified, it will be added to captions of training data if it does not already exist
15
+ # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
16
+ # trigger_word: "p3r5on"
17
+ network:
18
+ type: "lora"
19
+ linear: 16
20
+ linear_alpha: 16
21
+ save:
22
+ dtype: bf16 # precision to save
23
+ save_every: 250 # save every this many steps
24
+ max_step_saves_to_keep: 2 # how many intermittent saves to keep
25
+ save_format: 'diffusers' # 'diffusers'
26
+ datasets:
27
+ # datasets are a folder of images. captions need to be txt files with the same name as the image
28
+ # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
29
+ # images will automatically be resized and bucketed into the resolution specified
30
+ # on windows, escape back slashes with another backslash so
31
+ # "C:\\path\\to\\images\\folder"
32
+ - folder_path: "/path/to/images/folder"
33
+ caption_ext: "txt"
34
+ caption_dropout_rate: 0.05 # will drop out the caption 5% of time
35
+ shuffle_tokens: false # shuffle caption order, split by commas
36
+ # cache_latents_to_disk: true # leave this true unless you know what you're doing
37
+ resolution: [ 512, 768, 1024 ] # lumina2 enjoys multiple resolutions
38
+ train:
39
+ batch_size: 1
40
+
41
+ # can be 'sigmoid', 'linear', or 'lumina2_shift'
42
+ timestep_type: 'lumina2_shift'
43
+
44
+ steps: 2000 # total number of steps to train 500 - 4000 is a good range
45
+ gradient_accumulation: 1
46
+ train_unet: true
47
+ train_text_encoder: false # probably won't work with lumina2
48
+ gradient_checkpointing: true # need the on unless you have a ton of vram
49
+ noise_scheduler: "flowmatch" # for training only
50
+ optimizer: "adamw8bit"
51
+ lr: 1e-4
52
+ # uncomment this to skip the pre training sample
53
+ # skip_first_sample: true
54
+ # uncomment to completely disable sampling
55
+ # disable_sampling: true
56
+
57
+ # ema will smooth out learning, but could slow it down. Recommended to leave on if you have the vram
58
+ ema_config:
59
+ use_ema: true
60
+ ema_decay: 0.99
61
+
62
+ # will probably need this if gpu supports it for lumina2, other dtypes may not work correctly
63
+ dtype: bf16
64
+ model:
65
+ # huggingface model name or path
66
+ name_or_path: "Alpha-VLLM/Lumina-Image-2.0"
67
+ is_lumina2: true # lumina2 architecture
68
+ # you can quantize just the Gemma2 text encoder here to save vram
69
+ quantize_te: true
70
+ sample:
71
+ sampler: "flowmatch" # must match train.noise_scheduler
72
+ sample_every: 250 # sample every this many steps
73
+ width: 1024
74
+ height: 1024
75
+ prompts:
76
+ # you can add [trigger] to the prompts here and it will be replaced with the trigger word
77
+ # - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
78
+ - "woman with red hair, playing chess at the park, bomb going off in the background"
79
+ - "a woman holding a coffee cup, in a beanie, sitting at a cafe"
80
+ - "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
81
+ - "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
82
+ - "a bear building a log cabin in the snow covered mountains"
83
+ - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
84
+ - "hipster man with a beard, building a chair, in a wood shop"
85
+ - "photo of a cat that is half black and half orange tabby, split down the middle. The cat has on a blue tophat. They are holding a martini glass with a pink ball of yarn in it with green knitting needles sticking out, in one paw. In the other paw, they are holding a DVD case for a movie titled, \"This is a test\" that has a golden robot on it. In the background is a busy night club with a giant mushroom man dancing with a bear."
86
+ - "a man holding a sign that says, 'this is a sign'"
87
+ - "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
88
+ neg: ""
89
+ seed: 42
90
+ walk_seed: true
91
+ guidance_scale: 4.0
92
+ sample_steps: 25
93
+ # you can add any additional meta info here. [name] is replaced with config name at top
94
+ meta:
95
+ name: "[name]"
96
+ version: '1.0'
config/examples/train_lora_omnigen2_24gb.yaml ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ job: extension
3
+ config:
4
+ # this name will be the folder and filename name
5
+ name: "my_first_omnigen2_lora_v1"
6
+ process:
7
+ - type: 'sd_trainer'
8
+ # root folder to save training sessions/samples/weights
9
+ training_folder: "output"
10
+ # uncomment to see performance stats in the terminal every N steps
11
+ # performance_log_every: 1000
12
+ device: cuda:0
13
+ # if a trigger word is specified, it will be added to captions of training data if it does not already exist
14
+ # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
15
+ # trigger_word: "p3r5on"
16
+ network:
17
+ type: "lora"
18
+ linear: 16
19
+ linear_alpha: 16
20
+ save:
21
+ dtype: float16 # precision to save
22
+ save_every: 250 # save every this many steps
23
+ max_step_saves_to_keep: 4 # how many intermittent saves to keep
24
+ push_to_hub: false #change this to True to push your trained model to Hugging Face.
25
+ # You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
26
+ # hf_repo_id: your-username/your-model-slug
27
+ # hf_private: true #whether the repo is private or public
28
+ datasets:
29
+ # datasets are a folder of images. captions need to be txt files with the same name as the image
30
+ # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
31
+ # images will automatically be resized and bucketed into the resolution specified
32
+ # on windows, escape back slashes with another backslash so
33
+ # "C:\\path\\to\\images\\folder"
34
+ - folder_path: "/path/to/images/folder"
35
+ caption_ext: "txt"
36
+ caption_dropout_rate: 0.05 # will drop out the caption 5% of time
37
+ shuffle_tokens: false # shuffle caption order, split by commas
38
+ cache_latents_to_disk: true # leave this true unless you know what you're doing
39
+ resolution: [ 512, 768, 1024 ] # omnigen2 should work with multiple resolutions
40
+ train:
41
+ batch_size: 1
42
+ steps: 3000 # total number of steps to train 500 - 4000 is a good range
43
+ gradient_accumulation: 1
44
+ train_unet: true
45
+ train_text_encoder: false # probably won't work with omnigen2
46
+ gradient_checkpointing: true # need the on unless you have a ton of vram
47
+ noise_scheduler: "flowmatch" # for training only
48
+ optimizer: "adamw8bit"
49
+ lr: 1e-4
50
+ timestep_type: 'sigmoid' # sigmoid, linear, shift
51
+ # uncomment this to skip the pre training sample
52
+ # skip_first_sample: true
53
+ # uncomment to completely disable sampling
54
+ # disable_sampling: true
55
+
56
+ # ema will smooth out learning, but could slow it down.
57
+ # ema_config:
58
+ # use_ema: true
59
+ # ema_decay: 0.99
60
+
61
+ # will probably need this if gpu supports it for omnigen2, other dtypes may not work correctly
62
+ dtype: bf16
63
+ model:
64
+ name_or_path: "OmniGen2/OmniGen2
65
+ arch: "omnigen2"
66
+ quantize_te: true # quantize_only te
67
+ # quantize: true # quantize transformer
68
+ sample:
69
+ sampler: "flowmatch" # must match train.noise_scheduler
70
+ sample_every: 250 # sample every this many steps
71
+ width: 1024
72
+ height: 1024
73
+ prompts:
74
+ # you can add [trigger] to the prompts here and it will be replaced with the trigger word
75
+ # - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
76
+ - "woman with red hair, playing chess at the park, bomb going off in the background"
77
+ - "a woman holding a coffee cup, in a beanie, sitting at a cafe"
78
+ - "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
79
+ - "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
80
+ - "a bear building a log cabin in the snow covered mountains"
81
+ - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
82
+ - "hipster man with a beard, building a chair, in a wood shop"
83
+ - "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop"
84
+ - "a man holding a sign that says, 'this is a sign'"
85
+ - "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
86
+ neg: "" # negative prompt, optional
87
+ seed: 42
88
+ walk_seed: true
89
+ guidance_scale: 4
90
+ sample_steps: 25
91
+ # you can add any additional meta info here. [name] is replaced with config name at top
92
+ meta:
93
+ name: "[name]"
94
+ version: '1.0'
config/examples/train_lora_sd35_large_24gb.yaml ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ # NOTE!! THIS IS CURRENTLY EXPERIMENTAL AND UNDER DEVELOPMENT. SOME THINGS WILL CHANGE
3
+ job: extension
4
+ config:
5
+ # this name will be the folder and filename name
6
+ name: "my_first_sd3l_lora_v1"
7
+ process:
8
+ - type: 'sd_trainer'
9
+ # root folder to save training sessions/samples/weights
10
+ training_folder: "output"
11
+ # uncomment to see performance stats in the terminal every N steps
12
+ # performance_log_every: 1000
13
+ device: cuda:0
14
+ # if a trigger word is specified, it will be added to captions of training data if it does not already exist
15
+ # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
16
+ # trigger_word: "p3r5on"
17
+ network:
18
+ type: "lora"
19
+ linear: 16
20
+ linear_alpha: 16
21
+ save:
22
+ dtype: float16 # precision to save
23
+ save_every: 250 # save every this many steps
24
+ max_step_saves_to_keep: 4 # how many intermittent saves to keep
25
+ push_to_hub: false #change this to True to push your trained model to Hugging Face.
26
+ # You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
27
+ # hf_repo_id: your-username/your-model-slug
28
+ # hf_private: true #whether the repo is private or public
29
+ datasets:
30
+ # datasets are a folder of images. captions need to be txt files with the same name as the image
31
+ # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
32
+ # images will automatically be resized and bucketed into the resolution specified
33
+ # on windows, escape back slashes with another backslash so
34
+ # "C:\\path\\to\\images\\folder"
35
+ - folder_path: "/path/to/images/folder"
36
+ caption_ext: "txt"
37
+ caption_dropout_rate: 0.05 # will drop out the caption 5% of time
38
+ shuffle_tokens: false # shuffle caption order, split by commas
39
+ cache_latents_to_disk: true # leave this true unless you know what you're doing
40
+ resolution: [ 1024 ]
41
+ train:
42
+ batch_size: 1
43
+ steps: 2000 # total number of steps to train 500 - 4000 is a good range
44
+ gradient_accumulation_steps: 1
45
+ train_unet: true
46
+ train_text_encoder: false # May not fully work with SD3 yet
47
+ gradient_checkpointing: true # need the on unless you have a ton of vram
48
+ noise_scheduler: "flowmatch"
49
+ timestep_type: "linear" # linear or sigmoid
50
+ optimizer: "adamw8bit"
51
+ lr: 1e-4
52
+ # uncomment this to skip the pre training sample
53
+ # skip_first_sample: true
54
+ # uncomment to completely disable sampling
55
+ # disable_sampling: true
56
+ # uncomment to use new vell curved weighting. Experimental but may produce better results
57
+ # linear_timesteps: true
58
+
59
+ # ema will smooth out learning, but could slow it down. Recommended to leave on.
60
+ ema_config:
61
+ use_ema: true
62
+ ema_decay: 0.99
63
+
64
+ # will probably need this if gpu supports it for sd3, other dtypes may not work correctly
65
+ dtype: bf16
66
+ model:
67
+ # huggingface model name or path
68
+ name_or_path: "stabilityai/stable-diffusion-3.5-large"
69
+ is_v3: true
70
+ quantize: true # run 8bit mixed precision
71
+ sample:
72
+ sampler: "flowmatch" # must match train.noise_scheduler
73
+ sample_every: 250 # sample every this many steps
74
+ width: 1024
75
+ height: 1024
76
+ prompts:
77
+ # you can add [trigger] to the prompts here and it will be replaced with the trigger word
78
+ # - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
79
+ - "woman with red hair, playing chess at the park, bomb going off in the background"
80
+ - "a woman holding a coffee cup, in a beanie, sitting at a cafe"
81
+ - "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
82
+ - "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
83
+ - "a bear building a log cabin in the snow covered mountains"
84
+ - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
85
+ - "hipster man with a beard, building a chair, in a wood shop"
86
+ - "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop"
87
+ - "a man holding a sign that says, 'this is a sign'"
88
+ - "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
89
+ neg: ""
90
+ seed: 42
91
+ walk_seed: true
92
+ guidance_scale: 4
93
+ sample_steps: 25
94
+ # you can add any additional meta info here. [name] is replaced with config name at top
95
+ meta:
96
+ name: "[name]"
97
+ version: '1.0'
config/examples/train_lora_wan21_14b_24gb.yaml ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # IMPORTANT: The Wan2.1 14B model is huge. This config should work on 24GB GPUs. It cannot
2
+ # support keeping the text encoder on GPU while training with 24GB, so it is only good
3
+ # for training on a single prompt, for example a person with a trigger word.
4
+ # to train on captions, you need more vran for now.
5
+ ---
6
+ job: extension
7
+ config:
8
+ # this name will be the folder and filename name
9
+ name: "my_first_wan21_14b_lora_v1"
10
+ process:
11
+ - type: 'sd_trainer'
12
+ # root folder to save training sessions/samples/weights
13
+ training_folder: "output"
14
+ # uncomment to see performance stats in the terminal every N steps
15
+ # performance_log_every: 1000
16
+ device: cuda:0
17
+ # if a trigger word is specified, it will be added to captions of training data if it does not already exist
18
+ # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
19
+ # this is probably needed for 24GB cards when offloading TE to CPU
20
+ trigger_word: "p3r5on"
21
+ network:
22
+ type: "lora"
23
+ linear: 32
24
+ linear_alpha: 32
25
+ save:
26
+ dtype: float16 # precision to save
27
+ save_every: 250 # save every this many steps
28
+ max_step_saves_to_keep: 4 # how many intermittent saves to keep
29
+ push_to_hub: false #change this to True to push your trained model to Hugging Face.
30
+ # You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
31
+ # hf_repo_id: your-username/your-model-slug
32
+ # hf_private: true #whether the repo is private or public
33
+ datasets:
34
+ # datasets are a folder of images. captions need to be txt files with the same name as the image
35
+ # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
36
+ # images will automatically be resized and bucketed into the resolution specified
37
+ # on windows, escape back slashes with another backslash so
38
+ # "C:\\path\\to\\images\\folder"
39
+ # AI-Toolkit does not currently support video datasets, we will train on 1 frame at a time
40
+ # it works well for characters, but not as well for "actions"
41
+ - folder_path: "/path/to/images/folder"
42
+ caption_ext: "txt"
43
+ caption_dropout_rate: 0.05 # will drop out the caption 5% of time
44
+ shuffle_tokens: false # shuffle caption order, split by commas
45
+ cache_latents_to_disk: true # leave this true unless you know what you're doing
46
+ resolution: [ 632 ] # will be around 480p
47
+ train:
48
+ batch_size: 1
49
+ steps: 2000 # total number of steps to train 500 - 4000 is a good range
50
+ gradient_accumulation: 1
51
+ train_unet: true
52
+ train_text_encoder: false # probably won't work with wan
53
+ gradient_checkpointing: true # need the on unless you have a ton of vram
54
+ noise_scheduler: "flowmatch" # for training only
55
+ timestep_type: 'sigmoid'
56
+ optimizer: "adamw8bit"
57
+ lr: 1e-4
58
+ optimizer_params:
59
+ weight_decay: 1e-4
60
+ # uncomment this to skip the pre training sample
61
+ # skip_first_sample: true
62
+ # uncomment to completely disable sampling
63
+ # disable_sampling: true
64
+ # ema will smooth out learning, but could slow it down. Recommended to leave on.
65
+ ema_config:
66
+ use_ema: true
67
+ ema_decay: 0.99
68
+ dtype: bf16
69
+ # required for 24GB cards
70
+ # this will encode your trigger word and use those embeddings for every image in the dataset
71
+ unload_text_encoder: true
72
+ model:
73
+ # huggingface model name or path
74
+ name_or_path: "Wan-AI/Wan2.1-T2V-14B-Diffusers"
75
+ arch: 'wan21'
76
+ # these settings will save as much vram as possible
77
+ quantize: true
78
+ quantize_te: true
79
+ low_vram: true
80
+ sample:
81
+ sampler: "flowmatch"
82
+ sample_every: 250 # sample every this many steps
83
+ width: 832
84
+ height: 480
85
+ num_frames: 40
86
+ fps: 15
87
+ # samples take a long time. so use them sparingly
88
+ # samples will be animated webp files, if you don't see them animated, open in a browser.
89
+ prompts:
90
+ # you can add [trigger] to the prompts here and it will be replaced with the trigger word
91
+ # - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
92
+ - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
93
+ neg: ""
94
+ seed: 42
95
+ walk_seed: true
96
+ guidance_scale: 5
97
+ sample_steps: 30
98
+ # you can add any additional meta info here. [name] is replaced with config name at top
99
+ meta:
100
+ name: "[name]"
101
+ version: '1.0'
config/examples/train_lora_wan21_1b_24gb.yaml ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ job: extension
3
+ config:
4
+ # this name will be the folder and filename name
5
+ name: "my_first_wan21_1b_lora_v1"
6
+ process:
7
+ - type: 'sd_trainer'
8
+ # root folder to save training sessions/samples/weights
9
+ training_folder: "output"
10
+ # uncomment to see performance stats in the terminal every N steps
11
+ # performance_log_every: 1000
12
+ device: cuda:0
13
+ # if a trigger word is specified, it will be added to captions of training data if it does not already exist
14
+ # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
15
+ # trigger_word: "p3r5on"
16
+ network:
17
+ type: "lora"
18
+ linear: 32
19
+ linear_alpha: 32
20
+ save:
21
+ dtype: float16 # precision to save
22
+ save_every: 250 # save every this many steps
23
+ max_step_saves_to_keep: 4 # how many intermittent saves to keep
24
+ push_to_hub: false #change this to True to push your trained model to Hugging Face.
25
+ # You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
26
+ # hf_repo_id: your-username/your-model-slug
27
+ # hf_private: true #whether the repo is private or public
28
+ datasets:
29
+ # datasets are a folder of images. captions need to be txt files with the same name as the image
30
+ # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
31
+ # images will automatically be resized and bucketed into the resolution specified
32
+ # on windows, escape back slashes with another backslash so
33
+ # "C:\\path\\to\\images\\folder"
34
+ # AI-Toolkit does not currently support video datasets, we will train on 1 frame at a time
35
+ # it works well for characters, but not as well for "actions"
36
+ - folder_path: "/path/to/images/folder"
37
+ caption_ext: "txt"
38
+ caption_dropout_rate: 0.05 # will drop out the caption 5% of time
39
+ shuffle_tokens: false # shuffle caption order, split by commas
40
+ cache_latents_to_disk: true # leave this true unless you know what you're doing
41
+ resolution: [ 632 ] # will be around 480p
42
+ train:
43
+ batch_size: 1
44
+ steps: 2000 # total number of steps to train 500 - 4000 is a good range
45
+ gradient_accumulation: 1
46
+ train_unet: true
47
+ train_text_encoder: false # probably won't work with wan
48
+ gradient_checkpointing: true # need the on unless you have a ton of vram
49
+ noise_scheduler: "flowmatch" # for training only
50
+ timestep_type: 'sigmoid'
51
+ optimizer: "adamw8bit"
52
+ lr: 1e-4
53
+ optimizer_params:
54
+ weight_decay: 1e-4
55
+ # uncomment this to skip the pre training sample
56
+ # skip_first_sample: true
57
+ # uncomment to completely disable sampling
58
+ # disable_sampling: true
59
+ # ema will smooth out learning, but could slow it down. Recommended to leave on.
60
+ ema_config:
61
+ use_ema: true
62
+ ema_decay: 0.99
63
+ dtype: bf16
64
+ model:
65
+ # huggingface model name or path
66
+ name_or_path: "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
67
+ arch: 'wan21'
68
+ quantize_te: true # saves vram
69
+ sample:
70
+ sampler: "flowmatch"
71
+ sample_every: 250 # sample every this many steps
72
+ width: 832
73
+ height: 480
74
+ num_frames: 40
75
+ fps: 15
76
+ # samples take a long time. so use them sparingly
77
+ # samples will be animated webp files, if you don't see them animated, open in a browser.
78
+ prompts:
79
+ # you can add [trigger] to the prompts here and it will be replaced with the trigger word
80
+ # - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
81
+ - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
82
+ neg: ""
83
+ seed: 42
84
+ walk_seed: true
85
+ guidance_scale: 5
86
+ sample_steps: 30
87
+ # you can add any additional meta info here. [name] is replaced with config name at top
88
+ meta:
89
+ name: "[name]"
90
+ version: '1.0'
config/examples/train_slider.example.yml ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ # This is in yaml format. You can use json if you prefer
3
+ # I like both but yaml is easier to write
4
+ # Plus it has comments which is nice for documentation
5
+ # This is the config I use on my sliders, It is solid and tested
6
+ job: train
7
+ config:
8
+ # the name will be used to create a folder in the output folder
9
+ # it will also replace any [name] token in the rest of this config
10
+ name: detail_slider_v1
11
+ # folder will be created with name above in folder below
12
+ # it can be relative to the project root or absolute
13
+ training_folder: "output/LoRA"
14
+ device: cuda:0 # cpu, cuda:0, etc
15
+ # for tensorboard logging, we will make a subfolder for this job
16
+ log_dir: "output/.tensorboard"
17
+ # you can stack processes for other jobs, It is not tested with sliders though
18
+ # just use one for now
19
+ process:
20
+ - type: slider # tells runner to run the slider process
21
+ # network is the LoRA network for a slider, I recommend to leave this be
22
+ network:
23
+ # network type lierla is traditional LoRA that works everywhere, only linear layers
24
+ type: "lierla"
25
+ # rank / dim of the network. Bigger is not always better. Especially for sliders. 8 is good
26
+ linear: 8
27
+ linear_alpha: 4 # Do about half of rank
28
+ # training config
29
+ train:
30
+ # this is also used in sampling. Stick with ddpm unless you know what you are doing
31
+ noise_scheduler: "ddpm" # or "ddpm", "lms", "euler_a"
32
+ # how many steps to train. More is not always better. I rarely go over 1000
33
+ steps: 500
34
+ # I have had good results with 4e-4 to 1e-4 at 500 steps
35
+ lr: 2e-4
36
+ # enables gradient checkpoint, saves vram, leave it on
37
+ gradient_checkpointing: true
38
+ # train the unet. I recommend leaving this true
39
+ train_unet: true
40
+ # train the text encoder. I don't recommend this unless you have a special use case
41
+ # for sliders we are adjusting representation of the concept (unet),
42
+ # not the description of it (text encoder)
43
+ train_text_encoder: false
44
+ # same as from sd-scripts, not fully tested but should speed up training
45
+ min_snr_gamma: 5.0
46
+ # just leave unless you know what you are doing
47
+ # also supports "dadaptation" but set lr to 1 if you use that,
48
+ # but it learns too fast and I don't recommend it
49
+ optimizer: "adamw"
50
+ # only constant for now
51
+ lr_scheduler: "constant"
52
+ # we randomly denoise random num of steps form 1 to this number
53
+ # while training. Just leave it
54
+ max_denoising_steps: 40
55
+ # works great at 1. I do 1 even with my 4090.
56
+ # higher may not work right with newer single batch stacking code anyway
57
+ batch_size: 1
58
+ # bf16 works best if your GPU supports it (modern)
59
+ dtype: bf16 # fp32, bf16, fp16
60
+ # if you have it, use it. It is faster and better
61
+ # torch 2.0 doesnt need xformers anymore, only use if you have lower version
62
+ # xformers: true
63
+ # I don't recommend using unless you are trying to make a darker lora. Then do 0.1 MAX
64
+ # although, the way we train sliders is comparative, so it probably won't work anyway
65
+ noise_offset: 0.0
66
+ # noise_offset: 0.0357 # SDXL was trained with offset of 0.0357. So use that when training on SDXL
67
+
68
+ # the model to train the LoRA network on
69
+ model:
70
+ # huggingface name, relative prom project path, or absolute path to .safetensors or .ckpt
71
+ name_or_path: "runwayml/stable-diffusion-v1-5"
72
+ is_v2: false # for v2 models
73
+ is_v_pred: false # for v-prediction models (most v2 models)
74
+ # has some issues with the dual text encoder and the way we train sliders
75
+ # it works bit weights need to probably be higher to see it.
76
+ is_xl: false # for SDXL models
77
+
78
+ # saving config
79
+ save:
80
+ dtype: float16 # precision to save. I recommend float16
81
+ save_every: 50 # save every this many steps
82
+ # this will remove step counts more than this number
83
+ # allows you to save more often in case of a crash without filling up your drive
84
+ max_step_saves_to_keep: 2
85
+
86
+ # sampling config
87
+ sample:
88
+ # must match train.noise_scheduler, this is not used here
89
+ # but may be in future and in other processes
90
+ sampler: "ddpm"
91
+ # sample every this many steps
92
+ sample_every: 20
93
+ # image size
94
+ width: 512
95
+ height: 512
96
+ # prompts to use for sampling. Do as many as you want, but it slows down training
97
+ # pick ones that will best represent the concept you are trying to adjust
98
+ # allows some flags after the prompt
99
+ # --m [number] # network multiplier. LoRA weight. -3 for the negative slide, 3 for the positive
100
+ # slide are good tests. will inherit sample.network_multiplier if not set
101
+ # --n [string] # negative prompt, will inherit sample.neg if not set
102
+ # Only 75 tokens allowed currently
103
+ # I like to do a wide positive and negative spread so I can see a good range and stop
104
+ # early if the network is braking down
105
+ prompts:
106
+ - "a woman in a coffee shop, black hat, blonde hair, blue jacket --m -5"
107
+ - "a woman in a coffee shop, black hat, blonde hair, blue jacket --m -3"
108
+ - "a woman in a coffee shop, black hat, blonde hair, blue jacket --m 3"
109
+ - "a woman in a coffee shop, black hat, blonde hair, blue jacket --m 5"
110
+ - "a golden retriever sitting on a leather couch, --m -5"
111
+ - "a golden retriever sitting on a leather couch --m -3"
112
+ - "a golden retriever sitting on a leather couch --m 3"
113
+ - "a golden retriever sitting on a leather couch --m 5"
114
+ - "a man with a beard and red flannel shirt, wearing vr goggles, walking into traffic --m -5"
115
+ - "a man with a beard and red flannel shirt, wearing vr goggles, walking into traffic --m -3"
116
+ - "a man with a beard and red flannel shirt, wearing vr goggles, walking into traffic --m 3"
117
+ - "a man with a beard and red flannel shirt, wearing vr goggles, walking into traffic --m 5"
118
+ # negative prompt used on all prompts above as default if they don't have one
119
+ neg: "cartoon, fake, drawing, illustration, cgi, animated, anime, monochrome"
120
+ # seed for sampling. 42 is the answer for everything
121
+ seed: 42
122
+ # walks the seed so s1 is 42, s2 is 43, s3 is 44, etc
123
+ # will start over on next sample_every so s1 is always seed
124
+ # works well if you use same prompt but want different results
125
+ walk_seed: false
126
+ # cfg scale (4 to 10 is good)
127
+ guidance_scale: 7
128
+ # sampler steps (20 to 30 is good)
129
+ sample_steps: 20
130
+ # default network multiplier for all prompts
131
+ # since we are training a slider, I recommend overriding this with --m [number]
132
+ # in the prompts above to get both sides of the slider
133
+ network_multiplier: 1.0
134
+
135
+ # logging information
136
+ logging:
137
+ log_every: 10 # log every this many steps
138
+ use_wandb: false # not supported yet
139
+ verbose: false # probably done need unless you are debugging
140
+
141
+ # slider training config, best for last
142
+ slider:
143
+ # resolutions to train on. [ width, height ]. This is less important for sliders
144
+ # as we are not teaching the model anything it doesn't already know
145
+ # but must be a size it understands [ 512, 512 ] for sd_v1.5 and [ 768, 768 ] for sd_v2.1
146
+ # and [ 1024, 1024 ] for sd_xl
147
+ # you can do as many as you want here
148
+ resolutions:
149
+ - [ 512, 512 ]
150
+ # - [ 512, 768 ]
151
+ # - [ 768, 768 ]
152
+ # slider training uses 4 combined steps for a single round. This will do it in one gradient
153
+ # step. It is highly optimized and shouldn't take anymore vram than doing without it,
154
+ # since we break down batches for gradient accumulation now. so just leave it on.
155
+ batch_full_slide: true
156
+ # These are the concepts to train on. You can do as many as you want here,
157
+ # but they can conflict outweigh each other. Other than experimenting, I recommend
158
+ # just doing one for good results
159
+ targets:
160
+ # target_class is the base concept we are adjusting the representation of
161
+ # for example, if we are adjusting the representation of a person, we would use "person"
162
+ # if we are adjusting the representation of a cat, we would use "cat" It is not
163
+ # a keyword necessarily but what the model understands the concept to represent.
164
+ # "person" will affect men, women, children, etc but will not affect cats, dogs, etc
165
+ # it is the models base general understanding of the concept and everything it represents
166
+ # you can leave it blank to affect everything. In this example, we are adjusting
167
+ # detail, so we will leave it blank to affect everything
168
+ - target_class: ""
169
+ # positive is the prompt for the positive side of the slider.
170
+ # It is the concept that will be excited and amplified in the model when we slide the slider
171
+ # to the positive side and forgotten / inverted when we slide
172
+ # the slider to the negative side. It is generally best to include the target_class in
173
+ # the prompt. You want it to be the extreme of what you want to train on. For example,
174
+ # if you want to train on fat people, you would use "an extremely fat, morbidly obese person"
175
+ # as the prompt. Not just "fat person"
176
+ # max 75 tokens for now
177
+ positive: "high detail, 8k, intricate, detailed, high resolution, high res, high quality"
178
+ # negative is the prompt for the negative side of the slider and works the same as positive
179
+ # it does not necessarily work the same as a negative prompt when generating images
180
+ # these need to be polar opposites.
181
+ # max 76 tokens for now
182
+ negative: "blurry, boring, fuzzy, low detail, low resolution, low res, low quality"
183
+ # the loss for this target is multiplied by this number.
184
+ # if you are doing more than one target it may be good to set less important ones
185
+ # to a lower number like 0.1 so they don't outweigh the primary target
186
+ weight: 1.0
187
+ # shuffle the prompts split by the comma. We will run every combination randomly
188
+ # this will make the LoRA more robust. You probably want this on unless prompt order
189
+ # is important for some reason
190
+ shuffle: true
191
+
192
+
193
+ # anchors are prompts that we will try to hold on to while training the slider
194
+ # these are NOT necessary and can prevent the slider from converging if not done right
195
+ # leave them off if you are having issues, but they can help lock the network
196
+ # on certain concepts to help prevent catastrophic forgetting
197
+ # you want these to generate an image that is not your target_class, but close to it
198
+ # is fine as long as it does not directly overlap it.
199
+ # For example, if you are training on a person smiling,
200
+ # you could use "a person with a face mask" as an anchor. It is a person, the image is the same
201
+ # regardless if they are smiling or not, however, the closer the concept is to the target_class
202
+ # the less the multiplier needs to be. Keep multipliers less than 1.0 for anchors usually
203
+ # for close concepts, you want to be closer to 0.1 or 0.2
204
+ # these will slow down training. I am leaving them off for the demo
205
+
206
+ # anchors:
207
+ # - prompt: "a woman"
208
+ # neg_prompt: "animal"
209
+ # # the multiplier applied to the LoRA when this is run.
210
+ # # higher will give it more weight but also help keep the lora from collapsing
211
+ # multiplier: 1.0
212
+ # - prompt: "a man"
213
+ # neg_prompt: "animal"
214
+ # multiplier: 1.0
215
+ # - prompt: "a person"
216
+ # neg_prompt: "animal"
217
+ # multiplier: 1.0
218
+
219
+ # You can put any information you want here, and it will be saved in the model.
220
+ # The below is an example, but you can put your grocery list in it if you want.
221
+ # It is saved in the model so be aware of that. The software will include this
222
+ # plus some other information for you automatically
223
+ meta:
224
+ # [name] gets replaced with the name above
225
+ name: "[name]"
226
+ # version: '1.0'
227
+ # creator:
228
+ # name: Your Name
229
+ # email: [email protected]
230
+ # website: https://your.website
docker-compose.yml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ version: "3.8"
2
+
3
+ services:
4
+ ai-toolkit:
5
+ image: ostris/aitoolkit:latest
6
+ restart: unless-stopped
7
+ ports:
8
+ - "8675:8675"
9
+ volumes:
10
+ - ~/.cache/huggingface/hub:/root/.cache/huggingface/hub
11
+ - ./aitk_db.db:/app/ai-toolkit/aitk_db.db
12
+ - ./datasets:/app/ai-toolkit/datasets
13
+ - ./output:/app/ai-toolkit/output
14
+ - ./config:/app/ai-toolkit/config
15
+ environment:
16
+ - AI_TOOLKIT_AUTH=${AI_TOOLKIT_AUTH:-password}
17
+ - NODE_ENV=production
18
+ - TZ=UTC
19
+ deploy:
20
+ resources:
21
+ reservations:
22
+ devices:
23
+ - driver: nvidia
24
+ count: all
25
+ capabilities: [gpu]
docker/Dockerfile ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvidia/cuda:12.8.1-devel-ubuntu22.04
2
+
3
+ LABEL authors="jaret"
4
+
5
+ # Set noninteractive to avoid timezone prompts
6
+ ENV DEBIAN_FRONTEND=noninteractive
7
+
8
+ # ref https://en.wikipedia.org/wiki/CUDA
9
+ ENV TORCH_CUDA_ARCH_LIST="8.0 8.6 8.9 9.0 10.0 12.0"
10
+
11
+ # Install dependencies
12
+ RUN apt-get update && apt-get install --no-install-recommends -y \
13
+ git \
14
+ curl \
15
+ build-essential \
16
+ cmake \
17
+ wget \
18
+ python3.10 \
19
+ python3-pip \
20
+ python3-dev \
21
+ python3-setuptools \
22
+ python3-wheel \
23
+ python3-venv \
24
+ ffmpeg \
25
+ tmux \
26
+ htop \
27
+ nvtop \
28
+ python3-opencv \
29
+ openssh-client \
30
+ openssh-server \
31
+ openssl \
32
+ rsync \
33
+ unzip \
34
+ && apt-get clean \
35
+ && rm -rf /var/lib/apt/lists/*
36
+
37
+ # Install nodejs
38
+ WORKDIR /tmp
39
+ RUN curl -sL https://deb.nodesource.com/setup_23.x -o nodesource_setup.sh && \
40
+ bash nodesource_setup.sh && \
41
+ apt-get update && \
42
+ apt-get install -y nodejs && \
43
+ apt-get clean && \
44
+ rm -rf /var/lib/apt/lists/*
45
+
46
+ WORKDIR /app
47
+
48
+ # Set aliases for python and pip
49
+ RUN ln -s /usr/bin/python3 /usr/bin/python
50
+
51
+ # install pytorch before cache bust to avoid redownloading pytorch
52
+ RUN pip install --pre --no-cache-dir torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128
53
+
54
+ # Fix cache busting by moving CACHEBUST to right before git clone
55
+ ARG CACHEBUST=1234
56
+ ARG GIT_COMMIT=main
57
+ RUN echo "Cache bust: ${CACHEBUST}" && \
58
+ git clone https://github.com/ostris/ai-toolkit.git && \
59
+ cd ai-toolkit && \
60
+ git checkout ${GIT_COMMIT}
61
+
62
+ WORKDIR /app/ai-toolkit
63
+
64
+ # Install Python dependencies
65
+ RUN pip install --no-cache-dir -r requirements.txt && \
66
+ pip install --pre --no-cache-dir torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128 --force && \
67
+ pip install setuptools==69.5.1 --no-cache-dir
68
+
69
+ # Build UI
70
+ WORKDIR /app/ai-toolkit/ui
71
+ RUN npm install && \
72
+ npm run build && \
73
+ npm run update_db
74
+
75
+ # Expose port (assuming the application runs on port 3000)
76
+ EXPOSE 8675
77
+
78
+ WORKDIR /
79
+
80
+ COPY docker/start.sh /start.sh
81
+ RUN chmod +x /start.sh
82
+
83
+ CMD ["/start.sh"]
docker/start.sh ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ set -e # Exit the script if any statement returns a non-true return value
3
+
4
+ # ref https://github.com/runpod/containers/blob/main/container-template/start.sh
5
+
6
+ # ---------------------------------------------------------------------------- #
7
+ # Function Definitions #
8
+ # ---------------------------------------------------------------------------- #
9
+
10
+
11
+ # Setup ssh
12
+ setup_ssh() {
13
+ if [[ $PUBLIC_KEY ]]; then
14
+ echo "Setting up SSH..."
15
+ mkdir -p ~/.ssh
16
+ echo "$PUBLIC_KEY" >> ~/.ssh/authorized_keys
17
+ chmod 700 -R ~/.ssh
18
+
19
+ if [ ! -f /etc/ssh/ssh_host_rsa_key ]; then
20
+ ssh-keygen -t rsa -f /etc/ssh/ssh_host_rsa_key -q -N ''
21
+ echo "RSA key fingerprint:"
22
+ ssh-keygen -lf /etc/ssh/ssh_host_rsa_key.pub
23
+ fi
24
+
25
+ if [ ! -f /etc/ssh/ssh_host_dsa_key ]; then
26
+ ssh-keygen -t dsa -f /etc/ssh/ssh_host_dsa_key -q -N ''
27
+ echo "DSA key fingerprint:"
28
+ ssh-keygen -lf /etc/ssh/ssh_host_dsa_key.pub
29
+ fi
30
+
31
+ if [ ! -f /etc/ssh/ssh_host_ecdsa_key ]; then
32
+ ssh-keygen -t ecdsa -f /etc/ssh/ssh_host_ecdsa_key -q -N ''
33
+ echo "ECDSA key fingerprint:"
34
+ ssh-keygen -lf /etc/ssh/ssh_host_ecdsa_key.pub
35
+ fi
36
+
37
+ if [ ! -f /etc/ssh/ssh_host_ed25519_key ]; then
38
+ ssh-keygen -t ed25519 -f /etc/ssh/ssh_host_ed25519_key -q -N ''
39
+ echo "ED25519 key fingerprint:"
40
+ ssh-keygen -lf /etc/ssh/ssh_host_ed25519_key.pub
41
+ fi
42
+
43
+ service ssh start
44
+
45
+ echo "SSH host keys:"
46
+ for key in /etc/ssh/*.pub; do
47
+ echo "Key: $key"
48
+ ssh-keygen -lf $key
49
+ done
50
+ fi
51
+ }
52
+
53
+ # Export env vars
54
+ export_env_vars() {
55
+ echo "Exporting environment variables..."
56
+ printenv | grep -E '^RUNPOD_|^PATH=|^_=' | awk -F = '{ print "export " $1 "=\"" $2 "\"" }' >> /etc/rp_environment
57
+ echo 'source /etc/rp_environment' >> ~/.bashrc
58
+ }
59
+
60
+ # ---------------------------------------------------------------------------- #
61
+ # Main Program #
62
+ # ---------------------------------------------------------------------------- #
63
+
64
+
65
+ echo "Pod Started"
66
+
67
+ setup_ssh
68
+ export_env_vars
69
+ echo "Starting AI Toolkit UI..."
70
+ cd /app/ai-toolkit/ui && npm run start
extensions/example/ExampleMergeModels.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gc
3
+ from collections import OrderedDict
4
+ from typing import TYPE_CHECKING
5
+ from jobs.process import BaseExtensionProcess
6
+ from toolkit.config_modules import ModelConfig
7
+ from toolkit.stable_diffusion_model import StableDiffusion
8
+ from toolkit.train_tools import get_torch_dtype
9
+ from tqdm import tqdm
10
+
11
+ # Type check imports. Prevents circular imports
12
+ if TYPE_CHECKING:
13
+ from jobs import ExtensionJob
14
+
15
+
16
+ # extend standard config classes to add weight
17
+ class ModelInputConfig(ModelConfig):
18
+ def __init__(self, **kwargs):
19
+ super().__init__(**kwargs)
20
+ self.weight = kwargs.get('weight', 1.0)
21
+ # overwrite default dtype unless user specifies otherwise
22
+ # float 32 will give up better precision on the merging functions
23
+ self.dtype: str = kwargs.get('dtype', 'float32')
24
+
25
+
26
+ def flush():
27
+ torch.cuda.empty_cache()
28
+ gc.collect()
29
+
30
+
31
+ # this is our main class process
32
+ class ExampleMergeModels(BaseExtensionProcess):
33
+ def __init__(
34
+ self,
35
+ process_id: int,
36
+ job: 'ExtensionJob',
37
+ config: OrderedDict
38
+ ):
39
+ super().__init__(process_id, job, config)
40
+ # this is the setup process, do not do process intensive stuff here, just variable setup and
41
+ # checking requirements. This is called before the run() function
42
+ # no loading models or anything like that, it is just for setting up the process
43
+ # all of your process intensive stuff should be done in the run() function
44
+ # config will have everything from the process item in the config file
45
+
46
+ # convince methods exist on BaseProcess to get config values
47
+ # if required is set to true and the value is not found it will throw an error
48
+ # you can pass a default value to get_conf() as well if it was not in the config file
49
+ # as well as a type to cast the value to
50
+ self.save_path = self.get_conf('save_path', required=True)
51
+ self.save_dtype = self.get_conf('save_dtype', default='float16', as_type=get_torch_dtype)
52
+ self.device = self.get_conf('device', default='cpu', as_type=torch.device)
53
+
54
+ # build models to merge list
55
+ models_to_merge = self.get_conf('models_to_merge', required=True, as_type=list)
56
+ # build list of ModelInputConfig objects. I find it is a good idea to make a class for each config
57
+ # this way you can add methods to it and it is easier to read and code. There are a lot of
58
+ # inbuilt config classes located in toolkit.config_modules as well
59
+ self.models_to_merge = [ModelInputConfig(**model) for model in models_to_merge]
60
+ # setup is complete. Don't load anything else here, just setup variables and stuff
61
+
62
+ # this is the entire run process be sure to call super().run() first
63
+ def run(self):
64
+ # always call first
65
+ super().run()
66
+ print(f"Running process: {self.__class__.__name__}")
67
+
68
+ # let's adjust our weights first to normalize them so the total is 1.0
69
+ total_weight = sum([model.weight for model in self.models_to_merge])
70
+ weight_adjust = 1.0 / total_weight
71
+ for model in self.models_to_merge:
72
+ model.weight *= weight_adjust
73
+
74
+ output_model: StableDiffusion = None
75
+ # let's do the merge, it is a good idea to use tqdm to show progress
76
+ for model_config in tqdm(self.models_to_merge, desc="Merging models"):
77
+ # setup model class with our helper class
78
+ sd_model = StableDiffusion(
79
+ device=self.device,
80
+ model_config=model_config,
81
+ dtype="float32"
82
+ )
83
+ # load the model
84
+ sd_model.load_model()
85
+
86
+ # adjust the weight of the text encoder
87
+ if isinstance(sd_model.text_encoder, list):
88
+ # sdxl model
89
+ for text_encoder in sd_model.text_encoder:
90
+ for key, value in text_encoder.state_dict().items():
91
+ value *= model_config.weight
92
+ else:
93
+ # normal model
94
+ for key, value in sd_model.text_encoder.state_dict().items():
95
+ value *= model_config.weight
96
+ # adjust the weights of the unet
97
+ for key, value in sd_model.unet.state_dict().items():
98
+ value *= model_config.weight
99
+
100
+ if output_model is None:
101
+ # use this one as the base
102
+ output_model = sd_model
103
+ else:
104
+ # merge the models
105
+ # text encoder
106
+ if isinstance(output_model.text_encoder, list):
107
+ # sdxl model
108
+ for i, text_encoder in enumerate(output_model.text_encoder):
109
+ for key, value in text_encoder.state_dict().items():
110
+ value += sd_model.text_encoder[i].state_dict()[key]
111
+ else:
112
+ # normal model
113
+ for key, value in output_model.text_encoder.state_dict().items():
114
+ value += sd_model.text_encoder.state_dict()[key]
115
+ # unet
116
+ for key, value in output_model.unet.state_dict().items():
117
+ value += sd_model.unet.state_dict()[key]
118
+
119
+ # remove the model to free memory
120
+ del sd_model
121
+ flush()
122
+
123
+ # merge loop is done, let's save the model
124
+ print(f"Saving merged model to {self.save_path}")
125
+ output_model.save(self.save_path, meta=self.meta, save_dtype=self.save_dtype)
126
+ print(f"Saved merged model to {self.save_path}")
127
+ # do cleanup here
128
+ del output_model
129
+ flush()
extensions/example/__init__.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This is an example extension for custom training. It is great for experimenting with new ideas.
2
+ from toolkit.extension import Extension
3
+
4
+
5
+ # We make a subclass of Extension
6
+ class ExampleMergeExtension(Extension):
7
+ # uid must be unique, it is how the extension is identified
8
+ uid = "example_merge_extension"
9
+
10
+ # name is the name of the extension for printing
11
+ name = "Example Merge Extension"
12
+
13
+ # This is where your process class is loaded
14
+ # keep your imports in here so they don't slow down the rest of the program
15
+ @classmethod
16
+ def get_process(cls):
17
+ # import your process class here so it is only loaded when needed and return it
18
+ from .ExampleMergeModels import ExampleMergeModels
19
+ return ExampleMergeModels
20
+
21
+
22
+ AI_TOOLKIT_EXTENSIONS = [
23
+ # you can put a list of extensions here
24
+ ExampleMergeExtension
25
+ ]
extensions/example/config/config.example.yaml ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ # Always include at least one example config file to show how to use your extension.
3
+ # use plenty of comments so users know how to use it and what everything does
4
+
5
+ # all extensions will use this job name
6
+ job: extension
7
+ config:
8
+ name: 'my_awesome_merge'
9
+ process:
10
+ # Put your example processes here. This will be passed
11
+ # to your extension process in the config argument.
12
+ # the type MUST match your extension uid
13
+ - type: "example_merge_extension"
14
+ # save path for the merged model
15
+ save_path: "output/merge/[name].safetensors"
16
+ # save type
17
+ dtype: fp16
18
+ # device to run it on
19
+ device: cuda:0
20
+ # input models can only be SD1.x and SD2.x models for this example (currently)
21
+ models_to_merge:
22
+ # weights are relative, total weights will be normalized
23
+ # for example. If you have 2 models with weight 1.0, they will
24
+ # both be weighted 0.5. If you have 1 model with weight 1.0 and
25
+ # another with weight 2.0, the first will be weighted 1/3 and the
26
+ # second will be weighted 2/3
27
+ - name_or_path: "input/model1.safetensors"
28
+ weight: 1.0
29
+ - name_or_path: "input/model2.safetensors"
30
+ weight: 1.0
31
+ - name_or_path: "input/model3.safetensors"
32
+ weight: 0.3
33
+ - name_or_path: "input/model4.safetensors"
34
+ weight: 1.0
35
+
36
+
37
+ # you can put any information you want here, and it will be saved in the model
38
+ # the below is an example. I recommend doing trigger words at a minimum
39
+ # in the metadata. The software will include this plus some other information
40
+ meta:
41
+ name: "[name]" # [name] gets replaced with the name above
42
+ description: A short description of your model
43
+ version: '0.1'
44
+ creator:
45
+ name: Your Name
46
47
+ website: https://yourwebsite.com
48
+ any: All meta data above is arbitrary, it can be whatever you want.
extensions_built_in/.DS_Store ADDED
Binary file (8.2 kB). View file
 
extensions_built_in/advanced_generator/Img2ImgGenerator.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ import random
4
+ from collections import OrderedDict
5
+ from typing import List
6
+
7
+ import numpy as np
8
+ from PIL import Image
9
+ from diffusers import T2IAdapter
10
+ from diffusers.utils.torch_utils import randn_tensor
11
+ from torch.utils.data import DataLoader
12
+ from diffusers import StableDiffusionXLImg2ImgPipeline, PixArtSigmaPipeline
13
+ from tqdm import tqdm
14
+
15
+ from toolkit.config_modules import ModelConfig, GenerateImageConfig, preprocess_dataset_raw_config, DatasetConfig
16
+ from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO
17
+ from toolkit.sampler import get_sampler
18
+ from toolkit.stable_diffusion_model import StableDiffusion
19
+ import gc
20
+ import torch
21
+ from jobs.process import BaseExtensionProcess
22
+ from toolkit.data_loader import get_dataloader_from_datasets
23
+ from toolkit.train_tools import get_torch_dtype
24
+ from controlnet_aux.midas import MidasDetector
25
+ from diffusers.utils import load_image
26
+ from torchvision.transforms import ToTensor
27
+
28
+
29
+ def flush():
30
+ torch.cuda.empty_cache()
31
+ gc.collect()
32
+
33
+
34
+
35
+
36
+
37
+ class GenerateConfig:
38
+
39
+ def __init__(self, **kwargs):
40
+ self.prompts: List[str]
41
+ self.sampler = kwargs.get('sampler', 'ddpm')
42
+ self.neg = kwargs.get('neg', '')
43
+ self.seed = kwargs.get('seed', -1)
44
+ self.walk_seed = kwargs.get('walk_seed', False)
45
+ self.guidance_scale = kwargs.get('guidance_scale', 7)
46
+ self.sample_steps = kwargs.get('sample_steps', 20)
47
+ self.guidance_rescale = kwargs.get('guidance_rescale', 0.0)
48
+ self.ext = kwargs.get('ext', 'png')
49
+ self.denoise_strength = kwargs.get('denoise_strength', 0.5)
50
+ self.trigger_word = kwargs.get('trigger_word', None)
51
+
52
+
53
+ class Img2ImgGenerator(BaseExtensionProcess):
54
+
55
+ def __init__(self, process_id: int, job, config: OrderedDict):
56
+ super().__init__(process_id, job, config)
57
+ self.output_folder = self.get_conf('output_folder', required=True)
58
+ self.copy_inputs_to = self.get_conf('copy_inputs_to', None)
59
+ self.device = self.get_conf('device', 'cuda')
60
+ self.model_config = ModelConfig(**self.get_conf('model', required=True))
61
+ self.generate_config = GenerateConfig(**self.get_conf('generate', required=True))
62
+ self.is_latents_cached = True
63
+ raw_datasets = self.get_conf('datasets', None)
64
+ if raw_datasets is not None and len(raw_datasets) > 0:
65
+ raw_datasets = preprocess_dataset_raw_config(raw_datasets)
66
+ self.datasets = None
67
+ self.datasets_reg = None
68
+ self.dtype = self.get_conf('dtype', 'float16')
69
+ self.torch_dtype = get_torch_dtype(self.dtype)
70
+ self.params = []
71
+ if raw_datasets is not None and len(raw_datasets) > 0:
72
+ for raw_dataset in raw_datasets:
73
+ dataset = DatasetConfig(**raw_dataset)
74
+ is_caching = dataset.cache_latents or dataset.cache_latents_to_disk
75
+ if not is_caching:
76
+ self.is_latents_cached = False
77
+ if dataset.is_reg:
78
+ if self.datasets_reg is None:
79
+ self.datasets_reg = []
80
+ self.datasets_reg.append(dataset)
81
+ else:
82
+ if self.datasets is None:
83
+ self.datasets = []
84
+ self.datasets.append(dataset)
85
+
86
+ self.progress_bar = None
87
+ self.sd = StableDiffusion(
88
+ device=self.device,
89
+ model_config=self.model_config,
90
+ dtype=self.dtype,
91
+ )
92
+ print(f"Using device {self.device}")
93
+ self.data_loader: DataLoader = None
94
+ self.adapter: T2IAdapter = None
95
+
96
+ def to_pil(self, img):
97
+ # image comes in -1 to 1. convert to a PIL RGB image
98
+ img = (img + 1) / 2
99
+ img = img.clamp(0, 1)
100
+ img = img[0].permute(1, 2, 0).cpu().numpy()
101
+ img = (img * 255).astype(np.uint8)
102
+ image = Image.fromarray(img)
103
+ return image
104
+
105
+ def run(self):
106
+ with torch.no_grad():
107
+ super().run()
108
+ print("Loading model...")
109
+ self.sd.load_model()
110
+ device = torch.device(self.device)
111
+
112
+ if self.model_config.is_xl:
113
+ pipe = StableDiffusionXLImg2ImgPipeline(
114
+ vae=self.sd.vae,
115
+ unet=self.sd.unet,
116
+ text_encoder=self.sd.text_encoder[0],
117
+ text_encoder_2=self.sd.text_encoder[1],
118
+ tokenizer=self.sd.tokenizer[0],
119
+ tokenizer_2=self.sd.tokenizer[1],
120
+ scheduler=get_sampler(self.generate_config.sampler),
121
+ ).to(device, dtype=self.torch_dtype)
122
+ elif self.model_config.is_pixart:
123
+ pipe = self.sd.pipeline.to(device, dtype=self.torch_dtype)
124
+ else:
125
+ raise NotImplementedError("Only XL models are supported")
126
+ pipe.set_progress_bar_config(disable=True)
127
+
128
+ # pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
129
+ # midas_depth = torch.compile(midas_depth, mode="reduce-overhead", fullgraph=True)
130
+
131
+ self.data_loader = get_dataloader_from_datasets(self.datasets, 1, self.sd)
132
+
133
+ num_batches = len(self.data_loader)
134
+ pbar = tqdm(total=num_batches, desc="Generating images")
135
+ seed = self.generate_config.seed
136
+ # load images from datasets, use tqdm
137
+ for i, batch in enumerate(self.data_loader):
138
+ batch: DataLoaderBatchDTO = batch
139
+
140
+ gen_seed = seed if seed > 0 else random.randint(0, 2 ** 32 - 1)
141
+ generator = torch.manual_seed(gen_seed)
142
+
143
+ file_item: FileItemDTO = batch.file_items[0]
144
+ img_path = file_item.path
145
+ img_filename = os.path.basename(img_path)
146
+ img_filename_no_ext = os.path.splitext(img_filename)[0]
147
+ img_filename = img_filename_no_ext + '.' + self.generate_config.ext
148
+ output_path = os.path.join(self.output_folder, img_filename)
149
+ output_caption_path = os.path.join(self.output_folder, img_filename_no_ext + '.txt')
150
+
151
+ if self.copy_inputs_to is not None:
152
+ output_inputs_path = os.path.join(self.copy_inputs_to, img_filename)
153
+ output_inputs_caption_path = os.path.join(self.copy_inputs_to, img_filename_no_ext + '.txt')
154
+ else:
155
+ output_inputs_path = None
156
+ output_inputs_caption_path = None
157
+
158
+ caption = batch.get_caption_list()[0]
159
+ if self.generate_config.trigger_word is not None:
160
+ caption = caption.replace('[trigger]', self.generate_config.trigger_word)
161
+
162
+ img: torch.Tensor = batch.tensor.clone()
163
+ image = self.to_pil(img)
164
+
165
+ # image.save(output_depth_path)
166
+ if self.model_config.is_pixart:
167
+ pipe: PixArtSigmaPipeline = pipe
168
+
169
+ # Encode the full image once
170
+ encoded_image = pipe.vae.encode(
171
+ pipe.image_processor.preprocess(image).to(device=pipe.device, dtype=pipe.dtype))
172
+ if hasattr(encoded_image, "latent_dist"):
173
+ latents = encoded_image.latent_dist.sample(generator)
174
+ elif hasattr(encoded_image, "latents"):
175
+ latents = encoded_image.latents
176
+ else:
177
+ raise AttributeError("Could not access latents of provided encoder_output")
178
+ latents = pipe.vae.config.scaling_factor * latents
179
+
180
+ # latents = self.sd.encode_images(img)
181
+
182
+ # self.sd.noise_scheduler.set_timesteps(self.generate_config.sample_steps)
183
+ # start_step = math.floor(self.generate_config.sample_steps * self.generate_config.denoise_strength)
184
+ # timestep = self.sd.noise_scheduler.timesteps[start_step].unsqueeze(0)
185
+ # timestep = timestep.to(device, dtype=torch.int32)
186
+ # latent = latent.to(device, dtype=self.torch_dtype)
187
+ # noise = torch.randn_like(latent, device=device, dtype=self.torch_dtype)
188
+ # latent = self.sd.add_noise(latent, noise, timestep)
189
+ # timesteps_to_use = self.sd.noise_scheduler.timesteps[start_step + 1:]
190
+ batch_size = 1
191
+ num_images_per_prompt = 1
192
+
193
+ shape = (batch_size, pipe.transformer.config.in_channels, image.height // pipe.vae_scale_factor,
194
+ image.width // pipe.vae_scale_factor)
195
+ noise = randn_tensor(shape, generator=generator, device=pipe.device, dtype=pipe.dtype)
196
+
197
+ # noise = torch.randn_like(latents, device=device, dtype=self.torch_dtype)
198
+ num_inference_steps = self.generate_config.sample_steps
199
+ strength = self.generate_config.denoise_strength
200
+ # Get timesteps
201
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
202
+ t_start = max(num_inference_steps - init_timestep, 0)
203
+ pipe.scheduler.set_timesteps(num_inference_steps, device="cpu")
204
+ timesteps = pipe.scheduler.timesteps[t_start:]
205
+ timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
206
+ latents = pipe.scheduler.add_noise(latents, noise, timestep)
207
+
208
+ gen_images = pipe.__call__(
209
+ prompt=caption,
210
+ negative_prompt=self.generate_config.neg,
211
+ latents=latents,
212
+ timesteps=timesteps,
213
+ width=image.width,
214
+ height=image.height,
215
+ num_inference_steps=num_inference_steps,
216
+ num_images_per_prompt=num_images_per_prompt,
217
+ guidance_scale=self.generate_config.guidance_scale,
218
+ # strength=self.generate_config.denoise_strength,
219
+ use_resolution_binning=False,
220
+ output_type="np"
221
+ ).images[0]
222
+ gen_images = (gen_images * 255).clip(0, 255).astype(np.uint8)
223
+ gen_images = Image.fromarray(gen_images)
224
+ else:
225
+ pipe: StableDiffusionXLImg2ImgPipeline = pipe
226
+
227
+ gen_images = pipe.__call__(
228
+ prompt=caption,
229
+ negative_prompt=self.generate_config.neg,
230
+ image=image,
231
+ num_inference_steps=self.generate_config.sample_steps,
232
+ guidance_scale=self.generate_config.guidance_scale,
233
+ strength=self.generate_config.denoise_strength,
234
+ ).images[0]
235
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
236
+ gen_images.save(output_path)
237
+
238
+ # save caption
239
+ with open(output_caption_path, 'w') as f:
240
+ f.write(caption)
241
+
242
+ if output_inputs_path is not None:
243
+ os.makedirs(os.path.dirname(output_inputs_path), exist_ok=True)
244
+ image.save(output_inputs_path)
245
+ with open(output_inputs_caption_path, 'w') as f:
246
+ f.write(caption)
247
+
248
+ pbar.update(1)
249
+ batch.cleanup()
250
+
251
+ pbar.close()
252
+ print("Done generating images")
253
+ # cleanup
254
+ del self.sd
255
+ gc.collect()
256
+ torch.cuda.empty_cache()
extensions_built_in/advanced_generator/PureLoraGenerator.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from collections import OrderedDict
3
+
4
+ from toolkit.config_modules import ModelConfig, GenerateImageConfig, SampleConfig, LoRMConfig
5
+ from toolkit.lorm import ExtractMode, convert_diffusers_unet_to_lorm
6
+ from toolkit.sd_device_states_presets import get_train_sd_device_state_preset
7
+ from toolkit.stable_diffusion_model import StableDiffusion
8
+ import gc
9
+ import torch
10
+ from jobs.process import BaseExtensionProcess
11
+ from toolkit.train_tools import get_torch_dtype
12
+
13
+
14
+ def flush():
15
+ torch.cuda.empty_cache()
16
+ gc.collect()
17
+
18
+
19
+ class PureLoraGenerator(BaseExtensionProcess):
20
+
21
+ def __init__(self, process_id: int, job, config: OrderedDict):
22
+ super().__init__(process_id, job, config)
23
+ self.output_folder = self.get_conf('output_folder', required=True)
24
+ self.device = self.get_conf('device', 'cuda')
25
+ self.device_torch = torch.device(self.device)
26
+ self.model_config = ModelConfig(**self.get_conf('model', required=True))
27
+ self.generate_config = SampleConfig(**self.get_conf('sample', required=True))
28
+ self.dtype = self.get_conf('dtype', 'float16')
29
+ self.torch_dtype = get_torch_dtype(self.dtype)
30
+ lorm_config = self.get_conf('lorm', None)
31
+ self.lorm_config = LoRMConfig(**lorm_config) if lorm_config is not None else None
32
+
33
+ self.device_state_preset = get_train_sd_device_state_preset(
34
+ device=torch.device(self.device),
35
+ )
36
+
37
+ self.progress_bar = None
38
+ self.sd = StableDiffusion(
39
+ device=self.device,
40
+ model_config=self.model_config,
41
+ dtype=self.dtype,
42
+ )
43
+
44
+ def run(self):
45
+ super().run()
46
+ print("Loading model...")
47
+ with torch.no_grad():
48
+ self.sd.load_model()
49
+ self.sd.unet.eval()
50
+ self.sd.unet.to(self.device_torch)
51
+ if isinstance(self.sd.text_encoder, list):
52
+ for te in self.sd.text_encoder:
53
+ te.eval()
54
+ te.to(self.device_torch)
55
+ else:
56
+ self.sd.text_encoder.eval()
57
+ self.sd.to(self.device_torch)
58
+
59
+ print(f"Converting to LoRM UNet")
60
+ # replace the unet with LoRMUnet
61
+ convert_diffusers_unet_to_lorm(
62
+ self.sd.unet,
63
+ config=self.lorm_config,
64
+ )
65
+
66
+ sample_folder = os.path.join(self.output_folder)
67
+ gen_img_config_list = []
68
+
69
+ sample_config = self.generate_config
70
+ start_seed = sample_config.seed
71
+ current_seed = start_seed
72
+ for i in range(len(sample_config.prompts)):
73
+ if sample_config.walk_seed:
74
+ current_seed = start_seed + i
75
+
76
+ filename = f"[time]_[count].{self.generate_config.ext}"
77
+ output_path = os.path.join(sample_folder, filename)
78
+ prompt = sample_config.prompts[i]
79
+ extra_args = {}
80
+ gen_img_config_list.append(GenerateImageConfig(
81
+ prompt=prompt, # it will autoparse the prompt
82
+ width=sample_config.width,
83
+ height=sample_config.height,
84
+ negative_prompt=sample_config.neg,
85
+ seed=current_seed,
86
+ guidance_scale=sample_config.guidance_scale,
87
+ guidance_rescale=sample_config.guidance_rescale,
88
+ num_inference_steps=sample_config.sample_steps,
89
+ network_multiplier=sample_config.network_multiplier,
90
+ output_path=output_path,
91
+ output_ext=sample_config.ext,
92
+ adapter_conditioning_scale=sample_config.adapter_conditioning_scale,
93
+ **extra_args
94
+ ))
95
+
96
+ # send to be generated
97
+ self.sd.generate_images(gen_img_config_list, sampler=sample_config.sampler)
98
+ print("Done generating images")
99
+ # cleanup
100
+ del self.sd
101
+ gc.collect()
102
+ torch.cuda.empty_cache()
extensions_built_in/advanced_generator/ReferenceGenerator.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ from collections import OrderedDict
4
+ from typing import List
5
+
6
+ import numpy as np
7
+ from PIL import Image
8
+ from diffusers import T2IAdapter
9
+ from torch.utils.data import DataLoader
10
+ from diffusers import StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline
11
+ from tqdm import tqdm
12
+
13
+ from toolkit.config_modules import ModelConfig, GenerateImageConfig, preprocess_dataset_raw_config, DatasetConfig
14
+ from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO
15
+ from toolkit.sampler import get_sampler
16
+ from toolkit.stable_diffusion_model import StableDiffusion
17
+ import gc
18
+ import torch
19
+ from jobs.process import BaseExtensionProcess
20
+ from toolkit.data_loader import get_dataloader_from_datasets
21
+ from toolkit.train_tools import get_torch_dtype
22
+ from controlnet_aux.midas import MidasDetector
23
+ from diffusers.utils import load_image
24
+
25
+
26
+ def flush():
27
+ torch.cuda.empty_cache()
28
+ gc.collect()
29
+
30
+
31
+ class GenerateConfig:
32
+
33
+ def __init__(self, **kwargs):
34
+ self.prompts: List[str]
35
+ self.sampler = kwargs.get('sampler', 'ddpm')
36
+ self.neg = kwargs.get('neg', '')
37
+ self.seed = kwargs.get('seed', -1)
38
+ self.walk_seed = kwargs.get('walk_seed', False)
39
+ self.t2i_adapter_path = kwargs.get('t2i_adapter_path', None)
40
+ self.guidance_scale = kwargs.get('guidance_scale', 7)
41
+ self.sample_steps = kwargs.get('sample_steps', 20)
42
+ self.prompt_2 = kwargs.get('prompt_2', None)
43
+ self.neg_2 = kwargs.get('neg_2', None)
44
+ self.prompts = kwargs.get('prompts', None)
45
+ self.guidance_rescale = kwargs.get('guidance_rescale', 0.0)
46
+ self.ext = kwargs.get('ext', 'png')
47
+ self.adapter_conditioning_scale = kwargs.get('adapter_conditioning_scale', 1.0)
48
+ if kwargs.get('shuffle', False):
49
+ # shuffle the prompts
50
+ random.shuffle(self.prompts)
51
+
52
+
53
+ class ReferenceGenerator(BaseExtensionProcess):
54
+
55
+ def __init__(self, process_id: int, job, config: OrderedDict):
56
+ super().__init__(process_id, job, config)
57
+ self.output_folder = self.get_conf('output_folder', required=True)
58
+ self.device = self.get_conf('device', 'cuda')
59
+ self.model_config = ModelConfig(**self.get_conf('model', required=True))
60
+ self.generate_config = GenerateConfig(**self.get_conf('generate', required=True))
61
+ self.is_latents_cached = True
62
+ raw_datasets = self.get_conf('datasets', None)
63
+ if raw_datasets is not None and len(raw_datasets) > 0:
64
+ raw_datasets = preprocess_dataset_raw_config(raw_datasets)
65
+ self.datasets = None
66
+ self.datasets_reg = None
67
+ self.dtype = self.get_conf('dtype', 'float16')
68
+ self.torch_dtype = get_torch_dtype(self.dtype)
69
+ self.params = []
70
+ if raw_datasets is not None and len(raw_datasets) > 0:
71
+ for raw_dataset in raw_datasets:
72
+ dataset = DatasetConfig(**raw_dataset)
73
+ is_caching = dataset.cache_latents or dataset.cache_latents_to_disk
74
+ if not is_caching:
75
+ self.is_latents_cached = False
76
+ if dataset.is_reg:
77
+ if self.datasets_reg is None:
78
+ self.datasets_reg = []
79
+ self.datasets_reg.append(dataset)
80
+ else:
81
+ if self.datasets is None:
82
+ self.datasets = []
83
+ self.datasets.append(dataset)
84
+
85
+ self.progress_bar = None
86
+ self.sd = StableDiffusion(
87
+ device=self.device,
88
+ model_config=self.model_config,
89
+ dtype=self.dtype,
90
+ )
91
+ print(f"Using device {self.device}")
92
+ self.data_loader: DataLoader = None
93
+ self.adapter: T2IAdapter = None
94
+
95
+ def run(self):
96
+ super().run()
97
+ print("Loading model...")
98
+ self.sd.load_model()
99
+ device = torch.device(self.device)
100
+
101
+ if self.generate_config.t2i_adapter_path is not None:
102
+ self.adapter = T2IAdapter.from_pretrained(
103
+ self.generate_config.t2i_adapter_path,
104
+ torch_dtype=self.torch_dtype,
105
+ varient="fp16"
106
+ ).to(device)
107
+
108
+ midas_depth = MidasDetector.from_pretrained(
109
+ "valhalla/t2iadapter-aux-models", filename="dpt_large_384.pt", model_type="dpt_large"
110
+ ).to(device)
111
+
112
+ if self.model_config.is_xl:
113
+ pipe = StableDiffusionXLAdapterPipeline(
114
+ vae=self.sd.vae,
115
+ unet=self.sd.unet,
116
+ text_encoder=self.sd.text_encoder[0],
117
+ text_encoder_2=self.sd.text_encoder[1],
118
+ tokenizer=self.sd.tokenizer[0],
119
+ tokenizer_2=self.sd.tokenizer[1],
120
+ scheduler=get_sampler(self.generate_config.sampler),
121
+ adapter=self.adapter,
122
+ ).to(device, dtype=self.torch_dtype)
123
+ else:
124
+ pipe = StableDiffusionAdapterPipeline(
125
+ vae=self.sd.vae,
126
+ unet=self.sd.unet,
127
+ text_encoder=self.sd.text_encoder,
128
+ tokenizer=self.sd.tokenizer,
129
+ scheduler=get_sampler(self.generate_config.sampler),
130
+ safety_checker=None,
131
+ feature_extractor=None,
132
+ requires_safety_checker=False,
133
+ adapter=self.adapter,
134
+ ).to(device, dtype=self.torch_dtype)
135
+ pipe.set_progress_bar_config(disable=True)
136
+
137
+ pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
138
+ # midas_depth = torch.compile(midas_depth, mode="reduce-overhead", fullgraph=True)
139
+
140
+ self.data_loader = get_dataloader_from_datasets(self.datasets, 1, self.sd)
141
+
142
+ num_batches = len(self.data_loader)
143
+ pbar = tqdm(total=num_batches, desc="Generating images")
144
+ seed = self.generate_config.seed
145
+ # load images from datasets, use tqdm
146
+ for i, batch in enumerate(self.data_loader):
147
+ batch: DataLoaderBatchDTO = batch
148
+
149
+ file_item: FileItemDTO = batch.file_items[0]
150
+ img_path = file_item.path
151
+ img_filename = os.path.basename(img_path)
152
+ img_filename_no_ext = os.path.splitext(img_filename)[0]
153
+ output_path = os.path.join(self.output_folder, img_filename)
154
+ output_caption_path = os.path.join(self.output_folder, img_filename_no_ext + '.txt')
155
+ output_depth_path = os.path.join(self.output_folder, img_filename_no_ext + '.depth.png')
156
+
157
+ caption = batch.get_caption_list()[0]
158
+
159
+ img: torch.Tensor = batch.tensor.clone()
160
+ # image comes in -1 to 1. convert to a PIL RGB image
161
+ img = (img + 1) / 2
162
+ img = img.clamp(0, 1)
163
+ img = img[0].permute(1, 2, 0).cpu().numpy()
164
+ img = (img * 255).astype(np.uint8)
165
+ image = Image.fromarray(img)
166
+
167
+ width, height = image.size
168
+ min_res = min(width, height)
169
+
170
+ if self.generate_config.walk_seed:
171
+ seed = seed + 1
172
+
173
+ if self.generate_config.seed == -1:
174
+ # random
175
+ seed = random.randint(0, 1000000)
176
+
177
+ torch.manual_seed(seed)
178
+ torch.cuda.manual_seed(seed)
179
+
180
+ # generate depth map
181
+ image = midas_depth(
182
+ image,
183
+ detect_resolution=min_res, # do 512 ?
184
+ image_resolution=min_res
185
+ )
186
+
187
+ # image.save(output_depth_path)
188
+
189
+ gen_images = pipe(
190
+ prompt=caption,
191
+ negative_prompt=self.generate_config.neg,
192
+ image=image,
193
+ num_inference_steps=self.generate_config.sample_steps,
194
+ adapter_conditioning_scale=self.generate_config.adapter_conditioning_scale,
195
+ guidance_scale=self.generate_config.guidance_scale,
196
+ ).images[0]
197
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
198
+ gen_images.save(output_path)
199
+
200
+ # save caption
201
+ with open(output_caption_path, 'w') as f:
202
+ f.write(caption)
203
+
204
+ pbar.update(1)
205
+ batch.cleanup()
206
+
207
+ pbar.close()
208
+ print("Done generating images")
209
+ # cleanup
210
+ del self.sd
211
+ gc.collect()
212
+ torch.cuda.empty_cache()
extensions_built_in/advanced_generator/__init__.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This is an example extension for custom training. It is great for experimenting with new ideas.
2
+ from toolkit.extension import Extension
3
+
4
+
5
+ # This is for generic training (LoRA, Dreambooth, FineTuning)
6
+ class AdvancedReferenceGeneratorExtension(Extension):
7
+ # uid must be unique, it is how the extension is identified
8
+ uid = "reference_generator"
9
+
10
+ # name is the name of the extension for printing
11
+ name = "Reference Generator"
12
+
13
+ # This is where your process class is loaded
14
+ # keep your imports in here so they don't slow down the rest of the program
15
+ @classmethod
16
+ def get_process(cls):
17
+ # import your process class here so it is only loaded when needed and return it
18
+ from .ReferenceGenerator import ReferenceGenerator
19
+ return ReferenceGenerator
20
+
21
+
22
+ # This is for generic training (LoRA, Dreambooth, FineTuning)
23
+ class PureLoraGenerator(Extension):
24
+ # uid must be unique, it is how the extension is identified
25
+ uid = "pure_lora_generator"
26
+
27
+ # name is the name of the extension for printing
28
+ name = "Pure LoRA Generator"
29
+
30
+ # This is where your process class is loaded
31
+ # keep your imports in here so they don't slow down the rest of the program
32
+ @classmethod
33
+ def get_process(cls):
34
+ # import your process class here so it is only loaded when needed and return it
35
+ from .PureLoraGenerator import PureLoraGenerator
36
+ return PureLoraGenerator
37
+
38
+
39
+ # This is for generic training (LoRA, Dreambooth, FineTuning)
40
+ class Img2ImgGeneratorExtension(Extension):
41
+ # uid must be unique, it is how the extension is identified
42
+ uid = "batch_img2img"
43
+
44
+ # name is the name of the extension for printing
45
+ name = "Img2ImgGeneratorExtension"
46
+
47
+ # This is where your process class is loaded
48
+ # keep your imports in here so they don't slow down the rest of the program
49
+ @classmethod
50
+ def get_process(cls):
51
+ # import your process class here so it is only loaded when needed and return it
52
+ from .Img2ImgGenerator import Img2ImgGenerator
53
+ return Img2ImgGenerator
54
+
55
+
56
+ AI_TOOLKIT_EXTENSIONS = [
57
+ # you can put a list of extensions here
58
+ AdvancedReferenceGeneratorExtension, PureLoraGenerator, Img2ImgGeneratorExtension
59
+ ]
extensions_built_in/advanced_generator/config/train.example.yaml ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ job: extension
3
+ config:
4
+ name: test_v1
5
+ process:
6
+ - type: 'textual_inversion_trainer'
7
+ training_folder: "out/TI"
8
+ device: cuda:0
9
+ # for tensorboard logging
10
+ log_dir: "out/.tensorboard"
11
+ embedding:
12
+ trigger: "your_trigger_here"
13
+ tokens: 12
14
+ init_words: "man with short brown hair"
15
+ save_format: "safetensors" # 'safetensors' or 'pt'
16
+ save:
17
+ dtype: float16 # precision to save
18
+ save_every: 100 # save every this many steps
19
+ max_step_saves_to_keep: 5 # only affects step counts
20
+ datasets:
21
+ - folder_path: "/path/to/dataset"
22
+ caption_ext: "txt"
23
+ default_caption: "[trigger]"
24
+ buckets: true
25
+ resolution: 512
26
+ train:
27
+ noise_scheduler: "ddpm" # or "ddpm", "lms", "euler_a"
28
+ steps: 3000
29
+ weight_jitter: 0.0
30
+ lr: 5e-5
31
+ train_unet: false
32
+ gradient_checkpointing: true
33
+ train_text_encoder: false
34
+ optimizer: "adamw"
35
+ # optimizer: "prodigy"
36
+ optimizer_params:
37
+ weight_decay: 1e-2
38
+ lr_scheduler: "constant"
39
+ max_denoising_steps: 1000
40
+ batch_size: 4
41
+ dtype: bf16
42
+ xformers: true
43
+ min_snr_gamma: 5.0
44
+ # skip_first_sample: true
45
+ noise_offset: 0.0 # not needed for this
46
+ model:
47
+ # objective reality v2
48
+ name_or_path: "https://civitai.com/models/128453?modelVersionId=142465"
49
+ is_v2: false # for v2 models
50
+ is_xl: false # for SDXL models
51
+ is_v_pred: false # for v-prediction models (most v2 models)
52
+ sample:
53
+ sampler: "ddpm" # must match train.noise_scheduler
54
+ sample_every: 100 # sample every this many steps
55
+ width: 512
56
+ height: 512
57
+ prompts:
58
+ - "photo of [trigger] laughing"
59
+ - "photo of [trigger] smiling"
60
+ - "[trigger] close up"
61
+ - "dark scene [trigger] frozen"
62
+ - "[trigger] nighttime"
63
+ - "a painting of [trigger]"
64
+ - "a drawing of [trigger]"
65
+ - "a cartoon of [trigger]"
66
+ - "[trigger] pixar style"
67
+ - "[trigger] costume"
68
+ neg: ""
69
+ seed: 42
70
+ walk_seed: false
71
+ guidance_scale: 7
72
+ sample_steps: 20
73
+ network_multiplier: 1.0
74
+
75
+ logging:
76
+ log_every: 10 # log every this many steps
77
+ use_wandb: false # not supported yet
78
+ verbose: false
79
+
80
+ # You can put any information you want here, and it will be saved in the model.
81
+ # The below is an example, but you can put your grocery list in it if you want.
82
+ # It is saved in the model so be aware of that. The software will include this
83
+ # plus some other information for you automatically
84
+ meta:
85
+ # [name] gets replaced with the name above
86
+ name: "[name]"
87
+ # version: '1.0'
88
+ # creator:
89
+ # name: Your Name
90
+ # email: [email protected]
91
+ # website: https://your.website
extensions_built_in/concept_replacer/ConceptReplacer.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from collections import OrderedDict
3
+ from torch.utils.data import DataLoader
4
+ from toolkit.prompt_utils import concat_prompt_embeds, split_prompt_embeds
5
+ from toolkit.stable_diffusion_model import StableDiffusion, BlankNetwork
6
+ from toolkit.train_tools import get_torch_dtype, apply_snr_weight
7
+ import gc
8
+ import torch
9
+ from jobs.process import BaseSDTrainProcess
10
+
11
+
12
+ def flush():
13
+ torch.cuda.empty_cache()
14
+ gc.collect()
15
+
16
+
17
+ class ConceptReplacementConfig:
18
+ def __init__(self, **kwargs):
19
+ self.concept: str = kwargs.get('concept', '')
20
+ self.replacement: str = kwargs.get('replacement', '')
21
+
22
+
23
+ class ConceptReplacer(BaseSDTrainProcess):
24
+
25
+ def __init__(self, process_id: int, job, config: OrderedDict, **kwargs):
26
+ super().__init__(process_id, job, config, **kwargs)
27
+ replacement_list = self.config.get('replacements', [])
28
+ self.replacement_list = [ConceptReplacementConfig(**x) for x in replacement_list]
29
+
30
+ def before_model_load(self):
31
+ pass
32
+
33
+ def hook_before_train_loop(self):
34
+ self.sd.vae.eval()
35
+ self.sd.vae.to(self.device_torch)
36
+
37
+ # textual inversion
38
+ if self.embedding is not None:
39
+ # set text encoder to train. Not sure if this is necessary but diffusers example did it
40
+ self.sd.text_encoder.train()
41
+
42
+ def hook_train_loop(self, batch):
43
+ with torch.no_grad():
44
+ dtype = get_torch_dtype(self.train_config.dtype)
45
+ noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch)
46
+ network_weight_list = batch.get_network_weight_list()
47
+
48
+ # have a blank network so we can wrap it in a context and set multipliers without checking every time
49
+ if self.network is not None:
50
+ network = self.network
51
+ else:
52
+ network = BlankNetwork()
53
+
54
+ batch_replacement_list = []
55
+ # get a random replacement for each prompt
56
+ for prompt in conditioned_prompts:
57
+ replacement = random.choice(self.replacement_list)
58
+ batch_replacement_list.append(replacement)
59
+
60
+ # build out prompts
61
+ concept_prompts = []
62
+ replacement_prompts = []
63
+ for idx, replacement in enumerate(batch_replacement_list):
64
+ prompt = conditioned_prompts[idx]
65
+
66
+ # insert shuffled concept at beginning and end of prompt
67
+ shuffled_concept = [x.strip() for x in replacement.concept.split(',')]
68
+ random.shuffle(shuffled_concept)
69
+ shuffled_concept = ', '.join(shuffled_concept)
70
+ concept_prompts.append(f"{shuffled_concept}, {prompt}, {shuffled_concept}")
71
+
72
+ # insert replacement at beginning and end of prompt
73
+ shuffled_replacement = [x.strip() for x in replacement.replacement.split(',')]
74
+ random.shuffle(shuffled_replacement)
75
+ shuffled_replacement = ', '.join(shuffled_replacement)
76
+ replacement_prompts.append(f"{shuffled_replacement}, {prompt}, {shuffled_replacement}")
77
+
78
+ # predict the replacement without network
79
+ conditional_embeds = self.sd.encode_prompt(replacement_prompts).to(self.device_torch, dtype=dtype)
80
+
81
+ replacement_pred = self.sd.predict_noise(
82
+ latents=noisy_latents.to(self.device_torch, dtype=dtype),
83
+ conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype),
84
+ timestep=timesteps,
85
+ guidance_scale=1.0,
86
+ )
87
+
88
+ del conditional_embeds
89
+ replacement_pred = replacement_pred.detach()
90
+
91
+ self.optimizer.zero_grad()
92
+ flush()
93
+
94
+ # text encoding
95
+ grad_on_text_encoder = False
96
+ if self.train_config.train_text_encoder:
97
+ grad_on_text_encoder = True
98
+
99
+ if self.embedding:
100
+ grad_on_text_encoder = True
101
+
102
+ # set the weights
103
+ network.multiplier = network_weight_list
104
+
105
+ # activate network if it exits
106
+ with network:
107
+ with torch.set_grad_enabled(grad_on_text_encoder):
108
+ # embed the prompts
109
+ conditional_embeds = self.sd.encode_prompt(concept_prompts).to(self.device_torch, dtype=dtype)
110
+ if not grad_on_text_encoder:
111
+ # detach the embeddings
112
+ conditional_embeds = conditional_embeds.detach()
113
+ self.optimizer.zero_grad()
114
+ flush()
115
+
116
+ noise_pred = self.sd.predict_noise(
117
+ latents=noisy_latents.to(self.device_torch, dtype=dtype),
118
+ conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype),
119
+ timestep=timesteps,
120
+ guidance_scale=1.0,
121
+ )
122
+
123
+ loss = torch.nn.functional.mse_loss(noise_pred.float(), replacement_pred.float(), reduction="none")
124
+ loss = loss.mean([1, 2, 3])
125
+
126
+ if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001:
127
+ # add min_snr_gamma
128
+ loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma)
129
+
130
+ loss = loss.mean()
131
+
132
+ # back propagate loss to free ram
133
+ loss.backward()
134
+ flush()
135
+
136
+ # apply gradients
137
+ self.optimizer.step()
138
+ self.optimizer.zero_grad()
139
+ self.lr_scheduler.step()
140
+
141
+ if self.embedding is not None:
142
+ # Let's make sure we don't update any embedding weights besides the newly added token
143
+ self.embedding.restore_embeddings()
144
+
145
+ loss_dict = OrderedDict(
146
+ {'loss': loss.item()}
147
+ )
148
+ # reset network multiplier
149
+ network.multiplier = 1.0
150
+
151
+ return loss_dict
extensions_built_in/concept_replacer/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This is an example extension for custom training. It is great for experimenting with new ideas.
2
+ from toolkit.extension import Extension
3
+
4
+
5
+ # This is for generic training (LoRA, Dreambooth, FineTuning)
6
+ class ConceptReplacerExtension(Extension):
7
+ # uid must be unique, it is how the extension is identified
8
+ uid = "concept_replacer"
9
+
10
+ # name is the name of the extension for printing
11
+ name = "Concept Replacer"
12
+
13
+ # This is where your process class is loaded
14
+ # keep your imports in here so they don't slow down the rest of the program
15
+ @classmethod
16
+ def get_process(cls):
17
+ # import your process class here so it is only loaded when needed and return it
18
+ from .ConceptReplacer import ConceptReplacer
19
+ return ConceptReplacer
20
+
21
+
22
+
23
+ AI_TOOLKIT_EXTENSIONS = [
24
+ # you can put a list of extensions here
25
+ ConceptReplacerExtension,
26
+ ]
extensions_built_in/concept_replacer/config/train.example.yaml ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ job: extension
3
+ config:
4
+ name: test_v1
5
+ process:
6
+ - type: 'textual_inversion_trainer'
7
+ training_folder: "out/TI"
8
+ device: cuda:0
9
+ # for tensorboard logging
10
+ log_dir: "out/.tensorboard"
11
+ embedding:
12
+ trigger: "your_trigger_here"
13
+ tokens: 12
14
+ init_words: "man with short brown hair"
15
+ save_format: "safetensors" # 'safetensors' or 'pt'
16
+ save:
17
+ dtype: float16 # precision to save
18
+ save_every: 100 # save every this many steps
19
+ max_step_saves_to_keep: 5 # only affects step counts
20
+ datasets:
21
+ - folder_path: "/path/to/dataset"
22
+ caption_ext: "txt"
23
+ default_caption: "[trigger]"
24
+ buckets: true
25
+ resolution: 512
26
+ train:
27
+ noise_scheduler: "ddpm" # or "ddpm", "lms", "euler_a"
28
+ steps: 3000
29
+ weight_jitter: 0.0
30
+ lr: 5e-5
31
+ train_unet: false
32
+ gradient_checkpointing: true
33
+ train_text_encoder: false
34
+ optimizer: "adamw"
35
+ # optimizer: "prodigy"
36
+ optimizer_params:
37
+ weight_decay: 1e-2
38
+ lr_scheduler: "constant"
39
+ max_denoising_steps: 1000
40
+ batch_size: 4
41
+ dtype: bf16
42
+ xformers: true
43
+ min_snr_gamma: 5.0
44
+ # skip_first_sample: true
45
+ noise_offset: 0.0 # not needed for this
46
+ model:
47
+ # objective reality v2
48
+ name_or_path: "https://civitai.com/models/128453?modelVersionId=142465"
49
+ is_v2: false # for v2 models
50
+ is_xl: false # for SDXL models
51
+ is_v_pred: false # for v-prediction models (most v2 models)
52
+ sample:
53
+ sampler: "ddpm" # must match train.noise_scheduler
54
+ sample_every: 100 # sample every this many steps
55
+ width: 512
56
+ height: 512
57
+ prompts:
58
+ - "photo of [trigger] laughing"
59
+ - "photo of [trigger] smiling"
60
+ - "[trigger] close up"
61
+ - "dark scene [trigger] frozen"
62
+ - "[trigger] nighttime"
63
+ - "a painting of [trigger]"
64
+ - "a drawing of [trigger]"
65
+ - "a cartoon of [trigger]"
66
+ - "[trigger] pixar style"
67
+ - "[trigger] costume"
68
+ neg: ""
69
+ seed: 42
70
+ walk_seed: false
71
+ guidance_scale: 7
72
+ sample_steps: 20
73
+ network_multiplier: 1.0
74
+
75
+ logging:
76
+ log_every: 10 # log every this many steps
77
+ use_wandb: false # not supported yet
78
+ verbose: false
79
+
80
+ # You can put any information you want here, and it will be saved in the model.
81
+ # The below is an example, but you can put your grocery list in it if you want.
82
+ # It is saved in the model so be aware of that. The software will include this
83
+ # plus some other information for you automatically
84
+ meta:
85
+ # [name] gets replaced with the name above
86
+ name: "[name]"
87
+ # version: '1.0'
88
+ # creator:
89
+ # name: Your Name
90
+ # email: [email protected]
91
+ # website: https://your.website
extensions_built_in/dataset_tools/DatasetTools.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ import gc
3
+ import torch
4
+ from jobs.process import BaseExtensionProcess
5
+
6
+
7
+ def flush():
8
+ torch.cuda.empty_cache()
9
+ gc.collect()
10
+
11
+
12
+ class DatasetTools(BaseExtensionProcess):
13
+
14
+ def __init__(self, process_id: int, job, config: OrderedDict):
15
+ super().__init__(process_id, job, config)
16
+
17
+ def run(self):
18
+ super().run()
19
+
20
+ raise NotImplementedError("This extension is not yet implemented")
extensions_built_in/dataset_tools/SuperTagger.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import json
3
+ import os
4
+ from collections import OrderedDict
5
+ import gc
6
+ import traceback
7
+ import torch
8
+ from PIL import Image, ImageOps
9
+ from tqdm import tqdm
10
+
11
+ from .tools.dataset_tools_config_modules import RAW_DIR, TRAIN_DIR, Step, ImgInfo
12
+ from .tools.fuyu_utils import FuyuImageProcessor
13
+ from .tools.image_tools import load_image, ImageProcessor, resize_to_max
14
+ from .tools.llava_utils import LLaVAImageProcessor
15
+ from .tools.caption import default_long_prompt, default_short_prompt, default_replacements
16
+ from jobs.process import BaseExtensionProcess
17
+ from .tools.sync_tools import get_img_paths
18
+
19
+ img_ext = ['.jpg', '.jpeg', '.png', '.webp']
20
+
21
+
22
+ def flush():
23
+ torch.cuda.empty_cache()
24
+ gc.collect()
25
+
26
+
27
+ VERSION = 2
28
+
29
+
30
+ class SuperTagger(BaseExtensionProcess):
31
+
32
+ def __init__(self, process_id: int, job, config: OrderedDict):
33
+ super().__init__(process_id, job, config)
34
+ parent_dir = config.get('parent_dir', None)
35
+ self.dataset_paths: list[str] = config.get('dataset_paths', [])
36
+ self.device = config.get('device', 'cuda')
37
+ self.steps: list[Step] = config.get('steps', [])
38
+ self.caption_method = config.get('caption_method', 'llava:default')
39
+ self.caption_prompt = config.get('caption_prompt', default_long_prompt)
40
+ self.caption_short_prompt = config.get('caption_short_prompt', default_short_prompt)
41
+ self.force_reprocess_img = config.get('force_reprocess_img', False)
42
+ self.caption_replacements = config.get('caption_replacements', default_replacements)
43
+ self.caption_short_replacements = config.get('caption_short_replacements', default_replacements)
44
+ self.master_dataset_dict = OrderedDict()
45
+ self.dataset_master_config_file = config.get('dataset_master_config_file', None)
46
+ if parent_dir is not None and len(self.dataset_paths) == 0:
47
+ # find all folders in the patent_dataset_path
48
+ self.dataset_paths = [
49
+ os.path.join(parent_dir, folder)
50
+ for folder in os.listdir(parent_dir)
51
+ if os.path.isdir(os.path.join(parent_dir, folder))
52
+ ]
53
+ else:
54
+ # make sure they exist
55
+ for dataset_path in self.dataset_paths:
56
+ if not os.path.exists(dataset_path):
57
+ raise ValueError(f"Dataset path does not exist: {dataset_path}")
58
+
59
+ print(f"Found {len(self.dataset_paths)} dataset paths")
60
+
61
+ self.image_processor: ImageProcessor = self.get_image_processor()
62
+
63
+ def get_image_processor(self):
64
+ if self.caption_method.startswith('llava'):
65
+ return LLaVAImageProcessor(device=self.device)
66
+ elif self.caption_method.startswith('fuyu'):
67
+ return FuyuImageProcessor(device=self.device)
68
+ else:
69
+ raise ValueError(f"Unknown caption method: {self.caption_method}")
70
+
71
+ def process_image(self, img_path: str):
72
+ root_img_dir = os.path.dirname(os.path.dirname(img_path))
73
+ filename = os.path.basename(img_path)
74
+ filename_no_ext = os.path.splitext(filename)[0]
75
+ train_dir = os.path.join(root_img_dir, TRAIN_DIR)
76
+ train_img_path = os.path.join(train_dir, filename)
77
+ json_path = os.path.join(train_dir, f"{filename_no_ext}.json")
78
+
79
+ # check if json exists, if it does load it as image info
80
+ if os.path.exists(json_path):
81
+ with open(json_path, 'r') as f:
82
+ img_info = ImgInfo(**json.load(f))
83
+ else:
84
+ img_info = ImgInfo()
85
+
86
+ # always send steps first in case other processes need them
87
+ img_info.add_steps(copy.deepcopy(self.steps))
88
+ img_info.set_version(VERSION)
89
+ img_info.set_caption_method(self.caption_method)
90
+
91
+ image: Image = None
92
+ caption_image: Image = None
93
+
94
+ did_update_image = False
95
+
96
+ # trigger reprocess of steps
97
+ if self.force_reprocess_img:
98
+ img_info.trigger_image_reprocess()
99
+
100
+ # set the image as updated if it does not exist on disk
101
+ if not os.path.exists(train_img_path):
102
+ did_update_image = True
103
+ image = load_image(img_path)
104
+ if img_info.force_image_process:
105
+ did_update_image = True
106
+ image = load_image(img_path)
107
+
108
+ # go through the needed steps
109
+ for step in copy.deepcopy(img_info.state.steps_to_complete):
110
+ if step == 'caption':
111
+ # load image
112
+ if image is None:
113
+ image = load_image(img_path)
114
+ if caption_image is None:
115
+ caption_image = resize_to_max(image, 1024, 1024)
116
+
117
+ if not self.image_processor.is_loaded:
118
+ print('Loading Model. Takes a while, especially the first time')
119
+ self.image_processor.load_model()
120
+
121
+ img_info.caption = self.image_processor.generate_caption(
122
+ image=caption_image,
123
+ prompt=self.caption_prompt,
124
+ replacements=self.caption_replacements
125
+ )
126
+ img_info.mark_step_complete(step)
127
+ elif step == 'caption_short':
128
+ # load image
129
+ if image is None:
130
+ image = load_image(img_path)
131
+
132
+ if caption_image is None:
133
+ caption_image = resize_to_max(image, 1024, 1024)
134
+
135
+ if not self.image_processor.is_loaded:
136
+ print('Loading Model. Takes a while, especially the first time')
137
+ self.image_processor.load_model()
138
+ img_info.caption_short = self.image_processor.generate_caption(
139
+ image=caption_image,
140
+ prompt=self.caption_short_prompt,
141
+ replacements=self.caption_short_replacements
142
+ )
143
+ img_info.mark_step_complete(step)
144
+ elif step == 'contrast_stretch':
145
+ # load image
146
+ if image is None:
147
+ image = load_image(img_path)
148
+ image = ImageOps.autocontrast(image, cutoff=(0.1, 0), preserve_tone=True)
149
+ did_update_image = True
150
+ img_info.mark_step_complete(step)
151
+ else:
152
+ raise ValueError(f"Unknown step: {step}")
153
+
154
+ os.makedirs(os.path.dirname(train_img_path), exist_ok=True)
155
+ if did_update_image:
156
+ image.save(train_img_path)
157
+
158
+ if img_info.is_dirty:
159
+ with open(json_path, 'w') as f:
160
+ json.dump(img_info.to_dict(), f, indent=4)
161
+
162
+ if self.dataset_master_config_file:
163
+ # add to master dict
164
+ self.master_dataset_dict[train_img_path] = img_info.to_dict()
165
+
166
+ def run(self):
167
+ super().run()
168
+ imgs_to_process = []
169
+ # find all images
170
+ for dataset_path in self.dataset_paths:
171
+ raw_dir = os.path.join(dataset_path, RAW_DIR)
172
+ raw_image_paths = get_img_paths(raw_dir)
173
+ for raw_image_path in raw_image_paths:
174
+ imgs_to_process.append(raw_image_path)
175
+
176
+ if len(imgs_to_process) == 0:
177
+ print(f"No images to process")
178
+ else:
179
+ print(f"Found {len(imgs_to_process)} to process")
180
+
181
+ for img_path in tqdm(imgs_to_process, desc="Processing images"):
182
+ try:
183
+ self.process_image(img_path)
184
+ except Exception:
185
+ # print full stack trace
186
+ print(traceback.format_exc())
187
+ continue
188
+ # self.process_image(img_path)
189
+
190
+ if self.dataset_master_config_file is not None:
191
+ # save it as json
192
+ with open(self.dataset_master_config_file, 'w') as f:
193
+ json.dump(self.master_dataset_dict, f, indent=4)
194
+
195
+ del self.image_processor
196
+ flush()
extensions_built_in/dataset_tools/SyncFromCollection.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ from collections import OrderedDict
4
+ import gc
5
+ from typing import List
6
+
7
+ import torch
8
+ from tqdm import tqdm
9
+
10
+ from .tools.dataset_tools_config_modules import DatasetSyncCollectionConfig, RAW_DIR, NEW_DIR
11
+ from .tools.sync_tools import get_unsplash_images, get_pexels_images, get_local_image_file_names, download_image, \
12
+ get_img_paths
13
+ from jobs.process import BaseExtensionProcess
14
+
15
+
16
+ def flush():
17
+ torch.cuda.empty_cache()
18
+ gc.collect()
19
+
20
+
21
+ class SyncFromCollection(BaseExtensionProcess):
22
+
23
+ def __init__(self, process_id: int, job, config: OrderedDict):
24
+ super().__init__(process_id, job, config)
25
+
26
+ self.min_width = config.get('min_width', 1024)
27
+ self.min_height = config.get('min_height', 1024)
28
+
29
+ # add our min_width and min_height to each dataset config if they don't exist
30
+ for dataset_config in config.get('dataset_sync', []):
31
+ if 'min_width' not in dataset_config:
32
+ dataset_config['min_width'] = self.min_width
33
+ if 'min_height' not in dataset_config:
34
+ dataset_config['min_height'] = self.min_height
35
+
36
+ self.dataset_configs: List[DatasetSyncCollectionConfig] = [
37
+ DatasetSyncCollectionConfig(**dataset_config)
38
+ for dataset_config in config.get('dataset_sync', [])
39
+ ]
40
+ print(f"Found {len(self.dataset_configs)} dataset configs")
41
+
42
+ def move_new_images(self, root_dir: str):
43
+ raw_dir = os.path.join(root_dir, RAW_DIR)
44
+ new_dir = os.path.join(root_dir, NEW_DIR)
45
+ new_images = get_img_paths(new_dir)
46
+
47
+ for img_path in new_images:
48
+ # move to raw
49
+ new_path = os.path.join(raw_dir, os.path.basename(img_path))
50
+ shutil.move(img_path, new_path)
51
+
52
+ # remove new dir
53
+ shutil.rmtree(new_dir)
54
+
55
+ def sync_dataset(self, config: DatasetSyncCollectionConfig):
56
+ if config.host == 'unsplash':
57
+ get_images = get_unsplash_images
58
+ elif config.host == 'pexels':
59
+ get_images = get_pexels_images
60
+ else:
61
+ raise ValueError(f"Unknown host: {config.host}")
62
+
63
+ results = {
64
+ 'num_downloaded': 0,
65
+ 'num_skipped': 0,
66
+ 'bad': 0,
67
+ 'total': 0,
68
+ }
69
+
70
+ photos = get_images(config)
71
+ raw_dir = os.path.join(config.directory, RAW_DIR)
72
+ new_dir = os.path.join(config.directory, NEW_DIR)
73
+ raw_images = get_local_image_file_names(raw_dir)
74
+ new_images = get_local_image_file_names(new_dir)
75
+
76
+ for photo in tqdm(photos, desc=f"{config.host}-{config.collection_id}"):
77
+ try:
78
+ if photo.filename not in raw_images and photo.filename not in new_images:
79
+ download_image(photo, new_dir, min_width=self.min_width, min_height=self.min_height)
80
+ results['num_downloaded'] += 1
81
+ else:
82
+ results['num_skipped'] += 1
83
+ except Exception as e:
84
+ print(f" - BAD({photo.id}): {e}")
85
+ results['bad'] += 1
86
+ continue
87
+ results['total'] += 1
88
+
89
+ return results
90
+
91
+ def print_results(self, results):
92
+ print(
93
+ f" - new:{results['num_downloaded']}, old:{results['num_skipped']}, bad:{results['bad']} total:{results['total']}")
94
+
95
+ def run(self):
96
+ super().run()
97
+ print(f"Syncing {len(self.dataset_configs)} datasets")
98
+ all_results = None
99
+ failed_datasets = []
100
+ for dataset_config in tqdm(self.dataset_configs, desc="Syncing datasets", leave=True):
101
+ try:
102
+ results = self.sync_dataset(dataset_config)
103
+ if all_results is None:
104
+ all_results = {**results}
105
+ else:
106
+ for key, value in results.items():
107
+ all_results[key] += value
108
+
109
+ self.print_results(results)
110
+ except Exception as e:
111
+ print(f" - FAILED: {e}")
112
+ if 'response' in e.__dict__:
113
+ error = f"{e.response.status_code}: {e.response.text}"
114
+ print(f" - {error}")
115
+ failed_datasets.append({'dataset': dataset_config, 'error': error})
116
+ else:
117
+ failed_datasets.append({'dataset': dataset_config, 'error': str(e)})
118
+ continue
119
+
120
+ print("Moving new images to raw")
121
+ for dataset_config in self.dataset_configs:
122
+ self.move_new_images(dataset_config.directory)
123
+
124
+ print("Done syncing datasets")
125
+ self.print_results(all_results)
126
+
127
+ if len(failed_datasets) > 0:
128
+ print(f"Failed to sync {len(failed_datasets)} datasets")
129
+ for failed in failed_datasets:
130
+ print(f" - {failed['dataset'].host}-{failed['dataset'].collection_id}")
131
+ print(f" - ERR: {failed['error']}")
extensions_built_in/dataset_tools/__init__.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from toolkit.extension import Extension
2
+
3
+
4
+ class DatasetToolsExtension(Extension):
5
+ uid = "dataset_tools"
6
+
7
+ # name is the name of the extension for printing
8
+ name = "Dataset Tools"
9
+
10
+ # This is where your process class is loaded
11
+ # keep your imports in here so they don't slow down the rest of the program
12
+ @classmethod
13
+ def get_process(cls):
14
+ # import your process class here so it is only loaded when needed and return it
15
+ from .DatasetTools import DatasetTools
16
+ return DatasetTools
17
+
18
+
19
+ class SyncFromCollectionExtension(Extension):
20
+ uid = "sync_from_collection"
21
+ name = "Sync from Collection"
22
+
23
+ @classmethod
24
+ def get_process(cls):
25
+ # import your process class here so it is only loaded when needed and return it
26
+ from .SyncFromCollection import SyncFromCollection
27
+ return SyncFromCollection
28
+
29
+
30
+ class SuperTaggerExtension(Extension):
31
+ uid = "super_tagger"
32
+ name = "Super Tagger"
33
+
34
+ @classmethod
35
+ def get_process(cls):
36
+ # import your process class here so it is only loaded when needed and return it
37
+ from .SuperTagger import SuperTagger
38
+ return SuperTagger
39
+
40
+
41
+ AI_TOOLKIT_EXTENSIONS = [
42
+ SyncFromCollectionExtension, DatasetToolsExtension, SuperTaggerExtension
43
+ ]
extensions_built_in/dataset_tools/tools/caption.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ caption_manipulation_steps = ['caption', 'caption_short']
3
+
4
+ default_long_prompt = 'caption this image. describe every single thing in the image in detail. Do not include any unnecessary words in your description for the sake of good grammar. I want many short statements that serve the single purpose of giving the most thorough description if items as possible in the smallest, comma separated way possible. be sure to describe people\'s moods, clothing, the environment, lighting, colors, and everything.'
5
+ default_short_prompt = 'caption this image in less than ten words'
6
+
7
+ default_replacements = [
8
+ ("the image features", ""),
9
+ ("the image shows", ""),
10
+ ("the image depicts", ""),
11
+ ("the image is", ""),
12
+ ("in this image", ""),
13
+ ("in the image", ""),
14
+ ]
15
+
16
+
17
+ def clean_caption(cap, replacements=None):
18
+ if replacements is None:
19
+ replacements = default_replacements
20
+
21
+ # remove any newlines
22
+ cap = cap.replace("\n", ", ")
23
+ cap = cap.replace("\r", ", ")
24
+ cap = cap.replace(".", ",")
25
+ cap = cap.replace("\"", "")
26
+
27
+ # remove unicode characters
28
+ cap = cap.encode('ascii', 'ignore').decode('ascii')
29
+
30
+ # make lowercase
31
+ cap = cap.lower()
32
+ # remove any extra spaces
33
+ cap = " ".join(cap.split())
34
+
35
+ for replacement in replacements:
36
+ if replacement[0].startswith('*'):
37
+ # we are removing all text if it starts with this and the rest matches
38
+ search_text = replacement[0][1:]
39
+ if cap.startswith(search_text):
40
+ cap = ""
41
+ else:
42
+ cap = cap.replace(replacement[0].lower(), replacement[1].lower())
43
+
44
+ cap_list = cap.split(",")
45
+ # trim whitespace
46
+ cap_list = [c.strip() for c in cap_list]
47
+ # remove empty strings
48
+ cap_list = [c for c in cap_list if c != ""]
49
+ # remove duplicates
50
+ cap_list = list(dict.fromkeys(cap_list))
51
+ # join back together
52
+ cap = ", ".join(cap_list)
53
+ return cap