demo interface changes

#1
by niulx - opened
Files changed (9) hide show
  1. .gitattributes +0 -1
  2. app.py +268 -221
  3. img2.png +0 -3
  4. img3.png +0 -0
  5. img4.png +0 -0
  6. main.py +6 -16
  7. requirements.txt +3 -9
  8. segment.py +11 -21
  9. utils.py +1 -0
.gitattributes DELETED
@@ -1 +0,0 @@
1
- img2.png filter=lfs diff=lfs merge=lfs -text
 
 
app.py CHANGED
@@ -1,7 +1,6 @@
 
1
  import os
2
  import copy
3
- #import spaces
4
- from main import run_main
5
  from PIL import Image
6
  import matplotlib
7
  import numpy as np
@@ -11,12 +10,10 @@ from utils_mask import process_mask_to_follow_priority, mask_union, visualize_ma
11
  from pathlib import Path
12
  from PIL import Image
13
  from functools import partial
14
- import time
15
-
16
  LENGTH=512 #length of the square area displaying/editing images
17
  TRANSPARENCY = 150 # transparency of the mask in display
18
 
19
-
20
  def add_mask(mask_np_list_updated, mask_label_list):
21
  mask_new = np.zeros_like(mask_np_list_updated[0])
22
  mask_np_list_updated.append(mask_new)
@@ -35,25 +32,89 @@ def create_segmentation(mask_np_list):
35
  segmentation = Image.fromarray(np.uint8(segmentation*255))
36
  return segmentation
37
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
- #@spaces.GPU
40
- def run_segmentation_wrapper(image):
41
  try:
42
- print(image.shape)
43
- image, mask_np_list,mask_label_list = run_segmentation(image)
44
- #image = image.convert('RGB')
 
 
45
  segmentation = create_segmentation(mask_np_list)
46
  print("!!", len(mask_np_list))
47
- max_val = len(mask_np_list)-1
48
- sliderup = gr.Slider(value = 0, minimum=0, maximum=max_val, step=1, visible=True)
49
- gr.Info('Segmentation finish. Select mask id and move to the next step.')
50
- return image, segmentation, mask_np_list, mask_label_list, image, sliderup, sliderup , 'Segmentation finish. Select mask id and move to the next step.'
51
- except Exception as e:
52
- print(e)
53
- sliderup = gr.Slider(value = 0, minimum=0, maximum=1, step=1, visible=False)
54
- gr.Warning('Please upload an image before proceeding.')
55
- return None,None,None,None,None, sliderup, sliderup , 'Please upload an image before proceeding.'
56
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
  def transparent_paste_with_mask(backimg, foreimg, mask_np,transparency = 128):
59
  backimg_solid_np = np.array(backimg)
@@ -65,8 +126,11 @@ def transparent_paste_with_mask(backimg, foreimg, mask_np,transparency = 128):
65
  bimg_np = np.array(bimg)
66
  mask_np = mask_np[:,:,np.newaxis]
67
 
68
- new_img_np = bimg_np*mask_np + (1-mask_np)* backimg_solid_np
69
- return Image.fromarray(np.uint8(new_img_np))
 
 
 
70
 
71
  def show_segmentation(image, segmentation, flag):
72
  if flag is False:
@@ -97,32 +161,17 @@ def edit_mask_add(canvas, image, idx, mask_np_list):
97
  return mask_np_list_updated, image_edit
98
 
99
  def slider_release(index, image, mask_np_list_updated, mask_label_list):
100
- if index > len(mask_np_list_updated)-1:
101
- return image, "out of range", ""
 
102
  else:
103
  mask_np = mask_np_list_updated[index]
104
  mask_label = mask_label_list[index]
105
- index = mask_label.rfind('-')
106
- mask_label = mask_label[:index]
107
- if mask_label == 'handbag':
108
- mask_prompt = "white handbag"
109
- elif mask_label == 'person':
110
- mask_prompt = "little boy"
111
- elif mask_label == 'wall-other-merged':
112
- mask_prompt = "white wall"
113
- elif mask_label == 'table-merged':
114
- mask_prompt = "table"
115
- else:
116
- mask_prompt = mask_label
117
  segmentation = create_segmentation(mask_np_list_updated)
118
  new_image = transparent_paste_with_mask(image, segmentation, mask_np, transparency = TRANSPARENCY)
119
- gr.Info('Edit '+ mask_label)
120
- return new_image, mask_label, mask_prompt
121
- def image_change():
122
- return gr.Slider(value = 0, minimum=0, maximum=1, step=1, visible=False)
123
 
124
  def save_as_orig_mask(mask_np_list_updated, mask_label_list, input_folder="example_tmp"):
125
- print(mask_np_list_updated)
126
  try:
127
  assert np.all(sum(mask_np_list_updated)==1)
128
  except:
