import pandas as pd
import numpy as np
import os
import pickle
from prefixspan import PrefixSpan
import gradio as gr
import altair as alt
from plot import Plot
alt.data_transformers.enable("vegafusion")
# from dynabench.task_evaluator import *

BASE_DIR = "db"
MODELS = ['qwenvl-chat', 'qwenvl', 'llava15-7b', 'llava15-13b', 'instructblip-vicuna13b', 'instructblip-vicuna7b']
VIDEO_MODELS = ['video-chat2-7b','video-llama2-7b','video-llama2-13b','chat-univi-7b','chat-univi-13b','video-llava-7b','video-chatgpt-7b']
domains = ["imageqa-2d-sticker", "imageqa-3d-tabletop", "imageqa-scene-graph", "videoqa-3d-tabletop", "videoqa-scene-graph"]
domain2folder = {"imageqa-2d-sticker": "2d", 
                 "imageqa-3d-tabletop": "3d", 
                 "imageqa-scene-graph": "sg", 
                 "videoqa-3d-tabletop": "video-3d", 
                 "videoqa-scene-graph": "video-sg", 
                 None: '2d'}

def find_frequent_patterns(k, df, scores=None):
	if len(df) == 0:
		return []

	df = df.reset_index(drop=True)
	cols = df.columns.to_list()
	df = df.fillna('').astype('str')
	db = [[(c, v) for c, v in zip(cols, d) if v] for d in df.values.tolist()]

	ps = PrefixSpan(db)
	patterns = ps.topk(k, closed=True)
	if scores is None:
		return patterns
	else:
		aggregated_scores = []
		scores = np.asarray(scores)
		for count, pattern in patterns:
			q = ' and '.join([f"`{k}` == {repr(v)}" for k, v in pattern])
			indices = df.query(q).index.to_numpy()
			aggregated_scores.append(np.mean(scores[indices]))
		return patterns, aggregated_scores

def update_partition_and_models(domain):
    domain = domain2folder[domain]
    path = f"{BASE_DIR}/{domain}"
    

    if os.path.exists(path):
        partitions = list_directories(path)
        if domain.find("video") > -1:
            model = gr.Dropdown(VIDEO_MODELS, value=VIDEO_MODELS[0], label="model")
        else:  
            model = gr.Dropdown(MODELS, value=MODELS[0], label="model")

        partition = gr.Dropdown(partitions, value=partitions[0], label="task space of the following task generator")
        return [partition, model]
    else:
        partition = gr.Dropdown([], value=None, label="task space of the following task generator")
        model = gr.Dropdown([], value=None, label="model")
        return [partition, model]
    
def update_partition_and_models_and_baselines(domain):
    domain = domain2folder[domain]
    path = f"{BASE_DIR}/{domain}"
    
    if os.path.exists(path):
        partitions = list_directories(path)
        if domain.find("video") > -1:
            model = gr.Dropdown(VIDEO_MODELS, value=VIDEO_MODELS[0], label="model")
            baseline = gr.Dropdown(VIDEO_MODELS, value=VIDEO_MODELS[0], label="baseline")
        else:  
            model = gr.Dropdown(MODELS, value=MODELS[0], label="model")
            baseline = gr.Dropdown(MODELS, value=MODELS[0], label="baseline")

        partition = gr.Dropdown(partitions, value=partitions[0], label="task space of the following task generator") 
    else:
        partition = gr.Dropdown([], value=None, label="task space of the following task generator")
        model = gr.Dropdown([], value=None, label="model")
        baseline = gr.Dropdown([], value=None, label="baseline")
    return [partition, model, baseline]
    
def get_filtered_task_ids(domain, partition, models, rank, k, threshold, baseline):
    domain = domain2folder[domain]
    data_path = f"{BASE_DIR}/{domain}/{partition}/merged_data.csv"
    if not os.path.exists(data_path):
        return []
    else:
        merged_df = pd.read_csv(data_path)
        merged_df.rename(columns={'llavav1.5-7b': 'llava15-7b', 'llavav1.5-13b': 'llava15-13b'}, inplace=True)
        
        df = merged_df

        select_top = rank == "top"
        # Model X is good / bad at
        for model in models:
            if baseline:
                df = df[df[model] >= df[baseline]]
            else:
                if select_top:
                    df = df[df[model] >= threshold]
                else:
                    df = df[df[model] <= threshold]
        if not baseline: 
            df['mean score'] = df[models].mean(axis=1)
            df = df.sort_values(by='mean score', ascending=False)
            df = df.iloc[:k, :] if select_top else df.iloc[-k:, :]

        task_ids = list(df.index)
    return task_ids

def plot_patterns(domain, partition, models, rank, k, threshold, baseline, pattern, order):
    domain = domain2folder[domain]
    data_path = f"{BASE_DIR}/{domain}/{partition}/expanded_data.csv"
    if not os.path.exists(data_path):
        return None
    task_ids = get_filtered_task_ids(domain, partition, models, rank, k, threshold, baseline)
    expand_df = pd.read_csv(data_path)

    chart_df = expand_df[expand_df['model'].isin((models + [baseline]) if baseline else models)]
    chart_df = chart_df[chart_df['task id'].isin(task_ids)]
    print(pattern)
    freq, cols = eval(pattern)
    pattern_str = ""
    df = chart_df
    for col in cols:
        col_name, col_val = col
        try:
            col_val = int(col_val)
        except:
            col_val = col_val
        df = df[df[col_name] == col_val]
        pattern_str += f"{col_name} = {col_val}, "
        print(len(df))
    
    if baseline:
        model_str = (', '.join(models) if len(models) > 1 else models[0])
        phrase = f'{model_str} perform' if len(models) > 1 else f'{model_str} performs'
        title = f"{phrase} better than {baseline} on {freq} tasks where {pattern_str[:-2]}"
    else:
        title = f"Models are {'best' if rank == 'top' else 'worst'} at {freq} tasks where {pattern_str[:-2]}"
    
    chart = alt.Chart(df).mark_bar().encode(
            alt.X('model:N', 
                sort=alt.EncodingSortField(field=f'score', order=order, op="mean"), 
                axis=alt.Axis(labels=False, tickSize=0)), # no title, no label angle),
            alt.Y('mean(score):Q', scale=alt.Scale(zero=True)),
            alt.Color('model:N').legend(),
        ).properties(
            width=400,
            height=300,
            title=title
        )
    return chart

def plot_embedding(domain, partition, category):
    domain = domain2folder[domain]
    data_path = f"{BASE_DIR}/{domain}/{partition}/merged_data.csv"
    
    if os.path.exists(data_path):
        merged_df = pd.read_csv(data_path)
        # models = merged_df.columns 
        has_image = 'image' in merged_df
        chart = alt.Chart(merged_df).mark_point(size=30, filled=True).encode(
            alt.OpacityValue(0.5),
            alt.X('x:Q', title="UMAP Component 1"), 
            alt.Y('y:Q', title="UMAP Component 2"),
            alt.Color(f'{category}:N'),
            tooltip=['question', 'answer'] + (['image'] if has_image else []),  
        ).properties(
            width=800,
            height=800,
            title="UMAP Projected Task Embeddings"
        ).configure_axis(
            labelFontSize=25,
            titleFontSize=25,
        ).configure_title(
            fontSize=40
        ).configure_legend(
            labelFontSize=25,
            titleFontSize=25,
        ).interactive()
        return chart
    else:
        return None



