File size: 1,334 Bytes
ea8662a
bf09ad3
dcdcf39
bf09ad3
dcdcf39
e65c2a1
bf09ad3
dcdcf39
3252fb2
dcdcf39
3252fb2
dcdcf39
3252fb2
dcdcf39
ac182d6
 
 
 
 
 
3252fb2
 
 
dcdcf39
3252fb2
dcdcf39
 
ac182d6
3252fb2
ac182d6
 
e65c2a1
 
ac182d6
e65c2a1
cec1ada
bf09ad3
3252fb2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import gradio as gr
from transformers import pipeline, AutoModelForSeq2SeqLM, AutoTokenizer

model_checkpoint = "t5_history_qa"

tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True, local_files_only=True)
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint, local_files_only=True)


def gen_answer(question, context):

    context = f"question: {question} context: {context}"

    encoded_input = tokenizer(
            [context],
            max_length=512,
            return_tensors='pt',
            truncation=True,
            padding="max_length",
        )
    
    output = model.generate(input_ids = encoded_input.input_ids,
                                attention_mask = encoded_input.attention_mask, max_length=32)
    output = tokenizer.decode(output[0], skip_special_tokens=True)

    return output


with gr.Blocks() as demo:
    context = gr.Textbox(label="Context", placeholder="Please provide history related context", lines=5)
    question = gr.Textbox(label="Question", placeholder="Please ask related question", lines=1)
    answer_btn = gr.Button("Generate answer")
    output = gr.Textbox(label="Answer", lines=1)
   
    answer_btn.click(fn=gen_answer, inputs=[question, context], outputs=output, api_name="gen_answer")

if __name__ == "__main__":
    demo.launch(debug=True)