Spaces:
Runtime error
Runtime error
yjernite
commited on
Commit
·
8df4211
1
Parent(s):
427730e
template for examplars
Browse files
app.py
CHANGED
|
@@ -133,7 +133,10 @@ def make_profession_table(num_clusters, prof_names, mod_name, max_cols=8):
|
|
| 133 |
.to_html()
|
| 134 |
)
|
| 135 |
|
| 136 |
-
|
|
|
|
|
|
|
|
|
|
| 137 |
|
| 138 |
with gr.Blocks() as demo:
|
| 139 |
gr.Markdown("# 🤗 Diffusion Cluster Explorer")
|
|
@@ -183,31 +186,56 @@ with gr.Blocks() as demo:
|
|
| 183 |
# with gr.Accordion("Tag Frequencies", open=False):
|
| 184 |
|
| 185 |
with gr.Tab("Profession Focus"):
|
| 186 |
-
with gr.Row():
|
| 187 |
-
num_clusters = gr.Radio(
|
| 188 |
-
[12, 24, 48],
|
| 189 |
-
value=12,
|
| 190 |
-
label="How many clusters do you want to use to represent identities?",
|
| 191 |
-
)
|
| 192 |
with gr.Row():
|
| 193 |
with gr.Column():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 194 |
profession_choice_focus = gr.Dropdown(
|
| 195 |
choices=professions,
|
| 196 |
value="social worker",
|
| 197 |
label="Select profession:",
|
| 198 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 199 |
with gr.Column():
|
| 200 |
plot = gr.Plot(
|
| 201 |
label=f"Makeup of the cluster assignments for profession {profession_choice_focus}"
|
| 202 |
)
|
| 203 |
-
profession_choice_focus
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 207 |
queue=False,
|
| 208 |
)
|
| 209 |
-
with gr.Row():
|
| 210 |
-
gr.Markdown("TODO: show examplars for cluster")
|
| 211 |
|
| 212 |
|
| 213 |
demo.launch()
|
|
|
|
| 133 |
.to_html()
|
| 134 |
)
|
| 135 |
|
| 136 |
+
def show_examplars(num_clusters, prof_name, mod_name, cl_id):
|
| 137 |
+
# TODO: show the actual images
|
| 138 |
+
examplars_dict = clusters_dicts[num_clusters][df_models[mod_name]][prof_name]["cluster_examplars"][str(cl_id)]
|
| 139 |
+
return json.dumps(examplars_dict)
|
| 140 |
|
| 141 |
with gr.Blocks() as demo:
|
| 142 |
gr.Markdown("# 🤗 Diffusion Cluster Explorer")
|
|
|
|
| 186 |
# with gr.Accordion("Tag Frequencies", open=False):
|
| 187 |
|
| 188 |
with gr.Tab("Profession Focus"):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 189 |
with gr.Row():
|
| 190 |
with gr.Column():
|
| 191 |
+
gr.Markdown("Select profession to visualize here:")
|
| 192 |
+
num_clusters_focus = gr.Radio(
|
| 193 |
+
[12, 24, 48],
|
| 194 |
+
value=12,
|
| 195 |
+
label="How many clusters do you want to use to represent identities?",
|
| 196 |
+
)
|
| 197 |
profession_choice_focus = gr.Dropdown(
|
| 198 |
choices=professions,
|
| 199 |
value="social worker",
|
| 200 |
label="Select profession:",
|
| 201 |
)
|
| 202 |
+
gr.Markdown("You can show examples of profession images assigned to each cluster:")
|
| 203 |
+
model_choices_focus = gr.Dropdown(
|
| 204 |
+
[
|
| 205 |
+
"All Models",
|
| 206 |
+
"Stable Diffusion 1.4",
|
| 207 |
+
"Stable Diffusion 2",
|
| 208 |
+
"Dall-E 2",
|
| 209 |
+
],
|
| 210 |
+
value="All Models",
|
| 211 |
+
label="Select generation model:",
|
| 212 |
+
interactive=True,
|
| 213 |
+
)
|
| 214 |
+
cluster_id_focus = gr.Dropdown(
|
| 215 |
+
choices=[i for i in range(num_clusters_focus.value)],
|
| 216 |
+
value=0,
|
| 217 |
+
label="Select cluster to visualize:",
|
| 218 |
+
)
|
| 219 |
with gr.Column():
|
| 220 |
plot = gr.Plot(
|
| 221 |
label=f"Makeup of the cluster assignments for profession {profession_choice_focus}"
|
| 222 |
)
|
| 223 |
+
for var in [num_clusters_focus, profession_choice_focus]:
|
| 224 |
+
var.change(
|
| 225 |
+
make_profession_plot,
|
| 226 |
+
[num_clusters_focus, profession_choice_focus],
|
| 227 |
+
plot,
|
| 228 |
+
queue=False,
|
| 229 |
+
)
|
| 230 |
+
with gr.Row():
|
| 231 |
+
examplars_plot = gr.JSON() # TODO: turn this into a plot with the actual images
|
| 232 |
+
for var in [model_choices_focus, cluster_id_focus]:
|
| 233 |
+
var.change(
|
| 234 |
+
show_examplars,
|
| 235 |
+
[num_clusters_focus, profession_choice_focus, model_choices_focus, cluster_id_focus],
|
| 236 |
+
examplars_plot,
|
| 237 |
queue=False,
|
| 238 |
)
|
|
|
|
|
|
|
| 239 |
|
| 240 |
|
| 241 |
demo.launch()
|