def plot_multi_models(domain, partition, category, cat_options, models, order, pattern, aggregate="mean"):
    domain = domain2folder[domain]
    data_path = f"{BASE_DIR}/{domain}/{partition}/expanded_data.csv"
    if not os.path.exists(data_path):
        return None
    expand_df = pd.read_csv(data_path)
    print(pattern)
    if pattern is not None:
        df = expand_df
        freq, cols = eval(pattern)
        pattern_str = ""
        for col in cols:
            col_name, col_val = col
            try:
                col_val = int(col_val)
            except:
                col_val = col_val
            df = df[df[col_name] == col_val]
            pattern_str += f"{col_name} = {col_val}, "
        chart = alt.Chart(df).mark_bar().encode(
            alt.X('model:N', 
                sort=alt.EncodingSortField(field=f'score', order='ascending', op="mean"), 
                axis=alt.Axis(labels=False, tickSize=0)), # no title, no label angle),
            alt.Y('mean(score):Q', scale=alt.Scale(zero=True)),
            alt.Color('model:N').legend(),
        ).properties(
            width=200,
            height=100,
            title=f"How do models perform on tasks where {pattern_str[:-2]} (N={freq})?"
        )
        return chart
    else:
        df = expand_df[(expand_df['model'].isin(models)) & (expand_df[category].isin(cat_options))]
        if len(models) > 1:
            chart = alt.Chart(df).mark_bar().encode(
                alt.X('model:N', 
                    sort=alt.EncodingSortField(field=f'score', order=order, op="mean"), 
                    axis=alt.Axis(labels=False, tickSize=0, title=None)),
                alt.Y('mean(score):Q', scale=alt.Scale(zero=True)),
                alt.Color('model:N').legend(),
                alt.Column(f'{category}:N', header=alt.Header(titleOrient='bottom', labelOrient='bottom'))
            ).properties(
                width=200,
                height=100,
                title=f"How do models perform across {category}?"
            )
        else:
            chart = alt.Chart(df).mark_bar().encode(
                alt.X(f'{category}:N', sort=alt.EncodingSortField(field=f'score', order=order, op="mean")), # no title, no label angle),
                alt.Y('mean(score):Q', scale=alt.Scale(zero=True)),
                alt.Color(f'{category}:N').legend(None),
            ).properties(
                width=200,
                height=100,
                title=f"How does {models[0]} perform across {category}?"
            )
        chart = chart.configure_title(fontSize=15, offset=5, orient='top', anchor='middle')
        return chart


def plot(domain, partition, models, rank, k, threshold, baseline, order, category, cat_options):
    domain = domain2folder[domain]
    data_path = f"{BASE_DIR}/{domain}/{partition}/merged_data.csv"
    expand_data_path = f"{BASE_DIR}/{domain}/{partition}/expanded_data.csv"
    # task_plan.reset_index(inplace=True)
    if not os.path.exists(data_path) or not os.path.exists(expand_data_path):
        return None
    else:
        merged_df = pd.read_csv(data_path)
        merged_df.rename(columns={'llavav1.5-7b': 'llava15-7b', 'llavav1.5-13b': 'llava15-13b'}, inplace=True)
        expand_df = pd.read_csv(expand_data_path)
        
        df = merged_df

        select_top = rank == "top"
        # Model X is good / bad at
        for model in models:
            if baseline:
                df = df[df[model] >= df[baseline]]
            else:
                if select_top:
                    df = df[df[model] >= threshold]
                else:
                    df = df[df[model] <= threshold]
        if not baseline: 
            df['mean score'] = df[models].mean(axis=1)
            df = df.sort_values(by='mean score', ascending=False)
            df = df.iloc[:k, :] if select_top else df.iloc[-k:, :]

        task_ids = list(df.index)
        if baseline:
            models += [baseline]

        chart_df = expand_df[expand_df['model'].isin(models)]
        chart_df = chart_df[chart_df['task id'].isin(task_ids)]

        if cat_options:
            df = chart_df[chart_df[category].isin(cat_options)]
        else:
            df = chart_df
        if baseline:
            model_str = (', '.join(models) if len(models) > 1 else models[0])
            phrase = f'{model_str} perform' if len(models) > 1 else f'{model_str} performs'
            title = f"Are there any tasks where {phrase} better than {baseline} (by {category})?"
            
        else:
            title = f"What tasks are models {'best' if select_top else 'worst'} at by {category}?"

        if len(models) > 1:
            chart = alt.Chart(df).mark_bar().encode(
                alt.X('model:N', 
                    sort=alt.EncodingSortField(field=f'score', order=order, op="mean"), 
                    axis=alt.Axis(labels=False, tickSize=0, title=None)),
                alt.Y('mean(score):Q', scale=alt.Scale(zero=True)),
                alt.Color('model:N').legend(),
                alt.Column(f'{category}:N', header=alt.Header(titleOrient='bottom', labelOrient='bottom'))
            ).properties(
                width=200,
                height=100,
                title=title
            )
        else:
            chart = alt.Chart(df).mark_bar().encode(
                alt.X(f'{category}:N', sort=alt.EncodingSortField(field=f'score', order=order, op="mean")), # no title, no label angle),
                alt.Y('mean(score):Q', scale=alt.Scale(zero=True)),
                alt.Color(f'{category}:N').legend(None),
            ).properties(
                width=200,
                height=100,
                title=f"What tasks is model {models[0]} {'best' if select_top else 'worst'} at by {category}?"
            )
        chart = chart.configure_title(fontSize=15, offset=5, orient='top', anchor='middle')
        return chart
   

def get_frequent_patterns(task_plan, scores):
    find_frequent_patterns(k=10, df=task_plan, scores=scores)

def list_directories(path):
    """List all directories within a given path."""
    return [d for d in os.listdir(path) if os.path.isdir(os.path.join(path, d))]


def update_category(domain, partition):
    domain = domain2folder[domain]
    data_path = f"{BASE_DIR}/{domain}/{partition}/task_plan.pkl"
    if os.path.exists(data_path):
        data = pickle.load(open(data_path, 'rb'))
        categories = list(data.columns)
        category = gr.Dropdown(categories+["task id"], value=None, label="task metadata", interactive=True)
        return category
    else:
        return gr.Dropdown([], value=None, label="task metadata")
    
def update_category2(domain, partition, existing_category):
    domain = domain2folder[domain]
    data_path = f"{BASE_DIR}/{domain}/{partition}/task_plan.pkl"
    if os.path.exists(data_path):
        data = pickle.load(open(data_path, 'rb'))
        categories = list(data.columns)
        if existing_category and existing_category in categories:
            categories.remove(existing_category)
        category = gr.Dropdown(categories, value=None, label="Optional: second task metadata", interactive=True)
        return category
    else:
        return gr.Dropdown([], value=None, label="task metadata")

def update_partition(domain):
    domain = domain2folder[domain]
    path = f"{BASE_DIR}/{domain}"
    if os.path.exists(path):
        partitions = list_directories(path)
        return gr.Dropdown(partitions, value=partitions[0], label="task space of the following task generator")
    else:
        return gr.Dropdown([], value=None, label="task space of the following task generator")

