MahmoudH's picture
Decrease `length_penalty`
a427a6e
raw
history blame contribute delete
845 Bytes
import gradio as gr
from transformers import AutoTokenizer, TFAutoModelForSeq2SeqLM
tokenizer = AutoTokenizer.from_pretrained("MahmoudH/t5-v1_1-base-finetuned-sci_summ")
model = TFAutoModelForSeq2SeqLM.from_pretrained("MahmoudH/t5-v1_1-base-finetuned-sci_summ")
def predict(text):
text = "summarize: " + text
tokenized_inputs = tokenizer([text])
output = model.generate(
input_ids=tokenized_inputs["input_ids"],
attention_mask=tokenized_inputs["attention_mask"],
max_new_tokens=256,
length_penalty=0.5,
num_beams=4,
do_sample=True
)
summary = tokenizer.batch_decode(output, skip_special_tokens=True)[0]
return summary
input_box = gr.Textbox(label="Input")
output_box = gr.Textbox(label="Summary")
gr.Interface(fn=predict, inputs=input_box, outputs=output_box).launch()