Upload evaluation.py
Browse files- evaluation/evaluation.py +239 -0
evaluation/evaluation.py
ADDED
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import jsonlines
|
4 |
+
from collections import defaultdict
|
5 |
+
from sklearn.metrics import classification_report
|
6 |
+
|
7 |
+
|
8 |
+
def get_links(sample_string, sample_index):
|
9 |
+
"""
|
10 |
+
takes a sample string and returns a list of attach tuples
|
11 |
+
and a list of rel type strings
|
12 |
+
"""
|
13 |
+
#MINECRAFT labels
|
14 |
+
labels = ['COM','CONTR','CORR','QAP','ACK','ELAB','CLARIFQ','COND','CONTIN',
|
15 |
+
'RES','EXPL','QELAB','ALT','NARR','CONFQ','SEQ']
|
16 |
+
|
17 |
+
split_list = [st.strip() for st in sample_string.split(' ')]
|
18 |
+
|
19 |
+
rel_list = []
|
20 |
+
attach_list = []
|
21 |
+
bad = 0
|
22 |
+
good = 0
|
23 |
+
for a in split_list:
|
24 |
+
s_tuple = None
|
25 |
+
rel = None
|
26 |
+
try:
|
27 |
+
s = a.split('(')[1].split(')')[0].split(',')
|
28 |
+
r = a.split('(')[0].strip()
|
29 |
+
except IndexError:
|
30 |
+
print('split error at ', sample_index)
|
31 |
+
else:
|
32 |
+
try:
|
33 |
+
s_tuple = (int(s[0]), int(s[1]))
|
34 |
+
except IndexError:
|
35 |
+
print('split error at ', sample_index)
|
36 |
+
except ValueError:
|
37 |
+
print('value error at ', sample_index)
|
38 |
+
if r in labels:
|
39 |
+
#make sure the label is well-formed
|
40 |
+
rel = r
|
41 |
+
|
42 |
+
if rel != None and s_tuple != None and (s_tuple[1] - s_tuple[0]) <= 15: #if using a DISTANCE cutoff
|
43 |
+
# if rel != None and s_tuple != None: #if not using a DISTANCE cutoff
|
44 |
+
attach_list.append((int(s[0]), int(s[1])))
|
45 |
+
rel_list.append(r)
|
46 |
+
good += 1
|
47 |
+
else:
|
48 |
+
bad += 1
|
49 |
+
|
50 |
+
#re-construct the full list
|
51 |
+
#a list of tuples (rel, x, y)
|
52 |
+
#but don't allow doubles!!
|
53 |
+
full_list = []
|
54 |
+
endpoints = []
|
55 |
+
for i, r in enumerate(attach_list):
|
56 |
+
if r not in endpoints:
|
57 |
+
endpoints.append(r)
|
58 |
+
full_list.append((rel_list[i], r[0], r[1]))
|
59 |
+
return endpoints, full_list, [good, bad]
|
60 |
+
|
61 |
+
|
62 |
+
current_folder=os.getcwd()
|
63 |
+
|
64 |
+
gold_path = '/path/to/jsonl'
|
65 |
+
pred_path = '/path/to/llamipa_output.txt'
|
66 |
+
save_results = '/path/to/eval_.txt' #to create
|
67 |
+
|
68 |
+
#get predicted
|
69 |
+
with open(pred_path, 'r') as txt:
|
70 |
+
text = txt.read().split('\n')
|
71 |
+
|
72 |
+
pred_outputs = []
|
73 |
+
|
74 |
+
for t in text:
|
75 |
+
if t.startswith(' ### DS:'):
|
76 |
+
sample = t.split('### DS:')[1].strip()
|
77 |
+
pred_outputs.append(sample)
|
78 |
+
print(len(pred_outputs))
|
79 |
+
|
80 |
+
#get gold
|
81 |
+
gold_outputs = []
|
82 |
+
|
83 |
+
with jsonlines.open(gold_path) as reader:
|
84 |
+
for obj in reader:
|
85 |
+
if not obj['sample'].startswith('NEW DIALOGUE'): #make sure to ignore incremental formatting
|
86 |
+
gold_outputs.append(obj['PS'])
|
87 |
+
|
88 |
+
att_f1_l = []
|
89 |
+
att_prec_l = []
|
90 |
+
att_rec_l = []
|
91 |
+
|
92 |
+
total_attach_tp = 0
|
93 |
+
total_attach_fp = 0
|
94 |
+
total_attach_fn = 0
|
95 |
+
|
96 |
+
type_f1_l = []
|
97 |
+
type_prec_l = []
|
98 |
+
type_rec_l = []
|
99 |
+
|
100 |
+
total_TP = []
|
101 |
+
|
102 |
+
matrix_list = []
|
103 |
+
bad_output = 0
|
104 |
+
good_output = 0
|
105 |
+
|
106 |
+
for i, s in enumerate(pred_outputs):
|
107 |
+
|
108 |
+
pred_att, pred_all, malform = get_links(s, i)
|
109 |
+
gold_att, gold_all, malform = get_links(gold_outputs[i], i)
|
110 |
+
|
111 |
+
bad_output += malform[1]
|
112 |
+
good_output += malform[0]
|
113 |
+
|
114 |
+
#calculate number of nulls there should be -- will use to check null count below
|
115 |
+
common = len(set(pred_att).intersection(set(gold_att)))
|
116 |
+
expected_nulls = (len(pred_att) - common) + (len(gold_att) - common)
|
117 |
+
|
118 |
+
|
119 |
+
#calculate the precision, recall, and f1 for the sample FOR ATTACHMENTS
|
120 |
+
if len(gold_att) > 0 and len(pred_att) > 0:
|
121 |
+
prec = len([e for e in pred_att if e in gold_att])/len(pred_att)
|
122 |
+
rec = len([e for e in pred_att if e in gold_att])/len(gold_att)
|
123 |
+
total_attach_tp += len([e for e in pred_att if e in gold_att])
|
124 |
+
total_attach_fp += len([e for e in pred_att if e not in gold_att])
|
125 |
+
total_attach_fn += len([e for e in gold_att if e not in pred_att])
|
126 |
+
else:
|
127 |
+
prec = 0
|
128 |
+
rec = 0
|
129 |
+
att_prec_l.append(prec)
|
130 |
+
att_rec_l.append(rec)
|
131 |
+
if prec+rec==0:
|
132 |
+
att_f1_l.append(0)
|
133 |
+
else:
|
134 |
+
att_f1_l.append(2*prec*rec/(prec+rec))
|
135 |
+
|
136 |
+
#calculate the precision, recall, and f1 for the sample FOR ATTACHMENTS+RELATION TYPE
|
137 |
+
if len(gold_all) > 0 and len(pred_all) > 0:
|
138 |
+
prec = len([e for e in pred_all if e in gold_all])/len(pred_all)
|
139 |
+
rec = len([e for e in pred_all if e in gold_all])/len(gold_all)
|
140 |
+
else:
|
141 |
+
prec = 0
|
142 |
+
rec = 0
|
143 |
+
type_prec_l.append(prec)
|
144 |
+
type_rec_l.append(rec)
|
145 |
+
if prec+rec==0:
|
146 |
+
type_f1_l.append(0)
|
147 |
+
else:
|
148 |
+
type_f1_l.append(2*prec*rec/(prec+rec))
|
149 |
+
|
150 |
+
#create the relation comparisons by type
|
151 |
+
TP = [e for e in pred_all if e in gold_all]
|
152 |
+
leftover_pred = [p for p in pred_all if p not in TP]
|
153 |
+
leftover_gold = [p for p in gold_all if p not in TP]
|
154 |
+
|
155 |
+
#then process the TP, FP, FN for matrix
|
156 |
+
total_TP.extend(TP)
|
157 |
+
|
158 |
+
rem_dict = defaultdict(list)
|
159 |
+
for x in TP:
|
160 |
+
matrix_list.append([x[0], x[0]])
|
161 |
+
for x in leftover_pred:
|
162 |
+
rem_dict[(x[1], x[2])].append(('p', x[0]))
|
163 |
+
for x in leftover_gold:
|
164 |
+
rem_dict[(x[1], x[2])].append(('g', x[0]))
|
165 |
+
|
166 |
+
p_count = 0
|
167 |
+
g_count = 0
|
168 |
+
null_count = 0
|
169 |
+
for k in rem_dict.keys():
|
170 |
+
p = 'NULL'
|
171 |
+
t = 'NULL'
|
172 |
+
for re in rem_dict[k]:
|
173 |
+
if re[0] == 'p':
|
174 |
+
p = re[1]
|
175 |
+
p_count += 1
|
176 |
+
elif re[0] == 'g':
|
177 |
+
t = re[1]
|
178 |
+
g_count += 1
|
179 |
+
matrix_list.append([t,p])
|
180 |
+
if 'NULL' in [t,p]:
|
181 |
+
null_count += 1
|
182 |
+
|
183 |
+
assert(len(TP) + p_count == len(pred_all))
|
184 |
+
assert(len(TP) + g_count == len(gold_all))
|
185 |
+
assert null_count == expected_nulls
|
186 |
+
|
187 |
+
#compute labels in gold and pred
|
188 |
+
gold = [m[0] for m in matrix_list]
|
189 |
+
pred = [m[1] for m in matrix_list]
|
190 |
+
gold.extend(pred)
|
191 |
+
labels = list(set(gold))
|
192 |
+
|
193 |
+
microf1 = total_attach_tp/(total_attach_tp + 0.5*(total_attach_fp + total_attach_fn))
|
194 |
+
|
195 |
+
gold_list = [labels.index(m[0]) for m in matrix_list]
|
196 |
+
pred_list = [labels.index(m[1]) for m in matrix_list]
|
197 |
+
|
198 |
+
f = open(save_results,"w")
|
199 |
+
print("Attachment F1:",np.mean(att_f1_l),len(att_f1_l), file=f)
|
200 |
+
print("Attachment Average Precision:",np.mean(att_prec_l), file=f)
|
201 |
+
print("Attachment Average Recall:",np.mean(att_rec_l), file=f)
|
202 |
+
print('Micro F1: ', microf1, file=f)
|
203 |
+
print('--------------------------------', file=f)
|
204 |
+
print("Attachment + Rel F1:",np.mean(type_f1_l),len(type_f1_l))
|
205 |
+
print("Attachment + Rel Average Precision:",np.mean(type_prec_l))
|
206 |
+
print("Attachment + Rel Average Recall:",np.mean(type_rec_l))
|
207 |
+
print('---------------------------------------')
|
208 |
+
print(classification_report(gold_list,pred_list,target_names=labels), file=f)
|
209 |
+
|
210 |
+
# The F1-scores for the relation types displayed in the above table are correct.
|
211 |
+
#That is, while calculating F1 for label l, all the ["NULL", l] entries count towards false-positive for label l
|
212 |
+
#and all the [l, "NULL"] entries count towards false-negative for label l.
|
213 |
+
#So, the "NULL" type is affecting the precision/recall/F1 for label l (as it should).
|
214 |
+
#Now, for the overall weighted average precision/recall/f1-score,
|
215 |
+
# we want the average to be over the actual relation labels set (i.e. excluding "NULL" class).
|
216 |
+
#For that, we do this:
|
217 |
+
d = classification_report(gold_list,pred_list,target_names=labels,output_dict=True)
|
218 |
+
prec = 0
|
219 |
+
rec = 0
|
220 |
+
f1 = 0
|
221 |
+
count = 0
|
222 |
+
|
223 |
+
for label in labels:
|
224 |
+
if label!="NULL":
|
225 |
+
prec+=d[label]["precision"]*d[label]["support"]
|
226 |
+
rec+=d[label]["recall"]*d[label]["support"]
|
227 |
+
f1+=d[label]["f1-score"]*d[label]["support"]
|
228 |
+
count+=d[label]["support"]
|
229 |
+
# checking that support is same as the number of ground truth instance for the label
|
230 |
+
# assert d[label]["support"] == Counter(g_label_l)[label]
|
231 |
+
|
232 |
+
print('--------------------------------', file=f)
|
233 |
+
print("Weighted Average Precision:", prec/count, file=f)
|
234 |
+
print("Weighted Average Recall:", rec/count, file=f)
|
235 |
+
print("Weighted Average F1 score:", f1/count, file=f)
|
236 |
+
|
237 |
+
f.close()
|
238 |
+
|
239 |
+
|