Stylique commited on
Commit
90b8e48
·
verified ·
1 Parent(s): d9efec2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +143 -118
app.py CHANGED
@@ -12,7 +12,7 @@ from preprocess.openpose.run_openpose import OpenPose
12
 
13
  import gradio as gr
14
 
15
- # Download checkpoints
16
  snapshot_download(repo_id="franciszzj/Leffa", local_dir="./ckpts")
17
 
18
 
@@ -34,64 +34,66 @@ class LeffaPredictor:
34
  body_model_path="./ckpts/openpose/body_pose_model.pth",
35
  )
36
 
37
- vt_model_hd = LeffaModel(
 
38
  pretrained_model_name_or_path="./ckpts/stable-diffusion-inpainting",
39
  pretrained_model="./ckpts/virtual_tryon.pth",
40
  dtype="float16",
41
  )
42
- self.vt_inference_hd = LeffaInference(model=vt_model_hd)
43
 
44
- vt_model_dc = LeffaModel(
 
45
  pretrained_model_name_or_path="./ckpts/stable-diffusion-inpainting",
46
  pretrained_model="./ckpts/virtual_tryon_dc.pth",
47
  dtype="float16",
48
  )
49
- self.vt_inference_dc = LeffaInference(model=vt_model_dc)
50
 
51
- pt_model = LeffaModel(
 
52
  pretrained_model_name_or_path="./ckpts/stable-diffusion-xl-1.0-inpainting-0.1",
53
  pretrained_model="./ckpts/pose_transfer.pth",
54
  dtype="float16",
55
  )
56
- self.pt_inference = LeffaInference(model=pt_model)
57
-
58
- def leffa_predict(self, src_image_path, ref_image_path, control_type,
59
- ref_acceleration=False, step=50, scale=2.5, seed=42,
60
- vt_model_type="viton_hd", vt_garment_type="upper_body", vt_repaint=False):
61
- assert control_type in ["virtual_tryon", "pose_transfer"]
62
- src = Image.open(src_image_path)
63
- ref = Image.open(ref_image_path)
64
  src = resize_and_center(src, 768, 1024)
65
  ref = resize_and_center(ref, 768, 1024)
66
- arr = np.array(src)
67
-
68
- # Mask
69
- if control_type == "virtual_tryon":
70
- src_rgb = src.convert("RGB")
71
- parse, _ = self.parsing(src_rgb.resize((384, 512)))
72
- kpt = self.openpose(src_rgb.resize((384, 512)))
73
- if vt_model_type == "viton_hd":
74
- mask = get_agnostic_mask_hd(parse, kpt, vt_garment_type)
75
- else:
76
- mask = get_agnostic_mask_dc(parse, kpt, vt_garment_type)
77
- mask = mask.resize((768, 1024))
 
 
 
78
  else:
79
- mask = Image.fromarray(np.ones_like(arr) * 255)
80
-
81
- # DensePose
82
- if control_type == "virtual_tryon":
83
- if vt_model_type == "viton_hd":
84
- seg = self.densepose_predictor.predict_seg(arr)[:, :, ::-1]
85
- densepose = Image.fromarray(seg)
86
- else:
87
- iuv = self.densepose_predictor.predict_iuv(arr)
88
- seg = np.repeat(iuv[:, :, 0:1], 3, axis=-1)
89
- densepose = Image.fromarray(seg)
90
  else:
91
- iuv = self.densepose_predictor.predict_iuv(arr)[:, :, ::-1]
92
- densepose = Image.fromarray(iuv)
 
 
93
 
94
- # Inference
95
  data = {
96
  "src_image": [src],
97
  "ref_image": [ref],
@@ -99,100 +101,123 @@ class LeffaPredictor:
99
  "densepose": [densepose],
100
  }
101
  data = LeffaTransform()(data)
102
-
103
- if control_type == "virtual_tryon":
104
- inf = self.vt_inference_hd if vt_model_type == "viton_hd" else self.vt_inference_dc
105
- else:
106
- inf = self.pt_inference
107
-
108
  out = inf(
109
  data,
110
- ref_acceleration=ref_acceleration,
111
- num_inference_steps=step,
112
- guidance_scale=scale,
113
- seed=seed,
114
- repaint=vt_repaint,
115
  )
116
- img = out["generated_image"][0]
117
- return np.array(img), np.array(mask), np.array(densepose)
118
-
119
- def leffa_predict_vt(self, src, ref, accel, step, scale, seed, mtype, gtype, repaint):
120
- return self.leffa_predict(src, ref, "virtual_tryon", accel, step, scale, seed, mtype, gtype, repaint)
 
 
 
 
 
 
 
