Spaces:
Runtime error
Runtime error
import gradio as gr | |
from transformers import T5ForConditionalGeneration, AutoTokenizer | |
import string, spacy, torch | |
spacy.cli.download("en_core_web_sm") | |
device = torch.device( | |
"cuda" if torch.cuda.is_available() else "cpu" | |
) | |
configs = { | |
'max_input_embedding_length': 512, | |
'max_output_embedding_length': 128, | |
'task_prefix': "summarize: ", | |
'tokenizer': 't5-small', | |
'ignore_ids': -100, | |
'padding_ids': 0, | |
'base_model': 'ndtran/t5-small_cnn-daily-mail' | |
} | |
model = T5ForConditionalGeneration.from_pretrained(configs['base_model']) | |
model.eval() | |
tokenizer = AutoTokenizer.from_pretrained(configs['tokenizer']) | |
model = model.to(device) | |
nlp = spacy.load("en_core_web_sm") | |
def summarize(text): | |
global model, tokenizer, device, configs, nlp | |
input_ids = tokenizer(configs['task_prefix'] + text, return_tensors = 'pt').input_ids | |
generated_ids = model.generate( | |
input_ids.to(device), | |
do_sample = True, | |
max_length = 256, | |
top_k = 1, | |
temperature = 0.8 | |
) | |
doc = nlp(tokenizer.decode(generated_ids.squeeze(), skip_special_tokens=True)) | |
sents = [str(sent).lstrip(string.punctuation + ' ').rstrip() for sent in doc.sents] | |
for i, sent in enumerate(sents): | |
if len(sent) > 0: | |
sents[i] = sent[0].upper() + sent[1:] | |
return " ".join(sents) | |
def multipart_summarize(text): | |
global model, tokenizer, device, configs, nlp | |
buffer, tokens_count = '', 0 | |
nlp_text = nlp(text) | |
blocks = [] | |
for sent in nlp_text.sents: | |
tokens = tokenizer.tokenize(str(sent)) | |
if len(tokens) > 512: | |
if tokens_count > 0: | |
blocks.append(buffer) | |
buffer, tokens_count = '', 0 | |
blocks.append(str(sent)) | |
buffer += str(sent) | |
tokens_count += len(tokens) | |
if tokens_count > 512: | |
blocks.append(buffer) | |
buffer, tokens_count = '', 0 | |
if tokens_count > 0: | |
blocks.append(buffer) | |
return " ".join(summarize(e) for e in blocks) | |
iface = gr.Interface(fn = multipart_summarize, inputs = "text", outputs = "text") | |
iface.launch() |