Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import json | |
| import numpy as np | |
| import pandas as pd | |
| from datasets import load_from_disk | |
| from itertools import chain | |
| import operator | |
| pd.options.plotting.backend = "plotly" | |
| TITLE = "Diffusion Professions Cluster Explorer" | |
| professions_dset = load_from_disk("professions") | |
| professions_df = professions.to_pandas() | |
| def get_image(model, fname): | |
| return professions_dset.select(professions_df[(professions_df["image_path"]==fname) & (professions_df["model"]==model)].index)["image"][0] | |
| clusters_dicts = dict( | |
| (num_cl, json.load(open(f"clusters/professions_to_clusters_{num_cl}.json"))) | |
| for num_cl in [12, 24, 48] | |
| ) | |
| prompts = pd.read_csv("promptsadjectives.csv") | |
| professions = list(sorted([p.lower() for p in prompts["Occupation-Noun"].tolist()])) | |
| models = { | |
| "All": "All Models", | |
| "SD_14": "Stable Diffusion 1.4", | |
| "SD_2": "Stable Diffusion 2", | |
| "DallE": "Dall-E 2", | |
| } | |
| df_models = { | |
| "All Models": "All", | |
| "Stable Diffusion 1.4": "SD_14", | |
| "Stable Diffusion 2": "SD_2", | |
| "Dall-E 2": "DallE", | |
| } | |
| def describe_cluster(num_clusters, block="label"): | |
| cl_dict = clusters_dicts[num_clusters] | |
| labels_values = sorted(cl_dict.items(), key=operator.itemgetter(1)) | |
| labels_values.reverse() | |
| total = float(sum(cl_dict.values())) | |
| lv_prcnt = list( | |
| (item[0], round(item[1] * 100 / total, 0)) for item in labels_values | |
| ) | |
| top_label = lv_prcnt[0][0] | |
| description_string = ( | |
| "<span>The most represented %s is <b>%s</b>, making up about <b>%d%%</b> of the cluster.</span>" | |
| % (to_string(block), to_string(top_label), lv_prcnt[0][1]) | |
| ) | |
| description_string += "<p>This is followed by: " | |
| for lv in lv_prcnt[1:]: | |
| description_string += "<BR/><b>%s:</b> %d%%" % (to_string(lv[0]), lv[1]) | |
| description_string += "</p>" | |
| return description_string | |
| def make_profession_plot(num_clusters, prof_name): | |
| pre_pandas = dict( | |
| [ | |
| ( | |
| models[mod_name], | |
| dict( | |
| ( | |
| f"Cluster {k}", | |
| clusters_dicts[num_clusters][mod_name][prof_name][ | |
| "cluster_proportions" | |
| ][k], | |
| ) | |
| for k, v in sorted( | |
| clusters_dicts[num_clusters]["All"][prof_name][ | |
| "cluster_proportions" | |
| ].items(), | |
| key=lambda x: x[1], | |
| reverse=True, | |
| ) | |
| if v > 0 | |
| ), | |
| ) | |
| for mod_name in models | |
| ] | |
| ) | |
| df = pd.DataFrame.from_dict(pre_pandas) | |
| prof_plot = df.plot(kind="bar", barmode="group") | |
| return prof_plot | |
| def make_profession_table(num_clusters, prof_names, mod_name, max_cols=8): | |
| professions_list_clusters = [ | |
| ( | |
| prof_name, | |
| clusters_dicts[num_clusters][df_models[mod_name]][prof_name][ | |
| "cluster_proportions" | |
| ], | |
| ) | |
| for prof_name in prof_names | |
| ] | |
| totals = sorted( | |
| [ | |
| ( | |
| k, | |
| sum( | |
| prof_clusters[str(k)] | |
| for _, prof_clusters in professions_list_clusters | |
| ), | |
| ) | |
| for k in range(num_clusters) | |
| ], | |
| key=lambda x: x[1], | |
| reverse=True, | |
| )[:max_cols] | |
| prof_list_pre_pandas = [ | |
| dict( | |
| [ | |
| ("Profession", prof_name), | |
| ( | |
| "Entropy", | |
| clusters_dicts[num_clusters][df_models[mod_name]][prof_name][ | |
| "entropy" | |
| ], | |
| ), | |
| ( | |
| "Labor Women", | |
| clusters_dicts[num_clusters][df_models[mod_name]][prof_name][ | |
| "labor_fm" | |
| ][0], | |
| ), | |
| ("", ""), | |
| ] | |
| + [(f"Cluster {k}", prof_clusters[str(k)]) for k, v in totals if v > 0] | |
| ) | |
| for prof_name, prof_clusters in professions_list_clusters | |
| ] | |
| clusters_df = pd.DataFrame.from_dict(prof_list_pre_pandas) | |
| return [c[0] for c in totals], ( | |
| clusters_df.style.background_gradient( | |
| axis=None, vmin=0, vmax=100, cmap="YlGnBu" | |
| ) | |
| .format(precision=1) | |
| .to_html() | |
| ) | |
| def show_examplars(num_clusters, prof_name, mod_name, cl_id): | |
| # TODO: show the actual images | |
| examplars_dict = clusters_dicts[num_clusters][df_models[mod_name]][prof_name][ | |
| "cluster_examplars" | |
| ][str(cl_id)] | |
| l = list(chain(*[examplars_dict[k] for k in examplars_dict])) | |
| return [get_image(model,fname) for _,model,fname in l] | |
| with gr.Blocks(title=TITLE) as demo: | |
| gr.Markdown("# 🤗 Diffusion Cluster Explorer") | |
| gr.Markdown("description will go here") | |
| with gr.Tab("Professions Overview"): | |
| gr.Markdown("TODO") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("Select the parameters here:") | |
| num_clusters = gr.Radio( | |
| [12, 24, 48], | |
| value=12, | |
| label="How many clusters do you want to use to represent identities?", | |
| ) | |
| model_choices = gr.Dropdown( | |
| [ | |
| "All Models", | |
| "Stable Diffusion 1.4", | |
| "Stable Diffusion 2", | |
| "Dall-E 2", | |
| ], | |
| value="All Models", | |
| label="Which models do you want to compare?", | |
| interactive=True, | |
| ) | |
| profession_choices_overview = gr.Dropdown( | |
| professions, | |
| value=["CEO", "director", "social assistant", "social worker"], | |
| label="Which professions do you want to compare?", | |
| multiselect=True, | |
| interactive=True, | |
| ) | |
| with gr.Column(scale=3): | |
| with gr.Row(): | |
| table = gr.HTML( | |
| label="Profession assignment per cluster", wrap=True | |
| ) | |
| # clusters = gr.Dataframe(type="array", visible=False, col_count=1) | |
| clusters = gr.Textbox(label="clusters", visible=False) | |
| demo.load( | |
| make_profession_table, | |
| [num_clusters, profession_choices_overview, model_choices], | |
| [clusters, table], | |
| queue=False, | |
| ) | |
| for var in [num_clusters, model_choices, profession_choices_overview]: | |
| var.change( | |
| make_profession_table, | |
| [num_clusters, profession_choices_overview, model_choices], | |
| [clusters, table], | |
| queue=False, | |
| ) | |
| with gr.Tab("Profession Focus"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("Select profession to visualize here:") | |
| num_clusters_focus = gr.Radio( | |
| [12, 24, 48], | |
| value=12, | |
| label="How many clusters do you want to use to represent identities?", | |
| ) | |
| profession_choice_focus = gr.Dropdown( | |
| choices=professions, | |
| value="social worker", | |
| label="Select profession:", | |
| ) | |
| gr.Markdown( | |
| "You can show examples of profession images assigned to each cluster:" | |
| ) | |
| model_choices_focus = gr.Dropdown( | |
| [ | |
| "All Models", | |
| "Stable Diffusion 1.4", | |
| "Stable Diffusion 2", | |
| "Dall-E 2", | |
| ], | |
| value="All Models", | |
| label="Select generation model:", | |
| interactive=True, | |
| ) | |
| cluster_id_focus = gr.Dropdown( | |
| choices=[i for i in range(num_clusters_focus.value)], | |
| value=0, | |
| label="Select cluster to visualize:", | |
| ) | |
| with gr.Column(): | |
| plot = gr.Plot( | |
| label=f"Makeup of the cluster assignments for profession {profession_choice_focus}" | |
| ) | |
| demo.load( | |
| make_profession_plot, | |
| [num_clusters_focus, profession_choice_focus], | |
| plot, | |
| queue=False, | |
| ) | |
| for var in [num_clusters_focus, profession_choice_focus]: | |
| var.change( | |
| make_profession_plot, | |
| [num_clusters_focus, profession_choice_focus], | |
| plot, | |
| queue=False, | |
| ) | |
| with gr.Row(): | |
| examplars_plot = ( | |
| gr.Gallery() | |
| ) # TODO: turn this into a plot with the actual images | |
| demo.load( | |
| show_examplars, | |
| [ | |
| num_clusters_focus, | |
| profession_choice_focus, | |
| model_choices_focus, | |
| cluster_id_focus, | |
| ], | |
| examplars_plot, | |
| queue=False, | |
| ) | |
| for var in [model_choices_focus, cluster_id_focus]: | |
| var.change( | |
| show_examplars, | |
| [ | |
| num_clusters_focus, | |
| profession_choice_focus, | |
| model_choices_focus, | |
| cluster_id_focus, | |
| ], | |
| examplars_plot, | |
| queue=False, | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue().launch(debug=True) | |