121
 
122
- def leffa_predict_pt(self, src, ref, accel, step, scale, seed):
123
- return self.leffa_predict(src, ref, "pose_transfer", accel, step, scale, seed)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
 
126
  if __name__ == "__main__":
127
  lp = LeffaPredictor()
128
- ex = "./ckpts/examples"
129
- p1 = list_dir(f"{ex}/person1")
130
- p2 = list_dir(f"{ex}/person2")
131
- g = list_dir(f"{ex}/garment")
132
-
133
- with gr.Blocks(
134
- theme=gr.themes.Default(primary_hue=gr.themes.colors.pink, secondary_hue=gr.themes.colors.red)
135
- ).queue() as demo:
136
-
137
- gr.Markdown("## Leffa: Learning Flow Fields in Attention …")
138
- gr.Markdown(
139
- "[📚 Paper](https://arxiv.org/abs/2412.08486) "
140
- "[🤖 Code](https://github.com/franciszzj/Leffa) • "
141
- "[🤗 Model](https://huggingface.co/franciszzj/Leffa)"
142
- )
143
 
144
- with gr.Tab("Virtual Try-On"):
 
 
145
  with gr.Row():
146
  with gr.Column():
147
- vt_src = gr.Image(type="filepath", label="Person Image")
148
- gr.Examples(inputs=vt_src, examples_per_page=6, examples=p1)
 
149
  with gr.Column():
150
- vt_ref = gr.Image(type="filepath", label="Garment Image")
151
- gr.Examples(inputs=vt_ref, examples_per_page=6, examples=g)
 
152
  with gr.Column():
153
- vt_out = gr.Image(label="Generated")
154
- btn = gr.Button("Generate")
155
- with gr.Accordion("Advanced"):
156
- vt_model_type = gr.Radio(
157
- ["viton_hd", "dress_code"], label="Model Type", value="viton_hd"
158
- )
159
- vt_garment_type = gr.Radio(
160
- ["upper_body", "lower_body", "dresses"],
161
- label="Garment Type",
162
- value="upper_body",
163
- )
164
- vt_accel = gr.Checkbox(label="Accelerate UNet", value=False)
165
- vt_repaint = gr.Checkbox(label="Repaint Mode", value=False)
166
- vt_steps = gr.Number(label="Steps", value=30)
167
- vt_scale = gr.Number(label="Scale", value=2.5)
168
- vt_seed = gr.Number(label="Seed", value=42)
169
- btn.click(
170
- lp.leffa_predict_vt,
171
- inputs=[vt_src, vt_ref, vt_accel, vt_steps, vt_scale, vt_seed, vt_model_type, vt_garment_type, vt_repaint],
172
- outputs=[vt_out],
173
- )
174
 
175
  with gr.Tab("Pose Transfer"):
176
  with gr.Row():
177
  with gr.Column():
178
- pt_ref = gr.Image(type="filepath", label="Person Image")
179
- gr.Examples(inputs=pt_ref, examples_per_page=6, examples=p1)
 
180
  with gr.Column():
181
- pt_src = gr.Image(type="filepath", label="Target Pose Image")
182
- gr.Examples(inputs=pt_src, examples_per_page=6, examples=p2)
 
183
  with gr.Column():
184
- pt_out = gr.Image(label="Generated")
185
- btn2 = gr.Button("Generate")
186
- with gr.Accordion("Advanced"):
187
- pt_accel = gr.Checkbox(label="Accelerate UNet", value=False)
188
- pt_steps = gr.Number(label="Steps", value=30)
189
- pt_scale = gr.Number(label="Scale", value=2.5)
190
- pt_seed = gr.Number(label="Seed", value=42)
191
- btn2.click(
192
- lp.leffa_predict_pt,
193
- inputs=[pt_src, pt_ref, pt_accel, pt_steps, pt_scale, pt_seed],
194
- outputs=[pt_out],
195
- )
196
-
197
- gr.Markdown("Note: Virtual try-on uses VITON-HD/DressCode; pose transfer uses DeepFashion.")
198
- demo.launch(server_port=7860, allowed_paths=["./ckpts/examples"])
 
 
 
 
 
 
 
 
12
 
13
  import gradio as gr
14
 
15
+ # Download checkpoints once at startup
16
  snapshot_download(repo_id="franciszzj/Leffa", local_dir="./ckpts")
17
 
