Spaces:
Runtime error
Runtime error
File size: 2,272 Bytes
b21aeb7 f3b1c93 b21aeb7 934a315 ddae16f a97fa74 ddae16f b21aeb7 f692320 b21aeb7 f3b1c93 b21aeb7 b287b94 0a4c902 b21aeb7 0a4c902 b21aeb7 1abb70b 67be646 ce3e88a 67be646 b21aeb7 67be646 b21aeb7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
import gradio as gr
from transformers import T5ForConditionalGeneration, AutoTokenizer
import string, spacy, torch
spacy.cli.download("en_core_web_sm")
device = torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
)
configs = {
'max_input_embedding_length': 512,
'max_output_embedding_length': 128,
'task_prefix': "summarize: ",
'tokenizer': 't5-small',
'ignore_ids': -100,
'padding_ids': 0,
'base_model': 'ndtran/t5-small_cnn-daily-mail'
}
model = T5ForConditionalGeneration.from_pretrained(configs['base_model'])
model.eval()
tokenizer = AutoTokenizer.from_pretrained(configs['tokenizer'])
model = model.to(device)
nlp = spacy.load("en_core_web_sm")
def summarize(text):
global model, tokenizer, device, configs, nlp
input_ids = tokenizer(configs['task_prefix'] + text, return_tensors = 'pt').input_ids
generated_ids = model.generate(
input_ids.to(device),
do_sample = True,
max_length = 256,
top_k = 1,
temperature = 0.8
)
doc = nlp(tokenizer.decode(generated_ids.squeeze(), skip_special_tokens=True))
sents = [str(sent).lstrip(string.punctuation + ' ').rstrip() for sent in doc.sents]
for i, sent in enumerate(sents):
if len(sent) > 0:
sents[i] = sent[0].upper() + sent[1:]
return " ".join(sents)
def multipart_summarize(text):
global model, tokenizer, device, configs, nlp
buffer, tokens_count = '', 0
nlp_text = nlp(text)
blocks = []
for sent in nlp_text.sents:
tokens = tokenizer.tokenize(str(sent))
if len(tokens) > 512:
if tokens_count > 0:
blocks.append(buffer)
buffer, tokens_count = '', 0
blocks.append(str(sent))
buffer += str(sent)
tokens_count += len(tokens)
if tokens_count > 512:
blocks.append(buffer)
buffer, tokens_count = '', 0
if tokens_count > 0:
blocks.append(buffer)
return " ".join(summarize(e) for e in blocks)
iface = gr.Interface(fn = multipart_summarize, inputs = "text", outputs = "text")
iface.launch() |