File size: 6,606 Bytes
949981c
 
 
 
 
 
 
f6ad3fd
949981c
 
0686264
949981c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0686264
 
 
 
 
 
 
 
 
 
b74b343
0686264
 
 
 
b74b343
0686264
 
 
 
 
 
 
949981c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b74b343
949981c
 
 
 
 
 
 
 
b74b343
949981c
 
b74b343
949981c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2283c6e
949981c
 
b74b343
949981c
 
 
 
 
 
 
 
b74b343
949981c
 
 
 
b74b343
 
 
 
 
 
 
949981c
 
 
 
 
 
b74b343
949981c
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import torch
import numpy as np
import math
import spaces
from diffusers import StableDiffusionXLPipeline
from transformers import AutoFeatureExtractor
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from ip_adapter import EasyRef
from huggingface_hub import hf_hub_download
import gradio as gr
import os
import cv2
import pillow_avif

def adaptive_resize(w, h, size=1024):
    times = math.sqrt(h * w / (size**2))
    if w==h:
        w, h = size, size
    elif times > 1.1:
        w, h = math.ceil(w / times), math.ceil(h / times)
    elif times < 0.8:
        w, h = math.ceil(w / times), math.ceil(h / times)
    new_w, new_h = 64 * (math.ceil(w / 64)), 64 * (math.ceil(h / 64))
    return new_w, new_h

def res2string(w, h):
    return str(w)+"x"+str(h)

def get_image_path_list(folder_name):
    image_basename_list = os.listdir(folder_name)
    image_path_list = sorted([os.path.join(folder_name, basename) for basename in image_basename_list])
    return image_path_list

def get_example():
    case = [
        [
            get_image_path_list('./assets/aragaki_identity'),
            "An oil painting of a smiling woman.",
            "A collage of images, naked, monochrome, lowres, bad anatomy, worst quality, low quality",
        ],
        [
            get_image_path_list('./assets/blindbox_style'),
            "Donald Trump",
            "A collage of images, naked, monochrome, lowres, bad anatomy, worst quality, low quality",
        ],
    ]
    return case

def upload_example_to_gallery(images, prompt, negative_prompt):
    return gr.update(value=images, visible=True), gr.update(visible=True), gr.update(visible=False)

base_model_path = "stabilityai/stable-diffusion-xl-base-1.0"
multimodal_llm_path = "Qwen/Qwen2-VL-2B-Instruct"
ip_ckpt = hf_hub_download(repo_id="zongzhuofan/EasyRef", filename="pytorch_model.bin", repo_type="model")

safety_model_id = "CompVis/stable-diffusion-safety-checker"
safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)

device = "cuda"

pipe = StableDiffusionXLPipeline.from_pretrained(
    base_model_path,
    torch_dtype=torch.float16,
    feature_extractor=safety_feature_extractor,
    safety_checker=safety_checker,
    add_watermarker=False,
).to(device)

easyref = EasyRef(pipe, multimodal_llm_path, ip_ckpt, device, num_tokens=64, use_lora=True, lora_rank=128)

cv2.setNumThreads(1)

@spaces.GPU(enable_queue=True)
def generate_image(images, prompt, negative_prompt, scale, num_inference_steps, seed, progress=gr.Progress(track_tqdm=True)):
    print("Generating")
    template = "Visualize a scene that closely resembles the provided images, capturing the essence and details described in this prompt:\n"
    system_prompt = [template + prompt, template]
    image = easyref.generate(
        pil_image=images,
        system_prompt=system_prompt, 
        prompt=prompt, 
        negative_prompt=negative_prompt, 
        scale=scale,
        num_samples=1,
        num_inference_steps=num_inference_steps,
        seed=seed)
    print(image)
    return image

# def change_style(style):
#     if style == "Photorealistic":
#         return(gr.update(value=True), gr.update(value=1.3), gr.update(value=1.0))
#     else:
#         return(gr.update(value=True), gr.update(value=0.1), gr.update(value=0.8))

def swap_to_gallery(images):
    return gr.update(value=images, visible=True), gr.update(visible=True), gr.update(visible=False)

def remove_back_to_files():
    return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)

MAX_SEED = np.iinfo(np.int32).max
css = '''
h1{margin-bottom: 0 !important}
'''
with gr.Blocks(css=css) as demo:
    gr.Markdown("# EasyRef demo")
    gr.Markdown("Demo for the [zongzhuofan/EasyRef model](https://huggingface.co/zongzhuofan/EasyRef)")
    with gr.Row():
        with gr.Column():
            files = gr.Files(
                        label="Multiple reference images",
                        file_types=["image"]
                    )
            uploaded_files = gr.Gallery(label="Your images", visible=False, columns=6, rows=1, height=125)
            with gr.Column(visible=False) as clear_button:
                remove_and_reupload = gr.ClearButton(value="Remove and upload new ones", components=files, size="sm")
            prompt = gr.Textbox(label="Prompt",
                       placeholder="An oil painting of a [man/woman/person]...")
            negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="A collage of images, naked, monochrome, lowres, bad anatomy, worst quality, low quality")
            # style = gr.Radio(label="Generation type", info="For stylized try prompts like 'a watercolor painting of a woman'", choices=["Photorealistic", "Stylized"], value="Photorealistic")
            submit = gr.Button("Submit")
            with gr.Accordion(open=False, label="Advanced Options"):
                scale = gr.Slider(label="Scale", info="Scale for image reference", value=1.0, step=0.1, minimum=0.5, maximum=1.5)
                num_inference_steps = gr.Slider(label="Number of inference steps", value=30, step=1, minimum=1, maximum=60)
                seed = gr.Slider(label="Seed", value=24, step=1, minimum=0, maximum=MAX_SEED)
        with gr.Column():
            gallery = gr.Gallery(label="Generated Images")
            gr.Examples(
                examples=get_example(),
                inputs=[files, prompt, negative_prompt],
                run_on_click=True,
                fn=upload_example_to_gallery,
                outputs=[uploaded_files, clear_button, files],
            )            
        # style.change(fn=change_style,
        #             inputs=style,
        #             outputs=[preserve, face_strength, likeness_strength])
        files.upload(fn=swap_to_gallery, inputs=files, outputs=[uploaded_files, clear_button, files])
        remove_and_reupload.click(fn=remove_back_to_files, outputs=[uploaded_files, clear_button, files])
        submit.click(fn=generate_image,
                    inputs=[files, prompt, negative_prompt, scale, num_inference_steps, seed],
                    outputs=gallery)
            
    gr.Markdown("We release our checkpoints for research purposes only. Users are granted the freedom to create images using this tool, but they are expected to comply with local laws and utilize it in a responsible manner. The developers do not assume any responsibility for potential misuse by users.")
    
demo.launch()