18
 
 
34
  body_model_path="./ckpts/openpose/body_pose_model.pth",
35
  )
36
 
37
+ # Virtual try‑on HD
38
+ vt_hd = LeffaModel(
39
  pretrained_model_name_or_path="./ckpts/stable-diffusion-inpainting",
40
  pretrained_model="./ckpts/virtual_tryon.pth",
41
  dtype="float16",
42
  )
43
+ self.vt_hd_inf = LeffaInference(model=vt_hd)
44
 
45
+ # Virtual try‑on DressCode
46
+ vt_dc = LeffaModel(
47
  pretrained_model_name_or_path="./ckpts/stable-diffusion-inpainting",
48
  pretrained_model="./ckpts/virtual_tryon_dc.pth",
49
  dtype="float16",
50
  )
51
+ self.vt_dc_inf = LeffaInference(model=vt_dc)
52
 
53
+ # Pose transfer
54
+ pt = LeffaModel(
55
  pretrained_model_name_or_path="./ckpts/stable-diffusion-xl-1.0-inpainting-0.1",
56
  pretrained_model="./ckpts/pose_transfer.pth",
57
  dtype="float16",
58
  )
59
+ self.pt_inf = LeffaInference(model=pt)
60
+
61
+ def _prepare(self, src_path, ref_path):
62
+ src = Image.open(src_path)
63
+ ref = Image.open(ref_path)
 
 
 
64
  src = resize_and_center(src, 768, 1024)
65
  ref = resize_and_center(ref, 768, 1024)
66
+ return src, ref
67
+
68
+ def predict_virtual_tryon(
69
+ self, src_path, ref_path,
70
+ accelerate_ref, steps, scale, seed,
71
+ model_type, garment_type, repaint
72
+ ):
73
+ src, ref = self._prepare(src_path, ref_path)
74
+ src_arr = np.array(src.convert("RGB"))
75
+
76
+ # 1) parsing + keypoints → agnostic mask
77
+ parse, _ = self.parsing(src.resize((384, 512)))
78
+ kpts = self.openpose(src.resize((384, 512)))
79
+ if model_type == "viton_hd":
80
+ mask = get_agnostic_mask_hd(parse, kpts, garment_type)
81
  else:
82
+ mask = get_agnostic_mask_dc(parse, kpts, garment_type)
83
+ mask = mask.resize((768, 1024))
84
+
85
+ # 2) DensePose → seg or IUV
86
+ if model_type == "viton_hd":
87
+ seg = self.densepose_predictor.predict_seg(src_arr)[:, :, ::-1]
88
+ densepose = Image.fromarray(seg)
89
+ inf = self.vt_hd_inf
 
 
 
90
  else:
91
+ iuv = self.densepose_predictor.predict_iuv(src_arr)
92
+ seg = np.concatenate([iuv[:, :, :1]] * 3, axis=-1)
93
+ densepose = Image.fromarray(seg)
94
+ inf = self.vt_dc_inf
95
 
96
+ # 3) run Leffa
97
  data = {
98
  "src_image": [src],
99
  "ref_image": [ref],
 
101
  "densepose": [densepose],
102
  }
103
  data = LeffaTransform()(data)
 
 
 
 
 
 
104
  out = inf(
105
  data,
106
+ ref_acceleration=accelerate_ref,
107
+ num_inference_steps=int(steps),
108
+ guidance_scale=float(scale),
109
+ seed=int(seed),
110
+ repaint=repaint,
111
  )
112
+ gen = out["generated_image"][0]
113
+ return np.array(gen), np.array(mask), np.array(densepose)
114
+
115
+ def predict_pose_transfer(
116
+ self, src_path, ref_path,
117
+ accelerate_ref, steps, scale, seed
118
+ ):
119
+ src, ref = self._prepare(src_path, ref_path)
120
+ src_arr = np.array(src)
121
+ mask = Image.fromarray(np.ones_like(src_arr) * 255)
122
+ iuv = self.densepose_predictor.predict_iuv(src_arr)[:, :, ::-1]
123
+ densepose = Image.fromarray(iuv)
124
 
125
+ data = {
126
+ "src_image": [src],
127
+ "ref_image": [ref],
128
+ "mask": [mask],
129
+ "densepose": [densepose],
130
+ }
131
+ data = LeffaTransform()(data)
132
+ out = self.pt_inf(
133
+ data,
134
+ ref_acceleration=accelerate_ref,
135
+ num_inference_steps=int(steps),
136
+ guidance_scale=float(scale),
137
+ seed=int(seed),
138
+ )
139
+ gen = out["generated_image"][0]
140
+ return np.array(gen), np.array(mask), np.array(densepose)
141
 
