Yaowei222 commited on
Commit
12edc27
·
0 Parent(s):

Initial commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +91 -0
  2. .gitignore +177 -0
  3. README.md +14 -0
  4. app.py +414 -0
  5. app/APP.md +103 -0
  6. app/BEN2.py +1394 -0
  7. app/aspect_ratio_template.py +88 -0
  8. app/business_logic.py +556 -0
  9. app/config.py +72 -0
  10. app/constants.py +35 -0
  11. app/event_handlers.py +155 -0
  12. app/examples.py +210 -0
  13. app/metainfo.py +131 -0
  14. app/stylesheets.py +1679 -0
  15. app/ui_components.py +354 -0
  16. app/utils.py +429 -0
  17. assets/gradio/pos_aware/001/hypher_params.txt +3 -0
  18. assets/gradio/pos_aware/001/img_gen.png +3 -0
  19. assets/gradio/pos_aware/001/img_ref.png +3 -0
  20. assets/gradio/pos_aware/001/img_target.png +3 -0
  21. assets/gradio/pos_aware/001/mask_target.png +3 -0
  22. assets/gradio/pos_aware/002/hypher_params.txt +3 -0
  23. assets/gradio/pos_aware/002/img_gen.png +3 -0
  24. assets/gradio/pos_aware/002/img_ref.png +3 -0
  25. assets/gradio/pos_aware/002/img_target.png +3 -0
  26. assets/gradio/pos_aware/002/mask_target.png +3 -0
  27. assets/gradio/pos_aware/003/hypher_params.txt +3 -0
  28. assets/gradio/pos_aware/003/img_gen.png +3 -0
  29. assets/gradio/pos_aware/003/img_ref.png +3 -0
  30. assets/gradio/pos_aware/003/img_target.png +3 -0
  31. assets/gradio/pos_aware/003/mask_target.png +3 -0
  32. assets/gradio/pos_aware/004/hypher_params.txt +3 -0
  33. assets/gradio/pos_aware/004/img_gen.png +3 -0
  34. assets/gradio/pos_aware/004/img_ref.png +3 -0
  35. assets/gradio/pos_aware/004/img_target.png +3 -0
  36. assets/gradio/pos_aware/004/mask_target.png +3 -0
  37. assets/gradio/pos_aware/005/hypher_params.txt +3 -0
  38. assets/gradio/pos_aware/005/img_gen.png +3 -0
  39. assets/gradio/pos_aware/005/img_ref.png +3 -0
  40. assets/gradio/pos_aware/005/img_target.png +3 -0
  41. assets/gradio/pos_aware/005/mask_target.png +3 -0
  42. assets/gradio/pos_free/001/hyper_params.json +1 -0
  43. assets/gradio/pos_free/001/img_gen.png +3 -0
  44. assets/gradio/pos_free/001/img_ref.png +3 -0
  45. assets/gradio/pos_free/001/img_target.png +3 -0
  46. assets/gradio/pos_free/001/mask_target.png +3 -0
  47. assets/gradio/pos_free/002/hyper_params.json +1 -0
  48. assets/gradio/pos_free/002/img_gen.png +3 -0
  49. assets/gradio/pos_free/002/img_ref.png +3 -0
  50. assets/gradio/pos_free/002/img_target.png +3 -0