def update_k(domain, partition, category=None):
    domain = domain2folder[domain]
    data_path = f"{BASE_DIR}/{domain}/{partition}/merged_data.csv"
    if os.path.exists(data_path):
        data = pd.read_csv(data_path)
        max_k = len(data[category].unique()) if category and category != "task id" else len(data)
        mid = max_k // 2
        return gr.Slider(1, max_k, mid, step=1.0, label="k")
    else:
        return gr.Slider(1, 1, 1, step=1.0, label="k")

# def update_category_values(domain, partition, category):
#     data_path = f"{BASE_DIR}/{domain}/{partition}/merged_data.csv"
#     if os.path.exists(data_path) and category is not None:
#         data = pd.read_csv(data_path)
#         uni_cats = list(data[category].unique())
#         return gr.Dropdown(uni_cats, multiselect=True, value=None, interactive=True, label="category values")
#     else:
#         return gr.Dropdown([], multiselect=True, value=None, interactive=False, label="category values")
    
# def update_category_values(domain, partition, models, rank, k, threshold, baseline, category):
#     data_path = f"{BASE_DIR}/{domain}/{partition}/merged_data.csv"

#     if not os.path.exists(data_path):
#         return gr.Dropdown([], multiselect=True, value=None, interactive=False, label="category values")
#     else:
#         merged_df = pd.read_csv(data_path)
#         merged_df.rename(columns={'llavav1.5-7b': 'llava15-7b', 'llavav1.5-13b': 'llava15-13b'}, inplace=True)
        
#         df = merged_df

#         select_top = rank == "top"
#         # Model X is good / bad at
#         for model in models:
#             if baseline:
#                 df = df[df[model] >= df[baseline]]
#             else:
#                 if select_top:
#                     df = df[df[model] >= threshold]
#                 else:
#                     df = df[df[model] <= threshold]
#         if not baseline: 
#             df['mean score'] = df[models].mean(axis=1)
#             df = df.sort_values(by='mean score', ascending=False)
#             df = df.iloc[:k, :] if select_top else df.iloc[-k:, :]
#         uni_cats = list(df[category].unique())
#         return gr.Dropdown(uni_cats, multiselect=True, value=None, interactive=True, label="category values")

    
def update_tasks(domain, partition, find_pattern):
    domain = domain2folder[domain]
    if find_pattern == "yes":
        k1 = gr.Slider(1, 10000, 10, step=1.0, label="k", interactive=True)
        pattern = gr.Dropdown([], value=None, interactive=True, label="pattern")
        category1 =  gr.Dropdown([], value=None, interactive=False, label="task metadata")
        return [k1,  pattern, category1]
    else:
        k1 = gr.Slider(1, 10000, 10, step=1.0, label="k", interactive=False)
        pattern = gr.Dropdown([], value=None, interactive=False, label="pattern")

        data_path = f"{BASE_DIR}/{domain}/{partition}/merged_data.csv"
        if os.path.exists(data_path):
            data = pd.read_csv(data_path)
            non_columns =  MODELS + ['question', 'answer']
            categories = [cat for cat in list(data.columns) if cat not in non_columns]
            category1 = gr.Dropdown(categories, value=categories[0], interactive=True, label="task metadata")
        else:
            category1 = gr.Dropdown([], value=None, label="task metadata")
        return [k1, pattern, category1]


def update_pattern(domain, partition, k):
    domain = domain2folder[domain]
    data_path = f"{BASE_DIR}/{domain}/{partition}/patterns.pkl"
    if not os.path.exists(data_path):
        return gr.Dropdown([], value=None, interactive=False, label="pattern")
    else:
        results = pickle.load(open(data_path, 'rb'))
        patterns = results[0]
        patterns = [str(p) for p in patterns]
        print(patterns)
        return gr.Dropdown(patterns[:k], value=None, interactive=True, label="pattern")

def update_threshold(domain, partition, baseline):
    domain = domain2folder[domain]
    print(baseline)
    if baseline:
        rank = gr.Radio(['top', 'bottom'], value='top', label="rank", interactive=False)
        k = gr.Slider(1, 10000, 10, step=1.0, label="k", interactive=False)
        threshold = gr.Slider(0, 1, 0.0, label="threshold", interactive=False)
        return [rank, k, threshold]
    else:
        data_path = f"{BASE_DIR}/{domain}/{partition}/merged_data.csv"
        if os.path.exists(data_path):
            data = pd.read_csv(data_path)
            max_k = len(data)
            print(max_k)
            k = gr.Slider(1, max_k, 10, step=1.0, label="k", interactive=True)
        else:
            k = gr.Slider(1, 1, 1, step=1.0, label="k")
        rank = gr.Radio(['top', 'bottom'], value='top', label="rank", interactive=True)

        threshold = gr.Slider(0, 1, 0.0, label="threshold", interactive=True)
        return [rank, k, threshold]

def calc_surprisingness(model, scores, embeddings, k):
    scores = scores[model].to_numpy()
    sim = embeddings @ embeddings.T
    # print("sim values:", sim.shape, sim)
    indices = np.argsort(-sim)[:, :k]
    # print("indices:", indices.shape, indices)
    score_diff = scores[:, None] - scores[indices]
    # print("score differences:", score_diff.shape, score_diff)
    sim = sim[np.arange(len(scores))[:, None], indices]
    # print("top10 sim:", sim.shape, sim)
    all_surprisingness = score_diff * sim
    # print("all surprisingness:", all_surprisingness.shape, all_surprisingness)
    mean_surprisingness = np.mean(score_diff * sim, axis=1)
    res = {'similarity': sim, 
           'task index': indices,
           'score difference': score_diff, 
           'all surprisingness': all_surprisingness, 
           'mean surprisingness': mean_surprisingness
          }
    return res


