bmarci commited on
Commit
1feed0d
·
1 Parent(s): 35c1a87

More precise GPU allocation

Browse files
Files changed (1) hide show
  1. app.py +158 -165
app.py CHANGED
@@ -44,23 +44,11 @@ def _ensure_pil(x):
44
  raise TypeError("Unsupported image type returned by pipeline.")
45
 
46
 
47
- @spaces.GPU(duration=300)
48
- def infer(
49
- prompt=None,
50
- seed=0,
51
- width=512,
52
- height=512,
53
- num_inference_steps=28,
54
- cfg=DEFAULT_CFG,
55
- positive_prompt=DEFAULT_POSITIVE_PROMPT,
56
- negative_prompt=DEFAULT_NEGATIVE_PROMPT,
57
- progress=gr.Progress(track_tqdm=True),
58
- ):
59
- """Run inference at exactly (width, height)."""
60
  if prompt in [None, ""]:
61
  gr.Warning("⚠️ Please enter a prompt!")
62
  return None
63
-
64
  with autocast(device_type=("cuda" if device == "cuda" else "cpu"), dtype=torch.bfloat16):
65
  imgs = pipeline.generate_image(
66
  prompt,
@@ -77,15 +65,107 @@ def infer(
77
  seed=int(seed),
78
  progress=True,
79
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
- return _ensure_pil(imgs[0]) # Return raw output exactly as generated
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
  css = """
85
  #col-container {
86
  margin: 0 auto;
87
  max-width: 800px;
88
  }
 
 
 
 
89
  """
90
 
91
  with gr.Blocks(css=css) as demo:
@@ -93,178 +173,91 @@ with gr.Blocks(css=css) as demo:
93
  gr.Markdown("# NextStep-1-Large — Image generation")
94
 
95
  with gr.Row():
96
- prompt = gr.Text(
97
- label="Prompt",
98
- show_label=False,
99
- max_lines=2,
100
- placeholder="Enter your prompt",
101
- container=False,
102
- )
103
  run_button = gr.Button("Run", scale=0, variant="primary")
104
  cancel_button = gr.Button("Cancel", scale=0, variant="secondary")
105
 
106
  with gr.Row():
107
  with gr.Accordion("Advanced Settings", open=True):
108
- positive_prompt = gr.Text(
109
- label="Positive Prompt",
110
- show_label=True,
111
- max_lines=1,
112
- placeholder="Optional: add positives",
113
- container=True,
114
- )
115
- negative_prompt = gr.Text(
116
- label="Negative Prompt",
117
- show_label=True,
118
- max_lines=2,
119
- placeholder="Optional: add negatives",
120
- container=True,
121
- )
122
  with gr.Row():
123
- seed = gr.Slider(
124
- label="Seed",
125
- minimum=0,
126
- maximum=MAX_SEED,
127
- step=1,
128
- value=3407,
129
- )
130
- num_inference_steps = gr.Slider(
131
- label="Sampling steps",
132
- minimum=10,
133
- maximum=50,
134
- step=1,
135
- value=28,
136
- )
137
  with gr.Row():
138
- width = gr.Slider(
139
- label="Width",
140
- minimum=256,
141
- maximum=512,
142
- step=64,
143
- value=512,
144
- )
145
- height = gr.Slider(
146
- label="Height",
147
- minimum=256,
148
- maximum=512,
149
- step=64,
150
- value=512,
151
- )
152
- cfg = gr.Slider(
153
- label="CFG (guidance scale)",
154
- minimum=0.0,
155
- maximum=20.0,
156
- step=0.5,
157
- value=DEFAULT_CFG,
158
- info="Higher = closer to text, lower = more creative",
159
- )
160
 
161
  with gr.Row():
162
- result_1 = gr.Image(
163
- label="Result",
164
- show_label=True,
165
- container=True,
166
- interactive=False,
167
- format="png",
168
- )
 
 
169
 
170
  examples = [
171
  [
172
  "Studio portrait of an elderly sailor with a weathered face, dramatic Rembrandt lighting, shallow depth of field",
173
  101, 512, 512, 32, 7.5,
174
  "photorealistic, sharp eyes, detailed skin texture, soft rim light, 85mm lens",
175
- "over-smoothed skin, plastic look, extra limbs, watermark",
176
- ],
177
- [
178
- "Isometric cozy coffee shop interior with hanging plants and warm Edison bulbs",
179
- 202, 512, 384, 30, 8.5,
180
- "isometric view, clean lines, stylized, warm ambience, detailed furniture",
181
- "text, logo, watermark, perspective distortion",
182
- ],
183
- [
184
- "Ultra-wide desert canyon at golden hour with long shadows and dust in the air",
185
- 303, 512, 320, 28, 7.0,
186
- "cinematic, volumetric light, natural colors, high dynamic range",
187
- "over-saturated, haze artifacts, blown highlights",
188
- ],
189
- [
190
- "Cute red panda astronaut sticker, chibi style, white background",
191
- 404, 384, 384, 24, 9.0,
192
- "vector look, bold outlines, high contrast, die-cut silhouette",
193
- "background clutter, drop shadow, gradients, text",
194
- ],
195
- [
196
- "Product render of matte-black wireless headphones on reflective glass with soft studio lighting",
197
- 505, 512, 384, 28, 7.0,
198
- "clean backdrop, realistic reflections, subtle bloom, high detail",
199
- "noise, fingerprints, text, label",
200
- ],
201
- [
202
- "Graphic poster in Bauhaus style with geometric shapes and bold typography placeholders",
203
- 606, 512, 512, 22, 6.0,
204
- "flat colors, minimal palette, crisp edges, balanced composition",
205
- "photo realism, gradients, noisy texture",
206
- ],
207
- [
208
- "Oil painting of a stormy sea with a lighthouse, thick impasto brushwork",
209
- 707, 384, 512, 34, 7.0,
210
- "textured canvas, visible brush strokes, dramatic sky, moody lighting",
211
- "smooth digital look, airbrush, neon colors",
212
- ],
213
- [
214
- "Architectural concept art: glass pavilion in a pine forest at dawn, ground fog",
215
- 808, 512, 384, 30, 8.0,
216
- "physically-based rendering, soft fog, realistic materials, scale figures",
217
- "tilt, skew, warped geometry, chromatic aberration",
218
- ],
219
- [
220
- "Fantasy creature: bioluminescent jellyfish dragon swimming through a dark ocean trench",
221
- 909, 512, 512, 32, 8.5,
222
- "glowing tendrils, soft caustics, particles, high detail",
223
- "washed out, murky, low contrast, extra heads",
224
- ],
225
- [
226
- "Line art coloring page of a city skyline with hot air balloons",
227
- 111, 512, 512, 18, 5.5,
228
- "clean black outlines, uniform stroke weight, high contrast, no shading",
229
- "gray fill, gradients, cross-hatching, text",
230
- ],
231
  ]
232
 
233
  gr.Examples(
234
  examples=examples,
235
- inputs=[
236
- prompt,
237
- seed,
238
- width,
239
- height,
240
- num_inference_steps,
241
- cfg,
242
- positive_prompt,
243
- negative_prompt,
244
- ],
245
  label="Click & Fill Examples (Exact Size)",
246
  )
247
 
248
- def show_result():
249
- return gr.update(visible=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
 
251
- generation_event = gr.on(
252
- triggers=[run_button.click, prompt.submit],
253
- fn=infer,
254
- inputs=[
255
- prompt,
256
- seed,
257
- width,
258
- height,
259
- num_inference_steps,
260
- cfg,
261
- positive_prompt,
262
- negative_prompt,
263
- ],
264
- outputs=[result_1],
265
- )
266
-
267
- cancel_button.click(fn=None, inputs=None, outputs=None, cancels=[generation_event])
268
 
269
  if __name__ == "__main__":
270
- demo.launch()
 
44
  raise TypeError("Unsupported image type returned by pipeline.")
45
 
46
 
47
+ def infer_core(prompt, seed, width, height, num_inference_steps, cfg, positive_prompt, negative_prompt, progress):
48
+ """Core inference logic without GPU decorators."""
 
 
 
 
 
 
 
 
 
 
 
49
  if prompt in [None, ""]:
50
  gr.Warning("⚠️ Please enter a prompt!")
51
  return None
 
52
  with autocast(device_type=("cuda" if device == "cuda" else "cpu"), dtype=torch.bfloat16):
53
  imgs = pipeline.generate_image(
54
  prompt,
 
65
  seed=int(seed),
66
  progress=True,
67
  )
68
+ return _ensure_pil(imgs[0])
69
+
70
+
71
+ # Tier 1: Very small images with few steps
72
+ @spaces.GPU(duration=90)
73
+ def infer_tiny(prompt=None, seed=0, width=512, height=512, num_inference_steps=24, cfg=DEFAULT_CFG,
74
+ positive_prompt=DEFAULT_POSITIVE_PROMPT, negative_prompt=DEFAULT_NEGATIVE_PROMPT,
75
+ progress=gr.Progress(track_tqdm=True)):
76
+ return infer_core(prompt, seed, width, height, num_inference_steps, cfg, positive_prompt, negative_prompt, progress)
77
+
78
+
79
+ # Tier 2: Small to medium images with standard steps
80
+ @spaces.GPU(duration=150)
81
+ def infer_fast(prompt=None, seed=0, width=512, height=512, num_inference_steps=24, cfg=DEFAULT_CFG,
82
+ positive_prompt=DEFAULT_POSITIVE_PROMPT, negative_prompt=DEFAULT_NEGATIVE_PROMPT,
83
+ progress=gr.Progress(track_tqdm=True)):
84
+ return infer_core(prompt, seed, width, height, num_inference_steps, cfg, positive_prompt, negative_prompt, progress)
85
+
86
+
87
+ # Tier 3: Standard generation for most common cases
88
+ @spaces.GPU(duration=200)
89
+ def infer_std(prompt=None, seed=0, width=512, height=512, num_inference_steps=28, cfg=DEFAULT_CFG,
90
+ positive_prompt=DEFAULT_POSITIVE_PROMPT, negative_prompt=DEFAULT_NEGATIVE_PROMPT,
91
+ progress=gr.Progress(track_tqdm=True)):
92
+ return infer_core(prompt, seed, width, height, num_inference_steps, cfg, positive_prompt, negative_prompt, progress)
93
+
94
+
95
+ # Tier 4: Larger images or more steps
96
+ @spaces.GPU(duration=300)
97
+ def infer_long(prompt=None, seed=0, width=512, height=512, num_inference_steps=36, cfg=DEFAULT_CFG,
98
+ positive_prompt=DEFAULT_POSITIVE_PROMPT, negative_prompt=DEFAULT_NEGATIVE_PROMPT,
99
+ progress=gr.Progress(track_tqdm=True)):
100
+ return infer_core(prompt, seed, width, height, num_inference_steps, cfg, positive_prompt, negative_prompt, progress)
101
+
102
+
103
+ # Tier 5: Maximum quality with many steps
104
+ @spaces.GPU(duration=400)
105
+ def infer_max(prompt=None, seed=0, width=512, height=512, num_inference_steps=45, cfg=DEFAULT_CFG,
106
+ positive_prompt=DEFAULT_POSITIVE_PROMPT, negative_prompt=DEFAULT_NEGATIVE_PROMPT,
107
+ progress=gr.Progress(track_tqdm=True)):
108
+ return infer_core(prompt, seed, width, height, num_inference_steps, cfg, positive_prompt, negative_prompt, progress)
109
+
110
+
111
+ # Improved JS dispatcher with better calculation logic
112
+ js_dispatch = """
113
+ function(width, height, steps){
114
+ const w = Number(width);
115
+ const h = Number(height);
116
+ const s = Number(steps);
117
+
118
+ // Calculate total pixels and complexity score
119
+ const pixels = w * h;
120
+ const megapixels = pixels / 1000000;
121
+
122
+ // Complexity score combines image size and steps
123
+ // Base: ~0.5 seconds per megapixel per step
124
+ const complexity = megapixels * s;
125
+
126
+ let target = 'btn-std'; // Default
127
 
128
+ // Select appropriate tier based on complexity
129
+ if (pixels <= 256*256 && s <= 20) {
130
+ // Very small images with few steps
131
+ target = 'btn-tiny';
132
+ } else if (complexity < 5) {
133
+ // Small images or few steps (e.g., 384x384 @ 24 steps = 3.5)
134
+ target = 'btn-fast';
135
+ } else if (complexity < 8) {
136
+ // Standard generation (e.g., 512x512 @ 28 steps = 7.3)
137
+ target = 'btn-std';
138
+ } else if (complexity < 12) {
139
+ // Larger or more steps (e.g., 512x512 @ 40 steps = 10.5)
140
+ target = 'btn-long';
141
+ } else {
142
+ // Maximum complexity
143
+ target = 'btn-max';
144
+ }
145
 
146
+ // Special cases: override based on extreme values
147
+ if (s >= 45) {
148
+ target = 'btn-max'; // Many steps always need more time
149
+ } else if (pixels >= 512*512 && s >= 35) {
150
+ target = 'btn-long'; // Large images with many steps
151
+ }
152
+
153
+ console.log(`Resolution: ${w}x${h}, Steps: ${s}, Complexity: ${complexity.toFixed(2)}, Selected: ${target}`);
154
+
155
+ const b = document.getElementById(target);
156
+ if (b) b.click();
157
+ }
158
+ """
159
 
160
  css = """
161
  #col-container {
162
  margin: 0 auto;
163
  max-width: 800px;
164
  }
165
+ /* Hide the dispatcher buttons */
166
+ #btn-tiny, #btn-fast, #btn-std, #btn-long, #btn-max {
167
+ display: none !important;
168
+ }
169
  """
170
 
171
  with gr.Blocks(css=css) as demo:
 
173
  gr.Markdown("# NextStep-1-Large — Image generation")
174
 
175
  with gr.Row():
176
+ prompt = gr.Text(label="Prompt", show_label=False, max_lines=2, placeholder="Enter your prompt",
177
+ container=False)
 
 
 
 
 
178
  run_button = gr.Button("Run", scale=0, variant="primary")
179
  cancel_button = gr.Button("Cancel", scale=0, variant="secondary")
180
 
181
  with gr.Row():
182
  with gr.Accordion("Advanced Settings", open=True):
183
+ positive_prompt = gr.Text(label="Positive Prompt", show_label=True,
184
+ placeholder="Optional: add positives")
185
+ negative_prompt = gr.Text(label="Negative Prompt", show_label=True,
186
+ placeholder="Optional: add negatives")
 
 
 
 
 
 
 
 
 
 
187
  with gr.Row():
188
+ seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=3407)
189
+ num_inference_steps = gr.Slider(label="Sampling steps", minimum=10, maximum=50, step=1, value=28)
 
 
 
 
 
 
 
 
 
 
 
 
190
  with gr.Row():
191
+ width = gr.Slider(label="Width", minimum=256, maximum=512, step=64, value=512)
192
+ height = gr.Slider(label="Height", minimum=256, maximum=512, step=64, value=512)
193
+ cfg = gr.Slider(label="CFG (guidance scale)", minimum=0.0, maximum=20.0, step=0.5, value=DEFAULT_CFG,
194
+ info="Higher = closer to text, lower = more creative")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
 
196
  with gr.Row():
197
+ result_1 = gr.Image(label="Result", format="png", interactive=False)
198
+
199
+ # Hidden dispatcher buttons
200
+ with gr.Row(visible=False):
201
+ btn_tiny = gr.Button(visible=False, elem_id="btn-tiny")
202
+ btn_fast = gr.Button(visible=False, elem_id="btn-fast")
203
+ btn_std = gr.Button(visible=False, elem_id="btn-std")
204
+ btn_long = gr.Button(visible=False, elem_id="btn-long")
205
+ btn_max = gr.Button(visible=False, elem_id="btn-max")
206
 
207
  examples = [
208
  [
209
  "Studio portrait of an elderly sailor with a weathered face, dramatic Rembrandt lighting, shallow depth of field",
210
  101, 512, 512, 32, 7.5,
211
  "photorealistic, sharp eyes, detailed skin texture, soft rim light, 85mm lens",
212
+ "over-smoothed skin, plastic look, extra limbs, watermark"],
213
+ ["Isometric cozy coffee shop interior with hanging plants and warm Edison bulbs",
214
+ 202, 512, 384, 30, 8.5,
215
+ "isometric view, clean lines, stylized, warm ambience, detailed furniture",
216
+ "text, logo, watermark, perspective distortion"],
217
+ ["Ultra-wide desert canyon at golden hour with long shadows and dust in the air",
218
+ 303, 512, 320, 28, 7.0,
219
+ "cinematic, volumetric light, natural colors, high dynamic range",
220
+ "over-saturated, haze artifacts, blown highlights"],
221
+ ["Oil painting of a stormy sea with a lighthouse, thick impasto brushwork",
222
+ 707, 384, 512, 34, 7.0,
223
+ "textured canvas, visible brush strokes, dramatic sky, moody lighting",
224
+ "smooth digital look, airbrush, neon colors"],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
  ]
226
 
227
  gr.Examples(
228
  examples=examples,
229
+ inputs=[prompt, seed, width, height, num_inference_steps, cfg, positive_prompt, negative_prompt],
 
 
 
 
 
 
 
 
 
230
  label="Click & Fill Examples (Exact Size)",
231
  )
232
 
233
+ # Wire up the dispatcher buttons to their respective functions
234
+ ev_tiny = btn_tiny.click(infer_tiny,
235
+ inputs=[prompt, seed, width, height, num_inference_steps, cfg, positive_prompt,
236
+ negative_prompt],
237
+ outputs=[result_1])
238
+ ev_fast = btn_fast.click(infer_fast,
239
+ inputs=[prompt, seed, width, height, num_inference_steps, cfg, positive_prompt,
240
+ negative_prompt],
241
+ outputs=[result_1])
242
+ ev_std = btn_std.click(infer_std,
243
+ inputs=[prompt, seed, width, height, num_inference_steps, cfg, positive_prompt,
244
+ negative_prompt],
245
+ outputs=[result_1])
246
+ ev_long = btn_long.click(infer_long,
247
+ inputs=[prompt, seed, width, height, num_inference_steps, cfg, positive_prompt,
248
+ negative_prompt],
249
+ outputs=[result_1])
250
+ ev_max = btn_max.click(infer_max,
251
+ inputs=[prompt, seed, width, height, num_inference_steps, cfg, positive_prompt,
252
+ negative_prompt],
253
+ outputs=[result_1])
254
 
255
+ # Trigger JS dispatcher on run button or prompt submit
256
+ run_button.click(None, inputs=[width, height, num_inference_steps], outputs=[], js=js_dispatch)
257
+ prompt.submit(None, inputs=[width, height, num_inference_steps], outputs=[], js=js_dispatch)
258
+
259
+ # Cancel button cancels all possible events
260
+ cancel_button.click(fn=None, inputs=None, outputs=None, cancels=[ev_tiny, ev_fast, ev_std, ev_long, ev_max])
 
 
 
 
 
 
 
 
 
 
 
261
 
262
  if __name__ == "__main__":
263
+ demo.launch()