File size: 4,161 Bytes
5f70455
 
 
 
 
 
 
 
 
 
1114f49
 
 
 
 
5f70455
 
1114f49
5f70455
 
1114f49
5f70455
 
1114f49
5f70455
 
 
 
1114f49
 
 
 
 
 
 
 
 
 
 
 
 
 
5f70455
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1114f49
5f70455
 
 
 
1114f49
 
5f70455
1114f49
 
 
5f70455
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1114f49
5f70455
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import matplotlib.pyplot as plt
import torch
from flamingo_mini_task.utils import load_url
from flamingo_mini_task import FlamingoModel, FlamingoProcessor
from datasets import load_dataset,concatenate_datasets
from PIL import Image

EXAMPLES_DIR = 'examples'
DEFAULT_PROMPT = "<image>"
MINI_MODEL = "flamingo-mini-bilbaocaptions-scienceQA[QA]"
TINY_MODEL = "flamingo-tiny-scienceQA[COT+QA]"
MEGATINY_MODEL = "flamingo-megatiny-opt-scienceQA[QA]"

flamingo_megatiny_captioning_models = {
    MINI_MODEL: {
        'model': FlamingoModel.from_pretrained('TheMrguiller/Flamingo-tiny_ScienceQA_COT-QA'),
    },
    TINY_MODEL: {
        'model': FlamingoModel.from_pretrained('TheMrguiller/Flamingo-mini-Bilbao_Captions-task_BilbaoQA-ScienceQA'),
    },
    MEGATINY_MODEL:{
        'model': FlamingoModel.from_pretrained('landersanmi/flamingo-megatiny-opt-QA')
    },
}

# setup some example images
examples = []
path = EXAMPLES_DIR + "/{}"
cot = False

examples.append([path.format("koala.png"), "What animal is this?", "Koala", "Elephant", "Cat", "Mouse", cot, MEGATINY_MODEL])
examples.append([path.format("townhall.jpg"), "What building is this?", "Guggenheim museum", "San mames stadium", "Alhondiga", "Bilbao townhall", cot, TINY_MODEL])
examples.append([path.format("muniain.jpeg"), "What team is IKer Muniain associated?", "Real Madrid", "Manchester United", "Athletic Bilbao", "Rayo Vallecano", cot, TINY_MODEL])
examples.append([path.format("lasalve.jpeg"), "What is the name of this bridge?", "La Salve", "Zubizuri", "La Ribera", "San Anton", cot, TINY_MODEL])
examples.append([path.format("athl.jpeg"), "Football fans hold flags with what team colors?", "Athletic", "Besiktas", "Udinese", "Real Madrid", cot, TINY_MODEL])

#examples.append([path, cot, DEFAULT_PROMPT, DEFAULT_MODEL])
#examples.append([path, cot, DEFAULT_PROMPT, DEFAULT_MODEL])


def generate_text(image, question, option_a, option_b, option_c, option_d, cot_checkbox, model_name):
    model = flamingo_megatiny_captioning_models[model_name]['model']
    processor = FlamingoProcessor(model.config)

    prompt = ""
    if cot_checkbox:
        prompt += "[COT]"
    else:
        prompt += "[QA]"
    
    prompt += "[CONTEXT]<image>[QUESTION]{} [OPTIONS] (A) {} (B) {} (C) {} (D) {} [ANSWER]".format(question,
                                                                                                   option_a,
                                                                                                   option_b,
                                                                                                   option_c,
                                                                                                   option_d)

    print(prompt)
    prediction = model.generate_captions(images = image,
                                         processor = processor,
                                         prompt = prompt,
                                        )

    return prediction[0].split('[ANSWER]')[1]




image_input = gr.Image(path.format("giraffe.jpeg"))
question_input = gr.inputs.Textbox(default="What animal is this?")
opt_a_input = gr.inputs.Textbox(default="Dog")
opt_b_input = gr.inputs.Textbox(default="Giraffe")
opt_c_input = gr.inputs.Textbox(default="Elephant")
opt_d_input = gr.inputs.Textbox(default="Cocodrile")
cot_checkbox = gr.inputs.Checkbox(label="Generate COT")
select_model = gr.inputs.Dropdown(choices=list(flamingo_megatiny_captioning_models.keys()))

text_output = gr.outputs.Textbox()

# Create the Gradio interface
gr.Interface(
    fn=generate_text,
    inputs=[image_input, 
            question_input, 
            opt_a_input, 
            opt_b_input, 
            opt_c_input, 
            opt_d_input, 
            cot_checkbox,
            select_model
           ],
    examples=examples,
    outputs=text_output,
    title='Generate answers from MCQ',
    description='Generate answers from Multiple Choice Questions or generate a Chain Of Though about the question and the options given',
    theme='default'
).launch()