File size: 491 Bytes
5e9bd47 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 |
import json
def merge_predictions(results):
if len(results) == 0:
return {}
formats = results[0][1].keys()
predictions = {format_: {} for format_ in formats}
for format_ in formats:
for indices, batch_preds in results:
for idx, preds in zip(indices, batch_preds[format_]):
predictions[format_][idx] = preds
predictions[format_] = [predictions[format_][i] for i in range(len(predictions[format_]))]
return predictions
|