def plot_surprisingness(domain, partition, model, rank, k, num_neighbors):
    domain = domain2folder[domain]
    # model = model[0]
    model_str = model.replace("-", "_")

    # sp_path = f"{BASE_DIR}/{domain}/{partition}/surprise_data.csv"
    sp_pkl = f"{BASE_DIR}/{domain}/{partition}/{model_str}_surprise.pkl"
    merged_path = f"{BASE_DIR}/{domain}/{partition}/merged_data.csv"
    if os.path.exists(sp_pkl) and os.path.exists(merged_path): # and not os.path.exists(sp_path)
        # if os.path.exists(sp_path):
        #     sp_df = pd.read_csv(sp_path)
        #     # res = calc_surprisingness(model, scores, embeds, num_neighbors)
        #     # k = 10
        #     model = 'qwenvl'
        #     num_neighbors = 10
        # if os.path.exists(sp_pkl):
        res = pickle.load(open(sp_pkl, 'rb'))
        
        total_num_task = res['task index'].shape[0]
        all_records = []
        for i in range(total_num_task):
            mean_surprisingness = np.mean(res['all surprisingness'][i, :num_neighbors])
            for j in range(num_neighbors):
                neighbor_id = res['task index'][i, j]
                score_diff = res['score difference'][i, j]
                surprisingness = res['all surprisingness'][i, j]
                similarity = res['similarity'][i, j]

                record = {"task id": i, 
                        "neighbor rank": j, 
                        "neighbor id": neighbor_id, 
                        "score difference": score_diff,
                        "surprisingness": surprisingness, 
                        "mean surprisingness": mean_surprisingness,
                        "similarity": similarity
                        }
                # print(record)
                all_records.append(record)
        sp_df = pd.DataFrame.from_records(all_records)
        sp_df = sp_df.sort_values(by="mean surprisingness", ascending=False)
        
        num_rows = k * num_neighbors
        df = sp_df.iloc[:num_rows, :] if rank == "top" else sp_df.iloc[-num_rows:, :]
        print(len(df))
        
        df['is target'] = df.apply(lambda row: int(row['task id'] == row['neighbor id']), axis=1)

        merged_df = pd.read_csv(merged_path)
        for col in merged_df.columns:
            df[col] = df.apply(lambda row: merged_df.iloc[int(row['neighbor id']), :][col], axis=1)
                    
        tooltips = ['neighbor id'] + ['image', 'question', 'answer', model]

        print(df.head())
        pts = alt.selection_point(encodings=['x'])
        embeds = alt.Chart(df).mark_point(size=30, filled=True).encode(
            alt.OpacityValue(0.5),
            alt.X('x:Q', scale=alt.Scale(zero=False)), 
            alt.Y('y:Q', scale=alt.Scale(zero=False)),
            alt.Color(f'{model}:Q'), #scale=alt.Scale(domain=[1, 0.5, 0], range=['blue', 'white', 'red'], interpolate='rgb')
            alt.Size("is target:N", legend=None, scale=alt.Scale(domain=[0, 1], range=[300, 500])),
            alt.Shape("is target:N", legend=None, scale=alt.Scale(domain=[0, 1], range=['circle', 'triangle'])),
            alt.Order("is target:N"),
            tooltip=tooltips,  
        ).properties(
            width=400,
            height=400,
            title=f"What are the tasks {model} is surprisingly {'good' if rank == 'top' else 'bad'} at compared to {num_neighbors} similar tasks?"
        ).transform_filter(
            pts
        )

        bar = alt.Chart(df).mark_bar().encode(
            alt.Y('mean(mean surprisingness):Q'),
            alt.X('task id:N', sort=alt.EncodingSortField(field='mean surprisingness', order='descending')),
            color=alt.condition(pts, alt.ColorValue("steelblue"), alt.ColorValue("grey")), #
        ).add_params(pts).properties(
            width=400,
            height=200,
        )

        chart = alt.hconcat(
            bar,
            embeds
        ).resolve_legend(
            color="independent",
            size="independent"
        ).configure_title(
            fontSize=20
        ).configure_legend(
            labelFontSize=10,
            titleFontSize=10,
        )
        return chart
    else:
        print(sp_pkl, merged_path)
        return None
    


def plot_task_distribution(domain, partition, category):
    domain = domain2folder[domain]
    task_plan = pickle.load(open(f"{BASE_DIR}/{domain}/{partition}/task_plan.pkl", "rb"))
    task_plan.reset_index(inplace=True)
    col_name = category
    task_plan_cnt = task_plan.groupby(col_name)['index'].count().reset_index()
    task_plan_cnt.rename(columns={'index': 'count'}, inplace=True)
    task_plan_cnt['frequency (%)'] = round(task_plan_cnt['count'] / len(task_plan) * 100, 2)
    task_plan_cnt.head()

    base = alt.Chart(task_plan_cnt).encode(
    alt.Theta("count:Q").stack(True),
    alt.Color(f"{col_name}:N").legend(),
    tooltip=[col_name, 'count', 'frequency (%)']
    )
    pie = base.mark_arc(outerRadius=120)
    return pie

def plot_all(domain, partition, models, category1, category2, agg):
    domain = domain2folder[domain]
    data_path = f"{BASE_DIR}/{domain}/{partition}/expanded_data.csv"
    if not os.path.exists(data_path):
        return None
    expand_df = pd.read_csv(data_path)
    chart_df = expand_df[expand_df['model'].isin(models)]
    if category2:
    
        color_val = f'{agg}(score):Q'

        chart = alt.Chart(chart_df).mark_rect().encode(
            alt.X(f'{category1}:N', sort=alt.EncodingSortField(field='score', order='ascending', op=agg)),
            alt.Y(f'{category2}:N', sort=alt.EncodingSortField(field='score', order='descending', op=agg)), # no title, no label angle),
            alt.Color(color_val),
            alt.Tooltip('score', aggregate=agg, title=f"{agg} score"),  
        ).properties(
            width=800,
            height=200,
        )
    else:
        category = "index" if category1 == "task id" else category1
        # cat_options = list(chart_df[category].unique())
        # cat_options = cat_options[:5]
        y_val = f'{agg}(score):Q'
        df = chart_df
        # df = chart_df[chart_df[category].isin(cat_options)]
        if len(models) > 1:
            chart = alt.Chart(df).mark_bar().encode(
                alt.X('model:N', 
                    sort=alt.EncodingSortField(field=f'score', order='ascending', op=agg), 
                    axis=alt.Axis(labels=False, tickSize=0, title=None)),
                alt.Y(y_val, scale=alt.Scale(zero=True)),
                alt.Color('model:N').legend(),
                alt.Column(f'{category}:N', header=alt.Header(titleOrient='bottom', labelOrient='bottom'))
            ).properties(
                width=200,
                height=100,
                title=f"How do models perform across {category}?"
            )
        else:
            chart = alt.Chart(df).mark_bar().encode(
                alt.X(f'{category}:N', sort=alt.EncodingSortField(field=f'score', order='ascending', op=agg)), # no title, no label angle),
                alt.Y(y_val, scale=alt.Scale(zero=True)),
                alt.Color(f'{category}:N').legend(None),
            ).properties(
                width=200,
                height=100,
                title=f"How does {models[0]} perform across {category}?"
            )
        chart = chart.configure_title(fontSize=20, offset=5, orient='top', anchor='middle').configure_axis(
            labelFontSize=20,
            titleFontSize=20,
        ).configure_legend(
            labelFontSize=15,
            titleFontSize=15,
        )
    return chart

