|
|
|
|
|
|
|
|
|
import os, pdb |
|
import pandas as pd |
|
import gradio as gr |
|
from gradio_rich_textbox import RichTextbox |
|
|
|
from demo import VideoMIRModel |
|
|
|
|
|
def load_v2t_samples(data_root): |
|
sample_videos = [] |
|
df = pd.read_csv("meta/ek100_mir/sel_v2t.csv", header=None) |
|
idx2sid = {} |
|
for i, x in enumerate(df[0].values): |
|
sample_videos.append(f'{data_root}/video/gif/{x}.gif') |
|
idx2sid[i] = x |
|
|
|
return sample_videos, idx2sid |
|
|
|
def load_t2v_samples(data_root): |
|
|
|
sample_text = ['cut the sausage', 'rinse cutting board'] |
|
idx2sid = {0: 2119, 1: 1730, 2: 1276} |
|
return sample_text, idx2sid |
|
|
|
def format_pred(pred, gt): |
|
tp = '[color=green]{}[/color]' |
|
fp = '[color=red]{}[/color]' |
|
fmt_pred = [] |
|
for x in pred: |
|
if x in gt: |
|
fmt_pred.append(tp.format(x)) |
|
else: |
|
fmt_pred.append(fp.format(x)) |
|
|
|
return ', '.join(fmt_pred) |
|
|
|
def main(): |
|
lavila = VideoMIRModel("configs/ek100_mir/zeroshot.yml") |
|
egovpa = VideoMIRModel("configs/ek100_mir/egovpa.yml") |
|
v2t_samples, idx2sid_v2t = load_v2t_samples('data/ek100_mir') |
|
t2v_samples, idx2sid_t2v = load_t2v_samples('data/ek100_mir') |
|
print(v2t_samples) |
|
|
|
def predict_v2t(idx): |
|
if idx == 1: |
|
idx = 2 |
|
sid = idx2sid_v2t[idx] |
|
zeroshot_action, gt_action = lavila.predict_v2t(idx, sid) |
|
egovpa_action, gt_action = egovpa.predict_v2t(idx, sid) |
|
zeroshot_action = format_pred(zeroshot_action, gt_action) |
|
egovpa_action = format_pred(egovpa_action, gt_action) |
|
|
|
return gt_action, zeroshot_action, egovpa_action |
|
|
|
def predict_t2v(idx): |
|
sid = idx2sid_t2v[idx] |
|
egovpa_video, gt_video = egovpa.predict_t2v(idx, sid) |
|
egovpa_video = [f'data/ek100_mir/video/gif/{x}.gif' for x in egovpa_video] |
|
|
|
return egovpa_video |
|
|
|
with gr.Blocks() as demo: |
|
with gr.Tab("Video-to-text retrieval"): |
|
gr.Markdown( |
|
""" |
|
# Ego-VPA Demo |
|
Choose a sample video and click predict to view the text queried by the selected video |
|
(<span style="color:green">correct</span>/<span style="color:red">incorrect</span>). |
|
""" |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
video = gr.Image(label="video query", height='300px', interactive=False) |
|
with gr.Column(): |
|
idx = gr.Number(label="Idx", visible=False) |
|
label = RichTextbox(label="Ground Truth", visible=False) |
|
zeroshot = RichTextbox(label="LaViLa (zero-shot) prediction") |
|
ours = RichTextbox(label="Ego-VPA prediction") |
|
btn = gr.Button("Predict", variant="primary") |
|
btn.click(predict_v2t, inputs=[idx], outputs=[label, zeroshot, ours]) |
|
gr.Examples(examples=[[i, x] for i, x in enumerate(v2t_samples)], inputs=[idx, video]) |
|
|
|
with gr.Tab("Text-to-video retrieval"): |
|
gr.Markdown( |
|
""" |
|
# Ego-VPA Demo |
|
Choose a sample narration and click predict to view the video queried by the selected text. |
|
""" |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
text = gr.Text(label="text query") |
|
with gr.Column(): |
|
idx = gr.Number(label="Idx", visible=False) |
|
|
|
|
|
|
|
ours = gr.Gallery(label="Ego-VPA prediction", columns=[1], rows=[1], object_fit="contain", height="auto") |
|
btn = gr.Button("Predict", variant="primary") |
|
btn.click(predict_t2v, inputs=[idx], outputs=[ours]) |
|
gr.Examples(examples=[[i, x] for i, x in enumerate(t2v_samples)], inputs=[idx, text]) |
|
|
|
|
|
|
|
demo.launch(share=True) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|
|
|