lfcc commited on
Commit
dd95897
·
verified ·
1 Parent(s): cdf8937

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -177
app.py CHANGED
@@ -1,177 +1 @@
1
- import streamlit as st
2
- from annotated_text import annotated_text
3
-
4
- import torch
5
- from transformers import pipeline
6
- from transformers import AutoModelForTokenClassification, AutoTokenizer
7
-
8
- import json
9
-
10
-
11
- st.set_page_config(layout="wide")
12
-
13
-
14
-
15
- model = AutoModelForTokenClassification.from_pretrained("models/lusa")
16
- tokenizer = AutoTokenizer.from_pretrained("models/lusa", model_max_length=512)
17
- tagger = pipeline("ner", model=model, tokenizer=tokenizer, aggregation_strategy='first') #aggregation_strategy='max'
18
-
19
-
20
-
21
- def aggregate_subwords(input_tokens, labels):
22
- new_inputs = []
23
- new_labels = []
24
- current_word = ""
25
- current_label = ""
26
- for i, token in enumerate(input_tokens):
27
- label = labels[i]
28
- # Handle subwords
29
- if token.startswith('##'):
30
- current_word += token[2:]
31
- else:
32
- # Finish previous word
33
- if current_word:
34
- new_inputs.append(current_word)
35
- new_labels.append(current_label)
36
- # Start new word
37
- current_word = token
38
- current_label = label
39
- new_inputs.append(current_word)
40
- new_labels.append(current_label)
41
- return new_inputs, new_labels
42
-
43
- def annotateTriggers(line):
44
- line = line.strip()
45
- inputs = tokenizer(line, return_tensors="pt")
46
- input_tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
47
-
48
- with torch.no_grad():
49
- logits = model(**inputs).logits
50
-
51
- predictions = torch.argmax(logits, dim=2)
52
- predicted_token_class = [model.config.id2label[t.item()] for t in predictions[0]]
53
- input_tokens, predicted_token_class = aggregate_subwords(input_tokens,predicted_token_class)
54
- token_labels = []
55
- current_entity = ''
56
- for i, label in enumerate(predicted_token_class):
57
- token = input_tokens[i]
58
- if label == 'O':
59
- token_labels.append((token, 'O', ''))
60
- current_entity = ''
61
- elif label.startswith('B-'):
62
- current_entity = label[2:]
63
- token_labels.append((token, 'B', current_entity))
64
- elif label.startswith('I-'):
65
- if current_entity == '':
66
- raise ValueError(f"Invalid label sequence: {predicted_token_class}")
67
- token_labels[-1] = (token_labels[-1][0] + f" {token}", 'I', current_entity)
68
- else:
69
- raise ValueError(f"Invalid label: {label}")
70
- return token_labels[1:-1]
71
-
72
-
73
-
74
-
75
-
76
- def joinEntities(entities):
77
-
78
- joined_entities = []
79
- i = 0
80
- while i < len(entities):
81
- curr_entity = entities[i]
82
- if curr_entity['entity'][0] == 'B':
83
- label = curr_entity['entity'][2:]
84
- j = i + 1
85
- while j < len(entities) and entities[j]['entity'][0] == 'I':
86
- j += 1
87
- joined_entity = {
88
- 'entity': label,
89
- 'score': max(e['score'] for e in entities[i:j]),
90
- 'index': min(e['index'] for e in entities[i:j]),
91
- 'word': ' '.join(e['word'] for e in entities[i:j]),
92
- 'start': entities[i]['start'],
93
- 'end': entities[j-1]['end']
94
- }
95
- joined_entities.append(joined_entity)
96
- i = j - 1
97
- i += 1
98
- return joined_entities
99
-
100
-
101
-
102
- import pysbd
103
- seg = pysbd.Segmenter(language="es", clean=False)
104
-
105
- def sent_tokenize(text):
106
- return seg.segment(text)
107
-
108
- def getSentenceIndex(lines,span):
109
- i = 1
110
- sum = len(lines[0])
111
- while sum < span:
112
- sum += len(lines[i])
113
- i = i + 1
114
- return i - 1
115
-
116
- def generateContext(text, window,span):
117
- lines = sent_tokenize(text)
118
- index = getSentenceIndex(lines,span)
119
- text = " ".join(lines[max(0,index-window):index+window +1])
120
- return text
121
-
122
-
123
- def annotateEvents(text,squad,window):
124
- text = text.strip()
125
- ner_results = tagger(text)
126
- #print(ner_results)
127
- #ner_results = joinEntities(ner_results)
128
- i = 0
129
- #exit()
130
- while i < len(ner_results):
131
- ner_results[i]["entity"] = ner_results[i]["entity_group"].lstrip("B-")
132
- ner_results[i]["entity"] = ner_results[i]["entity_group"].lstrip("I-")
133
- i = i + 1
134
-
135
- events = []
136
- for trigger in ner_results:
137
- tipo = trigger["entity_group"]
138
- context = generateContext(text,window,trigger["start"])
139
- event = {
140
- "trigger":trigger["word"],
141
- "type": tipo,
142
- "score": trigger["score"],
143
- "context": context,
144
- }
145
- events.append(event)
146
- return events
147
-
148
-
149
- #"A Joana foi atacada pelo João nas ruas do Porto, com uma faca."
150
-
151
- st.title('Extract Events')
152
-
153
- options = ["O presidente da Federação Haitiana de Futebol, Yves Jean-Bart, foi banido para sempre de toda a atividade ligada ao futebol, por ter sido considerado culpado de abuso sexual sistemático de jogadoras, anunciou hoje a FIFA."]
154
-
155
- option = st.selectbox(
156
- 'Select examples',
157
- options)
158
- #option = options [index]
159
- line = st.text_area("Insert Text",option)
160
-
161
- st.button('Run')
162
-
163
-
164
- st.sidebar.write("## Hyperparameters :gear:")
165
- window = 1
166
- if line != "":
167
- st.header("Triggers:")
168
- triggerss = annotateTriggers(line)
169
- annotated_text(*[word[0]+" " if word[1] == 'O' else (word[0]+" ",word[2]) for word in triggerss ])
170
-
171
- eventos_1 = annotateEvents(line,1,window)
172
- eventos_2 = annotateEvents(line,2,window)
173
-
174
- for mention1, mention2 in zip(eventos_1,eventos_2):
175
- st.text(f"| Trigger: {mention1['trigger']:20} | Type: {mention1['type']:10} | Score: {str(round(mention1['score'],3)):5} |")
176
- st.markdown("""---""")
177
-
 
1
+ test