File size: 13,888 Bytes
64cc34d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5d80794
 
64cc34d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6530a10
 
 
 
 
 
 
 
 
 
 
 
64cc34d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5d80794
64cc34d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
import os
import spaces
import gradio as gr
import numpy as np
import torch
from PIL import Image

from ip_adapter import (
    ConceptrolIPAdapterPlus,
    ConceptrolIPAdapterPlusXL,
)
from ip_adapter.custom_pipelines import (
    StableDiffusionCustomPipeline,
    StableDiffusionXLCustomPipeline,
)
from omini_control.conceptrol import Conceptrol
from omini_control.flux_conceptrol_pipeline import FluxConceptrolPipeline


os.environ["TOKENIZERS_PARALLELISM"] = "false"

title = r"""
<h1 align="center">Conceptrol: Concept Control of Zero-shot Personalized Image Generation</h1>
"""

description = r"""
<b>Official 🤗 Gradio demo</b> for <a href='https://github.com/QY-H00/Conceptrol/tree/public' target='_blank'><b>Conceptrol: Concept Control of Zero-shot Personalized Image Generation</b></a>.<br>
How to use:<br>
1. Input text prompt, visual specification and the textual concept of the personalized target.
2. Choose your preferrd base model, the first time for switching might take longer time to download the model.
3. For each inference, SD-series takes about 20s, SDXL-series takes about 50s, FLUX takes about 100s.
4. Click the <b>Generate</b> button to enjoy! 😊
"""

article = r"""
---
✒️ **Citation**
<br>
If you found this demo/our paper useful, please consider citing:
```bibtex
@article{he2025conceptrol,
  title={Conceptrol: Concept Control of Zero-shot Personalized Image Generation},
  author={He, Qiyuan and Yao, Angela},
  journal={arXiv preprint arXiv:2403.17924},
  year={2024}
}
```
📧 **Contact**
<br>
If you have any questions, please feel free to open an issue in our <a href='https://github.com/QY-H00/Conceptrol/tree/public' target='_blank'><b>Github Repo</b></a> or directly reach us out at <b>[email protected]</b>.
"""

MAX_SEED = np.iinfo(np.int32).max
CACHE_EXAMPLES = False
USE_TORCH_COMPILE = False
ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD") == "1"
PREVIEW_IMAGES = False

# Default settings
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
adapter_name = "h94/IP-Adapter/models/ip-adapter-plus_sd15.bin"
pipe = StableDiffusionCustomPipeline.from_pretrained(
    "SG161222/Realistic_Vision_V5.1_noVAE",
    torch_dtype=torch.bfloat16,
    feature_extractor=None,
    safety_checker=None
)
pipeline = ConceptrolIPAdapterPlus(pipe, "", adapter_name, device, num_tokens=16)

@spaces.GPU
def change_model_fn(model_name: str) -> None:
    global device, pipeline
    
    # Clear GPU memory
    if torch.cuda.is_available():
        if pipeline is not None:
            del pipeline
        torch.cuda.empty_cache()
        
    name_mapping = {
        "SD1.5-512": "stable-diffusion-v1-5/stable-diffusion-v1-5", 
        "AOM3 (SD-based)": "hogiahien/aom3",
        "RealVis-v5.1 (SD-based)": "SG161222/Realistic_Vision_V5.1_noVAE",
        "SDXL-1024": "stabilityai/stable-diffusion-xl-base-1.0",
        "RealVisXL-v5.0 (SDXL-based)": "SG161222/RealVisXL_V5.0",
        "Playground-XL-v2 (SDXL-based)": "playgroundai/playground-v2.5-1024px-aesthetic",
        "Animagine-XL-v4.0 (SDXL-based)": "cagliostrolab/animagine-xl-4.0",
        "FLUX-schnell": "black-forest-labs/FLUX.1-schnell"
    }
    if "XL" in model_name:
        adapter_name = "h94/IP-Adapter/sdxl_models/ip-adapter-plus_sdxl_vit-h.safetensors"
        pipe = StableDiffusionXLCustomPipeline.from_pretrained(
            name_mapping[model_name],
            # variant="fp16",
            torch_dtype=torch.bfloat16,
            feature_extractor=None
        ).to(device)
        pipeline = ConceptrolIPAdapterPlusXL(pipe, "", adapter_name, device, num_tokens=16)
        globals()["pipeline"] = pipeline
    
    elif "FLUX" in model_name:
        adapter_name = "Yuanshi/OminiControl"
        pipeline = FluxConceptrolPipeline.from_pretrained(
            name_mapping[model_name], torch_dtype=torch.bfloat16
        ).to(device)
        pipeline.load_lora_weights(
            adapter_name,
            weight_name="omini/subject_512.safetensors",
            adapter_name="subject",
        )
        config = {"name": "conceptrol"}
        conceptrol = Conceptrol(config)
        pipeline.load_conceptrol(conceptrol)
        globals()["pipeline"] = pipeline
        globals()["pipeline"].to(device, dtype=torch.bfloat16)
    
    elif "XL" not in model_name and "FLUX" not in model_name:
        adapter_name = "h94/IP-Adapter/models/ip-adapter-plus_sd15.bin"
        pipe = StableDiffusionCustomPipeline.from_pretrained(
            name_mapping[model_name],
            torch_dtype=torch.bfloat16,
            feature_extractor=None,
            safety_checker=None
        ).to(device)
        pipeline = ConceptrolIPAdapterPlus(pipe, "", adapter_name, device, num_tokens=16)
        globals()["pipeline"] = pipeline
    else:
        raise KeyError("Not supported model name!")


