Spaces:
Build error
Build error
Upload folder using huggingface_hub
Browse files
README.md
CHANGED
|
@@ -5,7 +5,6 @@ sdk: gradio
|
|
| 5 |
sdk_version: 4.19.2
|
| 6 |
---
|
| 7 |
# ImprovedTokenMerge
|
| 8 |
-

|
| 9 |

|
| 10 |
|
| 11 |
twitter thread explanation: https://twitter.com/Ethan_smith_20/status/1750533558509433137
|
|
|
|
| 5 |
sdk_version: 4.19.2
|
| 6 |
---
|
| 7 |
# ImprovedTokenMerge
|
|
|
|
| 8 |

|
| 9 |
|
| 10 |
twitter thread explanation: https://twitter.com/Ethan_smith_20/status/1750533558509433137
|
app.py
CHANGED
|
@@ -12,6 +12,64 @@ pipe = diffusers.StableDiffusionPipeline.from_pretrained("Lykon/DreamShaper").to
|
|
| 12 |
pipe.scheduler = diffusers.EulerDiscreteScheduler.from_config(pipe.scheduler.config)
|
| 13 |
pipe.safety_checker = None
|
| 14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
with gr.Blocks() as demo:
|
| 16 |
prompt = gr.Textbox(interactive=True, label="prompt")
|
| 17 |
negative_prompt = gr.Textbox(interactive=True, label="negative_prompt")
|
|
@@ -24,97 +82,13 @@ with gr.Blocks() as demo:
|
|
| 24 |
seed = gr.Number(label="seed", value=1, precision=0)
|
| 25 |
result = gr.Textbox(label="Result")
|
| 26 |
|
| 27 |
-
|
|
|
|
|
|
|
| 28 |
|
| 29 |
gen = gr.Button("generate")
|
| 30 |
|
| 31 |
-
def which_image(img, target_val=253, width=1024):
|
| 32 |
-
npimg = np.array(img)
|
| 33 |
-
loc = np.where(npimg[:, :, 3] == target_val)[1].item()
|
| 34 |
-
if loc > width:
|
| 35 |
-
print("Right Image is merged!")
|
| 36 |
-
else:
|
| 37 |
-
print("Left Image is merged!")
|
| 38 |
-
|
| 39 |
-
@spaces.GPU
|
| 40 |
-
def generate(prompt, seed, steps, height_width, negative_prompt, guidance_scale, method):
|
| 41 |
-
|
| 42 |
-
downsample_factor = 2
|
| 43 |
-
ratio = 0.38
|
| 44 |
-
merge_method = "downsample" if method == "todo" else "similarity"
|
| 45 |
-
merge_tokens = "keys/values" if method == "todo" else "all"
|
| 46 |
-
|
| 47 |
-
if height_width == 1024:
|
| 48 |
-
downsample_factor = 2
|
| 49 |
-
ratio = 0.75
|
| 50 |
-
downsample_factor_level_2 = 1
|
| 51 |
-
ratio_level_2 = 0.0
|
| 52 |
-
elif height_width == 1536:
|
| 53 |
-
downsample_factor = 3
|
| 54 |
-
ratio = 0.89
|
| 55 |
-
downsample_factor_level_2 = 1
|
| 56 |
-
ratio_level_2 = 0.0
|
| 57 |
-
elif height_width == 2048:
|
| 58 |
-
downsample_factor = 4
|
| 59 |
-
ratio = 0.9375
|
| 60 |
-
downsample_factor_level_2 = 2
|
| 61 |
-
ratio_level_2 = 0.75
|
| 62 |
-
|
| 63 |
-
token_merge_args = {"ratio": ratio,
|
| 64 |
-
"merge_tokens": merge_tokens,
|
| 65 |
-
"merge_method": merge_method,
|
| 66 |
-
"downsample_method": "nearest",
|
| 67 |
-
"downsample_factor": downsample_factor,
|
| 68 |
-
"timestep_threshold_switch": 0.0,
|
| 69 |
-
"timestep_threshold_stop": 0.0,
|
| 70 |
-
"downsample_factor_level_2": downsample_factor_level_2,
|
| 71 |
-
"ratio_level_2": ratio_level_2
|
| 72 |
-
}
|
| 73 |
-
|
| 74 |
-
l_r = torch.rand(1).item()
|
| 75 |
-
torch.manual_seed(seed)
|
| 76 |
-
start_time_base = time.time()
|
| 77 |
-
base_img = pipe(prompt,
|
| 78 |
-
num_inference_steps=steps, height=height_width, width=height_width,
|
| 79 |
-
negative_prompt=negative_prompt,
|
| 80 |
-
guidance_scale=guidance_scale).images[0]
|
| 81 |
-
end_time_base = time.time()
|
| 82 |
-
|
| 83 |
-
patch_attention_proc(pipe.unet, token_merge_args=token_merge_args)
|
| 84 |
-
|
| 85 |
-
torch.manual_seed(seed)
|
| 86 |
-
start_time_merge = time.time()
|
| 87 |
-
merged_img = pipe(prompt,
|
| 88 |
-
num_inference_steps=steps, height=height_width, width=height_width,
|
| 89 |
-
negative_prompt=negative_prompt,
|
| 90 |
-
guidance_scale=guidance_scale).images[0]
|
| 91 |
-
end_time_merge = time.time()
|
| 92 |
-
|
| 93 |
-
base_img = base_img.convert("RGBA")
|
| 94 |
-
merged_img = merged_img.convert("RGBA")
|
| 95 |
-
merged_img = np.array(merged_img)
|
| 96 |
-
halfh, halfw = height_width // 2, height_width // 2
|
| 97 |
-
merged_img[halfh, halfw, 3] = 253 # set the center pixel of the merged image to be ever so slightly below 255 in alpha channel
|
| 98 |
-
merged_img = Image.fromarray(merged_img)
|
| 99 |
-
final_img = Image.new(size=(height_width * 2, height_width), mode="RGBA")
|
| 100 |
-
|
| 101 |
-
if l_r > 0.5:
|
| 102 |
-
left_img = base_img
|
| 103 |
-
right_img = merged_img
|
| 104 |
-
else:
|
| 105 |
-
left_img = merged_img
|
| 106 |
-
right_img = base_img
|
| 107 |
-
|
| 108 |
-
final_img.paste(left_img, (0, 0))
|
| 109 |
-
final_img.paste(right_img, (height_width, 0))
|
| 110 |
-
|
| 111 |
-
which_image(final_img, width=height_width)
|
| 112 |
-
|
| 113 |
-
result = f"Baseline image: {end_time_base-start_time_base:.2f} sec | {'ToDo' if method == 'todo' else 'ToMe'} image: {end_time_merge-start_time_merge:.2f} sec"
|
| 114 |
-
|
| 115 |
-
return final_img, result
|
| 116 |
-
|
| 117 |
gen.click(generate, inputs=[prompt, seed, steps, height_width, negative_prompt,
|
| 118 |
-
guidance_scale, method], outputs=[output_image, result])
|
| 119 |
|
| 120 |
demo.launch(share=True)
|
|
|
|
| 12 |
pipe.scheduler = diffusers.EulerDiscreteScheduler.from_config(pipe.scheduler.config)
|
| 13 |
pipe.safety_checker = None
|
| 14 |
|
| 15 |
+
@spaces.GPU
|
| 16 |
+
def generate(prompt, seed, steps, height_width, negative_prompt, guidance_scale, method):
|
| 17 |
+
|
| 18 |
+
downsample_factor = 2
|
| 19 |
+
ratio = 0.38
|
| 20 |
+
merge_method = "downsample" if method == "todo" else "similarity"
|
| 21 |
+
merge_tokens = "keys/values" if method == "todo" else "all"
|
| 22 |
+
|
| 23 |
+
if height_width == 1024:
|
| 24 |
+
downsample_factor = 2
|
| 25 |
+
ratio = 0.75
|
| 26 |
+
downsample_factor_level_2 = 1
|
| 27 |
+
ratio_level_2 = 0.0
|
| 28 |
+
elif height_width == 1536:
|
| 29 |
+
downsample_factor = 3
|
| 30 |
+
ratio = 0.89
|
| 31 |
+
downsample_factor_level_2 = 1
|
| 32 |
+
ratio_level_2 = 0.0
|
| 33 |
+
elif height_width == 2048:
|
| 34 |
+
downsample_factor = 4
|
| 35 |
+
ratio = 0.9375
|
| 36 |
+
downsample_factor_level_2 = 2
|
| 37 |
+
ratio_level_2 = 0.75
|
| 38 |
+
|
| 39 |
+
token_merge_args = {"ratio": ratio,
|
| 40 |
+
"merge_tokens": merge_tokens,
|
| 41 |
+
"merge_method": merge_method,
|
| 42 |
+
"downsample_method": "nearest",
|
| 43 |
+
"downsample_factor": downsample_factor,
|
| 44 |
+
"timestep_threshold_switch": 0.0,
|
| 45 |
+
"timestep_threshold_stop": 0.0,
|
| 46 |
+
"downsample_factor_level_2": downsample_factor_level_2,
|
| 47 |
+
"ratio_level_2": ratio_level_2
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
l_r = torch.rand(1).item()
|
| 51 |
+
torch.manual_seed(seed)
|
| 52 |
+
start_time_base = time.time()
|
| 53 |
+
base_img = pipe(prompt,
|
| 54 |
+
num_inference_steps=steps, height=height_width, width=height_width,
|
| 55 |
+
negative_prompt=negative_prompt,
|
| 56 |
+
guidance_scale=guidance_scale).images[0]
|
| 57 |
+
end_time_base = time.time()
|
| 58 |
+
|
| 59 |
+
patch_attention_proc(pipe.unet, token_merge_args=token_merge_args)
|
| 60 |
+
|
| 61 |
+
torch.manual_seed(seed)
|
| 62 |
+
start_time_merge = time.time()
|
| 63 |
+
merged_img = pipe(prompt,
|
| 64 |
+
num_inference_steps=steps, height=height_width, width=height_width,
|
| 65 |
+
negative_prompt=negative_prompt,
|
| 66 |
+
guidance_scale=guidance_scale).images[0]
|
| 67 |
+
end_time_merge = time.time()
|
| 68 |
+
|
| 69 |
+
result = f"Baseline image: {end_time_base-start_time_base:.2f} sec | {'ToDo' if method == 'todo' else 'ToMe'} image: {end_time_merge-start_time_merge:.2f} sec"
|
| 70 |
+
|
| 71 |
+
return base_img, merged_img, result
|
| 72 |
+
|
| 73 |
with gr.Blocks() as demo:
|
| 74 |
prompt = gr.Textbox(interactive=True, label="prompt")
|
| 75 |
negative_prompt = gr.Textbox(interactive=True, label="negative_prompt")
|
|
|
|
| 82 |
seed = gr.Number(label="seed", value=1, precision=0)
|
| 83 |
result = gr.Textbox(label="Result")
|
| 84 |
|
| 85 |
+
with gr.Row():
|
| 86 |
+
base_image = gr.Image(label=f"baseline_image", type="pil", interactive=False)
|
| 87 |
+
output_image = gr.Image(label=f"output_image", type="pil", interactive=False)
|
| 88 |
|
| 89 |
gen = gr.Button("generate")
|
| 90 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
gen.click(generate, inputs=[prompt, seed, steps, height_width, negative_prompt,
|
| 92 |
+
guidance_scale, method], outputs=[base_image, output_image, result])
|
| 93 |
|
| 94 |
demo.launch(share=True)
|