Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import pipeline, AutoModelForSeq2SeqLM, AutoTokenizer | |
model_checkpoint = "t5_history_qa" | |
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True, local_files_only=True) | |
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint, local_files_only=True) | |
def gen_answer(question, context): | |
context = f"question: {question} context: {context}" | |
encoded_input = tokenizer( | |
[context], | |
max_length=512, | |
return_tensors='pt', | |
truncation=True, | |
padding="max_length", | |
) | |
output = model.generate(input_ids = encoded_input.input_ids, | |
attention_mask = encoded_input.attention_mask, max_length=32) | |
output = tokenizer.decode(output[0], skip_special_tokens=True) | |
return output | |
with gr.Blocks() as demo: | |
context = gr.Textbox(label="Context", placeholder="Please provide history related context", lines=5) | |
question = gr.Textbox(label="Question", placeholder="Please ask related question", lines=1) | |
answer_btn = gr.Button("Generate answer") | |
output = gr.Textbox(label="Answer", lines=1) | |
answer_btn.click(fn=gen_answer, inputs=[question, context], outputs=output, api_name="gen_answer") | |
if __name__ == "__main__": | |
demo.launch(debug=True) |