def save_image(img, index):
    unique_name = f"{index}.png"
    img = Image.fromarray(img)
    img.save(unique_name)
    return unique_name


def get_example() -> list[list[str | float | int]]:
    case = [
        [
            "A high-resolution photograph of a serene cat, its fur softly illuminated by dappled sunlight, sitting amidst a lush, vibrant forest with rays of light filtering through the trees.",
            "cat",
            "",
            Image.open("demo/cat.jpg"),
            20,
            3.5,
            1.0,
            0.0,
            42,
            "FLUX-schnell"
        ],
        [
            "A statue is reading the book in the cafe, best quality, high quality",
            "statue",
            "deformed, ugly, wrong proportion, low res, bad anatomy, worst quality, low quality",
            Image.open("demo/statue.jpg"),
            50,
            6.0,
            1.0,
            0.2,
            42,
            "RealVis-v5.1 (SD-based)"
        ],
        [
            "A hyper-realistic, high-resolution photograph of an astronaut in a meticulously detailed space suit riding a majestic horse across an otherworldly landscape. The image features dynamic lighting, rich textures, and a cinematic atmosphere, capturing every intricate detail in stunning clarity.",
            "horse",
            "deformed, ugly, wrong proportion, low res, bad anatomy, worst quality, low quality",
            Image.open("demo/horse.jpg"),
            50,
            6.0,
            1.0,
            0.2,
            42,
            "RealVisXL-v5.0 (SDXL-based)"
        ],
        [
            "A man wearing a T-shirt walking on the street",
            "T-shirt",
            "",
            Image.open("demo/t-shirt.jpg"),
            20,
            3.5,
            1.0,
            0.0,
            42,
            "FLUX-schnell"
        ]
    ]
    return case


def change_generate_button_fn(enable: int) -> gr.Button:
    if enable == 0:
        return gr.Button(interactive=False, value="Switching Model...")
    else:
        return gr.Button(interactive=True, value="Generate")


def dynamic_gallery_fn():
    return gr.Image(label="Result", show_label=False)

@spaces.GPU(duration=110)
@torch.no_grad()
def generate(
    prompt="a statue is reading the book in the cafe",
    subject="cat",
    negative_prompt="",
    image=None,
    num_inference_steps=20,
    guidance_scale=3.5,
    condition_scale=1.0,
    control_guidance_start=0.0,
    seed=0,
    model_name="RealVis-v5.1 (SD-based)"
) -> np.ndarray:
    global pipeline
    change_model_fn(model_name)
    if isinstance(pipeline, FluxConceptrolPipeline):
        images = pipeline(
            prompt=prompt,
            image=image,
            subject=subject,
            num_inference_steps=num_inference_steps,
            guidance_scale=guidance_scale,
            condition_scale=condition_scale,
            control_guidance_start=control_guidance_start,
            height=512,
            width=512,
            seed=seed,
        ).images[0]
    elif isinstance(pipeline, ConceptrolIPAdapterPlus) or isinstance(pipeline, ConceptrolIPAdapterPlusXL):
        with torch.cuda.amp.autocast():
            images = pipeline.generate(
                prompt=prompt,
                pil_images=[image],
                subjects=[subject],
                num_samples=1,
                num_inference_steps=50,
                scale=condition_scale,
                negative_prompt=negative_prompt,
                control_guidance_start=control_guidance_start,
                seed=seed,
            )[0]
    else:
        raise TypeError("Unsupported Pipeline")

    return images

