File size: 2,272 Bytes
b21aeb7
 
f3b1c93
b21aeb7
934a315
 
ddae16f
a97fa74
ddae16f
b21aeb7
 
 
 
 
 
 
 
f692320
b21aeb7
 
 
 
 
 
 
f3b1c93
b21aeb7
 
 
 
 
 
 
 
b287b94
0a4c902
b21aeb7
0a4c902
b21aeb7
 
 
1abb70b
67be646
 
ce3e88a
 
67be646
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b21aeb7
67be646
b21aeb7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import gradio as gr
from transformers import T5ForConditionalGeneration, AutoTokenizer
import string, spacy, torch

spacy.cli.download("en_core_web_sm")

device = torch.device(
    "cuda" if torch.cuda.is_available() else "cpu"
)

configs = {
    'max_input_embedding_length': 512,
    'max_output_embedding_length': 128,
    'task_prefix': "summarize: ",
    'tokenizer': 't5-small',
    'ignore_ids': -100,
    'padding_ids': 0,
    'base_model': 'ndtran/t5-small_cnn-daily-mail'
}

model =  T5ForConditionalGeneration.from_pretrained(configs['base_model'])
model.eval()
tokenizer = AutoTokenizer.from_pretrained(configs['tokenizer'])

model = model.to(device)
nlp = spacy.load("en_core_web_sm")

def summarize(text):
    global model, tokenizer, device, configs, nlp
    
    input_ids = tokenizer(configs['task_prefix'] + text, return_tensors = 'pt').input_ids
        
    generated_ids = model.generate(
        input_ids.to(device), 
        do_sample = True, 
        max_length = 256,
        top_k = 1, 
        temperature = 0.8
    )
    
    doc = nlp(tokenizer.decode(generated_ids.squeeze(), skip_special_tokens=True))
    sents = [str(sent).lstrip(string.punctuation + ' ').rstrip() for sent in doc.sents]
    
    for i, sent in enumerate(sents):
        if len(sent) > 0:
            sents[i] = sent[0].upper() + sent[1:]
    
    return " ".join(sents)

def multipart_summarize(text):
    global model, tokenizer, device, configs, nlp
    
        
    buffer, tokens_count = '', 0
    nlp_text = nlp(text)
    
    blocks = []
    
    for sent in nlp_text.sents:
        tokens = tokenizer.tokenize(str(sent))
                
        if len(tokens) > 512:
            if tokens_count > 0:
                blocks.append(buffer)
                buffer, tokens_count = '', 0
            
            blocks.append(str(sent))
            
        buffer += str(sent)
        tokens_count += len(tokens)
        
        if tokens_count > 512:
            blocks.append(buffer)
            buffer, tokens_count = '', 0
            
    if tokens_count > 0:
        blocks.append(buffer)
                    
    return " ".join(summarize(e) for e in blocks)

iface = gr.Interface(fn = multipart_summarize, inputs = "text", outputs = "text")
iface.launch()