Other
English
minecraft
action prediction
Kqte commited on
Commit
65debfe
·
verified ·
1 Parent(s): fe11772

Upload evaluation.py

Browse files
Files changed (1) hide show
  1. evaluation/evaluation.py +239 -0
evaluation/evaluation.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import jsonlines
4
+ from collections import defaultdict
5
+ from sklearn.metrics import classification_report
6
+
7
+
8
+ def get_links(sample_string, sample_index):
9
+ """
10
+ takes a sample string and returns a list of attach tuples
11
+ and a list of rel type strings
12
+ """
13
+ #MINECRAFT labels
14
+ labels = ['COM','CONTR','CORR','QAP','ACK','ELAB','CLARIFQ','COND','CONTIN',
15
+ 'RES','EXPL','QELAB','ALT','NARR','CONFQ','SEQ']
16
+
17
+ split_list = [st.strip() for st in sample_string.split(' ')]
18
+
19
+ rel_list = []
20
+ attach_list = []
21
+ bad = 0
22
+ good = 0
23
+ for a in split_list:
24
+ s_tuple = None
25
+ rel = None
26
+ try:
27
+ s = a.split('(')[1].split(')')[0].split(',')
28
+ r = a.split('(')[0].strip()
29
+ except IndexError:
30
+ print('split error at ', sample_index)
31
+ else:
32
+ try:
33
+ s_tuple = (int(s[0]), int(s[1]))
34
+ except IndexError:
35
+ print('split error at ', sample_index)
36
+ except ValueError:
37
+ print('value error at ', sample_index)
38
+ if r in labels:
39
+ #make sure the label is well-formed
40
+ rel = r
41
+
42
+ if rel != None and s_tuple != None and (s_tuple[1] - s_tuple[0]) <= 15: #if using a DISTANCE cutoff
43
+ # if rel != None and s_tuple != None: #if not using a DISTANCE cutoff
44
+ attach_list.append((int(s[0]), int(s[1])))
45
+ rel_list.append(r)
46
+ good += 1
47
+ else:
48
+ bad += 1
49
+
50
+ #re-construct the full list
51
+ #a list of tuples (rel, x, y)
52
+ #but don't allow doubles!!
53
+ full_list = []
54
+ endpoints = []
55
+ for i, r in enumerate(attach_list):
56
+ if r not in endpoints:
57
+ endpoints.append(r)
58
+ full_list.append((rel_list[i], r[0], r[1]))
59
+ return endpoints, full_list, [good, bad]
60
+
61
+
62
+ current_folder=os.getcwd()
63
+
64
+ gold_path = '/path/to/jsonl'
65
+ pred_path = '/path/to/llamipa_output.txt'
66
+ save_results = '/path/to/eval_.txt' #to create
67
+
68
+ #get predicted
69
+ with open(pred_path, 'r') as txt:
70
+ text = txt.read().split('\n')
71
+
72
+ pred_outputs = []
73
+
74
+ for t in text:
75
+ if t.startswith(' ### DS:'):
76
+ sample = t.split('### DS:')[1].strip()
77
+ pred_outputs.append(sample)
78
+ print(len(pred_outputs))
79
+
80
+ #get gold
81
+ gold_outputs = []
82
+
83
+ with jsonlines.open(gold_path) as reader:
84
+ for obj in reader:
85
+ if not obj['sample'].startswith('NEW DIALOGUE'): #make sure to ignore incremental formatting
86
+ gold_outputs.append(obj['PS'])
87
+
88
+ att_f1_l = []
89
+ att_prec_l = []
90
+ att_rec_l = []
91
+
92
+ total_attach_tp = 0
93
+ total_attach_fp = 0
94
+ total_attach_fn = 0
95
+
96
+ type_f1_l = []
97
+ type_prec_l = []
98
+ type_rec_l = []
99
+
100
+ total_TP = []
101
+
102
+ matrix_list = []
103
+ bad_output = 0
104
+ good_output = 0
105
+
106
+ for i, s in enumerate(pred_outputs):
107
+
108
+ pred_att, pred_all, malform = get_links(s, i)
109
+ gold_att, gold_all, malform = get_links(gold_outputs[i], i)
110
+
111
+ bad_output += malform[1]
112
+ good_output += malform[0]
113
+
114
+ #calculate number of nulls there should be -- will use to check null count below
115
+ common = len(set(pred_att).intersection(set(gold_att)))
116
+ expected_nulls = (len(pred_att) - common) + (len(gold_att) - common)
117
+
118
+
119
+ #calculate the precision, recall, and f1 for the sample FOR ATTACHMENTS
120
+ if len(gold_att) > 0 and len(pred_att) > 0:
121
+ prec = len([e for e in pred_att if e in gold_att])/len(pred_att)
122
+ rec = len([e for e in pred_att if e in gold_att])/len(gold_att)
123
+ total_attach_tp += len([e for e in pred_att if e in gold_att])
124
+ total_attach_fp += len([e for e in pred_att if e not in gold_att])
125
+ total_attach_fn += len([e for e in gold_att if e not in pred_att])
126
+ else:
127
+ prec = 0
128
+ rec = 0
129
+ att_prec_l.append(prec)
130
+ att_rec_l.append(rec)
131
+ if prec+rec==0:
132
+ att_f1_l.append(0)
133
+ else:
134
+ att_f1_l.append(2*prec*rec/(prec+rec))
135
+
136
+ #calculate the precision, recall, and f1 for the sample FOR ATTACHMENTS+RELATION TYPE
137
+ if len(gold_all) > 0 and len(pred_all) > 0:
138
+ prec = len([e for e in pred_all if e in gold_all])/len(pred_all)
139
+ rec = len([e for e in pred_all if e in gold_all])/len(gold_all)
140
+ else:
141
+ prec = 0
142
+ rec = 0
143
+ type_prec_l.append(prec)
144
+ type_rec_l.append(rec)
145
+ if prec+rec==0:
146
+ type_f1_l.append(0)
147
+ else:
148
+ type_f1_l.append(2*prec*rec/(prec+rec))
149
+
150
+ #create the relation comparisons by type
151
+ TP = [e for e in pred_all if e in gold_all]
152
+ leftover_pred = [p for p in pred_all if p not in TP]
153
+ leftover_gold = [p for p in gold_all if p not in TP]
154
+
155
+ #then process the TP, FP, FN for matrix
156
+ total_TP.extend(TP)
157
+
158
+ rem_dict = defaultdict(list)
159
+ for x in TP:
160
+ matrix_list.append([x[0], x[0]])
161
+ for x in leftover_pred:
162
+ rem_dict[(x[1], x[2])].append(('p', x[0]))
163
+ for x in leftover_gold:
164
+ rem_dict[(x[1], x[2])].append(('g', x[0]))
165
+
166
+ p_count = 0
167
+ g_count = 0
168
+ null_count = 0
169
+ for k in rem_dict.keys():
170
+ p = 'NULL'
171
+ t = 'NULL'
172
+ for re in rem_dict[k]:
173
+ if re[0] == 'p':
174
+ p = re[1]
175
+ p_count += 1
176
+ elif re[0] == 'g':
177
+ t = re[1]
178
+ g_count += 1
179
+ matrix_list.append([t,p])
180
+ if 'NULL' in [t,p]:
181
+ null_count += 1
182
+
183
+ assert(len(TP) + p_count == len(pred_all))
184
+ assert(len(TP) + g_count == len(gold_all))
185
+ assert null_count == expected_nulls
186
+
187
+ #compute labels in gold and pred
188
+ gold = [m[0] for m in matrix_list]
189
+ pred = [m[1] for m in matrix_list]
190
+ gold.extend(pred)
191
+ labels = list(set(gold))
192
+
193
+ microf1 = total_attach_tp/(total_attach_tp + 0.5*(total_attach_fp + total_attach_fn))
194
+
195
+ gold_list = [labels.index(m[0]) for m in matrix_list]
196
+ pred_list = [labels.index(m[1]) for m in matrix_list]
197
+
198
+ f = open(save_results,"w")
199
+ print("Attachment F1:",np.mean(att_f1_l),len(att_f1_l), file=f)
200
+ print("Attachment Average Precision:",np.mean(att_prec_l), file=f)
201
+ print("Attachment Average Recall:",np.mean(att_rec_l), file=f)
202
+ print('Micro F1: ', microf1, file=f)
203
+ print('--------------------------------', file=f)
204
+ print("Attachment + Rel F1:",np.mean(type_f1_l),len(type_f1_l))
205
+ print("Attachment + Rel Average Precision:",np.mean(type_prec_l))
206
+ print("Attachment + Rel Average Recall:",np.mean(type_rec_l))
207
+ print('---------------------------------------')
208
+ print(classification_report(gold_list,pred_list,target_names=labels), file=f)
209
+
210
+ # The F1-scores for the relation types displayed in the above table are correct.
211
+ #That is, while calculating F1 for label l, all the ["NULL", l] entries count towards false-positive for label l
212
+ #and all the [l, "NULL"] entries count towards false-negative for label l.
213
+ #So, the "NULL" type is affecting the precision/recall/F1 for label l (as it should).
214
+ #Now, for the overall weighted average precision/recall/f1-score,
215
+ # we want the average to be over the actual relation labels set (i.e. excluding "NULL" class).
216
+ #For that, we do this:
217
+ d = classification_report(gold_list,pred_list,target_names=labels,output_dict=True)
218
+ prec = 0
219
+ rec = 0
220
+ f1 = 0
221
+ count = 0
222
+
223
+ for label in labels:
224
+ if label!="NULL":
225
+ prec+=d[label]["precision"]*d[label]["support"]
226
+ rec+=d[label]["recall"]*d[label]["support"]
227
+ f1+=d[label]["f1-score"]*d[label]["support"]
228
+ count+=d[label]["support"]
229
+ # checking that support is same as the number of ground truth instance for the label
230
+ # assert d[label]["support"] == Counter(g_label_l)[label]
231
+
232
+ print('--------------------------------', file=f)
233
+ print("Weighted Average Precision:", prec/count, file=f)
234
+ print("Weighted Average Recall:", rec/count, file=f)
235
+ print("Weighted Average F1 score:", f1/count, file=f)
236
+
237
+ f.close()
238
+
239
+