with gr.Blocks(css="style.css") as demo:
    gr.Markdown(title)
    gr.Markdown(description)
    with gr.Row(elem_classes="grid-container"):
        with gr.Group():
            with gr.Row(elem_classes="flex-grow"):
                with gr.Column(elem_classes="grid-item"):  # 左侧列
                    prompt = gr.Text(
                        label="Prompt",
                        max_lines=3,
                        placeholder="Enter the Descriptive Prompt",
                        interactive=True,
                        value="A statue is reading the book in the cafe, best quality, high quality",
                    )
                    textual_concept = gr.Text(
                        label="Textual Concept",
                        max_lines=3,
                        placeholder="Enter the Textual Concept required customization",
                        interactive=True,
                        value="statue",
                    )
                    negative_prompt = gr.Text(
                        label="Negative prompt",
                        max_lines=3,
                        placeholder="Enter a Negative Prompt",
                        interactive=True,
                        value="deformed, ugly, wrong proportion, low res, bad anatomy, worst quality, low quality"
                    )
        
        with gr.Row(elem_classes="flex-grow"):
            image_prompt = gr.Image(
                    label="Reference Image for customization",
                    interactive=True,
                    height=280,
                    type="pil"
                )
                

        with gr.Group():
            with gr.Column(elem_classes="grid-item"):  # 右侧列
                with gr.Row(elem_classes="flex-grow"):
                    
                    with gr.Group():
                        # result = gr.Gallery(label="Result", show_label=False, rows=1, columns=1)
                        result = gr.Image(label="Result", show_label=False, height=238, width=256)
                        generate_button = gr.Button(value="Generate", variant="primary")

    with gr.Accordion("Advanced options", open=True):
        with gr.Row():
            with gr.Column():
                # with gr.Row(elem_classes="flex-grow"):
                model_choice = gr.Dropdown(
                    [
                        "AOM3 (SD-based)",
                        "SD1.5-512",
                        "RealVis-v5.1 (SD-based)",
                        "SDXL-1024", 
                        "RealVisXL-v5.0 (SDXL-based)",
                        "Animagine-XL-v4.0 (SDXL-based)",
                        "FLUX-schnell"
                    ],
                    label="Model",
                    value="RealVis-v5.1 (SD-based)",
                    interactive=True,
                    info="XL-Series takes longer time and FLUX takes even more",
                )
                condition_scale = gr.Slider(
                    label="Condition Scale of Reference Image",
                    minimum=0.4,
                    maximum=1.5,
                    step=0.05,
                    value=1.0,
                    interactive=True,
                )
                warmup_ratio = gr.Slider(
                    label="Warmup Ratio",
                    minimum=0.0,
                    maximum=1,
                    step=0.05,
                    value=0.2,
                    interactive=True,
                )
                guidance_scale = gr.Slider(
                    label="Guidance Scale",
                    minimum=0,
                    maximum=10,
                    step=0.1,
                    value=5.0,
                    interactive=True,
                )
        num_inference_steps = gr.Slider(
            label="Inference Steps",
            minimum=10,
            maximum=50,
            step=1,
            value=50,
            interactive=True,
        )
        with gr.Column():
            seed = gr.Slider(
                label="Seed",
                minimum=0,
                maximum=MAX_SEED,
                step=1,
                value=0,
            )

    gr.Examples(
        examples=get_example(),
        inputs=[
            prompt,
            textual_concept,
            negative_prompt,
            image_prompt,
            num_inference_steps,
            guidance_scale,
            condition_scale,
            warmup_ratio,
            seed,
            model_choice
        ],
        cache_examples=CACHE_EXAMPLES,
    )

    # model_choice.change(
    #     fn=change_generate_button_fn,
    #     inputs=gr.Number(0, visible=False),
    #     outputs=generate_button,
    # )
    
    # .then(fn=change_model_fn, inputs=model_choice).then(
    #     fn=change_generate_button_fn,
    #     inputs=gr.Number(1, visible=False),
    #     outputs=generate_button,
    # )
    inputs = [
        prompt,
        textual_concept,
        negative_prompt,
        image_prompt,
        num_inference_steps,
        guidance_scale,
        condition_scale,
        warmup_ratio,
        seed,
        model_choice
    ]
    generate_button.click(
        fn=dynamic_gallery_fn,
        outputs=result,
    ).then(
        fn=generate,
        inputs=inputs,
        outputs=result,
    )
    gr.Markdown(article)

demo.launch()