def update_widgets(domain, partition, category, query_type):
    domain = domain2folder[domain]
    data_path = f"{BASE_DIR}/{domain}/{partition}/expanded_data.csv"
    if not os.path.exists(data_path):
        print("here?")
        return [None] * 11
    df = pd.read_csv(data_path)
    max_k = len(df[category].unique()) if category and category != "task id" else len(df)

    widgets = []

    if query_type == "top k":
        # aggregate = gr.Dropdown(['mean', 'median', 'min', 'max'], value="mean", label=" ", interactive=True, visible=True)
        rank = gr.Radio(['top', 'bottom'], value='top', label=" ", interactive=True, visible=True)
        k = gr.Slider(1, max_k, max_k // 2, step=1.0, label="k", interactive=True, visible=True)
        model = gr.Dropdown(MODELS, value=MODELS, label="of model(s)'", multiselect=True, interactive=True, visible=True)
        # model_aggregate = gr.Radio(['mean', 'median', 'min', 'max'], value="mean", label="task category aggregate", interactive=True, visible=True)
        model_aggregate = gr.Dropdown(['mean', 'median', 'min', 'max'], value="mean", label=" ", interactive=True, visible=True)

        baseline = gr.Dropdown(MODELS, value=None, label="baseline", visible=False)
        direction = gr.Radio(['above', 'below'], value='above', label=" ", visible=False)
        threshold = gr.Slider(0, 1, 0.0, label="threshold", visible=False)
        baseline_aggregate = gr.Radio(['mean', 'median', 'min', 'max'], value="mean", label="baseline aggregate", visible=False)
        md1 = gr.Markdown(r"<h2>ranked by the </h2>")
        md2 = gr.Markdown(r"<h2>accuracy</h2>")
        md3 = gr.Markdown(r"")

    elif query_type == "threshold":
        
        # aggregate = gr.Radio(['mean', 'median', 'min', 'max'], value="mean", label="task aggregate", interactive=True, visible=True)
        # aggregate = gr.Dropdown(['mean', 'median', 'min', 'max'], value="mean", label=" ", interactive=True, visible=True)
        model = gr.Dropdown(MODELS, value=MODELS[0], label="of model(s)'", multiselect=True, interactive=True, visible=True)
        direction = gr.Radio(['above', 'below'], value='above', label=" ", interactive=True, visible=True)
        threshold = gr.Slider(0, 1, 0.0, label="threshold", interactive=True, visible=True)
        # model_aggregate = gr.Radio(['mean', 'median', 'min', 'max'], value="mean", label="task category aggregate", interactive=True, visible=True)
        model_aggregate = gr.Dropdown(['mean', 'median', 'min', 'max'], value="mean", label=" ", interactive=True, visible=True)
        
        rank = gr.Radio(['top', 'bottom'], value='top', label=" ", visible=False)
        k = gr.Slider(1, max_k, max_k // 2, step=1.0, label="k", visible=False)
        baseline = gr.Dropdown(MODELS, value=None, label="baseline", visible=False)
        baseline_aggregate = gr.Radio(['mean', 'median', 'min', 'max'], value="mean", label="baseline aggregate", visible=False)
        md1 = gr.Markdown(r"<h2>where the</h2>")
        md2 = gr.Markdown(r"<h2>accuracy is</h2>")
        md3 = gr.Markdown(r"")

    elif query_type == "model comparison":
        
        model = gr.Dropdown(MODELS, value=MODELS[0], label="of model(s)' accuracy", multiselect=True, interactive=True, visible=True)
        baseline = gr.Dropdown(MODELS, value=None, label="of baseline(s)' accuracy", multiselect=True, interactive=True, visible=True)
        direction = gr.Radio(['above', 'below'], value='above', label=" ", interactive=True, visible=True)
        threshold = gr.Slider(0, 1, 0.0, label="threshold", interactive=True, visible=True)
        model_aggregate = gr.Dropdown(['mean', 'median', 'min', 'max'], value="mean", label=" ", interactive=True, visible=True)
        # baseline_aggregate = gr.Radio(['mean', 'median', 'min', 'max'], value="mean", label="task category aggregate (over baselines)", interactive=True, visible=True)
        baseline_aggregate = gr.Dropdown(['mean', 'median', 'min', 'max'], value="mean", label=" ", interactive=True, visible=True)
        
        # aggregate = gr.Radio(['mean', 'median', 'min', 'max'], value="mean", label="task aggregate", interactive=True, visible=False)
        rank = gr.Radio(['top', 'bottom'], value='top', label=" ", visible=False)
        k = gr.Slider(1, max_k, max_k // 2, step=1.0, label="k", visible=False)
        md1 = gr.Markdown(r"<h2>where the difference between the </h2>")
        md2 = gr.Markdown(r"<h2>is </h2>")
        md3 = gr.Markdown(r"<h2>and the</h2>")

    elif query_type == "model debugging":
        model = gr.Dropdown(MODELS, value=MODELS[0], label="model's", multiselect=False, interactive=True, visible=True)

        # aggregate = gr.Dropdown(['mean', 'median', 'min', 'max'], value="mean", label=" ", visible=False)
        baseline = gr.Dropdown(MODELS, value=None, label="baseline", visible=False)
        direction = gr.Radio(['above', 'below'], value='above', label=" ", visible=False)
        threshold = gr.Slider(0, 1, 0.0, label="threshold", visible=False)
        rank = gr.Radio(['top', 'bottom'], value='top', label=" ", visible=False)
        k = gr.Slider(1, max_k, max_k // 2, step=1.0, label="k", visible=False)
        model_aggregate = gr.Radio(['mean', 'median', 'min', 'max'], value="mean", label="task category aggregate (over models)", visible=False)
        baseline_aggregate = gr.Radio(['mean', 'median', 'min', 'max'], value="mean", label="baseline aggregate", visible=False)
        md1 = gr.Markdown(r"<h2>where </h2>")
        md2 = gr.Markdown(r"<h2>mean accuracy is below its overall mean accuracy by one standard deviation</h2>")
        md3 = gr.Markdown(r"")
    else:
        widgets = [None] * 11
    widgets = [rank, k, direction, threshold, model, model_aggregate, baseline, baseline_aggregate, md1, md2, md3]

    return widgets
    
def select_tasks(domain, partition, category, query_type, task_agg, models, model_agg, rank, k, direction, threshold, baselines, baseline_agg):
    domain = domain2folder[domain]
    data_path = f"{BASE_DIR}/{domain}/{partition}/expanded_data.csv"
    merged_path = f"{BASE_DIR}/{domain}/{partition}/merged_data.csv"
    
    if not os.path.exists(data_path) or not os.path.exists(merged_path):
        return gr.DataFrame(None)
    df = pd.read_csv(data_path)
    merged_df = pd.read_csv(merged_path)
    task_plan = pickle.load(open(f"{BASE_DIR}/{domain}/{partition}/task_plan.pkl", 'rb'))
    task_plan.reset_index(inplace=True)
    if not category or category == "task id":
        category = 'index'
    
    if query_type == "top k":
        df = df[df['model'].isin(models)]
        df = df.groupby([category, 'model'])['score'].agg(task_agg).reset_index()
        df = df.groupby([category])['score'].agg(model_agg).reset_index()
        df = df.sort_values(by='score', ascending=False)
        if rank == "bottom":
            df = df.iloc[-k:, :]
        else:
            df = df.iloc[:k, :]
    elif query_type == "threshold":
        df = df[df['model'].isin(models)]
        df = df.groupby([category, 'model'])['score'].agg(task_agg).reset_index()
        df = df.groupby([category])['score'].agg(model_agg).reset_index()
        if direction == "below":
            df = df[df['score'] <= threshold]
        else:
            df = df[df['score'] >= threshold]
    elif query_type == "model comparison":
        # df = merged_df
        # df.reset_index(inplace=True)
        # df = df.groupby([category])[[model, baseline]].agg(task_agg).reset_index()
        # df = df[(df[model] - df[baseline] > threshold)]
        df_baseline = deepcopy(df)

        df = df[df['model'].isin(models)]
        df = df.groupby([category, 'model'])['score'].agg(task_agg).reset_index()
        df = df.groupby([category])['score'].agg(model_agg).reset_index()
        model_str = ', '.join(models)
        exp_score_id = f'{model_agg}({model_str})' if len(models) > 1 else model_str
        df = df.sort_values(by=category)

        df_baseline = df_baseline[df_baseline['model'].isin(baselines)]
        df_baseline = df_baseline.groupby([category, 'model'])['score'].agg(task_agg).reset_index()
        df_baseline = df_baseline.groupby([category])['score'].agg(baseline_agg).reset_index()
        model_str = ', '.join(baselines)
        baseline_score_id = f'{baseline_agg}({model_str})' if len(baselines) > 1 else model_str
        df_baseline = df_baseline.sort_values(by=category)


        df.rename(columns={'score': exp_score_id}, inplace=True)
        df_baseline.rename(columns={'score': baseline_score_id}, inplace=True)
        df = pd.merge(df, df_baseline, on=category)
        df = df[(df[exp_score_id] - df[baseline_score_id] > threshold)]

    elif query_type == "model debugging":
        model = models
        print(models)
        avg_acc = merged_df[model].mean()
        std = merged_df[model].std()
        t = avg_acc - std
        df = df[df['model'] == model]
        df = df.groupby(['model', category])['score'].agg(task_agg).reset_index()
        df = df[df['score'] < t]
        df['mean'] = round(avg_acc, 4)
        df['std'] = round(std, 4)
    
    print(df.head())
    if category == 'index':
        task_attrs = list(df[category])
        selected_tasks = task_plan[task_plan[category].isin(task_attrs)]

        if len(selected_tasks) == 0:
            return gr.DataFrame(None, label="There is no such task.")

        if query_type == "model comparison" and (models and baselines):
            # selected_tasks[model] = selected_tasks.apply(lambda row: df[df['index'] == row['index']][model].values[0], axis=1)
            # selected_tasks[baseline] = selected_tasks.apply(lambda row: df[df['index'] == row['index']][baseline].values[0], axis=1)
            selected_tasks[exp_score_id] = selected_tasks.apply(lambda row: df[df['index'] == row['index']][exp_score_id].values[0], axis=1)
            selected_tasks[baseline_score_id] = selected_tasks.apply(lambda row: df[df['index'] == row['index']][baseline_score_id].values[0], axis=1)
        else:
            selected_tasks['score'] = selected_tasks.apply(lambda row: df[df['index'] == row['index']]['score'].values[0], axis=1)
    
        print(selected_tasks.head())
        return gr.DataFrame(selected_tasks, label=f"There are {len(selected_tasks)} (out of {len(task_plan)}) tasks in total.")
    else:
        if len(df) == 0:
            return gr.DataFrame(None, label=f"There is no such {category}.")
        else:
            return gr.DataFrame(df, label=f"The total number of such {category} is {len(df)}.")


def find_patterns(selected_tasks, num_patterns, models, baselines, model_agg, baseline_agg):
    if len(selected_tasks) == 0:
        return gr.DataFrame(None)
    print(selected_tasks.head())
    if 'score' in selected_tasks:
        scores = selected_tasks['score']
    # elif model in selected_tasks:
    #     scores = selected_tasks[model]
    else:
        scores = None
    print(scores)

    model_str = ', '.join(models)
    exp_score_id = f'{model_agg}({model_str})' if len(models) > 1 else model_str
    if baselines:
        baseline_str = ', '.join(baselines)
        baseline_score_id = f'{baseline_agg}({baseline_str})' if len(baselines) > 1 else baseline_str
    
    tasks_only = selected_tasks
    all_score_cols = ['score', exp_score_id] 
    if baselines:
        all_score_cols += [baseline_score_id]
    for name in all_score_cols:
        if name in selected_tasks:
            tasks_only = tasks_only.drop(name, axis=1)
    results = find_frequent_patterns(k=num_patterns, df=tasks_only, scores=scores)
    records = []
    if scores is not None:
        patterns, scores = results[0], results[1]
        for pattern, score in zip(patterns, scores):
            pattern_str = ""
            for t in pattern[1]:
                col_name, col_val = t
                pattern_str += f"{col_name} = {col_val}, "
            
            record = {'pattern': pattern_str[:-2], 'count': pattern[0], 'score': score} #{model} 
            records.append(record)
    else:
        patterns = results
        for pattern in patterns:
            pattern_str = ""
            for t in pattern[1]:
                col_name, col_val = t
                pattern_str += f"{col_name} = {col_val}, "
            
            record = {'pattern': pattern_str[:-2], 'count': pattern[0]}
            records.append(record)

    df = pd.DataFrame.from_records(records)
    return gr.DataFrame(df)

def visualize_task_distribution(selected_tasks, col_name, model1, model2):
    if not col_name:
        return None
    task_plan_cnt = selected_tasks.groupby(col_name)['index'].count().reset_index()
    task_plan_cnt.rename(columns={'index': 'count'}, inplace=True)
    task_plan_cnt['frequency (%)'] = round(task_plan_cnt['count'] / len(selected_tasks) * 100, 2)
    print(task_plan_cnt.head())

    tooltips = [col_name, 'count', 'frequency (%)']
    base = alt.Chart(task_plan_cnt).encode(
        alt.Theta("count:Q").stack(True),
        alt.Color(f"{col_name}:N").legend(),
        tooltip=tooltips
    )
    pie = base.mark_arc(outerRadius=120)
            
    return pie

def plot_performance_for_selected_tasks(domain, partition, df, query_type, models, baselines, select_category, vis_category, task_agg, model_agg, baseline_agg, rank, direction, threshold):
    domain = domain2folder[domain]
    task_agg = "mean"
    data_path = f"{BASE_DIR}/{domain}/{partition}/expanded_data.csv"
    mereged_data_path = f"{BASE_DIR}/{domain}/{partition}/merged_data.csv"

    if not os.path.exists(data_path) or not os.path.exists(mereged_data_path) or len(df) == 0:
        return None
    
    select_tasks = select_category == "task id" and vis_category
    if select_tasks: # select tasks
        y_val = f'{task_agg}(score):Q'
    else: # select task categories
        y_val = f'score:Q'

    if select_category == "task id":
        select_category = "index"
    print(df.head())
    if query_type == "model comparison":
        # re-format the data for plotting
        model_str = ', '.join(models)
        exp_score_id = f'{model_agg}({model_str})' if len(models) > 1 else model_str
        baseline_str = ', '.join(baselines)
        baseline_score_id = f'{baseline_agg}({baseline_str})' if len(baselines) > 1 else baseline_str
        # other_cols = list(df.columns)
        # other_cols.remove(select_category)
        print(exp_score_id, baseline_score_id)
        df = df.melt(id_vars=[select_category], value_vars=[exp_score_id, baseline_score_id])
        df.rename(columns={'variable': 'model', 'value': 'score'}, inplace=True)
        print(df.head())

        if select_tasks:
            merged_df = pd.read_csv(mereged_data_path)
            df[vis_category] = df.apply(lambda row: merged_df[merged_df.index == row['index']][vis_category].values[0], axis=1)

        num_columns = len(df['model'].unique()) * len(df[f'{vis_category}'].unique())
        chart = alt.Chart(df).mark_bar().encode(
            alt.X('model:N', 
                sort=alt.EncodingSortField(field=f'score', order='descending', op=task_agg),
                axis=alt.Axis(labels=False, tickSize=0, title=None)),
            alt.Y(y_val, scale=alt.Scale(zero=True), title="accuracy"),
            alt.Color('model:N').legend(),
            alt.Column(f'{vis_category}:N', header=alt.Header(titleOrient='bottom', labelOrient='bottom', labelFontSize=20, titleFontSize=20,))
        ).properties(
            width=num_columns * 30,
            height=200,
            title=f"How do models perform by {vis_category}?"
        )
        print(num_columns * 50)
    else:
        if query_type == "model debugging":
            y_title = "accuracy"
            plot_title = f"{models} performs worse than its (mean - std) on these {vis_category}s"
            models = [models]
        else:
            model_str = ', '.join(models) 
            y_title = f"{model_agg} accuracy" if len(models) > 0 else "accuracy"
            suffix = f"on these tasks (by {vis_category})" if select_category == "index" else f"on these {vis_category}s"
            if query_type == "top k":
                plot_title = f"The {model_agg} accuracy of {model_str} is the {'highest' if rank == 'top' else 'lowest'} " + suffix
            elif query_type == "threshold":
                plot_title = f"The {model_agg} accuracy of {model_str} is {direction} {threshold} " + suffix

        if select_tasks:
            expand_df = pd.read_csv(data_path)
            task_ids = list(df['index'].unique())
            
            # all_models = (models + baselines) if baselines else models
            df = expand_df[(expand_df['model'].isin(models)) & (expand_df['task id'].isin(task_ids))] 
        
        num_columns = len(df[f'{vis_category}'].unique())
        chart = alt.Chart(df).mark_bar().encode(
            alt.X(f'{vis_category}:N', sort=alt.EncodingSortField(field=f'score', order='ascending', op=task_agg), axis=alt.Axis(labelAngle=-45)), # no title, no label angle),
            alt.Y(y_val, scale=alt.Scale(zero=True), title=y_title),
            alt.Color(f'{vis_category}:N').legend(None),
        ).properties(
            width=num_columns * 30,
            height=200,
            title=plot_title
        )

    chart = chart.configure_title(fontSize=20, offset=5, orient='top', anchor='middle').configure_axis(
            labelFontSize=20,
            titleFontSize=20,
        ).configure_legend(
            labelFontSize=20,
            titleFontSize=20,
            labelLimit=200,
        )
    return chart

def sync_vis_category(domain, partition, category):
    domain = domain2folder[domain]
    if category and category != "task id":
        return [gr.Dropdown([category], value=category, label="by task metadata", interactive=False), gr.Dropdown([category], value=category, label="by task metadata", interactive=False)]
    else:
        data_path = f"{BASE_DIR}/{domain}/{partition}/task_plan.pkl"
        if os.path.exists(data_path):
            data = pickle.load(open(data_path, 'rb'))
            categories = list(data.columns)
            return [gr.Dropdown(categories, value=categories[0], label="by task metadata", interactive=True), gr.Dropdown(categories, value=categories[0], label="by task metadata", interactive=True)]
        else:
            return [None, None]

def hide_fpm_and_dist_components(domain, partition, category):
    domain = domain2folder[domain]
    print(category)
    if category and category != "task id":
        num_patterns = gr.Slider(1, 100, 50, step=1.0, label="number of patterns", visible=False)
        btn_pattern = gr.Button(value="Find patterns among tasks", visible=False)
            
        table = gr.DataFrame({}, height=250, visible=False)     
        dist_chart = Plot(visible=False)
        
        col_name = gr.Dropdown([], value=None, label="by task metadata", visible=False)
        btn_dist = gr.Button(value="Visualize task distribution", visible=False)
    else:
        data_path = f"{BASE_DIR}/{domain}/{partition}/task_plan.pkl"
        if os.path.exists(data_path):
            data = pickle.load(open(data_path, 'rb'))
            categories = list(data.columns)
            col_name = gr.Dropdown(categories, value=categories[0], label="by task metadata", interactive=True, visible=True)
        else:
            col_name = gr.Dropdown([], value=None, label="by task metadata", interactive=True, visible=True)
        
        num_patterns = gr.Slider(1, 100, 50, step=1.0, label="number of patterns", interactive=True, visible=True)
        btn_pattern = gr.Button(value="Find patterns among tasks", interactive=True, visible=True)
            
        table = gr.DataFrame({}, height=250, interactive=True, visible=True)     
        dist_chart = Plot(visible=True)
        
        btn_dist = gr.Button(value="Visualize task distribution", interactive=True, visible=True)
    return [num_patterns, btn_pattern, table, col_name, btn_dist, dist_chart]



# domains = list_directories(BASE_DIR)
theme = gr.Theme.from_hub('sudeepshouche/minimalist')
theme.font = [gr.themes.GoogleFont("Inconsolata"), "Arial", "sans-serif"] # gr.themes.GoogleFont("Source Sans Pro") # [gr.themes.GoogleFont("Inconsolata"), "Arial", "sans-serif"]
theme.text_size = gr.themes.sizes.text_lg
# theme = theme.set(font=)

demo = gr.Blocks(theme=theme, title="TaskMeAnything-UI") #
with demo:
    with gr.Row():
        with gr.Column(scale=1):
                gr.Markdown(
                    r""
                )
        with gr.Column(scale=1):
            gr.Markdown(
                    r"<h1>Welcome to TaskMeAnything-UI! </h1>"
                )
        with gr.Column(scale=1):
            gr.Markdown(
                    r""
                )
        
    with gr.Tab("📊 Overview"):
        gr.Markdown(
            r"<h2>📊 Visualize the overall task distribution and model performance </h2>"
        )
        
        with gr.Row():
            domain = gr.Radio(domains, label="scenario", scale=2)
            partition = gr.Dropdown([], value=None, label="task space of the following task generator", scale=1)
            # domain.change(fn=update_partition, inputs=domain, outputs=partition)
            

        gr.Markdown(
            r"<h2>Overall task metadata distribution</h2>"
        )

        with gr.Row():
            category = gr.Dropdown([], value=None, label="task metadata")
            partition.change(fn=update_category, inputs=[domain, partition], outputs=category)
        with gr.Row():
            output = Plot()
        with gr.Row():
            btn = gr.Button(value="Plot")
            btn.click(plot_task_distribution, [domain, partition, category], output)

        gr.Markdown(
            r"<h2>Models' overall performance by task metadata</h2>"
        )
        with gr.Row():
            with gr.Column(scale=2):
                models = gr.CheckboxGroup(MODELS, label="model(s)", value=MODELS)
            with gr.Column(scale=1):
                aggregate = gr.Radio(['mean', 'median', 'min', 'max'], value="mean", label="aggregate models' accuracy by")
        with gr.Row():
            # with gr.Column(scale=1):
                category1 = gr.Dropdown([], value=None, label="task metadata", interactive=True)
                category2 = gr.Dropdown([], value=None, label="Optional: second task metadata", interactive=True)
                partition.change(fn=update_category, inputs=[domain, partition], outputs=category1)
                category1.change(fn=update_category2, inputs=[domain, partition, category1], outputs=category2)
        domain.change(fn=update_partition_and_models, inputs=domain, outputs=[partition, models])
        with gr.Row():
            output = Plot()
        with gr.Row():
            btn = gr.Button(value="Plot")
            btn.click(plot_all, [domain, partition, models, category1, category2, aggregate], output)
        # gr.Examples(["hello", "bonjour", "merhaba"], input_textbox)


    with gr.Tab("✨ Embedding"):
        gr.Markdown(
            r"<h2>✨ Visualize the tasks' embeddings in the 2D space </h2>"
        )
        with gr.Row():
            domain2 = gr.Radio(domains, label="scenario", scale=2)
            # domain = gr.Dropdown(domains, value=domains[0], label="scenario")
            partition2 = gr.Dropdown([], value=None, label="task space of the following task generator", scale=1)
            category2 = gr.Dropdown([], value=None, label="colored by task metadata", scale=1)
            domain2.change(fn=update_partition, inputs=domain2, outputs=partition2)
            partition2.change(fn=update_category, inputs=[domain2, partition2], outputs=category2)
        
        with gr.Row():
            output2 = Plot()
        with gr.Row():
            btn = gr.Button(value="Run")
            btn.click(plot_embedding, [domain2, partition2, category2], output2)

            
    with gr.Tab("❓ Query"):
        gr.Markdown(
            r"<h2>❓ Find out the answers to your queries by finding and visualizing the relevant tasks and models' performance </h2>"
        )
        with gr.Row(equal_height=True):
            domain = gr.Radio(domains, label="scenario", scale=2)
            partition = gr.Dropdown([], value=None, label="task space of the following task generator", scale=1)
        with gr.Row():
            query1 = "top k"
            query2 = "threshold"
            query3 = "model debugging"
            query4 = "model comparison"
            query_type = gr.Radio([query1, query2, query3, query4], value="top k", label=r"query type")
        with gr.Row():
            with gr.Accordion("See more details about the query type"):
                gr.Markdown(
                r"<ul><li>Top k: Find the k tasks or task metadata that the model(s) perform the best or worst on</li><li>Threshold: Find the tasks or task metadata where the model(s)' performance is greater or lower than a given threshold t</li><li>Model debugging: Find the tasks or task metadata where a model performs significantly worse than its average performance (by one standard deviation)</li><li>Model comparison: Find the tasks or task metadata where some model(s) perform better or worse than the baseline(s) by a given threshold t</li></ul>"
            )
        
        with gr.Row():
            gr.Markdown(r"<h2>Help me find the</h2>")
        with gr.Row(equal_height=True):
            # with gr.Column(scale=1):
            rank = gr.Radio(['top', 'bottom'], value='top', label=" ", interactive=True, visible=True)
            # with gr.Column(scale=2):
            k = gr.Slider(1, 10, 5 // 2, step=1.0, label="k", interactive=True, visible=True)
            # with gr.Column(scale=2):
            category = gr.Dropdown([], value=None, label="tasks / task metadata", interactive=True)
            
        with gr.Row():
            md1 = gr.Markdown(r"<h2>ranked by the </h2>")

        with gr.Row(equal_height=True):
            # with gr.Column(scale=1, min_width=100):
                # model_aggregate = gr.Radio(['mean', 'median', 'min', 'max'], value="mean", label=" ", interactive=True, visible=True)
            model_aggregate = gr.Dropdown(['mean', 'median', 'min', 'max'], value="mean", label=" ", interactive=True, visible=True, scale=1)
            # with gr.Column(scale=8):
            model = gr.Dropdown(MODELS, value=MODELS, label="of model(s)", multiselect=True, interactive=True, visible=True, scale=2)
            # with gr.Column(scale=1, min_width=100):
            # aggregate = gr.Dropdown(['mean', 'median', 'min', 'max'], value="mean", label=" ", interactive=True, visible=True, scale=1)
        with gr.Row():
            md3 = gr.Markdown(r"")
        with gr.Row(equal_height=True):
            baseline_aggregate = gr.Dropdown(['mean', 'median', 'min', 'max'], value="mean", label=" ", interactive=True, visible=False, scale=1)
            baseline = gr.Dropdown(MODELS, value=None, label="of baseline(s)'", visible=False, scale=2)
            # aggregate = gr.Radio(['mean', 'median', 'min', 'max'], value="mean", label=" ", interactive=True, visible=True)
            # with gr.Column(scale=1, min_width=50):
        with gr.Row():
            md2 = gr.Markdown(r"<h2>accuracy</h2>")
        
        with gr.Row():
            # baseline_aggregate = gr.Radio(['mean', 'median', 'min', 'max'], value="mean", label="task category aggregate (over baselines)", visible=False)
            direction = gr.Radio(['above', 'below'], value='above', label=" ", visible=False)
            threshold = gr.Slider(0, 1, 0.0, label="threshold", visible=False)
            
        widgets = [rank, k, direction, threshold, model, model_aggregate, baseline, baseline_aggregate, md1, md2, md3]
        partition.change(fn=update_category, inputs=[domain, partition], outputs=category)
        query_type.change(update_widgets, [domain, partition, category, query_type], widgets)
        domain.change(fn=update_partition_and_models_and_baselines, inputs=domain, outputs=[partition, model, baseline])
        with gr.Row():
            df = gr.DataFrame({}, height=200)
        btn = gr.Button(value="Find tasks / task metadata")
        btn.click(select_tasks, [domain, partition, category, query_type, aggregate, model, model_aggregate, rank, k, direction, threshold, baseline, baseline_aggregate], df)

        with gr.Row():
            plot = Plot()
        with gr.Row():
            col_name2 = gr.Dropdown([], value=None, label="by task metadata", interactive=True)
            partition.change(fn=update_category, inputs=[domain, partition], outputs=col_name2)
            btn_plot = gr.Button(value="Plot model performance", interactive=True)
            btn_plot.click(plot_performance_for_selected_tasks, [domain, partition, df, query_type, model, baseline, category, col_name2, aggregate, model_aggregate, baseline_aggregate, rank, direction, threshold], plot)

        with gr.Row():
            dist_chart = Plot()
        with gr.Row():
            col_name = gr.Dropdown([], value=None, label="by task metadata", interactive=True)
            partition.change(fn=update_category, inputs=[domain, partition], outputs=col_name)
            btn_dist = gr.Button(value="Visualize task distribution", interactive=True)
            btn_dist.click(visualize_task_distribution, [df, col_name, model, baseline], dist_chart)

        with gr.Row():
            table = gr.DataFrame({}, height=250)
        with gr.Row():
            num_patterns = gr.Slider(1, 100, 50, step=1.0, label="number of patterns")
            btn_pattern = gr.Button(value="Find patterns among tasks")
            btn_pattern.click(find_patterns, [df, num_patterns, model, baseline], table)
            
        category.change(fn=hide_fpm_and_dist_components, inputs=[domain, partition, category], outputs=[num_patterns, btn_pattern, table, col_name, btn_dist, dist_chart])
        category.change(fn=sync_vis_category, inputs=[domain, partition, category], outputs=[col_name, col_name2])
        category.change(fn=update_k, inputs=[domain, partition, category], outputs=k)

    
    with gr.Tab("😮 Surprisingness"):
        gr.Markdown(r"<h2>😮 Find out the tasks a model is surprisingly good or bad at compared to similar tasks</h2>")
        with gr.Row():
            domain3 = gr.Radio(domains, label="scenario", scale=2)
            partition3 = gr.Dropdown([], value=None, label="task space of the following task generator", scale=1)
        with gr.Row():
            model3 = gr.Dropdown(MODELS, value=MODELS[0], label="model", interactive=True, visible=True)
            k3 = gr.Slider(1, 100, 50, step=1.0, label="number of surprising tasks", interactive=True)
            num_neighbors = gr.Slider(1, 100, 50, step=1.0, label="number of neighbors", interactive=True)
            rank3 = gr.Radio(['top', 'bottom'], value='top', label=" ", interactive=True, visible=True)
        domain3.change(fn=update_partition_and_models, inputs=domain3, outputs=[partition3, model3])
        # partition3.change(fn=update_k, inputs=[domain3, partition3], outputs=k3)
        with gr.Row():
            output3 = Plot()
        with gr.Row():
            btn = gr.Button(value="Plot")
            btn.click(plot_surprisingness, [domain3, partition3, model3, rank3, k3, num_neighbors], output3)

            
# if __name__ == "__main__":
demo.launch(share=True)