winoground_vq2 / app.py
yonatanbitton's picture
commit
5abf748
from datasets import load_dataset
import pandas as pd
import gradio as gr
import os
import random
def list_to_string(lst):
return '\n'.join(['{}. {}'.format(i+1, item) for i, item in enumerate(lst)])
auth_token = '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'
winoground = load_dataset("facebook/winoground", use_auth_token=auth_token)["test"]
df = pd.read_csv('BLIP2_Q2_CM.csv')
winoground_pd = winoground.to_pandas()
filter_id_set = set(df['id'])
winoground_pd = winoground_pd[winoground_pd['id'].isin(filter_id_set)]
for c in ['c0_i0_blip2', 'c0_i1_blip2', 'c1_i0_blip2', 'c1_i1_blip2', 'c0_i0_q2', 'c1_i0_q2', 'c0_i1_q2', 'c1_i1_q2']:
df[c] = df[c].apply(lambda x: round(x, 1))
print("Load wino")
import json
df['qa_samples'] = df['qa_samples'].apply(json.loads)
def func(index):
print(f"index: {index}")
example = winoground_pd.iloc[int(index)]
values = get_instance_values(example)
# print(f"index: {index}")
# return get_img(example["image_0"]), example["caption_0"], get_img(example["image_1"]), example["caption_1"]
# preds = ['yoyoyo']
return values
demo = gr.Blocks()
import PIL
def process_key(item):
if 'path' in item:
return PIL.Image.open(item['path'])
return item
def get_vals(items):
values = []
for idx, e in enumerate(items):
if idx == 0:
gr_val = gr.Image(value=e)
else:
gr_val = gr.Textbox(value=e)
values.append(gr_val)
return values
def get_instance_values(example):
item_q2 = df[df['id'] == example['id']].iloc[0].to_dict()
left_keys = [x for x in example.keys() if '_0' in x]
right_keys = [x for x in example.keys() if '_1' in x]
left = [process_key(example[k]) for k in left_keys]
right = [process_key(example[k]) for k in right_keys]
# for k in left_side_columns + right_side_columns:
# if k in enumerate_cols:
# value = list_to_string(example[k])
# elif k == QA:
# qa_list = [f"Q: {x[0]} A: {x[1]}" for x in example[k]]
# value = list_to_string(qa_list)
# else:
# value = example[k]
# values.append(value)
blip_preds = {k.split("_blip2")[0]:v for k,v in item_q2.items() if "_blip" in k}
q2_preds = {k.split("_q2")[0]: v for k, v in item_q2.items() if "_q2" in k}
del blip_preds['model']
del q2_preds['model']
blip_preds_str = dict2string(blip_preds)
q2_preds_str = dict2string(q2_preds)
qas = pd.DataFrame(item_q2['qa_samples'])
qa_image_0 = qas[qas['img_caption'].apply(lambda x: 'image_0' in x)]
qa_image_1 = qas[qas['img_caption'].apply(lambda x: 'image_1' in x)]
qas_image_0_str = qa_image_0.to_string()
qas_image_1_str = qa_image_1.to_string()
return [blip_preds_str, q2_preds_str, qas_image_0_str, qas_image_1_str] + left + right
def dict2string(d):
s = ''
for idx, (k, v) in enumerate(d.items()):
s += f"{k}: {v}"
if idx < len(d) - 1:
s += '\t|\t'
return s
with demo:
gr.Markdown("# Slide across the slider to see various examples from WinoGround")
with gr.Column():
index = random.choice(range(0, len(winoground_pd)))
# index = 0
example = winoground_pd.iloc[index]
example_values = get_instance_values(example)
first_item = example_values[0]
second_item = example_values[1]
third_item = example_values[2]
fourth_item = example_values[3]
rest_values = example_values[4:]
item_q2 = df[df['id'] == example['id']].iloc[0]
slider = gr.Slider(minimum=0, maximum=len(winoground_pd), step=1)
preds_blip2 = gr.Textbox(value=first_item, label='BLIP2')
preds_q2 = gr.Textbox(value=second_item, label='Visual Q^2')
with gr.Row():
with gr.Column():
qas_image_0 = gr.Textbox(value=third_item, label='QA Pairs (Image 0)')
with gr.Column():
qas_image_1 = gr.Textbox(value=fourth_item, label='QA Pairs (Image 1)')
with gr.Row():
items_left = rest_values[:int(len(rest_values)/2)]
items_right = rest_values[int(len(rest_values)/2):]
with gr.Column():
# image_input_1 = gr.Image(value=get_img(winoground_pd.iloc[index]["image_0"]))
# text_input_1 = gr.Textbox(value=winoground_pd.iloc[index]["caption_0"])
items = items_left
gr_values_left = get_vals(items_left)
with gr.Column():
# image_input_2 = gr.Image(value=get_img(winoground_pd.iloc[index]["image_1"]))
# text_input_2 = gr.Textbox(value=winoground_pd.iloc[index]["caption_1"])
gr_values_right = get_vals(items_right)
slider.change(func, inputs=[slider], outputs=[preds_blip2, preds_q2, qas_image_0, qas_image_1] + gr_values_left + gr_values_right)
demo.launch(auth=("admin", "visual_q2_secret"))