lfcc commited on
Commit
f35f7c9
·
verified ·
1 Parent(s): dd95897

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +177 -1
app.py CHANGED
@@ -1 +1,177 @@
1
- test
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from annotated_text import annotated_text
3
+
4
+ import torch
5
+ from transformers import pipeline
6
+ from transformers import AutoModelForTokenClassification, AutoTokenizer
7
+
8
+ import json
9
+
10
+
11
+ st.set_page_config(layout="wide")
12
+
13
+
14
+
15
+ model = AutoModelForTokenClassification.from_pretrained("models/lusa")
16
+ tokenizer = AutoTokenizer.from_pretrained("models/lusa", model_max_length=512)
17
+ tagger = pipeline("ner", model=model, tokenizer=tokenizer, aggregation_strategy='first') #aggregation_strategy='max'
18
+
19
+
20
+
21
+ def aggregate_subwords(input_tokens, labels):
22
+ new_inputs = []
23
+ new_labels = []
24
+ current_word = ""
25
+ current_label = ""
26
+ for i, token in enumerate(input_tokens):
27
+ label = labels[i]
28
+ # Handle subwords
29
+ if token.startswith('##'):
30
+ current_word += token[2:]
31
+ else:
32
+ # Finish previous word
33
+ if current_word:
34
+ new_inputs.append(current_word)
35
+ new_labels.append(current_label)
36
+ # Start new word
37
+ current_word = token
38
+ current_label = label
39
+ new_inputs.append(current_word)
40
+ new_labels.append(current_label)
41
+ return new_inputs, new_labels
42
+
43
+ def annotateTriggers(line):
44
+ line = line.strip()
45
+ inputs = tokenizer(line, return_tensors="pt")
46
+ input_tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
47
+
48
+ with torch.no_grad():
49
+ logits = model(**inputs).logits
50
+
51
+ predictions = torch.argmax(logits, dim=2)
52
+ predicted_token_class = [model.config.id2label[t.item()] for t in predictions[0]]
53
+ input_tokens, predicted_token_class = aggregate_subwords(input_tokens,predicted_token_class)
54
+ token_labels = []
55
+ current_entity = ''
56
+ for i, label in enumerate(predicted_token_class):
57
+ token = input_tokens[i]
58
+ if label == 'O':
59
+ token_labels.append((token, 'O', ''))
60
+ current_entity = ''
61
+ elif label.startswith('B-'):
62
+ current_entity = label[2:]
63
+ token_labels.append((token, 'B', current_entity))
64
+ elif label.startswith('I-'):
65
+ if current_entity == '':
66
+ raise ValueError(f"Invalid label sequence: {predicted_token_class}")
67
+ token_labels[-1] = (token_labels[-1][0] + f" {token}", 'I', current_entity)
68
+ else:
69
+ raise ValueError(f"Invalid label: {label}")
70
+ return token_labels[1:-1]
71
+
72
+
73
+
74
+
75
+
76
+ def joinEntities(entities):
77
+
78
+ joined_entities = []
79
+ i = 0
80
+ while i < len(entities):
81
+ curr_entity = entities[i]
82
+ if curr_entity['entity'][0] == 'B':
83
+ label = curr_entity['entity'][2:]
84
+ j = i + 1
85
+ while j < len(entities) and entities[j]['entity'][0] == 'I':
86
+ j += 1
87
+ joined_entity = {
88
+ 'entity': label,
89
+ 'score': max(e['score'] for e in entities[i:j]),
90
+ 'index': min(e['index'] for e in entities[i:j]),
91
+ 'word': ' '.join(e['word'] for e in entities[i:j]),
92
+ 'start': entities[i]['start'],
93
+ 'end': entities[j-1]['end']
94
+ }
95
+ joined_entities.append(joined_entity)
96
+ i = j - 1
97
+ i += 1
98
+ return joined_entities
99
+
100
+
101
+
102
+ import pysbd
103
+ seg = pysbd.Segmenter(language="es", clean=False)
104
+
105
+ def sent_tokenize(text):
106
+ return seg.segment(text)
107
+
108
+ def getSentenceIndex(lines,span):
109
+ i = 1
110
+ sum = len(lines[0])
111
+ while sum < span:
112
+ sum += len(lines[i])
113
+ i = i + 1
114
+ return i - 1
115
+
116
+ def generateContext(text, window,span):
117
+ lines = sent_tokenize(text)
118
+ index = getSentenceIndex(lines,span)
119
+ text = " ".join(lines[max(0,index-window):index+window +1])
120
+ return text
121
+
122
+
123
+ def annotateEvents(text,squad,window):
124
+ text = text.strip()
125
+ ner_results = tagger(text)
126
+ #print(ner_results)
127
+ #ner_results = joinEntities(ner_results)
128
+ i = 0
129
+ #exit()
130
+ while i < len(ner_results):
131
+ ner_results[i]["entity"] = ner_results[i]["entity_group"].lstrip("B-")
132
+ ner_results[i]["entity"] = ner_results[i]["entity_group"].lstrip("I-")
133
+ i = i + 1
134
+
135
+ events = []
136
+ for trigger in ner_results:
137
+ tipo = trigger["entity_group"]
138
+ context = generateContext(text,window,trigger["start"])
139
+ event = {
140
+ "trigger":trigger["word"],
141
+ "type": tipo,
142
+ "score": trigger["score"],
143
+ "context": context,
144
+ }
145
+ events.append(event)
146
+ return events
147
+
148
+
149
+ #"A Joana foi atacada pelo João nas ruas do Porto, com uma faca."
150
+
151
+ st.title('Extract Events')
152
+
153
+ options = ["O presidente da Federação Haitiana de Futebol, Yves Jean-Bart, foi banido para sempre de toda a atividade ligada ao futebol, por ter sido considerado culpado de abuso sexual sistemático de jogadoras, anunciou hoje a FIFA."]
154
+
155
+ option = st.selectbox(
156
+ 'Select examples',
157
+ options)
158
+ #option = options [index]
159
+ line = st.text_area("Insert Text",option)
160
+
161
+ st.button('Run')
162
+
163
+
164
+ st.sidebar.write("## Hyperparameters :gear:")
165
+ window = 1
166
+ if line != "":
167
+ st.header("Triggers:")
168
+ triggerss = annotateTriggers(line)
169
+ annotated_text(*[word[0]+" " if word[1] == 'O' else (word[0]+" ",word[2]) for word in triggerss ])
170
+
171
+ eventos_1 = annotateEvents(line,1,window)
172
+ eventos_2 = annotateEvents(line,2,window)
173
+
174
+ for mention1, mention2 in zip(eventos_1,eventos_2):
175
+ st.text(f"| Trigger: {mention1['trigger']:20} | Type: {mention1['type']:10} | Score: {str(round(mention1['score'],3)):5} |")
176
+ st.markdown("""---""")
177
+