Spaces:
Sleeping
Sleeping
File size: 1,334 Bytes
ea8662a bf09ad3 dcdcf39 bf09ad3 dcdcf39 e65c2a1 bf09ad3 dcdcf39 3252fb2 dcdcf39 3252fb2 dcdcf39 3252fb2 dcdcf39 ac182d6 3252fb2 dcdcf39 3252fb2 dcdcf39 ac182d6 3252fb2 ac182d6 e65c2a1 ac182d6 e65c2a1 cec1ada bf09ad3 3252fb2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
import gradio as gr
from transformers import pipeline, AutoModelForSeq2SeqLM, AutoTokenizer
model_checkpoint = "t5_history_qa"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True, local_files_only=True)
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint, local_files_only=True)
def gen_answer(question, context):
context = f"question: {question} context: {context}"
encoded_input = tokenizer(
[context],
max_length=512,
return_tensors='pt',
truncation=True,
padding="max_length",
)
output = model.generate(input_ids = encoded_input.input_ids,
attention_mask = encoded_input.attention_mask, max_length=32)
output = tokenizer.decode(output[0], skip_special_tokens=True)
return output
with gr.Blocks() as demo:
context = gr.Textbox(label="Context", placeholder="Please provide history related context", lines=5)
question = gr.Textbox(label="Question", placeholder="Please ask related question", lines=1)
answer_btn = gr.Button("Generate answer")
output = gr.Textbox(label="Answer", lines=1)
answer_btn.click(fn=gen_answer, inputs=[question, context], outputs=output, api_name="gen_answer")
if __name__ == "__main__":
demo.launch(debug=True) |