.gitattributes ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/gradio/pos_aware/005 filter=lfs diff=lfs merge=lfs -text
37
+ assets/gradio/pos_free filter=lfs diff=lfs merge=lfs -text
38
+ assets/gradio/pos_free/001/img_gen.png filter=lfs diff=lfs merge=lfs -text
39
+ assets/gradio/pos_free/001/img_ref.png filter=lfs diff=lfs merge=lfs -text
40
+ assets/gradio/pos_free/001/mask_target.png filter=lfs diff=lfs merge=lfs -text
41
+ assets/gradio/pos_free/003/hypher_params.txt filter=lfs diff=lfs merge=lfs -text
42
+ assets/gradio/pos_aware/002/hypher_params.txt filter=lfs diff=lfs merge=lfs -text
43
+ assets/gradio/pos_aware/003/hypher_params.txt filter=lfs diff=lfs merge=lfs -text
44
+ assets/gradio/pos_aware/004/mask_target.png filter=lfs diff=lfs merge=lfs -text
45
+ assets/gradio/pos_free/003/img_ref.png filter=lfs diff=lfs merge=lfs -text
46
+ assets/gradio/pos_free/004 filter=lfs diff=lfs merge=lfs -text
47
+ assets/gradio/pos_free/004/img_ref.png filter=lfs diff=lfs merge=lfs -text
48
+ assets/gradio/pos_aware/001/img_gen.png filter=lfs diff=lfs merge=lfs -text
49
+ assets/gradio/pos_aware/001/img_ref.png filter=lfs diff=lfs merge=lfs -text
50
+ assets/gradio/pos_aware/001/img_target.png filter=lfs diff=lfs merge=lfs -text
51
+ assets/gradio/pos_aware/002/img_ref.png filter=lfs diff=lfs merge=lfs -text
52
+ assets/gradio/pos_aware/003/img_gen.png filter=lfs diff=lfs merge=lfs -text
53
+ assets/gradio/pos_aware/003/img_ref.png filter=lfs diff=lfs merge=lfs -text
54
+ assets/gradio/pos_aware/004/img_gen.png filter=lfs diff=lfs merge=lfs -text
55
+ assets/gradio/pos_aware/004/img_ref.png filter=lfs diff=lfs merge=lfs -text
56
+ assets/gradio/pos_aware/004 filter=lfs diff=lfs merge=lfs -text
57
+ assets/gradio/pos_aware/004/hypher_params.txt filter=lfs diff=lfs merge=lfs -text
58
+ assets/gradio/pos_aware/005/img_gen.png filter=lfs diff=lfs merge=lfs -text
59
+ assets/gradio/pos_free/001/img_target.png filter=lfs diff=lfs merge=lfs -text
60
+ assets/gradio/pos_free/002 filter=lfs diff=lfs merge=lfs -text
61
+ assets/gradio/pos_free/002/img_ref.png filter=lfs diff=lfs merge=lfs -text
62
+ assets/gradio/pos_free/003/img_target.png filter=lfs diff=lfs merge=lfs -text
63
+ assets/gradio/pos_free/003/mask_target.png filter=lfs diff=lfs merge=lfs -text
64
+ assets/gradio/pos_aware filter=lfs diff=lfs merge=lfs -text
65
+ assets/gradio/pos_aware/001 filter=lfs diff=lfs merge=lfs -text
66
+ assets/gradio/pos_aware/005/img_ref.png filter=lfs diff=lfs merge=lfs -text
67
+ assets/gradio/pos_free/002/hypher_params.txt filter=lfs diff=lfs merge=lfs -text
68
+ assets/gradio/pos_free/002/img_gen.png filter=lfs diff=lfs merge=lfs -text
69
+ assets/gradio/pos_free/004/img_gen.png filter=lfs diff=lfs merge=lfs -text
70
+ assets/gradio/pos_aware/002/mask_target.png filter=lfs diff=lfs merge=lfs -text
71
+ assets/gradio/pos_aware/003 filter=lfs diff=lfs merge=lfs -text
72
+ assets/gradio/pos_aware/003/img_target.png filter=lfs diff=lfs merge=lfs -text
73
+ assets/gradio/pos_aware/003/mask_target.png filter=lfs diff=lfs merge=lfs -text
74
+ assets/gradio/pos_aware/005/img_target.png filter=lfs diff=lfs merge=lfs -text
75
+ assets/gradio/pos_free/001 filter=lfs diff=lfs merge=lfs -text
76
+ assets/gradio/pos_free/001/hypher_params.txt filter=lfs diff=lfs merge=lfs -text
77
+ assets/gradio/pos_free/002/mask_target.png filter=lfs diff=lfs merge=lfs -text
78
+ assets/gradio/pos_aware/005/hypher_params.txt filter=lfs diff=lfs merge=lfs -text
79
+ assets/gradio/pos_aware/005/mask_target.png filter=lfs diff=lfs merge=lfs -text
80
+ assets/gradio/pos_free/002/img_target.png filter=lfs diff=lfs merge=lfs -text
81
+ assets/gradio/pos_free/003 filter=lfs diff=lfs merge=lfs -text
82
+ assets/gradio/pos_free/003/img_gen.png filter=lfs diff=lfs merge=lfs -text
83
+ assets/gradio/pos_free/004/img_target.png filter=lfs diff=lfs merge=lfs -text
84
+ assets/gradio/pos_aware/001/hypher_params.txt filter=lfs diff=lfs merge=lfs -text
85
+ assets/gradio/pos_aware/001/mask_target.png filter=lfs diff=lfs merge=lfs -text
86
+ assets/gradio/pos_aware/002 filter=lfs diff=lfs merge=lfs -text
87
+ assets/gradio/pos_aware/004/img_target.png filter=lfs diff=lfs merge=lfs -text
88
+ assets/gradio/pos_free/004/hypher_params.txt filter=lfs diff=lfs merge=lfs -text
89
+ assets/gradio/pos_free/004/mask_target.png filter=lfs diff=lfs merge=lfs -text
90
+ assets/gradio/pos_aware/002/img_gen.png filter=lfs diff=lfs merge=lfs -text
91
+ assets/gradio/pos_aware/002/img_target.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Initially taken from GitHub's Python gitignore file
2
+
3
+ # Byte-compiled / optimized / DLL files
4
+ __pycache__/
5
+ *.py[cod]
6
+ *$py.class
7
+
8
+ # C extensions
9
+ *.so
10
+
11
+ # tests and logs
12
+ tests/fixtures/cached_*_text.txt
13
+ logs/
14
+ lightning_logs/
15
+ lang_code_data/
16
+
17
+ # Distribution / packaging
18
+ .Python
19
+ build/
20
+ develop-eggs/
21
+ dist/
22
+ downloads/
23
+ eggs/
24
+ .eggs/
25
+ lib/
26
+ lib64/
27
+ parts/
28
+ sdist/
29
+ var/
30
+ wheels/
31
+ *.egg-info/
32
+ .installed.cfg
33
+ *.egg
34
+ MANIFEST
35
+
36
+ # PyInstaller
37
+ # Usually these files are written by a Python script from a template
38
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
39
+ *.manifest
40
+ *.spec
41
+
42
+ # Installer logs
43
+ pip-log.txt
44
+ pip-delete-this-directory.txt
45
+
46
+ # Unit test / coverage reports
47
+ htmlcov/
48
+ .tox/
49
+ .nox/
50
+ .coverage
51
+ .coverage.*
52
+ .cache
53
+ nosetests.xml
54
+ coverage.xml
55
+ *.cover
56
+ .hypothesis/
57
+ .pytest_cache/
58
+
59
+ # Translations
60
+ *.mo
61
+ *.pot
62
+
63
+ # Django stuff:
64
+ *.log
65
+ local_settings.py
66
+ db.sqlite3
67
+
68
+ # Flask stuff:
69
+ instance/
70
+ .webassets-cache
71
+
72
+ # Scrapy stuff:
73
+ .scrapy
74
+
75
+ # Sphinx documentation
76
+ docs/_build/
77
+
78
+ # PyBuilder
79
+ target/
80
+
81
+ # Jupyter Notebook
82
+ .ipynb_checkpoints
83
+
84
+ # IPython
85
+ profile_default/
86
+ ipython_config.py
87
+
88
+ # pyenv
89
+ .python-version
90
+
91
+ # celery beat schedule file
92
+ celerybeat-schedule
93
+
94
+ # SageMath parsed files
95
+ *.sage.py
96
+
97
+ # Environments
98
+ .env
99
+ .venv
100
+ env/
101
+ venv/
102
+ ENV/
103
+ env.bak/
104
+ venv.bak/
105
+
106
+ # Spyder project settings
107
+ .spyderproject
108
+ .spyproject
109
+
110
+ # Rope project settings
111
+ .ropeproject
112
+
113
+ # mkdocs documentation
114
+ /site
115
+
116
+ # mypy
117
+ .mypy_cache/
118
+ .dmypy.json
119
+ dmypy.json
120
+
121
+ # Pyre type checker
122
+ .pyre/
123
+
124
+ # vscode
125
+ .vs
126
+ .vscode
127
+
128
+ # Pycharm
129
+ .idea
130
+
131
+ # TF code
132
+ tensorflow_code
133
+
134
+ # Models
135
+ proc_data
136
+ /models
137
+
138
+ # examples
139
+ runs
140
+ /runs_old
141
+ /wandb
142
+ /examples/runs
143
+ /examples/**/*.args
144
+ /examples/rag/sweep
145
+
146
+ # data
147
+ data
148
+ /data
149
+ serialization_dir
150
+
151
+ # emacs
152
+ *.*~
153
+ debug.env
154
+
155
+ # vim
156
+ .*.swp
157
+
158
+ # ctags
159
+ tags
160
+
161
+ # pre-commit
162
+ .pre-commit*
163
+
164
+ # .lock
165
+ *.lock
166
+
167
+ # DS_Store (MacOS)
168
+ .DS_Store
169
+
170
+ # RL pipelines may produce mp4 outputs
171
+ *.mp4
172
+
173
+ # dependencies
174
+ /transformers
175
+
176
+ # ruff
177
+ .ruff_cache
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: IC Custom
3
+ emoji: 🎨
4
+ colorFrom: purple
5
+ colorTo: blue
6
+ sdk: gradio
7
+ sdk_version: 5.43.1
8
+ app_file: app.py
9
+ pinned: false
10
+ license: other
11
+ short_description: IC-Custom is designed for diverse image customization.
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,414 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ IC-Custom Gradio Application
5
+
6
+ This module defines the UI and glue logic to run the IC-Custom pipeline
7
+ via Gradio. The code aims to keep UI text user-friendly while keeping the
8
+ implementation readable and maintainable.
9
+ """
10
+ import os
11
+ import sys
12
+ import numpy as np
13
+ import torch
14
+ import gradio as gr
15
+ import spaces
16
+
17
+ from PIL import Image
18
+ import time
19
+
20
+ # Add current directory to path for imports
21
+ sys.path.append(os.getcwd() + '/app')
22
+
23
+ # Import modular components
24
+ from config import parse_args, load_config, setup_environment
25
+ from ui_components import (
26
+ create_theme, create_css, create_header_section, create_customization_section,
27
+ create_image_input_section, create_prompt_section, create_advanced_options_section,
28
+ create_mask_operation_section, create_output_section, create_examples_section,
29
+ create_citation_section
30
+ )
31
+ from event_handlers import setup_event_handlers
32
+ from business_logic import (
33
+ init_image_target_1, init_image_target_2, init_image_reference,
34
+ undo_seg_points, segmentation, get_point, get_brush,
35
+ dilate_mask, erode_mask, bounding_box,
36
+ change_input_mask_mode, change_custmization_mode, change_seg_ref_mode,
37
+ vlm_auto_generate, vlm_auto_polish, save_results, set_mobile_predictor,
38
+ set_ben2_model, set_vlm_processor, set_vlm_model,
39
+ )
40
+
41
+ # Import other dependencies
42
+ from utils import (
43
+ get_sam_predictor, get_vlm, get_ben2_model,
44
+ prepare_input_images, get_mask_type_ids
45
+ )
46
+ from examples import GRADIO_EXAMPLES, MASK_TGT, IMG_GEN
47
+ from ic_custom.pipelines.ic_custom_pipeline import ICCustomPipeline
48
+
49
+ # Global variables for pipeline and assets cache directory
50
+ PIPELINE = None
51
+ ASSETS_CACHE_DIR = None
52
+
53
+ # Force Hugging Face to re-download models and clear cache
54
+ os.environ["HF_HUB_FORCE_DOWNLOAD"] = "1"
55
+ os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
56
+ os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache" # Use temp directory for Spaces
57
+ os.environ["HF_HOME"] = "/tmp/hf_home" # Use temp directory for Spaces
58
+
59
+
60
+ os.environ["GRADIO_TEMP_DIR"] = os.path.abspath(os.path.join(os.path.dirname(__file__), "gradio_cache"))
61
+
62
+
63
+ def set_pipeline(pipeline):
64
+ """Inject pipeline into this module without changing function signatures."""
65
+ global PIPELINE
66
+ PIPELINE = pipeline
67
+
68
+ def set_assets_cache_dir(assets_cache_dir):
69
+ """Inject assets cache dir into this module without changing function signatures."""
70
+ global ASSETS_CACHE_DIR
71
+ ASSETS_CACHE_DIR = assets_cache_dir
72
+
73
+
74
+ def initialize_models(args, cfg, device, weight_dtype):
75
+ """Initialize all required models."""
76
+ # Load IC-Custom pipeline
77
+ pipeline = ICCustomPipeline(
78
+ clip_path=cfg.checkpoint_config.clip_path if os.path.exists(cfg.checkpoint_config.clip_path) else "clip-vit-large-patch14",
79
+ t5_path=cfg.checkpoint_config.t5_path if os.path.exists(cfg.checkpoint_config.t5_path) else "t5-v1_1-xxl",
80
+ siglip_path=cfg.checkpoint_config.siglip_path if os.path.exists(cfg.checkpoint_config.siglip_path) else "siglip-so400m-patch14-384",
81
+ ae_path=cfg.checkpoint_config.ae_path if os.path.exists(cfg.checkpoint_config.ae_path) else "flux-fill-dev-ae",
82
+ dit_path=cfg.checkpoint_config.dit_path if os.path.exists(cfg.checkpoint_config.dit_path) else "flux-fill-dev-dit",
83
+ redux_path=cfg.checkpoint_config.redux_path if os.path.exists(cfg.checkpoint_config.redux_path) else "flux1-redux-dev",
84
+ lora_path=cfg.checkpoint_config.lora_path if os.path.exists(cfg.checkpoint_config.lora_path) else "dit_lora_0x1561",
85
+ img_txt_in_path=cfg.checkpoint_config.img_txt_in_path if os.path.exists(cfg.checkpoint_config.img_txt_in_path) else "dit_txt_img_in_0x1561",
86
+ boundary_embeddings_path=cfg.checkpoint_config.boundary_embeddings_path if os.path.exists(cfg.checkpoint_config.boundary_embeddings_path) else "dit_boundary_embeddings_0x1561",
87
+ task_register_embeddings_path=cfg.checkpoint_config.task_register_embeddings_path if os.path.exists(cfg.checkpoint_config.task_register_embeddings_path) else "dit_task_register_embeddings_0x1561",
88
+ network_alpha=cfg.model_config.network_alpha,
89
+ double_blocks_idx=cfg.model_config.double_blocks,
90
+ single_blocks_idx=cfg.model_config.single_blocks,
91
+ device=device,
92
+ weight_dtype=weight_dtype,
93
+ offload=True,
94
+ )
95
+ pipeline.set_pipeline_offload(True)
96
+ # pipeline.set_show_progress(True)
97
+
98
+ # Load SAM predictor
99
+ mobile_predictor = get_sam_predictor(cfg.checkpoint_config.sam_path, device)
100
+
101
+ # Load VLM if enabled
102
+ vlm_processor, vlm_model = None, None
103
+ if args.enable_vlm_for_prompt:
104
+ vlm_processor, vlm_model = get_vlm(
105
+ cfg.checkpoint_config.vlm_path,
106
+ device=device,
107
+ torch_dtype=weight_dtype,
108
+ )
109
+
110
+ # Load BEN2 model if enabled
111
+ ben2_model = None
112
+ if args.enable_ben2_for_mask_ref:
113
+ ben2_model = get_ben2_model(cfg.checkpoint_config.ben2_path, device)
114
+
115
+ return pipeline, mobile_predictor, vlm_processor, vlm_model, ben2_model
116
+
117
+ @spaces.GPU(duration=140)
118
+ def run_model(
119
+ image_target_state, mask_target_state, image_reference_ori_state,
120
+ image_reference_rmbg_state, prompt, seed, guidance, true_gs, num_steps,
121
+ num_images_per_prompt, use_background_preservation, background_blend_threshold,
122
+ aspect_ratio, custmization_mode, seg_ref_mode, input_mask_mode,
123
+ progress=gr.Progress()
124
+ ):
125
+ """Run IC-Custom pipeline with current UI state and return images."""
126
+ start_ts = time.time()
127
+ progress(0, desc="Starting generation...")
128
+ # Select reference image and check inputs
129
+ if seg_ref_mode == "Masked Ref":
130
+ image_reference_state = image_reference_rmbg_state
131
+ else:
132
+ image_reference_state = image_reference_ori_state
133
+
134
+ if image_reference_state is None:
135
+ gr.Warning('Please upload the reference image')
136
+ return None, seed, gr.update(placeholder="Last Input: " + prompt, value="")
137
+
138
+ if image_target_state is None and custmization_mode != "Position-free":
139
+ gr.Warning('Please upload the target image and mask it')
140
+ return None, seed, gr.update(placeholder="Last Input: " + prompt, value="")
141
+
142
+ if custmization_mode == "Position-aware" and mask_target_state is None:
143
+ gr.Warning('Please select/draw the target mask')
144
+ return None, seed, gr.update(placeholder=prompt, value="")
145
+
146
+
147
+ mask_type_ids = get_mask_type_ids(custmization_mode, input_mask_mode)
148
+
149
+ from constants import ASPECT_RATIO_TEMPLATE
150
+ output_w, output_h = ASPECT_RATIO_TEMPLATE[aspect_ratio]
151
+ image_reference, image_target, mask_target = prepare_input_images(
152
+ image_reference_state, custmization_mode, image_target_state, mask_target_state,
153
+ width=output_w, height=output_h,
154
+ force_resize_long_edge="long edge" in aspect_ratio,
155
+ return_type="pil"
156
+ )
157
+
158
+ gr.Info(f"Output WH resolution: {image_target.size[0]}px x {image_target.size[1]}px")
159
+ # Run the model
160
+ if seed == -1:
161
+ seed = torch.randint(0, 2147483647, (1,)).item()
162
+
163
+ width, height = image_target.size[0] + image_reference.size[0], image_target.size[1]
164
+
165
+
166
+ with torch.no_grad():
167
+ output_img = PIPELINE(
168
+ prompt=prompt, width=width, height=height, guidance=guidance,
169
+ num_steps=num_steps, seed=seed, img_ref=image_reference,
170
+ img_target=image_target, mask_target=mask_target, img_ip=image_reference,
171
+ cond_w_regions=[image_reference.size[0]], mask_type_ids=mask_type_ids,
172
+ use_background_preservation=use_background_preservation,
173
+ background_blend_threshold=background_blend_threshold, true_gs=true_gs,
174
+ neg_prompt="worst quality, normal quality, low quality, low res, blurry,",
175
+ num_images_per_prompt=num_images_per_prompt,
176
+ gradio_progress=progress,
177
+ )
178
+
179
+
180
+ elapsed = time.time() - start_ts
181
+ progress(1.0, desc=f"Completed in {elapsed:.2f}s!")
182
+ gr.Info(f"Finished in {elapsed:.2f}s")
183
+
184
+ return output_img, -1, gr.update(placeholder=f"Last Input ({elapsed:.2f}s): " + prompt, value="")
185
+
186
+
187
+ def example_pipeline(
188
+ image_reference, image_target_1, image_target_2, custmization_mode,
189
+ input_mask_mode, seg_ref_mode, prompt, seed, true_gs, eg_idx,
190
+ num_steps, guidance
191
+ ):
192
+ """Handle example loading in the UI."""
193
+
194
+ if seg_ref_mode == "Full Ref":
195
+ image_reference_ori_state = np.array(image_reference.convert("RGB"))
196
+ image_reference_rmbg_state = None
197
+ image_reference_state = image_reference_ori_state
198
+ else:
199
+ image_reference_rmbg_state = np.array(image_reference.convert("RGB"))
200
+ image_reference_ori_state = None
201
+ image_reference_state = image_reference_rmbg_state
202
+
203
+ if custmization_mode == "Position-aware":
204
+ if input_mask_mode == "Precise mask":
205
+ image_target_state = np.array(image_target_1.convert("RGB"))
206
+ else:
207
+ image_target_state = np.array(image_target_2['composite'].convert("RGB"))
208
+ mask_target_state = np.array(Image.open(MASK_TGT[int(eg_idx)]))
209
+ else: # Position-free mode
210
+ # For Position-free, use the target image from IMG_TGT1 and corresponding mask
211
+ image_target_state = np.array(image_target_1.convert("RGB"))
212
+ mask_target_state = np.array(Image.open(MASK_TGT[int(eg_idx)]))
213
+
214
+ mask_target_binary = mask_target_state / 255
215
+ masked_img = image_target_state * mask_target_binary
216
+ masked_img_pil = Image.fromarray(masked_img.astype("uint8"))
217
+ output_mask_pil = Image.fromarray(mask_target_state.astype("uint8"))
218
+
219
+ if custmization_mode == "Position-aware":
220
+ mask_gallery = [masked_img_pil, output_mask_pil]
221
+ else:
222
+ mask_gallery = gr.skip()
223
+
224
+ result_gallery = [Image.open(IMG_GEN[int(eg_idx)]).convert("RGB")]
225
+
226
+ if custmization_mode == "Position-free":
227
+ return (image_reference_ori_state, image_reference_rmbg_state, image_target_state,
228
+ mask_target_state, mask_gallery, result_gallery,
229
+ gr.update(visible=False), gr.update(visible=False))
230
+
231
+ if input_mask_mode == "Precise mask":
232
+ return (image_reference_ori_state, image_reference_rmbg_state, image_target_state,
233
+ mask_target_state, mask_gallery, result_gallery,
234
+ gr.update(visible=True), gr.update(visible=False))
235
+ else:
236
+ # Ensure ImageEditor has a proper background so brush + undo work
237
+ try:
238
+ bg_img = image_target_2.get('background') or image_target_2.get('composite')
239
+ except Exception:
240
+ bg_img = image_target_2
241
+
242
+ return (
243
+ image_reference_ori_state, image_reference_rmbg_state, image_target_state,
244
+ mask_target_state, mask_gallery, result_gallery,
245
+ gr.update(visible=False),
246
+ gr.update(visible=True, value={"background": bg_img, "layers": [], "composite": bg_img}),
247
+ )
248
+
249
+
250
+ def create_application():
251
+ """Create the main Gradio application."""
252
+ # Create theme and CSS
253
+ theme = create_theme()
254
+ css = create_css()
255
+
256
+ with gr.Blocks(theme=theme, css=css) as demo:
257
+
258
+ with gr.Column(elem_id="global_glass_container"):
259
+
260
+ # Create UI sections
261
+ create_header_section()
262
+
263
+ # Hidden components
264
+ eg_idx = gr.Textbox(label="eg_idx", visible=False, value="-1")
265
+
266
+ # State variables
267
+ image_target_state = gr.State(value=None)
268
+ mask_target_state = gr.State(value=None)
269
+ image_reference_ori_state = gr.State(value=None)
270
+ image_reference_rmbg_state = gr.State(value=None)
271
+ selected_points = gr.State(value=[])
272
+
273
+
274
+ # Main UI content with optimized left-right layout
275
+ with gr.Column(elem_id="glass_card"):
276
+ # Top section - Mode selection (full width)
277
+ custmization_mode, md_custmization_mode = create_customization_section()
278
+
279
+ # Main layout: Left for inputs, Right for outputs
280
+ with gr.Row(equal_height=False):
281
+ # LEFT COLUMN - ALL INPUTS
282
+ with gr.Column(scale=3, min_width=400):
283
+ # Image input section
284
+ (image_reference, input_mask_mode, image_target_1, image_target_2,
285
+ undo_target_seg_button, md_image_reference, md_input_mask_mode,
286
+ md_target_image) = create_image_input_section()
287
+
288
+ # Text prompt section
289
+ prompt, vlm_generate_btn, vlm_polish_btn, md_prompt = create_prompt_section()
290
+
291
+ # Advanced options (collapsible)
292
+ (aspect_ratio, seg_ref_mode, move_to_center, use_background_preservation,
293
+ background_blend_threshold, seed, num_images_per_prompt, guidance,
294
+ num_steps, true_gs) = create_advanced_options_section()
295
+
296
+ # RIGHT COLUMN - ALL OUTPUTS
297
+ with gr.Column(scale=2, min_width=350):
298
+ # Mask preview and operations
299
+ (mask_gallery, dilate_button, erode_button, bounding_box_button,
300
+ md_mask_operation) = create_mask_operation_section()
301
+
302
+ # Generation controls and results
303
+ result_gallery, submit_button, clear_btn, md_submit = create_output_section()
304
+
305
+ with gr.Row(elem_id="glass_card"):
306
+ # Examples section
307
+ examples = create_examples_section(
308
+ GRADIO_EXAMPLES,
309
+ inputs=[
310
+ image_reference,
311
+ image_target_1,
312
+ image_target_2,
313
+ custmization_mode,
314
+ input_mask_mode,
315
+ seg_ref_mode,
316
+ prompt,
317
+ seed,
318
+ true_gs,
319
+ eg_idx,
320
+ num_steps,
321
+ guidance
322
+ ],
323
+ outputs=[
324
+ image_reference_ori_state,
325
+ image_reference_rmbg_state,
326
+ image_target_state,
327
+ mask_target_state,
328
+ mask_gallery,
329
+ result_gallery,
330
+ image_target_1,
331
+ image_target_2,
332
+ ],
333
+ fn=example_pipeline,
334
+ )
335
+
336
+ with gr.Row(elem_id="glass_card"):
337
+ # Citation section
338
+ create_citation_section()
339
+
340
+ # Setup event handlers
341
+ setup_event_handlers(
342
+ ## UI components
343
+ input_mask_mode, image_target_1, image_target_2, undo_target_seg_button,
344
+ custmization_mode, dilate_button, erode_button, bounding_box_button,
345
+ mask_gallery, md_input_mask_mode, md_target_image, md_mask_operation,
346
+ md_prompt, md_submit, result_gallery, image_target_state, mask_target_state,
347
+ seg_ref_mode, image_reference_ori_state, move_to_center,
348
+ image_reference, image_reference_rmbg_state,
349
+ ## Functions
350
+ change_input_mask_mode, change_custmization_mode,
351
+ change_seg_ref_mode,
352
+ init_image_target_1, init_image_target_2, init_image_reference,
353
+ get_point, undo_seg_points,
354
+ get_brush,
355
+ # VLM buttons
356
+ vlm_generate_btn, vlm_polish_btn,
357
+ # VLM functions
358
+ vlm_auto_generate,
359
+ vlm_auto_polish,
360
+ dilate_mask, erode_mask, bounding_box,
361
+ run_model,
362
+ ## Other components
363
+ selected_points, prompt,
364
+ use_background_preservation, background_blend_threshold, seed,
365
+ num_images_per_prompt, guidance, true_gs, num_steps, aspect_ratio,
366
+ submit_button,
367
+ eg_idx,
368
+ )
369
+
370
+ # Setup clear button
371
+ clear_btn.add(
372
+ [image_reference, image_target_1,image_target_2, mask_gallery, result_gallery,
373
+ selected_points, image_target_state, mask_target_state, prompt,
374
+ image_reference_ori_state, image_reference_rmbg_state]
375
+ )
376
+
377
+ return demo
378
+
379
+
380
+ def main():
381
+ """Main entry point for the application."""
382
+ # Parse arguments and load config
383
+ args = parse_args()
384
+ cfg = load_config(args.config)
385
+ setup_environment(args)
386
+
387
+ # Initialize device and models
388
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
389
+ weight_dtype = torch.bfloat16
390
+
391
+ pipeline, mobile_predictor, vlm_processor, vlm_model, ben2_model = initialize_models(
392
+ args, cfg, device, weight_dtype
393
+ )
394
+
395
+ set_pipeline(pipeline)
396
+ set_assets_cache_dir(args.assets_cache_dir)
397
+
398
+ # Inject mobile predictor into business logic module so get_point can access it without lambdas
399
+ set_mobile_predictor(mobile_predictor)
400
+ set_ben2_model(ben2_model)
401
+ set_vlm_processor(vlm_processor)
402
+ set_vlm_model(vlm_model)
403
+
404
+ # Create and launch the application
405
+ demo = create_application()
406
+
407
+ # Launch the demo
408
+ demo.launch(server_port=7860, server_name="0.0.0.0",
409
+ allowed_paths=[os.path.abspath(os.path.join(os.path.dirname(__file__), "gradio_cache")),
410
+ os.path.abspath(os.path.join(os.path.dirname(__file__), "results"))])
411
+
412
+
413
+ if __name__ == "__main__":
414
+ main()
app/APP.md ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # IC-Custom Application
2
+
3
+ A sophisticated image customization tool powered by advanced AI models.
4
+
5
+ > 📺 **App Guide:**
6
+ > For a fast overview of how to use the app, watch this video:
7
+ > [IC-Custom App Usage Guide (YouTube)](https://www.youtube.com/watch?v=uaiZA3H5RV)
8
+
9
+ ---
10
+
11
+ ## 🚀 Quick Start
12
+
13
+ ```bash
14
+ python src/app/app.py \
15
+ --config configs/app/app.yaml \
16
+ --hf_token $HF_TOKEN \
17
+ --hf_cache_dir $HF_CACHE_DIR \
18
+ --assets_cache_dir results/app \
19
+ --enable_ben2_for_mask_ref False \
20
+ --enable_vlm_for_prompt False \
21
+ --save_results True
22
+ ```
23
+
24
+ ---
25
+
26
+ ## ⚙️ Configuration & CLI Arguments
27
+
28
+ | Argument | Type | Required | Default | Description |
29
+ |----------|------|----------|---------|-------------|
30
+ | `--config` | str | ✅ | - | Path to app YAML config file |
31
+ | `--hf_token` | str | ❌ | - | Hugging Face access token. |
32
+ | `--hf_cache_dir` | str | ❌ | `~/.cache/huggingface/hub` | HF assets cache directory |
33
+ | `--assets_cache_dir` | str | ❌ | `results/app` | Output images & metadata directory |
34
+ | `--save_results` | bool | ❌ | `False` | Save generated results |
35
+ | `--enable_ben2_for_mask_ref` | bool | ❌ | `False` | Enable BEN2 background removal |
36
+ | `--enable_vlm_for_prompt` | bool | ❌ | `False` | Enable VLM prompt generation |
37
+
38
+ ### Environment Variables
39
+
40
+ - `HF_TOKEN` ← `--hf_token`
41
+ - `HF_HUB_CACHE` ← `--hf_cache_dir`
42
+
43
+ ---
44
+
45
+ ## 📥 Model Downloads
46
+
47
+ > **Model checkpoints are required before running the app.**
48
+ > All required models will be automatically downloaded when you run the app, or you can manually download them and specify paths in `configs/app/app.yaml`.
49
+
50
+ ### Required Models
51
+
52
+ The following models are **automatically downloaded** when running the app:
53
+
54
+ | Model | Purpose | Source |
55
+ |-------|---------|--------|
56
+ | **IC-Custom** | Our customization model | [TencentARC/IC-Custom](https://huggingface.co/TencentARC/IC-Custom) |
57
+ | **CLIP** | Vision-language understanding | [openai/clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) |
58
+ | **T5** | Text processing | [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) |
59
+ | **SigLIP** | Image understanding | [google/siglip-so400m-patch14-384](https://huggingface.co/google/siglip-so400m-patch14-384) |
60
+ | **Autoencoder** | Image encoding/decoding | [black-forest-labs/FLUX.1-Fill-dev](https://huggingface.co/black-forest-labs/FLUX.1-Fill-dev/blob/main/ae.safetensors) |
61
+ | **DIT** | Diffusion model | [black-forest-labs/FLUX.1-Fill-dev](https://huggingface.co/black-forest-labs/FLUX.1-Fill-dev/blob/main/flux1-fill-dev.safetensors) |
62
+ | **Redux** | Image processing | [black-forest-labs/FLUX.1-Redux-dev](https://huggingface.co/black-forest-labs/FLUX.1-Redux-dev) |
63
+ | **SAM-vit-h** | Image segmentation | [HCMUE-Research/SAM-vit-h](https://huggingface.co/HCMUE-Research/SAM-vit-h/blob/main/sam_vit_h_4b8939.pth) |
64
+
65
+ ### Optional Models (Selective Download)
66
+
67
+ **BEN2 and Qwen2.5-VL models are disabled by default** and only downloaded when explicitly enabled:
68
+
69
+ | Model | Flag | Source | Purpose |
70
+ |-------|------|--------|---------|
71
+ | **BEN2** | `--enable_ben2_for_mask_ref True` | [PramaLLC/BEN2](https://huggingface.co/PramaLLC/BEN2/blob/main/BEN2_Base.pth) | Background removal |
72
+ | **Qwen2.5-VL** | `--enable_vlm_for_prompt True` | [Qwen/Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) | Prompt generation |
73
+
74
+ ### Manual Configuration
75
+
76
+ **Alternative**: Manually download all models and specify paths in `configs/app/app.yaml`:
77
+
78
+ ```yaml
79
+ checkpoint_config:
80
+ # Required models
81
+ dit_path: "/path/to/flux1-fill-dev.safetensors"
82
+ ae_path: "/path/to/ae.safetensors"
83
+ t5_path: "/path/to/t5-v1_1-xxl"
84
+ clip_path: "/path/to/clip-vit-large-patch14"
85
+ siglip_path: "/path/to/siglip-so400m-patch14-384"
86
+ redux_path: "/path/to/flux1-redux-dev.safetensors"
87
+ # IC-Custom models
88
+ lora_path: "/path/to/dit_lora_0x1561.safetensors"
89
+ img_txt_in_path: "/path/to/dit_txt_img_in_0x1561.safetensors"
90
+ boundary_embeddings_path: "/path/to/dit_boundary_embeddings_0x1561.safetensors"
91
+ task_register_embeddings_path: "/path/to/dit_task_register_embeddings_0x1561.safetensors"
92
+ # APP interactive models
93
+ sam_path: "/path/to/sam_vit_h_4b8939.pth"
94
+ # Optional models
95
+ ben2_path: "/path/to/BEN2_Base.pth"
96
+ vlm_path: "/path/to/Qwen2.5-VL-7B-Instruct"
97
+ ```
98
+
99
+ ### APP Overview
100
+
101
+ <p align="center">
102
+ <img src="../../assets/gradio_ui.png" alt="IC-Custom APP" width="80%">
103
+ </p>
app/BEN2.py ADDED
@@ -0,0 +1,1394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Prama LLC
2
+ # SPDX-License-Identifier: MIT
3
+
4
+ import math
5
+ import os
6
+ import random
7
+ import subprocess
8
+ import tempfile
9
+ import time
10
+
11
+ import cv2
12
+ import numpy as np
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+ import torch.utils.checkpoint as checkpoint
17
+ from einops import rearrange
18
+ from PIL import Image, ImageOps
19
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
20
+ from torchvision import transforms
21
+
22
+
23
+ def set_random_seed(seed):
24
+ random.seed(seed)
25
+ np.random.seed(seed)
26
+ torch.manual_seed(seed)
27
+ torch.cuda.manual_seed(seed)
28
+ torch.cuda.manual_seed_all(seed)
29
+ torch.backends.cudnn.deterministic = True
30
+ torch.backends.cudnn.benchmark = False
31
+
32
+
33
+ # set_random_seed(9)
34
+
35
+ torch.set_float32_matmul_precision('highest')
36
+
37
+
38
+ class Mlp(nn.Module):
39
+ """ Multilayer perceptron."""
40
+
41
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
42
+ super().__init__()
43
+ out_features = out_features or in_features
44
+ hidden_features = hidden_features or in_features
45
+ self.fc1 = nn.Linear(in_features, hidden_features)
46
+ self.act = act_layer()
47
+ self.fc2 = nn.Linear(hidden_features, out_features)
48
+ self.drop = nn.Dropout(drop)
49
+
50
+ def forward(self, x):
51
+ x = self.fc1(x)
52
+ x = self.act(x)
53
+ x = self.drop(x)
54
+ x = self.fc2(x)
55
+ x = self.drop(x)
56
+ return x
57
+
58
+
59
+ def window_partition(x, window_size):
60
+ """
61
+ Args:
62
+ x: (B, H, W, C)
63
+ window_size (int): window size
64
+ Returns:
65
+ windows: (num_windows*B, window_size, window_size, C)
66
+ """
67
+ B, H, W, C = x.shape
68
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
69
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
70
+ return windows
71
+
72
+
73
+ def window_reverse(windows, window_size, H, W):
74
+ """
75
+ Args:
76
+ windows: (num_windows*B, window_size, window_size, C)
77
+ window_size (int): Window size
78
+ H (int): Height of image
79
+ W (int): Width of image
80
+ Returns:
81
+ x: (B, H, W, C)
82
+ """
83
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
84
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
85
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
86
+ return x
87
+
88
+
89
+ class WindowAttention(nn.Module):
90
+ """ Window based multi-head self attention (W-MSA) module with relative position bias.
91
+ It supports both of shifted and non-shifted window.
92
+ Args:
93
+ dim (int): Number of input channels.
94
+ window_size (tuple[int]): The height and width of the window.
95
+ num_heads (int): Number of attention heads.
96
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
97
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
98
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
99
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
100
+ """
101
+
102
+ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
103
+
104
+ super().__init__()
105
+ self.dim = dim
106
+ self.window_size = window_size # Wh, Ww
107
+ self.num_heads = num_heads
108
+ head_dim = dim // num_heads
109
+ self.scale = qk_scale or head_dim ** -0.5
110
+
111
+ # define a parameter table of relative position bias
112
+ self.relative_position_bias_table = nn.Parameter(
113
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
114
+
115
+ # get pair-wise relative position index for each token inside the window
116
+ coords_h = torch.arange(self.window_size[0])
117
+ coords_w = torch.arange(self.window_size[1])
118
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
119
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
120
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
121
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
122
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
123
+ relative_coords[:, :, 1] += self.window_size[1] - 1
124
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
125
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
126
+ self.register_buffer("relative_position_index", relative_position_index)
127
+
128
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
129
+ self.attn_drop = nn.Dropout(attn_drop)
130
+ self.proj = nn.Linear(dim, dim)
131
+ self.proj_drop = nn.Dropout(proj_drop)
132
+
133
+ trunc_normal_(self.relative_position_bias_table, std=.02)
134
+ self.softmax = nn.Softmax(dim=-1)
135
+
136
+ def forward(self, x, mask=None):
137
+ """ Forward function.
138
+ Args:
139
+ x: input features with shape of (num_windows*B, N, C)
140
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
141
+ """
142
+ B_, N, C = x.shape
143
+ qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
144
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
145
+
146
+ q = q * self.scale
147
+ attn = (q @ k.transpose(-2, -1))
148
+
149
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
150
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
151
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
152
+ attn = attn + relative_position_bias.unsqueeze(0)
153
+
154
+ if mask is not None:
155
+ nW = mask.shape[0]
156
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
157
+ attn = attn.view(-1, self.num_heads, N, N)
158
+ attn = self.softmax(attn)
159
+ else:
160
+ attn = self.softmax(attn)
161
+
162
+ attn = self.attn_drop(attn)
163
+
164
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
165
+ x = self.proj(x)
166
+ x = self.proj_drop(x)
167
+ return x
168
+
169
+
170
+ class SwinTransformerBlock(nn.Module):
171
+ """ Swin Transformer Block.
172
+ Args:
173
+ dim (int): Number of input channels.
174
+ num_heads (int): Number of attention heads.
175
+ window_size (int): Window size.
176
+ shift_size (int): Shift size for SW-MSA.
177
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
178
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
179
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
180
+ drop (float, optional): Dropout rate. Default: 0.0
181
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
182
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
183
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
184
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
185
+ """
186
+
187
+ def __init__(self, dim, num_heads, window_size=7, shift_size=0,
188
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
189
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm):
190
+ super().__init__()
191
+ self.dim = dim
192
+ self.num_heads = num_heads
193
+ self.window_size = window_size
194
+ self.shift_size = shift_size
195
+ self.mlp_ratio = mlp_ratio
196
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
197
+
198
+ self.norm1 = norm_layer(dim)
199
+ self.attn = WindowAttention(
200
+ dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
201
+ qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
202
+
203
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
204
+ self.norm2 = norm_layer(dim)
205
+ mlp_hidden_dim = int(dim * mlp_ratio)
206
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
207
+
208
+ self.H = None
209
+ self.W = None
210
+
211
+ def forward(self, x, mask_matrix):
212
+ """ Forward function.
213
+ Args:
214
+ x: Input feature, tensor size (B, H*W, C).
215
+ H, W: Spatial resolution of the input feature.
216
+ mask_matrix: Attention mask for cyclic shift.
217
+ """
218
+ B, L, C = x.shape
219
+ H, W = self.H, self.W
220
+ assert L == H * W, "input feature has wrong size"
221
+
222
+ shortcut = x
223
+ x = self.norm1(x)
224
+ x = x.view(B, H, W, C)
225
+
226
+ # pad feature maps to multiples of window size
227
+ pad_l = pad_t = 0
228
+ pad_r = (self.window_size - W % self.window_size) % self.window_size
229
+ pad_b = (self.window_size - H % self.window_size) % self.window_size
230
+ x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
231
+ _, Hp, Wp, _ = x.shape
232
+
233
+ # cyclic shift
234
+ if self.shift_size > 0:
235
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
236
+ attn_mask = mask_matrix
237
+ else:
238
+ shifted_x = x
239
+ attn_mask = None
240
+
241
+ # partition windows
242
+ x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
243
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
244
+
245
+ # W-MSA/SW-MSA
246
+ attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
247
+
248
+ # merge windows
249
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
250
+ shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
251
+
252
+ # reverse cyclic shift
253
+ if self.shift_size > 0:
254
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
255
+ else:
256
+ x = shifted_x
257
+
258
+ if pad_r > 0 or pad_b > 0:
259
+ x = x[:, :H, :W, :].contiguous()
260
+
261
+ x = x.view(B, H * W, C)
262
+
263
+ # FFN
264
+ x = shortcut + self.drop_path(x)
265
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
266
+
267
+ return x
268
+
269
+
270
+ class PatchMerging(nn.Module):
271
+ """ Patch Merging Layer
272
+ Args:
273
+ dim (int): Number of input channels.
274
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
275
+ """
276
+
277
+ def __init__(self, dim, norm_layer=nn.LayerNorm):
278
+ super().__init__()
279
+ self.dim = dim
280
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
281
+ self.norm = norm_layer(4 * dim)
282
+
283
+ def forward(self, x, H, W):
284
+ """ Forward function.
285
+ Args:
286
+ x: Input feature, tensor size (B, H*W, C).
287
+ H, W: Spatial resolution of the input feature.
288
+ """
289
+ B, L, C = x.shape
290
+ assert L == H * W, "input feature has wrong size"
291
+
292
+ x = x.view(B, H, W, C)
293
+
294
+ # padding
295
+ pad_input = (H % 2 == 1) or (W % 2 == 1)
296
+ if pad_input:
297
+ x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
298
+
299
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
300
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
301
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
302
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
303
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
304
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
305
+
306
+ x = self.norm(x)
307
+ x = self.reduction(x)
308
+
309
+ return x
310
+
311
+
312
+ class BasicLayer(nn.Module):
313
+ """ A basic Swin Transformer layer for one stage.
314
+ Args:
315
+ dim (int): Number of feature channels
316
+ depth (int): Depths of this stage.
317
+ num_heads (int): Number of attention head.
318
+ window_size (int): Local window size. Default: 7.
319
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
320
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
321
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
322
+ drop (float, optional): Dropout rate. Default: 0.0
323
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
324
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
325
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
326
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
327
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
328
+ """
329
+
330
+ def __init__(self,
331
+ dim,
332
+ depth,
333
+ num_heads,
334
+ window_size=7,
335
+ mlp_ratio=4.,
336
+ qkv_bias=True,
337
+ qk_scale=None,
338
+ drop=0.,
339
+ attn_drop=0.,
340
+ drop_path=0.,
341
+ norm_layer=nn.LayerNorm,
342
+ downsample=None,
343
+ use_checkpoint=False):
344
+ super().__init__()
345
+ self.window_size = window_size
346
+ self.shift_size = window_size // 2
347
+ self.depth = depth
348
+ self.use_checkpoint = use_checkpoint
349
+
350
+ # build blocks
351
+ self.blocks = nn.ModuleList([
352
+ SwinTransformerBlock(
353
+ dim=dim,
354
+ num_heads=num_heads,
355
+ window_size=window_size,
356
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
357
+ mlp_ratio=mlp_ratio,
358
+ qkv_bias=qkv_bias,
359
+ qk_scale=qk_scale,
360
+ drop=drop,
361
+ attn_drop=attn_drop,
362
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
363
+ norm_layer=norm_layer)
364
+ for i in range(depth)])
365
+
366
+ # patch merging layer
367
+ if downsample is not None:
368
+ self.downsample = downsample(dim=dim, norm_layer=norm_layer)
369
+ else:
370
+ self.downsample = None
371
+
372
+ def forward(self, x, H, W):
373
+ """ Forward function.
374
+ Args:
375
+ x: Input feature, tensor size (B, H*W, C).
376
+ H, W: Spatial resolution of the input feature.
377
+ """
378
+
379
+ # calculate attention mask for SW-MSA
380
+ Hp = int(np.ceil(H / self.window_size)) * self.window_size
381
+ Wp = int(np.ceil(W / self.window_size)) * self.window_size
382
+ img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
383
+ h_slices = (slice(0, -self.window_size),
384
+ slice(-self.window_size, -self.shift_size),
385
+ slice(-self.shift_size, None))
386
+ w_slices = (slice(0, -self.window_size),
387
+ slice(-self.window_size, -self.shift_size),
388
+ slice(-self.shift_size, None))
389
+ cnt = 0
390
+ for h in h_slices:
391
+ for w in w_slices:
392
+ img_mask[:, h, w, :] = cnt
393
+ cnt += 1
394
+
395
+ mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
396
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
397
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
398
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
399
+
400
+ for blk in self.blocks:
401
+ blk.H, blk.W = H, W
402
+ if self.use_checkpoint:
403
+ x = checkpoint.checkpoint(blk, x, attn_mask)
404
+ else:
405
+ x = blk(x, attn_mask)
406
+ if self.downsample is not None:
407
+ x_down = self.downsample(x, H, W)
408
+ Wh, Ww = (H + 1) // 2, (W + 1) // 2
409
+ return x, H, W, x_down, Wh, Ww
410
+ else:
411
+ return x, H, W, x, H, W
412
+
413
+
414
+ class PatchEmbed(nn.Module):
415
+ """ Image to Patch Embedding
416
+ Args:
417
+ patch_size (int): Patch token size. Default: 4.
418
+ in_chans (int): Number of input image channels. Default: 3.
419
+ embed_dim (int): Number of linear projection output channels. Default: 96.
420
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
421
+ """
422
+
423
+ def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
424
+ super().__init__()
425
+ patch_size = to_2tuple(patch_size)
426
+ self.patch_size = patch_size
427
+
428
+ self.in_chans = in_chans
429
+ self.embed_dim = embed_dim
430
+
431
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
432
+ if norm_layer is not None:
433
+ self.norm = norm_layer(embed_dim)
434
+ else:
435
+ self.norm = None
436
+
437
+ def forward(self, x):
438
+ """Forward function."""
439
+ # padding
440
+ _, _, H, W = x.size()
441
+ if W % self.patch_size[1] != 0:
442
+ x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
443
+ if H % self.patch_size[0] != 0:
444
+ x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
445
+
446
+ x = self.proj(x) # B C Wh Ww
447
+ if self.norm is not None:
448
+ Wh, Ww = x.size(2), x.size(3)
449
+ x = x.flatten(2).transpose(1, 2)
450
+ x = self.norm(x)
451
+ x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
452
+
453
+ return x
454
+
455
+
456
+ class SwinTransformer(nn.Module):
457
+ """ Swin Transformer backbone.
458
+ A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
459
+ https://arxiv.org/pdf/2103.14030
460
+ Args:
461
+ pretrain_img_size (int): Input image size for training the pretrained model,
462
+ used in absolute postion embedding. Default 224.
463
+ patch_size (int | tuple(int)): Patch size. Default: 4.
464
+ in_chans (int): Number of input image channels. Default: 3.
465
+ embed_dim (int): Number of linear projection output channels. Default: 96.
466
+ depths (tuple[int]): Depths of each Swin Transformer stage.
467
+ num_heads (tuple[int]): Number of attention head of each stage.
468
+ window_size (int): Window size. Default: 7.
469
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
470
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
471
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
472
+ drop_rate (float): Dropout rate.
473
+ attn_drop_rate (float): Attention dropout rate. Default: 0.
474
+ drop_path_rate (float): Stochastic depth rate. Default: 0.2.
475
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
476
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
477
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True.
478
+ out_indices (Sequence[int]): Output from which stages.
479
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
480
+ -1 means not freezing any parameters.
481
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
482
+ """
483
+
484
+ def __init__(self,
485
+ pretrain_img_size=224,
486
+ patch_size=4,
487
+ in_chans=3,
488
+ embed_dim=96,
489
+ depths=[2, 2, 6, 2],
490
+ num_heads=[3, 6, 12, 24],
491
+ window_size=7,
492
+ mlp_ratio=4.,
493
+ qkv_bias=True,
494
+ qk_scale=None,
495
+ drop_rate=0.,
496
+ attn_drop_rate=0.,
497
+ drop_path_rate=0.2,
498
+ norm_layer=nn.LayerNorm,
499
+ ape=False,
500
+ patch_norm=True,
501
+ out_indices=(0, 1, 2, 3),
502
+ frozen_stages=-1,
503
+ use_checkpoint=False):
504
+ super().__init__()
505
+
506
+ self.pretrain_img_size = pretrain_img_size
507
+ self.num_layers = len(depths)
508
+ self.embed_dim = embed_dim
509
+ self.ape = ape
510
+ self.patch_norm = patch_norm
511
+ self.out_indices = out_indices
512
+ self.frozen_stages = frozen_stages
513
+
514
+ # split image into non-overlapping patches
515
+ self.patch_embed = PatchEmbed(
516
+ patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
517
+ norm_layer=norm_layer if self.patch_norm else None)
518
+
519
+ # absolute position embedding
520
+ if self.ape:
521
+ pretrain_img_size = to_2tuple(pretrain_img_size)
522
+ patch_size = to_2tuple(patch_size)
523
+ patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1]]
524
+
525
+ self.absolute_pos_embed = nn.Parameter(
526
+ torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1]))
527
+ trunc_normal_(self.absolute_pos_embed, std=.02)
528
+
529
+ self.pos_drop = nn.Dropout(p=drop_rate)
530
+
531
+ # stochastic depth
532
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
533
+
534
+ # build layers
535
+ self.layers = nn.ModuleList()
536
+ for i_layer in range(self.num_layers):
537
+ layer = BasicLayer(
538
+ dim=int(embed_dim * 2 ** i_layer),
539
+ depth=depths[i_layer],
540
+ num_heads=num_heads[i_layer],
541
+ window_size=window_size,
542
+ mlp_ratio=mlp_ratio,
543
+ qkv_bias=qkv_bias,
544
+ qk_scale=qk_scale,
545
+ drop=drop_rate,
546
+ attn_drop=attn_drop_rate,
547
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
548
+ norm_layer=norm_layer,
549
+ downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
550
+ use_checkpoint=use_checkpoint)
551
+ self.layers.append(layer)
552
+
553
+ num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
554
+ self.num_features = num_features
555
+
556
+ # add a norm layer for each output
557
+ for i_layer in out_indices:
558
+ layer = norm_layer(num_features[i_layer])
559
+ layer_name = f'norm{i_layer}'
560
+ self.add_module(layer_name, layer)
561
+
562
+ self._freeze_stages()
563
+
564
+ def _freeze_stages(self):
565
+ if self.frozen_stages >= 0:
566
+ self.patch_embed.eval()
567
+ for param in self.patch_embed.parameters():
568
+ param.requires_grad = False
569
+
570
+ if self.frozen_stages >= 1 and self.ape:
571
+ self.absolute_pos_embed.requires_grad = False
572
+
573
+ if self.frozen_stages >= 2:
574
+ self.pos_drop.eval()
575
+ for i in range(0, self.frozen_stages - 1):
576
+ m = self.layers[i]
577
+ m.eval()
578
+ for param in m.parameters():
579
+ param.requires_grad = False
580
+
581
+ def forward(self, x):
582
+
583
+ x = self.patch_embed(x)
584
+
585
+ Wh, Ww = x.size(2), x.size(3)
586
+ if self.ape:
587
+ # interpolate the position embedding to the corresponding size
588
+ absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic')
589
+ x = (x + absolute_pos_embed) # B Wh*Ww C
590
+
591
+ outs = [x.contiguous()]
592
+ x = x.flatten(2).transpose(1, 2)
593
+ x = self.pos_drop(x)
594
+
595
+ for i in range(self.num_layers):
596
+ layer = self.layers[i]
597
+ x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
598
+
599
+ if i in self.out_indices:
600
+ norm_layer = getattr(self, f'norm{i}')
601
+ x_out = norm_layer(x_out)
602
+
603
+ out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
604
+ outs.append(out)
605
+
606
+ return tuple(outs)
607
+
608
+
609
+ def get_activation_fn(activation):
610
+ """Return an activation function given a string"""
611
+ if activation == "gelu":
612
+ return F.gelu
613
+
614
+ raise RuntimeError(F"activation should be gelu, not {activation}.")
615
+
616
+
617
+ def make_cbr(in_dim, out_dim):
618
+ return nn.Sequential(nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1), nn.InstanceNorm2d(out_dim), nn.GELU())
619
+
620
+
621
+ def make_cbg(in_dim, out_dim):
622
+ return nn.Sequential(nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1), nn.InstanceNorm2d(out_dim), nn.GELU())
623
+
624
+
625
+ def rescale_to(x, scale_factor: float = 2, interpolation='nearest'):
626
+ return F.interpolate(x, scale_factor=scale_factor, mode=interpolation)
627
+
628
+
629
+ def resize_as(x, y, interpolation='bilinear'):
630
+ return F.interpolate(x, size=y.shape[-2:], mode=interpolation)
631
+
632
+
633
+ def image2patches(x):
634
+ """b c (hg h) (wg w) -> (hg wg b) c h w"""
635
+ x = rearrange(x, 'b c (hg h) (wg w) -> (hg wg b) c h w', hg=2, wg=2)
636
+ return x
637
+
638
+
639
+ def patches2image(x):
640
+ """(hg wg b) c h w -> b c (hg h) (wg w)"""
641
+ x = rearrange(x, '(hg wg b) c h w -> b c (hg h) (wg w)', hg=2, wg=2)
642
+ return x
643
+
644
+
645
+ class PositionEmbeddingSine:
646
+ def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
647
+ super().__init__()
648
+ self.num_pos_feats = num_pos_feats
649
+ self.temperature = temperature
650
+ self.normalize = normalize
651
+ if scale is not None and normalize is False:
652
+ raise ValueError("normalize should be True if scale is passed")
653
+ if scale is None:
654
+ scale = 2 * math.pi
655
+ self.scale = scale
656
+ self.dim_t = torch.arange(0, self.num_pos_feats, dtype=torch.float32)
657
+
658
+ def __call__(self, b, h, w):
659
+ device = self.dim_t.device
660
+ mask = torch.zeros([b, h, w], dtype=torch.bool, device=device)
661
+ assert mask is not None
662
+ not_mask = ~mask
663
+ y_embed = not_mask.cumsum(dim=1, dtype=torch.float32)
664
+ x_embed = not_mask.cumsum(dim=2, dtype=torch.float32)
665
+ if self.normalize:
666
+ eps = 1e-6
667
+ y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * self.scale
668
+ x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale
669
+
670
+ dim_t = self.temperature ** (2 * (self.dim_t.to(device) // 2) / self.num_pos_feats)
671
+ pos_x = x_embed[:, :, :, None] / dim_t
672
+ pos_y = y_embed[:, :, :, None] / dim_t
673
+
674
+ pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
675
+ pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
676
+
677
+ return torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
678
+
679
+
680
+ class PositionEmbeddingSine:
681
+ def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
682
+ super().__init__()
683
+ self.num_pos_feats = num_pos_feats
684
+ self.temperature = temperature
685
+ self.normalize = normalize
686
+ if scale is not None and normalize is False:
687
+ raise ValueError("normalize should be True if scale is passed")
688
+ if scale is None:
689
+ scale = 2 * math.pi
690
+ self.scale = scale
691
+ self.dim_t = torch.arange(0, self.num_pos_feats, dtype=torch.float32)
692
+
693
+ def __call__(self, b, h, w):
694
+ device = self.dim_t.device
695
+ mask = torch.zeros([b, h, w], dtype=torch.bool, device=device)
696
+ assert mask is not None
697
+ not_mask = ~mask
698
+ y_embed = not_mask.cumsum(dim=1, dtype=torch.float32)
699
+ x_embed = not_mask.cumsum(dim=2, dtype=torch.float32)
700
+ if self.normalize:
701
+ eps = 1e-6
702
+ y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * self.scale
703
+ x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale
704
+
705
+ dim_t = self.temperature ** (2 * (self.dim_t.to(device) // 2) / self.num_pos_feats)
706
+ pos_x = x_embed[:, :, :, None] / dim_t
707
+ pos_y = y_embed[:, :, :, None] / dim_t
708
+
709
+ pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
710
+ pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
711
+
712
+ return torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
713
+
714
+
715
+ class MCLM(nn.Module):
716
+ def __init__(self, d_model, num_heads, pool_ratios=[1, 4, 8]):
717
+ super(MCLM, self).__init__()
718
+ self.attention = nn.ModuleList([
719
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
720
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
721
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
722
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
723
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1)
724
+ ])
725
+
726
+ self.linear1 = nn.Linear(d_model, d_model * 2)
727
+ self.linear2 = nn.Linear(d_model * 2, d_model)
728
+ self.linear3 = nn.Linear(d_model, d_model * 2)
729
+ self.linear4 = nn.Linear(d_model * 2, d_model)
730
+ self.norm1 = nn.LayerNorm(d_model)
731
+ self.norm2 = nn.LayerNorm(d_model)
732
+ self.dropout = nn.Dropout(0.1)
733
+ self.dropout1 = nn.Dropout(0.1)
734
+ self.dropout2 = nn.Dropout(0.1)
735
+ self.activation = get_activation_fn('gelu')
736
+ self.pool_ratios = pool_ratios
737
+ self.p_poses = []
738
+ self.g_pos = None
739
+ self.positional_encoding = PositionEmbeddingSine(num_pos_feats=d_model // 2, normalize=True)
740
+
741
+ def forward(self, l, g):
742
+ """
743
+ l: 4,c,h,w
744
+ g: 1,c,h,w
745
+ """
746
+ self.p_poses = []
747
+ self.g_pos = None
748
+ b, c, h, w = l.size()
749
+ # 4,c,h,w -> 1,c,2h,2w
750
+ concated_locs = rearrange(l, '(hg wg b) c h w -> b c (hg h) (wg w)', hg=2, wg=2)
751
+
752
+ pools = []
753
+ for pool_ratio in self.pool_ratios:
754
+ # b,c,h,w
755
+ tgt_hw = (round(h / pool_ratio), round(w / pool_ratio))
756
+ pool = F.adaptive_avg_pool2d(concated_locs, tgt_hw)
757
+ pools.append(rearrange(pool, 'b c h w -> (h w) b c'))
758
+ if self.g_pos is None:
759
+ pos_emb = self.positional_encoding(pool.shape[0], pool.shape[2], pool.shape[3])
760
+ pos_emb = rearrange(pos_emb, 'b c h w -> (h w) b c')
761
+ self.p_poses.append(pos_emb)
762
+ pools = torch.cat(pools, 0)
763
+ if self.g_pos is None:
764
+ self.p_poses = torch.cat(self.p_poses, dim=0)
765
+ pos_emb = self.positional_encoding(g.shape[0], g.shape[2], g.shape[3])
766
+ self.g_pos = rearrange(pos_emb, 'b c h w -> (h w) b c')
767
+
768
+ device = pools.device
769
+ self.p_poses = self.p_poses.to(device)
770
+ self.g_pos = self.g_pos.to(device)
771
+
772
+ # attention between glb (q) & multisensory concated-locs (k,v)
773
+ g_hw_b_c = rearrange(g, 'b c h w -> (h w) b c')
774
+
775
+ g_hw_b_c = g_hw_b_c + self.dropout1(self.attention[0](g_hw_b_c + self.g_pos, pools + self.p_poses, pools)[0])
776
+ g_hw_b_c = self.norm1(g_hw_b_c)
777
+ g_hw_b_c = g_hw_b_c + self.dropout2(self.linear2(self.dropout(self.activation(self.linear1(g_hw_b_c)).clone())))
778
+ g_hw_b_c = self.norm2(g_hw_b_c)
779
+
780
+ # attention between origin locs (q) & freashed glb (k,v)
781
+ l_hw_b_c = rearrange(l, "b c h w -> (h w) b c")
782
+ _g_hw_b_c = rearrange(g_hw_b_c, '(h w) b c -> h w b c', h=h, w=w)
783
+ _g_hw_b_c = rearrange(_g_hw_b_c, "(ng h) (nw w) b c -> (h w) (ng nw b) c", ng=2, nw=2)
784
+ outputs_re = []
785
+ for i, (_l, _g) in enumerate(zip(l_hw_b_c.chunk(4, dim=1), _g_hw_b_c.chunk(4, dim=1))):
786
+ outputs_re.append(self.attention[i + 1](_l, _g, _g)[0]) # (h w) 1 c
787
+ outputs_re = torch.cat(outputs_re, 1) # (h w) 4 c
788
+
789
+ l_hw_b_c = l_hw_b_c + self.dropout1(outputs_re)
790
+ l_hw_b_c = self.norm1(l_hw_b_c)
791
+ l_hw_b_c = l_hw_b_c + self.dropout2(self.linear4(self.dropout(self.activation(self.linear3(l_hw_b_c)).clone())))
792
+ l_hw_b_c = self.norm2(l_hw_b_c)
793
+
794
+ l = torch.cat((l_hw_b_c, g_hw_b_c), 1) # hw,b(5),c
795
+ return rearrange(l, "(h w) b c -> b c h w", h=h, w=w) ## (5,c,h*w)
796
+
797
+
798
+ class MCRM(nn.Module):
799
+ def __init__(self, d_model, num_heads, pool_ratios=[4, 8, 16], h=None):
800
+ super(MCRM, self).__init__()
801
+ self.attention = nn.ModuleList([
802
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
803
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
804
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
805
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1)
806
+ ])
807
+ self.linear3 = nn.Linear(d_model, d_model * 2)
808
+ self.linear4 = nn.Linear(d_model * 2, d_model)
809
+ self.norm1 = nn.LayerNorm(d_model)
810
+ self.norm2 = nn.LayerNorm(d_model)
811
+ self.dropout = nn.Dropout(0.1)
812
+ self.dropout1 = nn.Dropout(0.1)
813
+ self.dropout2 = nn.Dropout(0.1)
814
+ self.sigmoid = nn.Sigmoid()
815
+ self.activation = get_activation_fn('gelu')
816
+ self.sal_conv = nn.Conv2d(d_model, 1, 1)
817
+ self.pool_ratios = pool_ratios
818
+
819
+ def forward(self, x):
820
+ device = x.device
821
+ b, c, h, w = x.size()
822
+ loc, glb = x.split([4, 1], dim=0) # 4,c,h,w; 1,c,h,w
823
+
824
+ patched_glb = rearrange(glb, 'b c (hg h) (wg w) -> (hg wg b) c h w', hg=2, wg=2)
825
+
826
+ token_attention_map = self.sigmoid(self.sal_conv(glb))
827
+ token_attention_map = F.interpolate(token_attention_map, size=patches2image(loc).shape[-2:], mode='nearest')
828
+ loc = loc * rearrange(token_attention_map, 'b c (hg h) (wg w) -> (hg wg b) c h w', hg=2, wg=2)
829
+
830
+ pools = []
831
+ for pool_ratio in self.pool_ratios:
832
+ tgt_hw = (round(h / pool_ratio), round(w / pool_ratio))
833
+ pool = F.adaptive_avg_pool2d(patched_glb, tgt_hw)
834
+ pools.append(rearrange(pool, 'nl c h w -> nl c (h w)')) # nl(4),c,hw
835
+
836
+ pools = rearrange(torch.cat(pools, 2), "nl c nphw -> nl nphw 1 c")
837
+ loc_ = rearrange(loc, 'nl c h w -> nl (h w) 1 c')
838
+
839
+ outputs = []
840
+ for i, q in enumerate(loc_.unbind(dim=0)): # traverse all local patches
841
+ v = pools[i]
842
+ k = v
843
+ outputs.append(self.attention[i](q, k, v)[0])
844
+
845
+ outputs = torch.cat(outputs, 1)
846
+ src = loc.view(4, c, -1).permute(2, 0, 1) + self.dropout1(outputs)
847
+ src = self.norm1(src)
848
+ src = src + self.dropout2(self.linear4(self.dropout(self.activation(self.linear3(src)).clone())))
849
+ src = self.norm2(src)
850
+ src = src.permute(1, 2, 0).reshape(4, c, h, w) # freshed loc
851
+ glb = glb + F.interpolate(patches2image(src), size=glb.shape[-2:], mode='nearest') # freshed glb
852
+
853
+ return torch.cat((src, glb), 0), token_attention_map
854
+
855
+
856
+ class BEN_Base(nn.Module):
857
+ def __init__(self):
858
+ super().__init__()
859
+
860
+ self.backbone = SwinTransformer(embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32], window_size=12)
861
+ emb_dim = 128
862
+ self.sideout5 = nn.Sequential(nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
863
+ self.sideout4 = nn.Sequential(nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
864
+ self.sideout3 = nn.Sequential(nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
865
+ self.sideout2 = nn.Sequential(nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
866
+ self.sideout1 = nn.Sequential(nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
867
+
868
+ self.output5 = make_cbr(1024, emb_dim)
869
+ self.output4 = make_cbr(512, emb_dim)
870
+ self.output3 = make_cbr(256, emb_dim)
871
+ self.output2 = make_cbr(128, emb_dim)
872
+ self.output1 = make_cbr(128, emb_dim)
873
+
874
+ self.multifieldcrossatt = MCLM(emb_dim, 1, [1, 4, 8])
875
+ self.conv1 = make_cbr(emb_dim, emb_dim)
876
+ self.conv2 = make_cbr(emb_dim, emb_dim)
877
+ self.conv3 = make_cbr(emb_dim, emb_dim)
878
+ self.conv4 = make_cbr(emb_dim, emb_dim)
879
+ self.dec_blk1 = MCRM(emb_dim, 1, [2, 4, 8])
880
+ self.dec_blk2 = MCRM(emb_dim, 1, [2, 4, 8])
881
+ self.dec_blk3 = MCRM(emb_dim, 1, [2, 4, 8])
882
+ self.dec_blk4 = MCRM(emb_dim, 1, [2, 4, 8])
883
+
884
+ self.insmask_head = nn.Sequential(
885
+ nn.Conv2d(emb_dim, 384, kernel_size=3, padding=1),
886
+ nn.InstanceNorm2d(384),
887
+ nn.GELU(),
888
+ nn.Conv2d(384, 384, kernel_size=3, padding=1),
889
+ nn.InstanceNorm2d(384),
890
+ nn.GELU(),
891
+ nn.Conv2d(384, emb_dim, kernel_size=3, padding=1)
892
+ )
893
+
894
+ self.shallow = nn.Sequential(nn.Conv2d(3, emb_dim, kernel_size=3, padding=1))
895
+ self.upsample1 = make_cbg(emb_dim, emb_dim)
896
+ self.upsample2 = make_cbg(emb_dim, emb_dim)
897
+ self.output = nn.Sequential(nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
898
+
899
+ for m in self.modules():
900
+ if isinstance(m, nn.GELU) or isinstance(m, nn.Dropout):
901
+ m.inplace = True
902
+
903
+ @torch.inference_mode()
904
+ @torch.autocast(device_type="cuda", dtype=torch.float16)
905
+ def forward(self, x):
906
+ real_batch = x.size(0)
907
+
908
+ shallow_batch = self.shallow(x)
909
+ glb_batch = rescale_to(x, scale_factor=0.5, interpolation='bilinear')
910
+
911
+ final_input = None
912
+ for i in range(real_batch):
913
+ start = i * 4
914
+ end = (i + 1) * 4
915
+ loc_batch = image2patches(x[i, :, :, :].unsqueeze(dim=0))
916
+ input_ = torch.cat((loc_batch, glb_batch[i, :, :, :].unsqueeze(dim=0)), dim=0)
917
+
918
+ if final_input == None:
919
+ final_input = input_
920
+ else:
921
+ final_input = torch.cat((final_input, input_), dim=0)
922
+
923
+ features = self.backbone(final_input)
924
+ outputs = []
925
+
926
+ for i in range(real_batch):
927
+ start = i * 5
928
+ end = (i + 1) * 5
929
+
930
+ f4 = features[4][start:end, :, :, :] # shape: [5, C, H, W]
931
+ f3 = features[3][start:end, :, :, :]
932
+ f2 = features[2][start:end, :, :, :]
933
+ f1 = features[1][start:end, :, :, :]
934
+ f0 = features[0][start:end, :, :, :]
935
+ e5 = self.output5(f4)
936
+ e4 = self.output4(f3)
937
+ e3 = self.output3(f2)
938
+ e2 = self.output2(f1)
939
+ e1 = self.output1(f0)
940
+ loc_e5, glb_e5 = e5.split([4, 1], dim=0)
941
+ e5 = self.multifieldcrossatt(loc_e5, glb_e5) # (4,128,16,16)
942
+
943
+ e4, tokenattmap4 = self.dec_blk4(e4 + resize_as(e5, e4))
944
+ e4 = self.conv4(e4)
945
+ e3, tokenattmap3 = self.dec_blk3(e3 + resize_as(e4, e3))
946
+ e3 = self.conv3(e3)
947
+ e2, tokenattmap2 = self.dec_blk2(e2 + resize_as(e3, e2))
948
+ e2 = self.conv2(e2)
949
+ e1, tokenattmap1 = self.dec_blk1(e1 + resize_as(e2, e1))
950
+ e1 = self.conv1(e1)
951
+
952
+ loc_e1, glb_e1 = e1.split([4, 1], dim=0)
953
+
954
+ output1_cat = patches2image(loc_e1) # (1,128,256,256)
955
+
956
+ # add glb feat in
957
+ output1_cat = output1_cat + resize_as(glb_e1, output1_cat)
958
+ # merge
959
+ final_output = self.insmask_head(output1_cat) # (1,128,256,256)
960
+ # shallow feature merge
961
+ shallow = shallow_batch[i, :, :, :].unsqueeze(dim=0)
962
+ final_output = final_output + resize_as(shallow, final_output)
963
+ final_output = self.upsample1(rescale_to(final_output))
964
+ final_output = rescale_to(final_output + resize_as(shallow, final_output))
965
+ final_output = self.upsample2(final_output)
966
+ final_output = self.output(final_output)
967
+ mask = final_output.sigmoid()
968
+ outputs.append(mask)
969
+
970
+ return torch.cat(outputs, dim=0)
971
+
972
+ def loadcheckpoints(self, model_path):
973
+ model_dict = torch.load(model_path, map_location="cpu", weights_only=True)
974
+ self.load_state_dict(model_dict['model_state_dict'], strict=True)
975
+ del model_path
976
+
977
+ def inference(self, image, refine_foreground=False, move_to_center=False):
978
+
979
+ # set_random_seed(9)
980
+ # image = ImageOps.exif_transpose(image)
981
+ if isinstance(image, Image.Image):
982
+ image, h, w, original_image = rgb_loader_refiner(image)
983
+ if torch.cuda.is_available():
984
+
985
+ img_tensor = img_transform(image).unsqueeze(0).to(next(self.parameters()).device)
986
+ else:
987
+ img_tensor = img_transform32(image).unsqueeze(0).to(next(self.parameters()).device)
988
+
989
+ with torch.no_grad():
990
+ res = self.forward(img_tensor)
991
+
992
+
993
+ # Show Results
994
+ if refine_foreground == True:
995
+
996
+
997
+ pred_pil = transforms.ToPILImage()(res.squeeze())
998
+ image_masked = refine_foreground_process(original_image, pred_pil)
999
+
1000
+ image_masked.putalpha(pred_pil.resize(original_image.size))
1001
+ return image_masked
1002
+
1003
+ else:
1004
+ alpha = postprocess_image(res, im_size=[w, h])
1005
+ pred_pil = transforms.ToPILImage()(alpha)
1006
+ mask = pred_pil.resize(original_image.size)
1007
+ original_image.putalpha(mask)
1008
+ # mask = Image.fromarray(alpha)
1009
+
1010
+ # 将背景置为白色
1011
+ white_background = Image.new('RGB', original_image.size, (255, 255, 255))
1012
+ white_background.paste(original_image, mask=original_image.split()[3])
1013
+
1014
+
1015
+ if move_to_center:
1016
+ # Get the bounding box of non-transparent pixels
1017
+ # Get alpha channel and convert to numpy array for processing
1018
+ alpha_mask = np.array(mask)
1019
+
1020
+ # Find coordinates where mask is 255 (foreground)
1021
+ non_zero_coords = np.where(alpha_mask >= 127.5)
1022
+ if len(non_zero_coords[0]) > 0:
1023
+ # Get bounding box from non-zero coordinates
1024
+ min_y, max_y = non_zero_coords[0].min(), non_zero_coords[0].max()
1025
+ min_x, max_x = non_zero_coords[1].min(), non_zero_coords[1].max()
1026
+
1027
+ # Extract the object region
1028
+ obj_width = max_x - min_x
1029
+ obj_height = max_y - min_y
1030
+ bbox = (min_x, min_y, max_x, max_y)
1031
+
1032
+ # Calculate center position
1033
+ img_width, img_height = white_background.size
1034
+ center_x = (img_width - obj_width) // 2
1035
+ center_y = (img_height - obj_height) // 2
1036
+
1037
+ # Create new white background
1038
+ new_background = Image.new('RGB', white_background.size, (255, 255, 255))
1039
+
1040
+ # Paste the object at center position
1041
+ new_background.paste(white_background.crop(bbox), (center_x, center_y))
1042
+ original_image = new_background
1043
+ else:
1044
+ original_image = white_background
1045
+ else:
1046
+ original_image = white_background
1047
+
1048
+ return original_image
1049
+
1050
+
1051
+ else:
1052
+ foregrounds = []
1053
+ for batch in image:
1054
+ image, h, w, original_image = rgb_loader_refiner(batch)
1055
+ if torch.cuda.is_available():
1056
+
1057
+ img_tensor = img_transform(image).unsqueeze(0).to(next(self.parameters()).device)
1058
+ else:
1059
+ img_tensor = img_transform32(image).unsqueeze(0).to(next(self.parameters()).device)
1060
+
1061
+ with torch.no_grad():
1062
+ res = self.forward(img_tensor)
1063
+
1064
+ if refine_foreground == True:
1065
+
1066
+ pred_pil = transforms.ToPILImage()(res.squeeze())
1067
+ image_masked = refine_foreground_process(original_image, pred_pil)
1068
+
1069
+ image_masked.putalpha(pred_pil.resize(original_image.size))
1070
+
1071
+ foregrounds.append(image_masked)
1072
+ else:
1073
+ alpha = postprocess_image(res, im_size=[w, h])
1074
+ pred_pil = transforms.ToPILImage()(alpha)
1075
+ mask = pred_pil.resize(original_image.size)
1076
+ original_image.putalpha(mask)
1077
+ # mask = Image.fromarray(alpha)
1078
+ foregrounds.append(original_image)
1079
+
1080
+ return foregrounds
1081
+
1082
+ def segment_video(self, video_path, output_path="./", fps=0, refine_foreground=False, batch=1,
1083
+ print_frames_processed=True, webm=False, rgb_value=(0, 255, 0)):
1084
+
1085
+ """
1086
+ Segments the given video to extract the foreground (with alpha) from each frame
1087
+ and saves the result as either a WebM video (with alpha channel) or MP4 (with a
1088
+ color background).
1089
+
1090
+ Args:
1091
+ video_path (str):
1092
+ Path to the input video file.
1093
+
1094
+ output_path (str, optional):
1095
+ Directory (or full path) where the output video and/or files will be saved.
1096
+ Defaults to "./".
1097
+
1098
+ fps (int, optional):
1099
+ The frames per second (FPS) to use for the output video. If 0 (default), the
1100
+ original FPS of the input video is used. Otherwise, overrides it.
1101
+
1102
+ refine_foreground (bool, optional):
1103
+ Whether to run an additional “refine foreground” process on each frame.
1104
+ Defaults to False.
1105
+
1106
+ batch (int, optional):
1107
+ Number of frames to process at once (inference batch size). Large batch sizes
1108
+ may require more GPU memory. Defaults to 1.
1109
+
1110
+ print_frames_processed (bool, optional):
1111
+ If True (default), prints progress (how many frames have been processed) to
1112
+ the console.
1113
+
1114
+ webm (bool, optional):
1115
+ If True (default), exports a WebM video with alpha channel (VP9 / yuva420p).
1116
+ If False, exports an MP4 video composited over a solid color background.
1117
+
1118
+ rgb_value (tuple, optional):
1119
+ The RGB background color (e.g., green screen) used to composite frames when
1120
+ saving to MP4. Defaults to (0, 255, 0).
1121
+
1122
+ Returns:
1123
+ None. Writes the output video(s) to disk in the specified format.
1124
+ """
1125
+
1126
+ cap = cv2.VideoCapture(video_path)
1127
+ if not cap.isOpened():
1128
+ raise IOError(f"Cannot open video: {video_path}")
1129
+
1130
+ original_fps = cap.get(cv2.CAP_PROP_FPS)
1131
+ original_fps = 30 if original_fps == 0 else original_fps
1132
+ fps = original_fps if fps == 0 else fps
1133
+
1134
+ ret, first_frame = cap.read()
1135
+ if not ret:
1136
+ raise ValueError("No frames found in the video.")
1137
+ height, width = first_frame.shape[:2]
1138
+ cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
1139
+
1140
+ foregrounds = []
1141
+ frame_idx = 0
1142
+ processed_count = 0
1143
+ batch_frames = []
1144
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
1145
+
1146
+ while True:
1147
+ ret, frame = cap.read()
1148
+ if not ret:
1149
+ if batch_frames:
1150
+ batch_results = self.inference(batch_frames, refine_foreground)
1151
+ if isinstance(batch_results, Image.Image):
1152
+ foregrounds.append(batch_results)
1153
+ else:
1154
+ foregrounds.extend(batch_results)
1155
+ if print_frames_processed:
1156
+ print(f"Processed frames {frame_idx - len(batch_frames) + 1} to {frame_idx} of {total_frames}")
1157
+ break
1158
+
1159
+ # Process every frame instead of using intervals
1160
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
1161
+ pil_frame = Image.fromarray(frame_rgb)
1162
+ batch_frames.append(pil_frame)
1163
+
1164
+ if len(batch_frames) == batch:
1165
+ batch_results = self.inference(batch_frames, refine_foreground)
1166
+ if isinstance(batch_results, Image.Image):
1167
+ foregrounds.append(batch_results)
1168
+ else:
1169
+ foregrounds.extend(batch_results)
1170
+ if print_frames_processed:
1171
+ print(f"Processed frames {frame_idx - batch + 1} to {frame_idx} of {total_frames}")
1172
+ batch_frames = []
1173
+ processed_count += batch
1174
+
1175
+ frame_idx += 1
1176
+
1177
+ if webm:
1178
+ alpha_webm_path = os.path.join(output_path, "foreground.webm")
1179
+ pil_images_to_webm_alpha(foregrounds, alpha_webm_path, fps=original_fps)
1180
+
1181
+ else:
1182
+ cap.release()
1183
+ fg_output = os.path.join(output_path, 'foreground.mp4')
1184
+
1185
+ pil_images_to_mp4(foregrounds, fg_output, fps=original_fps, rgb_value=rgb_value)
1186
+ cv2.destroyAllWindows()
1187
+
1188
+ try:
1189
+ fg_audio_output = os.path.join(output_path, 'foreground_output_with_audio.mp4')
1190
+ add_audio_to_video(fg_output, video_path, fg_audio_output)
1191
+ except Exception as e:
1192
+ print("No audio found in the original video")
1193
+ print(e)
1194
+
1195
+
1196
+ def rgb_loader_refiner(original_image):
1197
+ h, w = original_image.size
1198
+
1199
+ image = original_image
1200
+ # Convert to RGB if necessary
1201
+ if image.mode != 'RGB':
1202
+ image = image.convert('RGB')
1203
+
1204
+ # Resize the image
1205
+ image = image.resize((1024, 1024), resample=Image.LANCZOS)
1206
+
1207
+ return image.convert('RGB'), h, w, original_image
1208
+
1209
+
1210
+ # Define the image transformation
1211
+ img_transform = transforms.Compose([
1212
+ transforms.ToTensor(),
1213
+ transforms.ConvertImageDtype(torch.float16),
1214
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
1215
+ ])
1216
+
1217
+ img_transform32 = transforms.Compose([
1218
+ transforms.ToTensor(),
1219
+ transforms.ConvertImageDtype(torch.float32),
1220
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
1221
+ ])
1222
+
1223
+
1224
+ def pil_images_to_mp4(images, output_path, fps=24, rgb_value=(0, 255, 0)):
1225
+ """
1226
+ Converts an array of PIL images to an MP4 video.
1227
+
1228
+ Args:
1229
+ images: List of PIL images
1230
+ output_path: Path to save the MP4 file
1231
+ fps: Frames per second (default: 24)
1232
+ rgb_value: Background RGB color tuple (default: green (0, 255, 0))
1233
+ """
1234
+ if not images:
1235
+ raise ValueError("No images provided to convert to MP4.")
1236
+
1237
+ width, height = images[0].size
1238
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
1239
+ video_writer = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
1240
+
1241
+ for image in images:
1242
+ # If image has alpha channel, composite onto the specified background color
1243
+ if image.mode == 'RGBA':
1244
+ # Create background image with specified RGB color
1245
+ background = Image.new('RGB', image.size, rgb_value)
1246
+ background = background.convert('RGBA')
1247
+ # Composite the image onto the background
1248
+ image = Image.alpha_composite(background, image)
1249
+ image = image.convert('RGB')
1250
+ else:
1251
+ # Ensure RGB format for non-alpha images
1252
+ image = image.convert('RGB')
1253
+
1254
+ # Convert to OpenCV format and write
1255
+ open_cv_image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
1256
+ video_writer.write(open_cv_image)
1257
+
1258
+ video_writer.release()
1259
+
1260
+
1261
+ def pil_images_to_webm_alpha(images, output_path, fps=30):
1262
+ """
1263
+ Converts a list of PIL RGBA images to a VP9 .webm video with alpha channel.
1264
+
1265
+ NOTE: Not all players will display alpha in WebM.
1266
+ Browsers like Chrome/Firefox typically do support VP9 alpha.
1267
+ """
1268
+ if not images:
1269
+ raise ValueError("No images provided for WebM with alpha.")
1270
+
1271
+ # Ensure output directory exists
1272
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
1273
+
1274
+ with tempfile.TemporaryDirectory() as tmpdir:
1275
+ # Save frames as PNG (with alpha)
1276
+ for idx, img in enumerate(images):
1277
+ if img.mode != "RGBA":
1278
+ img = img.convert("RGBA")
1279
+ out_path = os.path.join(tmpdir, f"{idx:06d}.png")
1280
+ img.save(out_path, "PNG")
1281
+
1282
+ # Construct ffmpeg command
1283
+ # -c:v libvpx-vp9 => VP9 encoder
1284
+ # -pix_fmt yuva420p => alpha-enabled pixel format
1285
+ # -auto-alt-ref 0 => helps preserve alpha frames (libvpx quirk)
1286
+ ffmpeg_cmd = [
1287
+ "ffmpeg", "-y",
1288
+ "-framerate", str(fps),
1289
+ "-i", os.path.join(tmpdir, "%06d.png"),
1290
+ "-c:v", "libvpx-vp9",
1291
+ "-pix_fmt", "yuva420p",
1292
+ "-auto-alt-ref", "0",
1293
+ output_path
1294
+ ]
1295
+
1296
+ subprocess.run(ffmpeg_cmd, check=True)
1297
+
1298
+ print(f"WebM with alpha saved to {output_path}")
1299
+
1300
+
1301
+ def add_audio_to_video(video_without_audio_path, original_video_path, output_path):
1302
+ """
1303
+ Check if the original video has an audio stream. If yes, add it. If not, skip.
1304
+ """
1305
+ # 1) Probe original video for audio streams
1306
+ probe_command = [
1307
+ 'ffprobe', '-v', 'error',
1308
+ '-select_streams', 'a:0',
1309
+ '-show_entries', 'stream=index',
1310
+ '-of', 'csv=p=0',
1311
+ original_video_path
1312
+ ]
1313
+ result = subprocess.run(probe_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
1314
+
1315
+ # result.stdout is empty if no audio stream found
1316
+ if not result.stdout.strip():
1317
+ print("No audio track found in original video, skipping audio addition.")
1318
+ return
1319
+
1320
+ print("Audio track detected; proceeding to mux audio.")
1321
+ # 2) If audio found, run ffmpeg to add it
1322
+ command = [
1323
+ 'ffmpeg', '-y',
1324
+ '-i', video_without_audio_path,
1325
+ '-i', original_video_path,
1326
+ '-c', 'copy',
1327
+ '-map', '0:v:0',
1328
+ '-map', '1:a:0', # we know there's an audio track now
1329
+ output_path
1330
+ ]
1331
+ subprocess.run(command, check=True)
1332
+ print(f"Audio added successfully => {output_path}")
1333
+
1334
+
1335
+ ### Thanks to the source: https://huggingface.co/ZhengPeng7/BiRefNet/blob/main/handler.py
1336
+ def refine_foreground_process(image, mask, r=90):
1337
+ if mask.size != image.size:
1338
+ mask = mask.resize(image.size)
1339
+ image = np.array(image) / 255.0
1340
+ mask = np.array(mask) / 255.0
1341
+ estimated_foreground = FB_blur_fusion_foreground_estimator_2(image, mask, r=r)
1342
+ image_masked = Image.fromarray((estimated_foreground * 255.0).astype(np.uint8))
1343
+ return image_masked
1344
+
1345
+
1346
+ def FB_blur_fusion_foreground_estimator_2(image, alpha, r=90):
1347
+ # Thanks to the source: https://github.com/Photoroom/fast-foreground-estimation
1348
+ alpha = alpha[:, :, None]
1349
+ F, blur_B = FB_blur_fusion_foreground_estimator(image, image, image, alpha, r)
1350
+ return FB_blur_fusion_foreground_estimator(image, F, blur_B, alpha, r=6)[0]
1351
+
1352
+
1353
+ def FB_blur_fusion_foreground_estimator(image, F, B, alpha, r=90):
1354
+ if isinstance(image, Image.Image):
1355
+ image = np.array(image) / 255.0
1356
+ blurred_alpha = cv2.blur(alpha, (r, r))[:, :, None]
1357
+
1358
+ blurred_FA = cv2.blur(F * alpha, (r, r))
1359
+ blurred_F = blurred_FA / (blurred_alpha + 1e-5)
1360
+
1361
+ blurred_B1A = cv2.blur(B * (1 - alpha), (r, r))
1362
+ blurred_B = blurred_B1A / ((1 - blurred_alpha) + 1e-5)
1363
+ F = blurred_F + alpha * \
1364
+ (image - alpha * blurred_F - (1 - alpha) * blurred_B)
1365
+ F = np.clip(F, 0, 1)
1366
+ return F, blurred_B
1367
+
1368
+
1369
+ def postprocess_image(result: torch.Tensor, im_size: list) -> np.ndarray:
1370
+ result = torch.squeeze(F.interpolate(result, size=im_size, mode='bilinear'), 0)
1371
+ ma = torch.max(result)
1372
+ mi = torch.min(result)
1373
+ result = (result - mi) / (ma - mi)
1374
+ im_array = (result * 255).permute(1, 2, 0).cpu().data.numpy().astype(np.uint8)
1375
+ im_array = np.squeeze(im_array)
1376
+ return im_array
1377
+
1378
+
1379
+ def rgb_loader_refiner(original_image):
1380
+ h, w = original_image.size
1381
+ # # Apply EXIF orientation
1382
+
1383
+ image = ImageOps.exif_transpose(original_image)
1384
+
1385
+ if original_image.mode != 'RGB':
1386
+ original_image = original_image.convert('RGB')
1387
+
1388
+ image = original_image
1389
+ # Convert to RGB if necessary
1390
+
1391
+ # Resize the image
1392
+ image = image.resize((1024, 1024), resample=Image.LANCZOS)
1393
+
1394
+ return image, h, w, original_image
app/aspect_ratio_template.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # From https://github.com/TencentARC/PhotoMaker/pull/120 written by https://github.com/DiscoNova
2
+ # Note: Since output width & height need to be divisible by 8, the w & h -values do
3
+ # not exactly match the stated aspect ratios... but they are "close enough":)
4
+
5
+ ASPECT_RATIO_TEMPLATE = [
6
+ {
7
+ "name": "Custom (long edge to 1024px)",
8
+ "w": "",
9
+ "h": "",
10
+ },
11
+ {
12
+ "name": "Custom",
13
+ "w": "",
14
+ "h": "",
15
+ },
16
+ {
17
+ "name": "Instagram (1:1)",
18
+ "w": 1024,
19
+ "h": 1024,
20
+ },
21
+ {
22
+ "name": "35mm film / Landscape (3:2)",
23
+ "w": 1024,
24
+ "h": 680,
25
+ },
26
+ {
27
+ "name": "35mm film / Portrait (2:3)",
28
+ "w": 680,
29
+ "h": 1024,
30
+ },
31
+ {
32
+ "name": "CRT Monitor / Landscape (4:3)",
33
+ "w": 1024,
34
+ "h": 768,
35
+ },
36
+ {
37
+ "name": "CRT Monitor / Portrait (3:4)",
38
+ "w": 768,
39
+ "h": 1024,
40
+ },
41
+ {
42
+ "name": "Widescreen TV / Landscape (16:9)",
43
+ "w": 1024,
44
+ "h": 576,
45
+ },
46
+ {
47
+ "name": "Widescreen TV / Portrait (9:16)",
48
+ "w": 576,
49
+ "h": 1024,
50
+ },
51
+ {
52
+ "name": "Widescreen Monitor / Landscape (16:10)",
53
+ "w": 1024,
54
+ "h": 640,
55
+ },
56
+ {
57
+ "name": "Widescreen Monitor / Portrait (10:16)",
58
+ "w": 640,
59
+ "h": 1024,
60
+ },
61
+ {
62
+ "name": "Cinemascope (2.39:1)",
63
+ "w": 1024,
64
+ "h": 424,
65
+ },
66
+ {
67
+ "name": "Widescreen Movie (1.85:1)",
68
+ "w": 1024,
69
+ "h": 552,
70
+ },
71
+ {
72
+ "name": "Academy Movie (1.37:1)",
73
+ "w": 1024,
74
+ "h": 744,
75
+ },
76
+ {
77
+ "name": "Sheet-print (A-series) / Landscape (297:210)",
78
+ "w": 1024,
79
+ "h": 720,
80
+ },
81
+ {
82
+ "name": "Sheet-print (A-series) / Portrait (210:297)",
83
+ "w": 720,
84
+ "h": 1024,
85
+ },
86
+ ]
87
+
88
+ ASPECT_RATIO_TEMPLATE = {k["name"]: (k["w"], k["h"]) for k in ASPECT_RATIO_TEMPLATE}
app/business_logic.py ADDED
@@ -0,0 +1,556 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Business logic functions for IC-Custom application.
5
+ """
6
+ import numpy as np
7
+ import torch
8
+ import cv2
9
+ import gradio as gr
10
+ from PIL import Image
11
+ from datetime import datetime
12
+ import json
13
+ import os
14
+ from scipy.ndimage import binary_dilation, binary_erosion
15
+
16
+ from constants import (
17
+ DEFAULT_BACKGROUND_BLEND_THRESHOLD, DEFAULT_SEED, DEFAULT_NUM_IMAGES,
18
+ DEFAULT_GUIDANCE, DEFAULT_TRUE_GS, DEFAULT_NUM_STEPS, DEFAULT_ASPECT_RATIO,
19
+ DEFAULT_DILATION_KERNEL_SIZE, DEFAULT_MARKER_SIZE, DEFAULT_MARKER_THICKNESS,
20
+ DEFAULT_MASK_ALPHA, DEFAULT_COLOR_ALPHA, TIMESTAMP_FORMAT, SEGMENTATION_COLORS, SEGMENTATION_MARKERS
21
+ )
22
+
23
+ from utils import run_vlm, construct_vlm_gen_prompt, construct_vlm_polish_prompt
24
+
25
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26
+
27
+ # Global holder for SAM mobile predictor injected from the app layer
28
+ MOBILE_PREDICTOR = None
29
+ BEN2_MODEL = None # ben2 model injected from the app layer
30
+
31
+ def set_mobile_predictor(predictor):
32
+ """Inject SAM mobile predictor into this module without changing function signatures."""
33
+ global MOBILE_PREDICTOR
34
+ MOBILE_PREDICTOR = predictor
35
+
36
+ def set_ben2_model(ben2_model):
37
+ """Inject ben2 model into this module without changing function signatures."""
38
+ global BEN2_MODEL
39
+ BEN2_MODEL = ben2_model
40
+
41
+ def set_vlm_processor(vlm_processor):
42
+ """Inject vlm processor into this module without changing function signatures."""
43
+ global VLM_PROCESSOR
44
+ VLM_PROCESSOR = vlm_processor
45
+
46
+ def set_vlm_model(vlm_model):
47
+ """Inject vlm model into this module without changing function signatures."""
48
+ global VLM_MODEL
49
+ VLM_MODEL = vlm_model
50
+
51
+
52
+ def init_image_target_1(target_image):
53
+ """Initialize UI state when a target image is uploaded."""
54
+
55
+ # Handle both PIL Image (image_target_1) and ImageEditor dict (image_target_2)
56
+ try:
57
+ if isinstance(target_image, dict) and 'composite' in target_image:
58
+ # ImageEditor format (user-drawn mask)
59
+ image_target_state = np.array(target_image['composite'].convert("RGB"))
60
+ else:
61
+ # PIL Image format (precise mask)
62
+ image_target_state = np.array(target_image.convert("RGB"))
63
+ except Exception as e:
64
+ # If there's an error processing the image, skip initialization
65
+ return (
66
+ gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(),
67
+ gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(),
68
+ gr.skip(), gr.skip(), gr.update(value="-1")
69
+ )
70
+
71
+ selected_points = []
72
+ mask_target_state = None
73
+ prompt = None
74
+ mask_gallery = []
75
+ result_gallery = []
76
+ use_background_preservation = False
77
+ background_blend_threshold = DEFAULT_BACKGROUND_BLEND_THRESHOLD
78
+ seed = DEFAULT_SEED
79
+ num_images_per_prompt = DEFAULT_NUM_IMAGES
80
+ guidance = DEFAULT_GUIDANCE
81
+ true_gs = DEFAULT_TRUE_GS
82
+ num_steps = DEFAULT_NUM_STEPS
83
+ aspect_ratio_val = gr.update(value=DEFAULT_ASPECT_RATIO)
84
+
85
+ return (image_target_state, selected_points, mask_target_state, prompt,
86
+ mask_gallery, result_gallery, use_background_preservation,
87
+ background_blend_threshold, seed, num_images_per_prompt, guidance,
88
+ true_gs, num_steps, aspect_ratio_val)
89
+
90
+
91
+ def init_image_target_2(target_image):
92
+ """Initialize UI state when a target image is uploaded."""
93
+ # Handle both PIL Image (image_target_1) and ImageEditor dict (image_target_2)
94
+ try:
95
+ if isinstance(target_image, dict) and 'composite' in target_image:
96
+ # ImageEditor format (user-drawn mask)
97
+ image_target_state = np.array(target_image['composite'].convert("RGB"))
98
+ else:
99
+ # PIL Image format (precise mask)
100
+ image_target_state = np.array(target_image.convert("RGB"))
101
+ except Exception as e:
102
+ # If there's an error processing the image, skip initialization
103
+ return (
104
+ gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(),
105
+ gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(),
106
+ gr.skip(), gr.skip(), gr.update(value="-1")
107
+ )
108
+
109
+ selected_points = gr.skip()
110
+ mask_target_state = gr.skip()
111
+ prompt = gr.skip()
112
+ mask_gallery = gr.skip()
113
+ result_gallery = gr.skip()
114
+ use_background_preservation = gr.skip()
115
+ background_blend_threshold = gr.skip()
116
+ seed = gr.skip()
117
+ num_images_per_prompt = gr.skip()
118
+ guidance = gr.skip()
119
+ true_gs = gr.skip()
120
+ num_steps = gr.skip()
121
+ aspect_ratio_val = gr.skip()
122
+
123
+ return (image_target_state, selected_points, mask_target_state, prompt,
124
+ mask_gallery, result_gallery, use_background_preservation,
125
+ background_blend_threshold, seed, num_images_per_prompt, guidance,
126
+ true_gs, num_steps, aspect_ratio_val)
127
+
128
+
129
+ def init_image_reference(image_reference):
130
+ """Initialize all UI states when a reference image is uploaded."""
131
+ image_reference_state = np.array(image_reference.convert("RGB"))
132
+ image_reference_ori_state = image_reference_state
133
+ image_reference_rmbg_state = None
134
+ image_target_state = None
135
+ mask_target_state = None
136
+ prompt = None
137
+ mask_gallery = []
138
+ result_gallery = []
139
+ image_target_1_val = None
140
+ image_target_2_val = None
141
+ selected_points = []
142
+ input_mask_mode_val = gr.update(value="Precise mask")
143
+ seg_ref_mode_val = gr.update(value="Full Ref")
144
+ move_to_center = False
145
+ use_background_preservation = False
146
+ background_blend_threshold = DEFAULT_BACKGROUND_BLEND_THRESHOLD
147
+ seed = DEFAULT_SEED
148
+ num_images_per_prompt = DEFAULT_NUM_IMAGES
149
+ guidance = DEFAULT_GUIDANCE
150
+ true_gs = DEFAULT_TRUE_GS
151
+ num_steps = DEFAULT_NUM_STEPS
152
+ aspect_ratio_val = gr.update(value=DEFAULT_ASPECT_RATIO)
153
+
154
+ return (
155
+ image_reference_ori_state, image_reference_rmbg_state, image_target_state,
156
+ mask_target_state, prompt, mask_gallery, result_gallery, image_target_1_val,
157
+ image_target_2_val, selected_points, input_mask_mode_val, seg_ref_mode_val,
158
+ move_to_center, use_background_preservation, background_blend_threshold,
159
+ seed, num_images_per_prompt, guidance, true_gs, num_steps, aspect_ratio_val,
160
+ )
161
+
162
+
163
+ def undo_seg_points(orig_img, sel_pix):
164
+ """Remove the latest segmentation point and recompute the preview mask."""
165
+ if len(sel_pix) != 0:
166
+ temp = orig_img.copy()
167
+ sel_pix.pop()
168
+ # Online show seg mask
169
+ if len(sel_pix) != 0:
170
+ temp, output_mask = segmentation(temp, sel_pix, MOBILE_PREDICTOR, SEGMENTATION_COLORS, SEGMENTATION_MARKERS)
171
+ output_mask_pil = Image.fromarray(output_mask.astype("uint8"))
172
+ masked_img_pil = Image.fromarray(np.where(output_mask > 0, orig_img, 0).astype("uint8"))
173
+ mask_gallery = [masked_img_pil, output_mask_pil]
174
+ else:
175
+ output_mask = None
176
+ mask_gallery = []
177
+ return temp.astype(np.uint8), output_mask, mask_gallery
178
+ else:
179
+ gr.Warning("Nothing to Undo")
180
+ return orig_img, None, []
181
+
182
+
183
+ def segmentation(img, sel_pix, mobile_predictor, colors, markers):
184
+ """Run SAM-based segmentation given selected points and return previews."""
185
+ points = []
186
+ labels = []
187
+ for p, l in sel_pix:
188
+ points.append(p)
189
+ labels.append(l)
190
+
191
+ mobile_predictor.set_image(img if isinstance(img, np.ndarray) else np.array(img))
192
+ with torch.no_grad():
193
+ masks, _, _ = mobile_predictor.predict(
194
+ point_coords=np.array(points),
195
+ point_labels=np.array(labels),
196
+ multimask_output=False
197
+ )
198
+
199
+ output_mask = np.ones((masks.shape[1], masks.shape[2], 3)) * 255
200
+ for i in range(3):
201
+ output_mask[masks[0] == True, i] = 0.0
202
+
203
+ mask_all = np.ones((masks.shape[1], masks.shape[2], 3))
204
+ color_mask = np.random.random((1, 3)).tolist()[0]
205
+ for i in range(3):
206
+ mask_all[masks[0] == True, i] = color_mask[i]
207
+
208
+ masked_img = img / 255 * DEFAULT_MASK_ALPHA + mask_all * DEFAULT_COLOR_ALPHA
209
+ masked_img = masked_img * 255
210
+
211
+ # Draw points
212
+ for point, label in sel_pix:
213
+ cv2.drawMarker(
214
+ masked_img, point, colors[label],
215
+ markerType=markers[label],
216
+ markerSize=DEFAULT_MARKER_SIZE,
217
+ thickness=DEFAULT_MARKER_THICKNESS
218
+ )
219
+
220
+ return masked_img, output_mask
221
+
222
+
223
+ def get_point(img, sel_pix, evt: gr.SelectData):
224
+ """Handle a user click on the target image to add a foreground point."""
225
+ if evt is None or not hasattr(evt, 'index'):
226
+ gr.Warning(f"Event object missing index attribute. Event type: {type(evt)}")
227
+ return img, None, []
228
+
229
+ sel_pix.append((evt.index, 1)) # append the foreground_point
230
+ # Online show seg mask
231
+ global MOBILE_PREDICTOR
232
+ masked_img_seg, output_mask = segmentation(img, sel_pix, MOBILE_PREDICTOR, SEGMENTATION_COLORS, SEGMENTATION_MARKERS)
233
+
234
+ # Apply dilation to output_mask
235
+ output_mask = 1 - output_mask
236
+ kernel = np.ones((DEFAULT_DILATION_KERNEL_SIZE, DEFAULT_DILATION_KERNEL_SIZE), np.uint8)
237
+ output_mask = cv2.dilate(output_mask, kernel, iterations=1)
238
+ output_mask = 1 - output_mask
239
+
240
+ output_mask_binary = output_mask / 255
241
+
242
+ masked_img_seg = masked_img_seg.astype("uint8")
243
+ output_mask = output_mask.astype("uint8")
244
+
245
+ masked_img = img * output_mask_binary
246
+ masked_img_pil = Image.fromarray(masked_img.astype("uint8"))
247
+ output_mask_pil = Image.fromarray(output_mask.astype("uint8"))
248
+ outputs_gallery = [masked_img_pil, output_mask_pil]
249
+
250
+ return masked_img_seg, output_mask, outputs_gallery
251
+
252
+
253
+ def get_brush(img):
254
+ """Extract a mask from ImageEditor brush layers or composite/background diff."""
255
+ if img is None or not isinstance(img, dict):
256
+ return gr.skip(), gr.skip()
257
+
258
+ layers = img.get("layers", [])
259
+ background = img.get('background', None)
260
+ composite = img.get('composite', None)
261
+
262
+ output_mask = None
263
+ if layers and layers[0] is not None and background is not None:
264
+ output_mask = 255 - np.array(layers[0].convert("RGB")).astype(np.uint8)
265
+ elif composite is not None and background is not None:
266
+ comp_rgb = np.array(composite.convert("RGB")).astype(np.int16)
267
+ bg_rgb = np.array(background.convert("RGB")).astype(np.int16)
268
+ diff = np.abs(comp_rgb - bg_rgb)
269
+ painted = (diff.sum(axis=2) > 0).astype(np.uint8)
270
+ output_mask = (1 - painted) * 255
271
+ output_mask = np.repeat(output_mask[:, :, None], 3, axis=2).astype(np.uint8)
272
+ else:
273
+ return gr.skip(), gr.skip()
274
+
275
+ if len(np.unique(output_mask)) == 1:
276
+ return gr.skip(), gr.skip()
277
+
278
+ img = np.array(background.convert("RGB")).astype(np.uint8)
279
+
280
+ output_mask_binary = output_mask / 255
281
+ masked_img = img * output_mask_binary
282
+ masked_img_pil = Image.fromarray(masked_img.astype("uint8"))
283
+ output_mask_pil = Image.fromarray(output_mask.astype("uint8"))
284
+ mask_gallery = [masked_img_pil, output_mask_pil]
285
+
286
+ return output_mask, mask_gallery
287
+
288
+
289
+ def random_mask_func(mask, dilation_type='square', dilation_size=20):
290
+ """Utility to dilate/erode/box/ellipse expand a binary mask."""
291
+ binary_mask = mask[:,:,0] < 128
292
+
293
+ if dilation_type == 'square_dilation':
294
+ structure = np.ones((dilation_size, dilation_size), dtype=bool)
295
+ dilated_mask = binary_dilation(binary_mask, structure=structure)
296
+ elif dilation_type == 'square_erosion':
297
+ structure = np.ones((dilation_size, dilation_size), dtype=bool)
298
+ dilated_mask = binary_erosion(binary_mask, structure=structure)
299
+ elif dilation_type == 'bounding_box':
300
+ # Find the most left top and left bottom point
301
+ rows, cols = np.where(binary_mask)
302
+ if len(rows) == 0 or len(cols) == 0:
303
+ return mask # return original mask if no valid points
304
+
305
+ min_row, max_row = np.min(rows), np.max(rows)
306
+ min_col, max_col = np.min(cols), np.max(cols)
307
+
308
+ # Create a bounding box
309
+ dilated_mask = np.zeros_like(binary_mask, dtype=bool)
310
+ dilated_mask[min_row:max_row + 1, min_col:max_col + 1] = True
311
+
312
+ elif dilation_type == 'bounding_ellipse':
313
+ # Find the most left top and left bottom point
314
+ rows, cols = np.where(binary_mask)
315
+ if len(rows) == 0 or len(cols) == 0:
316
+ return mask # return original mask if no valid points
317
+
318
+ min_row, max_row = np.min(rows), np.max(rows)
319
+ min_col, max_col = np.min(cols), np.max(cols)
320
+
321
+ # Calculate the center and axis length of the ellipse
322
+ center = ((min_col + max_col) // 2, (min_row + max_row) // 2)
323
+ a = (max_col - min_col) // 2 # half long axis
324
+ b = (max_row - min_row) // 2 # half short axis
325
+
326
+ # Create a bounding ellipse
327
+ y, x = np.ogrid[:mask.shape[0], :mask.shape[1]]
328
+ ellipse_mask = ((x - center[0])**2 / a**2 + (y - center[1])**2 / b**2) <= 1
329
+ dilated_mask = np.zeros_like(binary_mask, dtype=bool)
330
+ dilated_mask[ellipse_mask] = True
331
+ else:
332
+ raise ValueError("dilation_type must be 'square', 'ellipse', 'bounding_box', or 'bounding_ellipse'")
333
+
334
+ # Use binary dilation
335
+ dilated_mask = 1 - dilated_mask
336
+ dilated_mask = np.uint8(dilated_mask[:,:,np.newaxis]) * 255
337
+ dilated_mask = np.concatenate([dilated_mask, dilated_mask, dilated_mask], axis=2)
338
+ return dilated_mask
339
+
340
+
341
+ def dilate_mask(mask, image):
342
+ """Dilate the target mask for robustness and preview the result."""
343
+ if mask is None:
344
+ gr.Warning("Please input the target mask first")
345
+ return None, None
346
+
347
+ mask = random_mask_func(mask, dilation_type='square_dilation', dilation_size=DEFAULT_DILATION_KERNEL_SIZE)
348
+ masked_img = image * (mask > 0)
349
+ return mask, [masked_img, mask]
350
+
351
+
352
+ def erode_mask(mask, image):
353
+ """Erode the target mask and preview the result."""
354
+ if mask is None:
355
+ gr.Warning("Please input the target mask first")
356
+ return None, None
357
+
358
+ mask = random_mask_func(mask, dilation_type='square_erosion', dilation_size=DEFAULT_DILATION_KERNEL_SIZE)
359
+ masked_img = image * (mask > 0)
360
+ return mask, [masked_img, mask]
361
+
362
+
363
+ def bounding_box(mask, image):
364
+ """Create bounding box mask and preview the result."""
365
+ if mask is None:
366
+ gr.Warning("Please input the target mask first")
367
+ return None, None
368
+
369
+ mask = random_mask_func(mask, dilation_type='bounding_box', dilation_size=DEFAULT_DILATION_KERNEL_SIZE)
370
+ masked_img = image * (mask > 0)
371
+ return mask, [masked_img, mask]
372
+
373
+
374
+ def change_input_mask_mode(input_mask_mode, custmization_mode):
375
+ """Change visibility of input mask mode components."""
376
+
377
+ if custmization_mode == "Position-free":
378
+ return (
379
+ gr.update(visible=False),
380
+ gr.update(visible=False),
381
+ gr.update(visible=False),
382
+ )
383
+ elif input_mask_mode.lower() == "precise mask":
384
+ return (
385
+ gr.update(visible=True),
386
+ gr.update(visible=False),
387
+ gr.update(visible=True),
388
+ )
389
+ elif input_mask_mode.lower() == "user-drawn mask":
390
+ return (
391
+ gr.update(visible=False),
392
+ gr.update(visible=True),
393
+ gr.update(visible=False),
394
+ )
395
+ else:
396
+ gr.Warning("Invalid input mask mode")
397
+ return (
398
+ gr.skip(), gr.skip(), gr.skip()
399
+ )
400
+
401
+ def change_custmization_mode(custmization_mode, input_mask_mode):
402
+ """Change visibility and interactivity based on customization mode."""
403
+
404
+
405
+ if custmization_mode.lower() == "position-free":
406
+ return (gr.update(interactive=False, visible=False),
407
+ gr.update(interactive=False, visible=False),
408
+ gr.update(interactive=False, visible=False),
409
+ gr.update(interactive=False, visible=False),
410
+ gr.update(interactive=False, visible=False),
411
+ gr.update(interactive=False, visible=False),
412
+ gr.update(value="<s>Select a input mask mode</s>", visible=False),
413
+ gr.update(value="<s>Input target image & mask (Iterate clicking or brushing until the target is covered)</s>", visible=False),
414
+ gr.update(value="<s>View or modify the target mask</s>", visible=False),
415
+ gr.update(value="3. Input text prompt (necessary)"),
416
+ gr.update(value="4. Submit and view the output"),
417
+ gr.update(visible=False),
418
+ gr.update(visible=False),
419
+
420
+ )
421
+ else:
422
+ if input_mask_mode.lower() == "precise mask":
423
+ return (gr.update(interactive=True, visible=True),
424
+ gr.update(interactive=True, visible=False),
425
+ gr.update(interactive=True, visible=True),
426
+ gr.update(interactive=True, visible=True),
427
+ gr.update(interactive=True, visible=True),
428
+ gr.update(interactive=True, visible=True),
429
+ gr.update(value="3. Select a input mask mode", visible=True),
430
+ gr.update(value="4. Input target image & mask (Iterate clicking or brushing until the target is covered)", visible=True),
431
+ gr.update(value="6. View or modify the target mask", visible=True),
432
+ gr.update(value="5. Input text prompt (optional)", visible=True),
433
+ gr.update(value="7. Submit and view the output", visible=True),
434
+ gr.update(visible=True, value="Precise mask"),
435
+ gr.update(visible=True),
436
+ )
437
+ elif input_mask_mode.lower() == "user-drawn mask":
438
+ return (gr.update(interactive=True, visible=False),
439
+ gr.update(interactive=True, visible=True),
440
+ gr.update(interactive=False, visible=False),
441
+ gr.update(interactive=True, visible=True),
442
+ gr.update(interactive=True, visible=True),
443
+ gr.update(interactive=True, visible=True),
444
+ gr.update(value="3. Select a input mask mode", visible=True),
445
+ gr.update(value="4. Input target image & mask (Iterate clicking or brushing until the target is covered)", visible=True),
446
+ gr.update(value="6. View or modify the target mask", visible=True),
447
+ gr.update(value="5. Input text prompt (optional)", visible=True),
448
+ gr.update(value="7. Submit and view the output", visible=True),
449
+ gr.update(visible=True, value="User-drawn mask"),
450
+ gr.update(visible=True),
451
+ )
452
+
453
+
454
+ def change_seg_ref_mode(seg_ref_mode, image_reference_state, move_to_center):
455
+ """Change segmentation reference mode and handle background removal."""
456
+ if image_reference_state is None:
457
+ gr.Warning("Please upload the reference image first")
458
+ return None, None
459
+
460
+ global BEN2_MODEL
461
+
462
+ if seg_ref_mode == "Full Ref":
463
+ return image_reference_state, None
464
+ else:
465
+ if BEN2_MODEL is None:
466
+ gr.Warning("Please enable ben2 for mask reference first")
467
+ return gr.skip(), gr.skip()
468
+
469
+ image_reference_pil = Image.fromarray(image_reference_state)
470
+ image_reference_pil_rmbg = BEN2_MODEL.inference(image_reference_pil, move_to_center=move_to_center)
471
+ image_reference_rmbg = np.array(image_reference_pil_rmbg)
472
+ return image_reference_rmbg, image_reference_rmbg
473
+
474
+
475
+ def vlm_auto_generate(image_target_state, image_reference_state, mask_target_state,
476
+ custmization_mode):
477
+ """Auto-generate prompt using VLM."""
478
+
479
+ global VLM_PROCESSOR, VLM_MODEL
480
+
481
+ if custmization_mode == "Position-aware":
482
+ if image_target_state is None or mask_target_state is None:
483
+ gr.Warning("Please upload the target image and get mask first")
484
+ return None
485
+
486
+ if image_reference_state is None:
487
+ gr.Warning("Please upload the reference image first")
488
+ return None
489
+
490
+ if VLM_PROCESSOR is None or VLM_MODEL is None:
491
+ gr.Warning("Please enable vlm for prompt first")
492
+ return prompt
493
+
494
+ messages = construct_vlm_gen_prompt(image_target_state, image_reference_state, mask_target_state, custmization_mode)
495
+ output_text = run_vlm(VLM_PROCESSOR, VLM_MODEL, messages, device=device)
496
+ return output_text
497
+
498
+
499
+ def vlm_auto_polish(prompt, custmization_mode):
500
+ """Auto-polish prompt using VLM."""
501
+
502
+ global VLM_PROCESSOR, VLM_MODEL
503
+
504
+ if prompt is None:
505
+ gr.Warning("Please input the text prompt first")
506
+ return None
507
+
508
+ if custmization_mode == "Position-aware":
509
+ gr.Warning("Polishing only works in position-free mode")
510
+ return prompt
511
+
512
+
513
+ if VLM_PROCESSOR is None or VLM_MODEL is None:
514
+ gr.Warning("Please enable vlm for prompt first")
515
+ return prompt
516
+
517
+ messages = construct_vlm_polish_prompt(prompt)
518
+ output_text = run_vlm(VLM_PROCESSOR, VLM_MODEL, messages, device=device)
519
+ return output_text
520
+
521
+
522
+ def save_results(output_img, image_reference, image_target, mask_target, prompt,
523
+ custmization_mode, input_mask_mode, seg_ref_mode, seed, guidance,
524
+ num_steps, num_images_per_prompt, use_background_preservation,
525
+ background_blend_threshold, true_gs, assets_cache_dir):
526
+ """Save generated results and metadata."""
527
+ save_name = datetime.now().strftime(TIMESTAMP_FORMAT)
528
+ results = []
529
+
530
+ for i in range(num_images_per_prompt):
531
+ save_dir = os.path.join(assets_cache_dir, save_name)
532
+ os.makedirs(save_dir, exist_ok=True)
533
+
534
+ output_img[i].save(os.path.join(save_dir, f"img_gen_{i}.png"))
535
+ image_reference.save(os.path.join(save_dir, f"img_ref_{i}.png"))
536
+ image_target.save(os.path.join(save_dir, f"img_target_{i}.png"))
537
+ mask_target.save(os.path.join(save_dir, f"mask_target_{i}.png"))
538
+
539
+ with open(os.path.join(save_dir, f"hyper_params_{i}.json"), "w") as f:
540
+ json.dump({
541
+ "prompt": prompt,
542
+ "custmization_mode": custmization_mode,
543
+ "input_mask_mode": input_mask_mode,
544
+ "seg_ref_mode": seg_ref_mode,
545
+ "seed": seed,
546
+ "guidance": guidance,
547
+ "num_steps": num_steps,
548
+ "num_images_per_prompt": num_images_per_prompt,
549
+ "use_background_preservation": use_background_preservation,
550
+ "background_blend_threshold": background_blend_threshold,
551
+ "true_gs": true_gs,
552
+ }, f)
553
+
554
+ results.append(output_img[i])
555
+
556
+ return results
app/config.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Configuration management for IC-Custom application.
5
+ """
6
+ import os
7
+ import argparse
8
+ from omegaconf import OmegaConf
9
+
10
+
11
+ def parse_args():
12
+ """Parse command line arguments."""
13
+ parser = argparse.ArgumentParser(description="IC-Custom App.")
14
+ parser.add_argument(
15
+ "--config",
16
+ type=str,
17
+ default="configs/app/app.yaml",
18
+ help="path to config",
19
+ )
20
+ parser.add_argument(
21
+ "--hf_token",
22
+ type=str,
23
+ required=False,
24
+ help="Hugging Face token",
25
+ )
26
+ parser.add_argument(
27
+ "--hf_cache_dir",
28
+ type=str,
29
+ required=False,
30
+ default=os.path.expanduser("~/.cache/huggingface/hub"),
31
+ help="Cache directory to save the models, default is ~/.cache/huggingface/hub",
32
+ )
33
+ parser.add_argument(
34
+ "--assets_cache_dir",
35
+ type=str,
36
+ required=False,
37
+ default="results/app",
38
+ help="Cache directory to save the results, default is results/app",
39
+ )
40
+ parser.add_argument(
41
+ "--save_results",
42
+ action="store_true",
43
+ help="Save results",
44
+ )
45
+ parser.add_argument(
46
+ "--enable_ben2_for_mask_ref",
47
+ action=argparse.BooleanOptionalAction,
48
+ default=True,
49
+ help="Enable ben2 for mask reference (default: True)",
50
+ )
51
+ parser.add_argument(
52
+ "--enable_vlm_for_prompt",
53
+ action=argparse.BooleanOptionalAction,
54
+ default=False,
55
+ help="Enable vlm for prompt (default: True)",
56
+ )
57
+
58
+ return parser.parse_args()
59
+
60
+
61
+ def load_config(config_path):
62
+ """Load configuration from file."""
63
+ return OmegaConf.load(config_path)
64
+
65
+
66
+ def setup_environment(args):
67
+ """Setup environment variables from command line arguments."""
68
+ if args.hf_token is not None:
69
+ os.environ["HF_TOKEN"] = args.hf_token
70
+
71
+ if args.hf_cache_dir is not None:
72
+ os.environ["HF_HUB_CACHE"] = args.hf_cache_dir
app/constants.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Constants and default values for IC-Custom application.
5
+ """
6
+ from aspect_ratio_template import ASPECT_RATIO_TEMPLATE
7
+
8
+ # Aspect ratio constants
9
+ ASPECT_RATIO_LABELS = list(ASPECT_RATIO_TEMPLATE)
10
+ DEFAULT_ASPECT_RATIO = ASPECT_RATIO_LABELS[0]
11
+
12
+ # Colors and markers for segmentation
13
+ # OpenCV expects BGR colors; keep tuples as (R, G, B) for consistency across code.
14
+ SEGMENTATION_COLORS = [(255, 0, 0), (0, 255, 0)]
15
+ SEGMENTATION_MARKERS = [1, 5]
16
+ RGBA_COLORS = [(255, 0, 255, 255), (0, 255, 0, 255), (0, 0, 255, 255)]
17
+
18
+ # Magic-number constants
19
+ DEFAULT_BACKGROUND_BLEND_THRESHOLD = 0.5
20
+ DEFAULT_NUM_STEPS = 32
21
+ DEFAULT_GUIDANCE = 40
22
+ DEFAULT_TRUE_GS = 1
23
+ DEFAULT_NUM_IMAGES = 1
24
+ DEFAULT_SEED = -1 # -1 indicates random seed
25
+ DEFAULT_DILATION_KERNEL_SIZE = 7
26
+
27
+ # UI constants
28
+ DEFAULT_BRUSH_SIZE = 30
29
+ DEFAULT_MARKER_SIZE = 20
30
+ DEFAULT_MARKER_THICKNESS = 5
31
+ DEFAULT_MASK_ALPHA = 0.3
32
+ DEFAULT_COLOR_ALPHA = 0.7
33
+
34
+ # File naming
35
+ TIMESTAMP_FORMAT = "%Y%m%d_%H%M"
app/event_handlers.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Event handlers for IC-Custom application.
5
+ """
6
+ import gradio as gr
7
+
8
+
9
+ def setup_event_handlers(
10
+ # UI components
11
+ input_mask_mode, image_target_1, image_target_2, undo_target_seg_button,
12
+ custmization_mode, dilate_button, erode_button, bounding_box_button,
13
+ mask_gallery, md_input_mask_mode, md_target_image, md_mask_operation,
14
+ md_prompt, md_submit, result_gallery, image_target_state, mask_target_state,
15
+ seg_ref_mode, image_reference_ori_state, move_to_center,
16
+ image_reference, image_reference_rmbg_state,
17
+ # Functions
18
+ change_input_mask_mode, change_custmization_mode, change_seg_ref_mode,
19
+ init_image_target_1, init_image_target_2, init_image_reference,
20
+ get_point, undo_seg_points, get_brush,
21
+ # VLM buttons (UI components)
22
+ vlm_generate_btn, vlm_polish_btn,
23
+ # VLM functions
24
+ vlm_auto_generate, vlm_auto_polish,
25
+ dilate_mask, erode_mask, bounding_box,
26
+ run_model,
27
+ # Other components
28
+ selected_points, prompt,
29
+ use_background_preservation, background_blend_threshold, seed,
30
+ num_images_per_prompt, guidance, true_gs, num_steps, aspect_ratio,
31
+ submit_button,
32
+ # extra state
33
+ eg_idx,
34
+ ):
35
+ """Setup all event handlers for the application."""
36
+
37
+ # Change input mask mode: precise mask or user-drawn mask
38
+ input_mask_mode.change(
39
+ change_input_mask_mode,
40
+ [input_mask_mode, custmization_mode],
41
+ [image_target_1, image_target_2, undo_target_seg_button]
42
+ )
43
+
44
+ # Change customization mode: pos-aware or pos-free
45
+ custmization_mode.change(
46
+ change_custmization_mode,
47
+ [custmization_mode, input_mask_mode],
48
+ [image_target_1, image_target_2, undo_target_seg_button, dilate_button,
49
+ erode_button, bounding_box_button, md_input_mask_mode,
50
+ md_target_image, md_mask_operation, md_prompt, md_submit, input_mask_mode, mask_gallery]
51
+ )
52
+
53
+ # Remove background for reference image
54
+ seg_ref_mode.change(
55
+ change_seg_ref_mode,
56
+ [seg_ref_mode, image_reference_ori_state, move_to_center],
57
+ [image_reference, image_reference_rmbg_state]
58
+ )
59
+
60
+ # Initialize components only on user upload (not programmatic updates)
61
+ image_target_1.upload(
62
+ init_image_target_1,
63
+ [image_target_1],
64
+ [image_target_state, selected_points, prompt, mask_target_state, mask_gallery,
65
+ result_gallery, use_background_preservation, background_blend_threshold, seed,
66
+ num_images_per_prompt, guidance, true_gs, num_steps, aspect_ratio]
67
+ )
68
+
69
+ image_target_2.upload(
70
+ init_image_target_2,
71
+ [image_target_2],
72
+ [image_target_state, selected_points, prompt, mask_target_state, mask_gallery,
73
+ result_gallery, use_background_preservation, background_blend_threshold, seed,
74
+ num_images_per_prompt, guidance, true_gs, num_steps, aspect_ratio]
75
+ )
76
+
77
+ image_reference.upload(
78
+ init_image_reference,
79
+ [image_reference],
80
+ [image_reference_ori_state, image_reference_rmbg_state, image_target_state,
81
+ mask_target_state, prompt, mask_gallery, result_gallery, image_target_1,
82
+ image_target_2, selected_points, input_mask_mode, seg_ref_mode, move_to_center,
83
+ use_background_preservation, background_blend_threshold, seed,
84
+ num_images_per_prompt, guidance, true_gs, num_steps, aspect_ratio]
85
+ )
86
+
87
+ # SAM for image_target_1
88
+ image_target_1.select(
89
+ get_point,
90
+ [image_target_state, selected_points],
91
+ [image_target_1, mask_target_state, mask_gallery],
92
+ )
93
+
94
+ undo_target_seg_button.click(
95
+ undo_seg_points,
96
+ [image_target_state, selected_points],
97
+ [image_target_1, mask_target_state, mask_gallery]
98
+ )
99
+
100
+ # Brush for image_target_2
101
+ image_target_2.change(
102
+ get_brush,
103
+ [image_target_2],
104
+ [mask_target_state, mask_gallery],
105
+ )
106
+
107
+ # VLM auto generate
108
+ vlm_generate_btn.click(
109
+ vlm_auto_generate,
110
+ [image_target_state, image_reference_ori_state, mask_target_state, custmization_mode],
111
+ [prompt]
112
+ )
113
+
114
+ # VLM auto polish
115
+ vlm_polish_btn.click(
116
+ vlm_auto_polish,
117
+ [prompt, custmization_mode],
118
+ [prompt]
119
+ )
120
+
121
+ # Mask operations
122
+ dilate_button.click(
123
+ dilate_mask,
124
+ [mask_target_state, image_target_state],
125
+ [mask_target_state, mask_gallery]
126
+ )
127
+
128
+ erode_button.click(
129
+ erode_mask,
130
+ [mask_target_state, image_target_state],
131
+ [mask_target_state, mask_gallery]
132
+ )
133
+
134
+ bounding_box_button.click(
135
+ bounding_box,
136
+ [mask_target_state, image_target_state],
137
+ [mask_target_state, mask_gallery]
138
+ )
139
+
140
+ # Run function
141
+ ips = [
142
+ image_target_state, mask_target_state, image_reference_ori_state,
143
+ image_reference_rmbg_state, prompt, seed, guidance, true_gs, num_steps,
144
+ num_images_per_prompt, use_background_preservation, background_blend_threshold,
145
+ aspect_ratio, custmization_mode, seg_ref_mode, input_mask_mode,
146
+ ]
147
+
148
+ submit_button.click(
149
+ fn=run_model,
150
+ inputs=ips,
151
+ outputs=[result_gallery, seed, prompt],
152
+ show_progress=True,
153
+ )
154
+
155
+
app/examples.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+
3
+ make_dict = lambda x : { 'background': Image.open(x).convert("RGBA"), 'layers': [Image.new("RGBA", Image.open(x).size, (255, 255, 255, 0))], 'composite': Image.open(x).convert("RGBA") }
4
+
5
+ null_dict = {
6
+ 'background': None,
7
+ 'composite': None,
8
+ 'layers': []
9
+ }
10
+
11
+ IMG_REF = [
12
+ ## pos-aware: precise mask
13
+ "assets/gradio/pos_aware/001/img_ref.png",
14
+ "assets/gradio/pos_aware/002/img_ref.png",
15
+ ## pos-aware: User-drawn mask mask
16
+ "assets/gradio/pos_aware/003/img_ref.png",
17
+ "assets/gradio/pos_aware/004/img_ref.png",
18
+ "assets/gradio/pos_aware/005/img_ref.png",
19
+ ## pos-free
20
+ "assets/gradio/pos_free/001/img_ref.png",
21
+ "assets/gradio/pos_free/002/img_ref.png",
22
+ "assets/gradio/pos_free/003/img_ref.png",
23
+ "assets/gradio/pos_free/004/img_ref.png",
24
+ ]
25
+
26
+ IMG_TGT1 = [
27
+ ## pos-aware: precise mask
28
+ "assets/gradio/pos_aware/001/img_target.png",
29
+ "assets/gradio/pos_aware/002/img_target.png",
30
+ ## pos-aware: User-drawn mask mask
31
+ None,
32
+ None,
33
+ None,
34
+ ## pos-free
35
+ "assets/gradio/pos_free/001/img_target.png",
36
+ "assets/gradio/pos_free/002/img_target.png",
37
+ "assets/gradio/pos_free/003/img_target.png",
38
+ "assets/gradio/pos_free/004/img_target.png",
39
+ ]
40
+
41
+ IMG_TGT2 = [
42
+ ## pos-aware: precise mask
43
+ null_dict,
44
+ null_dict,
45
+ ## pos-aware: User-drawn mask mask
46
+ make_dict("assets/gradio/pos_aware/003/img_target.png"),
47
+ make_dict("assets/gradio/pos_aware/004/img_target.png"),
48
+ make_dict("assets/gradio/pos_aware/005/img_target.png"),
49
+ ## pos-free
50
+ null_dict,
51
+ null_dict,
52
+ null_dict,
53
+ null_dict,
54
+ ]
55
+
56
+ MASK_TGT = [
57
+ ## pos-aware: precise mask
58
+ "assets/gradio/pos_aware/001/mask_target.png",
59
+ "assets/gradio/pos_aware/002/mask_target.png",
60
+ ## pos-aware: User-drawn mask mask
61
+ "assets/gradio/pos_aware/003/mask_target.png",
62
+ "assets/gradio/pos_aware/004/mask_target.png",
63
+ "assets/gradio/pos_aware/005/mask_target.png",
64
+ ## pos-free
65
+ "assets/gradio/pos_free/001/mask_target.png",
66
+ "assets/gradio/pos_free/002/mask_target.png",
67
+ "assets/gradio/pos_free/003/mask_target.png",
68
+ "assets/gradio/pos_free/004/mask_target.png",
69
+ ]
70
+
71
+ CUSTOM_MODE = [
72
+ ## pos-aware
73
+ "Position-aware",
74
+ "Position-aware",
75
+ "Position-aware",
76
+ "Position-aware",
77
+ "Position-aware",
78
+ ## pos-free
79
+ "Position-free",
80
+ "Position-free",
81
+ "Position-free",
82
+ "Position-free",
83
+ ]
84
+
85
+ INPUT_MASK_MODE = [
86
+ ## pos-aware: precise mask
87
+ "Precise mask",
88
+ "Precise mask",
89
+ ## pos-aware: User-drawn mask mask
90
+ "User-drawn mask",
91
+ "User-drawn mask",
92
+ "User-drawn mask",
93
+ ## pos-free
94
+ "Precise mask",
95
+ "Precise mask",
96
+ "Precise mask",
97
+ "Precise mask",
98
+ ]
99
+
100
+ SEG_REF_MODE = [
101
+ ## pos-aware
102
+ "Full Ref",
103
+ "Full Ref",
104
+ "Full Ref",
105
+ "Full Ref",
106
+ "Full Ref",
107
+ ## pos-free
108
+ "Full Ref",
109
+ "Full Ref",
110
+ "Full Ref",
111
+ "Full Ref",
112
+ ]
113
+
114
+ PROMPTS = [
115
+ ## pos-aware: precise mask
116
+ "",
117
+ "",
118
+ ## pos-aware: User-drawn mask mask
119
+ "A delicate necklace with a mother-of-pearl clover pendant hangs gracefully around the neck of a woman dressed in a black pinstripe blazer.",
120
+ "",
121
+ "",
122
+ ## pos-free
123
+ "TThe charming, soft plush toy is joyfully wandering through a lush, dense jungle, surrounded by vibrant green foliage and towering trees.",
124
+ "A bright yellow alarm clock sits on a wooden desk next to a stack of books in a cozy, sunlit room.",
125
+ "A Lego figure dressed in a vibrant chicken costume, leaning against a wooden chair, surrounded by lush green grass and blooming flowers.",
126
+ "The crocheted gingerbread man is perched on a tree branch in a dense forest, with sunlight filtering through the leaves, casting dappled shadows around him."
127
+ ]
128
+
129
+ IMG_GEN = [
130
+ ## pos-aware: precise mask
131
+ "assets/gradio/pos_aware/001/img_gen.png",
132
+ "assets/gradio/pos_aware/002/img_gen.png",
133
+ ## pos-aware: User-drawn mask mask
134
+ "assets/gradio/pos_aware/003/img_gen.png",
135
+ "assets/gradio/pos_aware/004/img_gen.png",
136
+ "assets/gradio/pos_aware/005/img_gen.png",
137
+ ## pos-free
138
+ "assets/gradio/pos_free/001/img_gen.png",
139
+ "assets/gradio/pos_free/002/img_gen.png",
140
+ "assets/gradio/pos_free/003/img_gen.png",
141
+ "assets/gradio/pos_free/004/img_gen.png",
142
+ ]
143
+
144
+ SEED = [
145
+ ## pos-aware
146
+ 97175498,
147
+ 2126677963,
148
+ 346969695,
149
+ 1172525388,
150
+ 268683460,
151
+ ## pos-free
152
+ 2126677963,
153
+ 418898253,
154
+ 2126677963,
155
+ 2126677963
156
+ ]
157
+
158
+ TRUE_GS = [
159
+ # pos-aware
160
+ 1,
161
+ 1,
162
+ 1,
163
+ 1,
164
+ 1,
165
+ # pos-free
166
+ 3,
167
+ 3,
168
+ 3,
169
+ 3,
170
+ ]
171
+
172
+ NUM_STEPS = [
173
+ ## pos-aware
174
+ 32,
175
+ 32,
176
+ 32,
177
+ 32,
178
+ 32,
179
+ ## pos-free
180
+ 20,
181
+ 20,
182
+ 20,
183
+ 20,
184
+ ]
185
+
186
+ GUIDANCE = [
187
+ ## pos-aware
188
+ 40,
189
+ 48,
190
+ 40,
191
+ 48,
192
+ 48,
193
+ ## pos-free
194
+ 40,
195
+ 40,
196
+ 40,
197
+ 40,
198
+ ]
199
+
200
+ GRADIO_EXAMPLES = [
201
+ [IMG_REF[0], IMG_TGT1[0], IMG_TGT2[0], CUSTOM_MODE[0], INPUT_MASK_MODE[0], SEG_REF_MODE[0], PROMPTS[0], SEED[0], TRUE_GS[0], '0', NUM_STEPS[0], GUIDANCE[0]],
202
+ [IMG_REF[1], IMG_TGT1[1], IMG_TGT2[1], CUSTOM_MODE[1], INPUT_MASK_MODE[1], SEG_REF_MODE[1], PROMPTS[1], SEED[1], TRUE_GS[1], '1', NUM_STEPS[1], GUIDANCE[1]],
203
+ [IMG_REF[2], IMG_TGT1[2], IMG_TGT2[2], CUSTOM_MODE[2], INPUT_MASK_MODE[2], SEG_REF_MODE[2], PROMPTS[2], SEED[2], TRUE_GS[2], '2', NUM_STEPS[2], GUIDANCE[2]],
204
+ [IMG_REF[3], IMG_TGT1[3], IMG_TGT2[3], CUSTOM_MODE[3], INPUT_MASK_MODE[3], SEG_REF_MODE[3], PROMPTS[3], SEED[3], TRUE_GS[3], '3', NUM_STEPS[3], GUIDANCE[3]],
205
+ [IMG_REF[4], IMG_TGT1[4], IMG_TGT2[4], CUSTOM_MODE[4], INPUT_MASK_MODE[4], SEG_REF_MODE[4], PROMPTS[4], SEED[4], TRUE_GS[4], '4', NUM_STEPS[4], GUIDANCE[4]],
206
+ [IMG_REF[5], IMG_TGT1[5], IMG_TGT2[5], CUSTOM_MODE[5], INPUT_MASK_MODE[5], SEG_REF_MODE[5], PROMPTS[5], SEED[5], TRUE_GS[5], '5', NUM_STEPS[5], GUIDANCE[5]],
207
+ [IMG_REF[6], IMG_TGT1[6], IMG_TGT2[6], CUSTOM_MODE[6], INPUT_MASK_MODE[6], SEG_REF_MODE[6], PROMPTS[6], SEED[6], TRUE_GS[6], '6', NUM_STEPS[6], GUIDANCE[6]],
208
+ [IMG_REF[7], IMG_TGT1[7], IMG_TGT2[7], CUSTOM_MODE[7], INPUT_MASK_MODE[7], SEG_REF_MODE[7], PROMPTS[7], SEED[7], TRUE_GS[7], '7', NUM_STEPS[7], GUIDANCE[7]],
209
+ [IMG_REF[8], IMG_TGT1[8], IMG_TGT2[8], CUSTOM_MODE[8], INPUT_MASK_MODE[8], SEG_REF_MODE[8], PROMPTS[8], SEED[8], TRUE_GS[8], '8', NUM_STEPS[8], GUIDANCE[8]],
210
+ ]
app/metainfo.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #### metainfo ####
2
+ head = r"""
3
+ <div class="elegant-header">
4
+ <div class="header-content">
5
+ <!-- Main title -->
6
+ <h1 class="main-title">
7
+ <span class="title-icon">🎨</span>
8
+ <span class="title-text">IC-Custom</span>
9
+ </h1>
10
+
11
+ <!-- Subtitle -->
12
+ <p class="subtitle">Transform your images with AI-powered customization</p>
13
+
14
+ <!-- Action badges -->
15
+ <div class="header-badges">
16
+ <a href="https://liyaowei-stu.github.io/project/IC_Custom/" class="badge-link">
17
+ <span class="badge-icon">🔗</span>
18
+ <span class="badge-text">Project</span>
19
+ </a>
20
+ <a href="https://arxiv.org/abs/2507.01926" class="badge-link">
21
+ <span class="badge-icon">📄</span>
22
+ <span class="badge-text">Paper</span>
23
+ </a>
24
+ <a href="https://github.com/TencentARC/IC-Custom" class="badge-link">
25
+ <span class="badge-icon">💻</span>
26
+ <span class="badge-text">Code</span>
27
+ </a>
28
+ </div>
29
+ </div>
30
+ </div>
31
+ """
32
+
33
+
34
+ getting_started = r"""
35
+ <div class="getting-started-container">
36
+ <!-- Header -->
37
+ <div class="guide-header">
38
+ <h3 class="guide-title">🚀 Quick Start Guide</h3>
39
+ <p class="guide-subtitle">Follow these steps to customize your images with IC-Custom</p>
40
+ </div>
41
+
42
+ <!-- What is IC-Custom -->
43
+ <div class="info-card">
44
+ <div class="info-content">
45
+ <strong class="brand-name">IC-Custom</strong> offers two customization modes:
46
+ <span class="mode-badge position-aware">Position-aware</span>
47
+ (precise placement in masked areas) and
48
+ <span class="mode-badge position-free">Position-free</span>
49
+ (subject-driven generation).
50
+ </div>
51
+ </div>
52
+
53
+ <!-- Common Steps -->
54
+ <div class="step-card common-steps">
55
+ <div class="step-header">
56
+ <span class="step-number">1</span>
57
+ Initial Setup (Both Modes)
58
+ </div>
59
+ <ul class="step-list">
60
+ <li>Choose your <strong>customization mode</strong></li>
61
+ <li>Upload a <strong>reference image</strong> 📸</li>
62
+ </ul>
63
+ </div>
64
+
65
+ <!-- Position-aware Mode -->
66
+ <div class="step-card position-aware-steps">
67
+ <div class="step-header">
68
+ <span class="step-number">2A</span>
69
+ 🎯 Position-aware Mode Steps
70
+ </div>
71
+ <ul class="step-list">
72
+ <li>Select <strong>input mask mode</strong> (precise mask or user-drawn mask)</li>
73
+ <li>Upload <strong>target image</strong> and create mask (click for SAM or brush directly)</li>
74
+ <li>Add <strong>text prompt</strong> (optional) - use VLM buttons for auto-generation</li>
75
+ <li>Review and refine your <strong>mask</strong> using mask tools if needed</li>
76
+ <li>Click <span class="run-button position-aware">Run</span> ✨</li>
77
+ </ul>
78
+ </div>
79
+
80
+ <!-- Position-free Mode -->
81
+ <div class="step-card position-free-steps">
82
+ <div class="step-header">
83
+ <span class="step-number">2B</span>
84
+ 🎨 Position-free Mode Steps
85
+ </div>
86
+ <ul class="step-list">
87
+ <li>Write your <strong>text prompt</strong> (required) - describe the target scene</li>
88
+ <li>Use VLM buttons for prompt auto-generation or polishing (if enabled)</li>
89
+ <li>Click <span class="run-button position-free">Run</span> ✨</li>
90
+ </ul>
91
+ </div>
92
+
93
+ <!-- Quick Tips -->
94
+ <div class="tips-card">
95
+ <div class="tips-content">
96
+ <strong>💡 Quick Tips:</strong>
97
+ Use <kbd class="key-hint">Alt + "-"</kbd> or <kbd class="key-hint">⌘ + "-"</kbd> to zoom out for better operation •
98
+ Adjust settings in <kbd class="key-hint">Advanced Options</kbd> • Use mask operations (<kbd class="key-hint">dilate</kbd>/<kbd class="key-hint">erode</kbd>/<kbd class="key-hint">bbox</kbd>) for better results •
99
+ Try different <kbd class="key-hint">seeds</kbd> for varied outputs
100
+ </div>
101
+ </div>
102
+
103
+ <!-- Final Message -->
104
+ <div class="final-message">
105
+ <div class="final-text">
106
+ 🎉 Ready to start? Collapse this guide and begin customizing!
107
+ </div>
108
+ </div>
109
+ </div>
110
+ """
111
+
112
+
113
+ citation = r"""
114
+ If IC-Custom is helpful, please help to ⭐ the <a href='https://github.com/TencentARC/IC-Custom' target='_blank'>Github Repo</a>. Thanks!
115
+ [![GitHub Stars](https://img.shields.io/github/stars/TencentARC/IC-Custom?style=social)](https://github.com/TencentARC/IC-Custom)
116
+ ---
117
+ 📝 **Citation**
118
+ <br>
119
+ If our work is useful for your research, please consider citing:
120
+ ```bibtex
121
+ @article{li2025ic,
122
+ title={IC-Custom: Diverse Image Customization via In-Context Learning},
123
+ author={Li, Yaowei and Li, Xiaoyu and Zhang, Zhaoyang and Bian, Yuxuan and Liu, Gan and Li, Xinyuan and Xu, Jiale and Hu, Wenbo and Liu, Yating and Li, Lingen and others},
124
+ journal={arXiv preprint arXiv:2507.01926},
125
+ year={2025}
126
+ }
127
+ ```
128
+ 📧 **Contact**
129
+ <br>
130
+ If you have any questions, please feel free to reach me out at <b>[email protected]</b>.
131
+ """
app/stylesheets.py ADDED
@@ -0,0 +1,1679 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Centralized CSS and JS for the IC-Custom app UI.
5
+
6
+ Expose helpers:
7
+ - get_css(): return a single CSS string for gradio Blocks(css=...)
8
+ - get_js(): return an JS for gradio.
9
+ """
10
+
11
+
12
+ def get_css() -> str:
13
+ return r"""
14
+ /* Global Optimization Effects - No Layout Changes */
15
+
16
+ /* Apple-style segmented control for radio buttons */
17
+ #customization_mode_radio .wrap, #input_mask_mode_radio .wrap, #seg_ref_mode_radio .wrap {
18
+ display: flex;
19
+ flex-wrap: nowrap;
20
+ justify-content: center;
21
+ align-items: center;
22
+ gap: 0;
23
+ background: rgba(255, 255, 255, 0.8);
24
+ border: 1px solid var(--neutral-200);
25
+ border-radius: 10px;
26
+ padding: 3px;
27
+ backdrop-filter: blur(12px);
28
+ -webkit-backdrop-filter: blur(12px);
29
+ box-shadow: 0 2px 8px rgba(15, 23, 42, 0.08);
30
+ }
31
+
32
+ #customization_mode_radio .wrap label, #input_mask_mode_radio .wrap label, #seg_ref_mode_radio .wrap label {
33
+ display: flex;
34
+ flex: 1;
35
+ justify-content: center;
36
+ align-items: center;
37
+ margin: 0;
38
+ padding: 10px 16px;
39
+ box-sizing: border-box;
40
+ border-radius: 7px;
41
+ transition: all 0.3s cubic-bezier(0.4, 0, 0.2, 1);
42
+ background: transparent;
43
+ border: none;
44
+ font-weight: 500;
45
+ font-size: 0.9rem;
46
+ color: var(--text-secondary);
47
+ cursor: pointer;
48
+ position: relative;
49
+ white-space: nowrap;
50
+ min-width: 0;
51
+ }
52
+
53
+ /* Hide the actual radio input */
54
+ #customization_mode_radio .wrap label input[type="radio"],
55
+ #input_mask_mode_radio .wrap label input[type="radio"],
56
+ #seg_ref_mode_radio .wrap label input[type="radio"] {
57
+ display: none;
58
+ }
59
+
60
+ /* Hover states */
61
+ #customization_mode_radio .wrap label:hover,
62
+ #input_mask_mode_radio .wrap label:hover,
63
+ #seg_ref_mode_radio .wrap label:hover {
64
+ background: rgba(14, 165, 233, 0.1);
65
+ color: var(--primary-blue);
66
+ }
67
+
68
+ /* Selected state with smooth background */
69
+ #customization_mode_radio .wrap label:has(input[type="radio"]:checked),
70
+ #input_mask_mode_radio .wrap label:has(input[type="radio"]:checked),
71
+ #seg_ref_mode_radio .wrap label:has(input[type="radio"]:checked) {
72
+ background: var(--primary-blue);
73
+ color: white;
74
+ font-weight: 600;
75
+ box-shadow: 0 2px 6px rgba(14, 165, 233, 0.25);
76
+ transform: none;
77
+ }
78
+
79
+ /* Fallback for browsers that don't support :has() */
80
+ #customization_mode_radio .wrap label input[type="radio"]:checked + *,
81
+ #input_mask_mode_radio .wrap label input[type="radio"]:checked + *,
82
+ #seg_ref_mode_radio .wrap label input[type="radio"]:checked + * {
83
+ color: white;
84
+ }
85
+
86
+ #customization_mode_radio .wrap:has(input[type="radio"]:checked) label:has(input[type="radio"]:checked),
87
+ #input_mask_mode_radio .wrap:has(input[type="radio"]:checked) label:has(input[type="radio"]:checked),
88
+ #seg_ref_mode_radio .wrap:has(input[type="radio"]:checked) label:has(input[type="radio"]:checked) {
89
+ background: var(--primary-blue);
90
+ }
91
+
92
+ /* Active state */
93
+ #customization_mode_radio .wrap label:active,
94
+ #input_mask_mode_radio .wrap label:active,
95
+ #seg_ref_mode_radio .wrap label:active {
96
+ transform: scale(0.98);
97
+ }
98
+
99
+ /* Elegant header styling */
100
+ .elegant-header {
101
+ text-align: center;
102
+ margin: 0 0 2rem 0;
103
+ padding: 0;
104
+ }
105
+
106
+ .header-content {
107
+ display: inline-block;
108
+ padding: 1.8rem 2.5rem;
109
+ background: linear-gradient(135deg,
110
+ rgba(255, 255, 255, 0.1) 0%,
111
+ rgba(255, 255, 255, 0.05) 100%);
112
+ border: 1px solid rgba(255, 255, 255, 0.15);
113
+ border-radius: 20px;
114
+ backdrop-filter: blur(12px);
115
+ -webkit-backdrop-filter: blur(12px);
116
+ box-shadow:
117
+ 0 8px 32px rgba(15, 23, 42, 0.04),
118
+ inset 0 1px 0 rgba(255, 255, 255, 0.2);
119
+ transition: all 0.4s ease;
120
+ position: relative;
121
+ overflow: hidden;
122
+ }
123
+
124
+ .header-content::before {
125
+ content: '';
126
+ position: absolute;
127
+ top: 0;
128
+ left: -100%;
129
+ width: 100%;
130
+ height: 100%;
131
+ background: linear-gradient(90deg,
132
+ transparent,
133
+ rgba(255, 255, 255, 0.1),
134
+ transparent);
135
+ transition: left 0.6s ease;
136
+ }
137
+
138
+ .header-content:hover::before {
139
+ left: 100%;
140
+ }
141
+
142
+ .header-content:hover {
143
+ transform: translateY(-2px);
144
+ box-shadow:
145
+ 0 12px 40px rgba(15, 23, 42, 0.08),
146
+ inset 0 1px 0 rgba(255, 255, 255, 0.3);
147
+ border-color: rgba(14, 165, 233, 0.2);
148
+ }
149
+
150
+ /* Main title styling */
151
+ .main-title {
152
+ margin: 0 0 0.8rem 0;
153
+ font-size: 2.4rem;
154
+ font-weight: 800;
155
+ display: flex;
156
+ align-items: center;
157
+ justify-content: center;
158
+ gap: 0.5rem;
159
+ }
160
+
161
+ .title-icon {
162
+ font-size: 2.2rem;
163
+ filter: drop-shadow(0 2px 4px rgba(0, 0, 0, 0.1));
164
+ }
165
+
166
+ .title-text {
167
+ background: linear-gradient(135deg, #0ea5e9 0%, #06b6d4 50%, #10b981 100%);
168
+ -webkit-background-clip: text;
169
+ -webkit-text-fill-color: transparent;
170
+ background-clip: text;
171
+ text-shadow: none;
172
+ position: relative;
173
+ }
174
+
175
+ /* Subtitle styling */
176
+ .subtitle {
177
+ margin: 0 0 1.2rem 0;
178
+ font-size: 1rem;
179
+ color: #64748b;
180
+ font-weight: 500;
181
+ letter-spacing: 0.025em;
182
+ opacity: 0.9;
183
+ }
184
+
185
+ /* Header badges container */
186
+ .header-badges {
187
+ display: flex;
188
+ justify-content: center;
189
+ gap: 0.8rem;
190
+ flex-wrap: wrap;
191
+ }
192
+
193
+ /* Individual badge links */
194
+ .badge-link {
195
+ display: inline-flex;
196
+ align-items: center;
197
+ gap: 0.4rem;
198
+ padding: 0.5rem 1rem;
199
+ background: rgba(255, 255, 255, 0.15);
200
+ border: 1px solid rgba(255, 255, 255, 0.2);
201
+ border-radius: 12px;
202
+ color: #475569;
203
+ text-decoration: none;
204
+ font-weight: 500;
205
+ font-size: 0.9rem;
206
+ transition: all 0.3s ease;
207
+ backdrop-filter: blur(4px);
208
+ -webkit-backdrop-filter: blur(4px);
209
+ position: relative;
210
+ overflow: hidden;
211
+ }
212
+
213
+ .badge-link::before {
214
+ content: '';
215
+ position: absolute;
216
+ top: 0;
217
+ left: -100%;
218
+ width: 100%;
219
+ height: 100%;
220
+ background: var(--primary-gradient);
221
+ transition: left 0.3s ease;
222
+ z-index: -1;
223
+ }
224
+
225
+ .badge-link:hover::before {
226
+ left: 0;
227
+ }
228
+
229
+ .badge-link:hover {
230
+ transform: translateY(-2px);
231
+ color: white;
232
+ border-color: transparent;
233
+ box-shadow: 0 4px 12px rgba(14, 165, 233, 0.3);
234
+ }
235
+
236
+ .badge-icon {
237
+ font-size: 1rem;
238
+ opacity: 0.8;
239
+ }
240
+
241
+ .badge-text {
242
+ font-weight: 600;
243
+ }
244
+
245
+ /* Getting Started Guide Styling */
246
+ .getting-started-container {
247
+ padding: 1.5rem;
248
+ background: linear-gradient(135deg, rgba(255, 255, 255, 0.1), rgba(255, 255, 255, 0.05));
249
+ border-radius: 12px;
250
+ border: 1px solid rgba(255, 255, 255, 0.15);
251
+ backdrop-filter: blur(8px);
252
+ -webkit-backdrop-filter: blur(8px);
253
+ box-shadow: 0 4px 16px rgba(15, 23, 42, 0.06);
254
+ }
255
+
256
+ .guide-header {
257
+ text-align: center;
258
+ margin-bottom: 1.5rem;
259
+ }
260
+
261
+ .guide-title {
262
+ color: var(--text-primary);
263
+ margin: 0 0 0.5rem 0;
264
+ font-size: 1.2rem;
265
+ font-weight: 700;
266
+ }
267
+
268
+ .guide-subtitle {
269
+ color: var(--text-muted);
270
+ margin: 0;
271
+ font-size: 0.9rem;
272
+ opacity: 0.9;
273
+ }
274
+
275
+ /* Info card */
276
+ .info-card {
277
+ background: rgba(255, 255, 255, 0.4);
278
+ border-radius: 8px;
279
+ padding: 1rem;
280
+ margin-bottom: 1.2rem;
281
+ border-left: 3px solid var(--primary-blue);
282
+ backdrop-filter: blur(4px);
283
+ -webkit-backdrop-filter: blur(4px);
284
+ transition: all 0.3s ease;
285
+ }
286
+
287
+ .info-card:hover {
288
+ background: rgba(255, 255, 255, 0.5);
289
+ transform: translateX(2px);
290
+ }
291
+
292
+ .info-content {
293
+ color: var(--text-secondary);
294
+ font-size: 0.9rem;
295
+ line-height: 1.5;
296
+ }
297
+
298
+ .brand-name {
299
+ color: var(--primary-blue);
300
+ font-weight: 700;
301
+ }
302
+
303
+ /* Mode badges */
304
+ .mode-badge {
305
+ padding: 0.2rem 0.4rem;
306
+ border-radius: 4px;
307
+ font-size: 0.8rem;
308
+ font-weight: 600;
309
+ margin: 0 0.2rem;
310
+ transition: all 0.2s ease;
311
+ }
312
+
313
+ .mode-badge.position-aware {
314
+ background: var(--badge-blue-bg);
315
+ color: var(--badge-blue-text);
316
+ }
317
+
318
+ .mode-badge.position-free {
319
+ background: var(--badge-green-bg);
320
+ color: var(--badge-green-text);
321
+ }
322
+
323
+ /* Step cards */
324
+ .step-card {
325
+ background: rgba(255, 255, 255, 0.4);
326
+ border-radius: 8px;
327
+ padding: 1rem;
328
+ margin-bottom: 1.2rem;
329
+ backdrop-filter: blur(4px);
330
+ -webkit-backdrop-filter: blur(4px);
331
+ transition: all 0.3s ease;
332
+ position: relative;
333
+ overflow: hidden;
334
+ }
335
+
336
+ .step-card::before {
337
+ content: '';
338
+ position: absolute;
339
+ left: 0;
340
+ top: 0;
341
+ bottom: 0;
342
+ width: 3px;
343
+ transition: all 0.3s ease;
344
+ }
345
+
346
+ .step-card.common-steps::before {
347
+ background: var(--neutral-500);
348
+ }
349
+
350
+ .step-card.position-aware-steps::before {
351
+ background: var(--position-aware-blue);
352
+ }
353
+
354
+ .step-card.position-free-steps::before {
355
+ background: var(--position-free-purple);
356
+ }
357
+
358
+ .step-card:hover {
359
+ background: rgba(255, 255, 255, 0.5);
360
+ transform: translateX(2px);
361
+ }
362
+
363
+ .step-header {
364
+ font-weight: 600;
365
+ color: var(--text-primary);
366
+ margin-bottom: 0.75rem;
367
+ display: flex;
368
+ align-items: center;
369
+ font-size: 0.95rem;
370
+ }
371
+
372
+ .step-number {
373
+ color: white;
374
+ border-radius: 50%;
375
+ width: 24px;
376
+ height: 24px;
377
+ display: inline-flex;
378
+ align-items: center;
379
+ justify-content: center;
380
+ margin-right: 0.5rem;
381
+ font-size: 0.75rem;
382
+ font-weight: 700;
383
+ }
384
+
385
+ .common-steps .step-number {
386
+ background: var(--neutral-500);
387
+ }
388
+
389
+ .position-aware-steps .step-number {
390
+ background: var(--position-aware-blue);
391
+ }
392
+
393
+ .position-free-steps .step-number {
394
+ background: var(--position-free-purple);
395
+ }
396
+
397
+ .step-list {
398
+ margin: 0;
399
+ padding-left: 1.2rem;
400
+ font-size: 0.85rem;
401
+ color: var(--text-secondary);
402
+ line-height: 1.6;
403
+ }
404
+
405
+ .step-list li {
406
+ margin-bottom: 0.4rem;
407
+ position: relative;
408
+ }
409
+
410
+ .step-list li:last-child {
411
+ margin-bottom: 0;
412
+ }
413
+
414
+ /* Run buttons */
415
+ .run-button {
416
+ padding: 0.2rem 0.5rem;
417
+ border-radius: 4px;
418
+ font-weight: 600;
419
+ font-size: 0.8rem;
420
+ color: white;
421
+ transition: all 0.2s ease;
422
+ }
423
+
424
+ .run-button.position-aware {
425
+ background: var(--position-aware-blue);
426
+ }
427
+
428
+ .run-button.position-free {
429
+ background: var(--position-free-purple);
430
+ }
431
+
432
+ /* Tips card */
433
+ .tips-card {
434
+ background: rgba(241, 245, 249, 0.6);
435
+ border-radius: 8px;
436
+ padding: 0.8rem;
437
+ border-left: 3px solid var(--neutral-400);
438
+ margin-bottom: 1rem;
439
+ backdrop-filter: blur(4px);
440
+ -webkit-backdrop-filter: blur(4px);
441
+ transition: all 0.3s ease;
442
+ }
443
+
444
+ .tips-card:hover {
445
+ background: rgba(241, 245, 249, 0.8);
446
+ transform: translateX(2px);
447
+ }
448
+
449
+ .tips-content {
450
+ font-size: 0.8rem;
451
+ color: var(--text-tips);
452
+ line-height: 1.5;
453
+ }
454
+
455
+ /* Key hints */
456
+ .key-hint {
457
+ background: var(--kbd-bg);
458
+ color: var(--kbd-text);
459
+ padding: 0.1rem 0.3rem;
460
+ border-radius: 3px;
461
+ font-size: 0.75em;
462
+ border: 1px solid var(--kbd-border);
463
+ font-family: monospace;
464
+ font-weight: 500;
465
+ transition: all 0.2s ease;
466
+ }
467
+
468
+ .key-hint:hover {
469
+ background: var(--primary-blue);
470
+ color: white;
471
+ border-color: var(--primary-blue);
472
+ }
473
+
474
+ /* Final message */
475
+ .final-message {
476
+ padding: 0.8rem;
477
+ background: var(--bg-final);
478
+ border-radius: 8px;
479
+ text-align: center;
480
+ transition: all 0.3s ease;
481
+ }
482
+
483
+ .final-message:hover {
484
+ transform: translateY(-1px);
485
+ box-shadow: 0 4px 12px rgba(14, 165, 233, 0.1);
486
+ }
487
+
488
+ .final-text {
489
+ color: var(--text-final);
490
+ font-weight: 600;
491
+ font-size: 0.85rem;
492
+ }
493
+
494
+ /* Legacy header badge styling for backward compatibility */
495
+ .header-badge {
496
+ background: var(--primary-gradient);
497
+ color: white;
498
+ padding: 0.5rem 1rem;
499
+ border-radius: 6px;
500
+ font-weight: 500;
501
+ font-size: 0.9rem;
502
+ transition: all 0.3s ease;
503
+ box-shadow: 0 1px 3px rgba(14, 165, 233, 0.2);
504
+ display: inline-block;
505
+ }
506
+
507
+ .header-badge:hover {
508
+ transform: translateY(-2px);
509
+ box-shadow: 0 4px 8px rgba(14, 165, 233, 0.3);
510
+ text-decoration: none;
511
+ }
512
+
513
+ /* Accordion styling matching getting_started */
514
+ .gradio-accordion {
515
+ border: 1px solid rgba(14, 165, 233, 0.2);
516
+ border-radius: 8px;
517
+ overflow: visible !important; /* Allow dropdown to overflow */
518
+ box-shadow: 0 2px 8px rgba(14, 165, 233, 0.1);
519
+ background: linear-gradient(135deg, rgba(255, 255, 255, 0.98) 0%, rgba(240, 249, 255, 0.95) 100%);
520
+ }
521
+
522
+ /* Ensure accordion content area allows dropdown overflow */
523
+ .gradio-accordion .wrap {
524
+ overflow: visible !important;
525
+ }
526
+
527
+ .gradio-accordion > .label-wrap {
528
+ background: linear-gradient(135deg, rgba(255, 255, 255, 0.98) 0%, rgba(240, 249, 255, 0.95) 100%);
529
+ border-bottom: 1px solid rgba(14, 165, 233, 0.2);
530
+ padding: 1rem 1.5rem;
531
+ font-weight: 600;
532
+ color: var(--text-primary);
533
+ }
534
+
535
+ /* Minimal dropdown styling - let Gradio handle positioning naturally */
536
+ #aspect_ratio_dropdown {
537
+ border-radius: 8px;
538
+ }
539
+
540
+ /* COMPLETELY REMOVE all dropdown styling - let Gradio handle everything */
541
+ /* This was causing the dropdown to display as a text block instead of options */
542
+
543
+ /* DO NOT style .gradio-dropdown globally - causes functionality issues */
544
+
545
+ /* Slider styling matching theme */
546
+ .gradio-slider {
547
+ background: linear-gradient(135deg, rgba(255, 255, 255, 0.98) 0%, rgba(240, 249, 255, 0.95) 100%) !important;
548
+ border: 1px solid rgba(14, 165, 233, 0.2) !important;
549
+ border-radius: 8px !important;
550
+ box-shadow: 0 2px 8px rgba(14, 165, 233, 0.1) !important;
551
+ padding: 12px !important;
552
+ }
553
+
554
+ .gradio-slider:hover {
555
+ border-color: rgba(14, 165, 233, 0.3) !important;
556
+ box-shadow: 0 2px 8px rgba(14, 165, 233, 0.15) !important;
557
+ }
558
+
559
+ /* Slider input styling */
560
+ .gradio-slider input[type="range"] {
561
+ background: transparent !important;
562
+ }
563
+
564
+ .gradio-slider input[type="range"]::-webkit-slider-track {
565
+ background: rgba(14, 165, 233, 0.2) !important;
566
+ border-radius: 4px !important;
567
+ }
568
+
569
+ .gradio-slider input[type="range"]::-webkit-slider-thumb {
570
+ background: var(--primary-blue) !important;
571
+ border: 2px solid white !important;
572
+ box-shadow: 0 2px 8px rgba(14, 165, 233, 0.3) !important;
573
+ }
574
+
575
+ /* Checkbox styling matching theme */
576
+ .gradio-checkbox {
577
+ background: linear-gradient(135deg, rgba(255, 255, 255, 0.98) 0%, rgba(240, 249, 255, 0.95) 100%) !important;
578
+ border: 1px solid rgba(14, 165, 233, 0.2) !important;
579
+ border-radius: 8px !important;
580
+ box-shadow: 0 2px 8px rgba(14, 165, 233, 0.1) !important;
581
+ padding: 8px 12px !important;
582
+ }
583
+
584
+ /* Specific styling for identified components */
585
+ #aspect_ratio_dropdown,
586
+ #text_prompt,
587
+ #move_to_center_checkbox,
588
+ #use_bg_preservation_checkbox {
589
+ background: linear-gradient(135deg, rgba(255, 255, 255, 0.98) 0%, rgba(240, 249, 255, 0.95) 100%) !important;
590
+ border: 1px solid rgba(14, 165, 233, 0.2) !important;
591
+ border-radius: 8px !important;
592
+ box-shadow: 0 2px 8px rgba(14, 165, 233, 0.1) !important;
593
+ }
594
+
595
+ /* Removed specific aspect_ratio_dropdown styling to avoid conflicts */
596
+
597
+ #aspect_ratio_dropdown:hover,
598
+ #text_prompt:hover,
599
+ #move_to_center_checkbox:hover,
600
+ #use_bg_preservation_checkbox:hover {
601
+ border-color: rgba(14, 165, 233, 0.3) !important;
602
+ box-shadow: 0 2px 8px rgba(14, 165, 233, 0.15) !important;
603
+ }
604
+
605
+ /* Textbox specific styling */
606
+ #text_prompt textarea {
607
+ background: linear-gradient(135deg, rgba(255, 255, 255, 0.98) 0%, rgba(240, 249, 255, 0.95) 100%) !important;
608
+ border: 1px solid rgba(14, 165, 233, 0.2) !important;
609
+ border-radius: 8px !important;
610
+ box-shadow: 0 2px 8px rgba(14, 165, 233, 0.1) !important;
611
+ }
612
+
613
+ /* Color variables matching getting_started section exactly */
614
+ :root {
615
+ /* Primary colors from getting_started */
616
+ --primary-blue: #0ea5e9;
617
+ --primary-blue-secondary: #06b6d4;
618
+ --primary-green: #10b981;
619
+ --primary-gradient: linear-gradient(135deg, #0ea5e9 0%, #06b6d4 50%, #10b981 100%);
620
+
621
+ /* Mode-specific colors from getting_started */
622
+ --position-aware-blue: #3b82f6;
623
+ --position-free-purple: #8b5cf6;
624
+
625
+ /* Badge colors from getting_started */
626
+ --badge-blue-bg: #dbeafe;
627
+ --badge-blue-text: #1e40af;
628
+ --badge-green-bg: #dcfce7;
629
+ --badge-green-text: #166534;
630
+
631
+ /* Neutral colors from getting_started */
632
+ --neutral-50: #f8fafc;
633
+ --neutral-100: #f1f5f9;
634
+ --neutral-200: #e2e8f0;
635
+ --neutral-300: #cbd5e1;
636
+ --neutral-400: #94a3b8;
637
+ --neutral-500: #64748b;
638
+ --neutral-600: #475569;
639
+ --neutral-700: #334155;
640
+ --neutral-800: #1e293b;
641
+
642
+ /* Text colors from getting_started */
643
+ --text-primary: #1e293b;
644
+ --text-secondary: #4b5563;
645
+ --text-muted: #64748b;
646
+ --text-tips: #475569;
647
+ --text-final: #0c4a6e;
648
+
649
+ /* Background colors from getting_started */
650
+ --bg-primary: white;
651
+ --bg-secondary: #f8fafc;
652
+ --bg-tips: #f1f5f9;
653
+ --bg-final: linear-gradient(135deg, #f0f9ff 0%, #e0f2fe 100%);
654
+
655
+ /* Keyboard hint styles from getting_started */
656
+ --kbd-bg: #e2e8f0;
657
+ --kbd-text: #475569;
658
+ --kbd-border: #cbd5e1;
659
+ }
660
+
661
+ /* Global smooth transitions - exclude dropdowns */
662
+ *:not(.gradio-dropdown):not(.gradio-dropdown *) {
663
+ transition: all 0.2s ease;
664
+ }
665
+
666
+ /* Focus states using getting_started primary blue - exclude dropdowns */
667
+ button:focus,
668
+ input:not(.gradio-dropdown input):focus,
669
+ select:not(.gradio-dropdown select):focus,
670
+ textarea:not(.gradio-dropdown textarea):focus {
671
+ outline: none;
672
+ box-shadow: 0 0 0 2px rgba(14, 165, 233, 0.3);
673
+ }
674
+
675
+ /* Subtle hover effects for interactive elements - exclude dropdowns */
676
+ button:not(.gradio-dropdown button):hover {
677
+ transform: translateY(-1px);
678
+ }
679
+
680
+ /* Global text styling matching getting_started */
681
+ body {
682
+ font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
683
+ line-height: 1.6;
684
+ color: var(--text-secondary);
685
+ background-color: var(--bg-secondary);
686
+ }
687
+
688
+ /* Enhanced form element styling - exclude dropdowns from global styling */
689
+ input:not(.gradio-dropdown input),
690
+ textarea:not(.gradio-dropdown textarea),
691
+ select:not(.gradio-dropdown select) {
692
+ border-radius: 8px;
693
+ border: 1px solid rgba(14, 165, 233, 0.2);
694
+ background: linear-gradient(135deg, rgba(255, 255, 255, 0.98) 0%, rgba(240, 249, 255, 0.95) 100%);
695
+ backdrop-filter: blur(8px);
696
+ -webkit-backdrop-filter: blur(8px);
697
+ color: var(--text-primary);
698
+ transition: all 0.3s ease;
699
+ padding: 12px 16px;
700
+ font-size: 0.95rem;
701
+ box-shadow: 0 2px 8px rgba(14, 165, 233, 0.1);
702
+ }
703
+
704
+ input:not(.gradio-dropdown input):focus,
705
+ textarea:not(.gradio-dropdown textarea):focus,
706
+ select:not(.gradio-dropdown select):focus {
707
+ border-color: var(--primary-blue);
708
+ box-shadow: 0 0 0 3px rgba(14, 165, 233, 0.1), 0 2px 8px rgba(14, 165, 233, 0.15);
709
+ background: linear-gradient(135deg, rgba(255, 255, 255, 1) 0%, rgba(240, 249, 255, 0.98) 100%);
710
+ outline: none;
711
+ transform: translateY(-1px);
712
+ }
713
+
714
+ input:not(.gradio-dropdown input):hover,
715
+ textarea:not(.gradio-dropdown textarea):hover,
716
+ select:not(.gradio-dropdown select):hover {
717
+ border-color: rgba(14, 165, 233, 0.3);
718
+ box-shadow: 0 2px 8px rgba(14, 165, 233, 0.12);
719
+ }
720
+
721
+ /* Textbox specific styling */
722
+ .gradio-textbox {
723
+ border-radius: 12px;
724
+ overflow: hidden;
725
+ }
726
+
727
+ .gradio-textbox textarea {
728
+ border-radius: 12px;
729
+ resize: vertical;
730
+ min-height: 44px;
731
+ }
732
+
733
+ /* Scrollbar styling matching getting_started */
734
+ ::-webkit-scrollbar {
735
+ width: 8px;
736
+ }
737
+
738
+ ::-webkit-scrollbar-track {
739
+ background: var(--neutral-100);
740
+ border-radius: 4px;
741
+ }
742
+
743
+ ::-webkit-scrollbar-thumb {
744
+ background: var(--neutral-400);
745
+ border-radius: 4px;
746
+ }
747
+
748
+ ::-webkit-scrollbar-thumb:hover {
749
+ background: var(--primary-blue);
750
+ }
751
+
752
+ /* Enhanced button styling with Apple-style refinement */
753
+ button {
754
+ border-radius: 8px;
755
+ font-weight: 500;
756
+ cursor: pointer;
757
+ transition: all 0.3s ease;
758
+ border: 1px solid var(--neutral-200);
759
+ position: relative;
760
+ overflow: hidden;
761
+ backdrop-filter: blur(8px);
762
+ -webkit-backdrop-filter: blur(8px);
763
+ }
764
+
765
+ /* Button hover glow effect */
766
+ button::after {
767
+ content: '';
768
+ position: absolute;
769
+ top: 50%;
770
+ left: 50%;
771
+ width: 0;
772
+ height: 0;
773
+ background: radial-gradient(circle, rgba(14, 165, 233, 0.1) 0%, transparent 70%);
774
+ transition: all 0.4s ease;
775
+ transform: translate(-50%, -50%);
776
+ pointer-events: none;
777
+ }
778
+
779
+ button:hover::after {
780
+ width: 200px;
781
+ height: 200px;
782
+ }
783
+
784
+ /* Primary button using unified primary blue */
785
+ button[variant="primary"] {
786
+ background: var(--primary-blue);
787
+ border-color: var(--primary-blue);
788
+ color: white;
789
+ box-shadow: 0 2px 6px rgba(14, 165, 233, 0.2);
790
+ }
791
+
792
+ button[variant="primary"]:hover {
793
+ background: #0284c7;
794
+ border-color: #0284c7;
795
+ box-shadow: 0 4px 12px rgba(14, 165, 233, 0.3);
796
+ transform: translateY(-1px);
797
+ }
798
+
799
+ /* Secondary buttons */
800
+ button[variant="secondary"], .secondary-button {
801
+ background: linear-gradient(135deg, rgba(255, 255, 255, 0.98) 0%, rgba(240, 249, 255, 0.95) 100%);
802
+ border: 1px solid rgba(14, 165, 233, 0.2);
803
+ color: var(--text-secondary);
804
+ box-shadow: 0 2px 8px rgba(14, 165, 233, 0.1);
805
+ }
806
+
807
+ button[variant="secondary"]:hover, .secondary-button:hover {
808
+ background: var(--primary-blue);
809
+ border-color: var(--primary-blue);
810
+ color: white;
811
+ transform: translateY(-1px);
812
+ box-shadow: 0 4px 12px rgba(14, 165, 233, 0.25);
813
+ }
814
+
815
+ /* VLM buttons with subtle, elegant styling */
816
+ #vlm_generate_btn, #vlm_polish_btn {
817
+ background: linear-gradient(135deg, rgba(255, 255, 255, 0.98) 0%, rgba(240, 249, 255, 0.95) 100%) !important;
818
+ border: 1px solid rgba(14, 165, 233, 0.2) !important;
819
+ color: var(--text-secondary) !important;
820
+ box-shadow: 0 2px 8px rgba(14, 165, 233, 0.1);
821
+ font-weight: 500;
822
+ border-radius: 8px;
823
+ position: relative;
824
+ overflow: hidden;
825
+ transition: all 0.3s ease;
826
+ backdrop-filter: blur(8px);
827
+ -webkit-backdrop-filter: blur(8px);
828
+ }
829
+
830
+ #vlm_generate_btn::before, #vlm_polish_btn::before {
831
+ content: '';
832
+ position: absolute;
833
+ top: 0;
834
+ left: -100%;
835
+ width: 100%;
836
+ height: 100%;
837
+ background: linear-gradient(90deg, transparent, rgba(14, 165, 233, 0.1), transparent);
838
+ transition: left 0.5s ease;
839
+ }
840
+
841
+ #vlm_generate_btn:hover::before, #vlm_polish_btn:hover::before {
842
+ left: 100%;
843
+ }
844
+
845
+ #vlm_generate_btn:hover, #vlm_polish_btn:hover {
846
+ background: var(--primary-blue) !important;
847
+ border-color: var(--primary-blue) !important;
848
+ color: white !important;
849
+ box-shadow: 0 4px 12px rgba(14, 165, 233, 0.25);
850
+ transform: translateY(-1px);
851
+ }
852
+
853
+ #vlm_generate_btn:active, #vlm_polish_btn:active {
854
+ transform: translateY(0px);
855
+ box-shadow: 0 2px 6px rgba(14, 165, 233, 0.2);
856
+ }
857
+
858
+ /* Enhanced image styling with fixed dimensions for consistency */
859
+ .gradio-image, .gradio-imageeditor {
860
+ height: 300px !important;
861
+ width: 100% !important;
862
+ padding: 0 !important;
863
+ margin: 0 !important;
864
+ }
865
+
866
+ .gradio-image img,
867
+ .gradio-imageeditor img {
868
+ height: 300px !important;
869
+ width: 100% !important;
870
+ object-fit: contain !important;
871
+ border-radius: 8px;
872
+ transition: all 0.3s ease;
873
+ box-shadow: 0 2px 8px rgba(15, 23, 42, 0.1);
874
+ }
875
+
876
+ .gradio-image img:hover,
877
+ .gradio-imageeditor img:hover {
878
+ transform: scale(1.02);
879
+ box-shadow: 0 4px 16px rgba(15, 23, 42, 0.15);
880
+ }
881
+
882
+ /* Gallery CSS - contained adaptive layout with theme colors */
883
+ #mask_gallery, #result_gallery, .custom-gallery {
884
+ overflow: visible !important; /* Allow progress indicator to show */
885
+ position: relative !important;
886
+ width: 100% !important;
887
+ height: auto !important;
888
+ max-height: 75vh !important;
889
+ min-height: 300px !important;
890
+ display: flex !important;
891
+ flex-direction: column !important;
892
+ padding: 6px !important;
893
+ margin: 0 !important;
894
+ border-radius: 12px !important;
895
+ background: linear-gradient(135deg, rgba(255, 255, 255, 0.98) 0%, rgba(240, 249, 255, 0.95) 100%) !important;
896
+ border: 1px solid rgba(14, 165, 233, 0.2) !important;
897
+ box-shadow: 0 2px 8px rgba(14, 165, 233, 0.1) !important;
898
+ }
899
+
900
+ /* Gallery containers with contained but flexible display */
901
+ #mask_gallery .gradio-gallery, #result_gallery .gradio-gallery {
902
+ width: 100% !important;
903
+ height: auto !important;
904
+ min-height: 280px !important;
905
+ max-height: 70vh !important;
906
+ padding: 8px !important;
907
+ overflow: auto !important;
908
+ display: flex !important;
909
+ flex-direction: column !important;
910
+ border-radius: 8px !important;
911
+ }
912
+
913
+ /* Only hide specific duplicate elements that cause the problem */
914
+ #mask_gallery > div > div:nth-child(n+2),
915
+ #result_gallery > div > div:nth-child(n+2) {
916
+ display: none !important;
917
+ }
918
+
919
+ /* Alternative: hide duplicate grid structures only */
920
+ #mask_gallery .gradio-gallery:nth-child(n+2),
921
+ #result_gallery .gradio-gallery:nth-child(n+2) {
922
+ display: none !important;
923
+ }
924
+
925
+ /* Ensure timing and status elements are NOT hidden by the above rules */
926
+ #result_gallery .status,
927
+ #result_gallery .timer,
928
+ #result_gallery [class*="time"],
929
+ #result_gallery [class*="status"],
930
+ #result_gallery [class*="duration"],
931
+ #result_gallery .gradio-status,
932
+ #result_gallery .gradio-timer,
933
+ #result_gallery .gradio-info,
934
+ #result_gallery [data-testid*="timer"],
935
+ #result_gallery [data-testid*="status"] {
936
+ display: block !important;
937
+ visibility: visible !important;
938
+ opacity: 1 !important;
939
+ position: relative !important;
940
+ z-index: 1000 !important;
941
+ }
942
+
943
+ /* Gallery images - contained adaptive display */
944
+ #mask_gallery img, #result_gallery img {
945
+ width: 100% !important;
946
+ height: auto !important;
947
+ max-width: 100% !important;
948
+ max-height: 60vh !important;
949
+ object-fit: contain !important;
950
+ border-radius: 8px;
951
+ box-shadow: 0 2px 8px rgba(15, 23, 42, 0.1);
952
+ display: block !important;
953
+ margin: 0 auto !important;
954
+ }
955
+
956
+ /* Main preview image styling - contained but responsive */
957
+ #mask_gallery .preview-image, #result_gallery .preview-image {
958
+ width: 100% !important;
959
+ height: auto !important;
960
+ max-width: 100% !important;
961
+ max-height: 55vh !important;
962
+ border-radius: 12px;
963
+ box-shadow: 0 4px 16px rgba(15, 23, 42, 0.15);
964
+ object-fit: contain !important;
965
+ display: block !important;
966
+ margin: 0 auto !important;
967
+ }
968
+
969
+ /* Gallery content wrappers - ensure no height constraints */
970
+ #mask_gallery .gradio-gallery > div,
971
+ #result_gallery .gradio-gallery > div {
972
+ width: 100% !important;
973
+ height: auto !important;
974
+ min-height: auto !important;
975
+ max-height: none !important;
976
+ overflow: visible !important;
977
+ }
978
+
979
+ /* Gallery image containers - remove any height limits */
980
+ #mask_gallery .image-container,
981
+ #result_gallery .image-container,
982
+ #mask_gallery [data-testid="image"],
983
+ #result_gallery [data-testid="image"] {
984
+ width: 100% !important;
985
+ height: auto !important;
986
+ max-height: none !important;
987
+ overflow: visible !important;
988
+ }
989
+
990
+ /* Controlled gallery wrapper elements */
991
+ #mask_gallery .image-wrapper,
992
+ #result_gallery .image-wrapper {
993
+ max-height: 60vh !important;
994
+ overflow: hidden !important;
995
+ }
996
+
997
+ /* Specific targeting for Gradio's internal gallery elements */
998
+ #mask_gallery .grid-wrap,
999
+ #result_gallery .grid-wrap,
1000
+ #mask_gallery .preview-wrap,
1001
+ #result_gallery .preview-wrap {
1002
+ height: auto !important;
1003
+ max-height: 65vh !important;
1004
+ overflow: auto !important;
1005
+ border-radius: 8px !important;
1006
+ }
1007
+
1008
+ /* Ensure gallery grids are properly sized within container */
1009
+ #mask_gallery .grid,
1010
+ #result_gallery .grid {
1011
+ height: auto !important;
1012
+ max-height: 60vh !important;
1013
+ display: grid !important;
1014
+ grid-template-columns: repeat(auto-fit, minmax(120px, 1fr)) !important;
1015
+ gap: 8px !important;
1016
+ align-items: start !important;
1017
+ overflow: auto !important;
1018
+ padding: 4px !important;
1019
+ }
1020
+
1021
+ /* Custom scrollbar for gallery */
1022
+ #mask_gallery .gradio-gallery::-webkit-scrollbar,
1023
+ #result_gallery .gradio-gallery::-webkit-scrollbar,
1024
+ #mask_gallery .grid::-webkit-scrollbar,
1025
+ #result_gallery .grid::-webkit-scrollbar {
1026
+ width: 6px;
1027
+ height: 6px;
1028
+ }
1029
+
1030
+ #mask_gallery .gradio-gallery::-webkit-scrollbar-track,
1031
+ #result_gallery .gradio-gallery::-webkit-scrollbar-track,
1032
+ #mask_gallery .grid::-webkit-scrollbar-track,
1033
+ #result_gallery .grid::-webkit-scrollbar-track {
1034
+ background: rgba(0, 0, 0, 0.1);
1035
+ border-radius: 3px;
1036
+ }
1037
+
1038
+ #mask_gallery .gradio-gallery::-webkit-scrollbar-thumb,
1039
+ #result_gallery .gradio-gallery::-webkit-scrollbar-thumb,
1040
+ #mask_gallery .grid::-webkit-scrollbar-thumb,
1041
+ #result_gallery .grid::-webkit-scrollbar-thumb {
1042
+ background: var(--neutral-400);
1043
+ border-radius: 3px;
1044
+ }
1045
+
1046
+ #mask_gallery .gradio-gallery::-webkit-scrollbar-thumb:hover,
1047
+ #result_gallery .gradio-gallery::-webkit-scrollbar-thumb:hover,
1048
+ #mask_gallery .grid::-webkit-scrollbar-thumb:hover,
1049
+ #result_gallery .grid::-webkit-scrollbar-thumb:hover {
1050
+ background: var(--primary-blue);
1051
+ }
1052
+
1053
+ /* Thumbnail navigation styling in preview mode */
1054
+ #mask_gallery .thumbnail, #result_gallery .thumbnail {
1055
+ opacity: 0.7;
1056
+ transition: opacity 0.3s ease;
1057
+ border-radius: 6px;
1058
+ }
1059
+
1060
+ #mask_gallery .thumbnail:hover, #result_gallery .thumbnail:hover,
1061
+ #mask_gallery .thumbnail.selected, #result_gallery .thumbnail.selected {
1062
+ opacity: 1;
1063
+ box-shadow: 0 2px 8px rgba(14, 165, 233, 0.3);
1064
+ }
1065
+
1066
+ /* Improved layout spacing and organization */
1067
+ #glass_card .gradio-row {
1068
+ gap: 16px !important;
1069
+ margin-bottom: 6px !important;
1070
+ }
1071
+
1072
+ #glass_card .gradio-column {
1073
+ gap: 10px !important;
1074
+ }
1075
+
1076
+ /* Better section spacing with theme colors */
1077
+ #glass_card .gradio-group {
1078
+ margin-bottom: 6px !important;
1079
+ padding: 10px !important;
1080
+ border-radius: 8px !important;
1081
+ background: linear-gradient(135deg, rgba(255, 255, 255, 0.95) 0%, rgba(240, 249, 255, 0.9) 100%) !important;
1082
+ border: 1px solid rgba(14, 165, 233, 0.15) !important;
1083
+ box-shadow: 0 1px 3px rgba(14, 165, 233, 0.05) !important;
1084
+ transition: all 0.3s ease !important;
1085
+ overflow: visible !important; /* Allow dropdown to overflow */
1086
+ }
1087
+
1088
+ #glass_card .gradio-group:hover {
1089
+ background: linear-gradient(135deg, rgba(255, 255, 255, 0.98) 0%, rgba(240, 249, 255, 0.95) 100%) !important;
1090
+ border-color: rgba(14, 165, 233, 0.25) !important;
1091
+ box-shadow: 0 2px 8px rgba(14, 165, 233, 0.12) !important;
1092
+ /* transform removed to prevent layout shift that hides dropdown */
1093
+ }
1094
+
1095
+ /* Enhanced button styling for improved UX */
1096
+ button[variant="secondary"] {
1097
+ background: linear-gradient(135deg, rgba(255, 255, 255, 0.98) 0%, rgba(240, 249, 255, 0.95) 100%) !important;
1098
+ border: 1px solid rgba(14, 165, 233, 0.2) !important;
1099
+ color: var(--text-secondary) !important;
1100
+ box-shadow: 0 2px 8px rgba(14, 165, 233, 0.1) !important;
1101
+ transition: all 0.3s ease !important;
1102
+ }
1103
+
1104
+ button[variant="secondary"]:hover {
1105
+ background: var(--primary-blue) !important;
1106
+ border-color: var(--primary-blue) !important;
1107
+ color: white !important;
1108
+ transform: translateY(-1px) !important;
1109
+ box-shadow: 0 4px 12px rgba(14, 165, 233, 0.25) !important;
1110
+ }
1111
+
1112
+ /* Markdown header improvements */
1113
+ .gradio-markdown h1, .gradio-markdown h2, .gradio-markdown h3 {
1114
+ color: var(--text-primary) !important;
1115
+ font-weight: 600 !important;
1116
+ margin-bottom: 8px !important;
1117
+ margin-top: 4px !important;
1118
+ }
1119
+
1120
+ /* Radio button container improvements */
1121
+ #customization_mode_radio, #input_mask_mode_radio, #seg_ref_mode_radio {
1122
+ margin-bottom: 0px !important;
1123
+ margin-top: 0px !important;
1124
+ }
1125
+
1126
+ /* Reduce space between markdown headers and subsequent components */
1127
+ .gradio-markdown + .gradio-group {
1128
+ margin-top: 1px !important;
1129
+ }
1130
+
1131
+ .gradio-markdown + .gradio-image,
1132
+ .gradio-markdown + .gradio-imageeditor,
1133
+ .gradio-markdown + .gradio-textbox,
1134
+ .gradio-markdown + .gradio-gallery {
1135
+ margin-top: 1px !important;
1136
+ }
1137
+
1138
+ /* Specific spacing adjustments for numbered sections */
1139
+ .gradio-markdown:has(h1), .gradio-markdown:has(h2), .gradio-markdown:has(h3) {
1140
+ margin-bottom: 2px !important;
1141
+ }
1142
+
1143
+ /* Remove padding from image and gallery containers */
1144
+ .gradio-image, .gradio-imageeditor, .gradio-gallery {
1145
+ padding: 0 !important;
1146
+ margin: 0 !important;
1147
+ }
1148
+
1149
+ /* Image container styling with theme colors */
1150
+ .gradio-image, .gradio-imageeditor {
1151
+ border-radius: 12px;
1152
+ overflow: hidden;
1153
+ background: linear-gradient(135deg, rgba(255, 255, 255, 0.98) 0%, rgba(240, 249, 255, 0.95) 100%) !important;
1154
+ border: 1px solid rgba(14, 165, 233, 0.2) !important;
1155
+ box-shadow: 0 2px 8px rgba(14, 165, 233, 0.1) !important;
1156
+ transition: all 0.3s ease;
1157
+ }
1158
+
1159
+ .gradio-image:hover, .gradio-imageeditor:hover {
1160
+ border-color: var(--primary-blue) !important;
1161
+ box-shadow: 0 4px 16px rgba(14, 165, 233, 0.15) !important;
1162
+ transform: translateY(-1px);
1163
+ }
1164
+
1165
+ /* Image upload area styling */
1166
+ .gradio-image .upload-container,
1167
+ .gradio-imageeditor .upload-container,
1168
+ .gradio-image > div,
1169
+ .gradio-imageeditor > div {
1170
+ background: linear-gradient(135deg, rgba(255, 255, 255, 0.98) 0%, rgba(240, 249, 255, 0.95) 100%) !important;
1171
+ border: 1px solid rgba(14, 165, 233, 0.2) !important;
1172
+ border-radius: 12px !important;
1173
+ }
1174
+
1175
+ /* Image upload placeholder styling */
1176
+ .gradio-image .upload-text,
1177
+ .gradio-imageeditor .upload-text,
1178
+ .gradio-image [data-testid="upload-text"],
1179
+ .gradio-imageeditor [data-testid="upload-text"] {
1180
+ background: linear-gradient(135deg, rgba(255, 255, 255, 0.98) 0%, rgba(240, 249, 255, 0.95) 100%) !important;
1181
+ color: var(--text-secondary) !important;
1182
+ }
1183
+
1184
+ /* Image preview area */
1185
+ .gradio-image .image-container,
1186
+ .gradio-imageeditor .image-container,
1187
+ .gradio-image .preview-container,
1188
+ .gradio-imageeditor .preview-container {
1189
+ background: linear-gradient(135deg, rgba(255, 255, 255, 0.98) 0%, rgba(240, 249, 255, 0.95) 100%) !important;
1190
+ border-radius: 12px !important;
1191
+ }
1192
+
1193
+ /* Specific targeting for image upload areas */
1194
+ .gradio-image .wrap,
1195
+ .gradio-imageeditor .wrap,
1196
+ .gradio-image .block,
1197
+ .gradio-imageeditor .block {
1198
+ background: linear-gradient(135deg, rgba(255, 255, 255, 0.98) 0%, rgba(240, 249, 255, 0.95) 100%) !important;
1199
+ border: 1px solid rgba(14, 165, 233, 0.2) !important;
1200
+ border-radius: 12px !important;
1201
+ }
1202
+
1203
+ /* Image drop zone styling */
1204
+ .gradio-image .drop-zone,
1205
+ .gradio-imageeditor .drop-zone,
1206
+ .gradio-image .upload-area,
1207
+ .gradio-imageeditor .upload-area {
1208
+ background: linear-gradient(135deg, rgba(255, 255, 255, 0.98) 0%, rgba(240, 249, 255, 0.95) 100%) !important;
1209
+ border: 2px dashed rgba(14, 165, 233, 0.3) !important;
1210
+ border-radius: 12px !important;
1211
+ }
1212
+
1213
+ /* Force override any white backgrounds in image components */
1214
+ .gradio-image *,
1215
+ .gradio-imageeditor * {
1216
+ background-color: transparent !important;
1217
+ }
1218
+
1219
+ .gradio-image .gradio-image,
1220
+ .gradio-imageeditor .gradio-imageeditor {
1221
+ background: linear-gradient(135deg, rgba(255, 255, 255, 0.98) 0%, rgba(240, 249, 255, 0.95) 100%) !important;
1222
+ }
1223
+
1224
+ /* Specific styling for Reference Image and Target Images */
1225
+ #reference_image,
1226
+ #target_image_1,
1227
+ #target_image_2 {
1228
+ background: linear-gradient(135deg, rgba(255, 255, 255, 0.98) 0%, rgba(240, 249, 255, 0.95) 100%) !important;
1229
+ border: 1px solid rgba(14, 165, 233, 0.2) !important;
1230
+ box-shadow: 0 2px 8px rgba(14, 165, 233, 0.1) !important;
1231
+ border-radius: 12px !important;
1232
+ }
1233
+
1234
+ #reference_image *,
1235
+ #target_image_1 *,
1236
+ #target_image_2 * {
1237
+ background: linear-gradient(135deg, rgba(255, 255, 255, 0.98) 0%, rgba(240, 249, 255, 0.95) 100%) !important;
1238
+ border-radius: 12px !important;
1239
+ }
1240
+
1241
+ /* Upload area for specific image components */
1242
+ #reference_image .upload-container,
1243
+ #target_image_1 .upload-container,
1244
+ #target_image_2 .upload-container,
1245
+ #reference_image .drop-zone,
1246
+ #target_image_1 .drop-zone,
1247
+ #target_image_2 .drop-zone {
1248
+ background: linear-gradient(135deg, rgba(255, 255, 255, 0.98) 0%, rgba(240, 249, 255, 0.95) 100%) !important;
1249
+ border: 2px dashed rgba(14, 165, 233, 0.3) !important;
1250
+ border-radius: 12px !important;
1251
+ }
1252
+
1253
+ /* Hover effects for specific image components */
1254
+ #reference_image:hover,
1255
+ #target_image_1:hover,
1256
+ #target_image_2:hover {
1257
+ border-color: var(--primary-blue) !important;
1258
+ box-shadow: 0 4px 16px rgba(14, 165, 233, 0.15) !important;
1259
+ transform: translateY(-1px);
1260
+ }
1261
+
1262
+ /* Group styling matching getting_started white cards */
1263
+ .group, .gradio-group {
1264
+ border-radius: 8px;
1265
+ background: var(--bg-primary);
1266
+ border: 1px solid var(--neutral-200);
1267
+ box-shadow: 0 1px 3px rgba(0, 0, 0, 0.05);
1268
+ }
1269
+
1270
+ /* Subtle page background with theme colors */
1271
+ body, .gradio-container {
1272
+ background: linear-gradient(135deg, #f8fafc 0%, #f0f9ff 50%, #f8fafc 100%);
1273
+ min-height: 100vh;
1274
+ }
1275
+
1276
+ /* Global glass container with subtle Apple-style gradient */
1277
+ #global_glass_container {
1278
+ position: relative;
1279
+ border-radius: 20px;
1280
+ padding: 16px;
1281
+ margin: 12px auto;
1282
+ max-width: 1400px;
1283
+ background: linear-gradient(145deg,
1284
+ rgba(248, 250, 252, 0.98),
1285
+ rgba(241, 245, 249, 0.95));
1286
+ box-shadow:
1287
+ 0 20px 40px rgba(15, 23, 42, 0.08),
1288
+ 0 8px 24px rgba(15, 23, 42, 0.06),
1289
+ inset 0 1px 0 rgba(255, 255, 255, 0.9);
1290
+ border: 1px solid rgba(226, 232, 240, 0.7);
1291
+ transition: all 0.3s ease;
1292
+ overflow: visible !important; /* Allow dropdown to overflow */
1293
+ }
1294
+
1295
+ /* Subtle gradient overlay for Apple effect */
1296
+ #global_glass_container::after {
1297
+ content: '';
1298
+ position: absolute;
1299
+ top: 0;
1300
+ left: 0;
1301
+ right: 0;
1302
+ height: 250px;
1303
+ background: linear-gradient(135deg,
1304
+ rgba(14, 165, 233, 0.08) 0%,
1305
+ rgba(6, 182, 212, 0.06) 25%,
1306
+ rgba(16, 185, 129, 0.08) 50%,
1307
+ rgba(139, 92, 246, 0.06) 75%,
1308
+ rgba(14, 165, 233, 0.08) 100%);
1309
+ background-size: 300% 300%;
1310
+ animation: subtleGradientShift 15s ease-in-out infinite;
1311
+ pointer-events: none;
1312
+ z-index: 0;
1313
+ }
1314
+
1315
+ @keyframes subtleGradientShift {
1316
+ 0%, 100% {
1317
+ background-position: 0% 50%;
1318
+ opacity: 0.8;
1319
+ }
1320
+ 50% {
1321
+ background-position: 100% 50%;
1322
+ opacity: 1;
1323
+ }
1324
+ }
1325
+
1326
+ /* Ensure content is above the gradient overlay */
1327
+ #global_glass_container > * {
1328
+ position: relative;
1329
+ z-index: 1;
1330
+ }
1331
+
1332
+ /* Hover effect for global container - transform disabled to avoid dropdown reposition */
1333
+ #global_glass_container:hover {
1334
+ /* transform: translateY(-2px); */
1335
+ box-shadow:
1336
+ 0 25px 50px rgba(15, 23, 42, 0.08),
1337
+ 0 12px 30px rgba(15, 23, 42, 0.06),
1338
+ inset 0 1px 0 rgba(255, 255, 255, 0.9);
1339
+ border-color: rgba(226, 232, 240, 0.8);
1340
+ }
1341
+
1342
+ /* Subtle border highlight for global container */
1343
+ #global_glass_container::before {
1344
+ content: "";
1345
+ position: absolute;
1346
+ inset: 0;
1347
+ border-radius: 20px;
1348
+ padding: 1px;
1349
+ background: linear-gradient(135deg,
1350
+ rgba(255, 255, 255, 0.8),
1351
+ rgba(226, 232, 240, 0.4),
1352
+ rgba(255, 255, 255, 0.6),
1353
+ rgba(226, 232, 240, 0.3)
1354
+ );
1355
+ -webkit-mask:
1356
+ linear-gradient(#fff 0 0) content-box,
1357
+ linear-gradient(#fff 0 0);
1358
+ -webkit-mask-composite: xor;
1359
+ mask-composite: exclude;
1360
+ pointer-events: none;
1361
+ z-index: 0;
1362
+ }
1363
+
1364
+ /* Inner glassmorphism container with theme colors */
1365
+ #glass_card {
1366
+ position: relative;
1367
+ border-radius: 16px;
1368
+ padding: 16px;
1369
+ background: linear-gradient(135deg, rgba(255, 255, 255, 0.6) 0%, rgba(240, 249, 255, 0.5) 100%);
1370
+ box-shadow:
1371
+ 0 8px 24px rgba(14, 165, 233, 0.08),
1372
+ inset 0 1px 0 rgba(255, 255, 255, 0.7);
1373
+ border: 1px solid rgba(14, 165, 233, 0.2);
1374
+ margin-bottom: 12px;
1375
+ transition: all 0.3s ease;
1376
+ overflow: visible !important; /* Allow dropdown to overflow */
1377
+ }
1378
+
1379
+ #glass_card:hover {
1380
+ background: linear-gradient(135deg, rgba(255, 255, 255, 0.7) 0%, rgba(240, 249, 255, 0.6) 100%);
1381
+ border-color: rgba(14, 165, 233, 0.3);
1382
+ box-shadow:
1383
+ 0 12px 32px rgba(14, 165, 233, 0.12),
1384
+ inset 0 1px 0 rgba(255, 255, 255, 0.8);
1385
+ }
1386
+
1387
+ /* Subtle inner border gradient for liquid glass feel */
1388
+ #glass_card::before {
1389
+ content: "";
1390
+ position: absolute;
1391
+ inset: 0;
1392
+ border-radius: 16px;
1393
+ padding: 1px;
1394
+ background: linear-gradient(135deg,
1395
+ rgba(255, 255, 255, 0.6),
1396
+ rgba(226, 232, 240, 0.2));
1397
+ -webkit-mask:
1398
+ linear-gradient(#fff 0 0) content-box,
1399
+ linear-gradient(#fff 0 0);
1400
+ -webkit-mask-composite: xor;
1401
+ mask-composite: exclude;
1402
+ pointer-events: none;
1403
+ }
1404
+
1405
+ /* Preserve the airy layout inside the cards */
1406
+ #global_glass_container .gradio-column { gap: 12px; }
1407
+ #glass_card .gradio-row { gap: 16px; }
1408
+ #glass_card .gradio-column { gap: 12px; }
1409
+ #glass_card .gradio-group { margin: 8px 0; }
1410
+
1411
+ /* Text selection matching getting_started colors */
1412
+ ::selection {
1413
+ background: var(--badge-blue-bg);
1414
+ color: var(--badge-blue-text);
1415
+ }
1416
+
1417
+ /* Placeholder styling */
1418
+ ::placeholder {
1419
+ color: var(--text-muted);
1420
+ opacity: 0.8;
1421
+ }
1422
+
1423
+ /* Improved error state styling */
1424
+ .error {
1425
+ border-color: #ef4444 !important;
1426
+ box-shadow: 0 0 0 2px rgba(239, 68, 68, 0.1) !important;
1427
+ }
1428
+
1429
+ /* Success state using getting_started green */
1430
+ .success-state {
1431
+ border-color: var(--primary-green) !important;
1432
+ box-shadow: 0 0 0 2px rgba(16, 185, 129, 0.1) !important;
1433
+ }
1434
+
1435
+
1436
+
1437
+
1438
+
1439
+
1440
+
1441
+
1442
+
1443
+
1444
+
1445
+ /* Label styling */
1446
+ .gradio-label {
1447
+ color: var(--text-primary);
1448
+ font-weight: 600;
1449
+ }
1450
+
1451
+ /* Markdown content styling */
1452
+ .markdown-body {
1453
+ color: var(--text-secondary);
1454
+ line-height: 1.6;
1455
+ }
1456
+
1457
+ .markdown-body h1, .markdown-body h2, .markdown-body h3 {
1458
+ color: var(--text-primary);
1459
+ }
1460
+
1461
+ /* Step indicators styling */
1462
+ .gradio-markdown h1, .gradio-markdown h2, .gradio-markdown h3,
1463
+ .gradio-markdown p {
1464
+ margin: 0.25rem 0;
1465
+ }
1466
+
1467
+ /* Enhanced step indicators with numbers */
1468
+ .gradio-markdown:contains("1."), .gradio-markdown:contains("2."),
1469
+ .gradio-markdown:contains("3."), .gradio-markdown:contains("4."),
1470
+ .gradio-markdown:contains("5."), .gradio-markdown:contains("6."),
1471
+ .gradio-markdown:contains("7.") {
1472
+ position: relative;
1473
+ padding-left: 2.5rem;
1474
+ color: var(--text-primary);
1475
+ font-weight: 600;
1476
+ }
1477
+
1478
+ /* Specific button styling */
1479
+ #undo_btnSEG, #dilate_btn, #erode_btn, #bounding_box_btn {
1480
+ background: linear-gradient(135deg, rgba(255, 255, 255, 0.98) 0%, rgba(240, 249, 255, 0.95) 100%);
1481
+ border: 1px solid rgba(14, 165, 233, 0.2);
1482
+ color: var(--text-secondary);
1483
+ font-weight: 500;
1484
+ padding: 8px 16px;
1485
+ border-radius: 6px;
1486
+ box-shadow: 0 2px 8px rgba(14, 165, 233, 0.1);
1487
+ transition: all 0.3s ease;
1488
+ }
1489
+
1490
+ #undo_btnSEG:hover, #dilate_btn:hover, #erode_btn:hover, #bounding_box_btn:hover {
1491
+ background: var(--primary-blue);
1492
+ border-color: var(--primary-blue);
1493
+ color: white;
1494
+ transform: translateY(-1px);
1495
+ box-shadow: 0 4px 12px rgba(14, 165, 233, 0.3);
1496
+ }
1497
+
1498
+ /* Submit button enhanced styling - unified with primary blue */
1499
+ button[variant="primary"], .gradio-button.primary {
1500
+ background: var(--primary-blue);
1501
+ border-color: var(--primary-blue);
1502
+ color: white;
1503
+ font-weight: 600;
1504
+ font-size: 1rem;
1505
+ padding: 12px 24px;
1506
+ box-shadow: 0 4px 16px rgba(14, 165, 233, 0.25);
1507
+ transition: all 0.3s ease;
1508
+ }
1509
+
1510
+ button[variant="primary"]:hover, .gradio-button.primary:hover {
1511
+ background: #0284c7;
1512
+ border-color: #0284c7;
1513
+ box-shadow: 0 6px 20px rgba(14, 165, 233, 0.4);
1514
+ transform: translateY(-2px);
1515
+ }
1516
+
1517
+ /* Improved button states */
1518
+ button:disabled {
1519
+ opacity: 0.5;
1520
+ cursor: not-allowed;
1521
+ transform: none !important;
1522
+ box-shadow: none !important;
1523
+ }
1524
+
1525
+ button:disabled::after {
1526
+ display: none;
1527
+ }
1528
+
1529
+ button.processing {
1530
+ background: var(--neutral-400) !important;
1531
+ border-color: var(--neutral-400) !important;
1532
+ cursor: wait;
1533
+ animation: processingPulse 2s ease-in-out infinite;
1534
+ }
1535
+
1536
+ @keyframes processingPulse {
1537
+ 0%, 100% { opacity: 0.8; }
1538
+ 50% { opacity: 1; }
1539
+ }
1540
+
1541
+ /* Responsive improvements */
1542
+ @media (max-width: 768px) {
1543
+ .header-content {
1544
+ padding: 1.2rem 1.8rem;
1545
+ margin: 0 1rem;
1546
+ }
1547
+
1548
+ .main-title {
1549
+ font-size: 2rem;
1550
+ }
1551
+
1552
+ .title-icon {
1553
+ font-size: 1.8rem;
1554
+ }
1555
+
1556
+ .subtitle {
1557
+ font-size: 0.9rem;
1558
+ }
1559
+
1560
+ .header-badges {
1561
+ gap: 0.6rem;
1562
+ }
1563
+
1564
+ .badge-link {
1565
+ padding: 0.4rem 0.8rem;
1566
+ font-size: 0.85rem;
1567
+ }
1568
+
1569
+ .header-badge {
1570
+ padding: 0.4rem 0.8rem;
1571
+ font-size: 0.85rem;
1572
+ }
1573
+
1574
+ /* Getting Started responsive */
1575
+ .getting-started-container {
1576
+ padding: 1rem;
1577
+ margin: 0 0.5rem;
1578
+ }
1579
+
1580
+ .guide-title {
1581
+ font-size: 1.1rem;
1582
+ }
1583
+
1584
+ .guide-subtitle {
1585
+ font-size: 0.85rem;
1586
+ }
1587
+
1588
+ .step-card {
1589
+ padding: 0.8rem;
1590
+ margin-bottom: 1rem;
1591
+ }
1592
+
1593
+ .step-header {
1594
+ font-size: 0.9rem;
1595
+ }
1596
+
1597
+ .step-number {
1598
+ width: 22px;
1599
+ height: 22px;
1600
+ font-size: 0.7rem;
1601
+ }
1602
+
1603
+ .step-list {
1604
+ font-size: 0.8rem;
1605
+ padding-left: 1rem;
1606
+ }
1607
+
1608
+ .tips-card {
1609
+ padding: 0.6rem;
1610
+ }
1611
+
1612
+ .tips-content {
1613
+ font-size: 0.75rem;
1614
+ }
1615
+
1616
+ .final-message {
1617
+ padding: 0.6rem;
1618
+ }
1619
+
1620
+ .final-text {
1621
+ font-size: 0.8rem;
1622
+ }
1623
+
1624
+ button {
1625
+ min-height: 44px;
1626
+ }
1627
+
1628
+ input, textarea, select {
1629
+ min-height: 44px;
1630
+ }
1631
+
1632
+ /* Mobile optimization for subtle effects */
1633
+ #global_glass_container {
1634
+ padding: 16px;
1635
+ margin: 8px;
1636
+ border-radius: 16px;
1637
+ }
1638
+
1639
+ #global_glass_container::after {
1640
+ height: 180px;
1641
+ animation-duration: 18s;
1642
+ }
1643
+
1644
+ #glass_card {
1645
+ padding: 20px;
1646
+ margin: 10px;
1647
+ border-radius: 12px;
1648
+ }
1649
+
1650
+ #glass_card .gradio-row { gap: 12px; }
1651
+ #glass_card .gradio-column { gap: 12px; }
1652
+
1653
+
1654
+ }
1655
+
1656
+ /* Ensure gallery works properly in all screen sizes */
1657
+ @media (min-width: 1200px) {
1658
+ #mask_gallery .gradio-gallery, #result_gallery .gradio-gallery {
1659
+ min-height: 300px !important;
1660
+ max-height: 80vh !important;
1661
+ }
1662
+
1663
+ .responsive-gallery .grid-container {
1664
+ grid-template-columns: repeat(auto-fit, minmax(150px, 1fr)) !important;
1665
+ }
1666
+ }
1667
+
1668
+ /* Fix for gallery duplicate content issue - ensure clean display */
1669
+ #mask_gallery > div > div:nth-child(n+2),
1670
+ #result_gallery > div > div:nth-child(n+2) {
1671
+ display: none !important;
1672
+ }
1673
+
1674
+ #mask_gallery .gradio-gallery:nth-child(n+2),
1675
+ #result_gallery .gradio-gallery:nth-child(n+2) {
1676
+ display: none !important;
1677
+ }
1678
+
1679
+ """
app/ui_components.py ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ UI components construction for IC-Custom application.
5
+ """
6
+ import gradio as gr
7
+ from constants import (
8
+ ASPECT_RATIO_LABELS,
9
+ DEFAULT_ASPECT_RATIO,
10
+ DEFAULT_BRUSH_SIZE
11
+ )
12
+
13
+
14
+ def create_theme():
15
+ """Create and configure the Gradio theme."""
16
+ theme = gr.themes.Ocean()
17
+ theme.set(
18
+ checkbox_label_background_fill_selected="*button_primary_background_fill",
19
+ checkbox_label_text_color_selected="*button_primary_text_color",
20
+ )
21
+ return theme
22
+
23
+
24
+ def create_css():
25
+ """Create custom CSS for the application."""
26
+ from stylesheets import get_css
27
+ return get_css()
28
+
29
+
30
+ def create_header_section():
31
+ """Create the header section with title and description."""
32
+ from metainfo import head, getting_started
33
+
34
+ with gr.Row():
35
+ gr.HTML(head)
36
+
37
+ with gr.Accordion(label="🚀 Getting Started:", open=True, elem_id="accordion"):
38
+ with gr.Row(equal_height=True):
39
+ gr.HTML(getting_started)
40
+
41
+
42
+ def create_customization_section():
43
+ """Create the customization mode selection section."""
44
+ with gr.Row():
45
+ # Add a note to remind users to click Clear before starting
46
+ md_custmization_mode = gr.Markdown(
47
+ "1. Select a Customization Mode\n\n*Tip: Please click the Clear button first to reset all states before starting a new task.*"
48
+ )
49
+ with gr.Row():
50
+ custmization_mode = gr.Radio(
51
+ ["Position-aware", "Position-free"],
52
+ value="Position-aware",
53
+ scale=1,
54
+ elem_id="customization_mode_radio",
55
+ show_label=False,
56
+ label="Customization Mode",
57
+ )
58
+ return custmization_mode, md_custmization_mode
59
+
60
+
61
+ def create_image_input_section():
62
+ """Create image input section optimized for left column layout."""
63
+ # Reference image section
64
+ md_image_reference = gr.Markdown("2. Input reference image")
65
+ with gr.Group():
66
+ image_reference = gr.Image(
67
+ label="Reference Image",
68
+ type="pil",
69
+ interactive=True,
70
+ height=320,
71
+ container=True,
72
+ elem_id="reference_image"
73
+ )
74
+
75
+ # Input mask mode selection
76
+ md_input_mask_mode = gr.Markdown("3. Select input mask mode")
77
+ with gr.Group():
78
+ input_mask_mode = gr.Radio(
79
+ ["Precise mask", "User-drawn mask"],
80
+ value="Precise mask",
81
+ elem_id="input_mask_mode_radio",
82
+ show_label=False,
83
+ label="Input Mask Mode",
84
+ )
85
+
86
+ # Target image section
87
+ md_target_image = gr.Markdown("4. Input target image & mask (Iterate clicking or brushing until the target is covered)")
88
+
89
+ # Precise mask mode
90
+ with gr.Group():
91
+ image_target_1 = gr.Image(
92
+ type="pil",
93
+ label="Target Image (precise mask)",
94
+ interactive=True,
95
+ visible=True,
96
+ height=500,
97
+ container=True,
98
+ elem_id="target_image_1"
99
+ )
100
+ with gr.Row():
101
+ undo_target_seg_button = gr.Button(
102
+ 'Undo seg',
103
+ elem_id="undo_btnSEG",
104
+ visible=True,
105
+ size="sm",
106
+ scale=1
107
+ )
108
+
109
+ # User-drawn mask mode
110
+ with gr.Group():
111
+ image_target_2 = gr.ImageEditor(
112
+ label="Target Image (user-drawn mask)",
113
+ type="pil",
114
+ brush=gr.Brush(colors=["#FFFFFF"], default_size=DEFAULT_BRUSH_SIZE, color_mode="fixed"),
115
+ layers=False,
116
+ interactive=True,
117
+ sources=["upload", "clipboard"],
118
+ placeholder="Please click here or the icon to upload the image.",
119
+ visible=False,
120
+ height=500,
121
+ container=True,
122
+ elem_id="target_image_2",
123
+ fixed_canvas=True,
124
+ )
125
+
126
+ return (image_reference, input_mask_mode, image_target_1, image_target_2,
127
+ undo_target_seg_button, md_image_reference, md_input_mask_mode, md_target_image)
128
+
129
+
130
+ def create_prompt_section():
131
+ """Create the text prompt input section with improved layout."""
132
+ md_prompt = gr.Markdown("5. Input text prompt (optional)")
133
+ with gr.Group():
134
+ prompt = gr.Textbox(
135
+ placeholder="Please input the description for the target scene.",
136
+ value="",
137
+ lines=2,
138
+ show_label=False,
139
+ label="Text Prompt",
140
+ container=True,
141
+ elem_id="text_prompt"
142
+ )
143
+
144
+ with gr.Row():
145
+ vlm_generate_btn = gr.Button(
146
+ "🤖 VLM Auto-generate",
147
+ scale=1,
148
+ elem_id="vlm_generate_btn",
149
+ variant="secondary"
150
+ )
151
+ vlm_polish_btn = gr.Button(
152
+ "✨ VLM Auto-polish",
153
+ scale=1,
154
+ elem_id="vlm_polish_btn",
155
+ variant="secondary"
156
+ )
157
+
158
+ return prompt, vlm_generate_btn, vlm_polish_btn, md_prompt
159
+
160
+
161
+ def create_advanced_options_section():
162
+ """Create the advanced options section."""
163
+ with gr.Accordion("Advanced Options", open=False, elem_id="accordion1"):
164
+ with gr.Group():
165
+ aspect_ratio = gr.Dropdown(
166
+ label="Output aspect ratio",
167
+ choices=ASPECT_RATIO_LABELS,
168
+ value=DEFAULT_ASPECT_RATIO,
169
+ interactive=True,
170
+ allow_custom_value=False,
171
+ filterable=False,
172
+ elem_id="aspect_ratio_dropdown"
173
+ )
174
+
175
+ with gr.Group():
176
+ seg_ref_mode = gr.Radio(
177
+ label="Segmentation mode",
178
+ choices=["Full Ref", "Masked Ref"],
179
+ value="Full Ref",
180
+ elem_id="seg_ref_mode_radio"
181
+ )
182
+ move_to_center = gr.Checkbox(label="Move object to center", value=False, elem_id="move_to_center_checkbox")
183
+
184
+ with gr.Group():
185
+ with gr.Row():
186
+ use_background_preservation = gr.Checkbox(label="Use background preservation", value=False, elem_id="use_bg_preservation_checkbox")
187
+ background_blend_threshold = gr.Slider(
188
+ label="Background blend threshold",
189
+ minimum=0,
190
+ maximum=1,
191
+ step=0.1,
192
+ value=0.5
193
+ )
194
+
195
+ with gr.Group():
196
+ with gr.Row():
197
+ seed = gr.Slider(
198
+ label="Seed (-1 for random): ",
199
+ minimum=-1,
200
+ maximum=2147483647,
201
+ step=1,
202
+ value=-1,
203
+ scale=4
204
+ )
205
+
206
+ num_images_per_prompt = gr.Slider(
207
+ label="Num samples",
208
+ minimum=1,
209
+ maximum=4,
210
+ step=1,
211
+ value=1,
212
+ scale=1
213
+ )
214
+
215
+ with gr.Group():
216
+ with gr.Row():
217
+ guidance = gr.Slider(
218
+ label="Guidance scale",
219
+ minimum=10,
220
+ maximum=65,
221
+ step=1,
222
+ value=40,
223
+ )
224
+ num_steps = gr.Slider(
225
+ label="Number of inference steps",
226
+ minimum=1,
227
+ maximum=60,
228
+ step=1,
229
+ value=32,
230
+ )
231
+ with gr.Row():
232
+ true_gs = gr.Slider(
233
+ label="True GS",
234
+ minimum=1,
235
+ maximum=10,
236
+ step=1,
237
+ value=3,
238
+ )
239
+
240
+ return (aspect_ratio, seg_ref_mode, move_to_center, use_background_preservation,
241
+ background_blend_threshold, seed, num_images_per_prompt, guidance, num_steps, true_gs)
242
+
243
+
244
+ def create_mask_operation_section():
245
+ """Create mask operation section optimized for right column (outputs)."""
246
+ md_mask_operation = gr.Markdown("6. View or modify the target mask")
247
+
248
+ with gr.Group():
249
+ # Mask gallery with responsive layout
250
+ mask_gallery = gr.Gallery(
251
+ label='Mask Preview',
252
+ show_label=False,
253
+ interactive=False,
254
+ columns=2,
255
+ rows=1,
256
+ height="auto",
257
+ object_fit="contain",
258
+ preview=True,
259
+ allow_preview=True,
260
+ selected_index=0,
261
+ elem_id="mask_gallery",
262
+ elem_classes=["custom-gallery", "responsive-gallery"],
263
+ container=True,
264
+ show_fullscreen_button=False
265
+ )
266
+
267
+ # Mask operation buttons - horizontal layout
268
+ with gr.Row():
269
+ dilate_button = gr.Button(
270
+ '🔍 Dilate',
271
+ elem_id="dilate_btn",
272
+ variant="secondary",
273
+ size="sm",
274
+ scale=1
275
+ )
276
+ erode_button = gr.Button(
277
+ '🔽 Erode',
278
+ elem_id="erode_btn",
279
+ variant="secondary",
280
+ size="sm",
281
+ scale=1
282
+ )
283
+ bounding_box_button = gr.Button(
284
+ '📦 Bounding box',
285
+ elem_id="bounding_box_btn",
286
+ variant="secondary",
287
+ size="sm",
288
+ scale=1
289
+ )
290
+
291
+ return mask_gallery, dilate_button, erode_button, bounding_box_button, md_mask_operation
292
+
293
+
294
+ def create_output_section():
295
+ """Create the output section optimized for right column."""
296
+ md_submit = gr.Markdown("7. Submit and view the output")
297
+
298
+ # Generation controls at top for better workflow
299
+ with gr.Group():
300
+ with gr.Row():
301
+ submit_button = gr.Button(
302
+ "💫 Generate",
303
+ variant="primary",
304
+ scale=3,
305
+ size="lg"
306
+ )
307
+ clear_btn = gr.ClearButton(
308
+ scale=1,
309
+ variant="secondary",
310
+ value="🗑️ Clear"
311
+ )
312
+
313
+ # Results gallery with responsive layout
314
+ with gr.Group():
315
+ result_gallery = gr.Gallery(
316
+ label='Generated Results',
317
+ show_label=False,
318
+ interactive=False,
319
+ columns=1,
320
+ rows=1,
321
+ height="auto",
322
+ object_fit="contain",
323
+ preview=True,
324
+ allow_preview=True,
325
+ selected_index=0,
326
+ elem_id="result_gallery",
327
+ elem_classes=["custom-gallery", "responsive-gallery"],
328
+ container=True,
329
+ show_fullscreen_button=False
330
+ )
331
+
332
+ return result_gallery, submit_button, clear_btn, md_submit
333
+
334
+
335
+ def create_examples_section(examples_list, inputs, outputs, fn):
336
+ """Create the examples section with required arguments."""
337
+ examples = gr.Examples(
338
+ examples=examples_list,
339
+ inputs=inputs,
340
+ outputs=outputs,
341
+ fn=fn,
342
+ cache_examples=False,
343
+ examples_per_page=10,
344
+ run_on_click=True,
345
+ )
346
+ return examples
347
+
348
+
349
+ def create_citation_section():
350
+ """Create the citation section."""
351
+ from metainfo import citation
352
+
353
+ with gr.Row():
354
+ gr.Markdown(citation)
app/utils.py ADDED
@@ -0,0 +1,429 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import base64
4
+ from io import BytesIO
5
+ from typing import Optional
6
+
7
+ from PIL import Image
8
+
9
+ import numpy as np
10
+ import torch
11
+
12
+ from segment_anything import SamPredictor, sam_model_registry
13
+ from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
14
+ from qwen_vl_utils import process_vision_info
15
+ from huggingface_hub import hf_hub_download
16
+
17
+ sys.path.append(os.getcwd())
18
+ import BEN2
19
+
20
+
21
+ ## Ordinary function
22
+ def resize(image: Image.Image,
23
+ target_width: int,
24
+ target_height: int,
25
+ interpolate: Image.Resampling = Image.Resampling.LANCZOS,
26
+ return_type: str = "pil") -> Image.Image | np.ndarray:
27
+ """
28
+ Crops and resizes an image while preserving the aspect ratio.
29
+
30
+ Args:
31
+ image (Image.Image): Input PIL image to be cropped and resized.
32
+ target_width (int): Target width of the output image.
33
+ target_height (int): Target height of the output image.
34
+ interpolate (Image.Resampling): The interpolation method.
35
+ return_type (str): The type of the output image.
36
+
37
+ Returns:
38
+ Image.Image: Cropped and resized image.
39
+ """
40
+ # Original dimensions
41
+ resized_image = image.resize((target_width, target_height), interpolate)
42
+ if return_type == "pil":
43
+ return resized_image
44
+ elif return_type == "np":
45
+ return np.array(resized_image)
46
+ else:
47
+ raise ValueError(f"Invalid return type: {return_type}")
48
+
49
+
50
+ def resize_long_edge(
51
+ image: Image.Image,
52
+ long_edge_size: int,
53
+ interpolate: Image.Resampling = Image.Resampling.LANCZOS,
54
+ return_type: str = "pil"
55
+ ) -> np.ndarray | Image.Image:
56
+ """
57
+ Resize the long edge of the image to the long_edge_size.
58
+
59
+ Args:
60
+ image (Image.Image): The image to resize.
61
+ long_edge_size (int): The size of the long edge.
62
+ interpolate (Image.Resampling): The interpolation method.
63
+
64
+ Returns:
65
+ np.ndarray: The resized image.
66
+ """
67
+ w, h = image.size
68
+ scale_ratio = long_edge_size / max(h, w)
69
+ output_w = int(w * scale_ratio)
70
+ output_h = int(h * scale_ratio)
71
+ image = resize(image, target_width=int(output_w), target_height=int(output_h), interpolate=interpolate, return_type=return_type)
72
+ return image
73
+
74
+
75
+ def ensure_divisible_by_value(
76
+ image: Image.Image | np.ndarray,
77
+ value: int = 8,
78
+ interpolate: Image.Resampling = Image.Resampling.NEAREST,
79
+ return_type: str = "np"
80
+ ) -> np.ndarray | Image.Image:
81
+ """
82
+ Ensure the image dimensions are divisible by value.
83
+
84
+ Args:
85
+ image_pil (Image.Image): The image to ensure divisible by value.
86
+ value (int): The value to ensure divisible by.
87
+ interpolate (Image.Resampling): The interpolation method.
88
+ return_type (str): The type of the output image.
89
+
90
+ Returns:
91
+ np.ndarray | Image.Image: The resized image.
92
+ """
93
+
94
+ if isinstance(image, np.ndarray):
95
+ image = Image.fromarray(image)
96
+
97
+ w, h = image.size
98
+
99
+ w = (w // value) * value
100
+ h = (h // value) * value
101
+ image = resize(image, w, h, interpolate=interpolate, return_type=return_type)
102
+ return image
103
+
104
+
105
+ def resize_paired_image(
106
+ image_reference: np.ndarray,
107
+ image_target: np.ndarray,
108
+ mask_target: np.ndarray,
109
+ force_resize_long_edge: bool = False,
110
+ return_type: str = "np"
111
+ ) -> tuple[np.ndarray | Image.Image, np.ndarray | Image.Image, np.ndarray | Image.Image]:
112
+
113
+ if isinstance(image_reference, np.ndarray):
114
+ image_reference = Image.fromarray(image_reference)
115
+ if isinstance(image_target, np.ndarray):
116
+ image_target = Image.fromarray(image_target)
117
+ if isinstance(mask_target, np.ndarray):
118
+ mask_target = Image.fromarray(mask_target)
119
+
120
+ if force_resize_long_edge:
121
+ image_reference = resize_long_edge(image_reference, 1024, interpolate=Image.Resampling.LANCZOS, return_type=return_type)
122
+ image_target = resize_long_edge(image_target, 1024, interpolate=Image.Resampling.LANCZOS, return_type=return_type)
123
+ mask_target = resize_long_edge(mask_target, 1024, interpolate=Image.Resampling.NEAREST, return_type=return_type)
124
+
125
+ if isinstance(image_reference, Image.Image):
126
+ ref_width, ref_height = image_reference.size
127
+ target_width, target_height = image_target.size
128
+ else:
129
+ ref_height, ref_width = image_reference.shape[:2]
130
+ target_width, target_height = image_target.shape[:2]
131
+
132
+ # resize the ref image to the same height as the target image and ensure the ratio remains the same
133
+ if ref_height != target_height:
134
+ scale_ratio = target_height / ref_height
135
+ image_reference = resize(image_reference, int(ref_width * scale_ratio), target_height, interpolate=Image.Resampling.LANCZOS, return_type=return_type)
136
+
137
+ if return_type == "pil":
138
+ image_reference = Image.fromarray(image_reference) if isinstance(image_reference, np.ndarray) else image_reference
139
+ image_target = Image.fromarray(image_target) if isinstance(image_target, np.ndarray) else image_target
140
+ mask_target = Image.fromarray(mask_target) if isinstance(mask_target, np.ndarray) else mask_target
141
+ return image_reference, image_target, mask_target
142
+ else:
143
+ image_reference = np.array(image_reference) if isinstance(image_reference, Image.Image) else image_reference
144
+ image_target = np.array(image_target) if isinstance(image_target, Image.Image) else image_target
145
+ mask_target = np.array(mask_target) if isinstance(mask_target, Image.Image) else mask_target
146
+ return image_reference, image_target, mask_target
147
+
148
+
149
+ def prepare_input_images(
150
+ img_ref: np.ndarray,
151
+ custmization_mode: str,
152
+ img_target: Optional[np.ndarray] = None,
153
+ mask_target: Optional[np.ndarray] = None,
154
+ width: Optional[int] = None,
155
+ height: Optional[int] = None,
156
+ force_resize_long_edge: bool = False,
157
+ return_type: str = "np"
158
+ ) -> tuple[np.ndarray | Image.Image, np.ndarray | Image.Image, np.ndarray | Image.Image]:
159
+
160
+
161
+ if custmization_mode.lower() == "position-free":
162
+ img_target = np.ones_like(img_ref) * 255
163
+ mask_target = np.zeros_like(img_ref)
164
+
165
+ if isinstance(width, int) and isinstance(height, int):
166
+ img_ref = resize(Image.fromarray(img_ref), width, height, interpolate=Image.Resampling.LANCZOS, return_type=return_type)
167
+ img_target = resize(Image.fromarray(img_target), width, height, interpolate=Image.Resampling.LANCZOS, return_type=return_type)
168
+ mask_target = resize(Image.fromarray(mask_target), width, height, interpolate=Image.Resampling.NEAREST, return_type=return_type)
169
+ else:
170
+ img_ref, img_target, mask_target = resize_paired_image(img_ref, img_target, mask_target, force_resize_long_edge, return_type=return_type)
171
+
172
+ img_ref = ensure_divisible_by_value(img_ref, value=16, interpolate=Image.Resampling.LANCZOS, return_type=return_type)
173
+ img_target = ensure_divisible_by_value(img_target, value=16, interpolate=Image.Resampling.LANCZOS, return_type=return_type)
174
+ mask_target = ensure_divisible_by_value(mask_target, value=16, interpolate=Image.Resampling.NEAREST, return_type=return_type)
175
+
176
+ return img_ref, img_target, mask_target
177
+
178
+
179
+ def get_mask_type_ids(custmization_mode: str, input_mask_mode: str) -> int:
180
+ if custmization_mode.lower() == "position-free":
181
+ return torch.tensor([0])
182
+ elif custmization_mode.lower() == "position-aware":
183
+ if "precise" in input_mask_mode.lower():
184
+ return torch.tensor([1])
185
+ else:
186
+ return torch.tensor([2])
187
+ else:
188
+ raise ValueError(f"Invalid custmization mode: {custmization_mode}")
189
+
190
+
191
+ def scale_image(image_np, is_mask: bool = False):
192
+ """
193
+ Scale the image to the range of [-1, 1] if not a mask, otherwise scale to [0, 1].
194
+
195
+ Args:
196
+ image_np (np.ndarray): Input image.
197
+ is_mask (bool): Whether the image is a mask.
198
+ Returns:
199
+ np.ndarray: Scaled image.
200
+ """
201
+ if is_mask:
202
+ image_np = image_np / 255.0
203
+ else:
204
+ image_np = image_np / 255.0
205
+ image_np = image_np * 2 - 1
206
+ return image_np
207
+
208
+
209
+ def get_sam_predictor(sam_ckpt_path, device):
210
+ """
211
+ Get the SAM predictor.
212
+ Args:
213
+ sam_ckpt_path (str): The path to the SAM checkpoint.
214
+ device (str): The device to load the model on.
215
+ Returns:
216
+ SamPredictor: The SAM predictor.
217
+ """
218
+ if not os.path.exists(sam_ckpt_path):
219
+ sam_ckpt_path = hf_hub_download(repo_id="HCMUE-Research/SAM-vit-h", filename="sam_vit_h_4b8939.pth")
220
+
221
+ sam = sam_model_registry['vit_h'](checkpoint=sam_ckpt_path).to(device)
222
+ sam.eval()
223
+ predictor = SamPredictor(sam)
224
+
225
+ return predictor
226
+
227
+
228
+ def image_to_base64(img):
229
+ """
230
+ Convert an image to a base64 string.
231
+ Args:
232
+ img (PIL.Image.Image): The image to convert.
233
+ Returns:
234
+ str: The base64 string.
235
+ """
236
+ buffered = BytesIO()
237
+ img.save(buffered, format="PNG")
238
+ img_bytes = buffered.getvalue()
239
+ return base64.b64encode(img_bytes).decode('utf-8')
240
+
241
+
242
+ def get_vlm(vlm_ckpt_path, device, torch_dtype):
243
+ """
244
+ Get the VLM pipeline.
245
+ Args:
246
+ vlm_ckpt_path (str): The path to the VLM checkpoint.
247
+ device (str): The device to load the model on.
248
+ torch_dtype (torch.dtype): The data type of the model.
249
+ Returns:
250
+ tuple: The processor and model.
251
+ """
252
+ if not os.path.exists(vlm_ckpt_path):
253
+ vlm_ckpt_path = "Qwen/Qwen2.5-VL-7B-Instruct"
254
+
255
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
256
+ vlm_ckpt_path, torch_dtype=torch_dtype).to(device)
257
+ processor = AutoProcessor.from_pretrained(vlm_ckpt_path)
258
+
259
+
260
+ return processor, model
261
+
262
+
263
+ def construct_vlm_gen_prompt(image_target, image_reference, target_mask, custmization_mode):
264
+ """
265
+ Construct the VLM generation prompt.
266
+ Args:
267
+ image_target (np.ndarray): The target image.
268
+ image_reference (np.ndarray): The reference image.
269
+ target_mask (np.ndarray): The target mask.
270
+ custmization_mode (str): The customization mode.
271
+ Returns:
272
+ list: The messages.
273
+ """
274
+ if custmization_mode.lower() == "position-free":
275
+ image_reference_pil = Image.fromarray(image_reference.astype(np.uint8))
276
+ image_reference_base_64 = image_to_base64(image_reference_pil)
277
+ messages = [
278
+ {
279
+ "role": "system",
280
+ "content": "I will input a reference image. Please identify the main subject/object in this image and generate a new description by placing this subject in a completely different scene or context. For example, if the reference image shows a rabbit sitting in a garden surrounded by green leaves and roses, you could generate a description like 'The rabbit is sitting on a rocky cliff overlooking a serene ocean, with the sun setting behind it, casting a warm glow over the scene'. Please directly output the new description without explaining your thought process. The description should not exceed 256 tokens."
281
+ },
282
+ {
283
+ "role": "user",
284
+ "content": [
285
+ {
286
+ "type": "image",
287
+ "image": f"data:image;base64,{image_reference_base_64}"
288
+ },
289
+ ],
290
+ }
291
+ ]
292
+ return messages
293
+ else:
294
+ image_reference_pil = Image.fromarray(image_reference.astype(np.uint8))
295
+ image_reference_base_64 = image_to_base64(image_reference_pil)
296
+
297
+ target_mask_binary = target_mask > 127.5
298
+ masked_image_target = image_target * target_mask_binary
299
+ masked_image_target_pil = Image.fromarray(masked_image_target.astype(np.uint8))
300
+ masked_image_target_base_64 = image_to_base64(masked_image_target_pil)
301
+
302
+
303
+ messages = [
304
+ {
305
+ "role": "system",
306
+ "content": "I will input a reference image and a target image with its main subject area masked (in black). Please directly describe the scene where the main subject/object from the reference image is placed into the masked area of the target image. Focus on describing the final combined scene, making sure to clearly describe both the object from the reference image and the background/environment from the target image. For example, if the reference shows a white cat with orange stripes on a beach and the target shows a masked area in a garden with blooming roses and tulips, directly describe 'A white cat with orange stripes sits elegantly among the vibrant red roses and yellow tulips in the lush garden, surrounded by green foliage.' The description should not exceed 256 tokens."
307
+ },
308
+ {
309
+ "role": "user",
310
+ "content": [
311
+ {
312
+ "type": "image",
313
+ "image": f"data:image;base64,{image_reference_base_64}"
314
+ },
315
+ {
316
+ "type": "image",
317
+ "image": f"data:image;base64,{masked_image_target_base_64}"
318
+ }
319
+ ],
320
+ }
321
+ ]
322
+ return messages
323
+
324
+
325
+ def construct_vlm_polish_prompt(prompt):
326
+ """
327
+ Construct the VLM polish prompt.
328
+ Args:
329
+ prompt (str): The prompt to polish.
330
+ Returns:
331
+ list: The messages.
332
+ """
333
+ messages = [
334
+ {
335
+ "role": "system",
336
+ "content": "You are a helpful assistant that can polish the text prompt to make it more specific, detailed, and complete. Please directly output the polished prompt without explaining your thought process. The prompt should not exceed 256 tokens."
337
+ },
338
+ {
339
+ "role": "user",
340
+ "content": prompt
341
+ }
342
+ ]
343
+ return messages
344
+
345
+
346
+ def run_vlm(vlm_processor, vlm_model, messages, device):
347
+ """
348
+ Run the VLM.
349
+ Args:
350
+ vlm_processor (torch.nn.Module): The VLM processor.
351
+ vlm_model (torch.nn.Module): The VLM model.
352
+ messages (list): The messages.
353
+ device (str): The device to run the model on.
354
+ Returns:
355
+ str: The output text.
356
+ """
357
+ text = vlm_processor.apply_chat_template(
358
+ messages, tokenize=False, add_generation_prompt=True)
359
+
360
+ image_inputs, video_inputs = process_vision_info(messages)
361
+ inputs = vlm_processor(
362
+ text=[text],
363
+ images=image_inputs,
364
+ videos=video_inputs,
365
+ padding=True,
366
+ return_tensors="pt",
367
+ )
368
+ inputs = inputs.to(device)
369
+ # Inference
370
+ generated_ids = vlm_model.generate(**inputs, do_sample=True, num_beams=4, temperature=1.5, max_new_tokens=128)
371
+ generated_ids_trimmed = [
372
+ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
373
+ ]
374
+ output_text = vlm_processor.batch_decode(
375
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
376
+ )[0]
377
+ return output_text
378
+
379
+
380
+ def get_ben2_model(ben2_model_path, device):
381
+ """
382
+ Get the BEN2 model.
383
+ Args:
384
+ ben2_model_path (str): The path to the BEN2 model.
385
+ device (str): The device to load the model on.
386
+ Returns:
387
+ BEN2: The BEN2 model.
388
+ """
389
+ if not os.path.exists(ben2_model_path):
390
+ ben2_model_path = hf_hub_download(repo_id="PramaLLC/BEN2", filename="BEN2_Base.pth")
391
+
392
+ ben2_model = BEN2.BEN_Base().to(device)
393
+ ben2_model.loadcheckpoints(model_path=ben2_model_path)
394
+ return ben2_model
395
+
396
+
397
+ def make_dict_img_mask(img_path, mask_path):
398
+ """
399
+ Make a dictionary of the image and mask for gr.ImageEditor.
400
+ Keep interface, not used in the gradio app.
401
+ Args:
402
+ img_path (str): The path to the image.
403
+ mask_path (str): The path to the mask.
404
+ Returns:
405
+ dict: The dictionary of the image and mask.
406
+ """
407
+ from PIL import ImageOps
408
+ background = Image.open(img_path).convert("RGBA")
409
+ layers = [
410
+ Image.merge("RGBA", (
411
+ Image.new("L", Image.open(mask_path).size, 255), # R channel
412
+ Image.new("L", Image.open(mask_path).size, 255), # G channel
413
+ Image.new("L", Image.open(mask_path).size, 255), # B channel
414
+ ImageOps.invert(Image.open(mask_path).convert("L")) # Inverted alpha channel
415
+ ))
416
+ ]
417
+ # Combine layers with background by replacing the alpha channel
418
+ background = np.array(background.convert("RGB"))
419
+ _, _, _, layer_alpha = layers[0].split()
420
+ layer_alpha = np.array(layer_alpha)[:,:,np.newaxis]
421
+ composite = background * (1 - (layer_alpha > 0)) + np.ones_like(background) * (layer_alpha > 0) * 255
422
+
423
+
424
+ composite = Image.fromarray(composite.astype("uint8")).convert("RGBA")
425
+ return {
426
+ 'background': background,
427
+ 'layers': layers,
428
+ 'composite': composite
429
+ }
assets/gradio/pos_aware/001/hypher_params.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7609eddfca9279636bdbbcaa25fc0c04b52816fa563caa220668a30526004e81
3
+ size 169
assets/gradio/pos_aware/001/img_gen.png ADDED

Git LFS Details

  • SHA256: c1bb90c2c7f1c0f3bdda840169e5800d52ac229deb6f2adcd6e2959e6db17b94
  • Pointer size: 131 Bytes
  • Size of remote file: 904 kB
assets/gradio/pos_aware/001/img_ref.png ADDED

Git LFS Details

  • SHA256: 4a0e009aaad8333c39d28148df6a4ce08efab636e6fc2a5302dc9ac7eb9c8260
  • Pointer size: 132 Bytes
  • Size of remote file: 1.27 MB
assets/gradio/pos_aware/001/img_target.png ADDED

Git LFS Details

  • SHA256: 32c3acbb178fe2eea27b1ab46e25470f8e972cfd832ee93a5dde44ab9311a773
  • Pointer size: 131 Bytes
  • Size of remote file: 666 kB
assets/gradio/pos_aware/001/mask_target.png ADDED

Git LFS Details

  • SHA256: 14f001bfb9ac31893fe3dbb76e8eb3640cec124a28f2c22e859fa532b7286e8b
  • Pointer size: 129 Bytes
  • Size of remote file: 7.26 kB
assets/gradio/pos_aware/002/hypher_params.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3e344b14750ee1f3ab780df3ce9871684f40cca81820ea0dcbc4896dd4b30a42
3
+ size 170
assets/gradio/pos_aware/002/img_gen.png ADDED

Git LFS Details

  • SHA256: ef5d260e6b9d25252f260e6c870ec028582835e285a31140d1d0ca7ae2fd0ab2
  • Pointer size: 132 Bytes
  • Size of remote file: 1.09 MB
assets/gradio/pos_aware/002/img_ref.png ADDED

Git LFS Details

  • SHA256: 1a58a77c7c149353b857536b60dddd8c6853c51a60a13f290c670b5f1718d897
  • Pointer size: 131 Bytes
  • Size of remote file: 236 kB
assets/gradio/pos_aware/002/img_target.png ADDED

Git LFS Details

  • SHA256: 25dd856cf5c114f440d7b1ea64df8135240c9bd29c6da072ec5c72d579b605b5
  • Pointer size: 131 Bytes
  • Size of remote file: 949 kB
assets/gradio/pos_aware/002/mask_target.png ADDED

Git LFS Details

  • SHA256: 651a59bf5c9d841540c84b3a28b54d7798cf1b5e2be10fbaf7cbf22321d3f14e
  • Pointer size: 129 Bytes
  • Size of remote file: 5.71 kB
assets/gradio/pos_aware/003/hypher_params.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:76e365ad0ac4b41937a6ff5d2817e73f3e662240b94834e844b59687c45c01f0
3
+ size 311
assets/gradio/pos_aware/003/img_gen.png ADDED

Git LFS Details

  • SHA256: 1caee153aaa03be649e33110092b2b3e63b6de8f5bd41115ccd315536f27ae5d
  • Pointer size: 132 Bytes
  • Size of remote file: 1.33 MB
assets/gradio/pos_aware/003/img_ref.png ADDED

Git LFS Details

  • SHA256: f295720c85f2d35de8e07f123a73a44a8338b6d8dce9617effbd639654afffa1
  • Pointer size: 131 Bytes
  • Size of remote file: 122 kB
assets/gradio/pos_aware/003/img_target.png ADDED

Git LFS Details

  • SHA256: 7de6458b1ed2cf0ce280f9662c39dcf4dd609c4f2e3991289c9b3de7ca04aa28
  • Pointer size: 131 Bytes
  • Size of remote file: 996 kB
assets/gradio/pos_aware/003/mask_target.png ADDED

Git LFS Details

  • SHA256: f8dd8b499cabec5615c745f87c4d604059d6e1c1933535a6e2c67039e720fc48
  • Pointer size: 129 Bytes
  • Size of remote file: 7.01 kB
assets/gradio/pos_aware/004/hypher_params.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6c8a478290f495492d4361f0c7b429f1bf216d73eb89cb460acb210d7bb53901
3
+ size 174
assets/gradio/pos_aware/004/img_gen.png ADDED

Git LFS Details

  • SHA256: 9ebe14a1bef8f7a92fb3f16d2119bf9d249624fcd54f49e547bd9c843b813877
  • Pointer size: 132 Bytes
  • Size of remote file: 1.45 MB
assets/gradio/pos_aware/004/img_ref.png ADDED

Git LFS Details

  • SHA256: 804a4b381965b46d8ccfc3c2f5167eec97a96104b6324e58afbf01bb995d1b23
  • Pointer size: 132 Bytes
  • Size of remote file: 1.04 MB
assets/gradio/pos_aware/004/img_target.png ADDED

Git LFS Details

  • SHA256: d2469e2c5b4bc736ece560438656376180b2de9c6c4f7d446112608ac6f3d0f1
  • Pointer size: 132 Bytes
  • Size of remote file: 1.32 MB
assets/gradio/pos_aware/004/mask_target.png ADDED

Git LFS Details

  • SHA256: 1d3f31df4d08aea1a689ac3cab96100d1d9d84cd6c73549e7f87c5915bbe4a19
  • Pointer size: 129 Bytes
  • Size of remote file: 8.43 kB
assets/gradio/pos_aware/005/hypher_params.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c0c8d875d0af12f57b70c596f70dfda334199990a819e96b67c8968538c7cdd6
3
+ size 173
assets/gradio/pos_aware/005/img_gen.png ADDED

Git LFS Details

  • SHA256: 776179662774adfc458f988a58b55e8cb2ee29a92e70cada52a4e4d8ead2e643
  • Pointer size: 131 Bytes
  • Size of remote file: 970 kB
assets/gradio/pos_aware/005/img_ref.png ADDED

Git LFS Details

  • SHA256: eb6cf9308603da24b1bd4b63eee9989024024cefdc3dff0fc23415762aaa1833
  • Pointer size: 131 Bytes
  • Size of remote file: 305 kB
assets/gradio/pos_aware/005/img_target.png ADDED

Git LFS Details

  • SHA256: f61b04b8a6174a25c68f34aca03f9fafb46e9b6b6ca69e9457ff6ffebc4493f3
  • Pointer size: 131 Bytes
  • Size of remote file: 937 kB
assets/gradio/pos_aware/005/mask_target.png ADDED

Git LFS Details

  • SHA256: 3ce455fb1a94b698a7ae656aaaafce1dcbb9e170cf40103a60a1936e58053ba0
  • Pointer size: 129 Bytes
  • Size of remote file: 4.44 kB
assets/gradio/pos_free/001/hyper_params.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"prompt": "TThe charming, soft plush toy is joyfully wandering through a lush, dense jungle, surrounded by vibrant green foliage and towering trees.", "custmization_mode": "Position-free", "input_mask_mode": "Precise mask", "seg_ref_mode": "Full Ref", "seed": 2126677963, "guidance": 40, "num_steps": 20, "num_images_per_prompt": 1, "use_background_preservation": false, "background_blend_threshold": 0.5, "true_gs": 3}
assets/gradio/pos_free/001/img_gen.png ADDED

Git LFS Details

  • SHA256: d131b70aab831c75a7f8d65bffeef7a7263e00b5e9767dbaa1b04b49549a3a93
  • Pointer size: 132 Bytes
  • Size of remote file: 1.48 MB
assets/gradio/pos_free/001/img_ref.png ADDED

Git LFS Details

  • SHA256: 31eb98b779029afcee3e0be48eeb6d5df3d2e8a76b60142fba5bc7632f1a083e
  • Pointer size: 131 Bytes
  • Size of remote file: 423 kB
assets/gradio/pos_free/001/img_target.png ADDED

Git LFS Details

  • SHA256: d9bd147ced77bca4a875af12714949cd17ddd3c11cc47218b4de30185bc0b4e9
  • Pointer size: 129 Bytes
  • Size of remote file: 5.33 kB
assets/gradio/pos_free/001/mask_target.png ADDED

Git LFS Details

  • SHA256: 79238174f0b3e2441720c46339b8cce2c8be2c19a1507831e726b51b8bbe3b82
  • Pointer size: 129 Bytes
  • Size of remote file: 3.13 kB
assets/gradio/pos_free/002/hyper_params.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"prompt": "A bright yellow alarm clock sits on a wooden desk next to a stack of books in a cozy, sunlit room.", "custmization_mode": "Position-free", "input_mask_mode": "Precise mask", "seg_ref_mode": "Full Ref", "seed": 2126677963, "guidance": 40, "num_steps": 20, "num_images_per_prompt": 1, "use_background_preservation": false, "background_blend_threshold": 0.5, "true_gs": 3}
assets/gradio/pos_free/002/img_gen.png ADDED

Git LFS Details

  • SHA256: 6acf6529520813f6b40dbe953267555e84a373c6d9532ee795adbc5538ffe14e
  • Pointer size: 131 Bytes
  • Size of remote file: 901 kB
assets/gradio/pos_free/002/img_ref.png ADDED

Git LFS Details

  • SHA256: 8e11cfbb8300d6191af71e5b8fb040f98fe7ce59b62002e5852c5b2c2044455f
  • Pointer size: 132 Bytes
  • Size of remote file: 1.12 MB
assets/gradio/pos_free/002/img_target.png ADDED

Git LFS Details

  • SHA256: d9bd147ced77bca4a875af12714949cd17ddd3c11cc47218b4de30185bc0b4e9
  • Pointer size: 129 Bytes
  • Size of remote file: 5.33 kB