gaparmar commited on
Commit
1930c69
·
1 Parent(s): fcb37cd
.gitignore ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ outputs
2
+ .gradio
3
+
4
+
5
+ # Byte-compiled / optimized / DLL files
6
+ __pycache__/
7
+ *.py[codz]
8
+ *$py.class
9
+
10
+ # C extensions
11
+ *.so
12
+
13
+ # Distribution / packaging
14
+ .Python
15
+ build/
16
+ develop-eggs/
17
+ dist/
18
+ downloads/
19
+ eggs/
20
+ .eggs/
21
+ lib/
22
+ lib64/
23
+ parts/
24
+ sdist/
25
+ var/
26
+ wheels/
27
+ share/python-wheels/
28
+ *.egg-info/
29
+ .installed.cfg
30
+ *.egg
31
+ MANIFEST
32
+
33
+ # PyInstaller
34
+ # Usually these files are written by a python script from a template
35
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
36
+ *.manifest
37
+ *.spec
38
+
39
+ # Installer logs
40
+ pip-log.txt
41
+ pip-delete-this-directory.txt
42
+
43
+ # Unit test / coverage reports
44
+ htmlcov/
45
+ .tox/
46
+ .nox/
47
+ .coverage
48
+ .coverage.*
49
+ .cache
50
+ nosetests.xml
51
+ coverage.xml
52
+ *.cover
53
+ *.py.cover
54
+ .hypothesis/
55
+ .pytest_cache/
56
+ cover/
57
+
58
+ # Translations
59
+ *.mo
60
+ *.pot
61
+
62
+ # Django stuff:
63
+ *.log
64
+ local_settings.py
65
+ db.sqlite3
66
+ db.sqlite3-journal
67
+
68
+ # Flask stuff:
69
+ instance/
70
+ .webassets-cache
71
+
72
+ # Scrapy stuff:
73
+ .scrapy
74
+
75
+ # Sphinx documentation
76
+ docs/_build/
77
+
78
+ # PyBuilder
79
+ .pybuilder/
80
+ target/
81
+
82
+ # Jupyter Notebook
83
+ .ipynb_checkpoints
84
+
85
+ # IPython
86
+ profile_default/
87
+ ipython_config.py
88
+
89
+ # pyenv
90
+ # For a library or package, you might want to ignore these files since the code is
91
+ # intended to run in multiple environments; otherwise, check them in:
92
+ # .python-version
93
+
94
+ # pipenv
95
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
96
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
97
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
98
+ # install all needed dependencies.
99
+ #Pipfile.lock
100
+
101
+ # UV
102
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
103
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
104
+ # commonly ignored for libraries.
105
+ #uv.lock
106
+
107
+ # poetry
108
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
109
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
110
+ # commonly ignored for libraries.
111
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
112
+ #poetry.lock
113
+ #poetry.toml
114
+
115
+ # pdm
116
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
117
+ # pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
118
+ # https://pdm-project.org/en/latest/usage/project/#working-with-version-control
119
+ #pdm.lock
120
+ #pdm.toml
121
+ .pdm-python
122
+ .pdm-build/
123
+
124
+ # pixi
125
+ # Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
126
+ #pixi.lock
127
+ # Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
128
+ # in the .venv directory. It is recommended not to include this directory in version control.
129
+ .pixi
130
+
131
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
132
+ __pypackages__/
133
+
134
+ # Celery stuff
135
+ celerybeat-schedule
136
+ celerybeat.pid
137
+
138
+ # SageMath parsed files
139
+ *.sage.py
140
+
141
+ # Environments
142
+ .env
143
+ .envrc
144
+ .venv
145
+ env/
146
+ venv/
147
+ ENV/
148
+ env.bak/
149
+ venv.bak/
150
+
151
+ # Spyder project settings
152
+ .spyderproject
153
+ .spyproject
154
+
155
+ # Rope project settings
156
+ .ropeproject
157
+
158
+ # mkdocs documentation
159
+ /site
160
+
161
+ # mypy
162
+ .mypy_cache/
163
+ .dmypy.json
164
+ dmypy.json
165
+
166
+ # Pyre type checker
167
+ .pyre/
168
+
169
+ # pytype static type analyzer
170
+ .pytype/
171
+
172
+ # Cython debug symbols
173
+ cython_debug/
174
+
175
+ # PyCharm
176
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
177
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
178
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
179
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
180
+ #.idea/
181
+
182
+ # Abstra
183
+ # Abstra is an AI-powered process automation framework.
184
+ # Ignore directories containing user credentials, local state, and settings.
185
+ # Learn more at https://abstra.io/docs
186
+ .abstra/
187
+
188
+ # Visual Studio Code
189
+ # Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
190
+ # that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
191
+ # and can be added to the global gitignore or merged into this file. However, if you prefer,
192
+ # you could uncomment the following to ignore the entire vscode folder
193
+ # .vscode/
194
+
195
+ # Ruff stuff:
196
+ .ruff_cache/
197
+
198
+ # PyPI configuration file
199
+ .pypirc
200
+
201
+ # Cursor
202
+ # Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to
203
+ # exclude from AI features like autocomplete and code analysis. Recommended for sensitive data
204
+ # refer to https://docs.cursor.com/context/ignore-files
205
+ .cursorignore
206
+ .cursorindexingignore
207
+
208
+ # Marimo
209
+ marimo/_static/
210
+ marimo/_lsp/
211
+ __marimo__/
app.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import spaces
3
+ import gradio as gr
4
+ import torch
5
+ import functools
6
+ import numpy as np
7
+ import torch.nn.functional as F
8
+ from diffusers import FluxPipeline, AutoencoderTiny
9
+ from transformers import CLIPProcessor, CLIPModel, AutoModel
10
+ from transformers.models.clip.modeling_clip import _get_vector_norm
11
+ from my_utils.group_inference import run_group_inference
12
+ from my_utils.default_values import apply_defaults
13
+ import argparse
14
+
15
+ pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell").to("cuda")
16
+ pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taef1").to("cuda")
17
+
18
+ m_clip = CLIPModel.from_pretrained("multimodalart/clip-vit-base-patch32").to("cuda")
19
+ prep_clip = CLIPProcessor.from_pretrained("multimodalart/clip-vit-base-patch32")
20
+ dino_model = AutoModel.from_pretrained('facebook/dinov2-base').to("cuda")
21
+
22
+ # Get default args for flux-schnell
23
+ default_args = argparse.Namespace(
24
+ model_name="flux-schnell",
25
+ prompt=None,
26
+ starting_candidates=None,
27
+ output_group_size=None,
28
+ pruning_ratio=None,
29
+ lambda_score=None,
30
+ seed=None,
31
+ unary_term="clip_text_img",
32
+ binary_term="diversity_dino",
33
+ guidance_scale=None,
34
+ num_inference_steps=None,
35
+ height=None,
36
+ width=None,
37
+ )
38
+ default_args = apply_defaults(default_args)
39
+
40
+
41
+ # Scoring functions
42
+ @torch.no_grad()
43
+ def unary_clip_text_img_score(l_images, target_caption, device="cuda"):
44
+ """Compute CLIP text-image similarity scores."""
45
+ _img_std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).view(1, 3, 1, 1).to(device)
46
+ _img_mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).view(1, 3, 1, 1).to(device)
47
+
48
+ b_images = torch.cat(l_images, dim=0)
49
+ b_images = F.interpolate(b_images, size=(224, 224), mode="bilinear", align_corners=False)
50
+ b_images = b_images * 0.5 + 0.5
51
+ b_images = (b_images - _img_mean) / _img_std
52
+
53
+ text_encoding = prep_clip.tokenizer(target_caption, return_tensors="pt", padding=True).to(device)
54
+ output = m_clip(pixel_values=b_images, **text_encoding).logits_per_image / m_clip.logit_scale.exp()
55
+ return output.view(-1).cpu().numpy()
56
+
57
+
58
+ @torch.no_grad()
59
+ def binary_dino_diversity_score(l_images, device="cuda"):
60
+ """Compute pairwise diversity scores using DINO."""
61
+ b_images = torch.cat(l_images, dim=0)
62
+ _img_mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
63
+ _img_std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
64
+
65
+ b_images = F.interpolate(b_images, size=(256, 256), mode="bilinear", align_corners=False)
66
+ b_images = b_images * 0.5 + 0.5
67
+ b_images = (b_images - _img_mean) / _img_std
68
+ all_features = dino_model(pixel_values=b_images).last_hidden_state[:, 1:, :].cpu()
69
+
70
+ N = len(l_images)
71
+ score_matrix = np.zeros((N, N))
72
+ for i in range(N):
73
+ f1 = all_features[i]
74
+ for j in range(i+1, N):
75
+ f2 = all_features[j]
76
+ cos_sim = (1 - F.cosine_similarity(f1, f2, dim=1)).mean().item()
77
+ score_matrix[i, j] = cos_sim
78
+ return score_matrix
79
+
80
+
81
+ @torch.no_grad()
82
+ def binary_dino_cls_score(l_images, device="cuda"):
83
+ """Compute pairwise diversity scores using DINO CLS tokens."""
84
+ b_images = torch.cat(l_images, dim=0)
85
+ _img_mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
86
+ _img_std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
87
+
88
+ b_images = F.interpolate(b_images, size=(256, 256), mode="bilinear", align_corners=False)
89
+ b_images = b_images * 0.5 + 0.5
90
+ b_images = (b_images - _img_mean) / _img_std
91
+ all_features = dino_model(pixel_values=b_images).last_hidden_state[:, 0:1, :].cpu()
92
+
93
+ N = len(l_images)
94
+ score_matrix = np.zeros((N, N))
95
+ for i in range(N):
96
+ f1 = all_features[i]
97
+ for j in range(i+1, N):
98
+ f2 = all_features[j]
99
+ cos_sim = (1 - F.cosine_similarity(f1, f2, dim=1)).mean().item()
100
+ score_matrix[i, j] = cos_sim
101
+ return score_matrix
102
+
103
+
104
+ @torch.no_grad()
105
+ def binary_clip_diversity_score(l_images, device="cuda"):
106
+ """Compute pairwise diversity scores using CLIP."""
107
+ _img_std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).view(1, 3, 1, 1).to(device)
108
+ _img_mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).view(1, 3, 1, 1).to(device)
109
+
110
+ b_images = torch.cat(l_images, dim=0)
111
+ b_images = F.interpolate(b_images, size=(224, 224), mode="bilinear", align_corners=False)
112
+ b_images = b_images * 0.5 + 0.5
113
+ b_images = (b_images - _img_mean) / _img_std
114
+
115
+ vision_outputs = m_clip.vision_model(
116
+ pixel_values=b_images,
117
+ output_attentions=False,
118
+ output_hidden_states=False,
119
+ interpolate_pos_encoding=False,
120
+ return_dict=True
121
+ )
122
+ image_embeds = m_clip.visual_projection(vision_outputs[1])
123
+ image_embeds = image_embeds / _get_vector_norm(image_embeds)
124
+
125
+ N = len(l_images)
126
+ score_matrix = np.zeros((N, N))
127
+ for i in range(N):
128
+ f1 = image_embeds[i]
129
+ for j in range(i+1, N):
130
+ f2 = image_embeds[j]
131
+ cos_sim = (1 - torch.dot(f1, f2)).item()
132
+ score_matrix[i, j] = cos_sim
133
+ return score_matrix
134
+
135
+
136
+ def get_score_functions(unary_term, binary_term, prompt):
137
+ """Get the appropriate scoring functions based on selected terms."""
138
+ # Unary score function (always CLIP for flux-schnell) - bind the prompt
139
+ unary_score_fn = functools.partial(unary_clip_text_img_score, target_caption=prompt, device="cuda")
140
+ # Binary score function
141
+ if binary_term == "diversity_dino":
142
+ binary_score_fn = functools.partial(binary_dino_diversity_score, device="cuda")
143
+ elif binary_term == "dino_cls_pairwise":
144
+ binary_score_fn = functools.partial(binary_dino_cls_score, device="cuda")
145
+ elif binary_term == "diversity_clip":
146
+ binary_score_fn = functools.partial(binary_clip_diversity_score, device="cuda")
147
+ else:
148
+ raise ValueError(f"Invalid binary term: {binary_term}")
149
+
150
+ return unary_score_fn, binary_score_fn
151
+
152
+
153
+ @spaces.GPU(duration=300)
154
+ def generate_images(prompt, starting_candidates, output_group_size, pruning_ratio,
155
+ lambda_score, seed, unary_term, binary_term, progress=gr.Progress(track_tqdm=True)):
156
+ """Generate images using group inference with progressive pruning."""
157
+
158
+ # Get scoring functions with prompt bound to unary function
159
+ unary_score_fn, binary_score_fn = get_score_functions(unary_term, binary_term, prompt)
160
+
161
+ # Create inference args
162
+ inference_args = {
163
+ "model_name": "flux-schnell",
164
+ "prompt": prompt,
165
+ "guidance_scale": default_args.guidance_scale,
166
+ "num_inference_steps": default_args.num_inference_steps,
167
+ "max_sequence_length": 256,
168
+ "height": default_args.height,
169
+ "width": default_args.width,
170
+ "unary_score_fn": unary_score_fn,
171
+ "binary_score_fn": binary_score_fn,
172
+ "output_group_size": output_group_size,
173
+ "pruning_ratio": pruning_ratio,
174
+ "lambda_score": lambda_score,
175
+ "l_generator": [torch.Generator("cpu").manual_seed(seed + i) for i in range(starting_candidates)],
176
+ "starting_candidates": starting_candidates,
177
+ "skip_first_cfg": True,
178
+ }
179
+ print(f"pruning ratio is: {pruning_ratio}")
180
+ # Run group inference
181
+ output_group = run_group_inference(pipe, **inference_args)
182
+ return output_group
183
+
184
+
185
+ # Load custom CSS
186
+ css_path = os.path.join(os.path.dirname(__file__), "styles.css")
187
+ with open(css_path, "r") as f:
188
+ custom_css = f.read()
189
+
190
+ # JavaScript to force light mode
191
+ js_func = """
192
+ function refresh() {
193
+ const url = new URL(window.location);
194
+ if (url.searchParams.get('__theme') !== 'light') {
195
+ url.searchParams.set('__theme', 'light');
196
+ window.location.href = url.href;
197
+ }
198
+ }
199
+ """
200
+
201
+ # Create Gradio interface
202
+ with gr.Blocks(css=custom_css, js=js_func, theme=gr.themes.Soft(), elem_id="main-container") as demo:
203
+
204
+ # Title and header
205
+ gr.HTML(
206
+ """
207
+ <div class="title_left">
208
+ <h1>Scaling Group Inference for Diverse and High-Quality Generation</h1>
209
+ <div class="author-container">
210
+ <div class="grid-item cmu"><a href="https://gauravparmar.com/">Gaurav Parmar</a></div>
211
+ <div class="grid-item snap"><a href="https://orpatashnik.github.io/">Or Patashnik</a></div>
212
+ <div class="grid-item snap"><a href="https://scholar.google.com/citations?user=uD79u6oAAAAJ&hl=en">Daniil Ostashev</a></div>
213
+ <div class="grid-item snap"><a href="https://wangkua1.github.io/">Kuan-Chieh (Jackson) Wang</a></div>
214
+ <div class="grid-item snap"><a href="https://kfiraberman.github.io/">Kfir Aberman</a></div>
215
+ </div>
216
+ <div class="author-container">
217
+ <div class="grid-item cmu"><a href="https://www.cs.cmu.edu/~srinivas/">Srinivasa Narasimhan</a></div>
218
+ <div class="grid-item cmu"><a href="https://www.cs.cmu.edu/~junyanz/">Jun-Yan Zhu</a></div>
219
+ </div>
220
+ <br>
221
+ <div class="affiliation-container">
222
+ <div class="grid-item cmu"> <p>Carnegie Mellon University</p></div>
223
+ <div class="grid-item snap"> <p>Snap Research</p></div>
224
+ </div>
225
+
226
+ <br>
227
+ <h2>DEMO: Text-to-Image Group Inference with FLUX.1-Schnell</h2>
228
+ </div>
229
+ """
230
+ )
231
+
232
+ with gr.Row(scale=1):
233
+ with gr.Column(scale=1.0):
234
+ prompt = gr.Textbox(label="Prompt", placeholder="A photo of a dog", lines=4, value="A photo of a dog")
235
+
236
+ with gr.Column(scale=1.0):
237
+ with gr.Row(elem_id="starting-candidates-row"):
238
+ gr.Text("Starting Candidates:", container=False, interactive=False, scale=5)
239
+ starting_candidates = gr.Number(value=default_args.starting_candidates, precision=0, container=False, show_label=False, scale=1)
240
+
241
+ with gr.Row(elem_id="output-group-size-row"):
242
+ gr.Text("Output Group Size:", container=False, interactive=False, scale=5)
243
+ output_group_size = gr.Number(value=default_args.output_group_size, precision=0, container=False, show_label=False, scale=1)
244
+
245
+ with gr.Column(scale=1.0):
246
+ with gr.Accordion("Advanced Options", open=False, elem_id="advanced-options-accordion"):
247
+ with gr.Row():
248
+ gr.Text("Pruning Ratio:", container=False, interactive=False, elem_id="pruning-ratio-label", scale=3)
249
+ pruning_ratio = gr.Number(value=default_args.pruning_ratio, precision=2, container=False, show_label=False, scale=1)
250
+
251
+ with gr.Row():
252
+ gr.Text("Lambda:", container=False, interactive=False, elem_id="lambda-label", scale=5)
253
+ lambda_score = gr.Number(value=default_args.lambda_score, precision=1, container=False, show_label=False, scale=1)
254
+
255
+ with gr.Row():
256
+ gr.Text("Seed:", container=False, interactive=False, elem_id="seed-label", scale=5)
257
+ seed = gr.Number(value=42, precision=0, container=False, show_label=False, scale=1)
258
+
259
+ with gr.Row():
260
+ gr.Text("Unary:", container=False, interactive=False, elem_id="unary-term-label", scale=2)
261
+ unary_term = gr.Dropdown(choices=["clip_text_img"], value=default_args.unary_term, container=False, show_label=False, scale=3)
262
+
263
+ with gr.Row():
264
+ gr.Text("Binary:", container=False, interactive=False, elem_id="binary-term-label", scale=2)
265
+ binary_term = gr.Dropdown(choices=["diversity_dino", "diversity_clip", "dino_cls_pairwise"], value=default_args.binary_term,
266
+ container=False, show_label=False, scale=3)
267
+
268
+ with gr.Row(scale=1):
269
+ generate_btn = gr.Button("Generate", variant="primary")
270
+
271
+ with gr.Row(scale=1):
272
+ output_gallery_group = gr.Gallery(label="Group Inference", show_label=True,elem_id="gallery", columns=4, height="auto")
273
+
274
+ gr.Examples(
275
+ examples=[
276
+ ["A photo of a dog", 64, 4, 0.5, 1.0, 42, "clip_text_img", "diversity_dino"],
277
+ ["A mountain landscape", 64, 4, 0.5, 1.0, 123, "clip_text_img", "diversity_dino"],
278
+ ["A cat sleeping", 64, 4, 0.5, 1.0, 456, "clip_text_img", "diversity_dino"],
279
+ ["A sunset at the beach", 64, 4, 0.5, 1.0, 789, "clip_text_img", "diversity_dino"],
280
+ ],
281
+ inputs=[prompt, starting_candidates, output_group_size, pruning_ratio, lambda_score, seed, unary_term, binary_term],
282
+ outputs=[output_gallery_group],
283
+ fn=generate_images,
284
+ cache_examples="lazy",
285
+ label="Examples"
286
+ )
287
+
288
+ generate_btn.click(
289
+ fn=generate_images,
290
+ inputs=[prompt, starting_candidates, output_group_size, pruning_ratio, lambda_score, seed, unary_term, binary_term],
291
+ outputs=[output_gallery_group]
292
+ )
293
+
294
+ demo.launch()
my_utils/default_values.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ DEFAULT_VALUES = {
2
+ "flux-schnell": {
3
+ "num_inference_steps": 4,
4
+ "guidance_scale": 0.0,
5
+ "starting_candidates": 32,
6
+ "output_group_size": 4,
7
+ "pruning_ratio": 0.9,
8
+ "lambda_score": 1.5,
9
+ "output_dir": "outputs/flux-schnell",
10
+ "height": 768,
11
+ "width": 768,
12
+ "unary_term": "clip_text_img",
13
+ "binary_term": "diversity_dino"
14
+ },
15
+ "flux-dev": {
16
+ "num_inference_steps": 20,
17
+ "guidance_scale": 3.5,
18
+ "starting_candidates": 128,
19
+ "output_group_size": 4,
20
+ "pruning_ratio": 0.5,
21
+ "lambda_score": 1.5,
22
+ "output_dir": "outputs/flux-dev",
23
+ "height": 768,
24
+ "width": 768,
25
+ "unary_term": "clip_text_img",
26
+ "binary_term": "diversity_dino"
27
+ },
28
+ "flux-depth": {
29
+ "num_inference_steps": 20,
30
+ "guidance_scale": 3.5,
31
+ "starting_candidates": 128,
32
+ "output_group_size": 4,
33
+ "pruning_ratio": 0.5,
34
+ "lambda_score": 1.5,
35
+ "output_dir": "outputs/flux-depth",
36
+ "height": 768,
37
+ "width": 768,
38
+ "unary_term": "clip_text_img",
39
+ "binary_term": "diversity_dino"
40
+ },
41
+ "flux-canny": {
42
+ "num_inference_steps": 20,
43
+ "guidance_scale": 3.5,
44
+ "starting_candidates": 128,
45
+ "output_group_size": 4,
46
+ "pruning_ratio": 0.5,
47
+ "lambda_score": 1.5,
48
+ "output_dir": "outputs/flux-canny",
49
+ "height": 768,
50
+ "width": 768,
51
+ "unary_term": "clip_text_img",
52
+ "binary_term": "diversity_dino"
53
+ },
54
+ "flux-kontext": {
55
+ "num_inference_steps": 28,
56
+ "guidance_scale": 3.5,
57
+ "starting_candidates": 128,
58
+ "output_group_size": 4,
59
+ "pruning_ratio": 0.5,
60
+ "lambda_score": 1.0,
61
+ "output_dir": "outputs/flux-kontext",
62
+ "height": 1024,
63
+ "width": 1024,
64
+ "unary_term": "clip_text_img",
65
+ "binary_term": "diversity_dino"
66
+ }
67
+ }
68
+
69
+ def apply_defaults(args):
70
+ model_name = args.model_name
71
+
72
+ if model_name not in DEFAULT_VALUES:
73
+ raise ValueError(f"Unknown model name: {model_name}. Available models: {list(DEFAULT_VALUES.keys())}")
74
+
75
+ defaults = DEFAULT_VALUES[model_name]
76
+
77
+ for param_name, default_value in defaults.items():
78
+ if hasattr(args, param_name) and getattr(args, param_name) is None:
79
+ setattr(args, param_name, default_value)
80
+
81
+ return args
my_utils/group_inference.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+ import math
3
+ import torch
4
+ import spaces
5
+ import numpy as np
6
+ from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
7
+ from diffusers.pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps
8
+
9
+ from my_utils.solvers import gurobi_solver
10
+
11
+
12
+ def get_next_size(curr_size, final_size, keep_ratio):
13
+ """Calculate next size for progressive pruning during denoising.
14
+
15
+ Args:
16
+ curr_size: Current number of candidates
17
+ final_size: Target final size
18
+ keep_ratio: Fraction of candidates to keep at each step
19
+ """
20
+ if curr_size < final_size:
21
+ raise ValueError("Current size is less than the final size!")
22
+ elif curr_size == final_size:
23
+ return curr_size
24
+ else:
25
+ next_size = math.ceil(curr_size * keep_ratio)
26
+ return max(next_size, final_size)
27
+
28
+
29
+ @torch.no_grad()
30
+ def decode_latent(z, pipe, height, width):
31
+ """Decode latent tensor to image using VAE decoder.
32
+
33
+ Args:
34
+ z: Latent tensor to decode
35
+ pipe: Diffusion pipeline with VAE
36
+ height: Image height
37
+ width: Image width
38
+ """
39
+ z = pipe._unpack_latents(z, height, width, pipe.vae_scale_factor)
40
+ z = (z / pipe.vae.config.scaling_factor) + pipe.vae.config.shift_factor
41
+ z = pipe.vae.decode(z, return_dict=False)[0].clamp(-1,1)
42
+ return z
43
+
44
+
45
+ @torch.no_grad()
46
+ @spaces.GPU(duration=300)
47
+ def run_group_inference(pipe, model_name=None, prompt=None, prompt_2=None, negative_prompt=None, negative_prompt_2=None,
48
+ true_cfg_scale=1.0, height=None, width=None, num_inference_steps=28, sigmas=None, guidance_scale=3.5,
49
+ l_generator=None, max_sequence_length=512,
50
+ # group inference arguments
51
+ unary_score_fn=None, binary_score_fn=None,
52
+ starting_candidates=None, output_group_size=None, pruning_ratio=None, lambda_score=None,
53
+ # control arguments
54
+ control_image=None,
55
+ # input image for flux-kontext
56
+ input_image=None,
57
+ skip_first_cfg=True
58
+ ):
59
+ """Run group inference with progressive pruning for diverse, high-quality image generation.
60
+
61
+ Args:
62
+ pipe: Diffusion pipeline
63
+ model_name: Model type (flux-schnell, flux-dev, flux-depth, flux-canny, flux-kontext)
64
+ prompt: Text prompt for generation
65
+ unary_score_fn: Function to compute image quality scores
66
+ binary_score_fn: Function to compute pairwise diversity scores
67
+ starting_candidates: Initial number of noise samples
68
+ output_group_size: Final number of images to generate
69
+ pruning_ratio: Fraction to prune at each denoising step
70
+ lambda_score: Weight between quality and diversity terms
71
+ control_image: Control image for depth/canny models
72
+ input_image: Input image for flux-kontext editing
73
+ """
74
+ if l_generator is None:
75
+ l_generator = [torch.Generator("cpu").manual_seed(42+_seed) for _seed in range(starting_candidates)]
76
+
77
+ # use the default height and width if not provided
78
+ height = height or pipe.default_sample_size * pipe.vae_scale_factor
79
+ width = width or pipe.default_sample_size * pipe.vae_scale_factor
80
+
81
+ pipe._guidance_scale = guidance_scale
82
+ pipe._current_timestep = None
83
+ pipe._interrupt = False
84
+ pipe._joint_attention_kwargs = {}
85
+
86
+ device = pipe._execution_device
87
+
88
+ lora_scale = None
89
+ has_neg_prompt = negative_prompt is not None
90
+ do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
91
+
92
+ # 3. Encode prompts
93
+ prompt_embeds, pooled_prompt_embeds, text_ids = pipe.encode_prompt(prompt=prompt, prompt_2=prompt_2, prompt_embeds=None, pooled_prompt_embeds=None, device=device, max_sequence_length=max_sequence_length, lora_scale=lora_scale)
94
+
95
+ if do_true_cfg:
96
+ negative_prompt_embeds, negative_pooled_prompt_embeds, _ = pipe.encode_prompt(prompt=negative_prompt, prompt_2=negative_prompt_2, prompt_embeds=None, pooled_prompt_embeds=None, device=device, max_sequence_length=max_sequence_length, lora_scale=lora_scale)
97
+
98
+ # 4. Prepare latent variables
99
+ if model_name in ["flux-depth", "flux-canny"]:
100
+ # for control models, the pipe.transformer.config.in_channels is doubled
101
+ num_channels_latents = pipe.transformer.config.in_channels // 8
102
+ else:
103
+ num_channels_latents = pipe.transformer.config.in_channels // 4
104
+
105
+ # Handle different model types
106
+ image_latents = None
107
+ image_ids = None
108
+ if model_name == "flux-kontext":
109
+ processed_image = pipe.image_processor.preprocess(input_image, height=height, width=width)
110
+ l_latents = []
111
+ for _gen in l_generator:
112
+ latents, img_latents, latent_ids, img_ids = pipe.prepare_latents(
113
+ processed_image, 1, num_channels_latents, height, width,
114
+ prompt_embeds.dtype, device, _gen
115
+ )
116
+ l_latents.append(latents)
117
+ # Use the image_latents and image_ids from the first generator
118
+ _, image_latents, latent_image_ids, image_ids = pipe.prepare_latents(
119
+ processed_image, 1, num_channels_latents, height, width,
120
+ prompt_embeds.dtype, device, l_generator[0]
121
+ )
122
+ # Combine latent_ids with image_ids
123
+ if image_ids is not None:
124
+ latent_image_ids = torch.cat([latent_image_ids, image_ids], dim=0)
125
+ else:
126
+ # For other models (flux-schnell, flux-dev, flux-depth, flux-canny)
127
+ l_latents = [pipe.prepare_latents(1, num_channels_latents, height, width, prompt_embeds.dtype, device, _gen)[0] for _gen in l_generator]
128
+ _, latent_image_ids = pipe.prepare_latents(1, num_channels_latents, height, width, prompt_embeds.dtype, device, l_generator[0])
129
+
130
+ # 4.5. Prepare control image if provided
131
+ control_latents = None
132
+ if model_name in ["flux-depth", "flux-canny"]:
133
+ control_image_processed = pipe.prepare_image(image=control_image, width=width, height=height, batch_size=1, num_images_per_prompt=1, device=device, dtype=pipe.vae.dtype,)
134
+ if control_image_processed.ndim == 4:
135
+ control_latents = pipe.vae.encode(control_image_processed).latents
136
+ control_latents = (control_latents - pipe.vae.config.shift_factor) * pipe.vae.config.scaling_factor
137
+ height_control_image, width_control_image = control_latents.shape[2:]
138
+ control_latents = pipe._pack_latents(control_latents, 1, num_channels_latents, height_control_image, width_control_image)
139
+
140
+ # 5. Prepare timesteps
141
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
142
+ image_seq_len = latent_image_ids.shape[0]
143
+ mu = calculate_shift(image_seq_len, pipe.scheduler.config.get("base_image_seq_len", 256), pipe.scheduler.config.get("max_image_seq_len", 4096), pipe.scheduler.config.get("base_shift", 0.5), pipe.scheduler.config.get("max_shift", 1.15))
144
+ timesteps, num_inference_steps = retrieve_timesteps(pipe.scheduler, num_inference_steps, device, sigmas=sigmas, mu=mu)
145
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * pipe.scheduler.order, 0)
146
+ pipe._num_timesteps = len(timesteps)
147
+ _dtype = l_latents[0].dtype
148
+
149
+ # handle guidance
150
+ if pipe.transformer.config.guidance_embeds:
151
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32).expand(1)
152
+ else:
153
+ guidance = None
154
+ guidance_1 = torch.full([1], 1.0, device=device, dtype=torch.float32).expand(1)
155
+
156
+ # 6. Denoising loop
157
+ with pipe.progress_bar(total=num_inference_steps) as progress_bar:
158
+ for i, t in enumerate(timesteps):
159
+ if pipe.interrupt:
160
+ continue
161
+ if guidance is not None and skip_first_cfg and i == 0:
162
+ curr_guidance = guidance_1
163
+ else:
164
+ curr_guidance = guidance
165
+
166
+ pipe._current_timestep = t
167
+ timestep = t.expand(1).to(_dtype)
168
+ # ipdb.set_trace()
169
+ next_latents = []
170
+ x0_preds = []
171
+ # do 1 denoising step
172
+ for _latent in l_latents:
173
+ # prepare input for transformer based on model type
174
+ if model_name in ["flux-depth", "flux-canny"]:
175
+ # Control models: concatenate control latents along dim=2
176
+ latent_model_input = torch.cat([_latent, control_latents], dim=2)
177
+ elif model_name == "flux-kontext":
178
+ # Kontext model: concatenate image latents along dim=1
179
+ latent_model_input = torch.cat([_latent, image_latents], dim=1)
180
+ else:
181
+ # Standard models (flux-schnell, flux-dev): use latents as is
182
+ latent_model_input = _latent
183
+
184
+ noise_pred = pipe.transformer(hidden_states=latent_model_input, timestep=timestep / 1000, guidance=curr_guidance, pooled_projections=pooled_prompt_embeds, encoder_hidden_states=prompt_embeds, txt_ids=text_ids, img_ids=latent_image_ids, joint_attention_kwargs=pipe.joint_attention_kwargs, return_dict=False)[0]
185
+
186
+ # For flux-kontext, we need to slice the noise_pred to match the latents size
187
+ if model_name == "flux-kontext":
188
+ noise_pred = noise_pred[:, : _latent.size(1)]
189
+
190
+ if do_true_cfg:
191
+ neg_noise_pred = pipe.transformer(hidden_states=latent_model_input, timestep=timestep / 1000, guidance=curr_guidance, pooled_projections=negative_pooled_prompt_embeds, encoder_hidden_states=negative_prompt_embeds, txt_ids=text_ids, img_ids=latent_image_ids, joint_attention_kwargs=pipe.joint_attention_kwargs, return_dict=False)[0]
192
+ if model_name == "flux-kontext":
193
+ neg_noise_pred = neg_noise_pred[:, : _latent.size(1)]
194
+ noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
195
+ # compute the previous noisy sample x_t -> x_t-1
196
+ _latent = pipe.scheduler.step(noise_pred, t, _latent, return_dict=False)[0]
197
+ # the scheduler is not state-less, it maintains a step index that is incremented by one after each step,
198
+ # so we need to decrease it here
199
+ if hasattr(pipe.scheduler, "_step_index"):
200
+ pipe.scheduler._step_index -= 1
201
+
202
+ if type(pipe.scheduler) == FlowMatchEulerDiscreteScheduler:
203
+ dt = 0.0 - pipe.scheduler.sigmas[i]
204
+ x0_pred = _latent + dt * noise_pred
205
+ else:
206
+ raise NotImplementedError("Only Flow Scheduler is supported for now! For other schedulers, you need to manually implement the x0 prediction step.")
207
+
208
+ x0_preds.append(x0_pred)
209
+ next_latents.append(_latent)
210
+
211
+ if hasattr(pipe.scheduler, "_step_index"):
212
+ pipe.scheduler._step_index += 1
213
+
214
+ # if the size of next_latents > output_group_size, prune the latents
215
+ if len(next_latents) > output_group_size:
216
+ next_size = get_next_size(len(next_latents), output_group_size, 1 - pruning_ratio)
217
+ print(f"Pruning from {len(next_latents)} to {next_size}")
218
+ # decode the latents to pixels with tiny-vae
219
+ l_x0_decoded = [decode_latent(_latent, pipe, height, width) for _latent in x0_preds]
220
+ # compute the unary and binary scores
221
+ l_unary_scores = unary_score_fn(l_x0_decoded, target_caption=prompt)
222
+ M_binary_scores = binary_score_fn(l_x0_decoded) # upper triangular matrix
223
+ # run with Quadratic Integer Programming sover
224
+ selected_indices = gurobi_solver(l_unary_scores, M_binary_scores, next_size, lam=lambda_score)
225
+ l_latents = [next_latents[_i] for _i in selected_indices]
226
+ else:
227
+ l_latents = next_latents
228
+
229
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % pipe.scheduler.order == 0):
230
+ progress_bar.update()
231
+
232
+ pipe._current_timestep = None
233
+
234
+ l_images = [pipe._unpack_latents(_latent, height, width, pipe.vae_scale_factor) for _latent in l_latents]
235
+ l_images = [(latents / pipe.vae.config.scaling_factor) + pipe.vae.config.shift_factor for latents in l_images]
236
+ l_images = [pipe.vae.decode(_image, return_dict=False)[0] for _image in l_images]
237
+ l_images_tensor = [image.clamp(-1, 1) for image in l_images] # Keep tensor version for scoring
238
+ l_images = [pipe.image_processor.postprocess(image, output_type="pil")[0] for image in l_images]
239
+
240
+ # Compute and print final scores
241
+ print(f"\n=== Final Scores for {len(l_images)} generated images ===")
242
+
243
+ # Compute unary scores
244
+ final_unary_scores = unary_score_fn(l_images_tensor, target_caption=prompt)
245
+ print(f"Unary scores (quality): {final_unary_scores}")
246
+ print(f"Mean unary score: {np.mean(final_unary_scores):.4f}")
247
+
248
+ # Compute binary scores
249
+ final_binary_scores = binary_score_fn(l_images_tensor)
250
+ print(f"Binary score matrix (diversity):")
251
+ print(final_binary_scores)
252
+
253
+ print("=" * 50)
254
+
255
+ pipe.maybe_free_model_hooks()
256
+ return l_images
257
+
my_utils/scores.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import torch
3
+ import numpy as np
4
+ import torch.nn.functional as F
5
+ import torch.nn as nn
6
+ import torchvision.models as models
7
+ import torchvision.transforms as transforms
8
+ import cv2
9
+
10
+ from transformers import CLIPProcessor, CLIPModel, AutoModel
11
+ from transformers.models.clip.modeling_clip import _get_vector_norm
12
+
13
+
14
+
15
+ def validate_tensor_list(tensor_list, expected_type=torch.Tensor, min_dims=None, value_range=None, tolerance=0.1):
16
+ """
17
+ Validates a list of tensors with specified requirements.
18
+
19
+ Args:
20
+ tensor_list: List to validate
21
+ expected_type: Expected type of each element (default: torch.Tensor)
22
+ min_dims: Minimum number of dimensions each tensor should have
23
+ value_range: Tuple of (min_val, max_val) for tensor values
24
+ tolerance: Tolerance for value range checking (default: 0.1)
25
+ """
26
+ if not isinstance(tensor_list, list):
27
+ raise TypeError(f"Input must be a list, got {type(tensor_list)}")
28
+
29
+ if len(tensor_list) == 0:
30
+ raise ValueError("Input list cannot be empty")
31
+
32
+ for i, item in enumerate(tensor_list):
33
+ if not isinstance(item, expected_type):
34
+ raise TypeError(f"List element [{i}] must be {expected_type}, got {type(item)}")
35
+
36
+ if min_dims is not None and len(item.shape) < min_dims:
37
+ raise ValueError(f"List element [{i}] must have at least {min_dims} dimensions, got shape {item.shape}")
38
+
39
+ if value_range is not None:
40
+ min_val, max_val = value_range
41
+ item_min, item_max = item.min().item(), item.max().item()
42
+ if item_min < (min_val - tolerance) or item_max > (max_val + tolerance):
43
+ raise ValueError(f"List element [{i}] values must be in range [{min_val}, {max_val}], got range [{item_min:.3f}, {item_max:.3f}]")
44
+
45
+
46
+
47
+ def build_score_fn(name, device="cuda"):
48
+ """Build scoring functions for image quality and diversity assessment.
49
+
50
+ Args:
51
+ name: Score function name (clip_text_img, diversity_dino, dino_cls_pairwise, diversity_clip)
52
+ device: Device to load models on (default: cuda)
53
+ """
54
+ d_score_nets = {}
55
+
56
+ if name == "clip_text_img":
57
+ m_clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
58
+ prep_clip = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
59
+ score_fn = functools.partial(unary_clip_text_img_t, device=device, m_clip=m_clip, preprocess_clip=prep_clip)
60
+ d_score_nets["m_clip"] = m_clip
61
+ d_score_nets["prep_clip"] = prep_clip
62
+
63
+ elif name == "diversity_dino":
64
+ dino_model = AutoModel.from_pretrained('facebook/dinov2-base').to(device)
65
+ score_fn = functools.partial(binary_dino_pairwise_t, device=device, dino_model=dino_model)
66
+ d_score_nets["dino_model"] = dino_model
67
+
68
+ elif name == "dino_cls_pairwise":
69
+ dino_model = AutoModel.from_pretrained('facebook/dinov2-base').to(device)
70
+ score_fn = functools.partial(binary_dino_cls_pairwise_t, device=device, dino_model=dino_model)
71
+ d_score_nets["dino_model"] = dino_model
72
+
73
+ elif name == "diversity_clip":
74
+ m_clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
75
+ prep_clip = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
76
+ score_fn = functools.partial(binary_clip_pairwise_t, device=device, m_clip=m_clip, preprocess_clip=prep_clip)
77
+ d_score_nets["m_clip"] = m_clip
78
+ d_score_nets["prep_clip"] = prep_clip
79
+
80
+ else:
81
+ raise ValueError(f"Invalid score function name: {name}")
82
+
83
+ return score_fn, d_score_nets
84
+
85
+
86
+ @torch.no_grad()
87
+ def unary_clip_text_img_t(l_images, device, m_clip, preprocess_clip, target_caption, d_cache=None):
88
+ """Compute CLIP text-image similarity scores for a list of images.
89
+
90
+ Args:
91
+ l_images: List of image tensors in range [-1, 1]
92
+ device: Device for computation
93
+ m_clip: CLIP model
94
+ preprocess_clip: CLIP processor
95
+ target_caption: Text prompt for similarity comparison
96
+ d_cache: Optional cached text embeddings
97
+ """
98
+ # validate input images, l_images should be a list of torch tensors with range [-1, 1]
99
+ validate_tensor_list(l_images, expected_type=torch.Tensor, min_dims=3, value_range=(-1, 1))
100
+
101
+ _img_std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).view(1, 3, 1, 1).to(device)
102
+ _img_mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).view(1, 3, 1, 1).to(device)
103
+ b_images = torch.cat(l_images, dim=0)
104
+ b_images = F.interpolate(b_images, size=(224, 224), mode="bilinear", align_corners=False)
105
+ # re-normalize with clip mean and std
106
+ b_images = b_images*0.5 + 0.5
107
+ b_images = (b_images - _img_mean) / _img_std
108
+
109
+ if d_cache is None:
110
+ text_encoding = preprocess_clip.tokenizer(target_caption, return_tensors="pt", padding=True).to(device)
111
+ output = m_clip(pixel_values=b_images, **text_encoding).logits_per_image /m_clip.logit_scale.exp()
112
+ _score = output.view(-1).cpu().numpy()
113
+ else:
114
+ # compute with cached text embeddings
115
+ vision_outputs = m_clip.vision_model(pixel_values=b_images, output_attentions=False, output_hidden_states=False,
116
+ interpolate_pos_encoding=False, return_dict=True,)
117
+ image_embeds = m_clip.visual_projection(vision_outputs[1])
118
+ image_embeds = image_embeds / _get_vector_norm(image_embeds)
119
+ text_embeds = d_cache["text_embeds"]
120
+ _score = torch.matmul(text_embeds, image_embeds.t().to(text_embeds.device)).t().view(-1).cpu().numpy()
121
+
122
+ return _score
123
+
124
+
125
+ @torch.no_grad()
126
+ def binary_dino_pairwise_t(l_images, device, dino_model):
127
+ """Compute pairwise diversity scores using DINO patch features.
128
+
129
+ Args:
130
+ l_images: List of image tensors in range [-1, 1]
131
+ device: Device for computation
132
+ dino_model: DINO model for feature extraction
133
+ """
134
+ # validate input images, l_images should be a list of torch tensors with range [-1, 1]
135
+ validate_tensor_list(l_images, expected_type=torch.Tensor, min_dims=3, value_range=(-1, 1))
136
+
137
+ b_images = torch.cat(l_images, dim=0)
138
+ _img_mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
139
+ _img_std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
140
+
141
+ b_images = F.interpolate(b_images, size=(256, 256), mode="bilinear", align_corners=False)
142
+ b_images = b_images*0.5 + 0.5
143
+ b_images = (b_images - _img_mean) / _img_std
144
+ all_features = dino_model(pixel_values=b_images).last_hidden_state[:, 1:, :].cpu() # B, 324, 768
145
+
146
+ N = len(l_images)
147
+ score_matrix = np.zeros((N, N))
148
+ for i in range(N):
149
+ f1 = all_features[i]
150
+ for j in range(i+1, N):
151
+ f2 = all_features[j]
152
+ cos_sim = (1 - F.cosine_similarity(f1, f2, dim=1)).mean().item()
153
+ score_matrix[i, j] = cos_sim
154
+ return score_matrix
155
+
156
+ @torch.no_grad()
157
+ def binary_dino_cls_pairwise_t(l_images, device, dino_model):
158
+ """Compute pairwise diversity scores using DINO CLS token features.
159
+
160
+ Args:
161
+ l_images: List of image tensors in range [-1, 1]
162
+ device: Device for computation
163
+ dino_model: DINO model for feature extraction
164
+ """
165
+ # validate input images, l_images should be a list of torch tensors with range [-1, 1]
166
+ validate_tensor_list(l_images, expected_type=torch.Tensor, min_dims=3, value_range=(-1, 1))
167
+
168
+ b_images = torch.cat(l_images, dim=0)
169
+ _img_mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
170
+ _img_std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
171
+
172
+ b_images = F.interpolate(b_images, size=(256, 256), mode="bilinear", align_corners=False)
173
+ b_images = b_images*0.5 + 0.5
174
+ b_images = (b_images - _img_mean) / _img_std
175
+ all_features = dino_model(pixel_values=b_images).last_hidden_state[:, 0:1, :].cpu() # B, 1, 768
176
+
177
+ N = len(l_images)
178
+ score_matrix = np.zeros((N, N))
179
+ for i in range(N):
180
+ f1 = all_features[i]
181
+ for j in range(i+1, N):
182
+ f2 = all_features[j]
183
+ cos_sim = (1 - F.cosine_similarity(f1, f2, dim=1)).mean().item()
184
+ score_matrix[i, j] = cos_sim
185
+ return score_matrix
186
+
187
+ @torch.no_grad()
188
+ def binary_clip_pairwise_t(l_images, device, m_clip, preprocess_clip):
189
+ """Compute pairwise diversity scores using CLIP image embeddings.
190
+
191
+ Args:
192
+ l_images: List of image tensors in range [-1, 1]
193
+ device: Device for computation
194
+ m_clip: CLIP model
195
+ preprocess_clip: CLIP processor
196
+ """
197
+ # validate input images, l_images should be a list of torch tensors with range [-1, 1]
198
+ validate_tensor_list(l_images, expected_type=torch.Tensor, min_dims=3, value_range=(-1, 1))
199
+
200
+ _img_std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).view(1, 3, 1, 1).to(device)
201
+ _img_mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).view(1, 3, 1, 1).to(device)
202
+ b_images = torch.cat(l_images, dim=0)
203
+ b_images = F.interpolate(b_images, size=(224, 224), mode="bilinear", align_corners=False)
204
+ # re-normalize with clip mean and std
205
+ b_images = b_images*0.5 + 0.5
206
+ b_images = (b_images - _img_mean) / _img_std
207
+
208
+ vision_outputs = m_clip.vision_model(pixel_values=b_images, output_attentions=False, output_hidden_states=False,
209
+ interpolate_pos_encoding=False, return_dict=True,)
210
+ image_embeds = m_clip.visual_projection(vision_outputs[1])
211
+ image_embeds = image_embeds / _get_vector_norm(image_embeds)
212
+
213
+ N = len(l_images)
214
+ score_matrix = np.zeros((N, N))
215
+ for i in range(N):
216
+ f1 = image_embeds[i]
217
+ for j in range(i+1, N):
218
+ f2 = image_embeds[j]
219
+ cos_sim = (1 - torch.dot(f1, f2)).item()
220
+ score_matrix[i, j] = cos_sim
221
+ return score_matrix
my_utils/solvers.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from gurobipy import Model, GRB, quicksum
2
+
3
+
4
+ def gurobi_solver(u, D, n_select, lam=1.0, time_limit=5.0):
5
+ """Solve quadratic integer programming problem for subset selection with unary and pairwise terms.
6
+
7
+ Args:
8
+ u: Unary scores for each item
9
+ D: Pairwise similarity matrix (upper triangular)
10
+ n_select: Number of items to select
11
+ lam: Weight for pairwise term (default: 1.0)
12
+ time_limit: Solver time limit in seconds (default: 5.0)
13
+ """
14
+ n = len(u)
15
+ model = Model()
16
+ model.Params.LogToConsole = 0
17
+ model.Params.TimeLimit = time_limit
18
+ model.Params.OutputFlag = 0
19
+
20
+ # Variables: x[i] in {0,1}
21
+ x = model.addVars(n, vtype=GRB.BINARY, name="x")
22
+ # Constraint: exactly k items selected
23
+ model.addConstr(quicksum(x[i] for i in range(n)) == n_select, name="select_k")
24
+
25
+ # Objective: sum of unary + lambda * pairwise
26
+ linear_part = quicksum(u[i] * x[i] for i in range(n))
27
+ quadratic_part = quicksum(lam * D[i, j] * x[i] * x[j] for i in range(n) for j in range(i + 1, n))
28
+
29
+ model.setObjective(linear_part + quadratic_part, GRB.MAXIMIZE)
30
+
31
+ model.optimize()
32
+ selected_indices = [i for i in range(n) if x[i].X > 0.5]
33
+ return selected_indices
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.7.1
2
+ torchvision==0.22.1
3
+ torchaudio==2.7.1
4
+ opencv-python
5
+ transformers
6
+ sentencepiece
7
+ protobuf
8
+ accelerate
9
+ diffusers==0.35.1
10
+ gurobipy
11
+ bitsandbytes
12
+ git+https://github.com/openai/CLIP.git
13
+ ipdb
14
+ https://github.com/nunchaku-tech/nunchaku/releases/download/v0.3.1/nunchaku-0.3.1+torch2.7-cp310-cp310-linux_x86_64.whl
styles.css ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ @import url('https://fonts.googleapis.com/css2?family=Varela+Round&display=swap');
2
+
3
+ ::selection {
4
+ background: rgba(255, 251, 35, 0.58);
5
+ }
6
+
7
+ body {
8
+ max-width: 1000px !important;
9
+ margin: 0 auto !important;
10
+ }
11
+
12
+ .title_left {
13
+ padding-top: 0vw !important;
14
+ filter: none !important;
15
+ }
16
+ .title_left > h1 {
17
+ color: #2f2f2f !important;
18
+ font-family: "Gelasio",Georgia,serif !important;
19
+ font-weight: normal !important;
20
+ font-size: 2.0vw !important;
21
+ text-align: center !important;
22
+ }
23
+
24
+ .title_left > h2 {
25
+ color: #2f2f2f !important;
26
+ font-family: "Gelasio",Georgia,serif !important;
27
+ font-weight: normal !important;
28
+ font-size: 1.5vw !important;
29
+ text-align: center !important;
30
+ }
31
+
32
+ .author-container {
33
+ color: #2f2f2f;
34
+ font-family: Gelasio,"Avenir Next",Helvetica,sans-serif;
35
+ font-weight: normal;
36
+ font-size: 1vw;
37
+ padding-top: 0.2vw;
38
+ justify-items: center;
39
+ justify-content: center;
40
+ display: grid;
41
+ grid-template-columns: auto auto auto auto auto;
42
+ }
43
+
44
+ .affiliation-container {
45
+ color: #2f2f2f;
46
+ font-family: Gelasio,"Avenir Next",Helvetica,sans-serif;
47
+ font-weight: normal;
48
+ font-size: 1vw;
49
+ padding-top: 0.2vw;
50
+ justify-items: center;
51
+ justify-content: center;
52
+ display: grid;
53
+ grid-template-columns: auto auto auto auto auto;
54
+ }
55
+
56
+ .grid-item {
57
+ text-align: center;
58
+ padding-right: 0.7vw;
59
+ padding-left: 0.7vw;
60
+ }
61
+
62
+ .grid-item > a {
63
+ color: #2f2f2f;
64
+ text-decoration: underline;
65
+ text-underline-offset: 3px;
66
+ }
67
+
68
+ .grid-item.cmu > a {
69
+ text-decoration-color: rgba(196, 18, 48, 0.2)
70
+ }
71
+
72
+ .grid-item.cmu > p::before {
73
+ content: "";
74
+ display: inline-block;
75
+ width: 12px;
76
+ height: 12px;
77
+ background-color: rgba(196, 18, 48, 0.6);
78
+ margin-right: 8px;
79
+ vertical-align: middle;
80
+ }
81
+
82
+ .grid-item.snap > a {
83
+ text-decoration-color: rgba(255,252,0, 0.4)
84
+ }
85
+
86
+ .grid-item.snap > p::before {
87
+ content: "";
88
+ display: inline-block;
89
+ width: 12px;
90
+ height: 12px;
91
+ background-color: rgba(255, 252, 0, 0.6);
92
+ margin-right: 8px;
93
+ vertical-align: middle;
94
+ }
95
+
96
+ .grid-item.cmu > p,
97
+ .grid-item.snap > p {
98
+ color: #2f2f2f;
99
+ }
100
+
101
+ .column {
102
+ min-width: min(100px, 100%) !important;
103
+ }
104
+
105
+ .block {
106
+ min-width: min(100px, 100%) !important;
107
+ background: transparent !important;
108
+ border: none !important;
109
+ }
110
+
111
+ .gr-box, .gr-form, .gr-panel {
112
+ background: transparent !important;
113
+ border: none !important;
114
+ }
115
+
116
+ .gr-row, .gr-column {
117
+ background: transparent !important;
118
+ }
119
+
120
+ .gr-textbox, .gr-number, .gr-slider, .gr-dropdown {
121
+ background: rgba(255, 255, 255, 0.1) !important;
122
+ border: 1px solid rgba(255, 255, 255, 0.2) !important;
123
+ backdrop-filter: blur(10px) !important;
124
+ }
125
+
126
+ .gr-button {
127
+ background: rgba(255, 255, 255, 0.15) !important;
128
+ border: 1px solid rgba(255, 255, 255, 0.3) !important;
129
+ backdrop-filter: blur(10px) !important;
130
+ }
131
+
132
+ .gr-accordion {
133
+ background: rgba(255, 255, 255, 0.05) !important;
134
+ border: 1px solid rgba(255, 255, 255, 0.1) !important;
135
+ backdrop-filter: blur(5px) !important;
136
+ }
137
+
138
+ .gr-gallery {
139
+ background: rgba(255, 255, 255, 0.05) !important;
140
+ border: 1px solid rgba(255, 255, 255, 0.1) !important;
141
+ backdrop-filter: blur(5px) !important;
142
+ }
143
+
144
+ #starting-candidates-row > #component-7 {
145
+ border: none !important;
146
+ /* font-family: "Varela Round" !important;
147
+ font-weight: 500 !important; */
148
+ }
149
+
150
+ #output-group-size-row >#component-10{
151
+ border: none !important;
152
+ /* font-family: "Varela Round" !important; */
153
+ /* font-weight: 500 !important; */
154
+ }
155
+
156
+ #pruning-ratio-label, #lambda-label, #seed-label, #unary-term-label, #binary-term-label {
157
+ border: none !important;
158
+ /* font-family: "Varela Round" !important;
159
+ font-weight: 500 !important; */
160
+ }