ndtran's picture
Update app.py
f692320
import gradio as gr
from transformers import T5ForConditionalGeneration, AutoTokenizer
import string, spacy, torch
spacy.cli.download("en_core_web_sm")
device = torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
)
configs = {
'max_input_embedding_length': 512,
'max_output_embedding_length': 128,
'task_prefix': "summarize: ",
'tokenizer': 't5-small',
'ignore_ids': -100,
'padding_ids': 0,
'base_model': 'ndtran/t5-small_cnn-daily-mail'
}
model = T5ForConditionalGeneration.from_pretrained(configs['base_model'])
model.eval()
tokenizer = AutoTokenizer.from_pretrained(configs['tokenizer'])
model = model.to(device)
nlp = spacy.load("en_core_web_sm")
def summarize(text):
global model, tokenizer, device, configs, nlp
input_ids = tokenizer(configs['task_prefix'] + text, return_tensors = 'pt').input_ids
generated_ids = model.generate(
input_ids.to(device),
do_sample = True,
max_length = 256,
top_k = 1,
temperature = 0.8
)
doc = nlp(tokenizer.decode(generated_ids.squeeze(), skip_special_tokens=True))
sents = [str(sent).lstrip(string.punctuation + ' ').rstrip() for sent in doc.sents]
for i, sent in enumerate(sents):
if len(sent) > 0:
sents[i] = sent[0].upper() + sent[1:]
return " ".join(sents)
def multipart_summarize(text):
global model, tokenizer, device, configs, nlp
buffer, tokens_count = '', 0
nlp_text = nlp(text)
blocks = []
for sent in nlp_text.sents:
tokens = tokenizer.tokenize(str(sent))
if len(tokens) > 512:
if tokens_count > 0:
blocks.append(buffer)
buffer, tokens_count = '', 0
blocks.append(str(sent))
buffer += str(sent)
tokens_count += len(tokens)
if tokens_count > 512:
blocks.append(buffer)
buffer, tokens_count = '', 0
if tokens_count > 0:
blocks.append(buffer)
return " ".join(summarize(e) for e in blocks)
iface = gr.Interface(fn = multipart_summarize, inputs = "text", outputs = "text")
iface.launch()