Spaces:
Build error
Build error
Upload folder using huggingface_hub
Browse files
README.md
CHANGED
@@ -5,7 +5,6 @@ sdk: gradio
|
|
5 |
sdk_version: 4.19.2
|
6 |
---
|
7 |
# ImprovedTokenMerge
|
8 |
-

|
9 |

|
10 |
|
11 |
twitter thread explanation: https://twitter.com/Ethan_smith_20/status/1750533558509433137
|
|
|
5 |
sdk_version: 4.19.2
|
6 |
---
|
7 |
# ImprovedTokenMerge
|
|
|
8 |

|
9 |
|
10 |
twitter thread explanation: https://twitter.com/Ethan_smith_20/status/1750533558509433137
|
app.py
CHANGED
@@ -12,6 +12,64 @@ pipe = diffusers.StableDiffusionPipeline.from_pretrained("Lykon/DreamShaper").to
|
|
12 |
pipe.scheduler = diffusers.EulerDiscreteScheduler.from_config(pipe.scheduler.config)
|
13 |
pipe.safety_checker = None
|
14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
with gr.Blocks() as demo:
|
16 |
prompt = gr.Textbox(interactive=True, label="prompt")
|
17 |
negative_prompt = gr.Textbox(interactive=True, label="negative_prompt")
|
@@ -24,97 +82,13 @@ with gr.Blocks() as demo:
|
|
24 |
seed = gr.Number(label="seed", value=1, precision=0)
|
25 |
result = gr.Textbox(label="Result")
|
26 |
|
27 |
-
|
|
|
|
|
28 |
|
29 |
gen = gr.Button("generate")
|
30 |
|
31 |
-
def which_image(img, target_val=253, width=1024):
|
32 |
-
npimg = np.array(img)
|
33 |
-
loc = np.where(npimg[:, :, 3] == target_val)[1].item()
|
34 |
-
if loc > width:
|
35 |
-
print("Right Image is merged!")
|
36 |
-
else:
|
37 |
-
print("Left Image is merged!")
|
38 |
-
|
39 |
-
@spaces.GPU
|
40 |
-
def generate(prompt, seed, steps, height_width, negative_prompt, guidance_scale, method):
|
41 |
-
|
42 |
-
downsample_factor = 2
|
43 |
-
ratio = 0.38
|
44 |
-
merge_method = "downsample" if method == "todo" else "similarity"
|
45 |
-
merge_tokens = "keys/values" if method == "todo" else "all"
|
46 |
-
|
47 |
-
if height_width == 1024:
|
48 |
-
downsample_factor = 2
|
49 |
-
ratio = 0.75
|
50 |
-
downsample_factor_level_2 = 1
|
51 |
-
ratio_level_2 = 0.0
|
52 |
-
elif height_width == 1536:
|
53 |
-
downsample_factor = 3
|
54 |
-
ratio = 0.89
|
55 |
-
downsample_factor_level_2 = 1
|
56 |
-
ratio_level_2 = 0.0
|
57 |
-
elif height_width == 2048:
|
58 |
-
downsample_factor = 4
|
59 |
-
ratio = 0.9375
|
60 |
-
downsample_factor_level_2 = 2
|
61 |
-
ratio_level_2 = 0.75
|
62 |
-
|
63 |
-
token_merge_args = {"ratio": ratio,
|
64 |
-
"merge_tokens": merge_tokens,
|
65 |
-
"merge_method": merge_method,
|
66 |
-
"downsample_method": "nearest",
|
67 |
-
"downsample_factor": downsample_factor,
|
68 |
-
"timestep_threshold_switch": 0.0,
|
69 |
-
"timestep_threshold_stop": 0.0,
|
70 |
-
"downsample_factor_level_2": downsample_factor_level_2,
|
71 |
-
"ratio_level_2": ratio_level_2
|
72 |
-
}
|
73 |
-
|
74 |
-
l_r = torch.rand(1).item()
|
75 |
-
torch.manual_seed(seed)
|
76 |
-
start_time_base = time.time()
|
77 |
-
base_img = pipe(prompt,
|
78 |
-
num_inference_steps=steps, height=height_width, width=height_width,
|
79 |
-
negative_prompt=negative_prompt,
|
80 |
-
guidance_scale=guidance_scale).images[0]
|
81 |
-
end_time_base = time.time()
|
82 |
-
|
83 |
-
patch_attention_proc(pipe.unet, token_merge_args=token_merge_args)
|
84 |
-
|
85 |
-
torch.manual_seed(seed)
|
86 |
-
start_time_merge = time.time()
|
87 |
-
merged_img = pipe(prompt,
|
88 |
-
num_inference_steps=steps, height=height_width, width=height_width,
|
89 |
-
negative_prompt=negative_prompt,
|
90 |
-
guidance_scale=guidance_scale).images[0]
|
91 |
-
end_time_merge = time.time()
|
92 |
-
|
93 |
-
base_img = base_img.convert("RGBA")
|
94 |
-
merged_img = merged_img.convert("RGBA")
|
95 |
-
merged_img = np.array(merged_img)
|
96 |
-
halfh, halfw = height_width // 2, height_width // 2
|
97 |
-
merged_img[halfh, halfw, 3] = 253 # set the center pixel of the merged image to be ever so slightly below 255 in alpha channel
|
98 |
-
merged_img = Image.fromarray(merged_img)
|
99 |
-
final_img = Image.new(size=(height_width * 2, height_width), mode="RGBA")
|
100 |
-
|
101 |
-
if l_r > 0.5:
|
102 |
-
left_img = base_img
|
103 |
-
right_img = merged_img
|
104 |
-
else:
|
105 |
-
left_img = merged_img
|
106 |
-
right_img = base_img
|
107 |
-
|
108 |
-
final_img.paste(left_img, (0, 0))
|
109 |
-
final_img.paste(right_img, (height_width, 0))
|
110 |
-
|
111 |
-
which_image(final_img, width=height_width)
|
112 |
-
|
113 |
-
result = f"Baseline image: {end_time_base-start_time_base:.2f} sec | {'ToDo' if method == 'todo' else 'ToMe'} image: {end_time_merge-start_time_merge:.2f} sec"
|
114 |
-
|
115 |
-
return final_img, result
|
116 |
-
|
117 |
gen.click(generate, inputs=[prompt, seed, steps, height_width, negative_prompt,
|
118 |
-
guidance_scale, method], outputs=[output_image, result])
|
119 |
|
120 |
demo.launch(share=True)
|
|
|
12 |
pipe.scheduler = diffusers.EulerDiscreteScheduler.from_config(pipe.scheduler.config)
|
13 |
pipe.safety_checker = None
|
14 |
|
15 |
+
@spaces.GPU
|
16 |
+
def generate(prompt, seed, steps, height_width, negative_prompt, guidance_scale, method):
|
17 |
+
|
18 |
+
downsample_factor = 2
|
19 |
+
ratio = 0.38
|
20 |
+
merge_method = "downsample" if method == "todo" else "similarity"
|
21 |
+
merge_tokens = "keys/values" if method == "todo" else "all"
|
22 |
+
|
23 |
+
if height_width == 1024:
|
24 |
+
downsample_factor = 2
|
25 |
+
ratio = 0.75
|
26 |
+
downsample_factor_level_2 = 1
|
27 |
+
ratio_level_2 = 0.0
|
28 |
+
elif height_width == 1536:
|
29 |
+
downsample_factor = 3
|
30 |
+
ratio = 0.89
|
31 |
+
downsample_factor_level_2 = 1
|
32 |
+
ratio_level_2 = 0.0
|
33 |
+
elif height_width == 2048:
|
34 |
+
downsample_factor = 4
|
35 |
+
ratio = 0.9375
|
36 |
+
downsample_factor_level_2 = 2
|
37 |
+
ratio_level_2 = 0.75
|
38 |
+
|
39 |
+
token_merge_args = {"ratio": ratio,
|
40 |
+
"merge_tokens": merge_tokens,
|
41 |
+
"merge_method": merge_method,
|
42 |
+
"downsample_method": "nearest",
|
43 |
+
"downsample_factor": downsample_factor,
|
44 |
+
"timestep_threshold_switch": 0.0,
|
45 |
+
"timestep_threshold_stop": 0.0,
|
46 |
+
"downsample_factor_level_2": downsample_factor_level_2,
|
47 |
+
"ratio_level_2": ratio_level_2
|
48 |
+
}
|
49 |
+
|
50 |
+
l_r = torch.rand(1).item()
|
51 |
+
torch.manual_seed(seed)
|
52 |
+
start_time_base = time.time()
|
53 |
+
base_img = pipe(prompt,
|
54 |
+
num_inference_steps=steps, height=height_width, width=height_width,
|
55 |
+
negative_prompt=negative_prompt,
|
56 |
+
guidance_scale=guidance_scale).images[0]
|
57 |
+
end_time_base = time.time()
|
58 |
+
|
59 |
+
patch_attention_proc(pipe.unet, token_merge_args=token_merge_args)
|
60 |
+
|
61 |
+
torch.manual_seed(seed)
|
62 |
+
start_time_merge = time.time()
|
63 |
+
merged_img = pipe(prompt,
|
64 |
+
num_inference_steps=steps, height=height_width, width=height_width,
|
65 |
+
negative_prompt=negative_prompt,
|
66 |
+
guidance_scale=guidance_scale).images[0]
|
67 |
+
end_time_merge = time.time()
|
68 |
+
|
69 |
+
result = f"Baseline image: {end_time_base-start_time_base:.2f} sec | {'ToDo' if method == 'todo' else 'ToMe'} image: {end_time_merge-start_time_merge:.2f} sec"
|
70 |
+
|
71 |
+
return base_img, merged_img, result
|
72 |
+
|
73 |
with gr.Blocks() as demo:
|
74 |
prompt = gr.Textbox(interactive=True, label="prompt")
|
75 |
negative_prompt = gr.Textbox(interactive=True, label="negative_prompt")
|
|
|
82 |
seed = gr.Number(label="seed", value=1, precision=0)
|
83 |
result = gr.Textbox(label="Result")
|
84 |
|
85 |
+
with gr.Row():
|
86 |
+
base_image = gr.Image(label=f"baseline_image", type="pil", interactive=False)
|
87 |
+
output_image = gr.Image(label=f"output_image", type="pil", interactive=False)
|
88 |
|
89 |
gen = gr.Button("generate")
|
90 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
gen.click(generate, inputs=[prompt, seed, steps, height_width, negative_prompt,
|
92 |
+
guidance_scale, method], outputs=[base_image, output_image, result])
|
93 |
|
94 |
demo.launch(share=True)
|