Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
12edc27
0
Parent(s):
Initial commit
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +91 -0
- .gitignore +177 -0
- README.md +14 -0
- app.py +414 -0
- app/APP.md +103 -0
- app/BEN2.py +1394 -0
- app/aspect_ratio_template.py +88 -0
- app/business_logic.py +556 -0
- app/config.py +72 -0
- app/constants.py +35 -0
- app/event_handlers.py +155 -0
- app/examples.py +210 -0
- app/metainfo.py +131 -0
- app/stylesheets.py +1679 -0
- app/ui_components.py +354 -0
- app/utils.py +429 -0
- assets/gradio/pos_aware/001/hypher_params.txt +3 -0
- assets/gradio/pos_aware/001/img_gen.png +3 -0
- assets/gradio/pos_aware/001/img_ref.png +3 -0
- assets/gradio/pos_aware/001/img_target.png +3 -0
- assets/gradio/pos_aware/001/mask_target.png +3 -0
- assets/gradio/pos_aware/002/hypher_params.txt +3 -0
- assets/gradio/pos_aware/002/img_gen.png +3 -0
- assets/gradio/pos_aware/002/img_ref.png +3 -0
- assets/gradio/pos_aware/002/img_target.png +3 -0
- assets/gradio/pos_aware/002/mask_target.png +3 -0
- assets/gradio/pos_aware/003/hypher_params.txt +3 -0
- assets/gradio/pos_aware/003/img_gen.png +3 -0
- assets/gradio/pos_aware/003/img_ref.png +3 -0
- assets/gradio/pos_aware/003/img_target.png +3 -0
- assets/gradio/pos_aware/003/mask_target.png +3 -0
- assets/gradio/pos_aware/004/hypher_params.txt +3 -0
- assets/gradio/pos_aware/004/img_gen.png +3 -0
- assets/gradio/pos_aware/004/img_ref.png +3 -0
- assets/gradio/pos_aware/004/img_target.png +3 -0
- assets/gradio/pos_aware/004/mask_target.png +3 -0
- assets/gradio/pos_aware/005/hypher_params.txt +3 -0
- assets/gradio/pos_aware/005/img_gen.png +3 -0
- assets/gradio/pos_aware/005/img_ref.png +3 -0
- assets/gradio/pos_aware/005/img_target.png +3 -0
- assets/gradio/pos_aware/005/mask_target.png +3 -0
- assets/gradio/pos_free/001/hyper_params.json +1 -0
- assets/gradio/pos_free/001/img_gen.png +3 -0
- assets/gradio/pos_free/001/img_ref.png +3 -0
- assets/gradio/pos_free/001/img_target.png +3 -0
- assets/gradio/pos_free/001/mask_target.png +3 -0
- assets/gradio/pos_free/002/hyper_params.json +1 -0
- assets/gradio/pos_free/002/img_gen.png +3 -0
- assets/gradio/pos_free/002/img_ref.png +3 -0
- assets/gradio/pos_free/002/img_target.png +3 -0
.gitattributes
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
assets/gradio/pos_aware/005 filter=lfs diff=lfs merge=lfs -text
|
37 |
+
assets/gradio/pos_free filter=lfs diff=lfs merge=lfs -text
|
38 |
+
assets/gradio/pos_free/001/img_gen.png filter=lfs diff=lfs merge=lfs -text
|
39 |
+
assets/gradio/pos_free/001/img_ref.png filter=lfs diff=lfs merge=lfs -text
|
40 |
+
assets/gradio/pos_free/001/mask_target.png filter=lfs diff=lfs merge=lfs -text
|
41 |
+
assets/gradio/pos_free/003/hypher_params.txt filter=lfs diff=lfs merge=lfs -text
|
42 |
+
assets/gradio/pos_aware/002/hypher_params.txt filter=lfs diff=lfs merge=lfs -text
|
43 |
+
assets/gradio/pos_aware/003/hypher_params.txt filter=lfs diff=lfs merge=lfs -text
|
44 |
+
assets/gradio/pos_aware/004/mask_target.png filter=lfs diff=lfs merge=lfs -text
|
45 |
+
assets/gradio/pos_free/003/img_ref.png filter=lfs diff=lfs merge=lfs -text
|
46 |
+
assets/gradio/pos_free/004 filter=lfs diff=lfs merge=lfs -text
|
47 |
+
assets/gradio/pos_free/004/img_ref.png filter=lfs diff=lfs merge=lfs -text
|
48 |
+
assets/gradio/pos_aware/001/img_gen.png filter=lfs diff=lfs merge=lfs -text
|
49 |
+
assets/gradio/pos_aware/001/img_ref.png filter=lfs diff=lfs merge=lfs -text
|
50 |
+
assets/gradio/pos_aware/001/img_target.png filter=lfs diff=lfs merge=lfs -text
|
51 |
+
assets/gradio/pos_aware/002/img_ref.png filter=lfs diff=lfs merge=lfs -text
|
52 |
+
assets/gradio/pos_aware/003/img_gen.png filter=lfs diff=lfs merge=lfs -text
|
53 |
+
assets/gradio/pos_aware/003/img_ref.png filter=lfs diff=lfs merge=lfs -text
|
54 |
+
assets/gradio/pos_aware/004/img_gen.png filter=lfs diff=lfs merge=lfs -text
|
55 |
+
assets/gradio/pos_aware/004/img_ref.png filter=lfs diff=lfs merge=lfs -text
|
56 |
+
assets/gradio/pos_aware/004 filter=lfs diff=lfs merge=lfs -text
|
57 |
+
assets/gradio/pos_aware/004/hypher_params.txt filter=lfs diff=lfs merge=lfs -text
|
58 |
+
assets/gradio/pos_aware/005/img_gen.png filter=lfs diff=lfs merge=lfs -text
|
59 |
+
assets/gradio/pos_free/001/img_target.png filter=lfs diff=lfs merge=lfs -text
|
60 |
+
assets/gradio/pos_free/002 filter=lfs diff=lfs merge=lfs -text
|
61 |
+
assets/gradio/pos_free/002/img_ref.png filter=lfs diff=lfs merge=lfs -text
|
62 |
+
assets/gradio/pos_free/003/img_target.png filter=lfs diff=lfs merge=lfs -text
|
63 |
+
assets/gradio/pos_free/003/mask_target.png filter=lfs diff=lfs merge=lfs -text
|
64 |
+
assets/gradio/pos_aware filter=lfs diff=lfs merge=lfs -text
|
65 |
+
assets/gradio/pos_aware/001 filter=lfs diff=lfs merge=lfs -text
|
66 |
+
assets/gradio/pos_aware/005/img_ref.png filter=lfs diff=lfs merge=lfs -text
|
67 |
+
assets/gradio/pos_free/002/hypher_params.txt filter=lfs diff=lfs merge=lfs -text
|
68 |
+
assets/gradio/pos_free/002/img_gen.png filter=lfs diff=lfs merge=lfs -text
|
69 |
+
assets/gradio/pos_free/004/img_gen.png filter=lfs diff=lfs merge=lfs -text
|
70 |
+
assets/gradio/pos_aware/002/mask_target.png filter=lfs diff=lfs merge=lfs -text
|
71 |
+
assets/gradio/pos_aware/003 filter=lfs diff=lfs merge=lfs -text
|
72 |
+
assets/gradio/pos_aware/003/img_target.png filter=lfs diff=lfs merge=lfs -text
|
73 |
+
assets/gradio/pos_aware/003/mask_target.png filter=lfs diff=lfs merge=lfs -text
|
74 |
+
assets/gradio/pos_aware/005/img_target.png filter=lfs diff=lfs merge=lfs -text
|
75 |
+
assets/gradio/pos_free/001 filter=lfs diff=lfs merge=lfs -text
|
76 |
+
assets/gradio/pos_free/001/hypher_params.txt filter=lfs diff=lfs merge=lfs -text
|
77 |
+
assets/gradio/pos_free/002/mask_target.png filter=lfs diff=lfs merge=lfs -text
|
78 |
+
assets/gradio/pos_aware/005/hypher_params.txt filter=lfs diff=lfs merge=lfs -text
|
79 |
+
assets/gradio/pos_aware/005/mask_target.png filter=lfs diff=lfs merge=lfs -text
|
80 |
+
assets/gradio/pos_free/002/img_target.png filter=lfs diff=lfs merge=lfs -text
|
81 |
+
assets/gradio/pos_free/003 filter=lfs diff=lfs merge=lfs -text
|
82 |
+
assets/gradio/pos_free/003/img_gen.png filter=lfs diff=lfs merge=lfs -text
|
83 |
+
assets/gradio/pos_free/004/img_target.png filter=lfs diff=lfs merge=lfs -text
|
84 |
+
assets/gradio/pos_aware/001/hypher_params.txt filter=lfs diff=lfs merge=lfs -text
|
85 |
+
assets/gradio/pos_aware/001/mask_target.png filter=lfs diff=lfs merge=lfs -text
|
86 |
+
assets/gradio/pos_aware/002 filter=lfs diff=lfs merge=lfs -text
|
87 |
+
assets/gradio/pos_aware/004/img_target.png filter=lfs diff=lfs merge=lfs -text
|
88 |
+
assets/gradio/pos_free/004/hypher_params.txt filter=lfs diff=lfs merge=lfs -text
|
89 |
+
assets/gradio/pos_free/004/mask_target.png filter=lfs diff=lfs merge=lfs -text
|
90 |
+
assets/gradio/pos_aware/002/img_gen.png filter=lfs diff=lfs merge=lfs -text
|
91 |
+
assets/gradio/pos_aware/002/img_target.png filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Initially taken from GitHub's Python gitignore file
|
2 |
+
|
3 |
+
# Byte-compiled / optimized / DLL files
|
4 |
+
__pycache__/
|
5 |
+
*.py[cod]
|
6 |
+
*$py.class
|
7 |
+
|
8 |
+
# C extensions
|
9 |
+
*.so
|
10 |
+
|
11 |
+
# tests and logs
|
12 |
+
tests/fixtures/cached_*_text.txt
|
13 |
+
logs/
|
14 |
+
lightning_logs/
|
15 |
+
lang_code_data/
|
16 |
+
|
17 |
+
# Distribution / packaging
|
18 |
+
.Python
|
19 |
+
build/
|
20 |
+
develop-eggs/
|
21 |
+
dist/
|
22 |
+
downloads/
|
23 |
+
eggs/
|
24 |
+
.eggs/
|
25 |
+
lib/
|
26 |
+
lib64/
|
27 |
+
parts/
|
28 |
+
sdist/
|
29 |
+
var/
|
30 |
+
wheels/
|
31 |
+
*.egg-info/
|
32 |
+
.installed.cfg
|
33 |
+
*.egg
|
34 |
+
MANIFEST
|
35 |
+
|
36 |
+
# PyInstaller
|
37 |
+
# Usually these files are written by a Python script from a template
|
38 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
39 |
+
*.manifest
|
40 |
+
*.spec
|
41 |
+
|
42 |
+
# Installer logs
|
43 |
+
pip-log.txt
|
44 |
+
pip-delete-this-directory.txt
|
45 |
+
|
46 |
+
# Unit test / coverage reports
|
47 |
+
htmlcov/
|
48 |
+
.tox/
|
49 |
+
.nox/
|
50 |
+
.coverage
|
51 |
+
.coverage.*
|
52 |
+
.cache
|
53 |
+
nosetests.xml
|
54 |
+
coverage.xml
|
55 |
+
*.cover
|
56 |
+
.hypothesis/
|
57 |
+
.pytest_cache/
|
58 |
+
|
59 |
+
# Translations
|
60 |
+
*.mo
|
61 |
+
*.pot
|
62 |
+
|
63 |
+
# Django stuff:
|
64 |
+
*.log
|
65 |
+
local_settings.py
|
66 |
+
db.sqlite3
|
67 |
+
|
68 |
+
# Flask stuff:
|
69 |
+
instance/
|
70 |
+
.webassets-cache
|
71 |
+
|
72 |
+
# Scrapy stuff:
|
73 |
+
.scrapy
|
74 |
+
|
75 |
+
# Sphinx documentation
|
76 |
+
docs/_build/
|
77 |
+
|
78 |
+
# PyBuilder
|
79 |
+
target/
|
80 |
+
|
81 |
+
# Jupyter Notebook
|
82 |
+
.ipynb_checkpoints
|
83 |
+
|
84 |
+
# IPython
|
85 |
+
profile_default/
|
86 |
+
ipython_config.py
|
87 |
+
|
88 |
+
# pyenv
|
89 |
+
.python-version
|
90 |
+
|
91 |
+
# celery beat schedule file
|
92 |
+
celerybeat-schedule
|
93 |
+
|
94 |
+
# SageMath parsed files
|
95 |
+
*.sage.py
|
96 |
+
|
97 |
+
# Environments
|
98 |
+
.env
|
99 |
+
.venv
|
100 |
+
env/
|
101 |
+
venv/
|
102 |
+
ENV/
|
103 |
+
env.bak/
|
104 |
+
venv.bak/
|
105 |
+
|
106 |
+
# Spyder project settings
|
107 |
+
.spyderproject
|
108 |
+
.spyproject
|
109 |
+
|
110 |
+
# Rope project settings
|
111 |
+
.ropeproject
|
112 |
+
|
113 |
+
# mkdocs documentation
|
114 |
+
/site
|
115 |
+
|
116 |
+
# mypy
|
117 |
+
.mypy_cache/
|
118 |
+
.dmypy.json
|
119 |
+
dmypy.json
|
120 |
+
|
121 |
+
# Pyre type checker
|
122 |
+
.pyre/
|
123 |
+
|
124 |
+
# vscode
|
125 |
+
.vs
|
126 |
+
.vscode
|
127 |
+
|
128 |
+
# Pycharm
|
129 |
+
.idea
|
130 |
+
|
131 |
+
# TF code
|
132 |
+
tensorflow_code
|
133 |
+
|
134 |
+
# Models
|
135 |
+
proc_data
|
136 |
+
/models
|
137 |
+
|
138 |
+
# examples
|
139 |
+
runs
|
140 |
+
/runs_old
|
141 |
+
/wandb
|
142 |
+
/examples/runs
|
143 |
+
/examples/**/*.args
|
144 |
+
/examples/rag/sweep
|
145 |
+
|
146 |
+
# data
|
147 |
+
data
|
148 |
+
/data
|
149 |
+
serialization_dir
|
150 |
+
|
151 |
+
# emacs
|
152 |
+
*.*~
|
153 |
+
debug.env
|
154 |
+
|
155 |
+
# vim
|
156 |
+
.*.swp
|
157 |
+
|
158 |
+
# ctags
|
159 |
+
tags
|
160 |
+
|
161 |
+
# pre-commit
|
162 |
+
.pre-commit*
|
163 |
+
|
164 |
+
# .lock
|
165 |
+
*.lock
|
166 |
+
|
167 |
+
# DS_Store (MacOS)
|
168 |
+
.DS_Store
|
169 |
+
|
170 |
+
# RL pipelines may produce mp4 outputs
|
171 |
+
*.mp4
|
172 |
+
|
173 |
+
# dependencies
|
174 |
+
/transformers
|
175 |
+
|
176 |
+
# ruff
|
177 |
+
.ruff_cache
|
README.md
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: IC Custom
|
3 |
+
emoji: 🎨
|
4 |
+
colorFrom: purple
|
5 |
+
colorTo: blue
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 5.43.1
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
license: other
|
11 |
+
short_description: IC-Custom is designed for diverse image customization.
|
12 |
+
---
|
13 |
+
|
14 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,414 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
IC-Custom Gradio Application
|
5 |
+
|
6 |
+
This module defines the UI and glue logic to run the IC-Custom pipeline
|
7 |
+
via Gradio. The code aims to keep UI text user-friendly while keeping the
|
8 |
+
implementation readable and maintainable.
|
9 |
+
"""
|
10 |
+
import os
|
11 |
+
import sys
|
12 |
+
import numpy as np
|
13 |
+
import torch
|
14 |
+
import gradio as gr
|
15 |
+
import spaces
|
16 |
+
|
17 |
+
from PIL import Image
|
18 |
+
import time
|
19 |
+
|
20 |
+
# Add current directory to path for imports
|
21 |
+
sys.path.append(os.getcwd() + '/app')
|
22 |
+
|
23 |
+
# Import modular components
|
24 |
+
from config import parse_args, load_config, setup_environment
|
25 |
+
from ui_components import (
|
26 |
+
create_theme, create_css, create_header_section, create_customization_section,
|
27 |
+
create_image_input_section, create_prompt_section, create_advanced_options_section,
|
28 |
+
create_mask_operation_section, create_output_section, create_examples_section,
|
29 |
+
create_citation_section
|
30 |
+
)
|
31 |
+
from event_handlers import setup_event_handlers
|
32 |
+
from business_logic import (
|
33 |
+
init_image_target_1, init_image_target_2, init_image_reference,
|
34 |
+
undo_seg_points, segmentation, get_point, get_brush,
|
35 |
+
dilate_mask, erode_mask, bounding_box,
|
36 |
+
change_input_mask_mode, change_custmization_mode, change_seg_ref_mode,
|
37 |
+
vlm_auto_generate, vlm_auto_polish, save_results, set_mobile_predictor,
|
38 |
+
set_ben2_model, set_vlm_processor, set_vlm_model,
|
39 |
+
)
|
40 |
+
|
41 |
+
# Import other dependencies
|
42 |
+
from utils import (
|
43 |
+
get_sam_predictor, get_vlm, get_ben2_model,
|
44 |
+
prepare_input_images, get_mask_type_ids
|
45 |
+
)
|
46 |
+
from examples import GRADIO_EXAMPLES, MASK_TGT, IMG_GEN
|
47 |
+
from ic_custom.pipelines.ic_custom_pipeline import ICCustomPipeline
|
48 |
+
|
49 |
+
# Global variables for pipeline and assets cache directory
|
50 |
+
PIPELINE = None
|
51 |
+
ASSETS_CACHE_DIR = None
|
52 |
+
|
53 |
+
# Force Hugging Face to re-download models and clear cache
|
54 |
+
os.environ["HF_HUB_FORCE_DOWNLOAD"] = "1"
|
55 |
+
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
|
56 |
+
os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache" # Use temp directory for Spaces
|
57 |
+
os.environ["HF_HOME"] = "/tmp/hf_home" # Use temp directory for Spaces
|
58 |
+
|
59 |
+
|
60 |
+
os.environ["GRADIO_TEMP_DIR"] = os.path.abspath(os.path.join(os.path.dirname(__file__), "gradio_cache"))
|
61 |
+
|
62 |
+
|
63 |
+
def set_pipeline(pipeline):
|
64 |
+
"""Inject pipeline into this module without changing function signatures."""
|
65 |
+
global PIPELINE
|
66 |
+
PIPELINE = pipeline
|
67 |
+
|
68 |
+
def set_assets_cache_dir(assets_cache_dir):
|
69 |
+
"""Inject assets cache dir into this module without changing function signatures."""
|
70 |
+
global ASSETS_CACHE_DIR
|
71 |
+
ASSETS_CACHE_DIR = assets_cache_dir
|
72 |
+
|
73 |
+
|
74 |
+
def initialize_models(args, cfg, device, weight_dtype):
|
75 |
+
"""Initialize all required models."""
|
76 |
+
# Load IC-Custom pipeline
|
77 |
+
pipeline = ICCustomPipeline(
|
78 |
+
clip_path=cfg.checkpoint_config.clip_path if os.path.exists(cfg.checkpoint_config.clip_path) else "clip-vit-large-patch14",
|
79 |
+
t5_path=cfg.checkpoint_config.t5_path if os.path.exists(cfg.checkpoint_config.t5_path) else "t5-v1_1-xxl",
|
80 |
+
siglip_path=cfg.checkpoint_config.siglip_path if os.path.exists(cfg.checkpoint_config.siglip_path) else "siglip-so400m-patch14-384",
|
81 |
+
ae_path=cfg.checkpoint_config.ae_path if os.path.exists(cfg.checkpoint_config.ae_path) else "flux-fill-dev-ae",
|
82 |
+
dit_path=cfg.checkpoint_config.dit_path if os.path.exists(cfg.checkpoint_config.dit_path) else "flux-fill-dev-dit",
|
83 |
+
redux_path=cfg.checkpoint_config.redux_path if os.path.exists(cfg.checkpoint_config.redux_path) else "flux1-redux-dev",
|
84 |
+
lora_path=cfg.checkpoint_config.lora_path if os.path.exists(cfg.checkpoint_config.lora_path) else "dit_lora_0x1561",
|
85 |
+
img_txt_in_path=cfg.checkpoint_config.img_txt_in_path if os.path.exists(cfg.checkpoint_config.img_txt_in_path) else "dit_txt_img_in_0x1561",
|
86 |
+
boundary_embeddings_path=cfg.checkpoint_config.boundary_embeddings_path if os.path.exists(cfg.checkpoint_config.boundary_embeddings_path) else "dit_boundary_embeddings_0x1561",
|
87 |
+
task_register_embeddings_path=cfg.checkpoint_config.task_register_embeddings_path if os.path.exists(cfg.checkpoint_config.task_register_embeddings_path) else "dit_task_register_embeddings_0x1561",
|
88 |
+
network_alpha=cfg.model_config.network_alpha,
|
89 |
+
double_blocks_idx=cfg.model_config.double_blocks,
|
90 |
+
single_blocks_idx=cfg.model_config.single_blocks,
|
91 |
+
device=device,
|
92 |
+
weight_dtype=weight_dtype,
|
93 |
+
offload=True,
|
94 |
+
)
|
95 |
+
pipeline.set_pipeline_offload(True)
|
96 |
+
# pipeline.set_show_progress(True)
|
97 |
+
|
98 |
+
# Load SAM predictor
|
99 |
+
mobile_predictor = get_sam_predictor(cfg.checkpoint_config.sam_path, device)
|
100 |
+
|
101 |
+
# Load VLM if enabled
|
102 |
+
vlm_processor, vlm_model = None, None
|
103 |
+
if args.enable_vlm_for_prompt:
|
104 |
+
vlm_processor, vlm_model = get_vlm(
|
105 |
+
cfg.checkpoint_config.vlm_path,
|
106 |
+
device=device,
|
107 |
+
torch_dtype=weight_dtype,
|
108 |
+
)
|
109 |
+
|
110 |
+
# Load BEN2 model if enabled
|
111 |
+
ben2_model = None
|
112 |
+
if args.enable_ben2_for_mask_ref:
|
113 |
+
ben2_model = get_ben2_model(cfg.checkpoint_config.ben2_path, device)
|
114 |
+
|
115 |
+
return pipeline, mobile_predictor, vlm_processor, vlm_model, ben2_model
|
116 |
+
|
117 |
+
@spaces.GPU(duration=140)
|
118 |
+
def run_model(
|
119 |
+
image_target_state, mask_target_state, image_reference_ori_state,
|
120 |
+
image_reference_rmbg_state, prompt, seed, guidance, true_gs, num_steps,
|
121 |
+
num_images_per_prompt, use_background_preservation, background_blend_threshold,
|
122 |
+
aspect_ratio, custmization_mode, seg_ref_mode, input_mask_mode,
|
123 |
+
progress=gr.Progress()
|
124 |
+
):
|
125 |
+
"""Run IC-Custom pipeline with current UI state and return images."""
|
126 |
+
start_ts = time.time()
|
127 |
+
progress(0, desc="Starting generation...")
|
128 |
+
# Select reference image and check inputs
|
129 |
+
if seg_ref_mode == "Masked Ref":
|
130 |
+
image_reference_state = image_reference_rmbg_state
|
131 |
+
else:
|
132 |
+
image_reference_state = image_reference_ori_state
|
133 |
+
|
134 |
+
if image_reference_state is None:
|
135 |
+
gr.Warning('Please upload the reference image')
|
136 |
+
return None, seed, gr.update(placeholder="Last Input: " + prompt, value="")
|
137 |
+
|
138 |
+
if image_target_state is None and custmization_mode != "Position-free":
|
139 |
+
gr.Warning('Please upload the target image and mask it')
|
140 |
+
return None, seed, gr.update(placeholder="Last Input: " + prompt, value="")
|
141 |
+
|
142 |
+
if custmization_mode == "Position-aware" and mask_target_state is None:
|
143 |
+
gr.Warning('Please select/draw the target mask')
|
144 |
+
return None, seed, gr.update(placeholder=prompt, value="")
|
145 |
+
|
146 |
+
|
147 |
+
mask_type_ids = get_mask_type_ids(custmization_mode, input_mask_mode)
|
148 |
+
|
149 |
+
from constants import ASPECT_RATIO_TEMPLATE
|
150 |
+
output_w, output_h = ASPECT_RATIO_TEMPLATE[aspect_ratio]
|
151 |
+
image_reference, image_target, mask_target = prepare_input_images(
|
152 |
+
image_reference_state, custmization_mode, image_target_state, mask_target_state,
|
153 |
+
width=output_w, height=output_h,
|
154 |
+
force_resize_long_edge="long edge" in aspect_ratio,
|
155 |
+
return_type="pil"
|
156 |
+
)
|
157 |
+
|
158 |
+
gr.Info(f"Output WH resolution: {image_target.size[0]}px x {image_target.size[1]}px")
|
159 |
+
# Run the model
|
160 |
+
if seed == -1:
|
161 |
+
seed = torch.randint(0, 2147483647, (1,)).item()
|
162 |
+
|
163 |
+
width, height = image_target.size[0] + image_reference.size[0], image_target.size[1]
|
164 |
+
|
165 |
+
|
166 |
+
with torch.no_grad():
|
167 |
+
output_img = PIPELINE(
|
168 |
+
prompt=prompt, width=width, height=height, guidance=guidance,
|
169 |
+
num_steps=num_steps, seed=seed, img_ref=image_reference,
|
170 |
+
img_target=image_target, mask_target=mask_target, img_ip=image_reference,
|
171 |
+
cond_w_regions=[image_reference.size[0]], mask_type_ids=mask_type_ids,
|
172 |
+
use_background_preservation=use_background_preservation,
|
173 |
+
background_blend_threshold=background_blend_threshold, true_gs=true_gs,
|
174 |
+
neg_prompt="worst quality, normal quality, low quality, low res, blurry,",
|
175 |
+
num_images_per_prompt=num_images_per_prompt,
|
176 |
+
gradio_progress=progress,
|
177 |
+
)
|
178 |
+
|
179 |
+
|
180 |
+
elapsed = time.time() - start_ts
|
181 |
+
progress(1.0, desc=f"Completed in {elapsed:.2f}s!")
|
182 |
+
gr.Info(f"Finished in {elapsed:.2f}s")
|
183 |
+
|
184 |
+
return output_img, -1, gr.update(placeholder=f"Last Input ({elapsed:.2f}s): " + prompt, value="")
|
185 |
+
|
186 |
+
|
187 |
+
def example_pipeline(
|
188 |
+
image_reference, image_target_1, image_target_2, custmization_mode,
|
189 |
+
input_mask_mode, seg_ref_mode, prompt, seed, true_gs, eg_idx,
|
190 |
+
num_steps, guidance
|
191 |
+
):
|
192 |
+
"""Handle example loading in the UI."""
|
193 |
+
|
194 |
+
if seg_ref_mode == "Full Ref":
|
195 |
+
image_reference_ori_state = np.array(image_reference.convert("RGB"))
|
196 |
+
image_reference_rmbg_state = None
|
197 |
+
image_reference_state = image_reference_ori_state
|
198 |
+
else:
|
199 |
+
image_reference_rmbg_state = np.array(image_reference.convert("RGB"))
|
200 |
+
image_reference_ori_state = None
|
201 |
+
image_reference_state = image_reference_rmbg_state
|
202 |
+
|
203 |
+
if custmization_mode == "Position-aware":
|
204 |
+
if input_mask_mode == "Precise mask":
|
205 |
+
image_target_state = np.array(image_target_1.convert("RGB"))
|
206 |
+
else:
|
207 |
+
image_target_state = np.array(image_target_2['composite'].convert("RGB"))
|
208 |
+
mask_target_state = np.array(Image.open(MASK_TGT[int(eg_idx)]))
|
209 |
+
else: # Position-free mode
|
210 |
+
# For Position-free, use the target image from IMG_TGT1 and corresponding mask
|
211 |
+
image_target_state = np.array(image_target_1.convert("RGB"))
|
212 |
+
mask_target_state = np.array(Image.open(MASK_TGT[int(eg_idx)]))
|
213 |
+
|
214 |
+
mask_target_binary = mask_target_state / 255
|
215 |
+
masked_img = image_target_state * mask_target_binary
|
216 |
+
masked_img_pil = Image.fromarray(masked_img.astype("uint8"))
|
217 |
+
output_mask_pil = Image.fromarray(mask_target_state.astype("uint8"))
|
218 |
+
|
219 |
+
if custmization_mode == "Position-aware":
|
220 |
+
mask_gallery = [masked_img_pil, output_mask_pil]
|
221 |
+
else:
|
222 |
+
mask_gallery = gr.skip()
|
223 |
+
|
224 |
+
result_gallery = [Image.open(IMG_GEN[int(eg_idx)]).convert("RGB")]
|
225 |
+
|
226 |
+
if custmization_mode == "Position-free":
|
227 |
+
return (image_reference_ori_state, image_reference_rmbg_state, image_target_state,
|
228 |
+
mask_target_state, mask_gallery, result_gallery,
|
229 |
+
gr.update(visible=False), gr.update(visible=False))
|
230 |
+
|
231 |
+
if input_mask_mode == "Precise mask":
|
232 |
+
return (image_reference_ori_state, image_reference_rmbg_state, image_target_state,
|
233 |
+
mask_target_state, mask_gallery, result_gallery,
|
234 |
+
gr.update(visible=True), gr.update(visible=False))
|
235 |
+
else:
|
236 |
+
# Ensure ImageEditor has a proper background so brush + undo work
|
237 |
+
try:
|
238 |
+
bg_img = image_target_2.get('background') or image_target_2.get('composite')
|
239 |
+
except Exception:
|
240 |
+
bg_img = image_target_2
|
241 |
+
|
242 |
+
return (
|
243 |
+
image_reference_ori_state, image_reference_rmbg_state, image_target_state,
|
244 |
+
mask_target_state, mask_gallery, result_gallery,
|
245 |
+
gr.update(visible=False),
|
246 |
+
gr.update(visible=True, value={"background": bg_img, "layers": [], "composite": bg_img}),
|
247 |
+
)
|
248 |
+
|
249 |
+
|
250 |
+
def create_application():
|
251 |
+
"""Create the main Gradio application."""
|
252 |
+
# Create theme and CSS
|
253 |
+
theme = create_theme()
|
254 |
+
css = create_css()
|
255 |
+
|
256 |
+
with gr.Blocks(theme=theme, css=css) as demo:
|
257 |
+
|
258 |
+
with gr.Column(elem_id="global_glass_container"):
|
259 |
+
|
260 |
+
# Create UI sections
|
261 |
+
create_header_section()
|
262 |
+
|
263 |
+
# Hidden components
|
264 |
+
eg_idx = gr.Textbox(label="eg_idx", visible=False, value="-1")
|
265 |
+
|
266 |
+
# State variables
|
267 |
+
image_target_state = gr.State(value=None)
|
268 |
+
mask_target_state = gr.State(value=None)
|
269 |
+
image_reference_ori_state = gr.State(value=None)
|
270 |
+
image_reference_rmbg_state = gr.State(value=None)
|
271 |
+
selected_points = gr.State(value=[])
|
272 |
+
|
273 |
+
|
274 |
+
# Main UI content with optimized left-right layout
|
275 |
+
with gr.Column(elem_id="glass_card"):
|
276 |
+
# Top section - Mode selection (full width)
|
277 |
+
custmization_mode, md_custmization_mode = create_customization_section()
|
278 |
+
|
279 |
+
# Main layout: Left for inputs, Right for outputs
|
280 |
+
with gr.Row(equal_height=False):
|
281 |
+
# LEFT COLUMN - ALL INPUTS
|
282 |
+
with gr.Column(scale=3, min_width=400):
|
283 |
+
# Image input section
|
284 |
+
(image_reference, input_mask_mode, image_target_1, image_target_2,
|
285 |
+
undo_target_seg_button, md_image_reference, md_input_mask_mode,
|
286 |
+
md_target_image) = create_image_input_section()
|
287 |
+
|
288 |
+
# Text prompt section
|
289 |
+
prompt, vlm_generate_btn, vlm_polish_btn, md_prompt = create_prompt_section()
|
290 |
+
|
291 |
+
# Advanced options (collapsible)
|
292 |
+
(aspect_ratio, seg_ref_mode, move_to_center, use_background_preservation,
|
293 |
+
background_blend_threshold, seed, num_images_per_prompt, guidance,
|
294 |
+
num_steps, true_gs) = create_advanced_options_section()
|
295 |
+
|
296 |
+
# RIGHT COLUMN - ALL OUTPUTS
|
297 |
+
with gr.Column(scale=2, min_width=350):
|
298 |
+
# Mask preview and operations
|
299 |
+
(mask_gallery, dilate_button, erode_button, bounding_box_button,
|
300 |
+
md_mask_operation) = create_mask_operation_section()
|
301 |
+
|
302 |
+
# Generation controls and results
|
303 |
+
result_gallery, submit_button, clear_btn, md_submit = create_output_section()
|
304 |
+
|
305 |
+
with gr.Row(elem_id="glass_card"):
|
306 |
+
# Examples section
|
307 |
+
examples = create_examples_section(
|
308 |
+
GRADIO_EXAMPLES,
|
309 |
+
inputs=[
|
310 |
+
image_reference,
|
311 |
+
image_target_1,
|
312 |
+
image_target_2,
|
313 |
+
custmization_mode,
|
314 |
+
input_mask_mode,
|
315 |
+
seg_ref_mode,
|
316 |
+
prompt,
|
317 |
+
seed,
|
318 |
+
true_gs,
|
319 |
+
eg_idx,
|
320 |
+
num_steps,
|
321 |
+
guidance
|
322 |
+
],
|
323 |
+
outputs=[
|
324 |
+
image_reference_ori_state,
|
325 |
+
image_reference_rmbg_state,
|
326 |
+
image_target_state,
|
327 |
+
mask_target_state,
|
328 |
+
mask_gallery,
|
329 |
+
result_gallery,
|
330 |
+
image_target_1,
|
331 |
+
image_target_2,
|
332 |
+
],
|
333 |
+
fn=example_pipeline,
|
334 |
+
)
|
335 |
+
|
336 |
+
with gr.Row(elem_id="glass_card"):
|
337 |
+
# Citation section
|
338 |
+
create_citation_section()
|
339 |
+
|
340 |
+
# Setup event handlers
|
341 |
+
setup_event_handlers(
|
342 |
+
## UI components
|
343 |
+
input_mask_mode, image_target_1, image_target_2, undo_target_seg_button,
|
344 |
+
custmization_mode, dilate_button, erode_button, bounding_box_button,
|
345 |
+
mask_gallery, md_input_mask_mode, md_target_image, md_mask_operation,
|
346 |
+
md_prompt, md_submit, result_gallery, image_target_state, mask_target_state,
|
347 |
+
seg_ref_mode, image_reference_ori_state, move_to_center,
|
348 |
+
image_reference, image_reference_rmbg_state,
|
349 |
+
## Functions
|
350 |
+
change_input_mask_mode, change_custmization_mode,
|
351 |
+
change_seg_ref_mode,
|
352 |
+
init_image_target_1, init_image_target_2, init_image_reference,
|
353 |
+
get_point, undo_seg_points,
|
354 |
+
get_brush,
|
355 |
+
# VLM buttons
|
356 |
+
vlm_generate_btn, vlm_polish_btn,
|
357 |
+
# VLM functions
|
358 |
+
vlm_auto_generate,
|
359 |
+
vlm_auto_polish,
|
360 |
+
dilate_mask, erode_mask, bounding_box,
|
361 |
+
run_model,
|
362 |
+
## Other components
|
363 |
+
selected_points, prompt,
|
364 |
+
use_background_preservation, background_blend_threshold, seed,
|
365 |
+
num_images_per_prompt, guidance, true_gs, num_steps, aspect_ratio,
|
366 |
+
submit_button,
|
367 |
+
eg_idx,
|
368 |
+
)
|
369 |
+
|
370 |
+
# Setup clear button
|
371 |
+
clear_btn.add(
|
372 |
+
[image_reference, image_target_1,image_target_2, mask_gallery, result_gallery,
|
373 |
+
selected_points, image_target_state, mask_target_state, prompt,
|
374 |
+
image_reference_ori_state, image_reference_rmbg_state]
|
375 |
+
)
|
376 |
+
|
377 |
+
return demo
|
378 |
+
|
379 |
+
|
380 |
+
def main():
|
381 |
+
"""Main entry point for the application."""
|
382 |
+
# Parse arguments and load config
|
383 |
+
args = parse_args()
|
384 |
+
cfg = load_config(args.config)
|
385 |
+
setup_environment(args)
|
386 |
+
|
387 |
+
# Initialize device and models
|
388 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
389 |
+
weight_dtype = torch.bfloat16
|
390 |
+
|
391 |
+
pipeline, mobile_predictor, vlm_processor, vlm_model, ben2_model = initialize_models(
|
392 |
+
args, cfg, device, weight_dtype
|
393 |
+
)
|
394 |
+
|
395 |
+
set_pipeline(pipeline)
|
396 |
+
set_assets_cache_dir(args.assets_cache_dir)
|
397 |
+
|
398 |
+
# Inject mobile predictor into business logic module so get_point can access it without lambdas
|
399 |
+
set_mobile_predictor(mobile_predictor)
|
400 |
+
set_ben2_model(ben2_model)
|
401 |
+
set_vlm_processor(vlm_processor)
|
402 |
+
set_vlm_model(vlm_model)
|
403 |
+
|
404 |
+
# Create and launch the application
|
405 |
+
demo = create_application()
|
406 |
+
|
407 |
+
# Launch the demo
|
408 |
+
demo.launch(server_port=7860, server_name="0.0.0.0",
|
409 |
+
allowed_paths=[os.path.abspath(os.path.join(os.path.dirname(__file__), "gradio_cache")),
|
410 |
+
os.path.abspath(os.path.join(os.path.dirname(__file__), "results"))])
|
411 |
+
|
412 |
+
|
413 |
+
if __name__ == "__main__":
|
414 |
+
main()
|
app/APP.md
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# IC-Custom Application
|
2 |
+
|
3 |
+
A sophisticated image customization tool powered by advanced AI models.
|
4 |
+
|
5 |
+
> 📺 **App Guide:**
|
6 |
+
> For a fast overview of how to use the app, watch this video:
|
7 |
+
> [IC-Custom App Usage Guide (YouTube)](https://www.youtube.com/watch?v=uaiZA3H5RV)
|
8 |
+
|
9 |
+
---
|
10 |
+
|
11 |
+
## 🚀 Quick Start
|
12 |
+
|
13 |
+
```bash
|
14 |
+
python src/app/app.py \
|
15 |
+
--config configs/app/app.yaml \
|
16 |
+
--hf_token $HF_TOKEN \
|
17 |
+
--hf_cache_dir $HF_CACHE_DIR \
|
18 |
+
--assets_cache_dir results/app \
|
19 |
+
--enable_ben2_for_mask_ref False \
|
20 |
+
--enable_vlm_for_prompt False \
|
21 |
+
--save_results True
|
22 |
+
```
|
23 |
+
|
24 |
+
---
|
25 |
+
|
26 |
+
## ⚙️ Configuration & CLI Arguments
|
27 |
+
|
28 |
+
| Argument | Type | Required | Default | Description |
|
29 |
+
|----------|------|----------|---------|-------------|
|
30 |
+
| `--config` | str | ✅ | - | Path to app YAML config file |
|
31 |
+
| `--hf_token` | str | ❌ | - | Hugging Face access token. |
|
32 |
+
| `--hf_cache_dir` | str | ❌ | `~/.cache/huggingface/hub` | HF assets cache directory |
|
33 |
+
| `--assets_cache_dir` | str | ❌ | `results/app` | Output images & metadata directory |
|
34 |
+
| `--save_results` | bool | ❌ | `False` | Save generated results |
|
35 |
+
| `--enable_ben2_for_mask_ref` | bool | ❌ | `False` | Enable BEN2 background removal |
|
36 |
+
| `--enable_vlm_for_prompt` | bool | ❌ | `False` | Enable VLM prompt generation |
|
37 |
+
|
38 |
+
### Environment Variables
|
39 |
+
|
40 |
+
- `HF_TOKEN` ← `--hf_token`
|
41 |
+
- `HF_HUB_CACHE` ← `--hf_cache_dir`
|
42 |
+
|
43 |
+
---
|
44 |
+
|
45 |
+
## 📥 Model Downloads
|
46 |
+
|
47 |
+
> **Model checkpoints are required before running the app.**
|
48 |
+
> All required models will be automatically downloaded when you run the app, or you can manually download them and specify paths in `configs/app/app.yaml`.
|
49 |
+
|
50 |
+
### Required Models
|
51 |
+
|
52 |
+
The following models are **automatically downloaded** when running the app:
|
53 |
+
|
54 |
+
| Model | Purpose | Source |
|
55 |
+
|-------|---------|--------|
|
56 |
+
| **IC-Custom** | Our customization model | [TencentARC/IC-Custom](https://huggingface.co/TencentARC/IC-Custom) |
|
57 |
+
| **CLIP** | Vision-language understanding | [openai/clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) |
|
58 |
+
| **T5** | Text processing | [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) |
|
59 |
+
| **SigLIP** | Image understanding | [google/siglip-so400m-patch14-384](https://huggingface.co/google/siglip-so400m-patch14-384) |
|
60 |
+
| **Autoencoder** | Image encoding/decoding | [black-forest-labs/FLUX.1-Fill-dev](https://huggingface.co/black-forest-labs/FLUX.1-Fill-dev/blob/main/ae.safetensors) |
|
61 |
+
| **DIT** | Diffusion model | [black-forest-labs/FLUX.1-Fill-dev](https://huggingface.co/black-forest-labs/FLUX.1-Fill-dev/blob/main/flux1-fill-dev.safetensors) |
|
62 |
+
| **Redux** | Image processing | [black-forest-labs/FLUX.1-Redux-dev](https://huggingface.co/black-forest-labs/FLUX.1-Redux-dev) |
|
63 |
+
| **SAM-vit-h** | Image segmentation | [HCMUE-Research/SAM-vit-h](https://huggingface.co/HCMUE-Research/SAM-vit-h/blob/main/sam_vit_h_4b8939.pth) |
|
64 |
+
|
65 |
+
### Optional Models (Selective Download)
|
66 |
+
|
67 |
+
**BEN2 and Qwen2.5-VL models are disabled by default** and only downloaded when explicitly enabled:
|
68 |
+
|
69 |
+
| Model | Flag | Source | Purpose |
|
70 |
+
|-------|------|--------|---------|
|
71 |
+
| **BEN2** | `--enable_ben2_for_mask_ref True` | [PramaLLC/BEN2](https://huggingface.co/PramaLLC/BEN2/blob/main/BEN2_Base.pth) | Background removal |
|
72 |
+
| **Qwen2.5-VL** | `--enable_vlm_for_prompt True` | [Qwen/Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) | Prompt generation |
|
73 |
+
|
74 |
+
### Manual Configuration
|
75 |
+
|
76 |
+
**Alternative**: Manually download all models and specify paths in `configs/app/app.yaml`:
|
77 |
+
|
78 |
+
```yaml
|
79 |
+
checkpoint_config:
|
80 |
+
# Required models
|
81 |
+
dit_path: "/path/to/flux1-fill-dev.safetensors"
|
82 |
+
ae_path: "/path/to/ae.safetensors"
|
83 |
+
t5_path: "/path/to/t5-v1_1-xxl"
|
84 |
+
clip_path: "/path/to/clip-vit-large-patch14"
|
85 |
+
siglip_path: "/path/to/siglip-so400m-patch14-384"
|
86 |
+
redux_path: "/path/to/flux1-redux-dev.safetensors"
|
87 |
+
# IC-Custom models
|
88 |
+
lora_path: "/path/to/dit_lora_0x1561.safetensors"
|
89 |
+
img_txt_in_path: "/path/to/dit_txt_img_in_0x1561.safetensors"
|
90 |
+
boundary_embeddings_path: "/path/to/dit_boundary_embeddings_0x1561.safetensors"
|
91 |
+
task_register_embeddings_path: "/path/to/dit_task_register_embeddings_0x1561.safetensors"
|
92 |
+
# APP interactive models
|
93 |
+
sam_path: "/path/to/sam_vit_h_4b8939.pth"
|
94 |
+
# Optional models
|
95 |
+
ben2_path: "/path/to/BEN2_Base.pth"
|
96 |
+
vlm_path: "/path/to/Qwen2.5-VL-7B-Instruct"
|
97 |
+
```
|
98 |
+
|
99 |
+
### APP Overview
|
100 |
+
|
101 |
+
<p align="center">
|
102 |
+
<img src="../../assets/gradio_ui.png" alt="IC-Custom APP" width="80%">
|
103 |
+
</p>
|
app/BEN2.py
ADDED
@@ -0,0 +1,1394 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2025 Prama LLC
|
2 |
+
# SPDX-License-Identifier: MIT
|
3 |
+
|
4 |
+
import math
|
5 |
+
import os
|
6 |
+
import random
|
7 |
+
import subprocess
|
8 |
+
import tempfile
|
9 |
+
import time
|
10 |
+
|
11 |
+
import cv2
|
12 |
+
import numpy as np
|
13 |
+
import torch
|
14 |
+
import torch.nn as nn
|
15 |
+
import torch.nn.functional as F
|
16 |
+
import torch.utils.checkpoint as checkpoint
|
17 |
+
from einops import rearrange
|
18 |
+
from PIL import Image, ImageOps
|
19 |
+
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
20 |
+
from torchvision import transforms
|
21 |
+
|
22 |
+
|
23 |
+
def set_random_seed(seed):
|
24 |
+
random.seed(seed)
|
25 |
+
np.random.seed(seed)
|
26 |
+
torch.manual_seed(seed)
|
27 |
+
torch.cuda.manual_seed(seed)
|
28 |
+
torch.cuda.manual_seed_all(seed)
|
29 |
+
torch.backends.cudnn.deterministic = True
|
30 |
+
torch.backends.cudnn.benchmark = False
|
31 |
+
|
32 |
+
|
33 |
+
# set_random_seed(9)
|
34 |
+
|
35 |
+
torch.set_float32_matmul_precision('highest')
|
36 |
+
|
37 |
+
|
38 |
+
class Mlp(nn.Module):
|
39 |
+
""" Multilayer perceptron."""
|
40 |
+
|
41 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
42 |
+
super().__init__()
|
43 |
+
out_features = out_features or in_features
|
44 |
+
hidden_features = hidden_features or in_features
|
45 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
46 |
+
self.act = act_layer()
|
47 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
48 |
+
self.drop = nn.Dropout(drop)
|
49 |
+
|
50 |
+
def forward(self, x):
|
51 |
+
x = self.fc1(x)
|
52 |
+
x = self.act(x)
|
53 |
+
x = self.drop(x)
|
54 |
+
x = self.fc2(x)
|
55 |
+
x = self.drop(x)
|
56 |
+
return x
|
57 |
+
|
58 |
+
|
59 |
+
def window_partition(x, window_size):
|
60 |
+
"""
|
61 |
+
Args:
|
62 |
+
x: (B, H, W, C)
|
63 |
+
window_size (int): window size
|
64 |
+
Returns:
|
65 |
+
windows: (num_windows*B, window_size, window_size, C)
|
66 |
+
"""
|
67 |
+
B, H, W, C = x.shape
|
68 |
+
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
|
69 |
+
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
70 |
+
return windows
|
71 |
+
|
72 |
+
|
73 |
+
def window_reverse(windows, window_size, H, W):
|
74 |
+
"""
|
75 |
+
Args:
|
76 |
+
windows: (num_windows*B, window_size, window_size, C)
|
77 |
+
window_size (int): Window size
|
78 |
+
H (int): Height of image
|
79 |
+
W (int): Width of image
|
80 |
+
Returns:
|
81 |
+
x: (B, H, W, C)
|
82 |
+
"""
|
83 |
+
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
84 |
+
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
|
85 |
+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
86 |
+
return x
|
87 |
+
|
88 |
+
|
89 |
+
class WindowAttention(nn.Module):
|
90 |
+
""" Window based multi-head self attention (W-MSA) module with relative position bias.
|
91 |
+
It supports both of shifted and non-shifted window.
|
92 |
+
Args:
|
93 |
+
dim (int): Number of input channels.
|
94 |
+
window_size (tuple[int]): The height and width of the window.
|
95 |
+
num_heads (int): Number of attention heads.
|
96 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
97 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
|
98 |
+
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
|
99 |
+
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
|
100 |
+
"""
|
101 |
+
|
102 |
+
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
|
103 |
+
|
104 |
+
super().__init__()
|
105 |
+
self.dim = dim
|
106 |
+
self.window_size = window_size # Wh, Ww
|
107 |
+
self.num_heads = num_heads
|
108 |
+
head_dim = dim // num_heads
|
109 |
+
self.scale = qk_scale or head_dim ** -0.5
|
110 |
+
|
111 |
+
# define a parameter table of relative position bias
|
112 |
+
self.relative_position_bias_table = nn.Parameter(
|
113 |
+
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
114 |
+
|
115 |
+
# get pair-wise relative position index for each token inside the window
|
116 |
+
coords_h = torch.arange(self.window_size[0])
|
117 |
+
coords_w = torch.arange(self.window_size[1])
|
118 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
119 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
120 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
121 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
122 |
+
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
|
123 |
+
relative_coords[:, :, 1] += self.window_size[1] - 1
|
124 |
+
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
125 |
+
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
126 |
+
self.register_buffer("relative_position_index", relative_position_index)
|
127 |
+
|
128 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
129 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
130 |
+
self.proj = nn.Linear(dim, dim)
|
131 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
132 |
+
|
133 |
+
trunc_normal_(self.relative_position_bias_table, std=.02)
|
134 |
+
self.softmax = nn.Softmax(dim=-1)
|
135 |
+
|
136 |
+
def forward(self, x, mask=None):
|
137 |
+
""" Forward function.
|
138 |
+
Args:
|
139 |
+
x: input features with shape of (num_windows*B, N, C)
|
140 |
+
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
|
141 |
+
"""
|
142 |
+
B_, N, C = x.shape
|
143 |
+
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
144 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
145 |
+
|
146 |
+
q = q * self.scale
|
147 |
+
attn = (q @ k.transpose(-2, -1))
|
148 |
+
|
149 |
+
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
150 |
+
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
|
151 |
+
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
152 |
+
attn = attn + relative_position_bias.unsqueeze(0)
|
153 |
+
|
154 |
+
if mask is not None:
|
155 |
+
nW = mask.shape[0]
|
156 |
+
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
|
157 |
+
attn = attn.view(-1, self.num_heads, N, N)
|
158 |
+
attn = self.softmax(attn)
|
159 |
+
else:
|
160 |
+
attn = self.softmax(attn)
|
161 |
+
|
162 |
+
attn = self.attn_drop(attn)
|
163 |
+
|
164 |
+
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
165 |
+
x = self.proj(x)
|
166 |
+
x = self.proj_drop(x)
|
167 |
+
return x
|
168 |
+
|
169 |
+
|
170 |
+
class SwinTransformerBlock(nn.Module):
|
171 |
+
""" Swin Transformer Block.
|
172 |
+
Args:
|
173 |
+
dim (int): Number of input channels.
|
174 |
+
num_heads (int): Number of attention heads.
|
175 |
+
window_size (int): Window size.
|
176 |
+
shift_size (int): Shift size for SW-MSA.
|
177 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
178 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
179 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
180 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
181 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
182 |
+
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
183 |
+
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
|
184 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
185 |
+
"""
|
186 |
+
|
187 |
+
def __init__(self, dim, num_heads, window_size=7, shift_size=0,
|
188 |
+
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
|
189 |
+
act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
190 |
+
super().__init__()
|
191 |
+
self.dim = dim
|
192 |
+
self.num_heads = num_heads
|
193 |
+
self.window_size = window_size
|
194 |
+
self.shift_size = shift_size
|
195 |
+
self.mlp_ratio = mlp_ratio
|
196 |
+
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
|
197 |
+
|
198 |
+
self.norm1 = norm_layer(dim)
|
199 |
+
self.attn = WindowAttention(
|
200 |
+
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
|
201 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
202 |
+
|
203 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
204 |
+
self.norm2 = norm_layer(dim)
|
205 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
206 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
207 |
+
|
208 |
+
self.H = None
|
209 |
+
self.W = None
|
210 |
+
|
211 |
+
def forward(self, x, mask_matrix):
|
212 |
+
""" Forward function.
|
213 |
+
Args:
|
214 |
+
x: Input feature, tensor size (B, H*W, C).
|
215 |
+
H, W: Spatial resolution of the input feature.
|
216 |
+
mask_matrix: Attention mask for cyclic shift.
|
217 |
+
"""
|
218 |
+
B, L, C = x.shape
|
219 |
+
H, W = self.H, self.W
|
220 |
+
assert L == H * W, "input feature has wrong size"
|
221 |
+
|
222 |
+
shortcut = x
|
223 |
+
x = self.norm1(x)
|
224 |
+
x = x.view(B, H, W, C)
|
225 |
+
|
226 |
+
# pad feature maps to multiples of window size
|
227 |
+
pad_l = pad_t = 0
|
228 |
+
pad_r = (self.window_size - W % self.window_size) % self.window_size
|
229 |
+
pad_b = (self.window_size - H % self.window_size) % self.window_size
|
230 |
+
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
|
231 |
+
_, Hp, Wp, _ = x.shape
|
232 |
+
|
233 |
+
# cyclic shift
|
234 |
+
if self.shift_size > 0:
|
235 |
+
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
|
236 |
+
attn_mask = mask_matrix
|
237 |
+
else:
|
238 |
+
shifted_x = x
|
239 |
+
attn_mask = None
|
240 |
+
|
241 |
+
# partition windows
|
242 |
+
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
|
243 |
+
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
|
244 |
+
|
245 |
+
# W-MSA/SW-MSA
|
246 |
+
attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
|
247 |
+
|
248 |
+
# merge windows
|
249 |
+
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
|
250 |
+
shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
|
251 |
+
|
252 |
+
# reverse cyclic shift
|
253 |
+
if self.shift_size > 0:
|
254 |
+
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
|
255 |
+
else:
|
256 |
+
x = shifted_x
|
257 |
+
|
258 |
+
if pad_r > 0 or pad_b > 0:
|
259 |
+
x = x[:, :H, :W, :].contiguous()
|
260 |
+
|
261 |
+
x = x.view(B, H * W, C)
|
262 |
+
|
263 |
+
# FFN
|
264 |
+
x = shortcut + self.drop_path(x)
|
265 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
266 |
+
|
267 |
+
return x
|
268 |
+
|
269 |
+
|
270 |
+
class PatchMerging(nn.Module):
|
271 |
+
""" Patch Merging Layer
|
272 |
+
Args:
|
273 |
+
dim (int): Number of input channels.
|
274 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
275 |
+
"""
|
276 |
+
|
277 |
+
def __init__(self, dim, norm_layer=nn.LayerNorm):
|
278 |
+
super().__init__()
|
279 |
+
self.dim = dim
|
280 |
+
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
|
281 |
+
self.norm = norm_layer(4 * dim)
|
282 |
+
|
283 |
+
def forward(self, x, H, W):
|
284 |
+
""" Forward function.
|
285 |
+
Args:
|
286 |
+
x: Input feature, tensor size (B, H*W, C).
|
287 |
+
H, W: Spatial resolution of the input feature.
|
288 |
+
"""
|
289 |
+
B, L, C = x.shape
|
290 |
+
assert L == H * W, "input feature has wrong size"
|
291 |
+
|
292 |
+
x = x.view(B, H, W, C)
|
293 |
+
|
294 |
+
# padding
|
295 |
+
pad_input = (H % 2 == 1) or (W % 2 == 1)
|
296 |
+
if pad_input:
|
297 |
+
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
|
298 |
+
|
299 |
+
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
|
300 |
+
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
|
301 |
+
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
|
302 |
+
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
|
303 |
+
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
|
304 |
+
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
|
305 |
+
|
306 |
+
x = self.norm(x)
|
307 |
+
x = self.reduction(x)
|
308 |
+
|
309 |
+
return x
|
310 |
+
|
311 |
+
|
312 |
+
class BasicLayer(nn.Module):
|
313 |
+
""" A basic Swin Transformer layer for one stage.
|
314 |
+
Args:
|
315 |
+
dim (int): Number of feature channels
|
316 |
+
depth (int): Depths of this stage.
|
317 |
+
num_heads (int): Number of attention head.
|
318 |
+
window_size (int): Local window size. Default: 7.
|
319 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
|
320 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
321 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
322 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
323 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
324 |
+
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
325 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
326 |
+
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
327 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
328 |
+
"""
|
329 |
+
|
330 |
+
def __init__(self,
|
331 |
+
dim,
|
332 |
+
depth,
|
333 |
+
num_heads,
|
334 |
+
window_size=7,
|
335 |
+
mlp_ratio=4.,
|
336 |
+
qkv_bias=True,
|
337 |
+
qk_scale=None,
|
338 |
+
drop=0.,
|
339 |
+
attn_drop=0.,
|
340 |
+
drop_path=0.,
|
341 |
+
norm_layer=nn.LayerNorm,
|
342 |
+
downsample=None,
|
343 |
+
use_checkpoint=False):
|
344 |
+
super().__init__()
|
345 |
+
self.window_size = window_size
|
346 |
+
self.shift_size = window_size // 2
|
347 |
+
self.depth = depth
|
348 |
+
self.use_checkpoint = use_checkpoint
|
349 |
+
|
350 |
+
# build blocks
|
351 |
+
self.blocks = nn.ModuleList([
|
352 |
+
SwinTransformerBlock(
|
353 |
+
dim=dim,
|
354 |
+
num_heads=num_heads,
|
355 |
+
window_size=window_size,
|
356 |
+
shift_size=0 if (i % 2 == 0) else window_size // 2,
|
357 |
+
mlp_ratio=mlp_ratio,
|
358 |
+
qkv_bias=qkv_bias,
|
359 |
+
qk_scale=qk_scale,
|
360 |
+
drop=drop,
|
361 |
+
attn_drop=attn_drop,
|
362 |
+
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
363 |
+
norm_layer=norm_layer)
|
364 |
+
for i in range(depth)])
|
365 |
+
|
366 |
+
# patch merging layer
|
367 |
+
if downsample is not None:
|
368 |
+
self.downsample = downsample(dim=dim, norm_layer=norm_layer)
|
369 |
+
else:
|
370 |
+
self.downsample = None
|
371 |
+
|
372 |
+
def forward(self, x, H, W):
|
373 |
+
""" Forward function.
|
374 |
+
Args:
|
375 |
+
x: Input feature, tensor size (B, H*W, C).
|
376 |
+
H, W: Spatial resolution of the input feature.
|
377 |
+
"""
|
378 |
+
|
379 |
+
# calculate attention mask for SW-MSA
|
380 |
+
Hp = int(np.ceil(H / self.window_size)) * self.window_size
|
381 |
+
Wp = int(np.ceil(W / self.window_size)) * self.window_size
|
382 |
+
img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
|
383 |
+
h_slices = (slice(0, -self.window_size),
|
384 |
+
slice(-self.window_size, -self.shift_size),
|
385 |
+
slice(-self.shift_size, None))
|
386 |
+
w_slices = (slice(0, -self.window_size),
|
387 |
+
slice(-self.window_size, -self.shift_size),
|
388 |
+
slice(-self.shift_size, None))
|
389 |
+
cnt = 0
|
390 |
+
for h in h_slices:
|
391 |
+
for w in w_slices:
|
392 |
+
img_mask[:, h, w, :] = cnt
|
393 |
+
cnt += 1
|
394 |
+
|
395 |
+
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
|
396 |
+
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
|
397 |
+
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
398 |
+
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
399 |
+
|
400 |
+
for blk in self.blocks:
|
401 |
+
blk.H, blk.W = H, W
|
402 |
+
if self.use_checkpoint:
|
403 |
+
x = checkpoint.checkpoint(blk, x, attn_mask)
|
404 |
+
else:
|
405 |
+
x = blk(x, attn_mask)
|
406 |
+
if self.downsample is not None:
|
407 |
+
x_down = self.downsample(x, H, W)
|
408 |
+
Wh, Ww = (H + 1) // 2, (W + 1) // 2
|
409 |
+
return x, H, W, x_down, Wh, Ww
|
410 |
+
else:
|
411 |
+
return x, H, W, x, H, W
|
412 |
+
|
413 |
+
|
414 |
+
class PatchEmbed(nn.Module):
|
415 |
+
""" Image to Patch Embedding
|
416 |
+
Args:
|
417 |
+
patch_size (int): Patch token size. Default: 4.
|
418 |
+
in_chans (int): Number of input image channels. Default: 3.
|
419 |
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
420 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
421 |
+
"""
|
422 |
+
|
423 |
+
def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
|
424 |
+
super().__init__()
|
425 |
+
patch_size = to_2tuple(patch_size)
|
426 |
+
self.patch_size = patch_size
|
427 |
+
|
428 |
+
self.in_chans = in_chans
|
429 |
+
self.embed_dim = embed_dim
|
430 |
+
|
431 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
432 |
+
if norm_layer is not None:
|
433 |
+
self.norm = norm_layer(embed_dim)
|
434 |
+
else:
|
435 |
+
self.norm = None
|
436 |
+
|
437 |
+
def forward(self, x):
|
438 |
+
"""Forward function."""
|
439 |
+
# padding
|
440 |
+
_, _, H, W = x.size()
|
441 |
+
if W % self.patch_size[1] != 0:
|
442 |
+
x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
|
443 |
+
if H % self.patch_size[0] != 0:
|
444 |
+
x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
|
445 |
+
|
446 |
+
x = self.proj(x) # B C Wh Ww
|
447 |
+
if self.norm is not None:
|
448 |
+
Wh, Ww = x.size(2), x.size(3)
|
449 |
+
x = x.flatten(2).transpose(1, 2)
|
450 |
+
x = self.norm(x)
|
451 |
+
x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
|
452 |
+
|
453 |
+
return x
|
454 |
+
|
455 |
+
|
456 |
+
class SwinTransformer(nn.Module):
|
457 |
+
""" Swin Transformer backbone.
|
458 |
+
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
|
459 |
+
https://arxiv.org/pdf/2103.14030
|
460 |
+
Args:
|
461 |
+
pretrain_img_size (int): Input image size for training the pretrained model,
|
462 |
+
used in absolute postion embedding. Default 224.
|
463 |
+
patch_size (int | tuple(int)): Patch size. Default: 4.
|
464 |
+
in_chans (int): Number of input image channels. Default: 3.
|
465 |
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
466 |
+
depths (tuple[int]): Depths of each Swin Transformer stage.
|
467 |
+
num_heads (tuple[int]): Number of attention head of each stage.
|
468 |
+
window_size (int): Window size. Default: 7.
|
469 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
|
470 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
|
471 |
+
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
|
472 |
+
drop_rate (float): Dropout rate.
|
473 |
+
attn_drop_rate (float): Attention dropout rate. Default: 0.
|
474 |
+
drop_path_rate (float): Stochastic depth rate. Default: 0.2.
|
475 |
+
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
|
476 |
+
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
|
477 |
+
patch_norm (bool): If True, add normalization after patch embedding. Default: True.
|
478 |
+
out_indices (Sequence[int]): Output from which stages.
|
479 |
+
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
|
480 |
+
-1 means not freezing any parameters.
|
481 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
482 |
+
"""
|
483 |
+
|
484 |
+
def __init__(self,
|
485 |
+
pretrain_img_size=224,
|
486 |
+
patch_size=4,
|
487 |
+
in_chans=3,
|
488 |
+
embed_dim=96,
|
489 |
+
depths=[2, 2, 6, 2],
|
490 |
+
num_heads=[3, 6, 12, 24],
|
491 |
+
window_size=7,
|
492 |
+
mlp_ratio=4.,
|
493 |
+
qkv_bias=True,
|
494 |
+
qk_scale=None,
|
495 |
+
drop_rate=0.,
|
496 |
+
attn_drop_rate=0.,
|
497 |
+
drop_path_rate=0.2,
|
498 |
+
norm_layer=nn.LayerNorm,
|
499 |
+
ape=False,
|
500 |
+
patch_norm=True,
|
501 |
+
out_indices=(0, 1, 2, 3),
|
502 |
+
frozen_stages=-1,
|
503 |
+
use_checkpoint=False):
|
504 |
+
super().__init__()
|
505 |
+
|
506 |
+
self.pretrain_img_size = pretrain_img_size
|
507 |
+
self.num_layers = len(depths)
|
508 |
+
self.embed_dim = embed_dim
|
509 |
+
self.ape = ape
|
510 |
+
self.patch_norm = patch_norm
|
511 |
+
self.out_indices = out_indices
|
512 |
+
self.frozen_stages = frozen_stages
|
513 |
+
|
514 |
+
# split image into non-overlapping patches
|
515 |
+
self.patch_embed = PatchEmbed(
|
516 |
+
patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
|
517 |
+
norm_layer=norm_layer if self.patch_norm else None)
|
518 |
+
|
519 |
+
# absolute position embedding
|
520 |
+
if self.ape:
|
521 |
+
pretrain_img_size = to_2tuple(pretrain_img_size)
|
522 |
+
patch_size = to_2tuple(patch_size)
|
523 |
+
patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1]]
|
524 |
+
|
525 |
+
self.absolute_pos_embed = nn.Parameter(
|
526 |
+
torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1]))
|
527 |
+
trunc_normal_(self.absolute_pos_embed, std=.02)
|
528 |
+
|
529 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
530 |
+
|
531 |
+
# stochastic depth
|
532 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
|
533 |
+
|
534 |
+
# build layers
|
535 |
+
self.layers = nn.ModuleList()
|
536 |
+
for i_layer in range(self.num_layers):
|
537 |
+
layer = BasicLayer(
|
538 |
+
dim=int(embed_dim * 2 ** i_layer),
|
539 |
+
depth=depths[i_layer],
|
540 |
+
num_heads=num_heads[i_layer],
|
541 |
+
window_size=window_size,
|
542 |
+
mlp_ratio=mlp_ratio,
|
543 |
+
qkv_bias=qkv_bias,
|
544 |
+
qk_scale=qk_scale,
|
545 |
+
drop=drop_rate,
|
546 |
+
attn_drop=attn_drop_rate,
|
547 |
+
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
|
548 |
+
norm_layer=norm_layer,
|
549 |
+
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
|
550 |
+
use_checkpoint=use_checkpoint)
|
551 |
+
self.layers.append(layer)
|
552 |
+
|
553 |
+
num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
|
554 |
+
self.num_features = num_features
|
555 |
+
|
556 |
+
# add a norm layer for each output
|
557 |
+
for i_layer in out_indices:
|
558 |
+
layer = norm_layer(num_features[i_layer])
|
559 |
+
layer_name = f'norm{i_layer}'
|
560 |
+
self.add_module(layer_name, layer)
|
561 |
+
|
562 |
+
self._freeze_stages()
|
563 |
+
|
564 |
+
def _freeze_stages(self):
|
565 |
+
if self.frozen_stages >= 0:
|
566 |
+
self.patch_embed.eval()
|
567 |
+
for param in self.patch_embed.parameters():
|
568 |
+
param.requires_grad = False
|
569 |
+
|
570 |
+
if self.frozen_stages >= 1 and self.ape:
|
571 |
+
self.absolute_pos_embed.requires_grad = False
|
572 |
+
|
573 |
+
if self.frozen_stages >= 2:
|
574 |
+
self.pos_drop.eval()
|
575 |
+
for i in range(0, self.frozen_stages - 1):
|
576 |
+
m = self.layers[i]
|
577 |
+
m.eval()
|
578 |
+
for param in m.parameters():
|
579 |
+
param.requires_grad = False
|
580 |
+
|
581 |
+
def forward(self, x):
|
582 |
+
|
583 |
+
x = self.patch_embed(x)
|
584 |
+
|
585 |
+
Wh, Ww = x.size(2), x.size(3)
|
586 |
+
if self.ape:
|
587 |
+
# interpolate the position embedding to the corresponding size
|
588 |
+
absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic')
|
589 |
+
x = (x + absolute_pos_embed) # B Wh*Ww C
|
590 |
+
|
591 |
+
outs = [x.contiguous()]
|
592 |
+
x = x.flatten(2).transpose(1, 2)
|
593 |
+
x = self.pos_drop(x)
|
594 |
+
|
595 |
+
for i in range(self.num_layers):
|
596 |
+
layer = self.layers[i]
|
597 |
+
x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
|
598 |
+
|
599 |
+
if i in self.out_indices:
|
600 |
+
norm_layer = getattr(self, f'norm{i}')
|
601 |
+
x_out = norm_layer(x_out)
|
602 |
+
|
603 |
+
out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
|
604 |
+
outs.append(out)
|
605 |
+
|
606 |
+
return tuple(outs)
|
607 |
+
|
608 |
+
|
609 |
+
def get_activation_fn(activation):
|
610 |
+
"""Return an activation function given a string"""
|
611 |
+
if activation == "gelu":
|
612 |
+
return F.gelu
|
613 |
+
|
614 |
+
raise RuntimeError(F"activation should be gelu, not {activation}.")
|
615 |
+
|
616 |
+
|
617 |
+
def make_cbr(in_dim, out_dim):
|
618 |
+
return nn.Sequential(nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1), nn.InstanceNorm2d(out_dim), nn.GELU())
|
619 |
+
|
620 |
+
|
621 |
+
def make_cbg(in_dim, out_dim):
|
622 |
+
return nn.Sequential(nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1), nn.InstanceNorm2d(out_dim), nn.GELU())
|
623 |
+
|
624 |
+
|
625 |
+
def rescale_to(x, scale_factor: float = 2, interpolation='nearest'):
|
626 |
+
return F.interpolate(x, scale_factor=scale_factor, mode=interpolation)
|
627 |
+
|
628 |
+
|
629 |
+
def resize_as(x, y, interpolation='bilinear'):
|
630 |
+
return F.interpolate(x, size=y.shape[-2:], mode=interpolation)
|
631 |
+
|
632 |
+
|
633 |
+
def image2patches(x):
|
634 |
+
"""b c (hg h) (wg w) -> (hg wg b) c h w"""
|
635 |
+
x = rearrange(x, 'b c (hg h) (wg w) -> (hg wg b) c h w', hg=2, wg=2)
|
636 |
+
return x
|
637 |
+
|
638 |
+
|
639 |
+
def patches2image(x):
|
640 |
+
"""(hg wg b) c h w -> b c (hg h) (wg w)"""
|
641 |
+
x = rearrange(x, '(hg wg b) c h w -> b c (hg h) (wg w)', hg=2, wg=2)
|
642 |
+
return x
|
643 |
+
|
644 |
+
|
645 |
+
class PositionEmbeddingSine:
|
646 |
+
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
|
647 |
+
super().__init__()
|
648 |
+
self.num_pos_feats = num_pos_feats
|
649 |
+
self.temperature = temperature
|
650 |
+
self.normalize = normalize
|
651 |
+
if scale is not None and normalize is False:
|
652 |
+
raise ValueError("normalize should be True if scale is passed")
|
653 |
+
if scale is None:
|
654 |
+
scale = 2 * math.pi
|
655 |
+
self.scale = scale
|
656 |
+
self.dim_t = torch.arange(0, self.num_pos_feats, dtype=torch.float32)
|
657 |
+
|
658 |
+
def __call__(self, b, h, w):
|
659 |
+
device = self.dim_t.device
|
660 |
+
mask = torch.zeros([b, h, w], dtype=torch.bool, device=device)
|
661 |
+
assert mask is not None
|
662 |
+
not_mask = ~mask
|
663 |
+
y_embed = not_mask.cumsum(dim=1, dtype=torch.float32)
|
664 |
+
x_embed = not_mask.cumsum(dim=2, dtype=torch.float32)
|
665 |
+
if self.normalize:
|
666 |
+
eps = 1e-6
|
667 |
+
y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * self.scale
|
668 |
+
x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale
|
669 |
+
|
670 |
+
dim_t = self.temperature ** (2 * (self.dim_t.to(device) // 2) / self.num_pos_feats)
|
671 |
+
pos_x = x_embed[:, :, :, None] / dim_t
|
672 |
+
pos_y = y_embed[:, :, :, None] / dim_t
|
673 |
+
|
674 |
+
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
675 |
+
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
676 |
+
|
677 |
+
return torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
678 |
+
|
679 |
+
|
680 |
+
class PositionEmbeddingSine:
|
681 |
+
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
|
682 |
+
super().__init__()
|
683 |
+
self.num_pos_feats = num_pos_feats
|
684 |
+
self.temperature = temperature
|
685 |
+
self.normalize = normalize
|
686 |
+
if scale is not None and normalize is False:
|
687 |
+
raise ValueError("normalize should be True if scale is passed")
|
688 |
+
if scale is None:
|
689 |
+
scale = 2 * math.pi
|
690 |
+
self.scale = scale
|
691 |
+
self.dim_t = torch.arange(0, self.num_pos_feats, dtype=torch.float32)
|
692 |
+
|
693 |
+
def __call__(self, b, h, w):
|
694 |
+
device = self.dim_t.device
|
695 |
+
mask = torch.zeros([b, h, w], dtype=torch.bool, device=device)
|
696 |
+
assert mask is not None
|
697 |
+
not_mask = ~mask
|
698 |
+
y_embed = not_mask.cumsum(dim=1, dtype=torch.float32)
|
699 |
+
x_embed = not_mask.cumsum(dim=2, dtype=torch.float32)
|
700 |
+
if self.normalize:
|
701 |
+
eps = 1e-6
|
702 |
+
y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * self.scale
|
703 |
+
x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale
|
704 |
+
|
705 |
+
dim_t = self.temperature ** (2 * (self.dim_t.to(device) // 2) / self.num_pos_feats)
|
706 |
+
pos_x = x_embed[:, :, :, None] / dim_t
|
707 |
+
pos_y = y_embed[:, :, :, None] / dim_t
|
708 |
+
|
709 |
+
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
710 |
+
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
711 |
+
|
712 |
+
return torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
713 |
+
|
714 |
+
|
715 |
+
class MCLM(nn.Module):
|
716 |
+
def __init__(self, d_model, num_heads, pool_ratios=[1, 4, 8]):
|
717 |
+
super(MCLM, self).__init__()
|
718 |
+
self.attention = nn.ModuleList([
|
719 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
|
720 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
|
721 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
|
722 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
|
723 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1)
|
724 |
+
])
|
725 |
+
|
726 |
+
self.linear1 = nn.Linear(d_model, d_model * 2)
|
727 |
+
self.linear2 = nn.Linear(d_model * 2, d_model)
|
728 |
+
self.linear3 = nn.Linear(d_model, d_model * 2)
|
729 |
+
self.linear4 = nn.Linear(d_model * 2, d_model)
|
730 |
+
self.norm1 = nn.LayerNorm(d_model)
|
731 |
+
self.norm2 = nn.LayerNorm(d_model)
|
732 |
+
self.dropout = nn.Dropout(0.1)
|
733 |
+
self.dropout1 = nn.Dropout(0.1)
|
734 |
+
self.dropout2 = nn.Dropout(0.1)
|
735 |
+
self.activation = get_activation_fn('gelu')
|
736 |
+
self.pool_ratios = pool_ratios
|
737 |
+
self.p_poses = []
|
738 |
+
self.g_pos = None
|
739 |
+
self.positional_encoding = PositionEmbeddingSine(num_pos_feats=d_model // 2, normalize=True)
|
740 |
+
|
741 |
+
def forward(self, l, g):
|
742 |
+
"""
|
743 |
+
l: 4,c,h,w
|
744 |
+
g: 1,c,h,w
|
745 |
+
"""
|
746 |
+
self.p_poses = []
|
747 |
+
self.g_pos = None
|
748 |
+
b, c, h, w = l.size()
|
749 |
+
# 4,c,h,w -> 1,c,2h,2w
|
750 |
+
concated_locs = rearrange(l, '(hg wg b) c h w -> b c (hg h) (wg w)', hg=2, wg=2)
|
751 |
+
|
752 |
+
pools = []
|
753 |
+
for pool_ratio in self.pool_ratios:
|
754 |
+
# b,c,h,w
|
755 |
+
tgt_hw = (round(h / pool_ratio), round(w / pool_ratio))
|
756 |
+
pool = F.adaptive_avg_pool2d(concated_locs, tgt_hw)
|
757 |
+
pools.append(rearrange(pool, 'b c h w -> (h w) b c'))
|
758 |
+
if self.g_pos is None:
|
759 |
+
pos_emb = self.positional_encoding(pool.shape[0], pool.shape[2], pool.shape[3])
|
760 |
+
pos_emb = rearrange(pos_emb, 'b c h w -> (h w) b c')
|
761 |
+
self.p_poses.append(pos_emb)
|
762 |
+
pools = torch.cat(pools, 0)
|
763 |
+
if self.g_pos is None:
|
764 |
+
self.p_poses = torch.cat(self.p_poses, dim=0)
|
765 |
+
pos_emb = self.positional_encoding(g.shape[0], g.shape[2], g.shape[3])
|
766 |
+
self.g_pos = rearrange(pos_emb, 'b c h w -> (h w) b c')
|
767 |
+
|
768 |
+
device = pools.device
|
769 |
+
self.p_poses = self.p_poses.to(device)
|
770 |
+
self.g_pos = self.g_pos.to(device)
|
771 |
+
|
772 |
+
# attention between glb (q) & multisensory concated-locs (k,v)
|
773 |
+
g_hw_b_c = rearrange(g, 'b c h w -> (h w) b c')
|
774 |
+
|
775 |
+
g_hw_b_c = g_hw_b_c + self.dropout1(self.attention[0](g_hw_b_c + self.g_pos, pools + self.p_poses, pools)[0])
|
776 |
+
g_hw_b_c = self.norm1(g_hw_b_c)
|
777 |
+
g_hw_b_c = g_hw_b_c + self.dropout2(self.linear2(self.dropout(self.activation(self.linear1(g_hw_b_c)).clone())))
|
778 |
+
g_hw_b_c = self.norm2(g_hw_b_c)
|
779 |
+
|
780 |
+
# attention between origin locs (q) & freashed glb (k,v)
|
781 |
+
l_hw_b_c = rearrange(l, "b c h w -> (h w) b c")
|
782 |
+
_g_hw_b_c = rearrange(g_hw_b_c, '(h w) b c -> h w b c', h=h, w=w)
|
783 |
+
_g_hw_b_c = rearrange(_g_hw_b_c, "(ng h) (nw w) b c -> (h w) (ng nw b) c", ng=2, nw=2)
|
784 |
+
outputs_re = []
|
785 |
+
for i, (_l, _g) in enumerate(zip(l_hw_b_c.chunk(4, dim=1), _g_hw_b_c.chunk(4, dim=1))):
|
786 |
+
outputs_re.append(self.attention[i + 1](_l, _g, _g)[0]) # (h w) 1 c
|
787 |
+
outputs_re = torch.cat(outputs_re, 1) # (h w) 4 c
|
788 |
+
|
789 |
+
l_hw_b_c = l_hw_b_c + self.dropout1(outputs_re)
|
790 |
+
l_hw_b_c = self.norm1(l_hw_b_c)
|
791 |
+
l_hw_b_c = l_hw_b_c + self.dropout2(self.linear4(self.dropout(self.activation(self.linear3(l_hw_b_c)).clone())))
|
792 |
+
l_hw_b_c = self.norm2(l_hw_b_c)
|
793 |
+
|
794 |
+
l = torch.cat((l_hw_b_c, g_hw_b_c), 1) # hw,b(5),c
|
795 |
+
return rearrange(l, "(h w) b c -> b c h w", h=h, w=w) ## (5,c,h*w)
|
796 |
+
|
797 |
+
|
798 |
+
class MCRM(nn.Module):
|
799 |
+
def __init__(self, d_model, num_heads, pool_ratios=[4, 8, 16], h=None):
|
800 |
+
super(MCRM, self).__init__()
|
801 |
+
self.attention = nn.ModuleList([
|
802 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
|
803 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
|
804 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
|
805 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1)
|
806 |
+
])
|
807 |
+
self.linear3 = nn.Linear(d_model, d_model * 2)
|
808 |
+
self.linear4 = nn.Linear(d_model * 2, d_model)
|
809 |
+
self.norm1 = nn.LayerNorm(d_model)
|
810 |
+
self.norm2 = nn.LayerNorm(d_model)
|
811 |
+
self.dropout = nn.Dropout(0.1)
|
812 |
+
self.dropout1 = nn.Dropout(0.1)
|
813 |
+
self.dropout2 = nn.Dropout(0.1)
|
814 |
+
self.sigmoid = nn.Sigmoid()
|
815 |
+
self.activation = get_activation_fn('gelu')
|
816 |
+
self.sal_conv = nn.Conv2d(d_model, 1, 1)
|
817 |
+
self.pool_ratios = pool_ratios
|
818 |
+
|
819 |
+
def forward(self, x):
|
820 |
+
device = x.device
|
821 |
+
b, c, h, w = x.size()
|
822 |
+
loc, glb = x.split([4, 1], dim=0) # 4,c,h,w; 1,c,h,w
|
823 |
+
|
824 |
+
patched_glb = rearrange(glb, 'b c (hg h) (wg w) -> (hg wg b) c h w', hg=2, wg=2)
|
825 |
+
|
826 |
+
token_attention_map = self.sigmoid(self.sal_conv(glb))
|
827 |
+
token_attention_map = F.interpolate(token_attention_map, size=patches2image(loc).shape[-2:], mode='nearest')
|
828 |
+
loc = loc * rearrange(token_attention_map, 'b c (hg h) (wg w) -> (hg wg b) c h w', hg=2, wg=2)
|
829 |
+
|
830 |
+
pools = []
|
831 |
+
for pool_ratio in self.pool_ratios:
|
832 |
+
tgt_hw = (round(h / pool_ratio), round(w / pool_ratio))
|
833 |
+
pool = F.adaptive_avg_pool2d(patched_glb, tgt_hw)
|
834 |
+
pools.append(rearrange(pool, 'nl c h w -> nl c (h w)')) # nl(4),c,hw
|
835 |
+
|
836 |
+
pools = rearrange(torch.cat(pools, 2), "nl c nphw -> nl nphw 1 c")
|
837 |
+
loc_ = rearrange(loc, 'nl c h w -> nl (h w) 1 c')
|
838 |
+
|
839 |
+
outputs = []
|
840 |
+
for i, q in enumerate(loc_.unbind(dim=0)): # traverse all local patches
|
841 |
+
v = pools[i]
|
842 |
+
k = v
|
843 |
+
outputs.append(self.attention[i](q, k, v)[0])
|
844 |
+
|
845 |
+
outputs = torch.cat(outputs, 1)
|
846 |
+
src = loc.view(4, c, -1).permute(2, 0, 1) + self.dropout1(outputs)
|
847 |
+
src = self.norm1(src)
|
848 |
+
src = src + self.dropout2(self.linear4(self.dropout(self.activation(self.linear3(src)).clone())))
|
849 |
+
src = self.norm2(src)
|
850 |
+
src = src.permute(1, 2, 0).reshape(4, c, h, w) # freshed loc
|
851 |
+
glb = glb + F.interpolate(patches2image(src), size=glb.shape[-2:], mode='nearest') # freshed glb
|
852 |
+
|
853 |
+
return torch.cat((src, glb), 0), token_attention_map
|
854 |
+
|
855 |
+
|
856 |
+
class BEN_Base(nn.Module):
|
857 |
+
def __init__(self):
|
858 |
+
super().__init__()
|
859 |
+
|
860 |
+
self.backbone = SwinTransformer(embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32], window_size=12)
|
861 |
+
emb_dim = 128
|
862 |
+
self.sideout5 = nn.Sequential(nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
|
863 |
+
self.sideout4 = nn.Sequential(nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
|
864 |
+
self.sideout3 = nn.Sequential(nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
|
865 |
+
self.sideout2 = nn.Sequential(nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
|
866 |
+
self.sideout1 = nn.Sequential(nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
|
867 |
+
|
868 |
+
self.output5 = make_cbr(1024, emb_dim)
|
869 |
+
self.output4 = make_cbr(512, emb_dim)
|
870 |
+
self.output3 = make_cbr(256, emb_dim)
|
871 |
+
self.output2 = make_cbr(128, emb_dim)
|
872 |
+
self.output1 = make_cbr(128, emb_dim)
|
873 |
+
|
874 |
+
self.multifieldcrossatt = MCLM(emb_dim, 1, [1, 4, 8])
|
875 |
+
self.conv1 = make_cbr(emb_dim, emb_dim)
|
876 |
+
self.conv2 = make_cbr(emb_dim, emb_dim)
|
877 |
+
self.conv3 = make_cbr(emb_dim, emb_dim)
|
878 |
+
self.conv4 = make_cbr(emb_dim, emb_dim)
|
879 |
+
self.dec_blk1 = MCRM(emb_dim, 1, [2, 4, 8])
|
880 |
+
self.dec_blk2 = MCRM(emb_dim, 1, [2, 4, 8])
|
881 |
+
self.dec_blk3 = MCRM(emb_dim, 1, [2, 4, 8])
|
882 |
+
self.dec_blk4 = MCRM(emb_dim, 1, [2, 4, 8])
|
883 |
+
|
884 |
+
self.insmask_head = nn.Sequential(
|
885 |
+
nn.Conv2d(emb_dim, 384, kernel_size=3, padding=1),
|
886 |
+
nn.InstanceNorm2d(384),
|
887 |
+
nn.GELU(),
|
888 |
+
nn.Conv2d(384, 384, kernel_size=3, padding=1),
|
889 |
+
nn.InstanceNorm2d(384),
|
890 |
+
nn.GELU(),
|
891 |
+
nn.Conv2d(384, emb_dim, kernel_size=3, padding=1)
|
892 |
+
)
|
893 |
+
|
894 |
+
self.shallow = nn.Sequential(nn.Conv2d(3, emb_dim, kernel_size=3, padding=1))
|
895 |
+
self.upsample1 = make_cbg(emb_dim, emb_dim)
|
896 |
+
self.upsample2 = make_cbg(emb_dim, emb_dim)
|
897 |
+
self.output = nn.Sequential(nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
|
898 |
+
|
899 |
+
for m in self.modules():
|
900 |
+
if isinstance(m, nn.GELU) or isinstance(m, nn.Dropout):
|
901 |
+
m.inplace = True
|
902 |
+
|
903 |
+
@torch.inference_mode()
|
904 |
+
@torch.autocast(device_type="cuda", dtype=torch.float16)
|
905 |
+
def forward(self, x):
|
906 |
+
real_batch = x.size(0)
|
907 |
+
|
908 |
+
shallow_batch = self.shallow(x)
|
909 |
+
glb_batch = rescale_to(x, scale_factor=0.5, interpolation='bilinear')
|
910 |
+
|
911 |
+
final_input = None
|
912 |
+
for i in range(real_batch):
|
913 |
+
start = i * 4
|
914 |
+
end = (i + 1) * 4
|
915 |
+
loc_batch = image2patches(x[i, :, :, :].unsqueeze(dim=0))
|
916 |
+
input_ = torch.cat((loc_batch, glb_batch[i, :, :, :].unsqueeze(dim=0)), dim=0)
|
917 |
+
|
918 |
+
if final_input == None:
|
919 |
+
final_input = input_
|
920 |
+
else:
|
921 |
+
final_input = torch.cat((final_input, input_), dim=0)
|
922 |
+
|
923 |
+
features = self.backbone(final_input)
|
924 |
+
outputs = []
|
925 |
+
|
926 |
+
for i in range(real_batch):
|
927 |
+
start = i * 5
|
928 |
+
end = (i + 1) * 5
|
929 |
+
|
930 |
+
f4 = features[4][start:end, :, :, :] # shape: [5, C, H, W]
|
931 |
+
f3 = features[3][start:end, :, :, :]
|
932 |
+
f2 = features[2][start:end, :, :, :]
|
933 |
+
f1 = features[1][start:end, :, :, :]
|
934 |
+
f0 = features[0][start:end, :, :, :]
|
935 |
+
e5 = self.output5(f4)
|
936 |
+
e4 = self.output4(f3)
|
937 |
+
e3 = self.output3(f2)
|
938 |
+
e2 = self.output2(f1)
|
939 |
+
e1 = self.output1(f0)
|
940 |
+
loc_e5, glb_e5 = e5.split([4, 1], dim=0)
|
941 |
+
e5 = self.multifieldcrossatt(loc_e5, glb_e5) # (4,128,16,16)
|
942 |
+
|
943 |
+
e4, tokenattmap4 = self.dec_blk4(e4 + resize_as(e5, e4))
|
944 |
+
e4 = self.conv4(e4)
|
945 |
+
e3, tokenattmap3 = self.dec_blk3(e3 + resize_as(e4, e3))
|
946 |
+
e3 = self.conv3(e3)
|
947 |
+
e2, tokenattmap2 = self.dec_blk2(e2 + resize_as(e3, e2))
|
948 |
+
e2 = self.conv2(e2)
|
949 |
+
e1, tokenattmap1 = self.dec_blk1(e1 + resize_as(e2, e1))
|
950 |
+
e1 = self.conv1(e1)
|
951 |
+
|
952 |
+
loc_e1, glb_e1 = e1.split([4, 1], dim=0)
|
953 |
+
|
954 |
+
output1_cat = patches2image(loc_e1) # (1,128,256,256)
|
955 |
+
|
956 |
+
# add glb feat in
|
957 |
+
output1_cat = output1_cat + resize_as(glb_e1, output1_cat)
|
958 |
+
# merge
|
959 |
+
final_output = self.insmask_head(output1_cat) # (1,128,256,256)
|
960 |
+
# shallow feature merge
|
961 |
+
shallow = shallow_batch[i, :, :, :].unsqueeze(dim=0)
|
962 |
+
final_output = final_output + resize_as(shallow, final_output)
|
963 |
+
final_output = self.upsample1(rescale_to(final_output))
|
964 |
+
final_output = rescale_to(final_output + resize_as(shallow, final_output))
|
965 |
+
final_output = self.upsample2(final_output)
|
966 |
+
final_output = self.output(final_output)
|
967 |
+
mask = final_output.sigmoid()
|
968 |
+
outputs.append(mask)
|
969 |
+
|
970 |
+
return torch.cat(outputs, dim=0)
|
971 |
+
|
972 |
+
def loadcheckpoints(self, model_path):
|
973 |
+
model_dict = torch.load(model_path, map_location="cpu", weights_only=True)
|
974 |
+
self.load_state_dict(model_dict['model_state_dict'], strict=True)
|
975 |
+
del model_path
|
976 |
+
|
977 |
+
def inference(self, image, refine_foreground=False, move_to_center=False):
|
978 |
+
|
979 |
+
# set_random_seed(9)
|
980 |
+
# image = ImageOps.exif_transpose(image)
|
981 |
+
if isinstance(image, Image.Image):
|
982 |
+
image, h, w, original_image = rgb_loader_refiner(image)
|
983 |
+
if torch.cuda.is_available():
|
984 |
+
|
985 |
+
img_tensor = img_transform(image).unsqueeze(0).to(next(self.parameters()).device)
|
986 |
+
else:
|
987 |
+
img_tensor = img_transform32(image).unsqueeze(0).to(next(self.parameters()).device)
|
988 |
+
|
989 |
+
with torch.no_grad():
|
990 |
+
res = self.forward(img_tensor)
|
991 |
+
|
992 |
+
|
993 |
+
# Show Results
|
994 |
+
if refine_foreground == True:
|
995 |
+
|
996 |
+
|
997 |
+
pred_pil = transforms.ToPILImage()(res.squeeze())
|
998 |
+
image_masked = refine_foreground_process(original_image, pred_pil)
|
999 |
+
|
1000 |
+
image_masked.putalpha(pred_pil.resize(original_image.size))
|
1001 |
+
return image_masked
|
1002 |
+
|
1003 |
+
else:
|
1004 |
+
alpha = postprocess_image(res, im_size=[w, h])
|
1005 |
+
pred_pil = transforms.ToPILImage()(alpha)
|
1006 |
+
mask = pred_pil.resize(original_image.size)
|
1007 |
+
original_image.putalpha(mask)
|
1008 |
+
# mask = Image.fromarray(alpha)
|
1009 |
+
|
1010 |
+
# 将背景置为白色
|
1011 |
+
white_background = Image.new('RGB', original_image.size, (255, 255, 255))
|
1012 |
+
white_background.paste(original_image, mask=original_image.split()[3])
|
1013 |
+
|
1014 |
+
|
1015 |
+
if move_to_center:
|
1016 |
+
# Get the bounding box of non-transparent pixels
|
1017 |
+
# Get alpha channel and convert to numpy array for processing
|
1018 |
+
alpha_mask = np.array(mask)
|
1019 |
+
|
1020 |
+
# Find coordinates where mask is 255 (foreground)
|
1021 |
+
non_zero_coords = np.where(alpha_mask >= 127.5)
|
1022 |
+
if len(non_zero_coords[0]) > 0:
|
1023 |
+
# Get bounding box from non-zero coordinates
|
1024 |
+
min_y, max_y = non_zero_coords[0].min(), non_zero_coords[0].max()
|
1025 |
+
min_x, max_x = non_zero_coords[1].min(), non_zero_coords[1].max()
|
1026 |
+
|
1027 |
+
# Extract the object region
|
1028 |
+
obj_width = max_x - min_x
|
1029 |
+
obj_height = max_y - min_y
|
1030 |
+
bbox = (min_x, min_y, max_x, max_y)
|
1031 |
+
|
1032 |
+
# Calculate center position
|
1033 |
+
img_width, img_height = white_background.size
|
1034 |
+
center_x = (img_width - obj_width) // 2
|
1035 |
+
center_y = (img_height - obj_height) // 2
|
1036 |
+
|
1037 |
+
# Create new white background
|
1038 |
+
new_background = Image.new('RGB', white_background.size, (255, 255, 255))
|
1039 |
+
|
1040 |
+
# Paste the object at center position
|
1041 |
+
new_background.paste(white_background.crop(bbox), (center_x, center_y))
|
1042 |
+
original_image = new_background
|
1043 |
+
else:
|
1044 |
+
original_image = white_background
|
1045 |
+
else:
|
1046 |
+
original_image = white_background
|
1047 |
+
|
1048 |
+
return original_image
|
1049 |
+
|
1050 |
+
|
1051 |
+
else:
|
1052 |
+
foregrounds = []
|
1053 |
+
for batch in image:
|
1054 |
+
image, h, w, original_image = rgb_loader_refiner(batch)
|
1055 |
+
if torch.cuda.is_available():
|
1056 |
+
|
1057 |
+
img_tensor = img_transform(image).unsqueeze(0).to(next(self.parameters()).device)
|
1058 |
+
else:
|
1059 |
+
img_tensor = img_transform32(image).unsqueeze(0).to(next(self.parameters()).device)
|
1060 |
+
|
1061 |
+
with torch.no_grad():
|
1062 |
+
res = self.forward(img_tensor)
|
1063 |
+
|
1064 |
+
if refine_foreground == True:
|
1065 |
+
|
1066 |
+
pred_pil = transforms.ToPILImage()(res.squeeze())
|
1067 |
+
image_masked = refine_foreground_process(original_image, pred_pil)
|
1068 |
+
|
1069 |
+
image_masked.putalpha(pred_pil.resize(original_image.size))
|
1070 |
+
|
1071 |
+
foregrounds.append(image_masked)
|
1072 |
+
else:
|
1073 |
+
alpha = postprocess_image(res, im_size=[w, h])
|
1074 |
+
pred_pil = transforms.ToPILImage()(alpha)
|
1075 |
+
mask = pred_pil.resize(original_image.size)
|
1076 |
+
original_image.putalpha(mask)
|
1077 |
+
# mask = Image.fromarray(alpha)
|
1078 |
+
foregrounds.append(original_image)
|
1079 |
+
|
1080 |
+
return foregrounds
|
1081 |
+
|
1082 |
+
def segment_video(self, video_path, output_path="./", fps=0, refine_foreground=False, batch=1,
|
1083 |
+
print_frames_processed=True, webm=False, rgb_value=(0, 255, 0)):
|
1084 |
+
|
1085 |
+
"""
|
1086 |
+
Segments the given video to extract the foreground (with alpha) from each frame
|
1087 |
+
and saves the result as either a WebM video (with alpha channel) or MP4 (with a
|
1088 |
+
color background).
|
1089 |
+
|
1090 |
+
Args:
|
1091 |
+
video_path (str):
|
1092 |
+
Path to the input video file.
|
1093 |
+
|
1094 |
+
output_path (str, optional):
|
1095 |
+
Directory (or full path) where the output video and/or files will be saved.
|
1096 |
+
Defaults to "./".
|
1097 |
+
|
1098 |
+
fps (int, optional):
|
1099 |
+
The frames per second (FPS) to use for the output video. If 0 (default), the
|
1100 |
+
original FPS of the input video is used. Otherwise, overrides it.
|
1101 |
+
|
1102 |
+
refine_foreground (bool, optional):
|
1103 |
+
Whether to run an additional “refine foreground” process on each frame.
|
1104 |
+
Defaults to False.
|
1105 |
+
|
1106 |
+
batch (int, optional):
|
1107 |
+
Number of frames to process at once (inference batch size). Large batch sizes
|
1108 |
+
may require more GPU memory. Defaults to 1.
|
1109 |
+
|
1110 |
+
print_frames_processed (bool, optional):
|
1111 |
+
If True (default), prints progress (how many frames have been processed) to
|
1112 |
+
the console.
|
1113 |
+
|
1114 |
+
webm (bool, optional):
|
1115 |
+
If True (default), exports a WebM video with alpha channel (VP9 / yuva420p).
|
1116 |
+
If False, exports an MP4 video composited over a solid color background.
|
1117 |
+
|
1118 |
+
rgb_value (tuple, optional):
|
1119 |
+
The RGB background color (e.g., green screen) used to composite frames when
|
1120 |
+
saving to MP4. Defaults to (0, 255, 0).
|
1121 |
+
|
1122 |
+
Returns:
|
1123 |
+
None. Writes the output video(s) to disk in the specified format.
|
1124 |
+
"""
|
1125 |
+
|
1126 |
+
cap = cv2.VideoCapture(video_path)
|
1127 |
+
if not cap.isOpened():
|
1128 |
+
raise IOError(f"Cannot open video: {video_path}")
|
1129 |
+
|
1130 |
+
original_fps = cap.get(cv2.CAP_PROP_FPS)
|
1131 |
+
original_fps = 30 if original_fps == 0 else original_fps
|
1132 |
+
fps = original_fps if fps == 0 else fps
|
1133 |
+
|
1134 |
+
ret, first_frame = cap.read()
|
1135 |
+
if not ret:
|
1136 |
+
raise ValueError("No frames found in the video.")
|
1137 |
+
height, width = first_frame.shape[:2]
|
1138 |
+
cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
|
1139 |
+
|
1140 |
+
foregrounds = []
|
1141 |
+
frame_idx = 0
|
1142 |
+
processed_count = 0
|
1143 |
+
batch_frames = []
|
1144 |
+
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
1145 |
+
|
1146 |
+
while True:
|
1147 |
+
ret, frame = cap.read()
|
1148 |
+
if not ret:
|
1149 |
+
if batch_frames:
|
1150 |
+
batch_results = self.inference(batch_frames, refine_foreground)
|
1151 |
+
if isinstance(batch_results, Image.Image):
|
1152 |
+
foregrounds.append(batch_results)
|
1153 |
+
else:
|
1154 |
+
foregrounds.extend(batch_results)
|
1155 |
+
if print_frames_processed:
|
1156 |
+
print(f"Processed frames {frame_idx - len(batch_frames) + 1} to {frame_idx} of {total_frames}")
|
1157 |
+
break
|
1158 |
+
|
1159 |
+
# Process every frame instead of using intervals
|
1160 |
+
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
1161 |
+
pil_frame = Image.fromarray(frame_rgb)
|
1162 |
+
batch_frames.append(pil_frame)
|
1163 |
+
|
1164 |
+
if len(batch_frames) == batch:
|
1165 |
+
batch_results = self.inference(batch_frames, refine_foreground)
|
1166 |
+
if isinstance(batch_results, Image.Image):
|
1167 |
+
foregrounds.append(batch_results)
|
1168 |
+
else:
|
1169 |
+
foregrounds.extend(batch_results)
|
1170 |
+
if print_frames_processed:
|
1171 |
+
print(f"Processed frames {frame_idx - batch + 1} to {frame_idx} of {total_frames}")
|
1172 |
+
batch_frames = []
|
1173 |
+
processed_count += batch
|
1174 |
+
|
1175 |
+
frame_idx += 1
|
1176 |
+
|
1177 |
+
if webm:
|
1178 |
+
alpha_webm_path = os.path.join(output_path, "foreground.webm")
|
1179 |
+
pil_images_to_webm_alpha(foregrounds, alpha_webm_path, fps=original_fps)
|
1180 |
+
|
1181 |
+
else:
|
1182 |
+
cap.release()
|
1183 |
+
fg_output = os.path.join(output_path, 'foreground.mp4')
|
1184 |
+
|
1185 |
+
pil_images_to_mp4(foregrounds, fg_output, fps=original_fps, rgb_value=rgb_value)
|
1186 |
+
cv2.destroyAllWindows()
|
1187 |
+
|
1188 |
+
try:
|
1189 |
+
fg_audio_output = os.path.join(output_path, 'foreground_output_with_audio.mp4')
|
1190 |
+
add_audio_to_video(fg_output, video_path, fg_audio_output)
|
1191 |
+
except Exception as e:
|
1192 |
+
print("No audio found in the original video")
|
1193 |
+
print(e)
|
1194 |
+
|
1195 |
+
|
1196 |
+
def rgb_loader_refiner(original_image):
|
1197 |
+
h, w = original_image.size
|
1198 |
+
|
1199 |
+
image = original_image
|
1200 |
+
# Convert to RGB if necessary
|
1201 |
+
if image.mode != 'RGB':
|
1202 |
+
image = image.convert('RGB')
|
1203 |
+
|
1204 |
+
# Resize the image
|
1205 |
+
image = image.resize((1024, 1024), resample=Image.LANCZOS)
|
1206 |
+
|
1207 |
+
return image.convert('RGB'), h, w, original_image
|
1208 |
+
|
1209 |
+
|
1210 |
+
# Define the image transformation
|
1211 |
+
img_transform = transforms.Compose([
|
1212 |
+
transforms.ToTensor(),
|
1213 |
+
transforms.ConvertImageDtype(torch.float16),
|
1214 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
1215 |
+
])
|
1216 |
+
|
1217 |
+
img_transform32 = transforms.Compose([
|
1218 |
+
transforms.ToTensor(),
|
1219 |
+
transforms.ConvertImageDtype(torch.float32),
|
1220 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
1221 |
+
])
|
1222 |
+
|
1223 |
+
|
1224 |
+
def pil_images_to_mp4(images, output_path, fps=24, rgb_value=(0, 255, 0)):
|
1225 |
+
"""
|
1226 |
+
Converts an array of PIL images to an MP4 video.
|
1227 |
+
|
1228 |
+
Args:
|
1229 |
+
images: List of PIL images
|
1230 |
+
output_path: Path to save the MP4 file
|
1231 |
+
fps: Frames per second (default: 24)
|
1232 |
+
rgb_value: Background RGB color tuple (default: green (0, 255, 0))
|
1233 |
+
"""
|
1234 |
+
if not images:
|
1235 |
+
raise ValueError("No images provided to convert to MP4.")
|
1236 |
+
|
1237 |
+
width, height = images[0].size
|
1238 |
+
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
1239 |
+
video_writer = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
|
1240 |
+
|
1241 |
+
for image in images:
|
1242 |
+
# If image has alpha channel, composite onto the specified background color
|
1243 |
+
if image.mode == 'RGBA':
|
1244 |
+
# Create background image with specified RGB color
|
1245 |
+
background = Image.new('RGB', image.size, rgb_value)
|
1246 |
+
background = background.convert('RGBA')
|
1247 |
+
# Composite the image onto the background
|
1248 |
+
image = Image.alpha_composite(background, image)
|
1249 |
+
image = image.convert('RGB')
|
1250 |
+
else:
|
1251 |
+
# Ensure RGB format for non-alpha images
|
1252 |
+
image = image.convert('RGB')
|
1253 |
+
|
1254 |
+
# Convert to OpenCV format and write
|
1255 |
+
open_cv_image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
|
1256 |
+
video_writer.write(open_cv_image)
|
1257 |
+
|
1258 |
+
video_writer.release()
|
1259 |
+
|
1260 |
+
|
1261 |
+
def pil_images_to_webm_alpha(images, output_path, fps=30):
|
1262 |
+
"""
|
1263 |
+
Converts a list of PIL RGBA images to a VP9 .webm video with alpha channel.
|
1264 |
+
|
1265 |
+
NOTE: Not all players will display alpha in WebM.
|
1266 |
+
Browsers like Chrome/Firefox typically do support VP9 alpha.
|
1267 |
+
"""
|
1268 |
+
if not images:
|
1269 |
+
raise ValueError("No images provided for WebM with alpha.")
|
1270 |
+
|
1271 |
+
# Ensure output directory exists
|
1272 |
+
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
1273 |
+
|
1274 |
+
with tempfile.TemporaryDirectory() as tmpdir:
|
1275 |
+
# Save frames as PNG (with alpha)
|
1276 |
+
for idx, img in enumerate(images):
|
1277 |
+
if img.mode != "RGBA":
|
1278 |
+
img = img.convert("RGBA")
|
1279 |
+
out_path = os.path.join(tmpdir, f"{idx:06d}.png")
|
1280 |
+
img.save(out_path, "PNG")
|
1281 |
+
|
1282 |
+
# Construct ffmpeg command
|
1283 |
+
# -c:v libvpx-vp9 => VP9 encoder
|
1284 |
+
# -pix_fmt yuva420p => alpha-enabled pixel format
|
1285 |
+
# -auto-alt-ref 0 => helps preserve alpha frames (libvpx quirk)
|
1286 |
+
ffmpeg_cmd = [
|
1287 |
+
"ffmpeg", "-y",
|
1288 |
+
"-framerate", str(fps),
|
1289 |
+
"-i", os.path.join(tmpdir, "%06d.png"),
|
1290 |
+
"-c:v", "libvpx-vp9",
|
1291 |
+
"-pix_fmt", "yuva420p",
|
1292 |
+
"-auto-alt-ref", "0",
|
1293 |
+
output_path
|
1294 |
+
]
|
1295 |
+
|
1296 |
+
subprocess.run(ffmpeg_cmd, check=True)
|
1297 |
+
|
1298 |
+
print(f"WebM with alpha saved to {output_path}")
|
1299 |
+
|
1300 |
+
|
1301 |
+
def add_audio_to_video(video_without_audio_path, original_video_path, output_path):
|
1302 |
+
"""
|
1303 |
+
Check if the original video has an audio stream. If yes, add it. If not, skip.
|
1304 |
+
"""
|
1305 |
+
# 1) Probe original video for audio streams
|
1306 |
+
probe_command = [
|
1307 |
+
'ffprobe', '-v', 'error',
|
1308 |
+
'-select_streams', 'a:0',
|
1309 |
+
'-show_entries', 'stream=index',
|
1310 |
+
'-of', 'csv=p=0',
|
1311 |
+
original_video_path
|
1312 |
+
]
|
1313 |
+
result = subprocess.run(probe_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
1314 |
+
|
1315 |
+
# result.stdout is empty if no audio stream found
|
1316 |
+
if not result.stdout.strip():
|
1317 |
+
print("No audio track found in original video, skipping audio addition.")
|
1318 |
+
return
|
1319 |
+
|
1320 |
+
print("Audio track detected; proceeding to mux audio.")
|
1321 |
+
# 2) If audio found, run ffmpeg to add it
|
1322 |
+
command = [
|
1323 |
+
'ffmpeg', '-y',
|
1324 |
+
'-i', video_without_audio_path,
|
1325 |
+
'-i', original_video_path,
|
1326 |
+
'-c', 'copy',
|
1327 |
+
'-map', '0:v:0',
|
1328 |
+
'-map', '1:a:0', # we know there's an audio track now
|
1329 |
+
output_path
|
1330 |
+
]
|
1331 |
+
subprocess.run(command, check=True)
|
1332 |
+
print(f"Audio added successfully => {output_path}")
|
1333 |
+
|
1334 |
+
|
1335 |
+
### Thanks to the source: https://huggingface.co/ZhengPeng7/BiRefNet/blob/main/handler.py
|
1336 |
+
def refine_foreground_process(image, mask, r=90):
|
1337 |
+
if mask.size != image.size:
|
1338 |
+
mask = mask.resize(image.size)
|
1339 |
+
image = np.array(image) / 255.0
|
1340 |
+
mask = np.array(mask) / 255.0
|
1341 |
+
estimated_foreground = FB_blur_fusion_foreground_estimator_2(image, mask, r=r)
|
1342 |
+
image_masked = Image.fromarray((estimated_foreground * 255.0).astype(np.uint8))
|
1343 |
+
return image_masked
|
1344 |
+
|
1345 |
+
|
1346 |
+
def FB_blur_fusion_foreground_estimator_2(image, alpha, r=90):
|
1347 |
+
# Thanks to the source: https://github.com/Photoroom/fast-foreground-estimation
|
1348 |
+
alpha = alpha[:, :, None]
|
1349 |
+
F, blur_B = FB_blur_fusion_foreground_estimator(image, image, image, alpha, r)
|
1350 |
+
return FB_blur_fusion_foreground_estimator(image, F, blur_B, alpha, r=6)[0]
|
1351 |
+
|
1352 |
+
|
1353 |
+
def FB_blur_fusion_foreground_estimator(image, F, B, alpha, r=90):
|
1354 |
+
if isinstance(image, Image.Image):
|
1355 |
+
image = np.array(image) / 255.0
|
1356 |
+
blurred_alpha = cv2.blur(alpha, (r, r))[:, :, None]
|
1357 |
+
|
1358 |
+
blurred_FA = cv2.blur(F * alpha, (r, r))
|
1359 |
+
blurred_F = blurred_FA / (blurred_alpha + 1e-5)
|
1360 |
+
|
1361 |
+
blurred_B1A = cv2.blur(B * (1 - alpha), (r, r))
|
1362 |
+
blurred_B = blurred_B1A / ((1 - blurred_alpha) + 1e-5)
|
1363 |
+
F = blurred_F + alpha * \
|
1364 |
+
(image - alpha * blurred_F - (1 - alpha) * blurred_B)
|
1365 |
+
F = np.clip(F, 0, 1)
|
1366 |
+
return F, blurred_B
|
1367 |
+
|
1368 |
+
|
1369 |
+
def postprocess_image(result: torch.Tensor, im_size: list) -> np.ndarray:
|
1370 |
+
result = torch.squeeze(F.interpolate(result, size=im_size, mode='bilinear'), 0)
|
1371 |
+
ma = torch.max(result)
|
1372 |
+
mi = torch.min(result)
|
1373 |
+
result = (result - mi) / (ma - mi)
|
1374 |
+
im_array = (result * 255).permute(1, 2, 0).cpu().data.numpy().astype(np.uint8)
|
1375 |
+
im_array = np.squeeze(im_array)
|
1376 |
+
return im_array
|
1377 |
+
|
1378 |
+
|
1379 |
+
def rgb_loader_refiner(original_image):
|
1380 |
+
h, w = original_image.size
|
1381 |
+
# # Apply EXIF orientation
|
1382 |
+
|
1383 |
+
image = ImageOps.exif_transpose(original_image)
|
1384 |
+
|
1385 |
+
if original_image.mode != 'RGB':
|
1386 |
+
original_image = original_image.convert('RGB')
|
1387 |
+
|
1388 |
+
image = original_image
|
1389 |
+
# Convert to RGB if necessary
|
1390 |
+
|
1391 |
+
# Resize the image
|
1392 |
+
image = image.resize((1024, 1024), resample=Image.LANCZOS)
|
1393 |
+
|
1394 |
+
return image, h, w, original_image
|
app/aspect_ratio_template.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# From https://github.com/TencentARC/PhotoMaker/pull/120 written by https://github.com/DiscoNova
|
2 |
+
# Note: Since output width & height need to be divisible by 8, the w & h -values do
|
3 |
+
# not exactly match the stated aspect ratios... but they are "close enough":)
|
4 |
+
|
5 |
+
ASPECT_RATIO_TEMPLATE = [
|
6 |
+
{
|
7 |
+
"name": "Custom (long edge to 1024px)",
|
8 |
+
"w": "",
|
9 |
+
"h": "",
|
10 |
+
},
|
11 |
+
{
|
12 |
+
"name": "Custom",
|
13 |
+
"w": "",
|
14 |
+
"h": "",
|
15 |
+
},
|
16 |
+
{
|
17 |
+
"name": "Instagram (1:1)",
|
18 |
+
"w": 1024,
|
19 |
+
"h": 1024,
|
20 |
+
},
|
21 |
+
{
|
22 |
+
"name": "35mm film / Landscape (3:2)",
|
23 |
+
"w": 1024,
|
24 |
+
"h": 680,
|
25 |
+
},
|
26 |
+
{
|
27 |
+
"name": "35mm film / Portrait (2:3)",
|
28 |
+
"w": 680,
|
29 |
+
"h": 1024,
|
30 |
+
},
|
31 |
+
{
|
32 |
+
"name": "CRT Monitor / Landscape (4:3)",
|
33 |
+
"w": 1024,
|
34 |
+
"h": 768,
|
35 |
+
},
|
36 |
+
{
|
37 |
+
"name": "CRT Monitor / Portrait (3:4)",
|
38 |
+
"w": 768,
|
39 |
+
"h": 1024,
|
40 |
+
},
|
41 |
+
{
|
42 |
+
"name": "Widescreen TV / Landscape (16:9)",
|
43 |
+
"w": 1024,
|
44 |
+
"h": 576,
|
45 |
+
},
|
46 |
+
{
|
47 |
+
"name": "Widescreen TV / Portrait (9:16)",
|
48 |
+
"w": 576,
|
49 |
+
"h": 1024,
|
50 |
+
},
|
51 |
+
{
|
52 |
+
"name": "Widescreen Monitor / Landscape (16:10)",
|
53 |
+
"w": 1024,
|
54 |
+
"h": 640,
|
55 |
+
},
|
56 |
+
{
|
57 |
+
"name": "Widescreen Monitor / Portrait (10:16)",
|
58 |
+
"w": 640,
|
59 |
+
"h": 1024,
|
60 |
+
},
|
61 |
+
{
|
62 |
+
"name": "Cinemascope (2.39:1)",
|
63 |
+
"w": 1024,
|
64 |
+
"h": 424,
|
65 |
+
},
|
66 |
+
{
|
67 |
+
"name": "Widescreen Movie (1.85:1)",
|
68 |
+
"w": 1024,
|
69 |
+
"h": 552,
|
70 |
+
},
|
71 |
+
{
|
72 |
+
"name": "Academy Movie (1.37:1)",
|
73 |
+
"w": 1024,
|
74 |
+
"h": 744,
|
75 |
+
},
|
76 |
+
{
|
77 |
+
"name": "Sheet-print (A-series) / Landscape (297:210)",
|
78 |
+
"w": 1024,
|
79 |
+
"h": 720,
|
80 |
+
},
|
81 |
+
{
|
82 |
+
"name": "Sheet-print (A-series) / Portrait (210:297)",
|
83 |
+
"w": 720,
|
84 |
+
"h": 1024,
|
85 |
+
},
|
86 |
+
]
|
87 |
+
|
88 |
+
ASPECT_RATIO_TEMPLATE = {k["name"]: (k["w"], k["h"]) for k in ASPECT_RATIO_TEMPLATE}
|
app/business_logic.py
ADDED
@@ -0,0 +1,556 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
Business logic functions for IC-Custom application.
|
5 |
+
"""
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import cv2
|
9 |
+
import gradio as gr
|
10 |
+
from PIL import Image
|
11 |
+
from datetime import datetime
|
12 |
+
import json
|
13 |
+
import os
|
14 |
+
from scipy.ndimage import binary_dilation, binary_erosion
|
15 |
+
|
16 |
+
from constants import (
|
17 |
+
DEFAULT_BACKGROUND_BLEND_THRESHOLD, DEFAULT_SEED, DEFAULT_NUM_IMAGES,
|
18 |
+
DEFAULT_GUIDANCE, DEFAULT_TRUE_GS, DEFAULT_NUM_STEPS, DEFAULT_ASPECT_RATIO,
|
19 |
+
DEFAULT_DILATION_KERNEL_SIZE, DEFAULT_MARKER_SIZE, DEFAULT_MARKER_THICKNESS,
|
20 |
+
DEFAULT_MASK_ALPHA, DEFAULT_COLOR_ALPHA, TIMESTAMP_FORMAT, SEGMENTATION_COLORS, SEGMENTATION_MARKERS
|
21 |
+
)
|
22 |
+
|
23 |
+
from utils import run_vlm, construct_vlm_gen_prompt, construct_vlm_polish_prompt
|
24 |
+
|
25 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
26 |
+
|
27 |
+
# Global holder for SAM mobile predictor injected from the app layer
|
28 |
+
MOBILE_PREDICTOR = None
|
29 |
+
BEN2_MODEL = None # ben2 model injected from the app layer
|
30 |
+
|
31 |
+
def set_mobile_predictor(predictor):
|
32 |
+
"""Inject SAM mobile predictor into this module without changing function signatures."""
|
33 |
+
global MOBILE_PREDICTOR
|
34 |
+
MOBILE_PREDICTOR = predictor
|
35 |
+
|
36 |
+
def set_ben2_model(ben2_model):
|
37 |
+
"""Inject ben2 model into this module without changing function signatures."""
|
38 |
+
global BEN2_MODEL
|
39 |
+
BEN2_MODEL = ben2_model
|
40 |
+
|
41 |
+
def set_vlm_processor(vlm_processor):
|
42 |
+
"""Inject vlm processor into this module without changing function signatures."""
|
43 |
+
global VLM_PROCESSOR
|
44 |
+
VLM_PROCESSOR = vlm_processor
|
45 |
+
|
46 |
+
def set_vlm_model(vlm_model):
|
47 |
+
"""Inject vlm model into this module without changing function signatures."""
|
48 |
+
global VLM_MODEL
|
49 |
+
VLM_MODEL = vlm_model
|
50 |
+
|
51 |
+
|
52 |
+
def init_image_target_1(target_image):
|
53 |
+
"""Initialize UI state when a target image is uploaded."""
|
54 |
+
|
55 |
+
# Handle both PIL Image (image_target_1) and ImageEditor dict (image_target_2)
|
56 |
+
try:
|
57 |
+
if isinstance(target_image, dict) and 'composite' in target_image:
|
58 |
+
# ImageEditor format (user-drawn mask)
|
59 |
+
image_target_state = np.array(target_image['composite'].convert("RGB"))
|
60 |
+
else:
|
61 |
+
# PIL Image format (precise mask)
|
62 |
+
image_target_state = np.array(target_image.convert("RGB"))
|
63 |
+
except Exception as e:
|
64 |
+
# If there's an error processing the image, skip initialization
|
65 |
+
return (
|
66 |
+
gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(),
|
67 |
+
gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(),
|
68 |
+
gr.skip(), gr.skip(), gr.update(value="-1")
|
69 |
+
)
|
70 |
+
|
71 |
+
selected_points = []
|
72 |
+
mask_target_state = None
|
73 |
+
prompt = None
|
74 |
+
mask_gallery = []
|
75 |
+
result_gallery = []
|
76 |
+
use_background_preservation = False
|
77 |
+
background_blend_threshold = DEFAULT_BACKGROUND_BLEND_THRESHOLD
|
78 |
+
seed = DEFAULT_SEED
|
79 |
+
num_images_per_prompt = DEFAULT_NUM_IMAGES
|
80 |
+
guidance = DEFAULT_GUIDANCE
|
81 |
+
true_gs = DEFAULT_TRUE_GS
|
82 |
+
num_steps = DEFAULT_NUM_STEPS
|
83 |
+
aspect_ratio_val = gr.update(value=DEFAULT_ASPECT_RATIO)
|
84 |
+
|
85 |
+
return (image_target_state, selected_points, mask_target_state, prompt,
|
86 |
+
mask_gallery, result_gallery, use_background_preservation,
|
87 |
+
background_blend_threshold, seed, num_images_per_prompt, guidance,
|
88 |
+
true_gs, num_steps, aspect_ratio_val)
|
89 |
+
|
90 |
+
|
91 |
+
def init_image_target_2(target_image):
|
92 |
+
"""Initialize UI state when a target image is uploaded."""
|
93 |
+
# Handle both PIL Image (image_target_1) and ImageEditor dict (image_target_2)
|
94 |
+
try:
|
95 |
+
if isinstance(target_image, dict) and 'composite' in target_image:
|
96 |
+
# ImageEditor format (user-drawn mask)
|
97 |
+
image_target_state = np.array(target_image['composite'].convert("RGB"))
|
98 |
+
else:
|
99 |
+
# PIL Image format (precise mask)
|
100 |
+
image_target_state = np.array(target_image.convert("RGB"))
|
101 |
+
except Exception as e:
|
102 |
+
# If there's an error processing the image, skip initialization
|
103 |
+
return (
|
104 |
+
gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(),
|
105 |
+
gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(),
|
106 |
+
gr.skip(), gr.skip(), gr.update(value="-1")
|
107 |
+
)
|
108 |
+
|
109 |
+
selected_points = gr.skip()
|
110 |
+
mask_target_state = gr.skip()
|
111 |
+
prompt = gr.skip()
|
112 |
+
mask_gallery = gr.skip()
|
113 |
+
result_gallery = gr.skip()
|
114 |
+
use_background_preservation = gr.skip()
|
115 |
+
background_blend_threshold = gr.skip()
|
116 |
+
seed = gr.skip()
|
117 |
+
num_images_per_prompt = gr.skip()
|
118 |
+
guidance = gr.skip()
|
119 |
+
true_gs = gr.skip()
|
120 |
+
num_steps = gr.skip()
|
121 |
+
aspect_ratio_val = gr.skip()
|
122 |
+
|
123 |
+
return (image_target_state, selected_points, mask_target_state, prompt,
|
124 |
+
mask_gallery, result_gallery, use_background_preservation,
|
125 |
+
background_blend_threshold, seed, num_images_per_prompt, guidance,
|
126 |
+
true_gs, num_steps, aspect_ratio_val)
|
127 |
+
|
128 |
+
|
129 |
+
def init_image_reference(image_reference):
|
130 |
+
"""Initialize all UI states when a reference image is uploaded."""
|
131 |
+
image_reference_state = np.array(image_reference.convert("RGB"))
|
132 |
+
image_reference_ori_state = image_reference_state
|
133 |
+
image_reference_rmbg_state = None
|
134 |
+
image_target_state = None
|
135 |
+
mask_target_state = None
|
136 |
+
prompt = None
|
137 |
+
mask_gallery = []
|
138 |
+
result_gallery = []
|
139 |
+
image_target_1_val = None
|
140 |
+
image_target_2_val = None
|
141 |
+
selected_points = []
|
142 |
+
input_mask_mode_val = gr.update(value="Precise mask")
|
143 |
+
seg_ref_mode_val = gr.update(value="Full Ref")
|
144 |
+
move_to_center = False
|
145 |
+
use_background_preservation = False
|
146 |
+
background_blend_threshold = DEFAULT_BACKGROUND_BLEND_THRESHOLD
|
147 |
+
seed = DEFAULT_SEED
|
148 |
+
num_images_per_prompt = DEFAULT_NUM_IMAGES
|
149 |
+
guidance = DEFAULT_GUIDANCE
|
150 |
+
true_gs = DEFAULT_TRUE_GS
|
151 |
+
num_steps = DEFAULT_NUM_STEPS
|
152 |
+
aspect_ratio_val = gr.update(value=DEFAULT_ASPECT_RATIO)
|
153 |
+
|
154 |
+
return (
|
155 |
+
image_reference_ori_state, image_reference_rmbg_state, image_target_state,
|
156 |
+
mask_target_state, prompt, mask_gallery, result_gallery, image_target_1_val,
|
157 |
+
image_target_2_val, selected_points, input_mask_mode_val, seg_ref_mode_val,
|
158 |
+
move_to_center, use_background_preservation, background_blend_threshold,
|
159 |
+
seed, num_images_per_prompt, guidance, true_gs, num_steps, aspect_ratio_val,
|
160 |
+
)
|
161 |
+
|
162 |
+
|
163 |
+
def undo_seg_points(orig_img, sel_pix):
|
164 |
+
"""Remove the latest segmentation point and recompute the preview mask."""
|
165 |
+
if len(sel_pix) != 0:
|
166 |
+
temp = orig_img.copy()
|
167 |
+
sel_pix.pop()
|
168 |
+
# Online show seg mask
|
169 |
+
if len(sel_pix) != 0:
|
170 |
+
temp, output_mask = segmentation(temp, sel_pix, MOBILE_PREDICTOR, SEGMENTATION_COLORS, SEGMENTATION_MARKERS)
|
171 |
+
output_mask_pil = Image.fromarray(output_mask.astype("uint8"))
|
172 |
+
masked_img_pil = Image.fromarray(np.where(output_mask > 0, orig_img, 0).astype("uint8"))
|
173 |
+
mask_gallery = [masked_img_pil, output_mask_pil]
|
174 |
+
else:
|
175 |
+
output_mask = None
|
176 |
+
mask_gallery = []
|
177 |
+
return temp.astype(np.uint8), output_mask, mask_gallery
|
178 |
+
else:
|
179 |
+
gr.Warning("Nothing to Undo")
|
180 |
+
return orig_img, None, []
|
181 |
+
|
182 |
+
|
183 |
+
def segmentation(img, sel_pix, mobile_predictor, colors, markers):
|
184 |
+
"""Run SAM-based segmentation given selected points and return previews."""
|
185 |
+
points = []
|
186 |
+
labels = []
|
187 |
+
for p, l in sel_pix:
|
188 |
+
points.append(p)
|
189 |
+
labels.append(l)
|
190 |
+
|
191 |
+
mobile_predictor.set_image(img if isinstance(img, np.ndarray) else np.array(img))
|
192 |
+
with torch.no_grad():
|
193 |
+
masks, _, _ = mobile_predictor.predict(
|
194 |
+
point_coords=np.array(points),
|
195 |
+
point_labels=np.array(labels),
|
196 |
+
multimask_output=False
|
197 |
+
)
|
198 |
+
|
199 |
+
output_mask = np.ones((masks.shape[1], masks.shape[2], 3)) * 255
|
200 |
+
for i in range(3):
|
201 |
+
output_mask[masks[0] == True, i] = 0.0
|
202 |
+
|
203 |
+
mask_all = np.ones((masks.shape[1], masks.shape[2], 3))
|
204 |
+
color_mask = np.random.random((1, 3)).tolist()[0]
|
205 |
+
for i in range(3):
|
206 |
+
mask_all[masks[0] == True, i] = color_mask[i]
|
207 |
+
|
208 |
+
masked_img = img / 255 * DEFAULT_MASK_ALPHA + mask_all * DEFAULT_COLOR_ALPHA
|
209 |
+
masked_img = masked_img * 255
|
210 |
+
|
211 |
+
# Draw points
|
212 |
+
for point, label in sel_pix:
|
213 |
+
cv2.drawMarker(
|
214 |
+
masked_img, point, colors[label],
|
215 |
+
markerType=markers[label],
|
216 |
+
markerSize=DEFAULT_MARKER_SIZE,
|
217 |
+
thickness=DEFAULT_MARKER_THICKNESS
|
218 |
+
)
|
219 |
+
|
220 |
+
return masked_img, output_mask
|
221 |
+
|
222 |
+
|
223 |
+
def get_point(img, sel_pix, evt: gr.SelectData):
|
224 |
+
"""Handle a user click on the target image to add a foreground point."""
|
225 |
+
if evt is None or not hasattr(evt, 'index'):
|
226 |
+
gr.Warning(f"Event object missing index attribute. Event type: {type(evt)}")
|
227 |
+
return img, None, []
|
228 |
+
|
229 |
+
sel_pix.append((evt.index, 1)) # append the foreground_point
|
230 |
+
# Online show seg mask
|
231 |
+
global MOBILE_PREDICTOR
|
232 |
+
masked_img_seg, output_mask = segmentation(img, sel_pix, MOBILE_PREDICTOR, SEGMENTATION_COLORS, SEGMENTATION_MARKERS)
|
233 |
+
|
234 |
+
# Apply dilation to output_mask
|
235 |
+
output_mask = 1 - output_mask
|
236 |
+
kernel = np.ones((DEFAULT_DILATION_KERNEL_SIZE, DEFAULT_DILATION_KERNEL_SIZE), np.uint8)
|
237 |
+
output_mask = cv2.dilate(output_mask, kernel, iterations=1)
|
238 |
+
output_mask = 1 - output_mask
|
239 |
+
|
240 |
+
output_mask_binary = output_mask / 255
|
241 |
+
|
242 |
+
masked_img_seg = masked_img_seg.astype("uint8")
|
243 |
+
output_mask = output_mask.astype("uint8")
|
244 |
+
|
245 |
+
masked_img = img * output_mask_binary
|
246 |
+
masked_img_pil = Image.fromarray(masked_img.astype("uint8"))
|
247 |
+
output_mask_pil = Image.fromarray(output_mask.astype("uint8"))
|
248 |
+
outputs_gallery = [masked_img_pil, output_mask_pil]
|
249 |
+
|
250 |
+
return masked_img_seg, output_mask, outputs_gallery
|
251 |
+
|
252 |
+
|
253 |
+
def get_brush(img):
|
254 |
+
"""Extract a mask from ImageEditor brush layers or composite/background diff."""
|
255 |
+
if img is None or not isinstance(img, dict):
|
256 |
+
return gr.skip(), gr.skip()
|
257 |
+
|
258 |
+
layers = img.get("layers", [])
|
259 |
+
background = img.get('background', None)
|
260 |
+
composite = img.get('composite', None)
|
261 |
+
|
262 |
+
output_mask = None
|
263 |
+
if layers and layers[0] is not None and background is not None:
|
264 |
+
output_mask = 255 - np.array(layers[0].convert("RGB")).astype(np.uint8)
|
265 |
+
elif composite is not None and background is not None:
|
266 |
+
comp_rgb = np.array(composite.convert("RGB")).astype(np.int16)
|
267 |
+
bg_rgb = np.array(background.convert("RGB")).astype(np.int16)
|
268 |
+
diff = np.abs(comp_rgb - bg_rgb)
|
269 |
+
painted = (diff.sum(axis=2) > 0).astype(np.uint8)
|
270 |
+
output_mask = (1 - painted) * 255
|
271 |
+
output_mask = np.repeat(output_mask[:, :, None], 3, axis=2).astype(np.uint8)
|
272 |
+
else:
|
273 |
+
return gr.skip(), gr.skip()
|
274 |
+
|
275 |
+
if len(np.unique(output_mask)) == 1:
|
276 |
+
return gr.skip(), gr.skip()
|
277 |
+
|
278 |
+
img = np.array(background.convert("RGB")).astype(np.uint8)
|
279 |
+
|
280 |
+
output_mask_binary = output_mask / 255
|
281 |
+
masked_img = img * output_mask_binary
|
282 |
+
masked_img_pil = Image.fromarray(masked_img.astype("uint8"))
|
283 |
+
output_mask_pil = Image.fromarray(output_mask.astype("uint8"))
|
284 |
+
mask_gallery = [masked_img_pil, output_mask_pil]
|
285 |
+
|
286 |
+
return output_mask, mask_gallery
|
287 |
+
|
288 |
+
|
289 |
+
def random_mask_func(mask, dilation_type='square', dilation_size=20):
|
290 |
+
"""Utility to dilate/erode/box/ellipse expand a binary mask."""
|
291 |
+
binary_mask = mask[:,:,0] < 128
|
292 |
+
|
293 |
+
if dilation_type == 'square_dilation':
|
294 |
+
structure = np.ones((dilation_size, dilation_size), dtype=bool)
|
295 |
+
dilated_mask = binary_dilation(binary_mask, structure=structure)
|
296 |
+
elif dilation_type == 'square_erosion':
|
297 |
+
structure = np.ones((dilation_size, dilation_size), dtype=bool)
|
298 |
+
dilated_mask = binary_erosion(binary_mask, structure=structure)
|
299 |
+
elif dilation_type == 'bounding_box':
|
300 |
+
# Find the most left top and left bottom point
|
301 |
+
rows, cols = np.where(binary_mask)
|
302 |
+
if len(rows) == 0 or len(cols) == 0:
|
303 |
+
return mask # return original mask if no valid points
|
304 |
+
|
305 |
+
min_row, max_row = np.min(rows), np.max(rows)
|
306 |
+
min_col, max_col = np.min(cols), np.max(cols)
|
307 |
+
|
308 |
+
# Create a bounding box
|
309 |
+
dilated_mask = np.zeros_like(binary_mask, dtype=bool)
|
310 |
+
dilated_mask[min_row:max_row + 1, min_col:max_col + 1] = True
|
311 |
+
|
312 |
+
elif dilation_type == 'bounding_ellipse':
|
313 |
+
# Find the most left top and left bottom point
|
314 |
+
rows, cols = np.where(binary_mask)
|
315 |
+
if len(rows) == 0 or len(cols) == 0:
|
316 |
+
return mask # return original mask if no valid points
|
317 |
+
|
318 |
+
min_row, max_row = np.min(rows), np.max(rows)
|
319 |
+
min_col, max_col = np.min(cols), np.max(cols)
|
320 |
+
|
321 |
+
# Calculate the center and axis length of the ellipse
|
322 |
+
center = ((min_col + max_col) // 2, (min_row + max_row) // 2)
|
323 |
+
a = (max_col - min_col) // 2 # half long axis
|
324 |
+
b = (max_row - min_row) // 2 # half short axis
|
325 |
+
|
326 |
+
# Create a bounding ellipse
|
327 |
+
y, x = np.ogrid[:mask.shape[0], :mask.shape[1]]
|
328 |
+
ellipse_mask = ((x - center[0])**2 / a**2 + (y - center[1])**2 / b**2) <= 1
|
329 |
+
dilated_mask = np.zeros_like(binary_mask, dtype=bool)
|
330 |
+
dilated_mask[ellipse_mask] = True
|
331 |
+
else:
|
332 |
+
raise ValueError("dilation_type must be 'square', 'ellipse', 'bounding_box', or 'bounding_ellipse'")
|
333 |
+
|
334 |
+
# Use binary dilation
|
335 |
+
dilated_mask = 1 - dilated_mask
|
336 |
+
dilated_mask = np.uint8(dilated_mask[:,:,np.newaxis]) * 255
|
337 |
+
dilated_mask = np.concatenate([dilated_mask, dilated_mask, dilated_mask], axis=2)
|
338 |
+
return dilated_mask
|
339 |
+
|
340 |
+
|
341 |
+
def dilate_mask(mask, image):
|
342 |
+
"""Dilate the target mask for robustness and preview the result."""
|
343 |
+
if mask is None:
|
344 |
+
gr.Warning("Please input the target mask first")
|
345 |
+
return None, None
|
346 |
+
|
347 |
+
mask = random_mask_func(mask, dilation_type='square_dilation', dilation_size=DEFAULT_DILATION_KERNEL_SIZE)
|
348 |
+
masked_img = image * (mask > 0)
|
349 |
+
return mask, [masked_img, mask]
|
350 |
+
|
351 |
+
|
352 |
+
def erode_mask(mask, image):
|
353 |
+
"""Erode the target mask and preview the result."""
|
354 |
+
if mask is None:
|
355 |
+
gr.Warning("Please input the target mask first")
|
356 |
+
return None, None
|
357 |
+
|
358 |
+
mask = random_mask_func(mask, dilation_type='square_erosion', dilation_size=DEFAULT_DILATION_KERNEL_SIZE)
|
359 |
+
masked_img = image * (mask > 0)
|
360 |
+
return mask, [masked_img, mask]
|
361 |
+
|
362 |
+
|
363 |
+
def bounding_box(mask, image):
|
364 |
+
"""Create bounding box mask and preview the result."""
|
365 |
+
if mask is None:
|
366 |
+
gr.Warning("Please input the target mask first")
|
367 |
+
return None, None
|
368 |
+
|
369 |
+
mask = random_mask_func(mask, dilation_type='bounding_box', dilation_size=DEFAULT_DILATION_KERNEL_SIZE)
|
370 |
+
masked_img = image * (mask > 0)
|
371 |
+
return mask, [masked_img, mask]
|
372 |
+
|
373 |
+
|
374 |
+
def change_input_mask_mode(input_mask_mode, custmization_mode):
|
375 |
+
"""Change visibility of input mask mode components."""
|
376 |
+
|
377 |
+
if custmization_mode == "Position-free":
|
378 |
+
return (
|
379 |
+
gr.update(visible=False),
|
380 |
+
gr.update(visible=False),
|
381 |
+
gr.update(visible=False),
|
382 |
+
)
|
383 |
+
elif input_mask_mode.lower() == "precise mask":
|
384 |
+
return (
|
385 |
+
gr.update(visible=True),
|
386 |
+
gr.update(visible=False),
|
387 |
+
gr.update(visible=True),
|
388 |
+
)
|
389 |
+
elif input_mask_mode.lower() == "user-drawn mask":
|
390 |
+
return (
|
391 |
+
gr.update(visible=False),
|
392 |
+
gr.update(visible=True),
|
393 |
+
gr.update(visible=False),
|
394 |
+
)
|
395 |
+
else:
|
396 |
+
gr.Warning("Invalid input mask mode")
|
397 |
+
return (
|
398 |
+
gr.skip(), gr.skip(), gr.skip()
|
399 |
+
)
|
400 |
+
|
401 |
+
def change_custmization_mode(custmization_mode, input_mask_mode):
|
402 |
+
"""Change visibility and interactivity based on customization mode."""
|
403 |
+
|
404 |
+
|
405 |
+
if custmization_mode.lower() == "position-free":
|
406 |
+
return (gr.update(interactive=False, visible=False),
|
407 |
+
gr.update(interactive=False, visible=False),
|
408 |
+
gr.update(interactive=False, visible=False),
|
409 |
+
gr.update(interactive=False, visible=False),
|
410 |
+
gr.update(interactive=False, visible=False),
|
411 |
+
gr.update(interactive=False, visible=False),
|
412 |
+
gr.update(value="<s>Select a input mask mode</s>", visible=False),
|
413 |
+
gr.update(value="<s>Input target image & mask (Iterate clicking or brushing until the target is covered)</s>", visible=False),
|
414 |
+
gr.update(value="<s>View or modify the target mask</s>", visible=False),
|
415 |
+
gr.update(value="3. Input text prompt (necessary)"),
|
416 |
+
gr.update(value="4. Submit and view the output"),
|
417 |
+
gr.update(visible=False),
|
418 |
+
gr.update(visible=False),
|
419 |
+
|
420 |
+
)
|
421 |
+
else:
|
422 |
+
if input_mask_mode.lower() == "precise mask":
|
423 |
+
return (gr.update(interactive=True, visible=True),
|
424 |
+
gr.update(interactive=True, visible=False),
|
425 |
+
gr.update(interactive=True, visible=True),
|
426 |
+
gr.update(interactive=True, visible=True),
|
427 |
+
gr.update(interactive=True, visible=True),
|
428 |
+
gr.update(interactive=True, visible=True),
|
429 |
+
gr.update(value="3. Select a input mask mode", visible=True),
|
430 |
+
gr.update(value="4. Input target image & mask (Iterate clicking or brushing until the target is covered)", visible=True),
|
431 |
+
gr.update(value="6. View or modify the target mask", visible=True),
|
432 |
+
gr.update(value="5. Input text prompt (optional)", visible=True),
|
433 |
+
gr.update(value="7. Submit and view the output", visible=True),
|
434 |
+
gr.update(visible=True, value="Precise mask"),
|
435 |
+
gr.update(visible=True),
|
436 |
+
)
|
437 |
+
elif input_mask_mode.lower() == "user-drawn mask":
|
438 |
+
return (gr.update(interactive=True, visible=False),
|
439 |
+
gr.update(interactive=True, visible=True),
|
440 |
+
gr.update(interactive=False, visible=False),
|
441 |
+
gr.update(interactive=True, visible=True),
|
442 |
+
gr.update(interactive=True, visible=True),
|
443 |
+
gr.update(interactive=True, visible=True),
|
444 |
+
gr.update(value="3. Select a input mask mode", visible=True),
|
445 |
+
gr.update(value="4. Input target image & mask (Iterate clicking or brushing until the target is covered)", visible=True),
|
446 |
+
gr.update(value="6. View or modify the target mask", visible=True),
|
447 |
+
gr.update(value="5. Input text prompt (optional)", visible=True),
|
448 |
+
gr.update(value="7. Submit and view the output", visible=True),
|
449 |
+
gr.update(visible=True, value="User-drawn mask"),
|
450 |
+
gr.update(visible=True),
|
451 |
+
)
|
452 |
+
|
453 |
+
|
454 |
+
def change_seg_ref_mode(seg_ref_mode, image_reference_state, move_to_center):
|
455 |
+
"""Change segmentation reference mode and handle background removal."""
|
456 |
+
if image_reference_state is None:
|
457 |
+
gr.Warning("Please upload the reference image first")
|
458 |
+
return None, None
|
459 |
+
|
460 |
+
global BEN2_MODEL
|
461 |
+
|
462 |
+
if seg_ref_mode == "Full Ref":
|
463 |
+
return image_reference_state, None
|
464 |
+
else:
|
465 |
+
if BEN2_MODEL is None:
|
466 |
+
gr.Warning("Please enable ben2 for mask reference first")
|
467 |
+
return gr.skip(), gr.skip()
|
468 |
+
|
469 |
+
image_reference_pil = Image.fromarray(image_reference_state)
|
470 |
+
image_reference_pil_rmbg = BEN2_MODEL.inference(image_reference_pil, move_to_center=move_to_center)
|
471 |
+
image_reference_rmbg = np.array(image_reference_pil_rmbg)
|
472 |
+
return image_reference_rmbg, image_reference_rmbg
|
473 |
+
|
474 |
+
|
475 |
+
def vlm_auto_generate(image_target_state, image_reference_state, mask_target_state,
|
476 |
+
custmization_mode):
|
477 |
+
"""Auto-generate prompt using VLM."""
|
478 |
+
|
479 |
+
global VLM_PROCESSOR, VLM_MODEL
|
480 |
+
|
481 |
+
if custmization_mode == "Position-aware":
|
482 |
+
if image_target_state is None or mask_target_state is None:
|
483 |
+
gr.Warning("Please upload the target image and get mask first")
|
484 |
+
return None
|
485 |
+
|
486 |
+
if image_reference_state is None:
|
487 |
+
gr.Warning("Please upload the reference image first")
|
488 |
+
return None
|
489 |
+
|
490 |
+
if VLM_PROCESSOR is None or VLM_MODEL is None:
|
491 |
+
gr.Warning("Please enable vlm for prompt first")
|
492 |
+
return prompt
|
493 |
+
|
494 |
+
messages = construct_vlm_gen_prompt(image_target_state, image_reference_state, mask_target_state, custmization_mode)
|
495 |
+
output_text = run_vlm(VLM_PROCESSOR, VLM_MODEL, messages, device=device)
|
496 |
+
return output_text
|
497 |
+
|
498 |
+
|
499 |
+
def vlm_auto_polish(prompt, custmization_mode):
|
500 |
+
"""Auto-polish prompt using VLM."""
|
501 |
+
|
502 |
+
global VLM_PROCESSOR, VLM_MODEL
|
503 |
+
|
504 |
+
if prompt is None:
|
505 |
+
gr.Warning("Please input the text prompt first")
|
506 |
+
return None
|
507 |
+
|
508 |
+
if custmization_mode == "Position-aware":
|
509 |
+
gr.Warning("Polishing only works in position-free mode")
|
510 |
+
return prompt
|
511 |
+
|
512 |
+
|
513 |
+
if VLM_PROCESSOR is None or VLM_MODEL is None:
|
514 |
+
gr.Warning("Please enable vlm for prompt first")
|
515 |
+
return prompt
|
516 |
+
|
517 |
+
messages = construct_vlm_polish_prompt(prompt)
|
518 |
+
output_text = run_vlm(VLM_PROCESSOR, VLM_MODEL, messages, device=device)
|
519 |
+
return output_text
|
520 |
+
|
521 |
+
|
522 |
+
def save_results(output_img, image_reference, image_target, mask_target, prompt,
|
523 |
+
custmization_mode, input_mask_mode, seg_ref_mode, seed, guidance,
|
524 |
+
num_steps, num_images_per_prompt, use_background_preservation,
|
525 |
+
background_blend_threshold, true_gs, assets_cache_dir):
|
526 |
+
"""Save generated results and metadata."""
|
527 |
+
save_name = datetime.now().strftime(TIMESTAMP_FORMAT)
|
528 |
+
results = []
|
529 |
+
|
530 |
+
for i in range(num_images_per_prompt):
|
531 |
+
save_dir = os.path.join(assets_cache_dir, save_name)
|
532 |
+
os.makedirs(save_dir, exist_ok=True)
|
533 |
+
|
534 |
+
output_img[i].save(os.path.join(save_dir, f"img_gen_{i}.png"))
|
535 |
+
image_reference.save(os.path.join(save_dir, f"img_ref_{i}.png"))
|
536 |
+
image_target.save(os.path.join(save_dir, f"img_target_{i}.png"))
|
537 |
+
mask_target.save(os.path.join(save_dir, f"mask_target_{i}.png"))
|
538 |
+
|
539 |
+
with open(os.path.join(save_dir, f"hyper_params_{i}.json"), "w") as f:
|
540 |
+
json.dump({
|
541 |
+
"prompt": prompt,
|
542 |
+
"custmization_mode": custmization_mode,
|
543 |
+
"input_mask_mode": input_mask_mode,
|
544 |
+
"seg_ref_mode": seg_ref_mode,
|
545 |
+
"seed": seed,
|
546 |
+
"guidance": guidance,
|
547 |
+
"num_steps": num_steps,
|
548 |
+
"num_images_per_prompt": num_images_per_prompt,
|
549 |
+
"use_background_preservation": use_background_preservation,
|
550 |
+
"background_blend_threshold": background_blend_threshold,
|
551 |
+
"true_gs": true_gs,
|
552 |
+
}, f)
|
553 |
+
|
554 |
+
results.append(output_img[i])
|
555 |
+
|
556 |
+
return results
|
app/config.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
Configuration management for IC-Custom application.
|
5 |
+
"""
|
6 |
+
import os
|
7 |
+
import argparse
|
8 |
+
from omegaconf import OmegaConf
|
9 |
+
|
10 |
+
|
11 |
+
def parse_args():
|
12 |
+
"""Parse command line arguments."""
|
13 |
+
parser = argparse.ArgumentParser(description="IC-Custom App.")
|
14 |
+
parser.add_argument(
|
15 |
+
"--config",
|
16 |
+
type=str,
|
17 |
+
default="configs/app/app.yaml",
|
18 |
+
help="path to config",
|
19 |
+
)
|
20 |
+
parser.add_argument(
|
21 |
+
"--hf_token",
|
22 |
+
type=str,
|
23 |
+
required=False,
|
24 |
+
help="Hugging Face token",
|
25 |
+
)
|
26 |
+
parser.add_argument(
|
27 |
+
"--hf_cache_dir",
|
28 |
+
type=str,
|
29 |
+
required=False,
|
30 |
+
default=os.path.expanduser("~/.cache/huggingface/hub"),
|
31 |
+
help="Cache directory to save the models, default is ~/.cache/huggingface/hub",
|
32 |
+
)
|
33 |
+
parser.add_argument(
|
34 |
+
"--assets_cache_dir",
|
35 |
+
type=str,
|
36 |
+
required=False,
|
37 |
+
default="results/app",
|
38 |
+
help="Cache directory to save the results, default is results/app",
|
39 |
+
)
|
40 |
+
parser.add_argument(
|
41 |
+
"--save_results",
|
42 |
+
action="store_true",
|
43 |
+
help="Save results",
|
44 |
+
)
|
45 |
+
parser.add_argument(
|
46 |
+
"--enable_ben2_for_mask_ref",
|
47 |
+
action=argparse.BooleanOptionalAction,
|
48 |
+
default=True,
|
49 |
+
help="Enable ben2 for mask reference (default: True)",
|
50 |
+
)
|
51 |
+
parser.add_argument(
|
52 |
+
"--enable_vlm_for_prompt",
|
53 |
+
action=argparse.BooleanOptionalAction,
|
54 |
+
default=False,
|
55 |
+
help="Enable vlm for prompt (default: True)",
|
56 |
+
)
|
57 |
+
|
58 |
+
return parser.parse_args()
|
59 |
+
|
60 |
+
|
61 |
+
def load_config(config_path):
|
62 |
+
"""Load configuration from file."""
|
63 |
+
return OmegaConf.load(config_path)
|
64 |
+
|
65 |
+
|
66 |
+
def setup_environment(args):
|
67 |
+
"""Setup environment variables from command line arguments."""
|
68 |
+
if args.hf_token is not None:
|
69 |
+
os.environ["HF_TOKEN"] = args.hf_token
|
70 |
+
|
71 |
+
if args.hf_cache_dir is not None:
|
72 |
+
os.environ["HF_HUB_CACHE"] = args.hf_cache_dir
|
app/constants.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
Constants and default values for IC-Custom application.
|
5 |
+
"""
|
6 |
+
from aspect_ratio_template import ASPECT_RATIO_TEMPLATE
|
7 |
+
|
8 |
+
# Aspect ratio constants
|
9 |
+
ASPECT_RATIO_LABELS = list(ASPECT_RATIO_TEMPLATE)
|
10 |
+
DEFAULT_ASPECT_RATIO = ASPECT_RATIO_LABELS[0]
|
11 |
+
|
12 |
+
# Colors and markers for segmentation
|
13 |
+
# OpenCV expects BGR colors; keep tuples as (R, G, B) for consistency across code.
|
14 |
+
SEGMENTATION_COLORS = [(255, 0, 0), (0, 255, 0)]
|
15 |
+
SEGMENTATION_MARKERS = [1, 5]
|
16 |
+
RGBA_COLORS = [(255, 0, 255, 255), (0, 255, 0, 255), (0, 0, 255, 255)]
|
17 |
+
|
18 |
+
# Magic-number constants
|
19 |
+
DEFAULT_BACKGROUND_BLEND_THRESHOLD = 0.5
|
20 |
+
DEFAULT_NUM_STEPS = 32
|
21 |
+
DEFAULT_GUIDANCE = 40
|
22 |
+
DEFAULT_TRUE_GS = 1
|
23 |
+
DEFAULT_NUM_IMAGES = 1
|
24 |
+
DEFAULT_SEED = -1 # -1 indicates random seed
|
25 |
+
DEFAULT_DILATION_KERNEL_SIZE = 7
|
26 |
+
|
27 |
+
# UI constants
|
28 |
+
DEFAULT_BRUSH_SIZE = 30
|
29 |
+
DEFAULT_MARKER_SIZE = 20
|
30 |
+
DEFAULT_MARKER_THICKNESS = 5
|
31 |
+
DEFAULT_MASK_ALPHA = 0.3
|
32 |
+
DEFAULT_COLOR_ALPHA = 0.7
|
33 |
+
|
34 |
+
# File naming
|
35 |
+
TIMESTAMP_FORMAT = "%Y%m%d_%H%M"
|
app/event_handlers.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
Event handlers for IC-Custom application.
|
5 |
+
"""
|
6 |
+
import gradio as gr
|
7 |
+
|
8 |
+
|
9 |
+
def setup_event_handlers(
|
10 |
+
# UI components
|
11 |
+
input_mask_mode, image_target_1, image_target_2, undo_target_seg_button,
|
12 |
+
custmization_mode, dilate_button, erode_button, bounding_box_button,
|
13 |
+
mask_gallery, md_input_mask_mode, md_target_image, md_mask_operation,
|
14 |
+
md_prompt, md_submit, result_gallery, image_target_state, mask_target_state,
|
15 |
+
seg_ref_mode, image_reference_ori_state, move_to_center,
|
16 |
+
image_reference, image_reference_rmbg_state,
|
17 |
+
# Functions
|
18 |
+
change_input_mask_mode, change_custmization_mode, change_seg_ref_mode,
|
19 |
+
init_image_target_1, init_image_target_2, init_image_reference,
|
20 |
+
get_point, undo_seg_points, get_brush,
|
21 |
+
# VLM buttons (UI components)
|
22 |
+
vlm_generate_btn, vlm_polish_btn,
|
23 |
+
# VLM functions
|
24 |
+
vlm_auto_generate, vlm_auto_polish,
|
25 |
+
dilate_mask, erode_mask, bounding_box,
|
26 |
+
run_model,
|
27 |
+
# Other components
|
28 |
+
selected_points, prompt,
|
29 |
+
use_background_preservation, background_blend_threshold, seed,
|
30 |
+
num_images_per_prompt, guidance, true_gs, num_steps, aspect_ratio,
|
31 |
+
submit_button,
|
32 |
+
# extra state
|
33 |
+
eg_idx,
|
34 |
+
):
|
35 |
+
"""Setup all event handlers for the application."""
|
36 |
+
|
37 |
+
# Change input mask mode: precise mask or user-drawn mask
|
38 |
+
input_mask_mode.change(
|
39 |
+
change_input_mask_mode,
|
40 |
+
[input_mask_mode, custmization_mode],
|
41 |
+
[image_target_1, image_target_2, undo_target_seg_button]
|
42 |
+
)
|
43 |
+
|
44 |
+
# Change customization mode: pos-aware or pos-free
|
45 |
+
custmization_mode.change(
|
46 |
+
change_custmization_mode,
|
47 |
+
[custmization_mode, input_mask_mode],
|
48 |
+
[image_target_1, image_target_2, undo_target_seg_button, dilate_button,
|
49 |
+
erode_button, bounding_box_button, md_input_mask_mode,
|
50 |
+
md_target_image, md_mask_operation, md_prompt, md_submit, input_mask_mode, mask_gallery]
|
51 |
+
)
|
52 |
+
|
53 |
+
# Remove background for reference image
|
54 |
+
seg_ref_mode.change(
|
55 |
+
change_seg_ref_mode,
|
56 |
+
[seg_ref_mode, image_reference_ori_state, move_to_center],
|
57 |
+
[image_reference, image_reference_rmbg_state]
|
58 |
+
)
|
59 |
+
|
60 |
+
# Initialize components only on user upload (not programmatic updates)
|
61 |
+
image_target_1.upload(
|
62 |
+
init_image_target_1,
|
63 |
+
[image_target_1],
|
64 |
+
[image_target_state, selected_points, prompt, mask_target_state, mask_gallery,
|
65 |
+
result_gallery, use_background_preservation, background_blend_threshold, seed,
|
66 |
+
num_images_per_prompt, guidance, true_gs, num_steps, aspect_ratio]
|
67 |
+
)
|
68 |
+
|
69 |
+
image_target_2.upload(
|
70 |
+
init_image_target_2,
|
71 |
+
[image_target_2],
|
72 |
+
[image_target_state, selected_points, prompt, mask_target_state, mask_gallery,
|
73 |
+
result_gallery, use_background_preservation, background_blend_threshold, seed,
|
74 |
+
num_images_per_prompt, guidance, true_gs, num_steps, aspect_ratio]
|
75 |
+
)
|
76 |
+
|
77 |
+
image_reference.upload(
|
78 |
+
init_image_reference,
|
79 |
+
[image_reference],
|
80 |
+
[image_reference_ori_state, image_reference_rmbg_state, image_target_state,
|
81 |
+
mask_target_state, prompt, mask_gallery, result_gallery, image_target_1,
|
82 |
+
image_target_2, selected_points, input_mask_mode, seg_ref_mode, move_to_center,
|
83 |
+
use_background_preservation, background_blend_threshold, seed,
|
84 |
+
num_images_per_prompt, guidance, true_gs, num_steps, aspect_ratio]
|
85 |
+
)
|
86 |
+
|
87 |
+
# SAM for image_target_1
|
88 |
+
image_target_1.select(
|
89 |
+
get_point,
|
90 |
+
[image_target_state, selected_points],
|
91 |
+
[image_target_1, mask_target_state, mask_gallery],
|
92 |
+
)
|
93 |
+
|
94 |
+
undo_target_seg_button.click(
|
95 |
+
undo_seg_points,
|
96 |
+
[image_target_state, selected_points],
|
97 |
+
[image_target_1, mask_target_state, mask_gallery]
|
98 |
+
)
|
99 |
+
|
100 |
+
# Brush for image_target_2
|
101 |
+
image_target_2.change(
|
102 |
+
get_brush,
|
103 |
+
[image_target_2],
|
104 |
+
[mask_target_state, mask_gallery],
|
105 |
+
)
|
106 |
+
|
107 |
+
# VLM auto generate
|
108 |
+
vlm_generate_btn.click(
|
109 |
+
vlm_auto_generate,
|
110 |
+
[image_target_state, image_reference_ori_state, mask_target_state, custmization_mode],
|
111 |
+
[prompt]
|
112 |
+
)
|
113 |
+
|
114 |
+
# VLM auto polish
|
115 |
+
vlm_polish_btn.click(
|
116 |
+
vlm_auto_polish,
|
117 |
+
[prompt, custmization_mode],
|
118 |
+
[prompt]
|
119 |
+
)
|
120 |
+
|
121 |
+
# Mask operations
|
122 |
+
dilate_button.click(
|
123 |
+
dilate_mask,
|
124 |
+
[mask_target_state, image_target_state],
|
125 |
+
[mask_target_state, mask_gallery]
|
126 |
+
)
|
127 |
+
|
128 |
+
erode_button.click(
|
129 |
+
erode_mask,
|
130 |
+
[mask_target_state, image_target_state],
|
131 |
+
[mask_target_state, mask_gallery]
|
132 |
+
)
|
133 |
+
|
134 |
+
bounding_box_button.click(
|
135 |
+
bounding_box,
|
136 |
+
[mask_target_state, image_target_state],
|
137 |
+
[mask_target_state, mask_gallery]
|
138 |
+
)
|
139 |
+
|
140 |
+
# Run function
|
141 |
+
ips = [
|
142 |
+
image_target_state, mask_target_state, image_reference_ori_state,
|
143 |
+
image_reference_rmbg_state, prompt, seed, guidance, true_gs, num_steps,
|
144 |
+
num_images_per_prompt, use_background_preservation, background_blend_threshold,
|
145 |
+
aspect_ratio, custmization_mode, seg_ref_mode, input_mask_mode,
|
146 |
+
]
|
147 |
+
|
148 |
+
submit_button.click(
|
149 |
+
fn=run_model,
|
150 |
+
inputs=ips,
|
151 |
+
outputs=[result_gallery, seed, prompt],
|
152 |
+
show_progress=True,
|
153 |
+
)
|
154 |
+
|
155 |
+
|
app/examples.py
ADDED
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
|
3 |
+
make_dict = lambda x : { 'background': Image.open(x).convert("RGBA"), 'layers': [Image.new("RGBA", Image.open(x).size, (255, 255, 255, 0))], 'composite': Image.open(x).convert("RGBA") }
|
4 |
+
|
5 |
+
null_dict = {
|
6 |
+
'background': None,
|
7 |
+
'composite': None,
|
8 |
+
'layers': []
|
9 |
+
}
|
10 |
+
|
11 |
+
IMG_REF = [
|
12 |
+
## pos-aware: precise mask
|
13 |
+
"assets/gradio/pos_aware/001/img_ref.png",
|
14 |
+
"assets/gradio/pos_aware/002/img_ref.png",
|
15 |
+
## pos-aware: User-drawn mask mask
|
16 |
+
"assets/gradio/pos_aware/003/img_ref.png",
|
17 |
+
"assets/gradio/pos_aware/004/img_ref.png",
|
18 |
+
"assets/gradio/pos_aware/005/img_ref.png",
|
19 |
+
## pos-free
|
20 |
+
"assets/gradio/pos_free/001/img_ref.png",
|
21 |
+
"assets/gradio/pos_free/002/img_ref.png",
|
22 |
+
"assets/gradio/pos_free/003/img_ref.png",
|
23 |
+
"assets/gradio/pos_free/004/img_ref.png",
|
24 |
+
]
|
25 |
+
|
26 |
+
IMG_TGT1 = [
|
27 |
+
## pos-aware: precise mask
|
28 |
+
"assets/gradio/pos_aware/001/img_target.png",
|
29 |
+
"assets/gradio/pos_aware/002/img_target.png",
|
30 |
+
## pos-aware: User-drawn mask mask
|
31 |
+
None,
|
32 |
+
None,
|
33 |
+
None,
|
34 |
+
## pos-free
|
35 |
+
"assets/gradio/pos_free/001/img_target.png",
|
36 |
+
"assets/gradio/pos_free/002/img_target.png",
|
37 |
+
"assets/gradio/pos_free/003/img_target.png",
|
38 |
+
"assets/gradio/pos_free/004/img_target.png",
|
39 |
+
]
|
40 |
+
|
41 |
+
IMG_TGT2 = [
|
42 |
+
## pos-aware: precise mask
|
43 |
+
null_dict,
|
44 |
+
null_dict,
|
45 |
+
## pos-aware: User-drawn mask mask
|
46 |
+
make_dict("assets/gradio/pos_aware/003/img_target.png"),
|
47 |
+
make_dict("assets/gradio/pos_aware/004/img_target.png"),
|
48 |
+
make_dict("assets/gradio/pos_aware/005/img_target.png"),
|
49 |
+
## pos-free
|
50 |
+
null_dict,
|
51 |
+
null_dict,
|
52 |
+
null_dict,
|
53 |
+
null_dict,
|
54 |
+
]
|
55 |
+
|
56 |
+
MASK_TGT = [
|
57 |
+
## pos-aware: precise mask
|
58 |
+
"assets/gradio/pos_aware/001/mask_target.png",
|
59 |
+
"assets/gradio/pos_aware/002/mask_target.png",
|
60 |
+
## pos-aware: User-drawn mask mask
|
61 |
+
"assets/gradio/pos_aware/003/mask_target.png",
|
62 |
+
"assets/gradio/pos_aware/004/mask_target.png",
|
63 |
+
"assets/gradio/pos_aware/005/mask_target.png",
|
64 |
+
## pos-free
|
65 |
+
"assets/gradio/pos_free/001/mask_target.png",
|
66 |
+
"assets/gradio/pos_free/002/mask_target.png",
|
67 |
+
"assets/gradio/pos_free/003/mask_target.png",
|
68 |
+
"assets/gradio/pos_free/004/mask_target.png",
|
69 |
+
]
|
70 |
+
|
71 |
+
CUSTOM_MODE = [
|
72 |
+
## pos-aware
|
73 |
+
"Position-aware",
|
74 |
+
"Position-aware",
|
75 |
+
"Position-aware",
|
76 |
+
"Position-aware",
|
77 |
+
"Position-aware",
|
78 |
+
## pos-free
|
79 |
+
"Position-free",
|
80 |
+
"Position-free",
|
81 |
+
"Position-free",
|
82 |
+
"Position-free",
|
83 |
+
]
|
84 |
+
|
85 |
+
INPUT_MASK_MODE = [
|
86 |
+
## pos-aware: precise mask
|
87 |
+
"Precise mask",
|
88 |
+
"Precise mask",
|
89 |
+
## pos-aware: User-drawn mask mask
|
90 |
+
"User-drawn mask",
|
91 |
+
"User-drawn mask",
|
92 |
+
"User-drawn mask",
|
93 |
+
## pos-free
|
94 |
+
"Precise mask",
|
95 |
+
"Precise mask",
|
96 |
+
"Precise mask",
|
97 |
+
"Precise mask",
|
98 |
+
]
|
99 |
+
|
100 |
+
SEG_REF_MODE = [
|
101 |
+
## pos-aware
|
102 |
+
"Full Ref",
|
103 |
+
"Full Ref",
|
104 |
+
"Full Ref",
|
105 |
+
"Full Ref",
|
106 |
+
"Full Ref",
|
107 |
+
## pos-free
|
108 |
+
"Full Ref",
|
109 |
+
"Full Ref",
|
110 |
+
"Full Ref",
|
111 |
+
"Full Ref",
|
112 |
+
]
|
113 |
+
|
114 |
+
PROMPTS = [
|
115 |
+
## pos-aware: precise mask
|
116 |
+
"",
|
117 |
+
"",
|
118 |
+
## pos-aware: User-drawn mask mask
|
119 |
+
"A delicate necklace with a mother-of-pearl clover pendant hangs gracefully around the neck of a woman dressed in a black pinstripe blazer.",
|
120 |
+
"",
|
121 |
+
"",
|
122 |
+
## pos-free
|
123 |
+
"TThe charming, soft plush toy is joyfully wandering through a lush, dense jungle, surrounded by vibrant green foliage and towering trees.",
|
124 |
+
"A bright yellow alarm clock sits on a wooden desk next to a stack of books in a cozy, sunlit room.",
|
125 |
+
"A Lego figure dressed in a vibrant chicken costume, leaning against a wooden chair, surrounded by lush green grass and blooming flowers.",
|
126 |
+
"The crocheted gingerbread man is perched on a tree branch in a dense forest, with sunlight filtering through the leaves, casting dappled shadows around him."
|
127 |
+
]
|
128 |
+
|
129 |
+
IMG_GEN = [
|
130 |
+
## pos-aware: precise mask
|
131 |
+
"assets/gradio/pos_aware/001/img_gen.png",
|
132 |
+
"assets/gradio/pos_aware/002/img_gen.png",
|
133 |
+
## pos-aware: User-drawn mask mask
|
134 |
+
"assets/gradio/pos_aware/003/img_gen.png",
|
135 |
+
"assets/gradio/pos_aware/004/img_gen.png",
|
136 |
+
"assets/gradio/pos_aware/005/img_gen.png",
|
137 |
+
## pos-free
|
138 |
+
"assets/gradio/pos_free/001/img_gen.png",
|
139 |
+
"assets/gradio/pos_free/002/img_gen.png",
|
140 |
+
"assets/gradio/pos_free/003/img_gen.png",
|
141 |
+
"assets/gradio/pos_free/004/img_gen.png",
|
142 |
+
]
|
143 |
+
|
144 |
+
SEED = [
|
145 |
+
## pos-aware
|
146 |
+
97175498,
|
147 |
+
2126677963,
|
148 |
+
346969695,
|
149 |
+
1172525388,
|
150 |
+
268683460,
|
151 |
+
## pos-free
|
152 |
+
2126677963,
|
153 |
+
418898253,
|
154 |
+
2126677963,
|
155 |
+
2126677963
|
156 |
+
]
|
157 |
+
|
158 |
+
TRUE_GS = [
|
159 |
+
# pos-aware
|
160 |
+
1,
|
161 |
+
1,
|
162 |
+
1,
|
163 |
+
1,
|
164 |
+
1,
|
165 |
+
# pos-free
|
166 |
+
3,
|
167 |
+
3,
|
168 |
+
3,
|
169 |
+
3,
|
170 |
+
]
|
171 |
+
|
172 |
+
NUM_STEPS = [
|
173 |
+
## pos-aware
|
174 |
+
32,
|
175 |
+
32,
|
176 |
+
32,
|
177 |
+
32,
|
178 |
+
32,
|
179 |
+
## pos-free
|
180 |
+
20,
|
181 |
+
20,
|
182 |
+
20,
|
183 |
+
20,
|
184 |
+
]
|
185 |
+
|
186 |
+
GUIDANCE = [
|
187 |
+
## pos-aware
|
188 |
+
40,
|
189 |
+
48,
|
190 |
+
40,
|
191 |
+
48,
|
192 |
+
48,
|
193 |
+
## pos-free
|
194 |
+
40,
|
195 |
+
40,
|
196 |
+
40,
|
197 |
+
40,
|
198 |
+
]
|
199 |
+
|
200 |
+
GRADIO_EXAMPLES = [
|
201 |
+
[IMG_REF[0], IMG_TGT1[0], IMG_TGT2[0], CUSTOM_MODE[0], INPUT_MASK_MODE[0], SEG_REF_MODE[0], PROMPTS[0], SEED[0], TRUE_GS[0], '0', NUM_STEPS[0], GUIDANCE[0]],
|
202 |
+
[IMG_REF[1], IMG_TGT1[1], IMG_TGT2[1], CUSTOM_MODE[1], INPUT_MASK_MODE[1], SEG_REF_MODE[1], PROMPTS[1], SEED[1], TRUE_GS[1], '1', NUM_STEPS[1], GUIDANCE[1]],
|
203 |
+
[IMG_REF[2], IMG_TGT1[2], IMG_TGT2[2], CUSTOM_MODE[2], INPUT_MASK_MODE[2], SEG_REF_MODE[2], PROMPTS[2], SEED[2], TRUE_GS[2], '2', NUM_STEPS[2], GUIDANCE[2]],
|
204 |
+
[IMG_REF[3], IMG_TGT1[3], IMG_TGT2[3], CUSTOM_MODE[3], INPUT_MASK_MODE[3], SEG_REF_MODE[3], PROMPTS[3], SEED[3], TRUE_GS[3], '3', NUM_STEPS[3], GUIDANCE[3]],
|
205 |
+
[IMG_REF[4], IMG_TGT1[4], IMG_TGT2[4], CUSTOM_MODE[4], INPUT_MASK_MODE[4], SEG_REF_MODE[4], PROMPTS[4], SEED[4], TRUE_GS[4], '4', NUM_STEPS[4], GUIDANCE[4]],
|
206 |
+
[IMG_REF[5], IMG_TGT1[5], IMG_TGT2[5], CUSTOM_MODE[5], INPUT_MASK_MODE[5], SEG_REF_MODE[5], PROMPTS[5], SEED[5], TRUE_GS[5], '5', NUM_STEPS[5], GUIDANCE[5]],
|
207 |
+
[IMG_REF[6], IMG_TGT1[6], IMG_TGT2[6], CUSTOM_MODE[6], INPUT_MASK_MODE[6], SEG_REF_MODE[6], PROMPTS[6], SEED[6], TRUE_GS[6], '6', NUM_STEPS[6], GUIDANCE[6]],
|
208 |
+
[IMG_REF[7], IMG_TGT1[7], IMG_TGT2[7], CUSTOM_MODE[7], INPUT_MASK_MODE[7], SEG_REF_MODE[7], PROMPTS[7], SEED[7], TRUE_GS[7], '7', NUM_STEPS[7], GUIDANCE[7]],
|
209 |
+
[IMG_REF[8], IMG_TGT1[8], IMG_TGT2[8], CUSTOM_MODE[8], INPUT_MASK_MODE[8], SEG_REF_MODE[8], PROMPTS[8], SEED[8], TRUE_GS[8], '8', NUM_STEPS[8], GUIDANCE[8]],
|
210 |
+
]
|
app/metainfo.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#### metainfo ####
|
2 |
+
head = r"""
|
3 |
+
<div class="elegant-header">
|
4 |
+
<div class="header-content">
|
5 |
+
<!-- Main title -->
|
6 |
+
<h1 class="main-title">
|
7 |
+
<span class="title-icon">🎨</span>
|
8 |
+
<span class="title-text">IC-Custom</span>
|
9 |
+
</h1>
|
10 |
+
|
11 |
+
<!-- Subtitle -->
|
12 |
+
<p class="subtitle">Transform your images with AI-powered customization</p>
|
13 |
+
|
14 |
+
<!-- Action badges -->
|
15 |
+
<div class="header-badges">
|
16 |
+
<a href="https://liyaowei-stu.github.io/project/IC_Custom/" class="badge-link">
|
17 |
+
<span class="badge-icon">🔗</span>
|
18 |
+
<span class="badge-text">Project</span>
|
19 |
+
</a>
|
20 |
+
<a href="https://arxiv.org/abs/2507.01926" class="badge-link">
|
21 |
+
<span class="badge-icon">📄</span>
|
22 |
+
<span class="badge-text">Paper</span>
|
23 |
+
</a>
|
24 |
+
<a href="https://github.com/TencentARC/IC-Custom" class="badge-link">
|
25 |
+
<span class="badge-icon">💻</span>
|
26 |
+
<span class="badge-text">Code</span>
|
27 |
+
</a>
|
28 |
+
</div>
|
29 |
+
</div>
|
30 |
+
</div>
|
31 |
+
"""
|
32 |
+
|
33 |
+
|
34 |
+
getting_started = r"""
|
35 |
+
<div class="getting-started-container">
|
36 |
+
<!-- Header -->
|
37 |
+
<div class="guide-header">
|
38 |
+
<h3 class="guide-title">🚀 Quick Start Guide</h3>
|
39 |
+
<p class="guide-subtitle">Follow these steps to customize your images with IC-Custom</p>
|
40 |
+
</div>
|
41 |
+
|
42 |
+
<!-- What is IC-Custom -->
|
43 |
+
<div class="info-card">
|
44 |
+
<div class="info-content">
|
45 |
+
<strong class="brand-name">IC-Custom</strong> offers two customization modes:
|
46 |
+
<span class="mode-badge position-aware">Position-aware</span>
|
47 |
+
(precise placement in masked areas) and
|
48 |
+
<span class="mode-badge position-free">Position-free</span>
|
49 |
+
(subject-driven generation).
|
50 |
+
</div>
|
51 |
+
</div>
|
52 |
+
|
53 |
+
<!-- Common Steps -->
|
54 |
+
<div class="step-card common-steps">
|
55 |
+
<div class="step-header">
|
56 |
+
<span class="step-number">1</span>
|
57 |
+
Initial Setup (Both Modes)
|
58 |
+
</div>
|
59 |
+
<ul class="step-list">
|
60 |
+
<li>Choose your <strong>customization mode</strong></li>
|
61 |
+
<li>Upload a <strong>reference image</strong> 📸</li>
|
62 |
+
</ul>
|
63 |
+
</div>
|
64 |
+
|
65 |
+
<!-- Position-aware Mode -->
|
66 |
+
<div class="step-card position-aware-steps">
|
67 |
+
<div class="step-header">
|
68 |
+
<span class="step-number">2A</span>
|
69 |
+
🎯 Position-aware Mode Steps
|
70 |
+
</div>
|
71 |
+
<ul class="step-list">
|
72 |
+
<li>Select <strong>input mask mode</strong> (precise mask or user-drawn mask)</li>
|
73 |
+
<li>Upload <strong>target image</strong> and create mask (click for SAM or brush directly)</li>
|
74 |
+
<li>Add <strong>text prompt</strong> (optional) - use VLM buttons for auto-generation</li>
|
75 |
+
<li>Review and refine your <strong>mask</strong> using mask tools if needed</li>
|
76 |
+
<li>Click <span class="run-button position-aware">Run</span> ✨</li>
|
77 |
+
</ul>
|
78 |
+
</div>
|
79 |
+
|
80 |
+
<!-- Position-free Mode -->
|
81 |
+
<div class="step-card position-free-steps">
|
82 |
+
<div class="step-header">
|
83 |
+
<span class="step-number">2B</span>
|
84 |
+
🎨 Position-free Mode Steps
|
85 |
+
</div>
|
86 |
+
<ul class="step-list">
|
87 |
+
<li>Write your <strong>text prompt</strong> (required) - describe the target scene</li>
|
88 |
+
<li>Use VLM buttons for prompt auto-generation or polishing (if enabled)</li>
|
89 |
+
<li>Click <span class="run-button position-free">Run</span> ✨</li>
|
90 |
+
</ul>
|
91 |
+
</div>
|
92 |
+
|
93 |
+
<!-- Quick Tips -->
|
94 |
+
<div class="tips-card">
|
95 |
+
<div class="tips-content">
|
96 |
+
<strong>💡 Quick Tips:</strong>
|
97 |
+
Use <kbd class="key-hint">Alt + "-"</kbd> or <kbd class="key-hint">⌘ + "-"</kbd> to zoom out for better operation •
|
98 |
+
Adjust settings in <kbd class="key-hint">Advanced Options</kbd> • Use mask operations (<kbd class="key-hint">dilate</kbd>/<kbd class="key-hint">erode</kbd>/<kbd class="key-hint">bbox</kbd>) for better results •
|
99 |
+
Try different <kbd class="key-hint">seeds</kbd> for varied outputs
|
100 |
+
</div>
|
101 |
+
</div>
|
102 |
+
|
103 |
+
<!-- Final Message -->
|
104 |
+
<div class="final-message">
|
105 |
+
<div class="final-text">
|
106 |
+
🎉 Ready to start? Collapse this guide and begin customizing!
|
107 |
+
</div>
|
108 |
+
</div>
|
109 |
+
</div>
|
110 |
+
"""
|
111 |
+
|
112 |
+
|
113 |
+
citation = r"""
|
114 |
+
If IC-Custom is helpful, please help to ⭐ the <a href='https://github.com/TencentARC/IC-Custom' target='_blank'>Github Repo</a>. Thanks!
|
115 |
+
[](https://github.com/TencentARC/IC-Custom)
|
116 |
+
---
|
117 |
+
📝 **Citation**
|
118 |
+
<br>
|
119 |
+
If our work is useful for your research, please consider citing:
|
120 |
+
```bibtex
|
121 |
+
@article{li2025ic,
|
122 |
+
title={IC-Custom: Diverse Image Customization via In-Context Learning},
|
123 |
+
author={Li, Yaowei and Li, Xiaoyu and Zhang, Zhaoyang and Bian, Yuxuan and Liu, Gan and Li, Xinyuan and Xu, Jiale and Hu, Wenbo and Liu, Yating and Li, Lingen and others},
|
124 |
+
journal={arXiv preprint arXiv:2507.01926},
|
125 |
+
year={2025}
|
126 |
+
}
|
127 |
+
```
|
128 |
+
📧 **Contact**
|
129 |
+
<br>
|
130 |
+
If you have any questions, please feel free to reach me out at <b>[email protected]</b>.
|
131 |
+
"""
|
app/stylesheets.py
ADDED
@@ -0,0 +1,1679 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
Centralized CSS and JS for the IC-Custom app UI.
|
5 |
+
|
6 |
+
Expose helpers:
|
7 |
+
- get_css(): return a single CSS string for gradio Blocks(css=...)
|
8 |
+
- get_js(): return an JS for gradio.
|
9 |
+
"""
|
10 |
+
|
11 |
+
|
12 |
+
def get_css() -> str:
|
13 |
+
return r"""
|
14 |
+
/* Global Optimization Effects - No Layout Changes */
|
15 |
+
|
16 |
+
/* Apple-style segmented control for radio buttons */
|
17 |
+
#customization_mode_radio .wrap, #input_mask_mode_radio .wrap, #seg_ref_mode_radio .wrap {
|
18 |
+
display: flex;
|
19 |
+
flex-wrap: nowrap;
|
20 |
+
justify-content: center;
|
21 |
+
align-items: center;
|
22 |
+
gap: 0;
|
23 |
+
background: rgba(255, 255, 255, 0.8);
|
24 |
+
border: 1px solid var(--neutral-200);
|
25 |
+
border-radius: 10px;
|
26 |
+
padding: 3px;
|
27 |
+
backdrop-filter: blur(12px);
|
28 |
+
-webkit-backdrop-filter: blur(12px);
|
29 |
+
box-shadow: 0 2px 8px rgba(15, 23, 42, 0.08);
|
30 |
+
}
|
31 |
+
|
32 |
+
#customization_mode_radio .wrap label, #input_mask_mode_radio .wrap label, #seg_ref_mode_radio .wrap label {
|
33 |
+
display: flex;
|
34 |
+
flex: 1;
|
35 |
+
justify-content: center;
|
36 |
+
align-items: center;
|
37 |
+
margin: 0;
|
38 |
+
padding: 10px 16px;
|
39 |
+
box-sizing: border-box;
|
40 |
+
border-radius: 7px;
|
41 |
+
transition: all 0.3s cubic-bezier(0.4, 0, 0.2, 1);
|
42 |
+
background: transparent;
|
43 |
+
border: none;
|
44 |
+
font-weight: 500;
|
45 |
+
font-size: 0.9rem;
|
46 |
+
color: var(--text-secondary);
|
47 |
+
cursor: pointer;
|
48 |
+
position: relative;
|
49 |
+
white-space: nowrap;
|
50 |
+
min-width: 0;
|
51 |
+
}
|
52 |
+
|
53 |
+
/* Hide the actual radio input */
|
54 |
+
#customization_mode_radio .wrap label input[type="radio"],
|
55 |
+
#input_mask_mode_radio .wrap label input[type="radio"],
|
56 |
+
#seg_ref_mode_radio .wrap label input[type="radio"] {
|
57 |
+
display: none;
|
58 |
+
}
|
59 |
+
|
60 |
+
/* Hover states */
|
61 |
+
#customization_mode_radio .wrap label:hover,
|
62 |
+
#input_mask_mode_radio .wrap label:hover,
|
63 |
+
#seg_ref_mode_radio .wrap label:hover {
|
64 |
+
background: rgba(14, 165, 233, 0.1);
|
65 |
+
color: var(--primary-blue);
|
66 |
+
}
|
67 |
+
|
68 |
+
/* Selected state with smooth background */
|
69 |
+
#customization_mode_radio .wrap label:has(input[type="radio"]:checked),
|
70 |
+
#input_mask_mode_radio .wrap label:has(input[type="radio"]:checked),
|
71 |
+
#seg_ref_mode_radio .wrap label:has(input[type="radio"]:checked) {
|
72 |
+
background: var(--primary-blue);
|
73 |
+
color: white;
|
74 |
+
font-weight: 600;
|
75 |
+
box-shadow: 0 2px 6px rgba(14, 165, 233, 0.25);
|
76 |
+
transform: none;
|
77 |
+
}
|
78 |
+
|
79 |
+
/* Fallback for browsers that don't support :has() */
|
80 |
+
#customization_mode_radio .wrap label input[type="radio"]:checked + *,
|
81 |
+
#input_mask_mode_radio .wrap label input[type="radio"]:checked + *,
|
82 |
+
#seg_ref_mode_radio .wrap label input[type="radio"]:checked + * {
|
83 |
+
color: white;
|
84 |
+
}
|
85 |
+
|
86 |
+
#customization_mode_radio .wrap:has(input[type="radio"]:checked) label:has(input[type="radio"]:checked),
|
87 |
+
#input_mask_mode_radio .wrap:has(input[type="radio"]:checked) label:has(input[type="radio"]:checked),
|
88 |
+
#seg_ref_mode_radio .wrap:has(input[type="radio"]:checked) label:has(input[type="radio"]:checked) {
|
89 |
+
background: var(--primary-blue);
|
90 |
+
}
|
91 |
+
|
92 |
+
/* Active state */
|
93 |
+
#customization_mode_radio .wrap label:active,
|
94 |
+
#input_mask_mode_radio .wrap label:active,
|
95 |
+
#seg_ref_mode_radio .wrap label:active {
|
96 |
+
transform: scale(0.98);
|
97 |
+
}
|
98 |
+
|
99 |
+
/* Elegant header styling */
|
100 |
+
.elegant-header {
|
101 |
+
text-align: center;
|
102 |
+
margin: 0 0 2rem 0;
|
103 |
+
padding: 0;
|
104 |
+
}
|
105 |
+
|
106 |
+
.header-content {
|
107 |
+
display: inline-block;
|
108 |
+
padding: 1.8rem 2.5rem;
|
109 |
+
background: linear-gradient(135deg,
|
110 |
+
rgba(255, 255, 255, 0.1) 0%,
|
111 |
+
rgba(255, 255, 255, 0.05) 100%);
|
112 |
+
border: 1px solid rgba(255, 255, 255, 0.15);
|
113 |
+
border-radius: 20px;
|
114 |
+
backdrop-filter: blur(12px);
|
115 |
+
-webkit-backdrop-filter: blur(12px);
|
116 |
+
box-shadow:
|
117 |
+
0 8px 32px rgba(15, 23, 42, 0.04),
|
118 |
+
inset 0 1px 0 rgba(255, 255, 255, 0.2);
|
119 |
+
transition: all 0.4s ease;
|
120 |
+
position: relative;
|
121 |
+
overflow: hidden;
|
122 |
+
}
|
123 |
+
|
124 |
+
.header-content::before {
|
125 |
+
content: '';
|
126 |
+
position: absolute;
|
127 |
+
top: 0;
|
128 |
+
left: -100%;
|
129 |
+
width: 100%;
|
130 |
+
height: 100%;
|
131 |
+
background: linear-gradient(90deg,
|
132 |
+
transparent,
|
133 |
+
rgba(255, 255, 255, 0.1),
|
134 |
+
transparent);
|
135 |
+
transition: left 0.6s ease;
|
136 |
+
}
|
137 |
+
|
138 |
+
.header-content:hover::before {
|
139 |
+
left: 100%;
|
140 |
+
}
|
141 |
+
|
142 |
+
.header-content:hover {
|
143 |
+
transform: translateY(-2px);
|
144 |
+
box-shadow:
|
145 |
+
0 12px 40px rgba(15, 23, 42, 0.08),
|
146 |
+
inset 0 1px 0 rgba(255, 255, 255, 0.3);
|
147 |
+
border-color: rgba(14, 165, 233, 0.2);
|
148 |
+
}
|
149 |
+
|
150 |
+
/* Main title styling */
|
151 |
+
.main-title {
|
152 |
+
margin: 0 0 0.8rem 0;
|
153 |
+
font-size: 2.4rem;
|
154 |
+
font-weight: 800;
|
155 |
+
display: flex;
|
156 |
+
align-items: center;
|
157 |
+
justify-content: center;
|
158 |
+
gap: 0.5rem;
|
159 |
+
}
|
160 |
+
|
161 |
+
.title-icon {
|
162 |
+
font-size: 2.2rem;
|
163 |
+
filter: drop-shadow(0 2px 4px rgba(0, 0, 0, 0.1));
|
164 |
+
}
|
165 |
+
|
166 |
+
.title-text {
|
167 |
+
background: linear-gradient(135deg, #0ea5e9 0%, #06b6d4 50%, #10b981 100%);
|
168 |
+
-webkit-background-clip: text;
|
169 |
+
-webkit-text-fill-color: transparent;
|
170 |
+
background-clip: text;
|
171 |
+
text-shadow: none;
|
172 |
+
position: relative;
|
173 |
+
}
|
174 |
+
|
175 |
+
/* Subtitle styling */
|
176 |
+
.subtitle {
|
177 |
+
margin: 0 0 1.2rem 0;
|
178 |
+
font-size: 1rem;
|
179 |
+
color: #64748b;
|
180 |
+
font-weight: 500;
|
181 |
+
letter-spacing: 0.025em;
|
182 |
+
opacity: 0.9;
|
183 |
+
}
|
184 |
+
|
185 |
+
/* Header badges container */
|
186 |
+
.header-badges {
|
187 |
+
display: flex;
|
188 |
+
justify-content: center;
|
189 |
+
gap: 0.8rem;
|
190 |
+
flex-wrap: wrap;
|
191 |
+
}
|
192 |
+
|
193 |
+
/* Individual badge links */
|
194 |
+
.badge-link {
|
195 |
+
display: inline-flex;
|
196 |
+
align-items: center;
|
197 |
+
gap: 0.4rem;
|
198 |
+
padding: 0.5rem 1rem;
|
199 |
+
background: rgba(255, 255, 255, 0.15);
|
200 |
+
border: 1px solid rgba(255, 255, 255, 0.2);
|
201 |
+
border-radius: 12px;
|
202 |
+
color: #475569;
|
203 |
+
text-decoration: none;
|
204 |
+
font-weight: 500;
|
205 |
+
font-size: 0.9rem;
|
206 |
+
transition: all 0.3s ease;
|
207 |
+
backdrop-filter: blur(4px);
|
208 |
+
-webkit-backdrop-filter: blur(4px);
|
209 |
+
position: relative;
|
210 |
+
overflow: hidden;
|
211 |
+
}
|
212 |
+
|
213 |
+
.badge-link::before {
|
214 |
+
content: '';
|
215 |
+
position: absolute;
|
216 |
+
top: 0;
|
217 |
+
left: -100%;
|
218 |
+
width: 100%;
|
219 |
+
height: 100%;
|
220 |
+
background: var(--primary-gradient);
|
221 |
+
transition: left 0.3s ease;
|
222 |
+
z-index: -1;
|
223 |
+
}
|
224 |
+
|
225 |
+
.badge-link:hover::before {
|
226 |
+
left: 0;
|
227 |
+
}
|
228 |
+
|
229 |
+
.badge-link:hover {
|
230 |
+
transform: translateY(-2px);
|
231 |
+
color: white;
|
232 |
+
border-color: transparent;
|
233 |
+
box-shadow: 0 4px 12px rgba(14, 165, 233, 0.3);
|
234 |
+
}
|
235 |
+
|
236 |
+
.badge-icon {
|
237 |
+
font-size: 1rem;
|
238 |
+
opacity: 0.8;
|
239 |
+
}
|
240 |
+
|
241 |
+
.badge-text {
|
242 |
+
font-weight: 600;
|
243 |
+
}
|
244 |
+
|
245 |
+
/* Getting Started Guide Styling */
|
246 |
+
.getting-started-container {
|
247 |
+
padding: 1.5rem;
|
248 |
+
background: linear-gradient(135deg, rgba(255, 255, 255, 0.1), rgba(255, 255, 255, 0.05));
|
249 |
+
border-radius: 12px;
|
250 |
+
border: 1px solid rgba(255, 255, 255, 0.15);
|
251 |
+
backdrop-filter: blur(8px);
|
252 |
+
-webkit-backdrop-filter: blur(8px);
|
253 |
+
box-shadow: 0 4px 16px rgba(15, 23, 42, 0.06);
|
254 |
+
}
|
255 |
+
|
256 |
+
.guide-header {
|
257 |
+
text-align: center;
|
258 |
+
margin-bottom: 1.5rem;
|
259 |
+
}
|
260 |
+
|
261 |
+
.guide-title {
|
262 |
+
color: var(--text-primary);
|
263 |
+
margin: 0 0 0.5rem 0;
|
264 |
+
font-size: 1.2rem;
|
265 |
+
font-weight: 700;
|
266 |
+
}
|
267 |
+
|
268 |
+
.guide-subtitle {
|
269 |
+
color: var(--text-muted);
|
270 |
+
margin: 0;
|
271 |
+
font-size: 0.9rem;
|
272 |
+
opacity: 0.9;
|
273 |
+
}
|
274 |
+
|
275 |
+
/* Info card */
|
276 |
+
.info-card {
|
277 |
+
background: rgba(255, 255, 255, 0.4);
|
278 |
+
border-radius: 8px;
|
279 |
+
padding: 1rem;
|
280 |
+
margin-bottom: 1.2rem;
|
281 |
+
border-left: 3px solid var(--primary-blue);
|
282 |
+
backdrop-filter: blur(4px);
|
283 |
+
-webkit-backdrop-filter: blur(4px);
|
284 |
+
transition: all 0.3s ease;
|
285 |
+
}
|
286 |
+
|
287 |
+
.info-card:hover {
|
288 |
+
background: rgba(255, 255, 255, 0.5);
|
289 |
+
transform: translateX(2px);
|
290 |
+
}
|
291 |
+
|
292 |
+
.info-content {
|
293 |
+
color: var(--text-secondary);
|
294 |
+
font-size: 0.9rem;
|
295 |
+
line-height: 1.5;
|
296 |
+
}
|
297 |
+
|
298 |
+
.brand-name {
|
299 |
+
color: var(--primary-blue);
|
300 |
+
font-weight: 700;
|
301 |
+
}
|
302 |
+
|
303 |
+
/* Mode badges */
|
304 |
+
.mode-badge {
|
305 |
+
padding: 0.2rem 0.4rem;
|
306 |
+
border-radius: 4px;
|
307 |
+
font-size: 0.8rem;
|
308 |
+
font-weight: 600;
|
309 |
+
margin: 0 0.2rem;
|
310 |
+
transition: all 0.2s ease;
|
311 |
+
}
|
312 |
+
|
313 |
+
.mode-badge.position-aware {
|
314 |
+
background: var(--badge-blue-bg);
|
315 |
+
color: var(--badge-blue-text);
|
316 |
+
}
|
317 |
+
|
318 |
+
.mode-badge.position-free {
|
319 |
+
background: var(--badge-green-bg);
|
320 |
+
color: var(--badge-green-text);
|
321 |
+
}
|
322 |
+
|
323 |
+
/* Step cards */
|
324 |
+
.step-card {
|
325 |
+
background: rgba(255, 255, 255, 0.4);
|
326 |
+
border-radius: 8px;
|
327 |
+
padding: 1rem;
|
328 |
+
margin-bottom: 1.2rem;
|
329 |
+
backdrop-filter: blur(4px);
|
330 |
+
-webkit-backdrop-filter: blur(4px);
|
331 |
+
transition: all 0.3s ease;
|
332 |
+
position: relative;
|
333 |
+
overflow: hidden;
|
334 |
+
}
|
335 |
+
|
336 |
+
.step-card::before {
|
337 |
+
content: '';
|
338 |
+
position: absolute;
|
339 |
+
left: 0;
|
340 |
+
top: 0;
|
341 |
+
bottom: 0;
|
342 |
+
width: 3px;
|
343 |
+
transition: all 0.3s ease;
|
344 |
+
}
|
345 |
+
|
346 |
+
.step-card.common-steps::before {
|
347 |
+
background: var(--neutral-500);
|
348 |
+
}
|
349 |
+
|
350 |
+
.step-card.position-aware-steps::before {
|
351 |
+
background: var(--position-aware-blue);
|
352 |
+
}
|
353 |
+
|
354 |
+
.step-card.position-free-steps::before {
|
355 |
+
background: var(--position-free-purple);
|
356 |
+
}
|
357 |
+
|
358 |
+
.step-card:hover {
|
359 |
+
background: rgba(255, 255, 255, 0.5);
|
360 |
+
transform: translateX(2px);
|
361 |
+
}
|
362 |
+
|
363 |
+
.step-header {
|
364 |
+
font-weight: 600;
|
365 |
+
color: var(--text-primary);
|
366 |
+
margin-bottom: 0.75rem;
|
367 |
+
display: flex;
|
368 |
+
align-items: center;
|
369 |
+
font-size: 0.95rem;
|
370 |
+
}
|
371 |
+
|
372 |
+
.step-number {
|
373 |
+
color: white;
|
374 |
+
border-radius: 50%;
|
375 |
+
width: 24px;
|
376 |
+
height: 24px;
|
377 |
+
display: inline-flex;
|
378 |
+
align-items: center;
|
379 |
+
justify-content: center;
|
380 |
+
margin-right: 0.5rem;
|
381 |
+
font-size: 0.75rem;
|
382 |
+
font-weight: 700;
|
383 |
+
}
|
384 |
+
|
385 |
+
.common-steps .step-number {
|
386 |
+
background: var(--neutral-500);
|
387 |
+
}
|
388 |
+
|
389 |
+
.position-aware-steps .step-number {
|
390 |
+
background: var(--position-aware-blue);
|
391 |
+
}
|
392 |
+
|
393 |
+
.position-free-steps .step-number {
|
394 |
+
background: var(--position-free-purple);
|
395 |
+
}
|
396 |
+
|
397 |
+
.step-list {
|
398 |
+
margin: 0;
|
399 |
+
padding-left: 1.2rem;
|
400 |
+
font-size: 0.85rem;
|
401 |
+
color: var(--text-secondary);
|
402 |
+
line-height: 1.6;
|
403 |
+
}
|
404 |
+
|
405 |
+
.step-list li {
|
406 |
+
margin-bottom: 0.4rem;
|
407 |
+
position: relative;
|
408 |
+
}
|
409 |
+
|
410 |
+
.step-list li:last-child {
|
411 |
+
margin-bottom: 0;
|
412 |
+
}
|
413 |
+
|
414 |
+
/* Run buttons */
|
415 |
+
.run-button {
|
416 |
+
padding: 0.2rem 0.5rem;
|
417 |
+
border-radius: 4px;
|
418 |
+
font-weight: 600;
|
419 |
+
font-size: 0.8rem;
|
420 |
+
color: white;
|
421 |
+
transition: all 0.2s ease;
|
422 |
+
}
|
423 |
+
|
424 |
+
.run-button.position-aware {
|
425 |
+
background: var(--position-aware-blue);
|
426 |
+
}
|
427 |
+
|
428 |
+
.run-button.position-free {
|
429 |
+
background: var(--position-free-purple);
|
430 |
+
}
|
431 |
+
|
432 |
+
/* Tips card */
|
433 |
+
.tips-card {
|
434 |
+
background: rgba(241, 245, 249, 0.6);
|
435 |
+
border-radius: 8px;
|
436 |
+
padding: 0.8rem;
|
437 |
+
border-left: 3px solid var(--neutral-400);
|
438 |
+
margin-bottom: 1rem;
|
439 |
+
backdrop-filter: blur(4px);
|
440 |
+
-webkit-backdrop-filter: blur(4px);
|
441 |
+
transition: all 0.3s ease;
|
442 |
+
}
|
443 |
+
|
444 |
+
.tips-card:hover {
|
445 |
+
background: rgba(241, 245, 249, 0.8);
|
446 |
+
transform: translateX(2px);
|
447 |
+
}
|
448 |
+
|
449 |
+
.tips-content {
|
450 |
+
font-size: 0.8rem;
|
451 |
+
color: var(--text-tips);
|
452 |
+
line-height: 1.5;
|
453 |
+
}
|
454 |
+
|
455 |
+
/* Key hints */
|
456 |
+
.key-hint {
|
457 |
+
background: var(--kbd-bg);
|
458 |
+
color: var(--kbd-text);
|
459 |
+
padding: 0.1rem 0.3rem;
|
460 |
+
border-radius: 3px;
|
461 |
+
font-size: 0.75em;
|
462 |
+
border: 1px solid var(--kbd-border);
|
463 |
+
font-family: monospace;
|
464 |
+
font-weight: 500;
|
465 |
+
transition: all 0.2s ease;
|
466 |
+
}
|
467 |
+
|
468 |
+
.key-hint:hover {
|
469 |
+
background: var(--primary-blue);
|
470 |
+
color: white;
|
471 |
+
border-color: var(--primary-blue);
|
472 |
+
}
|
473 |
+
|
474 |
+
/* Final message */
|
475 |
+
.final-message {
|
476 |
+
padding: 0.8rem;
|
477 |
+
background: var(--bg-final);
|
478 |
+
border-radius: 8px;
|
479 |
+
text-align: center;
|
480 |
+
transition: all 0.3s ease;
|
481 |
+
}
|
482 |
+
|
483 |
+
.final-message:hover {
|
484 |
+
transform: translateY(-1px);
|
485 |
+
box-shadow: 0 4px 12px rgba(14, 165, 233, 0.1);
|
486 |
+
}
|
487 |
+
|
488 |
+
.final-text {
|
489 |
+
color: var(--text-final);
|
490 |
+
font-weight: 600;
|
491 |
+
font-size: 0.85rem;
|
492 |
+
}
|
493 |
+
|
494 |
+
/* Legacy header badge styling for backward compatibility */
|
495 |
+
.header-badge {
|
496 |
+
background: var(--primary-gradient);
|
497 |
+
color: white;
|
498 |
+
padding: 0.5rem 1rem;
|
499 |
+
border-radius: 6px;
|
500 |
+
font-weight: 500;
|
501 |
+
font-size: 0.9rem;
|
502 |
+
transition: all 0.3s ease;
|
503 |
+
box-shadow: 0 1px 3px rgba(14, 165, 233, 0.2);
|
504 |
+
display: inline-block;
|
505 |
+
}
|
506 |
+
|
507 |
+
.header-badge:hover {
|
508 |
+
transform: translateY(-2px);
|
509 |
+
box-shadow: 0 4px 8px rgba(14, 165, 233, 0.3);
|
510 |
+
text-decoration: none;
|
511 |
+
}
|
512 |
+
|
513 |
+
/* Accordion styling matching getting_started */
|
514 |
+
.gradio-accordion {
|
515 |
+
border: 1px solid rgba(14, 165, 233, 0.2);
|
516 |
+
border-radius: 8px;
|
517 |
+
overflow: visible !important; /* Allow dropdown to overflow */
|
518 |
+
box-shadow: 0 2px 8px rgba(14, 165, 233, 0.1);
|
519 |
+
background: linear-gradient(135deg, rgba(255, 255, 255, 0.98) 0%, rgba(240, 249, 255, 0.95) 100%);
|
520 |
+
}
|
521 |
+
|
522 |
+
/* Ensure accordion content area allows dropdown overflow */
|
523 |
+
.gradio-accordion .wrap {
|
524 |
+
overflow: visible !important;
|
525 |
+
}
|
526 |
+
|
527 |
+
.gradio-accordion > .label-wrap {
|
528 |
+
background: linear-gradient(135deg, rgba(255, 255, 255, 0.98) 0%, rgba(240, 249, 255, 0.95) 100%);
|
529 |
+
border-bottom: 1px solid rgba(14, 165, 233, 0.2);
|
530 |
+
padding: 1rem 1.5rem;
|
531 |
+
font-weight: 600;
|
532 |
+
color: var(--text-primary);
|
533 |
+
}
|
534 |
+
|
535 |
+
/* Minimal dropdown styling - let Gradio handle positioning naturally */
|
536 |
+
#aspect_ratio_dropdown {
|
537 |
+
border-radius: 8px;
|
538 |
+
}
|
539 |
+
|
540 |
+
/* COMPLETELY REMOVE all dropdown styling - let Gradio handle everything */
|
541 |
+
/* This was causing the dropdown to display as a text block instead of options */
|
542 |
+
|
543 |
+
/* DO NOT style .gradio-dropdown globally - causes functionality issues */
|
544 |
+
|
545 |
+
/* Slider styling matching theme */
|
546 |
+
.gradio-slider {
|
547 |
+
background: linear-gradient(135deg, rgba(255, 255, 255, 0.98) 0%, rgba(240, 249, 255, 0.95) 100%) !important;
|
548 |
+
border: 1px solid rgba(14, 165, 233, 0.2) !important;
|
549 |
+
border-radius: 8px !important;
|
550 |
+
box-shadow: 0 2px 8px rgba(14, 165, 233, 0.1) !important;
|
551 |
+
padding: 12px !important;
|
552 |
+
}
|
553 |
+
|
554 |
+
.gradio-slider:hover {
|
555 |
+
border-color: rgba(14, 165, 233, 0.3) !important;
|
556 |
+
box-shadow: 0 2px 8px rgba(14, 165, 233, 0.15) !important;
|
557 |
+
}
|
558 |
+
|
559 |
+
/* Slider input styling */
|
560 |
+
.gradio-slider input[type="range"] {
|
561 |
+
background: transparent !important;
|
562 |
+
}
|
563 |
+
|
564 |
+
.gradio-slider input[type="range"]::-webkit-slider-track {
|
565 |
+
background: rgba(14, 165, 233, 0.2) !important;
|
566 |
+
border-radius: 4px !important;
|
567 |
+
}
|
568 |
+
|
569 |
+
.gradio-slider input[type="range"]::-webkit-slider-thumb {
|
570 |
+
background: var(--primary-blue) !important;
|
571 |
+
border: 2px solid white !important;
|
572 |
+
box-shadow: 0 2px 8px rgba(14, 165, 233, 0.3) !important;
|
573 |
+
}
|
574 |
+
|
575 |
+
/* Checkbox styling matching theme */
|
576 |
+
.gradio-checkbox {
|
577 |
+
background: linear-gradient(135deg, rgba(255, 255, 255, 0.98) 0%, rgba(240, 249, 255, 0.95) 100%) !important;
|
578 |
+
border: 1px solid rgba(14, 165, 233, 0.2) !important;
|
579 |
+
border-radius: 8px !important;
|
580 |
+
box-shadow: 0 2px 8px rgba(14, 165, 233, 0.1) !important;
|
581 |
+
padding: 8px 12px !important;
|
582 |
+
}
|
583 |
+
|
584 |
+
/* Specific styling for identified components */
|
585 |
+
#aspect_ratio_dropdown,
|
586 |
+
#text_prompt,
|
587 |
+
#move_to_center_checkbox,
|
588 |
+
#use_bg_preservation_checkbox {
|
589 |
+
background: linear-gradient(135deg, rgba(255, 255, 255, 0.98) 0%, rgba(240, 249, 255, 0.95) 100%) !important;
|
590 |
+
border: 1px solid rgba(14, 165, 233, 0.2) !important;
|
591 |
+
border-radius: 8px !important;
|
592 |
+
box-shadow: 0 2px 8px rgba(14, 165, 233, 0.1) !important;
|
593 |
+
}
|
594 |
+
|
595 |
+
/* Removed specific aspect_ratio_dropdown styling to avoid conflicts */
|
596 |
+
|
597 |
+
#aspect_ratio_dropdown:hover,
|
598 |
+
#text_prompt:hover,
|
599 |
+
#move_to_center_checkbox:hover,
|
600 |
+
#use_bg_preservation_checkbox:hover {
|
601 |
+
border-color: rgba(14, 165, 233, 0.3) !important;
|
602 |
+
box-shadow: 0 2px 8px rgba(14, 165, 233, 0.15) !important;
|
603 |
+
}
|
604 |
+
|
605 |
+
/* Textbox specific styling */
|
606 |
+
#text_prompt textarea {
|
607 |
+
background: linear-gradient(135deg, rgba(255, 255, 255, 0.98) 0%, rgba(240, 249, 255, 0.95) 100%) !important;
|
608 |
+
border: 1px solid rgba(14, 165, 233, 0.2) !important;
|
609 |
+
border-radius: 8px !important;
|
610 |
+
box-shadow: 0 2px 8px rgba(14, 165, 233, 0.1) !important;
|
611 |
+
}
|
612 |
+
|
613 |
+
/* Color variables matching getting_started section exactly */
|
614 |
+
:root {
|
615 |
+
/* Primary colors from getting_started */
|
616 |
+
--primary-blue: #0ea5e9;
|
617 |
+
--primary-blue-secondary: #06b6d4;
|
618 |
+
--primary-green: #10b981;
|
619 |
+
--primary-gradient: linear-gradient(135deg, #0ea5e9 0%, #06b6d4 50%, #10b981 100%);
|
620 |
+
|
621 |
+
/* Mode-specific colors from getting_started */
|
622 |
+
--position-aware-blue: #3b82f6;
|
623 |
+
--position-free-purple: #8b5cf6;
|
624 |
+
|
625 |
+
/* Badge colors from getting_started */
|
626 |
+
--badge-blue-bg: #dbeafe;
|
627 |
+
--badge-blue-text: #1e40af;
|
628 |
+
--badge-green-bg: #dcfce7;
|
629 |
+
--badge-green-text: #166534;
|
630 |
+
|
631 |
+
/* Neutral colors from getting_started */
|
632 |
+
--neutral-50: #f8fafc;
|
633 |
+
--neutral-100: #f1f5f9;
|
634 |
+
--neutral-200: #e2e8f0;
|
635 |
+
--neutral-300: #cbd5e1;
|
636 |
+
--neutral-400: #94a3b8;
|
637 |
+
--neutral-500: #64748b;
|
638 |
+
--neutral-600: #475569;
|
639 |
+
--neutral-700: #334155;
|
640 |
+
--neutral-800: #1e293b;
|
641 |
+
|
642 |
+
/* Text colors from getting_started */
|
643 |
+
--text-primary: #1e293b;
|
644 |
+
--text-secondary: #4b5563;
|
645 |
+
--text-muted: #64748b;
|
646 |
+
--text-tips: #475569;
|
647 |
+
--text-final: #0c4a6e;
|
648 |
+
|
649 |
+
/* Background colors from getting_started */
|
650 |
+
--bg-primary: white;
|
651 |
+
--bg-secondary: #f8fafc;
|
652 |
+
--bg-tips: #f1f5f9;
|
653 |
+
--bg-final: linear-gradient(135deg, #f0f9ff 0%, #e0f2fe 100%);
|
654 |
+
|
655 |
+
/* Keyboard hint styles from getting_started */
|
656 |
+
--kbd-bg: #e2e8f0;
|
657 |
+
--kbd-text: #475569;
|
658 |
+
--kbd-border: #cbd5e1;
|
659 |
+
}
|
660 |
+
|
661 |
+
/* Global smooth transitions - exclude dropdowns */
|
662 |
+
*:not(.gradio-dropdown):not(.gradio-dropdown *) {
|
663 |
+
transition: all 0.2s ease;
|
664 |
+
}
|
665 |
+
|
666 |
+
/* Focus states using getting_started primary blue - exclude dropdowns */
|
667 |
+
button:focus,
|
668 |
+
input:not(.gradio-dropdown input):focus,
|
669 |
+
select:not(.gradio-dropdown select):focus,
|
670 |
+
textarea:not(.gradio-dropdown textarea):focus {
|
671 |
+
outline: none;
|
672 |
+
box-shadow: 0 0 0 2px rgba(14, 165, 233, 0.3);
|
673 |
+
}
|
674 |
+
|
675 |
+
/* Subtle hover effects for interactive elements - exclude dropdowns */
|
676 |
+
button:not(.gradio-dropdown button):hover {
|
677 |
+
transform: translateY(-1px);
|
678 |
+
}
|
679 |
+
|
680 |
+
/* Global text styling matching getting_started */
|
681 |
+
body {
|
682 |
+
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
|
683 |
+
line-height: 1.6;
|
684 |
+
color: var(--text-secondary);
|
685 |
+
background-color: var(--bg-secondary);
|
686 |
+
}
|
687 |
+
|
688 |
+
/* Enhanced form element styling - exclude dropdowns from global styling */
|
689 |
+
input:not(.gradio-dropdown input),
|
690 |
+
textarea:not(.gradio-dropdown textarea),
|
691 |
+
select:not(.gradio-dropdown select) {
|
692 |
+
border-radius: 8px;
|
693 |
+
border: 1px solid rgba(14, 165, 233, 0.2);
|
694 |
+
background: linear-gradient(135deg, rgba(255, 255, 255, 0.98) 0%, rgba(240, 249, 255, 0.95) 100%);
|
695 |
+
backdrop-filter: blur(8px);
|
696 |
+
-webkit-backdrop-filter: blur(8px);
|
697 |
+
color: var(--text-primary);
|
698 |
+
transition: all 0.3s ease;
|
699 |
+
padding: 12px 16px;
|
700 |
+
font-size: 0.95rem;
|
701 |
+
box-shadow: 0 2px 8px rgba(14, 165, 233, 0.1);
|
702 |
+
}
|
703 |
+
|
704 |
+
input:not(.gradio-dropdown input):focus,
|
705 |
+
textarea:not(.gradio-dropdown textarea):focus,
|
706 |
+
select:not(.gradio-dropdown select):focus {
|
707 |
+
border-color: var(--primary-blue);
|
708 |
+
box-shadow: 0 0 0 3px rgba(14, 165, 233, 0.1), 0 2px 8px rgba(14, 165, 233, 0.15);
|
709 |
+
background: linear-gradient(135deg, rgba(255, 255, 255, 1) 0%, rgba(240, 249, 255, 0.98) 100%);
|
710 |
+
outline: none;
|
711 |
+
transform: translateY(-1px);
|
712 |
+
}
|
713 |
+
|
714 |
+
input:not(.gradio-dropdown input):hover,
|
715 |
+
textarea:not(.gradio-dropdown textarea):hover,
|
716 |
+
select:not(.gradio-dropdown select):hover {
|
717 |
+
border-color: rgba(14, 165, 233, 0.3);
|
718 |
+
box-shadow: 0 2px 8px rgba(14, 165, 233, 0.12);
|
719 |
+
}
|
720 |
+
|
721 |
+
/* Textbox specific styling */
|
722 |
+
.gradio-textbox {
|
723 |
+
border-radius: 12px;
|
724 |
+
overflow: hidden;
|
725 |
+
}
|
726 |
+
|
727 |
+
.gradio-textbox textarea {
|
728 |
+
border-radius: 12px;
|
729 |
+
resize: vertical;
|
730 |
+
min-height: 44px;
|
731 |
+
}
|
732 |
+
|
733 |
+
/* Scrollbar styling matching getting_started */
|
734 |
+
::-webkit-scrollbar {
|
735 |
+
width: 8px;
|
736 |
+
}
|
737 |
+
|
738 |
+
::-webkit-scrollbar-track {
|
739 |
+
background: var(--neutral-100);
|
740 |
+
border-radius: 4px;
|
741 |
+
}
|
742 |
+
|
743 |
+
::-webkit-scrollbar-thumb {
|
744 |
+
background: var(--neutral-400);
|
745 |
+
border-radius: 4px;
|
746 |
+
}
|
747 |
+
|
748 |
+
::-webkit-scrollbar-thumb:hover {
|
749 |
+
background: var(--primary-blue);
|
750 |
+
}
|
751 |
+
|
752 |
+
/* Enhanced button styling with Apple-style refinement */
|
753 |
+
button {
|
754 |
+
border-radius: 8px;
|
755 |
+
font-weight: 500;
|
756 |
+
cursor: pointer;
|
757 |
+
transition: all 0.3s ease;
|
758 |
+
border: 1px solid var(--neutral-200);
|
759 |
+
position: relative;
|
760 |
+
overflow: hidden;
|
761 |
+
backdrop-filter: blur(8px);
|
762 |
+
-webkit-backdrop-filter: blur(8px);
|
763 |
+
}
|
764 |
+
|
765 |
+
/* Button hover glow effect */
|
766 |
+
button::after {
|
767 |
+
content: '';
|
768 |
+
position: absolute;
|
769 |
+
top: 50%;
|
770 |
+
left: 50%;
|
771 |
+
width: 0;
|
772 |
+
height: 0;
|
773 |
+
background: radial-gradient(circle, rgba(14, 165, 233, 0.1) 0%, transparent 70%);
|
774 |
+
transition: all 0.4s ease;
|
775 |
+
transform: translate(-50%, -50%);
|
776 |
+
pointer-events: none;
|
777 |
+
}
|
778 |
+
|
779 |
+
button:hover::after {
|
780 |
+
width: 200px;
|
781 |
+
height: 200px;
|
782 |
+
}
|
783 |
+
|
784 |
+
/* Primary button using unified primary blue */
|
785 |
+
button[variant="primary"] {
|
786 |
+
background: var(--primary-blue);
|
787 |
+
border-color: var(--primary-blue);
|
788 |
+
color: white;
|
789 |
+
box-shadow: 0 2px 6px rgba(14, 165, 233, 0.2);
|
790 |
+
}
|
791 |
+
|
792 |
+
button[variant="primary"]:hover {
|
793 |
+
background: #0284c7;
|
794 |
+
border-color: #0284c7;
|
795 |
+
box-shadow: 0 4px 12px rgba(14, 165, 233, 0.3);
|
796 |
+
transform: translateY(-1px);
|
797 |
+
}
|
798 |
+
|
799 |
+
/* Secondary buttons */
|
800 |
+
button[variant="secondary"], .secondary-button {
|
801 |
+
background: linear-gradient(135deg, rgba(255, 255, 255, 0.98) 0%, rgba(240, 249, 255, 0.95) 100%);
|
802 |
+
border: 1px solid rgba(14, 165, 233, 0.2);
|
803 |
+
color: var(--text-secondary);
|
804 |
+
box-shadow: 0 2px 8px rgba(14, 165, 233, 0.1);
|
805 |
+
}
|
806 |
+
|
807 |
+
button[variant="secondary"]:hover, .secondary-button:hover {
|
808 |
+
background: var(--primary-blue);
|
809 |
+
border-color: var(--primary-blue);
|
810 |
+
color: white;
|
811 |
+
transform: translateY(-1px);
|
812 |
+
box-shadow: 0 4px 12px rgba(14, 165, 233, 0.25);
|
813 |
+
}
|
814 |
+
|
815 |
+
/* VLM buttons with subtle, elegant styling */
|
816 |
+
#vlm_generate_btn, #vlm_polish_btn {
|
817 |
+
background: linear-gradient(135deg, rgba(255, 255, 255, 0.98) 0%, rgba(240, 249, 255, 0.95) 100%) !important;
|
818 |
+
border: 1px solid rgba(14, 165, 233, 0.2) !important;
|
819 |
+
color: var(--text-secondary) !important;
|
820 |
+
box-shadow: 0 2px 8px rgba(14, 165, 233, 0.1);
|
821 |
+
font-weight: 500;
|
822 |
+
border-radius: 8px;
|
823 |
+
position: relative;
|
824 |
+
overflow: hidden;
|
825 |
+
transition: all 0.3s ease;
|
826 |
+
backdrop-filter: blur(8px);
|
827 |
+
-webkit-backdrop-filter: blur(8px);
|
828 |
+
}
|
829 |
+
|
830 |
+
#vlm_generate_btn::before, #vlm_polish_btn::before {
|
831 |
+
content: '';
|
832 |
+
position: absolute;
|
833 |
+
top: 0;
|
834 |
+
left: -100%;
|
835 |
+
width: 100%;
|
836 |
+
height: 100%;
|
837 |
+
background: linear-gradient(90deg, transparent, rgba(14, 165, 233, 0.1), transparent);
|
838 |
+
transition: left 0.5s ease;
|
839 |
+
}
|
840 |
+
|
841 |
+
#vlm_generate_btn:hover::before, #vlm_polish_btn:hover::before {
|
842 |
+
left: 100%;
|
843 |
+
}
|
844 |
+
|
845 |
+
#vlm_generate_btn:hover, #vlm_polish_btn:hover {
|
846 |
+
background: var(--primary-blue) !important;
|
847 |
+
border-color: var(--primary-blue) !important;
|
848 |
+
color: white !important;
|
849 |
+
box-shadow: 0 4px 12px rgba(14, 165, 233, 0.25);
|
850 |
+
transform: translateY(-1px);
|
851 |
+
}
|
852 |
+
|
853 |
+
#vlm_generate_btn:active, #vlm_polish_btn:active {
|
854 |
+
transform: translateY(0px);
|
855 |
+
box-shadow: 0 2px 6px rgba(14, 165, 233, 0.2);
|
856 |
+
}
|
857 |
+
|
858 |
+
/* Enhanced image styling with fixed dimensions for consistency */
|
859 |
+
.gradio-image, .gradio-imageeditor {
|
860 |
+
height: 300px !important;
|
861 |
+
width: 100% !important;
|
862 |
+
padding: 0 !important;
|
863 |
+
margin: 0 !important;
|
864 |
+
}
|
865 |
+
|
866 |
+
.gradio-image img,
|
867 |
+
.gradio-imageeditor img {
|
868 |
+
height: 300px !important;
|
869 |
+
width: 100% !important;
|
870 |
+
object-fit: contain !important;
|
871 |
+
border-radius: 8px;
|
872 |
+
transition: all 0.3s ease;
|
873 |
+
box-shadow: 0 2px 8px rgba(15, 23, 42, 0.1);
|
874 |
+
}
|
875 |
+
|
876 |
+
.gradio-image img:hover,
|
877 |
+
.gradio-imageeditor img:hover {
|
878 |
+
transform: scale(1.02);
|
879 |
+
box-shadow: 0 4px 16px rgba(15, 23, 42, 0.15);
|
880 |
+
}
|
881 |
+
|
882 |
+
/* Gallery CSS - contained adaptive layout with theme colors */
|
883 |
+
#mask_gallery, #result_gallery, .custom-gallery {
|
884 |
+
overflow: visible !important; /* Allow progress indicator to show */
|
885 |
+
position: relative !important;
|
886 |
+
width: 100% !important;
|
887 |
+
height: auto !important;
|
888 |
+
max-height: 75vh !important;
|
889 |
+
min-height: 300px !important;
|
890 |
+
display: flex !important;
|
891 |
+
flex-direction: column !important;
|
892 |
+
padding: 6px !important;
|
893 |
+
margin: 0 !important;
|
894 |
+
border-radius: 12px !important;
|
895 |
+
background: linear-gradient(135deg, rgba(255, 255, 255, 0.98) 0%, rgba(240, 249, 255, 0.95) 100%) !important;
|
896 |
+
border: 1px solid rgba(14, 165, 233, 0.2) !important;
|
897 |
+
box-shadow: 0 2px 8px rgba(14, 165, 233, 0.1) !important;
|
898 |
+
}
|
899 |
+
|
900 |
+
/* Gallery containers with contained but flexible display */
|
901 |
+
#mask_gallery .gradio-gallery, #result_gallery .gradio-gallery {
|
902 |
+
width: 100% !important;
|
903 |
+
height: auto !important;
|
904 |
+
min-height: 280px !important;
|
905 |
+
max-height: 70vh !important;
|
906 |
+
padding: 8px !important;
|
907 |
+
overflow: auto !important;
|
908 |
+
display: flex !important;
|
909 |
+
flex-direction: column !important;
|
910 |
+
border-radius: 8px !important;
|
911 |
+
}
|
912 |
+
|
913 |
+
/* Only hide specific duplicate elements that cause the problem */
|
914 |
+
#mask_gallery > div > div:nth-child(n+2),
|
915 |
+
#result_gallery > div > div:nth-child(n+2) {
|
916 |
+
display: none !important;
|
917 |
+
}
|
918 |
+
|
919 |
+
/* Alternative: hide duplicate grid structures only */
|
920 |
+
#mask_gallery .gradio-gallery:nth-child(n+2),
|
921 |
+
#result_gallery .gradio-gallery:nth-child(n+2) {
|
922 |
+
display: none !important;
|
923 |
+
}
|
924 |
+
|
925 |
+
/* Ensure timing and status elements are NOT hidden by the above rules */
|
926 |
+
#result_gallery .status,
|
927 |
+
#result_gallery .timer,
|
928 |
+
#result_gallery [class*="time"],
|
929 |
+
#result_gallery [class*="status"],
|
930 |
+
#result_gallery [class*="duration"],
|
931 |
+
#result_gallery .gradio-status,
|
932 |
+
#result_gallery .gradio-timer,
|
933 |
+
#result_gallery .gradio-info,
|
934 |
+
#result_gallery [data-testid*="timer"],
|
935 |
+
#result_gallery [data-testid*="status"] {
|
936 |
+
display: block !important;
|
937 |
+
visibility: visible !important;
|
938 |
+
opacity: 1 !important;
|
939 |
+
position: relative !important;
|
940 |
+
z-index: 1000 !important;
|
941 |
+
}
|
942 |
+
|
943 |
+
/* Gallery images - contained adaptive display */
|
944 |
+
#mask_gallery img, #result_gallery img {
|
945 |
+
width: 100% !important;
|
946 |
+
height: auto !important;
|
947 |
+
max-width: 100% !important;
|
948 |
+
max-height: 60vh !important;
|
949 |
+
object-fit: contain !important;
|
950 |
+
border-radius: 8px;
|
951 |
+
box-shadow: 0 2px 8px rgba(15, 23, 42, 0.1);
|
952 |
+
display: block !important;
|
953 |
+
margin: 0 auto !important;
|
954 |
+
}
|
955 |
+
|
956 |
+
/* Main preview image styling - contained but responsive */
|
957 |
+
#mask_gallery .preview-image, #result_gallery .preview-image {
|
958 |
+
width: 100% !important;
|
959 |
+
height: auto !important;
|
960 |
+
max-width: 100% !important;
|
961 |
+
max-height: 55vh !important;
|
962 |
+
border-radius: 12px;
|
963 |
+
box-shadow: 0 4px 16px rgba(15, 23, 42, 0.15);
|
964 |
+
object-fit: contain !important;
|
965 |
+
display: block !important;
|
966 |
+
margin: 0 auto !important;
|
967 |
+
}
|
968 |
+
|
969 |
+
/* Gallery content wrappers - ensure no height constraints */
|
970 |
+
#mask_gallery .gradio-gallery > div,
|
971 |
+
#result_gallery .gradio-gallery > div {
|
972 |
+
width: 100% !important;
|
973 |
+
height: auto !important;
|
974 |
+
min-height: auto !important;
|
975 |
+
max-height: none !important;
|
976 |
+
overflow: visible !important;
|
977 |
+
}
|
978 |
+
|
979 |
+
/* Gallery image containers - remove any height limits */
|
980 |
+
#mask_gallery .image-container,
|
981 |
+
#result_gallery .image-container,
|
982 |
+
#mask_gallery [data-testid="image"],
|
983 |
+
#result_gallery [data-testid="image"] {
|
984 |
+
width: 100% !important;
|
985 |
+
height: auto !important;
|
986 |
+
max-height: none !important;
|
987 |
+
overflow: visible !important;
|
988 |
+
}
|
989 |
+
|
990 |
+
/* Controlled gallery wrapper elements */
|
991 |
+
#mask_gallery .image-wrapper,
|
992 |
+
#result_gallery .image-wrapper {
|
993 |
+
max-height: 60vh !important;
|
994 |
+
overflow: hidden !important;
|
995 |
+
}
|
996 |
+
|
997 |
+
/* Specific targeting for Gradio's internal gallery elements */
|
998 |
+
#mask_gallery .grid-wrap,
|
999 |
+
#result_gallery .grid-wrap,
|
1000 |
+
#mask_gallery .preview-wrap,
|
1001 |
+
#result_gallery .preview-wrap {
|
1002 |
+
height: auto !important;
|
1003 |
+
max-height: 65vh !important;
|
1004 |
+
overflow: auto !important;
|
1005 |
+
border-radius: 8px !important;
|
1006 |
+
}
|
1007 |
+
|
1008 |
+
/* Ensure gallery grids are properly sized within container */
|
1009 |
+
#mask_gallery .grid,
|
1010 |
+
#result_gallery .grid {
|
1011 |
+
height: auto !important;
|
1012 |
+
max-height: 60vh !important;
|
1013 |
+
display: grid !important;
|
1014 |
+
grid-template-columns: repeat(auto-fit, minmax(120px, 1fr)) !important;
|
1015 |
+
gap: 8px !important;
|
1016 |
+
align-items: start !important;
|
1017 |
+
overflow: auto !important;
|
1018 |
+
padding: 4px !important;
|
1019 |
+
}
|
1020 |
+
|
1021 |
+
/* Custom scrollbar for gallery */
|
1022 |
+
#mask_gallery .gradio-gallery::-webkit-scrollbar,
|
1023 |
+
#result_gallery .gradio-gallery::-webkit-scrollbar,
|
1024 |
+
#mask_gallery .grid::-webkit-scrollbar,
|
1025 |
+
#result_gallery .grid::-webkit-scrollbar {
|
1026 |
+
width: 6px;
|
1027 |
+
height: 6px;
|
1028 |
+
}
|
1029 |
+
|
1030 |
+
#mask_gallery .gradio-gallery::-webkit-scrollbar-track,
|
1031 |
+
#result_gallery .gradio-gallery::-webkit-scrollbar-track,
|
1032 |
+
#mask_gallery .grid::-webkit-scrollbar-track,
|
1033 |
+
#result_gallery .grid::-webkit-scrollbar-track {
|
1034 |
+
background: rgba(0, 0, 0, 0.1);
|
1035 |
+
border-radius: 3px;
|
1036 |
+
}
|
1037 |
+
|
1038 |
+
#mask_gallery .gradio-gallery::-webkit-scrollbar-thumb,
|
1039 |
+
#result_gallery .gradio-gallery::-webkit-scrollbar-thumb,
|
1040 |
+
#mask_gallery .grid::-webkit-scrollbar-thumb,
|
1041 |
+
#result_gallery .grid::-webkit-scrollbar-thumb {
|
1042 |
+
background: var(--neutral-400);
|
1043 |
+
border-radius: 3px;
|
1044 |
+
}
|
1045 |
+
|
1046 |
+
#mask_gallery .gradio-gallery::-webkit-scrollbar-thumb:hover,
|
1047 |
+
#result_gallery .gradio-gallery::-webkit-scrollbar-thumb:hover,
|
1048 |
+
#mask_gallery .grid::-webkit-scrollbar-thumb:hover,
|
1049 |
+
#result_gallery .grid::-webkit-scrollbar-thumb:hover {
|
1050 |
+
background: var(--primary-blue);
|
1051 |
+
}
|
1052 |
+
|
1053 |
+
/* Thumbnail navigation styling in preview mode */
|
1054 |
+
#mask_gallery .thumbnail, #result_gallery .thumbnail {
|
1055 |
+
opacity: 0.7;
|
1056 |
+
transition: opacity 0.3s ease;
|
1057 |
+
border-radius: 6px;
|
1058 |
+
}
|
1059 |
+
|
1060 |
+
#mask_gallery .thumbnail:hover, #result_gallery .thumbnail:hover,
|
1061 |
+
#mask_gallery .thumbnail.selected, #result_gallery .thumbnail.selected {
|
1062 |
+
opacity: 1;
|
1063 |
+
box-shadow: 0 2px 8px rgba(14, 165, 233, 0.3);
|
1064 |
+
}
|
1065 |
+
|
1066 |
+
/* Improved layout spacing and organization */
|
1067 |
+
#glass_card .gradio-row {
|
1068 |
+
gap: 16px !important;
|
1069 |
+
margin-bottom: 6px !important;
|
1070 |
+
}
|
1071 |
+
|
1072 |
+
#glass_card .gradio-column {
|
1073 |
+
gap: 10px !important;
|
1074 |
+
}
|
1075 |
+
|
1076 |
+
/* Better section spacing with theme colors */
|
1077 |
+
#glass_card .gradio-group {
|
1078 |
+
margin-bottom: 6px !important;
|
1079 |
+
padding: 10px !important;
|
1080 |
+
border-radius: 8px !important;
|
1081 |
+
background: linear-gradient(135deg, rgba(255, 255, 255, 0.95) 0%, rgba(240, 249, 255, 0.9) 100%) !important;
|
1082 |
+
border: 1px solid rgba(14, 165, 233, 0.15) !important;
|
1083 |
+
box-shadow: 0 1px 3px rgba(14, 165, 233, 0.05) !important;
|
1084 |
+
transition: all 0.3s ease !important;
|
1085 |
+
overflow: visible !important; /* Allow dropdown to overflow */
|
1086 |
+
}
|
1087 |
+
|
1088 |
+
#glass_card .gradio-group:hover {
|
1089 |
+
background: linear-gradient(135deg, rgba(255, 255, 255, 0.98) 0%, rgba(240, 249, 255, 0.95) 100%) !important;
|
1090 |
+
border-color: rgba(14, 165, 233, 0.25) !important;
|
1091 |
+
box-shadow: 0 2px 8px rgba(14, 165, 233, 0.12) !important;
|
1092 |
+
/* transform removed to prevent layout shift that hides dropdown */
|
1093 |
+
}
|
1094 |
+
|
1095 |
+
/* Enhanced button styling for improved UX */
|
1096 |
+
button[variant="secondary"] {
|
1097 |
+
background: linear-gradient(135deg, rgba(255, 255, 255, 0.98) 0%, rgba(240, 249, 255, 0.95) 100%) !important;
|
1098 |
+
border: 1px solid rgba(14, 165, 233, 0.2) !important;
|
1099 |
+
color: var(--text-secondary) !important;
|
1100 |
+
box-shadow: 0 2px 8px rgba(14, 165, 233, 0.1) !important;
|
1101 |
+
transition: all 0.3s ease !important;
|
1102 |
+
}
|
1103 |
+
|
1104 |
+
button[variant="secondary"]:hover {
|
1105 |
+
background: var(--primary-blue) !important;
|
1106 |
+
border-color: var(--primary-blue) !important;
|
1107 |
+
color: white !important;
|
1108 |
+
transform: translateY(-1px) !important;
|
1109 |
+
box-shadow: 0 4px 12px rgba(14, 165, 233, 0.25) !important;
|
1110 |
+
}
|
1111 |
+
|
1112 |
+
/* Markdown header improvements */
|
1113 |
+
.gradio-markdown h1, .gradio-markdown h2, .gradio-markdown h3 {
|
1114 |
+
color: var(--text-primary) !important;
|
1115 |
+
font-weight: 600 !important;
|
1116 |
+
margin-bottom: 8px !important;
|
1117 |
+
margin-top: 4px !important;
|
1118 |
+
}
|
1119 |
+
|
1120 |
+
/* Radio button container improvements */
|
1121 |
+
#customization_mode_radio, #input_mask_mode_radio, #seg_ref_mode_radio {
|
1122 |
+
margin-bottom: 0px !important;
|
1123 |
+
margin-top: 0px !important;
|
1124 |
+
}
|
1125 |
+
|
1126 |
+
/* Reduce space between markdown headers and subsequent components */
|
1127 |
+
.gradio-markdown + .gradio-group {
|
1128 |
+
margin-top: 1px !important;
|
1129 |
+
}
|
1130 |
+
|
1131 |
+
.gradio-markdown + .gradio-image,
|
1132 |
+
.gradio-markdown + .gradio-imageeditor,
|
1133 |
+
.gradio-markdown + .gradio-textbox,
|
1134 |
+
.gradio-markdown + .gradio-gallery {
|
1135 |
+
margin-top: 1px !important;
|
1136 |
+
}
|
1137 |
+
|
1138 |
+
/* Specific spacing adjustments for numbered sections */
|
1139 |
+
.gradio-markdown:has(h1), .gradio-markdown:has(h2), .gradio-markdown:has(h3) {
|
1140 |
+
margin-bottom: 2px !important;
|
1141 |
+
}
|
1142 |
+
|
1143 |
+
/* Remove padding from image and gallery containers */
|
1144 |
+
.gradio-image, .gradio-imageeditor, .gradio-gallery {
|
1145 |
+
padding: 0 !important;
|
1146 |
+
margin: 0 !important;
|
1147 |
+
}
|
1148 |
+
|
1149 |
+
/* Image container styling with theme colors */
|
1150 |
+
.gradio-image, .gradio-imageeditor {
|
1151 |
+
border-radius: 12px;
|
1152 |
+
overflow: hidden;
|
1153 |
+
background: linear-gradient(135deg, rgba(255, 255, 255, 0.98) 0%, rgba(240, 249, 255, 0.95) 100%) !important;
|
1154 |
+
border: 1px solid rgba(14, 165, 233, 0.2) !important;
|
1155 |
+
box-shadow: 0 2px 8px rgba(14, 165, 233, 0.1) !important;
|
1156 |
+
transition: all 0.3s ease;
|
1157 |
+
}
|
1158 |
+
|
1159 |
+
.gradio-image:hover, .gradio-imageeditor:hover {
|
1160 |
+
border-color: var(--primary-blue) !important;
|
1161 |
+
box-shadow: 0 4px 16px rgba(14, 165, 233, 0.15) !important;
|
1162 |
+
transform: translateY(-1px);
|
1163 |
+
}
|
1164 |
+
|
1165 |
+
/* Image upload area styling */
|
1166 |
+
.gradio-image .upload-container,
|
1167 |
+
.gradio-imageeditor .upload-container,
|
1168 |
+
.gradio-image > div,
|
1169 |
+
.gradio-imageeditor > div {
|
1170 |
+
background: linear-gradient(135deg, rgba(255, 255, 255, 0.98) 0%, rgba(240, 249, 255, 0.95) 100%) !important;
|
1171 |
+
border: 1px solid rgba(14, 165, 233, 0.2) !important;
|
1172 |
+
border-radius: 12px !important;
|
1173 |
+
}
|
1174 |
+
|
1175 |
+
/* Image upload placeholder styling */
|
1176 |
+
.gradio-image .upload-text,
|
1177 |
+
.gradio-imageeditor .upload-text,
|
1178 |
+
.gradio-image [data-testid="upload-text"],
|
1179 |
+
.gradio-imageeditor [data-testid="upload-text"] {
|
1180 |
+
background: linear-gradient(135deg, rgba(255, 255, 255, 0.98) 0%, rgba(240, 249, 255, 0.95) 100%) !important;
|
1181 |
+
color: var(--text-secondary) !important;
|
1182 |
+
}
|
1183 |
+
|
1184 |
+
/* Image preview area */
|
1185 |
+
.gradio-image .image-container,
|
1186 |
+
.gradio-imageeditor .image-container,
|
1187 |
+
.gradio-image .preview-container,
|
1188 |
+
.gradio-imageeditor .preview-container {
|
1189 |
+
background: linear-gradient(135deg, rgba(255, 255, 255, 0.98) 0%, rgba(240, 249, 255, 0.95) 100%) !important;
|
1190 |
+
border-radius: 12px !important;
|
1191 |
+
}
|
1192 |
+
|
1193 |
+
/* Specific targeting for image upload areas */
|
1194 |
+
.gradio-image .wrap,
|
1195 |
+
.gradio-imageeditor .wrap,
|
1196 |
+
.gradio-image .block,
|
1197 |
+
.gradio-imageeditor .block {
|
1198 |
+
background: linear-gradient(135deg, rgba(255, 255, 255, 0.98) 0%, rgba(240, 249, 255, 0.95) 100%) !important;
|
1199 |
+
border: 1px solid rgba(14, 165, 233, 0.2) !important;
|
1200 |
+
border-radius: 12px !important;
|
1201 |
+
}
|
1202 |
+
|
1203 |
+
/* Image drop zone styling */
|
1204 |
+
.gradio-image .drop-zone,
|
1205 |
+
.gradio-imageeditor .drop-zone,
|
1206 |
+
.gradio-image .upload-area,
|
1207 |
+
.gradio-imageeditor .upload-area {
|
1208 |
+
background: linear-gradient(135deg, rgba(255, 255, 255, 0.98) 0%, rgba(240, 249, 255, 0.95) 100%) !important;
|
1209 |
+
border: 2px dashed rgba(14, 165, 233, 0.3) !important;
|
1210 |
+
border-radius: 12px !important;
|
1211 |
+
}
|
1212 |
+
|
1213 |
+
/* Force override any white backgrounds in image components */
|
1214 |
+
.gradio-image *,
|
1215 |
+
.gradio-imageeditor * {
|
1216 |
+
background-color: transparent !important;
|
1217 |
+
}
|
1218 |
+
|
1219 |
+
.gradio-image .gradio-image,
|
1220 |
+
.gradio-imageeditor .gradio-imageeditor {
|
1221 |
+
background: linear-gradient(135deg, rgba(255, 255, 255, 0.98) 0%, rgba(240, 249, 255, 0.95) 100%) !important;
|
1222 |
+
}
|
1223 |
+
|
1224 |
+
/* Specific styling for Reference Image and Target Images */
|
1225 |
+
#reference_image,
|
1226 |
+
#target_image_1,
|
1227 |
+
#target_image_2 {
|
1228 |
+
background: linear-gradient(135deg, rgba(255, 255, 255, 0.98) 0%, rgba(240, 249, 255, 0.95) 100%) !important;
|
1229 |
+
border: 1px solid rgba(14, 165, 233, 0.2) !important;
|
1230 |
+
box-shadow: 0 2px 8px rgba(14, 165, 233, 0.1) !important;
|
1231 |
+
border-radius: 12px !important;
|
1232 |
+
}
|
1233 |
+
|
1234 |
+
#reference_image *,
|
1235 |
+
#target_image_1 *,
|
1236 |
+
#target_image_2 * {
|
1237 |
+
background: linear-gradient(135deg, rgba(255, 255, 255, 0.98) 0%, rgba(240, 249, 255, 0.95) 100%) !important;
|
1238 |
+
border-radius: 12px !important;
|
1239 |
+
}
|
1240 |
+
|
1241 |
+
/* Upload area for specific image components */
|
1242 |
+
#reference_image .upload-container,
|
1243 |
+
#target_image_1 .upload-container,
|
1244 |
+
#target_image_2 .upload-container,
|
1245 |
+
#reference_image .drop-zone,
|
1246 |
+
#target_image_1 .drop-zone,
|
1247 |
+
#target_image_2 .drop-zone {
|
1248 |
+
background: linear-gradient(135deg, rgba(255, 255, 255, 0.98) 0%, rgba(240, 249, 255, 0.95) 100%) !important;
|
1249 |
+
border: 2px dashed rgba(14, 165, 233, 0.3) !important;
|
1250 |
+
border-radius: 12px !important;
|
1251 |
+
}
|
1252 |
+
|
1253 |
+
/* Hover effects for specific image components */
|
1254 |
+
#reference_image:hover,
|
1255 |
+
#target_image_1:hover,
|
1256 |
+
#target_image_2:hover {
|
1257 |
+
border-color: var(--primary-blue) !important;
|
1258 |
+
box-shadow: 0 4px 16px rgba(14, 165, 233, 0.15) !important;
|
1259 |
+
transform: translateY(-1px);
|
1260 |
+
}
|
1261 |
+
|
1262 |
+
/* Group styling matching getting_started white cards */
|
1263 |
+
.group, .gradio-group {
|
1264 |
+
border-radius: 8px;
|
1265 |
+
background: var(--bg-primary);
|
1266 |
+
border: 1px solid var(--neutral-200);
|
1267 |
+
box-shadow: 0 1px 3px rgba(0, 0, 0, 0.05);
|
1268 |
+
}
|
1269 |
+
|
1270 |
+
/* Subtle page background with theme colors */
|
1271 |
+
body, .gradio-container {
|
1272 |
+
background: linear-gradient(135deg, #f8fafc 0%, #f0f9ff 50%, #f8fafc 100%);
|
1273 |
+
min-height: 100vh;
|
1274 |
+
}
|
1275 |
+
|
1276 |
+
/* Global glass container with subtle Apple-style gradient */
|
1277 |
+
#global_glass_container {
|
1278 |
+
position: relative;
|
1279 |
+
border-radius: 20px;
|
1280 |
+
padding: 16px;
|
1281 |
+
margin: 12px auto;
|
1282 |
+
max-width: 1400px;
|
1283 |
+
background: linear-gradient(145deg,
|
1284 |
+
rgba(248, 250, 252, 0.98),
|
1285 |
+
rgba(241, 245, 249, 0.95));
|
1286 |
+
box-shadow:
|
1287 |
+
0 20px 40px rgba(15, 23, 42, 0.08),
|
1288 |
+
0 8px 24px rgba(15, 23, 42, 0.06),
|
1289 |
+
inset 0 1px 0 rgba(255, 255, 255, 0.9);
|
1290 |
+
border: 1px solid rgba(226, 232, 240, 0.7);
|
1291 |
+
transition: all 0.3s ease;
|
1292 |
+
overflow: visible !important; /* Allow dropdown to overflow */
|
1293 |
+
}
|
1294 |
+
|
1295 |
+
/* Subtle gradient overlay for Apple effect */
|
1296 |
+
#global_glass_container::after {
|
1297 |
+
content: '';
|
1298 |
+
position: absolute;
|
1299 |
+
top: 0;
|
1300 |
+
left: 0;
|
1301 |
+
right: 0;
|
1302 |
+
height: 250px;
|
1303 |
+
background: linear-gradient(135deg,
|
1304 |
+
rgba(14, 165, 233, 0.08) 0%,
|
1305 |
+
rgba(6, 182, 212, 0.06) 25%,
|
1306 |
+
rgba(16, 185, 129, 0.08) 50%,
|
1307 |
+
rgba(139, 92, 246, 0.06) 75%,
|
1308 |
+
rgba(14, 165, 233, 0.08) 100%);
|
1309 |
+
background-size: 300% 300%;
|
1310 |
+
animation: subtleGradientShift 15s ease-in-out infinite;
|
1311 |
+
pointer-events: none;
|
1312 |
+
z-index: 0;
|
1313 |
+
}
|
1314 |
+
|
1315 |
+
@keyframes subtleGradientShift {
|
1316 |
+
0%, 100% {
|
1317 |
+
background-position: 0% 50%;
|
1318 |
+
opacity: 0.8;
|
1319 |
+
}
|
1320 |
+
50% {
|
1321 |
+
background-position: 100% 50%;
|
1322 |
+
opacity: 1;
|
1323 |
+
}
|
1324 |
+
}
|
1325 |
+
|
1326 |
+
/* Ensure content is above the gradient overlay */
|
1327 |
+
#global_glass_container > * {
|
1328 |
+
position: relative;
|
1329 |
+
z-index: 1;
|
1330 |
+
}
|
1331 |
+
|
1332 |
+
/* Hover effect for global container - transform disabled to avoid dropdown reposition */
|
1333 |
+
#global_glass_container:hover {
|
1334 |
+
/* transform: translateY(-2px); */
|
1335 |
+
box-shadow:
|
1336 |
+
0 25px 50px rgba(15, 23, 42, 0.08),
|
1337 |
+
0 12px 30px rgba(15, 23, 42, 0.06),
|
1338 |
+
inset 0 1px 0 rgba(255, 255, 255, 0.9);
|
1339 |
+
border-color: rgba(226, 232, 240, 0.8);
|
1340 |
+
}
|
1341 |
+
|
1342 |
+
/* Subtle border highlight for global container */
|
1343 |
+
#global_glass_container::before {
|
1344 |
+
content: "";
|
1345 |
+
position: absolute;
|
1346 |
+
inset: 0;
|
1347 |
+
border-radius: 20px;
|
1348 |
+
padding: 1px;
|
1349 |
+
background: linear-gradient(135deg,
|
1350 |
+
rgba(255, 255, 255, 0.8),
|
1351 |
+
rgba(226, 232, 240, 0.4),
|
1352 |
+
rgba(255, 255, 255, 0.6),
|
1353 |
+
rgba(226, 232, 240, 0.3)
|
1354 |
+
);
|
1355 |
+
-webkit-mask:
|
1356 |
+
linear-gradient(#fff 0 0) content-box,
|
1357 |
+
linear-gradient(#fff 0 0);
|
1358 |
+
-webkit-mask-composite: xor;
|
1359 |
+
mask-composite: exclude;
|
1360 |
+
pointer-events: none;
|
1361 |
+
z-index: 0;
|
1362 |
+
}
|
1363 |
+
|
1364 |
+
/* Inner glassmorphism container with theme colors */
|
1365 |
+
#glass_card {
|
1366 |
+
position: relative;
|
1367 |
+
border-radius: 16px;
|
1368 |
+
padding: 16px;
|
1369 |
+
background: linear-gradient(135deg, rgba(255, 255, 255, 0.6) 0%, rgba(240, 249, 255, 0.5) 100%);
|
1370 |
+
box-shadow:
|
1371 |
+
0 8px 24px rgba(14, 165, 233, 0.08),
|
1372 |
+
inset 0 1px 0 rgba(255, 255, 255, 0.7);
|
1373 |
+
border: 1px solid rgba(14, 165, 233, 0.2);
|
1374 |
+
margin-bottom: 12px;
|
1375 |
+
transition: all 0.3s ease;
|
1376 |
+
overflow: visible !important; /* Allow dropdown to overflow */
|
1377 |
+
}
|
1378 |
+
|
1379 |
+
#glass_card:hover {
|
1380 |
+
background: linear-gradient(135deg, rgba(255, 255, 255, 0.7) 0%, rgba(240, 249, 255, 0.6) 100%);
|
1381 |
+
border-color: rgba(14, 165, 233, 0.3);
|
1382 |
+
box-shadow:
|
1383 |
+
0 12px 32px rgba(14, 165, 233, 0.12),
|
1384 |
+
inset 0 1px 0 rgba(255, 255, 255, 0.8);
|
1385 |
+
}
|
1386 |
+
|
1387 |
+
/* Subtle inner border gradient for liquid glass feel */
|
1388 |
+
#glass_card::before {
|
1389 |
+
content: "";
|
1390 |
+
position: absolute;
|
1391 |
+
inset: 0;
|
1392 |
+
border-radius: 16px;
|
1393 |
+
padding: 1px;
|
1394 |
+
background: linear-gradient(135deg,
|
1395 |
+
rgba(255, 255, 255, 0.6),
|
1396 |
+
rgba(226, 232, 240, 0.2));
|
1397 |
+
-webkit-mask:
|
1398 |
+
linear-gradient(#fff 0 0) content-box,
|
1399 |
+
linear-gradient(#fff 0 0);
|
1400 |
+
-webkit-mask-composite: xor;
|
1401 |
+
mask-composite: exclude;
|
1402 |
+
pointer-events: none;
|
1403 |
+
}
|
1404 |
+
|
1405 |
+
/* Preserve the airy layout inside the cards */
|
1406 |
+
#global_glass_container .gradio-column { gap: 12px; }
|
1407 |
+
#glass_card .gradio-row { gap: 16px; }
|
1408 |
+
#glass_card .gradio-column { gap: 12px; }
|
1409 |
+
#glass_card .gradio-group { margin: 8px 0; }
|
1410 |
+
|
1411 |
+
/* Text selection matching getting_started colors */
|
1412 |
+
::selection {
|
1413 |
+
background: var(--badge-blue-bg);
|
1414 |
+
color: var(--badge-blue-text);
|
1415 |
+
}
|
1416 |
+
|
1417 |
+
/* Placeholder styling */
|
1418 |
+
::placeholder {
|
1419 |
+
color: var(--text-muted);
|
1420 |
+
opacity: 0.8;
|
1421 |
+
}
|
1422 |
+
|
1423 |
+
/* Improved error state styling */
|
1424 |
+
.error {
|
1425 |
+
border-color: #ef4444 !important;
|
1426 |
+
box-shadow: 0 0 0 2px rgba(239, 68, 68, 0.1) !important;
|
1427 |
+
}
|
1428 |
+
|
1429 |
+
/* Success state using getting_started green */
|
1430 |
+
.success-state {
|
1431 |
+
border-color: var(--primary-green) !important;
|
1432 |
+
box-shadow: 0 0 0 2px rgba(16, 185, 129, 0.1) !important;
|
1433 |
+
}
|
1434 |
+
|
1435 |
+
|
1436 |
+
|
1437 |
+
|
1438 |
+
|
1439 |
+
|
1440 |
+
|
1441 |
+
|
1442 |
+
|
1443 |
+
|
1444 |
+
|
1445 |
+
/* Label styling */
|
1446 |
+
.gradio-label {
|
1447 |
+
color: var(--text-primary);
|
1448 |
+
font-weight: 600;
|
1449 |
+
}
|
1450 |
+
|
1451 |
+
/* Markdown content styling */
|
1452 |
+
.markdown-body {
|
1453 |
+
color: var(--text-secondary);
|
1454 |
+
line-height: 1.6;
|
1455 |
+
}
|
1456 |
+
|
1457 |
+
.markdown-body h1, .markdown-body h2, .markdown-body h3 {
|
1458 |
+
color: var(--text-primary);
|
1459 |
+
}
|
1460 |
+
|
1461 |
+
/* Step indicators styling */
|
1462 |
+
.gradio-markdown h1, .gradio-markdown h2, .gradio-markdown h3,
|
1463 |
+
.gradio-markdown p {
|
1464 |
+
margin: 0.25rem 0;
|
1465 |
+
}
|
1466 |
+
|
1467 |
+
/* Enhanced step indicators with numbers */
|
1468 |
+
.gradio-markdown:contains("1."), .gradio-markdown:contains("2."),
|
1469 |
+
.gradio-markdown:contains("3."), .gradio-markdown:contains("4."),
|
1470 |
+
.gradio-markdown:contains("5."), .gradio-markdown:contains("6."),
|
1471 |
+
.gradio-markdown:contains("7.") {
|
1472 |
+
position: relative;
|
1473 |
+
padding-left: 2.5rem;
|
1474 |
+
color: var(--text-primary);
|
1475 |
+
font-weight: 600;
|
1476 |
+
}
|
1477 |
+
|
1478 |
+
/* Specific button styling */
|
1479 |
+
#undo_btnSEG, #dilate_btn, #erode_btn, #bounding_box_btn {
|
1480 |
+
background: linear-gradient(135deg, rgba(255, 255, 255, 0.98) 0%, rgba(240, 249, 255, 0.95) 100%);
|
1481 |
+
border: 1px solid rgba(14, 165, 233, 0.2);
|
1482 |
+
color: var(--text-secondary);
|
1483 |
+
font-weight: 500;
|
1484 |
+
padding: 8px 16px;
|
1485 |
+
border-radius: 6px;
|
1486 |
+
box-shadow: 0 2px 8px rgba(14, 165, 233, 0.1);
|
1487 |
+
transition: all 0.3s ease;
|
1488 |
+
}
|
1489 |
+
|
1490 |
+
#undo_btnSEG:hover, #dilate_btn:hover, #erode_btn:hover, #bounding_box_btn:hover {
|
1491 |
+
background: var(--primary-blue);
|
1492 |
+
border-color: var(--primary-blue);
|
1493 |
+
color: white;
|
1494 |
+
transform: translateY(-1px);
|
1495 |
+
box-shadow: 0 4px 12px rgba(14, 165, 233, 0.3);
|
1496 |
+
}
|
1497 |
+
|
1498 |
+
/* Submit button enhanced styling - unified with primary blue */
|
1499 |
+
button[variant="primary"], .gradio-button.primary {
|
1500 |
+
background: var(--primary-blue);
|
1501 |
+
border-color: var(--primary-blue);
|
1502 |
+
color: white;
|
1503 |
+
font-weight: 600;
|
1504 |
+
font-size: 1rem;
|
1505 |
+
padding: 12px 24px;
|
1506 |
+
box-shadow: 0 4px 16px rgba(14, 165, 233, 0.25);
|
1507 |
+
transition: all 0.3s ease;
|
1508 |
+
}
|
1509 |
+
|
1510 |
+
button[variant="primary"]:hover, .gradio-button.primary:hover {
|
1511 |
+
background: #0284c7;
|
1512 |
+
border-color: #0284c7;
|
1513 |
+
box-shadow: 0 6px 20px rgba(14, 165, 233, 0.4);
|
1514 |
+
transform: translateY(-2px);
|
1515 |
+
}
|
1516 |
+
|
1517 |
+
/* Improved button states */
|
1518 |
+
button:disabled {
|
1519 |
+
opacity: 0.5;
|
1520 |
+
cursor: not-allowed;
|
1521 |
+
transform: none !important;
|
1522 |
+
box-shadow: none !important;
|
1523 |
+
}
|
1524 |
+
|
1525 |
+
button:disabled::after {
|
1526 |
+
display: none;
|
1527 |
+
}
|
1528 |
+
|
1529 |
+
button.processing {
|
1530 |
+
background: var(--neutral-400) !important;
|
1531 |
+
border-color: var(--neutral-400) !important;
|
1532 |
+
cursor: wait;
|
1533 |
+
animation: processingPulse 2s ease-in-out infinite;
|
1534 |
+
}
|
1535 |
+
|
1536 |
+
@keyframes processingPulse {
|
1537 |
+
0%, 100% { opacity: 0.8; }
|
1538 |
+
50% { opacity: 1; }
|
1539 |
+
}
|
1540 |
+
|
1541 |
+
/* Responsive improvements */
|
1542 |
+
@media (max-width: 768px) {
|
1543 |
+
.header-content {
|
1544 |
+
padding: 1.2rem 1.8rem;
|
1545 |
+
margin: 0 1rem;
|
1546 |
+
}
|
1547 |
+
|
1548 |
+
.main-title {
|
1549 |
+
font-size: 2rem;
|
1550 |
+
}
|
1551 |
+
|
1552 |
+
.title-icon {
|
1553 |
+
font-size: 1.8rem;
|
1554 |
+
}
|
1555 |
+
|
1556 |
+
.subtitle {
|
1557 |
+
font-size: 0.9rem;
|
1558 |
+
}
|
1559 |
+
|
1560 |
+
.header-badges {
|
1561 |
+
gap: 0.6rem;
|
1562 |
+
}
|
1563 |
+
|
1564 |
+
.badge-link {
|
1565 |
+
padding: 0.4rem 0.8rem;
|
1566 |
+
font-size: 0.85rem;
|
1567 |
+
}
|
1568 |
+
|
1569 |
+
.header-badge {
|
1570 |
+
padding: 0.4rem 0.8rem;
|
1571 |
+
font-size: 0.85rem;
|
1572 |
+
}
|
1573 |
+
|
1574 |
+
/* Getting Started responsive */
|
1575 |
+
.getting-started-container {
|
1576 |
+
padding: 1rem;
|
1577 |
+
margin: 0 0.5rem;
|
1578 |
+
}
|
1579 |
+
|
1580 |
+
.guide-title {
|
1581 |
+
font-size: 1.1rem;
|
1582 |
+
}
|
1583 |
+
|
1584 |
+
.guide-subtitle {
|
1585 |
+
font-size: 0.85rem;
|
1586 |
+
}
|
1587 |
+
|
1588 |
+
.step-card {
|
1589 |
+
padding: 0.8rem;
|
1590 |
+
margin-bottom: 1rem;
|
1591 |
+
}
|
1592 |
+
|
1593 |
+
.step-header {
|
1594 |
+
font-size: 0.9rem;
|
1595 |
+
}
|
1596 |
+
|
1597 |
+
.step-number {
|
1598 |
+
width: 22px;
|
1599 |
+
height: 22px;
|
1600 |
+
font-size: 0.7rem;
|
1601 |
+
}
|
1602 |
+
|
1603 |
+
.step-list {
|
1604 |
+
font-size: 0.8rem;
|
1605 |
+
padding-left: 1rem;
|
1606 |
+
}
|
1607 |
+
|
1608 |
+
.tips-card {
|
1609 |
+
padding: 0.6rem;
|
1610 |
+
}
|
1611 |
+
|
1612 |
+
.tips-content {
|
1613 |
+
font-size: 0.75rem;
|
1614 |
+
}
|
1615 |
+
|
1616 |
+
.final-message {
|
1617 |
+
padding: 0.6rem;
|
1618 |
+
}
|
1619 |
+
|
1620 |
+
.final-text {
|
1621 |
+
font-size: 0.8rem;
|
1622 |
+
}
|
1623 |
+
|
1624 |
+
button {
|
1625 |
+
min-height: 44px;
|
1626 |
+
}
|
1627 |
+
|
1628 |
+
input, textarea, select {
|
1629 |
+
min-height: 44px;
|
1630 |
+
}
|
1631 |
+
|
1632 |
+
/* Mobile optimization for subtle effects */
|
1633 |
+
#global_glass_container {
|
1634 |
+
padding: 16px;
|
1635 |
+
margin: 8px;
|
1636 |
+
border-radius: 16px;
|
1637 |
+
}
|
1638 |
+
|
1639 |
+
#global_glass_container::after {
|
1640 |
+
height: 180px;
|
1641 |
+
animation-duration: 18s;
|
1642 |
+
}
|
1643 |
+
|
1644 |
+
#glass_card {
|
1645 |
+
padding: 20px;
|
1646 |
+
margin: 10px;
|
1647 |
+
border-radius: 12px;
|
1648 |
+
}
|
1649 |
+
|
1650 |
+
#glass_card .gradio-row { gap: 12px; }
|
1651 |
+
#glass_card .gradio-column { gap: 12px; }
|
1652 |
+
|
1653 |
+
|
1654 |
+
}
|
1655 |
+
|
1656 |
+
/* Ensure gallery works properly in all screen sizes */
|
1657 |
+
@media (min-width: 1200px) {
|
1658 |
+
#mask_gallery .gradio-gallery, #result_gallery .gradio-gallery {
|
1659 |
+
min-height: 300px !important;
|
1660 |
+
max-height: 80vh !important;
|
1661 |
+
}
|
1662 |
+
|
1663 |
+
.responsive-gallery .grid-container {
|
1664 |
+
grid-template-columns: repeat(auto-fit, minmax(150px, 1fr)) !important;
|
1665 |
+
}
|
1666 |
+
}
|
1667 |
+
|
1668 |
+
/* Fix for gallery duplicate content issue - ensure clean display */
|
1669 |
+
#mask_gallery > div > div:nth-child(n+2),
|
1670 |
+
#result_gallery > div > div:nth-child(n+2) {
|
1671 |
+
display: none !important;
|
1672 |
+
}
|
1673 |
+
|
1674 |
+
#mask_gallery .gradio-gallery:nth-child(n+2),
|
1675 |
+
#result_gallery .gradio-gallery:nth-child(n+2) {
|
1676 |
+
display: none !important;
|
1677 |
+
}
|
1678 |
+
|
1679 |
+
"""
|
app/ui_components.py
ADDED
@@ -0,0 +1,354 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
UI components construction for IC-Custom application.
|
5 |
+
"""
|
6 |
+
import gradio as gr
|
7 |
+
from constants import (
|
8 |
+
ASPECT_RATIO_LABELS,
|
9 |
+
DEFAULT_ASPECT_RATIO,
|
10 |
+
DEFAULT_BRUSH_SIZE
|
11 |
+
)
|
12 |
+
|
13 |
+
|
14 |
+
def create_theme():
|
15 |
+
"""Create and configure the Gradio theme."""
|
16 |
+
theme = gr.themes.Ocean()
|
17 |
+
theme.set(
|
18 |
+
checkbox_label_background_fill_selected="*button_primary_background_fill",
|
19 |
+
checkbox_label_text_color_selected="*button_primary_text_color",
|
20 |
+
)
|
21 |
+
return theme
|
22 |
+
|
23 |
+
|
24 |
+
def create_css():
|
25 |
+
"""Create custom CSS for the application."""
|
26 |
+
from stylesheets import get_css
|
27 |
+
return get_css()
|
28 |
+
|
29 |
+
|
30 |
+
def create_header_section():
|
31 |
+
"""Create the header section with title and description."""
|
32 |
+
from metainfo import head, getting_started
|
33 |
+
|
34 |
+
with gr.Row():
|
35 |
+
gr.HTML(head)
|
36 |
+
|
37 |
+
with gr.Accordion(label="🚀 Getting Started:", open=True, elem_id="accordion"):
|
38 |
+
with gr.Row(equal_height=True):
|
39 |
+
gr.HTML(getting_started)
|
40 |
+
|
41 |
+
|
42 |
+
def create_customization_section():
|
43 |
+
"""Create the customization mode selection section."""
|
44 |
+
with gr.Row():
|
45 |
+
# Add a note to remind users to click Clear before starting
|
46 |
+
md_custmization_mode = gr.Markdown(
|
47 |
+
"1. Select a Customization Mode\n\n*Tip: Please click the Clear button first to reset all states before starting a new task.*"
|
48 |
+
)
|
49 |
+
with gr.Row():
|
50 |
+
custmization_mode = gr.Radio(
|
51 |
+
["Position-aware", "Position-free"],
|
52 |
+
value="Position-aware",
|
53 |
+
scale=1,
|
54 |
+
elem_id="customization_mode_radio",
|
55 |
+
show_label=False,
|
56 |
+
label="Customization Mode",
|
57 |
+
)
|
58 |
+
return custmization_mode, md_custmization_mode
|
59 |
+
|
60 |
+
|
61 |
+
def create_image_input_section():
|
62 |
+
"""Create image input section optimized for left column layout."""
|
63 |
+
# Reference image section
|
64 |
+
md_image_reference = gr.Markdown("2. Input reference image")
|
65 |
+
with gr.Group():
|
66 |
+
image_reference = gr.Image(
|
67 |
+
label="Reference Image",
|
68 |
+
type="pil",
|
69 |
+
interactive=True,
|
70 |
+
height=320,
|
71 |
+
container=True,
|
72 |
+
elem_id="reference_image"
|
73 |
+
)
|
74 |
+
|
75 |
+
# Input mask mode selection
|
76 |
+
md_input_mask_mode = gr.Markdown("3. Select input mask mode")
|
77 |
+
with gr.Group():
|
78 |
+
input_mask_mode = gr.Radio(
|
79 |
+
["Precise mask", "User-drawn mask"],
|
80 |
+
value="Precise mask",
|
81 |
+
elem_id="input_mask_mode_radio",
|
82 |
+
show_label=False,
|
83 |
+
label="Input Mask Mode",
|
84 |
+
)
|
85 |
+
|
86 |
+
# Target image section
|
87 |
+
md_target_image = gr.Markdown("4. Input target image & mask (Iterate clicking or brushing until the target is covered)")
|
88 |
+
|
89 |
+
# Precise mask mode
|
90 |
+
with gr.Group():
|
91 |
+
image_target_1 = gr.Image(
|
92 |
+
type="pil",
|
93 |
+
label="Target Image (precise mask)",
|
94 |
+
interactive=True,
|
95 |
+
visible=True,
|
96 |
+
height=500,
|
97 |
+
container=True,
|
98 |
+
elem_id="target_image_1"
|
99 |
+
)
|
100 |
+
with gr.Row():
|
101 |
+
undo_target_seg_button = gr.Button(
|
102 |
+
'Undo seg',
|
103 |
+
elem_id="undo_btnSEG",
|
104 |
+
visible=True,
|
105 |
+
size="sm",
|
106 |
+
scale=1
|
107 |
+
)
|
108 |
+
|
109 |
+
# User-drawn mask mode
|
110 |
+
with gr.Group():
|
111 |
+
image_target_2 = gr.ImageEditor(
|
112 |
+
label="Target Image (user-drawn mask)",
|
113 |
+
type="pil",
|
114 |
+
brush=gr.Brush(colors=["#FFFFFF"], default_size=DEFAULT_BRUSH_SIZE, color_mode="fixed"),
|
115 |
+
layers=False,
|
116 |
+
interactive=True,
|
117 |
+
sources=["upload", "clipboard"],
|
118 |
+
placeholder="Please click here or the icon to upload the image.",
|
119 |
+
visible=False,
|
120 |
+
height=500,
|
121 |
+
container=True,
|
122 |
+
elem_id="target_image_2",
|
123 |
+
fixed_canvas=True,
|
124 |
+
)
|
125 |
+
|
126 |
+
return (image_reference, input_mask_mode, image_target_1, image_target_2,
|
127 |
+
undo_target_seg_button, md_image_reference, md_input_mask_mode, md_target_image)
|
128 |
+
|
129 |
+
|
130 |
+
def create_prompt_section():
|
131 |
+
"""Create the text prompt input section with improved layout."""
|
132 |
+
md_prompt = gr.Markdown("5. Input text prompt (optional)")
|
133 |
+
with gr.Group():
|
134 |
+
prompt = gr.Textbox(
|
135 |
+
placeholder="Please input the description for the target scene.",
|
136 |
+
value="",
|
137 |
+
lines=2,
|
138 |
+
show_label=False,
|
139 |
+
label="Text Prompt",
|
140 |
+
container=True,
|
141 |
+
elem_id="text_prompt"
|
142 |
+
)
|
143 |
+
|
144 |
+
with gr.Row():
|
145 |
+
vlm_generate_btn = gr.Button(
|
146 |
+
"🤖 VLM Auto-generate",
|
147 |
+
scale=1,
|
148 |
+
elem_id="vlm_generate_btn",
|
149 |
+
variant="secondary"
|
150 |
+
)
|
151 |
+
vlm_polish_btn = gr.Button(
|
152 |
+
"✨ VLM Auto-polish",
|
153 |
+
scale=1,
|
154 |
+
elem_id="vlm_polish_btn",
|
155 |
+
variant="secondary"
|
156 |
+
)
|
157 |
+
|
158 |
+
return prompt, vlm_generate_btn, vlm_polish_btn, md_prompt
|
159 |
+
|
160 |
+
|
161 |
+
def create_advanced_options_section():
|
162 |
+
"""Create the advanced options section."""
|
163 |
+
with gr.Accordion("Advanced Options", open=False, elem_id="accordion1"):
|
164 |
+
with gr.Group():
|
165 |
+
aspect_ratio = gr.Dropdown(
|
166 |
+
label="Output aspect ratio",
|
167 |
+
choices=ASPECT_RATIO_LABELS,
|
168 |
+
value=DEFAULT_ASPECT_RATIO,
|
169 |
+
interactive=True,
|
170 |
+
allow_custom_value=False,
|
171 |
+
filterable=False,
|
172 |
+
elem_id="aspect_ratio_dropdown"
|
173 |
+
)
|
174 |
+
|
175 |
+
with gr.Group():
|
176 |
+
seg_ref_mode = gr.Radio(
|
177 |
+
label="Segmentation mode",
|
178 |
+
choices=["Full Ref", "Masked Ref"],
|
179 |
+
value="Full Ref",
|
180 |
+
elem_id="seg_ref_mode_radio"
|
181 |
+
)
|
182 |
+
move_to_center = gr.Checkbox(label="Move object to center", value=False, elem_id="move_to_center_checkbox")
|
183 |
+
|
184 |
+
with gr.Group():
|
185 |
+
with gr.Row():
|
186 |
+
use_background_preservation = gr.Checkbox(label="Use background preservation", value=False, elem_id="use_bg_preservation_checkbox")
|
187 |
+
background_blend_threshold = gr.Slider(
|
188 |
+
label="Background blend threshold",
|
189 |
+
minimum=0,
|
190 |
+
maximum=1,
|
191 |
+
step=0.1,
|
192 |
+
value=0.5
|
193 |
+
)
|
194 |
+
|
195 |
+
with gr.Group():
|
196 |
+
with gr.Row():
|
197 |
+
seed = gr.Slider(
|
198 |
+
label="Seed (-1 for random): ",
|
199 |
+
minimum=-1,
|
200 |
+
maximum=2147483647,
|
201 |
+
step=1,
|
202 |
+
value=-1,
|
203 |
+
scale=4
|
204 |
+
)
|
205 |
+
|
206 |
+
num_images_per_prompt = gr.Slider(
|
207 |
+
label="Num samples",
|
208 |
+
minimum=1,
|
209 |
+
maximum=4,
|
210 |
+
step=1,
|
211 |
+
value=1,
|
212 |
+
scale=1
|
213 |
+
)
|
214 |
+
|
215 |
+
with gr.Group():
|
216 |
+
with gr.Row():
|
217 |
+
guidance = gr.Slider(
|
218 |
+
label="Guidance scale",
|
219 |
+
minimum=10,
|
220 |
+
maximum=65,
|
221 |
+
step=1,
|
222 |
+
value=40,
|
223 |
+
)
|
224 |
+
num_steps = gr.Slider(
|
225 |
+
label="Number of inference steps",
|
226 |
+
minimum=1,
|
227 |
+
maximum=60,
|
228 |
+
step=1,
|
229 |
+
value=32,
|
230 |
+
)
|
231 |
+
with gr.Row():
|
232 |
+
true_gs = gr.Slider(
|
233 |
+
label="True GS",
|
234 |
+
minimum=1,
|
235 |
+
maximum=10,
|
236 |
+
step=1,
|
237 |
+
value=3,
|
238 |
+
)
|
239 |
+
|
240 |
+
return (aspect_ratio, seg_ref_mode, move_to_center, use_background_preservation,
|
241 |
+
background_blend_threshold, seed, num_images_per_prompt, guidance, num_steps, true_gs)
|
242 |
+
|
243 |
+
|
244 |
+
def create_mask_operation_section():
|
245 |
+
"""Create mask operation section optimized for right column (outputs)."""
|
246 |
+
md_mask_operation = gr.Markdown("6. View or modify the target mask")
|
247 |
+
|
248 |
+
with gr.Group():
|
249 |
+
# Mask gallery with responsive layout
|
250 |
+
mask_gallery = gr.Gallery(
|
251 |
+
label='Mask Preview',
|
252 |
+
show_label=False,
|
253 |
+
interactive=False,
|
254 |
+
columns=2,
|
255 |
+
rows=1,
|
256 |
+
height="auto",
|
257 |
+
object_fit="contain",
|
258 |
+
preview=True,
|
259 |
+
allow_preview=True,
|
260 |
+
selected_index=0,
|
261 |
+
elem_id="mask_gallery",
|
262 |
+
elem_classes=["custom-gallery", "responsive-gallery"],
|
263 |
+
container=True,
|
264 |
+
show_fullscreen_button=False
|
265 |
+
)
|
266 |
+
|
267 |
+
# Mask operation buttons - horizontal layout
|
268 |
+
with gr.Row():
|
269 |
+
dilate_button = gr.Button(
|
270 |
+
'🔍 Dilate',
|
271 |
+
elem_id="dilate_btn",
|
272 |
+
variant="secondary",
|
273 |
+
size="sm",
|
274 |
+
scale=1
|
275 |
+
)
|
276 |
+
erode_button = gr.Button(
|
277 |
+
'🔽 Erode',
|
278 |
+
elem_id="erode_btn",
|
279 |
+
variant="secondary",
|
280 |
+
size="sm",
|
281 |
+
scale=1
|
282 |
+
)
|
283 |
+
bounding_box_button = gr.Button(
|
284 |
+
'📦 Bounding box',
|
285 |
+
elem_id="bounding_box_btn",
|
286 |
+
variant="secondary",
|
287 |
+
size="sm",
|
288 |
+
scale=1
|
289 |
+
)
|
290 |
+
|
291 |
+
return mask_gallery, dilate_button, erode_button, bounding_box_button, md_mask_operation
|
292 |
+
|
293 |
+
|
294 |
+
def create_output_section():
|
295 |
+
"""Create the output section optimized for right column."""
|
296 |
+
md_submit = gr.Markdown("7. Submit and view the output")
|
297 |
+
|
298 |
+
# Generation controls at top for better workflow
|
299 |
+
with gr.Group():
|
300 |
+
with gr.Row():
|
301 |
+
submit_button = gr.Button(
|
302 |
+
"💫 Generate",
|
303 |
+
variant="primary",
|
304 |
+
scale=3,
|
305 |
+
size="lg"
|
306 |
+
)
|
307 |
+
clear_btn = gr.ClearButton(
|
308 |
+
scale=1,
|
309 |
+
variant="secondary",
|
310 |
+
value="🗑️ Clear"
|
311 |
+
)
|
312 |
+
|
313 |
+
# Results gallery with responsive layout
|
314 |
+
with gr.Group():
|
315 |
+
result_gallery = gr.Gallery(
|
316 |
+
label='Generated Results',
|
317 |
+
show_label=False,
|
318 |
+
interactive=False,
|
319 |
+
columns=1,
|
320 |
+
rows=1,
|
321 |
+
height="auto",
|
322 |
+
object_fit="contain",
|
323 |
+
preview=True,
|
324 |
+
allow_preview=True,
|
325 |
+
selected_index=0,
|
326 |
+
elem_id="result_gallery",
|
327 |
+
elem_classes=["custom-gallery", "responsive-gallery"],
|
328 |
+
container=True,
|
329 |
+
show_fullscreen_button=False
|
330 |
+
)
|
331 |
+
|
332 |
+
return result_gallery, submit_button, clear_btn, md_submit
|
333 |
+
|
334 |
+
|
335 |
+
def create_examples_section(examples_list, inputs, outputs, fn):
|
336 |
+
"""Create the examples section with required arguments."""
|
337 |
+
examples = gr.Examples(
|
338 |
+
examples=examples_list,
|
339 |
+
inputs=inputs,
|
340 |
+
outputs=outputs,
|
341 |
+
fn=fn,
|
342 |
+
cache_examples=False,
|
343 |
+
examples_per_page=10,
|
344 |
+
run_on_click=True,
|
345 |
+
)
|
346 |
+
return examples
|
347 |
+
|
348 |
+
|
349 |
+
def create_citation_section():
|
350 |
+
"""Create the citation section."""
|
351 |
+
from metainfo import citation
|
352 |
+
|
353 |
+
with gr.Row():
|
354 |
+
gr.Markdown(citation)
|
app/utils.py
ADDED
@@ -0,0 +1,429 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import base64
|
4 |
+
from io import BytesIO
|
5 |
+
from typing import Optional
|
6 |
+
|
7 |
+
from PIL import Image
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
|
12 |
+
from segment_anything import SamPredictor, sam_model_registry
|
13 |
+
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
|
14 |
+
from qwen_vl_utils import process_vision_info
|
15 |
+
from huggingface_hub import hf_hub_download
|
16 |
+
|
17 |
+
sys.path.append(os.getcwd())
|
18 |
+
import BEN2
|
19 |
+
|
20 |
+
|
21 |
+
## Ordinary function
|
22 |
+
def resize(image: Image.Image,
|
23 |
+
target_width: int,
|
24 |
+
target_height: int,
|
25 |
+
interpolate: Image.Resampling = Image.Resampling.LANCZOS,
|
26 |
+
return_type: str = "pil") -> Image.Image | np.ndarray:
|
27 |
+
"""
|
28 |
+
Crops and resizes an image while preserving the aspect ratio.
|
29 |
+
|
30 |
+
Args:
|
31 |
+
image (Image.Image): Input PIL image to be cropped and resized.
|
32 |
+
target_width (int): Target width of the output image.
|
33 |
+
target_height (int): Target height of the output image.
|
34 |
+
interpolate (Image.Resampling): The interpolation method.
|
35 |
+
return_type (str): The type of the output image.
|
36 |
+
|
37 |
+
Returns:
|
38 |
+
Image.Image: Cropped and resized image.
|
39 |
+
"""
|
40 |
+
# Original dimensions
|
41 |
+
resized_image = image.resize((target_width, target_height), interpolate)
|
42 |
+
if return_type == "pil":
|
43 |
+
return resized_image
|
44 |
+
elif return_type == "np":
|
45 |
+
return np.array(resized_image)
|
46 |
+
else:
|
47 |
+
raise ValueError(f"Invalid return type: {return_type}")
|
48 |
+
|
49 |
+
|
50 |
+
def resize_long_edge(
|
51 |
+
image: Image.Image,
|
52 |
+
long_edge_size: int,
|
53 |
+
interpolate: Image.Resampling = Image.Resampling.LANCZOS,
|
54 |
+
return_type: str = "pil"
|
55 |
+
) -> np.ndarray | Image.Image:
|
56 |
+
"""
|
57 |
+
Resize the long edge of the image to the long_edge_size.
|
58 |
+
|
59 |
+
Args:
|
60 |
+
image (Image.Image): The image to resize.
|
61 |
+
long_edge_size (int): The size of the long edge.
|
62 |
+
interpolate (Image.Resampling): The interpolation method.
|
63 |
+
|
64 |
+
Returns:
|
65 |
+
np.ndarray: The resized image.
|
66 |
+
"""
|
67 |
+
w, h = image.size
|
68 |
+
scale_ratio = long_edge_size / max(h, w)
|
69 |
+
output_w = int(w * scale_ratio)
|
70 |
+
output_h = int(h * scale_ratio)
|
71 |
+
image = resize(image, target_width=int(output_w), target_height=int(output_h), interpolate=interpolate, return_type=return_type)
|
72 |
+
return image
|
73 |
+
|
74 |
+
|
75 |
+
def ensure_divisible_by_value(
|
76 |
+
image: Image.Image | np.ndarray,
|
77 |
+
value: int = 8,
|
78 |
+
interpolate: Image.Resampling = Image.Resampling.NEAREST,
|
79 |
+
return_type: str = "np"
|
80 |
+
) -> np.ndarray | Image.Image:
|
81 |
+
"""
|
82 |
+
Ensure the image dimensions are divisible by value.
|
83 |
+
|
84 |
+
Args:
|
85 |
+
image_pil (Image.Image): The image to ensure divisible by value.
|
86 |
+
value (int): The value to ensure divisible by.
|
87 |
+
interpolate (Image.Resampling): The interpolation method.
|
88 |
+
return_type (str): The type of the output image.
|
89 |
+
|
90 |
+
Returns:
|
91 |
+
np.ndarray | Image.Image: The resized image.
|
92 |
+
"""
|
93 |
+
|
94 |
+
if isinstance(image, np.ndarray):
|
95 |
+
image = Image.fromarray(image)
|
96 |
+
|
97 |
+
w, h = image.size
|
98 |
+
|
99 |
+
w = (w // value) * value
|
100 |
+
h = (h // value) * value
|
101 |
+
image = resize(image, w, h, interpolate=interpolate, return_type=return_type)
|
102 |
+
return image
|
103 |
+
|
104 |
+
|
105 |
+
def resize_paired_image(
|
106 |
+
image_reference: np.ndarray,
|
107 |
+
image_target: np.ndarray,
|
108 |
+
mask_target: np.ndarray,
|
109 |
+
force_resize_long_edge: bool = False,
|
110 |
+
return_type: str = "np"
|
111 |
+
) -> tuple[np.ndarray | Image.Image, np.ndarray | Image.Image, np.ndarray | Image.Image]:
|
112 |
+
|
113 |
+
if isinstance(image_reference, np.ndarray):
|
114 |
+
image_reference = Image.fromarray(image_reference)
|
115 |
+
if isinstance(image_target, np.ndarray):
|
116 |
+
image_target = Image.fromarray(image_target)
|
117 |
+
if isinstance(mask_target, np.ndarray):
|
118 |
+
mask_target = Image.fromarray(mask_target)
|
119 |
+
|
120 |
+
if force_resize_long_edge:
|
121 |
+
image_reference = resize_long_edge(image_reference, 1024, interpolate=Image.Resampling.LANCZOS, return_type=return_type)
|
122 |
+
image_target = resize_long_edge(image_target, 1024, interpolate=Image.Resampling.LANCZOS, return_type=return_type)
|
123 |
+
mask_target = resize_long_edge(mask_target, 1024, interpolate=Image.Resampling.NEAREST, return_type=return_type)
|
124 |
+
|
125 |
+
if isinstance(image_reference, Image.Image):
|
126 |
+
ref_width, ref_height = image_reference.size
|
127 |
+
target_width, target_height = image_target.size
|
128 |
+
else:
|
129 |
+
ref_height, ref_width = image_reference.shape[:2]
|
130 |
+
target_width, target_height = image_target.shape[:2]
|
131 |
+
|
132 |
+
# resize the ref image to the same height as the target image and ensure the ratio remains the same
|
133 |
+
if ref_height != target_height:
|
134 |
+
scale_ratio = target_height / ref_height
|
135 |
+
image_reference = resize(image_reference, int(ref_width * scale_ratio), target_height, interpolate=Image.Resampling.LANCZOS, return_type=return_type)
|
136 |
+
|
137 |
+
if return_type == "pil":
|
138 |
+
image_reference = Image.fromarray(image_reference) if isinstance(image_reference, np.ndarray) else image_reference
|
139 |
+
image_target = Image.fromarray(image_target) if isinstance(image_target, np.ndarray) else image_target
|
140 |
+
mask_target = Image.fromarray(mask_target) if isinstance(mask_target, np.ndarray) else mask_target
|
141 |
+
return image_reference, image_target, mask_target
|
142 |
+
else:
|
143 |
+
image_reference = np.array(image_reference) if isinstance(image_reference, Image.Image) else image_reference
|
144 |
+
image_target = np.array(image_target) if isinstance(image_target, Image.Image) else image_target
|
145 |
+
mask_target = np.array(mask_target) if isinstance(mask_target, Image.Image) else mask_target
|
146 |
+
return image_reference, image_target, mask_target
|
147 |
+
|
148 |
+
|
149 |
+
def prepare_input_images(
|
150 |
+
img_ref: np.ndarray,
|
151 |
+
custmization_mode: str,
|
152 |
+
img_target: Optional[np.ndarray] = None,
|
153 |
+
mask_target: Optional[np.ndarray] = None,
|
154 |
+
width: Optional[int] = None,
|
155 |
+
height: Optional[int] = None,
|
156 |
+
force_resize_long_edge: bool = False,
|
157 |
+
return_type: str = "np"
|
158 |
+
) -> tuple[np.ndarray | Image.Image, np.ndarray | Image.Image, np.ndarray | Image.Image]:
|
159 |
+
|
160 |
+
|
161 |
+
if custmization_mode.lower() == "position-free":
|
162 |
+
img_target = np.ones_like(img_ref) * 255
|
163 |
+
mask_target = np.zeros_like(img_ref)
|
164 |
+
|
165 |
+
if isinstance(width, int) and isinstance(height, int):
|
166 |
+
img_ref = resize(Image.fromarray(img_ref), width, height, interpolate=Image.Resampling.LANCZOS, return_type=return_type)
|
167 |
+
img_target = resize(Image.fromarray(img_target), width, height, interpolate=Image.Resampling.LANCZOS, return_type=return_type)
|
168 |
+
mask_target = resize(Image.fromarray(mask_target), width, height, interpolate=Image.Resampling.NEAREST, return_type=return_type)
|
169 |
+
else:
|
170 |
+
img_ref, img_target, mask_target = resize_paired_image(img_ref, img_target, mask_target, force_resize_long_edge, return_type=return_type)
|
171 |
+
|
172 |
+
img_ref = ensure_divisible_by_value(img_ref, value=16, interpolate=Image.Resampling.LANCZOS, return_type=return_type)
|
173 |
+
img_target = ensure_divisible_by_value(img_target, value=16, interpolate=Image.Resampling.LANCZOS, return_type=return_type)
|
174 |
+
mask_target = ensure_divisible_by_value(mask_target, value=16, interpolate=Image.Resampling.NEAREST, return_type=return_type)
|
175 |
+
|
176 |
+
return img_ref, img_target, mask_target
|
177 |
+
|
178 |
+
|
179 |
+
def get_mask_type_ids(custmization_mode: str, input_mask_mode: str) -> int:
|
180 |
+
if custmization_mode.lower() == "position-free":
|
181 |
+
return torch.tensor([0])
|
182 |
+
elif custmization_mode.lower() == "position-aware":
|
183 |
+
if "precise" in input_mask_mode.lower():
|
184 |
+
return torch.tensor([1])
|
185 |
+
else:
|
186 |
+
return torch.tensor([2])
|
187 |
+
else:
|
188 |
+
raise ValueError(f"Invalid custmization mode: {custmization_mode}")
|
189 |
+
|
190 |
+
|
191 |
+
def scale_image(image_np, is_mask: bool = False):
|
192 |
+
"""
|
193 |
+
Scale the image to the range of [-1, 1] if not a mask, otherwise scale to [0, 1].
|
194 |
+
|
195 |
+
Args:
|
196 |
+
image_np (np.ndarray): Input image.
|
197 |
+
is_mask (bool): Whether the image is a mask.
|
198 |
+
Returns:
|
199 |
+
np.ndarray: Scaled image.
|
200 |
+
"""
|
201 |
+
if is_mask:
|
202 |
+
image_np = image_np / 255.0
|
203 |
+
else:
|
204 |
+
image_np = image_np / 255.0
|
205 |
+
image_np = image_np * 2 - 1
|
206 |
+
return image_np
|
207 |
+
|
208 |
+
|
209 |
+
def get_sam_predictor(sam_ckpt_path, device):
|
210 |
+
"""
|
211 |
+
Get the SAM predictor.
|
212 |
+
Args:
|
213 |
+
sam_ckpt_path (str): The path to the SAM checkpoint.
|
214 |
+
device (str): The device to load the model on.
|
215 |
+
Returns:
|
216 |
+
SamPredictor: The SAM predictor.
|
217 |
+
"""
|
218 |
+
if not os.path.exists(sam_ckpt_path):
|
219 |
+
sam_ckpt_path = hf_hub_download(repo_id="HCMUE-Research/SAM-vit-h", filename="sam_vit_h_4b8939.pth")
|
220 |
+
|
221 |
+
sam = sam_model_registry['vit_h'](checkpoint=sam_ckpt_path).to(device)
|
222 |
+
sam.eval()
|
223 |
+
predictor = SamPredictor(sam)
|
224 |
+
|
225 |
+
return predictor
|
226 |
+
|
227 |
+
|
228 |
+
def image_to_base64(img):
|
229 |
+
"""
|
230 |
+
Convert an image to a base64 string.
|
231 |
+
Args:
|
232 |
+
img (PIL.Image.Image): The image to convert.
|
233 |
+
Returns:
|
234 |
+
str: The base64 string.
|
235 |
+
"""
|
236 |
+
buffered = BytesIO()
|
237 |
+
img.save(buffered, format="PNG")
|
238 |
+
img_bytes = buffered.getvalue()
|
239 |
+
return base64.b64encode(img_bytes).decode('utf-8')
|
240 |
+
|
241 |
+
|
242 |
+
def get_vlm(vlm_ckpt_path, device, torch_dtype):
|
243 |
+
"""
|
244 |
+
Get the VLM pipeline.
|
245 |
+
Args:
|
246 |
+
vlm_ckpt_path (str): The path to the VLM checkpoint.
|
247 |
+
device (str): The device to load the model on.
|
248 |
+
torch_dtype (torch.dtype): The data type of the model.
|
249 |
+
Returns:
|
250 |
+
tuple: The processor and model.
|
251 |
+
"""
|
252 |
+
if not os.path.exists(vlm_ckpt_path):
|
253 |
+
vlm_ckpt_path = "Qwen/Qwen2.5-VL-7B-Instruct"
|
254 |
+
|
255 |
+
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
256 |
+
vlm_ckpt_path, torch_dtype=torch_dtype).to(device)
|
257 |
+
processor = AutoProcessor.from_pretrained(vlm_ckpt_path)
|
258 |
+
|
259 |
+
|
260 |
+
return processor, model
|
261 |
+
|
262 |
+
|
263 |
+
def construct_vlm_gen_prompt(image_target, image_reference, target_mask, custmization_mode):
|
264 |
+
"""
|
265 |
+
Construct the VLM generation prompt.
|
266 |
+
Args:
|
267 |
+
image_target (np.ndarray): The target image.
|
268 |
+
image_reference (np.ndarray): The reference image.
|
269 |
+
target_mask (np.ndarray): The target mask.
|
270 |
+
custmization_mode (str): The customization mode.
|
271 |
+
Returns:
|
272 |
+
list: The messages.
|
273 |
+
"""
|
274 |
+
if custmization_mode.lower() == "position-free":
|
275 |
+
image_reference_pil = Image.fromarray(image_reference.astype(np.uint8))
|
276 |
+
image_reference_base_64 = image_to_base64(image_reference_pil)
|
277 |
+
messages = [
|
278 |
+
{
|
279 |
+
"role": "system",
|
280 |
+
"content": "I will input a reference image. Please identify the main subject/object in this image and generate a new description by placing this subject in a completely different scene or context. For example, if the reference image shows a rabbit sitting in a garden surrounded by green leaves and roses, you could generate a description like 'The rabbit is sitting on a rocky cliff overlooking a serene ocean, with the sun setting behind it, casting a warm glow over the scene'. Please directly output the new description without explaining your thought process. The description should not exceed 256 tokens."
|
281 |
+
},
|
282 |
+
{
|
283 |
+
"role": "user",
|
284 |
+
"content": [
|
285 |
+
{
|
286 |
+
"type": "image",
|
287 |
+
"image": f"data:image;base64,{image_reference_base_64}"
|
288 |
+
},
|
289 |
+
],
|
290 |
+
}
|
291 |
+
]
|
292 |
+
return messages
|
293 |
+
else:
|
294 |
+
image_reference_pil = Image.fromarray(image_reference.astype(np.uint8))
|
295 |
+
image_reference_base_64 = image_to_base64(image_reference_pil)
|
296 |
+
|
297 |
+
target_mask_binary = target_mask > 127.5
|
298 |
+
masked_image_target = image_target * target_mask_binary
|
299 |
+
masked_image_target_pil = Image.fromarray(masked_image_target.astype(np.uint8))
|
300 |
+
masked_image_target_base_64 = image_to_base64(masked_image_target_pil)
|
301 |
+
|
302 |
+
|
303 |
+
messages = [
|
304 |
+
{
|
305 |
+
"role": "system",
|
306 |
+
"content": "I will input a reference image and a target image with its main subject area masked (in black). Please directly describe the scene where the main subject/object from the reference image is placed into the masked area of the target image. Focus on describing the final combined scene, making sure to clearly describe both the object from the reference image and the background/environment from the target image. For example, if the reference shows a white cat with orange stripes on a beach and the target shows a masked area in a garden with blooming roses and tulips, directly describe 'A white cat with orange stripes sits elegantly among the vibrant red roses and yellow tulips in the lush garden, surrounded by green foliage.' The description should not exceed 256 tokens."
|
307 |
+
},
|
308 |
+
{
|
309 |
+
"role": "user",
|
310 |
+
"content": [
|
311 |
+
{
|
312 |
+
"type": "image",
|
313 |
+
"image": f"data:image;base64,{image_reference_base_64}"
|
314 |
+
},
|
315 |
+
{
|
316 |
+
"type": "image",
|
317 |
+
"image": f"data:image;base64,{masked_image_target_base_64}"
|
318 |
+
}
|
319 |
+
],
|
320 |
+
}
|
321 |
+
]
|
322 |
+
return messages
|
323 |
+
|
324 |
+
|
325 |
+
def construct_vlm_polish_prompt(prompt):
|
326 |
+
"""
|
327 |
+
Construct the VLM polish prompt.
|
328 |
+
Args:
|
329 |
+
prompt (str): The prompt to polish.
|
330 |
+
Returns:
|
331 |
+
list: The messages.
|
332 |
+
"""
|
333 |
+
messages = [
|
334 |
+
{
|
335 |
+
"role": "system",
|
336 |
+
"content": "You are a helpful assistant that can polish the text prompt to make it more specific, detailed, and complete. Please directly output the polished prompt without explaining your thought process. The prompt should not exceed 256 tokens."
|
337 |
+
},
|
338 |
+
{
|
339 |
+
"role": "user",
|
340 |
+
"content": prompt
|
341 |
+
}
|
342 |
+
]
|
343 |
+
return messages
|
344 |
+
|
345 |
+
|
346 |
+
def run_vlm(vlm_processor, vlm_model, messages, device):
|
347 |
+
"""
|
348 |
+
Run the VLM.
|
349 |
+
Args:
|
350 |
+
vlm_processor (torch.nn.Module): The VLM processor.
|
351 |
+
vlm_model (torch.nn.Module): The VLM model.
|
352 |
+
messages (list): The messages.
|
353 |
+
device (str): The device to run the model on.
|
354 |
+
Returns:
|
355 |
+
str: The output text.
|
356 |
+
"""
|
357 |
+
text = vlm_processor.apply_chat_template(
|
358 |
+
messages, tokenize=False, add_generation_prompt=True)
|
359 |
+
|
360 |
+
image_inputs, video_inputs = process_vision_info(messages)
|
361 |
+
inputs = vlm_processor(
|
362 |
+
text=[text],
|
363 |
+
images=image_inputs,
|
364 |
+
videos=video_inputs,
|
365 |
+
padding=True,
|
366 |
+
return_tensors="pt",
|
367 |
+
)
|
368 |
+
inputs = inputs.to(device)
|
369 |
+
# Inference
|
370 |
+
generated_ids = vlm_model.generate(**inputs, do_sample=True, num_beams=4, temperature=1.5, max_new_tokens=128)
|
371 |
+
generated_ids_trimmed = [
|
372 |
+
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
373 |
+
]
|
374 |
+
output_text = vlm_processor.batch_decode(
|
375 |
+
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
376 |
+
)[0]
|
377 |
+
return output_text
|
378 |
+
|
379 |
+
|
380 |
+
def get_ben2_model(ben2_model_path, device):
|
381 |
+
"""
|
382 |
+
Get the BEN2 model.
|
383 |
+
Args:
|
384 |
+
ben2_model_path (str): The path to the BEN2 model.
|
385 |
+
device (str): The device to load the model on.
|
386 |
+
Returns:
|
387 |
+
BEN2: The BEN2 model.
|
388 |
+
"""
|
389 |
+
if not os.path.exists(ben2_model_path):
|
390 |
+
ben2_model_path = hf_hub_download(repo_id="PramaLLC/BEN2", filename="BEN2_Base.pth")
|
391 |
+
|
392 |
+
ben2_model = BEN2.BEN_Base().to(device)
|
393 |
+
ben2_model.loadcheckpoints(model_path=ben2_model_path)
|
394 |
+
return ben2_model
|
395 |
+
|
396 |
+
|
397 |
+
def make_dict_img_mask(img_path, mask_path):
|
398 |
+
"""
|
399 |
+
Make a dictionary of the image and mask for gr.ImageEditor.
|
400 |
+
Keep interface, not used in the gradio app.
|
401 |
+
Args:
|
402 |
+
img_path (str): The path to the image.
|
403 |
+
mask_path (str): The path to the mask.
|
404 |
+
Returns:
|
405 |
+
dict: The dictionary of the image and mask.
|
406 |
+
"""
|
407 |
+
from PIL import ImageOps
|
408 |
+
background = Image.open(img_path).convert("RGBA")
|
409 |
+
layers = [
|
410 |
+
Image.merge("RGBA", (
|
411 |
+
Image.new("L", Image.open(mask_path).size, 255), # R channel
|
412 |
+
Image.new("L", Image.open(mask_path).size, 255), # G channel
|
413 |
+
Image.new("L", Image.open(mask_path).size, 255), # B channel
|
414 |
+
ImageOps.invert(Image.open(mask_path).convert("L")) # Inverted alpha channel
|
415 |
+
))
|
416 |
+
]
|
417 |
+
# Combine layers with background by replacing the alpha channel
|
418 |
+
background = np.array(background.convert("RGB"))
|
419 |
+
_, _, _, layer_alpha = layers[0].split()
|
420 |
+
layer_alpha = np.array(layer_alpha)[:,:,np.newaxis]
|
421 |
+
composite = background * (1 - (layer_alpha > 0)) + np.ones_like(background) * (layer_alpha > 0) * 255
|
422 |
+
|
423 |
+
|
424 |
+
composite = Image.fromarray(composite.astype("uint8")).convert("RGBA")
|
425 |
+
return {
|
426 |
+
'background': background,
|
427 |
+
'layers': layers,
|
428 |
+
'composite': composite
|
429 |
+
}
|
assets/gradio/pos_aware/001/hypher_params.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7609eddfca9279636bdbbcaa25fc0c04b52816fa563caa220668a30526004e81
|
3 |
+
size 169
|
assets/gradio/pos_aware/001/img_gen.png
ADDED
![]() |
Git LFS Details
|
assets/gradio/pos_aware/001/img_ref.png
ADDED
![]() |
Git LFS Details
|
assets/gradio/pos_aware/001/img_target.png
ADDED
![]() |
Git LFS Details
|
assets/gradio/pos_aware/001/mask_target.png
ADDED
![]() |
Git LFS Details
|
assets/gradio/pos_aware/002/hypher_params.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3e344b14750ee1f3ab780df3ce9871684f40cca81820ea0dcbc4896dd4b30a42
|
3 |
+
size 170
|
assets/gradio/pos_aware/002/img_gen.png
ADDED
![]() |
Git LFS Details
|
assets/gradio/pos_aware/002/img_ref.png
ADDED
![]() |
Git LFS Details
|
assets/gradio/pos_aware/002/img_target.png
ADDED
![]() |
Git LFS Details
|
assets/gradio/pos_aware/002/mask_target.png
ADDED
![]() |
Git LFS Details
|
assets/gradio/pos_aware/003/hypher_params.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:76e365ad0ac4b41937a6ff5d2817e73f3e662240b94834e844b59687c45c01f0
|
3 |
+
size 311
|
assets/gradio/pos_aware/003/img_gen.png
ADDED
![]() |
Git LFS Details
|
assets/gradio/pos_aware/003/img_ref.png
ADDED
![]() |
Git LFS Details
|
assets/gradio/pos_aware/003/img_target.png
ADDED
![]() |
Git LFS Details
|
assets/gradio/pos_aware/003/mask_target.png
ADDED
![]() |
Git LFS Details
|
assets/gradio/pos_aware/004/hypher_params.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6c8a478290f495492d4361f0c7b429f1bf216d73eb89cb460acb210d7bb53901
|
3 |
+
size 174
|
assets/gradio/pos_aware/004/img_gen.png
ADDED
![]() |
Git LFS Details
|
assets/gradio/pos_aware/004/img_ref.png
ADDED
![]() |
Git LFS Details
|
assets/gradio/pos_aware/004/img_target.png
ADDED
![]() |
Git LFS Details
|
assets/gradio/pos_aware/004/mask_target.png
ADDED
![]() |
Git LFS Details
|
assets/gradio/pos_aware/005/hypher_params.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c0c8d875d0af12f57b70c596f70dfda334199990a819e96b67c8968538c7cdd6
|
3 |
+
size 173
|
assets/gradio/pos_aware/005/img_gen.png
ADDED
![]() |
Git LFS Details
|
assets/gradio/pos_aware/005/img_ref.png
ADDED
![]() |
Git LFS Details
|
assets/gradio/pos_aware/005/img_target.png
ADDED
![]() |
Git LFS Details
|
assets/gradio/pos_aware/005/mask_target.png
ADDED
![]() |
Git LFS Details
|
assets/gradio/pos_free/001/hyper_params.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"prompt": "TThe charming, soft plush toy is joyfully wandering through a lush, dense jungle, surrounded by vibrant green foliage and towering trees.", "custmization_mode": "Position-free", "input_mask_mode": "Precise mask", "seg_ref_mode": "Full Ref", "seed": 2126677963, "guidance": 40, "num_steps": 20, "num_images_per_prompt": 1, "use_background_preservation": false, "background_blend_threshold": 0.5, "true_gs": 3}
|
assets/gradio/pos_free/001/img_gen.png
ADDED
![]() |
Git LFS Details
|
assets/gradio/pos_free/001/img_ref.png
ADDED
![]() |
Git LFS Details
|
assets/gradio/pos_free/001/img_target.png
ADDED
![]() |
Git LFS Details
|
assets/gradio/pos_free/001/mask_target.png
ADDED
![]() |
Git LFS Details
|
assets/gradio/pos_free/002/hyper_params.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"prompt": "A bright yellow alarm clock sits on a wooden desk next to a stack of books in a cozy, sunlit room.", "custmization_mode": "Position-free", "input_mask_mode": "Precise mask", "seg_ref_mode": "Full Ref", "seed": 2126677963, "guidance": 40, "num_steps": 20, "num_images_per_prompt": 1, "use_background_preservation": false, "background_blend_threshold": 0.5, "true_gs": 3}
|
assets/gradio/pos_free/002/img_gen.png
ADDED
![]() |
Git LFS Details
|
assets/gradio/pos_free/002/img_ref.png
ADDED
![]() |
Git LFS Details
|
assets/gradio/pos_free/002/img_target.png
ADDED
![]() |
Git LFS Details
|