@@ -137,7 +186,6 @@ def save_as_orig_mask(mask_np_list_updated, mask_label_list, input_folder="examp
137
  visualize_mask_list_clean(mask_np_list_updated, savepath)
138
 
139
  def save_as_edit_mask(mask_np_list_updated, mask_label_list, input_folder="example_tmp"):
140
- print(mask_np_list_updated)
141
  try:
142
  assert np.all(sum(mask_np_list_updated)==1)
143
  except:
@@ -149,30 +197,12 @@ def save_as_edit_mask(mask_np_list_updated, mask_label_list, input_folder="examp
149
  savepath = os.path.join(input_folder, "seg_edited.png")
150
  visualize_mask_list_clean(mask_np_list_updated, savepath)
151
 
152
-
153
-
154
- def button_clickable(is_clickable):
155
- return gr.Button(interactive=is_clickable)
156
-
157
-
158
-
159
- def load_pil_img():
160
- from PIL import Image
161
- return Image.open("example_tmp/text/out_text_0.png")
162
-
163
- def change_image(img):
164
- return None
165
-
166
-
167
  import shutil
168
  if os.path.isdir("./example_tmp"):
169
  shutil.rmtree("./example_tmp")
170
 
171
-
172
-
173
-
174
  from segment import run_segmentation
175
-
176
  with gr.Blocks() as demo:
177
  image = gr.State() # store mask
178
  image_loaded = gr.State()
@@ -188,186 +218,203 @@ with gr.Blocks() as demo:
188
  with gr.Row():
189
  gr.Markdown("""# D-Edit""")
190
 
191
-
192
- with gr.Row():
193
- with gr.Column():
194
- canvas = gr.Image(value = None, type="numpy", label="Show Mask", show_label=True, height=LENGTH, width=LENGTH, interactive=True)
195
- example_inps = [['./img.png'],['./img2.png'],['./img3.png'],['./img4.png']]
196
- gr.Examples(examples=example_inps, inputs=[canvas],
197
- label='examples', cache_examples='lazy', outputs=[],
198
- fn=change_image)
199
- gr.Markdown(f"Each image must first undergo segmentation. Afterwards, you can modify the \n mask ID and the prompt for image editing, then proceed with the editing process. \n The link of D-edit paper: [https://arxiv.org/abs/2403.04880v2](https://arxiv.org/abs/2403.04880v2), [https://huggingface.co/papers/2403.04880](https://huggingface.co/papers/2403.04880)")
200
-
201
- with gr.Column():
202
- result_info0 = gr.Text(label="Response")
203
- segment_button = gr.Button("Step 1. Run segmentation")
204
- flag = gr.State(False)
205
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
  # mask_np_list_updated.value = copy.deepcopy(mask_np_list.value) #!!
207
  mask_np_list_updated = mask_np_list
208
- gr.Markdown("""<p style="text-align: center; font-size: 20px">Edit Mask (Do not change it during the editing process)</p>""")
209
- slider = gr.Slider(0, 20, step=1, label = 'mask id', visible=False)
210
- label = gr.Text(label='label')
211
-
212
-
213
-
 
 
 
 
 
 
 
214
 
215
- result_info = gr.Text(label="Response")
216
-
217
- opt_flag = gr.State(0)
218
- gr.Markdown("""<p style="text-align: center; font-size: 20px">Optimization settings</p>""")
219
- with gr.Accordion(label="Advanced settings", open=False):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
  num_tokens = gr.Number(value="5", label="num tokens to represent each object", interactive= True)
221
  num_tokens_global = num_tokens
222
- embedding_learning_rate = gr.Textbox(value="0.00025", label="Embedding optimization: Learning rate", interactive= True )
223
- max_emb_train_steps = gr.Number(value="6", label="embedding optimization: Training steps", interactive= True )
 
 
 
224
 
225
- diffusion_model_learning_rate = gr.Textbox(value="0.0002", label="UNet Optimization: Learning rate", interactive= True )
226
- max_diffusion_train_steps = gr.Number(value="28", label="UNet Optimization: Learning rate: Training steps", interactive= True )
227
 
228
- train_batch_size = gr.Number(value="20", label="Batch size", interactive= True )
229
- gradient_accumulation_steps=gr.Number(value="2", label="Gradient accumulation", interactive= True )
230
-
231
- def run_optimization_wrapper (
232
- mask_np_list,
233
- mask_label_list,
234
- image,
235
- opt_flag,
236
- num_tokens,
237
- embedding_learning_rate ,
238
- max_emb_train_steps ,
239
- diffusion_model_learning_rate ,
240
- max_diffusion_train_steps,
241
- train_batch_size,
242
- gradient_accumulation_steps,
243
- ):
244
- try:
245
  run_optimization = partial(
246
- run_main,
247
- mask_np_list=mask_np_list,
248
- mask_label_list=mask_label_list,
249
- image_gt=np.array(image),
250
  num_tokens=int(num_tokens),
251
  embedding_learning_rate = float(embedding_learning_rate),
252
- max_emb_train_steps = min(int(max_emb_train_steps),50),
253
  diffusion_model_learning_rate= float(diffusion_model_learning_rate),
254
- max_diffusion_train_steps = min(int(max_diffusion_train_steps),100),
255
  train_batch_size=int(train_batch_size),
256
  gradient_accumulation_steps=int(gradient_accumulation_steps)
257
  )
258
  run_optimization()
259
- gr.Info("Optimization Finished! Move to the next step.")
260
- return "Optimization finished! Move to the next step."#,gr.Button("Step 3. Run Editing",interactive = True)
261
- except Exception as e:
262
- print(e)
263
- gr.Error("e")
264
- return "Error: use a smaller batch size or try latter."#,gr.Button("Step 3. Run Editing",interactive = False)
265
-
266
-
267
-
268
- if 1:
269
- with gr.Row():
270
- with gr.Column():
271
- canvas_text_edit = gr.Image(value = None, type = "pil", label="Editing results", show_label=True,visible = True)
272
- # canvas_text_edit = gr.Gallery(label = "Edited results")
273
-
274
- with gr.Column():
275
- gr.Markdown("""<p style="text-align: center; font-size: 20px">Editing setting</p>""")
276
- tgt_prompt = gr.Textbox(value="text prompt", label="Editing: Text prompt", interactive= True )
277
- with gr.Accordion(label="Advanced settings", open=False):
278
- slider2 = gr.Slider(0, 20, step=1, label = 'mask id', visible=False)
279
- guidance_scale = gr.Textbox(value="5", label="Editing: CFG guidance scale", interactive= True )
280
- num_sampling_steps = gr.Number(value="20", label="Editing: Sampling steps", interactive= True )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
281
  edge_thickness = gr.Number(value="10", label="Editing: Edge thickness", interactive= True )
282
  strength = gr.Textbox(value="0.5", label="Editing: Mask strength", interactive= True )
283
-
284
- add_button = gr.Button("Step 2. Run Editing",interactive = True)
285
- def run_edit_text_wrapper(
286
- mask_np_list,
287
- mask_label_list,
288
- image,
289
- num_tokens,
290
- guidance_scale,
291
- num_sampling_steps ,
292
- strength ,
293
- edge_thickness,
294
- tgt_prompt ,
295
- tgt_index
296
- ):
 
 
 
 
 
 
 
 
 
 
 
 
297
 
298
- run_edit_text = partial(
299
- run_main,
300
- mask_np_list=mask_np_list,
301
- mask_label_list=mask_label_list,
302
- image_gt=np.array(image),
303
- load_trained=True,
304
- text=True,
305
- num_tokens = int(num_tokens_global.value),
306
- guidance_scale = float(guidance_scale),
307
- num_sampling_steps = int(num_sampling_steps),
308
- strength = float(strength),
309
- edge_thickness = int(edge_thickness),
310
- num_imgs = 1,
311
- tgt_prompt = tgt_prompt,
312
- tgt_index = int(tgt_index)
 
 
 
 
 
313
  )
314
- run_edit_text()
315
- gr.Info('Image editing completed.')
316
- return load_pil_img()
317
-
318
-
319
-
320
- def run_total_wrapper(mask_np_list, mask_label_list, image_loaded, opt_flag, num_tokens, embedding_learning_rate, max_emb_train_steps, diffusion_model_learning_rate, max_diffusion_train_steps, train_batch_size, gradient_accumulation_steps, num_tokens_global, guidance_scale, num_sampling_steps, strength, edge_thickness, tgt_prompt, slider2):
321
- result_info = run_optimization_wrapper(mask_np_list, mask_label_list, image_loaded, opt_flag, num_tokens, embedding_learning_rate, max_emb_train_steps, diffusion_model_learning_rate, max_diffusion_train_steps, train_batch_size, gradient_accumulation_steps)
322
- canvas_text_edit = run_edit_text_wrapper(mask_np_list, mask_label_list, image_loaded, num_tokens_global, guidance_scale, num_sampling_steps, strength, edge_thickness, tgt_prompt, slider2)
323
- return result_info, canvas_text_edit
324
-
325
-
326
- add_button.click(
327
- run_total_wrapper,
328
- inputs=[
329
- mask_np_list,
330
- mask_label_list,
331
- image_loaded,
332
- opt_flag,
333
- num_tokens,
334
- embedding_learning_rate,
335
- max_emb_train_steps,
336
- diffusion_model_learning_rate,
337
- max_diffusion_train_steps,
338
- train_batch_size,
339
- gradient_accumulation_steps,
340
- num_tokens_global,
341
- guidance_scale,
342
- num_sampling_steps,
343
- strength,
344
- edge_thickness,
345
- tgt_prompt,
346
- slider2
347
- ],
348
- outputs=[result_info, canvas_text_edit],
349
- )
350
-
351
-
352
-
353
-
354
- canvas.upload(image_change, inputs=[], outputs=[slider])
355
-
356
- slider.release(slider_release,
357
- inputs = [slider, image_loaded, mask_np_list_updated, mask_label_list],
358
- outputs= [canvas, label,tgt_prompt])
359
-
360
- slider.change(
361
- lambda x: x,
362
- inputs=[slider],
363
- outputs=[slider2]
364
- )
365
-
366
 
367
- segment_button.click(run_segmentation_wrapper,
368
- [canvas] ,
369
- [image_loaded, segmentation, mask_np_list, mask_label_list, canvas, slider, slider2, result_info0] )
370
 
371
 
372
 
373
- demo.queue().launch(debug=True)
 
1
+
2
  import os
3
  import copy
 
 
4
  from PIL import Image
5
  import matplotlib
6
  import numpy as np
 
10
  from pathlib import Path
11
  from PIL import Image
12
  from functools import partial
13
+ from main import run_main
 
14
  LENGTH=512 #length of the square area displaying/editing images
15
  TRANSPARENCY = 150 # transparency of the mask in display
16
 
 
17
  def add_mask(mask_np_list_updated, mask_label_list):
18
  mask_new = np.zeros_like(mask_np_list_updated[0])
19
  mask_np_list_updated.append(mask_new)
 
32
  segmentation = Image.fromarray(np.uint8(segmentation*255))
33
  return segmentation
34
 
35
+ def load_mask_ui(input_folder="example_tmp",load_edit = False):
36
+ if not load_edit:
37
+ mask_list, mask_label_list = load_mask(input_folder)
38
+ else:
39
+ mask_list, mask_label_list = load_mask_edit(input_folder)
40
+
41
+ mask_np_list = []
42
+ for m in mask_list:
43
+ mask_np_list. append( m.cpu().numpy())
44
+
45
+ return mask_np_list, mask_label_list
46
 
47
+ def load_image_ui(load_edit, input_folder="example_tmp"):
 
48
  try:
49
+ for img_path in Path(input_folder).iterdir():
50
+ if img_path.name in ["img_512.png"]:
51
+ image = Image.open(img_path)
52
+ mask_np_list, mask_label_list = load_mask_ui(input_folder, load_edit = load_edit)
53
+ image = image.convert('RGB')
54
  segmentation = create_segmentation(mask_np_list)
55
  print("!!", len(mask_np_list))
56
+ return image, segmentation, mask_np_list, mask_label_list, image
57
+ except:
58
+ print("Image folder invalid: The folder should contain image.png")
59
+ return None, None, None, None, None
60
+
61
+ # def run_edit_text(
62
+ # num_tokens,
63
+ # num_sampling_steps,
64
+ # strength,
65
+ # edge_thickness,
66
+ # tgt_prompt,
67
+ # tgt_idx,
68
+ # guidance_scale,
69
+ # input_folder="example_tmp"
70
+ # ):
71
+ # subprocess.run(["python",
72
+ # "main.py" ,
73
+ # "--text=True",
74
+ # "--name={}".format(input_folder),
75
+ # "--dpm={}".format("sd"),
76
+ # "--resolution={}".format(512),
77
+ # "--load_trained",
78
+ # "--num_tokens={}".format(num_tokens),
79
+ # "--seed={}".format(2024),
80
+ # "--guidance_scale={}".format(guidance_scale),
81
+ # "--num_sampling_step={}".format(num_sampling_steps),
82
+ # "--strength={}".format(strength),
83
+ # "--edge_thickness={}".format(edge_thickness),
84
+ # "--num_imgs={}".format(2),
85
+ # "--tgt_prompt={}".format(tgt_prompt) ,
86
+ # "--tgt_index={}".format(tgt_idx)
87
+ # ])
88
+
89
+ # return Image.open(os.path.join(input_folder, "text", "out_text_0.png"))
90
+
91
+
92
+ # def run_optimization(
93
+ # num_tokens,
94
+ # embedding_learning_rate,
95
+ # max_emb_train_steps,
96
+ # diffusion_model_learning_rate,
97
+ # max_diffusion_train_steps,
98
+ # train_batch_size,
99
+ # gradient_accumulation_steps,
100
+ # input_folder = "example_tmp"
101
+ # ):
102
+ # subprocess.run(["python",
103
+ # "main.py" ,
104
+ # "--name={}".format(input_folder),
105
+ # "--dpm={}".format("sd"),
106
+ # "--resolution={}".format(512),
107
+ # "--num_tokens={}".format(num_tokens),
108
+ # "--embedding_learning_rate={}".format(embedding_learning_rate),
109
+ # "--diffusion_model_learning_rate={}".format(diffusion_model_learning_rate),
110
+ # "--max_emb_train_steps={}".format(max_emb_train_steps),
111
+ # "--max_diffusion_train_steps={}".format(max_diffusion_train_steps),
112
+ # "--train_batch_size={}".format(train_batch_size),
113
+ # "--gradient_accumulation_steps={}".format(gradient_accumulation_steps)
114
+
115
+ # ])
116
+ # return
117
+
118
 
119
  def transparent_paste_with_mask(backimg, foreimg, mask_np,transparency = 128):
120
  backimg_solid_np = np.array(backimg)
 
126
  bimg_np = np.array(bimg)
127
  mask_np = mask_np[:,:,np.newaxis]
128
 
129
+ try:
130
+ new_img_np = bimg_np*mask_np + (1-mask_np)* backimg_solid_np
131
+ return Image.fromarray(new_img_np)
132
+ except:
133
+ import pdb; pdb.set_trace()
134
 
135
  def show_segmentation(image, segmentation, flag):
136
  if flag is False:
 
161
  return mask_np_list_updated, image_edit
162
 
163
  def slider_release(index, image, mask_np_list_updated, mask_label_list):
164
+
165
+ if index > len(mask_np_list_updated):
166
+ return image, "out of range"
167
  else:
168
  mask_np = mask_np_list_updated[index]
169
  mask_label = mask_label_list[index]
 
 
 
 
 
 
 
 
 
 
 
 
170
  segmentation = create_segmentation(mask_np_list_updated)
171
  new_image = transparent_paste_with_mask(image, segmentation, mask_np, transparency = TRANSPARENCY)
172
+ return new_image, mask_label
 
 
 
173
 
174
  def save_as_orig_mask(mask_np_list_updated, mask_label_list, input_folder="example_tmp"):
 
175
  try:
176
  assert np.all(sum(mask_np_list_updated)==1)
177
  except:
 
186
  visualize_mask_list_clean(mask_np_list_updated, savepath)
187
 
188
  def save_as_edit_mask(mask_np_list_updated, mask_label_list, input_folder="example_tmp"):
 
189
  try:
190
  assert np.all(sum(mask_np_list_updated)==1)
191
  except:
 
197
  savepath = os.path.join(input_folder, "seg_edited.png")
198
  visualize_mask_list_clean(mask_np_list_updated, savepath)
199
 
200
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
  import shutil
202
  if os.path.isdir("./example_tmp"):
203
  shutil.rmtree("./example_tmp")
204
 
 
 
 
205
  from segment import run_segmentation
 
206
  with gr.Blocks() as demo:
207
  image = gr.State() # store mask
208
  image_loaded = gr.State()
 
218
  with gr.Row():
219
  gr.Markdown("""# D-Edit""")
220
 
221
+ with gr.Tab(label="1 Edit mask"):
222
+ with gr.Row():
223
+ with gr.Column():
224
+ canvas = gr.Image(value = "./img.png", type="numpy", label="Draw Mask", show_label=True, height=LENGTH, width=LENGTH, interactive=True)
225
+
226
+ segment_button = gr.Button("1.1 Run segmentation")
227
+ segment_button.click(run_segmentation,
228
+ [canvas, block_flag] ,
229
+ [block_flag] )
230
+
231
+ text_button = gr.Button("Waiting 1.1 to complete")
232
+ text_button.click(load_image_ui,
233
+ [ false] ,
234
+ [image_loaded, segmentation, mask_np_list, mask_label_list, canvas] )
235
+
236
+ load_edit_button = gr.Button("Waiting 1.1 to complete")
237
+ load_edit_button.click(load_image_ui,
238
+ [ true] ,
239
+ [image_loaded, segmentation, mask_np_list, mask_label_list, canvas] )
240
+
241
+ show_segment = gr.Checkbox(label = "Waiting 1.1 to complete")
242
+ flag = gr.State(False)
243
+ show_segment.select(show_segmentation,
244
+ [image_loaded, segmentation, flag],
245
+ [canvas, flag])
246
+ def show_more_buttons():
247
+ return gr.Button("1.2 Load original masks"), gr.Button("1.2 Load edited masks") , gr.Checkbox(label = "Show Segmentation")
248
+ block_flag.change(show_more_buttons, [], [text_button,load_edit_button,show_segment ])
249
+
250
+
251
  # mask_np_list_updated.value = copy.deepcopy(mask_np_list.value) #!!
252
  mask_np_list_updated = mask_np_list
253
+ with gr.Column():
254
+ gr.Markdown("""<p style="text-align: center; font-size: 20px">Edit Mask (Optional)</p>""")
255
+ slider = gr.Slider(0, 20, step=1, interactive=True)
256
+ label = gr.Textbox()
257
+ slider.release(slider_release,
258
+ inputs = [slider, image_loaded, mask_np_list_updated, mask_label_list],
259
+ outputs= [canvas, label]
260
+ )
261
+ add_button = gr.Button("Add")
262
+ add_button.click( edit_mask_add,
263
+ [canvas, image_loaded, slider, mask_np_list_updated] ,
264
+ [mask_np_list_updated, canvas]
265
+ )
266
 
267
+ save_button2 = gr.Button("Set and Save as edited masks")
268
+ save_button2.click( save_as_edit_mask,
269
+ [mask_np_list_updated, mask_label_list] ,
270
+ [] )
271
+
272
+ save_button = gr.Button("Set and Save as original masks")
273
+ save_button.click( save_as_orig_mask,
274
+ [mask_np_list_updated, mask_label_list] ,
275
+ [] )
276
+
277
+ back_button = gr.Button("Back to current seg")
278
+ back_button.click( load_mask_ui,
279
+ [] ,
280
+ [ mask_np_list_updated,mask_label_list] )
281
+
282
+ add_mask_button = gr.Button("Add new empty mask")
283
+ add_mask_button.click(add_mask,
284
+ [mask_np_list_updated, mask_label_list] ,
285
+ [mask_np_list_updated, mask_label_list] )
286
+
287
+ with gr.Tab(label="2 Optimization"):
288
+ with gr.Row():
289
+ with gr.Column():
290
+
291
+ txt_box = gr.Textbox("Click to start optimization...", interactive = False)
292
+
293
+ opt_flag = gr.State(0)
294
+ gr.Markdown("""<p style="text-align: center; font-size: 20px">Optimization settings (SD)</p>""")
295
  num_tokens = gr.Number(value="5", label="num tokens to represent each object", interactive= True)
296
  num_tokens_global = num_tokens
297
+ embedding_learning_rate = gr.Textbox(value="0.0001", label="Embedding optimization: Learning rate", interactive= True )
298
+ max_emb_train_steps = gr.Number(value="200", label="embedding optimization: Training steps", interactive= True )
299
+
300
+ diffusion_model_learning_rate = gr.Textbox(value="0.00005", label="UNet Optimization: Learning rate", interactive= True )
301
+ max_diffusion_train_steps = gr.Number(value="200", label="UNet Optimization: Learning rate: Training steps", interactive= True )
302
 
303
+ train_batch_size = gr.Number(value="5", label="Batch size", interactive= True )
304
+ gradient_accumulation_steps=gr.Number(value="5", label="Gradient accumulation", interactive= True )
305
 
306
+ add_button = gr.Button("Run optimization")
307
+ def run_optimization_wrapper (
308
+ opt_flag,
309
+ num_tokens,
310
+ embedding_learning_rate ,
311
+ max_emb_train_steps ,
312
+ diffusion_model_learning_rate ,
313
+ max_diffusion_train_steps,
314
+ train_batch_size,
315
+ gradient_accumulation_steps
316
+ ):
 
 
 
 
 
 
317
  run_optimization = partial(
318
+ run_main,
 
 
 
319
  num_tokens=int(num_tokens),
320
  embedding_learning_rate = float(embedding_learning_rate),
321
+ max_emb_train_steps = int(max_emb_train_steps),
322
  diffusion_model_learning_rate= float(diffusion_model_learning_rate),
323
+ max_diffusion_train_steps = int(max_diffusion_train_steps),
324
  train_batch_size=int(train_batch_size),
325
  gradient_accumulation_steps=int(gradient_accumulation_steps)
326
  )
327
  run_optimization()
328
+ return opt_flag+1
329
+
330
+ add_button.click(run_optimization_wrapper,
331
+ inputs = [
332
+ opt_flag,
333
+ num_tokens,
334
+ embedding_learning_rate ,
335
+ max_emb_train_steps ,
336
+ diffusion_model_learning_rate ,
337
+ max_diffusion_train_steps,
338
+ train_batch_size,
339
+ gradient_accumulation_steps
340
+ ],
341
+ outputs = [opt_flag]
342
+ )
343
+
344
+ def change_text(txt_box):
345
+ return gr.Textbox("Optimization Finished!", interactive = False)
346
+ def change_text2(txt_box):
347
+ return gr.Textbox("Start optimization, check logs for progress...", interactive = False)
348
+ add_button.click(change_text2, txt_box, txt_box)
349
+ opt_flag.change(change_text, txt_box, txt_box)
350
+
351
+ with gr.Tab(label="3 Editing"):
352
+ with gr.Tab(label="3.1 Text-based editing"):
353
+
354
+ with gr.Row():
355
+ with gr.Column():
356
+ canvas_text_edit = gr.Image(value = None, type = "pil", label="Editing results", show_label=True)
357
+ # canvas_text_edit = gr.Gallery(label = "Edited results")
358
+
359
+ with gr.Column():
360
+ gr.Markdown("""<p style="text-align: center; font-size: 20px">Editing setting (SD)</p>""")
361
+
362
+ tgt_prompt = gr.Textbox(value="White bag", label="Editing: Text prompt", interactive= True )
363
+ tgt_index = gr.Number(value="0", label="Editing: Object index", interactive= True )
364
+ guidance_scale = gr.Textbox(value="6", label="Editing: CFG guidance scale", interactive= True )
365
+ num_sampling_steps = gr.Number(value="50", label="Editing: Sampling steps", interactive= True )
366
  edge_thickness = gr.Number(value="10", label="Editing: Edge thickness", interactive= True )
367
  strength = gr.Textbox(value="0.5", label="Editing: Mask strength", interactive= True )
368
+
369
+ add_button = gr.Button("Run Editing")
370
+ def run_edit_text_wrapper(
371
+ num_tokens,
372
+ guidance_scale,
373
+ num_sampling_steps ,
374
+ strength ,
375
+ edge_thickness,
376
+ tgt_prompt ,
377
+ tgt_index
378
+ ):
379
+
380
+ run_edit_text = partial(
381
+ run_main,
382
+ load_trained=True,
383
+ text=True,
384
+ num_tokens = int(num_tokens_global.value),
385
+ guidance_scale = float(guidance_scale),
386
+ num_sampling_steps = int(num_sampling_steps),
387
+ strength = float(strength),
388
+ edge_thickness = int(edge_thickness),
389
+ num_imgs = 1,
390
+ tgt_prompt = tgt_prompt,
391
+ tgt_index = int(tgt_index)
392
+ )
393
+ return run_edit_text()
394
 
395
+ add_button.click(run_edit_text_wrapper,
396
+ inputs = [num_tokens_global,
397
+ guidance_scale,
398
+ num_sampling_steps,
399
+ strength ,
400
+ edge_thickness,
401
+ tgt_prompt ,
402
+ tgt_index
403
+ ],
404
+ outputs = [canvas_text_edit]
405
+ )
406
+
407
+ def load_pil_img():
408
+ from PIL import Image
409
+ return Image.open("example_tmp/text/out_text_0.png")
410
+
411
+ load_button = gr.Button("Load results")
412
+ load_button.click(load_pil_img,
413
+ inputs = [],
414
+ outputs = [canvas_text_edit]
415
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
416
 
 
 
 
417
 
418
 
419
 
420
+ demo.queue().launch(share=True, debug=True)
img2.png DELETED

Git LFS Details

  • SHA256: f0d93d36051ad4f4ce9b371d4122830bdbbda01c2a27e23a538b13e5cb3715f6
  • Pointer size: 132 Bytes
  • Size of remote file: 1.4 MB
img3.png DELETED
Binary file (259 kB)
 
img4.png DELETED
Binary file (45.9 kB)
 
main.py CHANGED
@@ -1,5 +1,4 @@
1
  import os
2
- import spaces
3
  import torch
4
  import numpy as np
5
  import argparse
@@ -10,14 +9,9 @@ from utils import load_image, load_mask, load_mask_edit
10
  from utils_mask import process_mask_move_torch, process_mask_remove_torch, mask_union_torch, mask_substract_torch, create_outer_edge_mask_torch
11
  from utils_mask import check_mask_overlap_torch, check_cover_all_torch, visualize_mask_list, get_mask_difference_torch, save_mask_list_to_npys
12
 
13
- @spaces.GPU(duration=45)
14
-
15
  def run_main(
16
  name="example_tmp",
17
  name_2=None,
18
- mask_np_list=None,
19
- mask_label_list=None,
20
- image_gt=None,
21
  dpm="sd",
22
  resolution=512,
23
  seed=42,
@@ -77,17 +71,13 @@ def run_main(
77
  base_output_folder = "."
78
 
79
  input_folder = os.path.join(base_input_folder, name)
80
- mask_list = []
81
- for mask_np in mask_np_list:
82
- mask = torch.from_numpy(mask_np.astype(np.uint8))
83
- mask_list.append(mask)
84
-
85
- #mask_list, mask_label_list = load_mask(input_folder)
86
  assert mask_list[0].shape[0] == resolution, "Segmentation should be done on size {}".format(resolution)
87
- #try:
88
- # image_gt = load_image(os.path.join(input_folder, "img_{}.png".format(resolution) ), size = resolution)
89
- #except:
90
- # image_gt = load_image(os.path.join(input_folder, "img_{}.jpg".format(resolution) ), size = resolution)
91
 
92
  if image:
93
  input_folder_2 = os.path.join(base_input_folder, name_2)
 
1
  import os
 
2
  import torch
3
  import numpy as np
4
  import argparse
 
9
  from utils_mask import process_mask_move_torch, process_mask_remove_torch, mask_union_torch, mask_substract_torch, create_outer_edge_mask_torch
10
  from utils_mask import check_mask_overlap_torch, check_cover_all_torch, visualize_mask_list, get_mask_difference_torch, save_mask_list_to_npys
11
 
 
 
12
  def run_main(
13
  name="example_tmp",
14
  name_2=None,
 
 
 
15
  dpm="sd",
16
  resolution=512,
17
  seed=42,
 
71
  base_output_folder = "."
72
 
73
  input_folder = os.path.join(base_input_folder, name)
74
+
75
+ mask_list, mask_label_list = load_mask(input_folder)
 
 
 
 
76
  assert mask_list[0].shape[0] == resolution, "Segmentation should be done on size {}".format(resolution)
77
+ try:
78
+ image_gt = load_image(os.path.join(input_folder, "img_{}.png".format(resolution) ), size = resolution)
79
+ except:
80
+ image_gt = load_image(os.path.join(input_folder, "img_{}.jpg".format(resolution) ), size = resolution)
81
 
82
  if image:
83
  input_folder_2 = os.path.join(base_input_folder, name_2)
requirements.txt CHANGED
@@ -1,17 +1,11 @@
1
- gradio==4.36.0
2
- torch
3
- torchvision
4
- huggingface_hub
5
-
6
- accelerate==0.27.2
7
- diffusers==0.30.2
8
- numpy==1.26.4
9
  torch==2.2.0
10
  torchvision==0.17.0
11
  transformers==4.37.2
 
 
12
  xformers==0.0.24
 
13
  scipy
14
- setuptools
15
  tqdm
16
  numpy
17
  safetensors
 
 
 
 
 
 
 
 
 
1
  torch==2.2.0
2
  torchvision==0.17.0
3
  transformers==4.37.2
4
+ accelerate==0.23.0
5
+ gradio==3.41.1
6
  xformers==0.0.24
7
+ diffusers==0.26.3
8
  scipy
 
9
  tqdm
10
  numpy
11
  safetensors
segment.py CHANGED
@@ -1,7 +1,6 @@
1
 
2
  from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation
3
  from PIL import Image
4
- import spaces
5
  import torch
6
  from collections import defaultdict
7
  import matplotlib.pyplot as plt
@@ -11,8 +10,6 @@ import os
11
  import numpy as np
12
  import argparse
13
  import matplotlib
14
- import gradio as gr
15
-
16
 
17
  def load_image(image_path, left=0, right=0, top=0, bottom=0, size = 512):
18
  if type(image_path) is str:
@@ -47,18 +44,14 @@ def draw_panoptic_segmentation(segmentation, segments_info,save_folder=None, nos
47
  instances_counter = defaultdict(int)
48
  handles = []
49
  label_list = []
50
-
51
- mask_np_list = []
52
-
53
  if not noseg:
54
  if torch.min(segmentation) == 0:
55
  mask = segmentation==0
56
  mask = mask.cpu().detach().numpy() # [512,512] bool
57
- print(mask.shape)
58
  segment_label = "rest"
 
59
  color = viridis(0)
60
  label = f"{segment_label}-{0}"
61
- mask_np_list.append(mask)
62
  handles.append(mpatches.Patch(color=color, label=label))
63
  label_list.append(label)
64
 
@@ -68,11 +61,10 @@ def draw_panoptic_segmentation(segmentation, segments_info,save_folder=None, nos
68
  if torch.min(segmentation) != 0:
69
  segment_id -= 1
70
  mask = mask.cpu().detach().numpy() # [512,512] bool
71
- print(mask.shape)
72
- mask_np_list.append(mask)
73
  segment_label = model.config.id2label[segment['label_id']]
74
  instances_counter[segment['label_id']] += 1
75
-
76
  color = viridis(segment_id)
77
 
78
  label = f"{segment_label}-{segment_id}"
@@ -80,10 +72,8 @@ def draw_panoptic_segmentation(segmentation, segments_info,save_folder=None, nos
80
  label_list.append(label)
81
  else:
82
  mask = np.full(segmentation.shape, True)
83
- print(mask.shape)
84
-
85
  segment_label = "all"
86
- mask_np_list.append(mask)
87
  color = viridis(0)
88
  label = f"{segment_label}-{0}"
89
  handles.append(mpatches.Patch(color=color, label=label))
@@ -95,11 +85,11 @@ def draw_panoptic_segmentation(segmentation, segments_info,save_folder=None, nos
95
  ax.legend(handles=handles)
96
  plt.savefig(os.path.join(save_folder, 'seg_init.png'), dpi=500 )
97
  print("; ".join(label_list))
98
- return mask_np_list,label_list
99
 
100
 
101
- @spaces.GPU(duration=10)
102
- def run_segmentation(image, name="example_tmp", size = 512, noseg=False):
 
103
 
104
  base_folder_path = "."
105
 
@@ -115,7 +105,7 @@ def run_segmentation(image, name="example_tmp", size = 512, noseg=False):
115
  image =Image.fromarray(image)
116
  image = image.resize((size, size))
117
  os.makedirs(name, exist_ok=True)
118
- #image.save(os.path.join(name,"img_{}.png".format(size)))
119
  inputs = processor(image, return_tensors="pt")
120
  with torch.no_grad():
121
  outputs = model(**inputs)
@@ -123,7 +113,7 @@ def run_segmentation(image, name="example_tmp", size = 512, noseg=False):
123
  panoptic_segmentation = processor.post_process_panoptic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
124
  save_folder = os.path.join(base_folder_path, name)
125
  os.makedirs(save_folder, exist_ok=True)
126
- mask_list,label_list = draw_panoptic_segmentation(**panoptic_segmentation, save_folder = save_folder, noseg = noseg, model = model)
127
  print("Finish segment")
128
- #block_flag += 1
129
- return image,mask_list,label_list#, gr.Button.update("1.2 Load edited masks",visible = True), gr.Checkbox.update(label = "Show Segmentation",visible = True)
 
1
 
2
  from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation
3
  from PIL import Image
 
4
  import torch
5
  from collections import defaultdict
6
  import matplotlib.pyplot as plt
 
10
  import numpy as np
11
  import argparse
12
  import matplotlib
 
 
13
 
14
  def load_image(image_path, left=0, right=0, top=0, bottom=0, size = 512):
15
  if type(image_path) is str:
 
44
  instances_counter = defaultdict(int)
45
  handles = []
46
  label_list = []
 
 
 
47
  if not noseg:
48
  if torch.min(segmentation) == 0:
49
  mask = segmentation==0
50
  mask = mask.cpu().detach().numpy() # [512,512] bool
 
51
  segment_label = "rest"
52
+ np.save( os.path.join(save_folder, "mask{}_{}.npy".format(0,"rest")) , mask)
53
  color = viridis(0)
54
  label = f"{segment_label}-{0}"
 
55
  handles.append(mpatches.Patch(color=color, label=label))
56
  label_list.append(label)
57
 
 
61
  if torch.min(segmentation) != 0:
62
  segment_id -= 1
63
  mask = mask.cpu().detach().numpy() # [512,512] bool
64
+
 
65
  segment_label = model.config.id2label[segment['label_id']]
66
  instances_counter[segment['label_id']] += 1
67
+ np.save( os.path.join(save_folder, "mask{}_{}.npy".format(segment_id,segment_label)) , mask)
68
  color = viridis(segment_id)
69
 
70
  label = f"{segment_label}-{segment_id}"
 
72
  label_list.append(label)
73
  else:
74
  mask = np.full(segmentation.shape, True)
 
 
75
  segment_label = "all"
76
+ np.save( os.path.join(save_folder, "mask{}_{}.npy".format(0,"all")) , mask)
77
  color = viridis(0)
78
  label = f"{segment_label}-{0}"
79
  handles.append(mpatches.Patch(color=color, label=label))
 
85
  ax.legend(handles=handles)
86
  plt.savefig(os.path.join(save_folder, 'seg_init.png'), dpi=500 )
87
  print("; ".join(label_list))
 
88
 
89
 
90
+
91
+
92
+ def run_segmentation(image, block_flag, name="example_tmp", size = 512, noseg=False):
93
 
94
  base_folder_path = "."
95
 
 
105
  image =Image.fromarray(image)
106
  image = image.resize((size, size))
107
  os.makedirs(name, exist_ok=True)
108
+ image.save(os.path.join(name,"img_{}.png".format(size)))
109
  inputs = processor(image, return_tensors="pt")
110
  with torch.no_grad():
111
  outputs = model(**inputs)
 
113
  panoptic_segmentation = processor.post_process_panoptic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
114
  save_folder = os.path.join(base_folder_path, name)
115
  os.makedirs(save_folder, exist_ok=True)
116
+ draw_panoptic_segmentation(**panoptic_segmentation, save_folder = save_folder, noseg = noseg, model = model)
117
  print("Finish segment")
118
+ block_flag += 1
119
+ return block_flag
utils.py CHANGED
@@ -249,6 +249,7 @@ def load_mask (input_folder):
249
  except:
250
  print("please check mask")
251
  # plt.imsave( "out_mask.png", mask_list_edit[0])
 
252
  return mask_list, mask_label_list
253
 
254
  def load_image(image_path, left=0, right=0, top=0, bottom=0, size = 512):
 
249
  except:
250
  print("please check mask")
251
  # plt.imsave( "out_mask.png", mask_list_edit[0])
252
+ import pdb; pdb.set_trace()
253
  return mask_list, mask_label_list
254
 
255
  def load_image(image_path, left=0, right=0, top=0, bottom=0, size = 512):