142
 
143
  if __name__ == "__main__":
144
  lp = LeffaPredictor()
145
+ examples = "./ckpts/examples"
146
+ person1 = list_dir(f"{examples}/person1")
147
+ person2 = list_dir(f"{examples}/person2")
148
+ garments = list_dir(f"{examples}/garment")
149
+
150
+ title = "## Leffa: Controllable Person Image Generation"
151
+ note = "Note: Virtual Try‑On uses VITON‑HD/DressCode; Pose Transfer uses DeepFashion."
152
+
153
+ with gr.Blocks(theme=gr.themes.Default(
154
+ primary_hue=gr.themes.colors.pink,
155
+ secondary_hue=gr.themes.colors.red
156
+ )).queue() as demo:
 
 
 
157
 
158
+ gr.Markdown(title)
159
+
160
+ with gr.Tab("Virtual Try‑On"):
161
  with gr.Row():
162
  with gr.Column():
163
+ vt_src = gr.Image(source="upload", type="filepath", label="Person")
164
+ gr.Examples(examples=person1, inputs=vt_src)
165
+
166
  with gr.Column():
167
+ vt_ref = gr.Image(source="upload", type="filepath", label="Garment")
168
+ gr.Examples(examples=garments, inputs=vt_ref)
169
+
170
  with gr.Column():
171
+ vt_out = gr.Image(label="Result")
172
+ vt_mask = gr.Image(label="Mask")
173
+ vt_dp = gr.Image(label="DensePose")
174
+ vt_btn = gr.Button("Generate")
175
+
176
+ with gr.Accordion("Advanced Options", open=False):
177
+ vt_model = gr.Radio(["viton_hd","dress_code"], value="viton_hd", label="Model")
178
+ vt_garment = gr.Radio(["upper_body","lower_body","dresses"], value="upper_body", label="Garment Type")
179
+ vt_accel_ref = gr.Checkbox(label="Accelerate Reference UNet")
180
+ vt_repaint = gr.Checkbox(label="Repaint Mode")
181
+ vt_steps = gr.Slider(30,100,value=30,step=1,label="Steps")
182
+ vt_scale = gr.Slider(0.1,5.0,value=2.5,step=0.1,label="Guidance Scale")
183
+ vt_seed = gr.Number(value=42, label="Seed")
184
+
185
+ vt_btn.click(
186
+ fn=lp.predict_virtual_tryon,
187
+ inputs=[vt_src, vt_ref, vt_accel_ref, vt_steps, vt_scale, vt_seed, vt_model, vt_garment, vt_repaint],
188
+ outputs=[vt_out, vt_mask, vt_dp],
189
+ )
 
 
190
 
191
  with gr.Tab("Pose Transfer"):
192
  with gr.Row():
193
  with gr.Column():
194
+ pt_src = gr.Image(source="upload", type="filepath", label="Source Pose")
195
+ gr.Examples(examples=person2, inputs=pt_src)
196
+
197
  with gr.Column():
198
+ pt_ref = gr.Image(source="upload", type="filepath", label="Target Person")
199
+ gr.Examples(examples=person1, inputs=pt_ref)
200
+
201
  with gr.Column():
202
+ pt_out = gr.Image(label="Result")
203
+ pt_mask = gr.Image(label="Mask")
204
+ pt_dp = gr.Image(label="DensePose")
205
+ pt_btn = gr.Button("Generate")
206
+
207
+ with gr.Accordion("Advanced Options", open=False):
208
+ pt_accel_ref = gr.Checkbox(label="Accelerate Reference UNet")
209
+ pt_steps = gr.Slider(30,100,value=30,step=1,label="Steps")
210
+ pt_scale = gr.Slider(0.1,5.0,value=2.5,step=0.1,label="Guidance Scale")
211
+ pt_seed = gr.Number(value=42, label="Seed")
212
+
213
+ pt_btn.click(
214
+ fn=lp.predict_pose_transfer,
215
+ inputs=[pt_src, pt_ref, pt_accel_ref, pt_steps, pt_scale, pt_seed],
216
+ outputs=[pt_out, pt_mask, pt_dp],
217
+ )
218
+
219
+ gr.Markdown(note)
220
+
221
+ # expose publicly
222
+ demo.launch(share=True, server_port=7860,
223
+ allowed_paths=["./ckpts/examples"])