aningineer commited on
Commit
cedad44
·
verified ·
1 Parent(s): 10d5e0b

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. README.md +0 -1
  2. app.py +62 -88
README.md CHANGED
@@ -5,7 +5,6 @@ sdk: gradio
5
  sdk_version: 4.19.2
6
  ---
7
  # ImprovedTokenMerge
8
- ![compare3.png](compare3.png)
9
  ![GEuoFn1bMAABQqD](https://github.com/ethansmith2000/ImprovedTokenMerge/assets/98723285/82e03423-81e6-47da-afa4-9c1b2c1c4aeb)
10
 
11
  twitter thread explanation: https://twitter.com/Ethan_smith_20/status/1750533558509433137
 
5
  sdk_version: 4.19.2
6
  ---
7
  # ImprovedTokenMerge
 
8
  ![GEuoFn1bMAABQqD](https://github.com/ethansmith2000/ImprovedTokenMerge/assets/98723285/82e03423-81e6-47da-afa4-9c1b2c1c4aeb)
9
 
10
  twitter thread explanation: https://twitter.com/Ethan_smith_20/status/1750533558509433137
app.py CHANGED
@@ -12,6 +12,64 @@ pipe = diffusers.StableDiffusionPipeline.from_pretrained("Lykon/DreamShaper").to
12
  pipe.scheduler = diffusers.EulerDiscreteScheduler.from_config(pipe.scheduler.config)
13
  pipe.safety_checker = None
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  with gr.Blocks() as demo:
16
  prompt = gr.Textbox(interactive=True, label="prompt")
17
  negative_prompt = gr.Textbox(interactive=True, label="negative_prompt")
@@ -24,97 +82,13 @@ with gr.Blocks() as demo:
24
  seed = gr.Number(label="seed", value=1, precision=0)
25
  result = gr.Textbox(label="Result")
26
 
27
- output_image = gr.Image(label=f"output_image", type="pil", interactive=False)
 
 
28
 
29
  gen = gr.Button("generate")
30
 
31
- def which_image(img, target_val=253, width=1024):
32
- npimg = np.array(img)
33
- loc = np.where(npimg[:, :, 3] == target_val)[1].item()
34
- if loc > width:
35
- print("Right Image is merged!")
36
- else:
37
- print("Left Image is merged!")
38
-
39
- @spaces.GPU
40
- def generate(prompt, seed, steps, height_width, negative_prompt, guidance_scale, method):
41
-
42
- downsample_factor = 2
43
- ratio = 0.38
44
- merge_method = "downsample" if method == "todo" else "similarity"
45
- merge_tokens = "keys/values" if method == "todo" else "all"
46
-
47
- if height_width == 1024:
48
- downsample_factor = 2
49
- ratio = 0.75
50
- downsample_factor_level_2 = 1
51
- ratio_level_2 = 0.0
52
- elif height_width == 1536:
53
- downsample_factor = 3
54
- ratio = 0.89
55
- downsample_factor_level_2 = 1
56
- ratio_level_2 = 0.0
57
- elif height_width == 2048:
58
- downsample_factor = 4
59
- ratio = 0.9375
60
- downsample_factor_level_2 = 2
61
- ratio_level_2 = 0.75
62
-
63
- token_merge_args = {"ratio": ratio,
64
- "merge_tokens": merge_tokens,
65
- "merge_method": merge_method,
66
- "downsample_method": "nearest",
67
- "downsample_factor": downsample_factor,
68
- "timestep_threshold_switch": 0.0,
69
- "timestep_threshold_stop": 0.0,
70
- "downsample_factor_level_2": downsample_factor_level_2,
71
- "ratio_level_2": ratio_level_2
72
- }
73
-
74
- l_r = torch.rand(1).item()
75
- torch.manual_seed(seed)
76
- start_time_base = time.time()
77
- base_img = pipe(prompt,
78
- num_inference_steps=steps, height=height_width, width=height_width,
79
- negative_prompt=negative_prompt,
80
- guidance_scale=guidance_scale).images[0]
81
- end_time_base = time.time()
82
-
83
- patch_attention_proc(pipe.unet, token_merge_args=token_merge_args)
84
-
85
- torch.manual_seed(seed)
86
- start_time_merge = time.time()
87
- merged_img = pipe(prompt,
88
- num_inference_steps=steps, height=height_width, width=height_width,
89
- negative_prompt=negative_prompt,
90
- guidance_scale=guidance_scale).images[0]
91
- end_time_merge = time.time()
92
-
93
- base_img = base_img.convert("RGBA")
94
- merged_img = merged_img.convert("RGBA")
95
- merged_img = np.array(merged_img)
96
- halfh, halfw = height_width // 2, height_width // 2
97
- merged_img[halfh, halfw, 3] = 253 # set the center pixel of the merged image to be ever so slightly below 255 in alpha channel
98
- merged_img = Image.fromarray(merged_img)
99
- final_img = Image.new(size=(height_width * 2, height_width), mode="RGBA")
100
-
101
- if l_r > 0.5:
102
- left_img = base_img
103
- right_img = merged_img
104
- else:
105
- left_img = merged_img
106
- right_img = base_img
107
-
108
- final_img.paste(left_img, (0, 0))
109
- final_img.paste(right_img, (height_width, 0))
110
-
111
- which_image(final_img, width=height_width)
112
-
113
- result = f"Baseline image: {end_time_base-start_time_base:.2f} sec | {'ToDo' if method == 'todo' else 'ToMe'} image: {end_time_merge-start_time_merge:.2f} sec"
114
-
115
- return final_img, result
116
-
117
  gen.click(generate, inputs=[prompt, seed, steps, height_width, negative_prompt,
118
- guidance_scale, method], outputs=[output_image, result])
119
 
120
  demo.launch(share=True)
 
12
  pipe.scheduler = diffusers.EulerDiscreteScheduler.from_config(pipe.scheduler.config)
13
  pipe.safety_checker = None
14
 
15
+ @spaces.GPU
16
+ def generate(prompt, seed, steps, height_width, negative_prompt, guidance_scale, method):
17
+
18
+ downsample_factor = 2
19
+ ratio = 0.38
20
+ merge_method = "downsample" if method == "todo" else "similarity"
21
+ merge_tokens = "keys/values" if method == "todo" else "all"
22
+
23
+ if height_width == 1024:
24
+ downsample_factor = 2
25
+ ratio = 0.75
26
+ downsample_factor_level_2 = 1
27
+ ratio_level_2 = 0.0
28
+ elif height_width == 1536:
29
+ downsample_factor = 3
30
+ ratio = 0.89
31
+ downsample_factor_level_2 = 1
32
+ ratio_level_2 = 0.0
33
+ elif height_width == 2048:
34
+ downsample_factor = 4
35
+ ratio = 0.9375
36
+ downsample_factor_level_2 = 2
37
+ ratio_level_2 = 0.75
38
+
39
+ token_merge_args = {"ratio": ratio,
40
+ "merge_tokens": merge_tokens,
41
+ "merge_method": merge_method,
42
+ "downsample_method": "nearest",
43
+ "downsample_factor": downsample_factor,
44
+ "timestep_threshold_switch": 0.0,
45
+ "timestep_threshold_stop": 0.0,
46
+ "downsample_factor_level_2": downsample_factor_level_2,
47
+ "ratio_level_2": ratio_level_2
48
+ }
49
+
50
+ l_r = torch.rand(1).item()
51
+ torch.manual_seed(seed)
52
+ start_time_base = time.time()
53
+ base_img = pipe(prompt,
54
+ num_inference_steps=steps, height=height_width, width=height_width,
55
+ negative_prompt=negative_prompt,
56
+ guidance_scale=guidance_scale).images[0]
57
+ end_time_base = time.time()
58
+
59
+ patch_attention_proc(pipe.unet, token_merge_args=token_merge_args)
60
+
61
+ torch.manual_seed(seed)
62
+ start_time_merge = time.time()
63
+ merged_img = pipe(prompt,
64
+ num_inference_steps=steps, height=height_width, width=height_width,
65
+ negative_prompt=negative_prompt,
66
+ guidance_scale=guidance_scale).images[0]
67
+ end_time_merge = time.time()
68
+
69
+ result = f"Baseline image: {end_time_base-start_time_base:.2f} sec | {'ToDo' if method == 'todo' else 'ToMe'} image: {end_time_merge-start_time_merge:.2f} sec"
70
+
71
+ return base_img, merged_img, result
72
+
73
  with gr.Blocks() as demo:
74
  prompt = gr.Textbox(interactive=True, label="prompt")
75
  negative_prompt = gr.Textbox(interactive=True, label="negative_prompt")
 
82
  seed = gr.Number(label="seed", value=1, precision=0)
83
  result = gr.Textbox(label="Result")
84
 
85
+ with gr.Row():
86
+ base_image = gr.Image(label=f"baseline_image", type="pil", interactive=False)
87
+ output_image = gr.Image(label=f"output_image", type="pil", interactive=False)
88
 
89
  gen = gr.Button("generate")
90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  gen.click(generate, inputs=[prompt, seed, steps, height_width, negative_prompt,
92
+ guidance_scale, method], outputs=[base_image, output_image, result])
93
 
94
  demo.launch(share=True)