dragynir commited on
Commit
4240411
·
1 Parent(s): d91aa80

add flag from mask

Browse files
Files changed (2) hide show
  1. app.py +4 -1
  2. src/pipeline.py +6 -2
app.py CHANGED
@@ -17,6 +17,7 @@ def process(
17
  input_image: np.ndarray,
18
  prompt: str,
19
  negative_prompt: str,
 
20
  num_inference_steps: int,
21
  guidance_scale: float,
22
  conditioning_scale: float,
@@ -29,6 +30,7 @@ def process(
29
  # control_image=input_image,
30
  # prompt=prompt,
31
  # negative_prompt=negative_prompt,
 
32
  # num_inference_steps=num_inference_steps,
33
  # guidance_scale=guidance_scale,
34
  # conditioning_scale=conditioning_scale,
@@ -69,6 +71,7 @@ with block:
69
  input_image = gr.Image(type="numpy")
70
  prompt = gr.Textbox(label="Prompt")
71
  negative_prompt = gr.Textbox(label="Negative Prompt")
 
72
  run_button = gr.Button(value="Run")
73
  with gr.Accordion("Advanced options", open=False):
74
  target_image_size = gr.Slider(
@@ -107,7 +110,7 @@ with block:
107
  generated_output = gr.Image(label="Generated", type="numpy", elem_id="generated")
108
  mask_output = gr.Image(label="Mask", type="numpy", elem_id="mask")
109
 
110
- ips = [input_image, prompt, negative_prompt, num_inference_steps, guidance_scale, conditioning_scale, guess_mode, target_image_size, max_image_size, seed]
111
  run_button.click(fn=process, inputs=ips, outputs=[generated_output, mask_output])
112
 
113
 
 
17
  input_image: np.ndarray,
18
  prompt: str,
19
  negative_prompt: str,
20
+ generate_from_mask: bool,
21
  num_inference_steps: int,
22
  guidance_scale: float,
23
  conditioning_scale: float,
 
30
  # control_image=input_image,
31
  # prompt=prompt,
32
  # negative_prompt=negative_prompt,
33
+ # generate_from_mask=generate_from_mask,
34
  # num_inference_steps=num_inference_steps,
35
  # guidance_scale=guidance_scale,
36
  # conditioning_scale=conditioning_scale,
 
71
  input_image = gr.Image(type="numpy")
72
  prompt = gr.Textbox(label="Prompt")
73
  negative_prompt = gr.Textbox(label="Negative Prompt")
74
+ generate_from_mask = gr.Checkbox(label="Input image is already a control mask", value=False)
75
  run_button = gr.Button(value="Run")
76
  with gr.Accordion("Advanced options", open=False):
77
  target_image_size = gr.Slider(
 
110
  generated_output = gr.Image(label="Generated", type="numpy", elem_id="generated")
111
  mask_output = gr.Image(label="Mask", type="numpy", elem_id="mask")
112
 
113
+ ips = [input_image, prompt, negative_prompt, generate_from_mask, num_inference_steps, guidance_scale, conditioning_scale, guess_mode, target_image_size, max_image_size, seed]
114
  run_button.click(fn=process, inputs=ips, outputs=[generated_output, mask_output])
115
 
116
 
src/pipeline.py CHANGED
@@ -41,6 +41,7 @@ class FashionPipeline:
41
  control_image: np.ndarray,
42
  prompt: str,
43
  negative_prompt: str,
 
44
  num_inference_steps: int,
45
  guidance_scale: float,
46
  conditioning_scale: float,
@@ -54,8 +55,11 @@ class FashionPipeline:
54
  control_image = HWC3(control_image)
55
 
56
  # extract segmentation mask
57
- segm_mask = generate_mask(control_image, self.segmentation_model, device=self.device)
58
- control_mask = self.create_control_mask(segm_mask)
 
 
 
59
 
60
  control_mask = self.adaptive_resize(
61
  image=control_mask,
 
41
  control_image: np.ndarray,
42
  prompt: str,
43
  negative_prompt: str,
44
+ generate_from_mask: bool,
45
  num_inference_steps: int,
46
  guidance_scale: float,
47
  conditioning_scale: float,
 
55
  control_image = HWC3(control_image)
56
 
57
  # extract segmentation mask
58
+ if generate_from_mask:
59
+ control_mask = control_image
60
+ else:
61
+ segm_mask = generate_mask(control_image, self.segmentation_model, device=self.device)
62
+ control_mask = self.create_control_mask(segm_mask)
63
 
64
  control_mask = self.adaptive_resize(
65
  image=control_mask,