package org.nlpub.watset.eval;

import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.Function;
import java.util.stream.Collectors;

/* loaded from: input_file:org/nlpub/watset/eval/NormalizedModifiedPurity.class */
public class NormalizedModifiedPurity<V> {
    final boolean normalized;
    final boolean modified;

    public static <V> List<Map<V, Double>> transform(List<? extends Collection<V>> list) {
        return (List) list.stream().map(collection -> {
            return (Map) collection.stream().collect(Collectors.groupingBy(Function.identity(), Collectors.reducing(Double.valueOf(0.0d), obj -> {
                return Double.valueOf(1.0d);
            }, (v0, v1) -> {
                return Double.sum(v0, v1);
            })));
        }).collect(Collectors.toList());
    }

    public static <V> List<Map<V, Double>> normalize(Collection<Map<V, Double>> collection) {
        HashMap hashMap = new HashMap();
        collection.stream().flatMap(map -> {
            return map.entrySet().stream();
        }).forEach(entry -> {
            hashMap.put(entry.getKey(), Double.valueOf(((Double) hashMap.getOrDefault(entry.getKey(), Double.valueOf(0.0d))).doubleValue() + ((Double) entry.getValue()).doubleValue()));
        });
        List<Map<V, Double>> list = (List) collection.stream().map(map2 -> {
            Map map2 = (Map) map2.entrySet().stream().collect(Collectors.toMap((v0) -> {
                return v0.getKey();
            }, entry2 -> {
                return Double.valueOf(((Double) entry2.getValue()).doubleValue() / ((Double) hashMap.get(entry2.getKey())).doubleValue());
            }));
            if (map2.size() != map2.size()) {
                throw new IllegalArgumentException("Cluster size changed");
            }
            return map2;
        }).collect(Collectors.toList());
        if (collection.size() != list.size()) {
            throw new IllegalArgumentException("Collection size changed");
        }
        return list;
    }

    public static <V> PrecisionRecall evaluate(NormalizedModifiedPurity<V> normalizedModifiedPurity, NormalizedModifiedPurity<V> normalizedModifiedPurity2, Collection<Map<V, Double>> collection, Collection<Map<V, Double>> collection2) {
        return new PrecisionRecall(normalizedModifiedPurity.purity((Collection) Objects.requireNonNull(collection), (Collection) Objects.requireNonNull(collection2)), normalizedModifiedPurity2.purity(collection2, collection));
    }

    public NormalizedModifiedPurity() {
        this(true, true);
    }

    public NormalizedModifiedPurity(boolean z, boolean z2) {
        this.normalized = z;
        this.modified = z2;
    }

    public double purity(Collection<Map<V, Double>> collection, Collection<Map<V, Double>> collection2) {
        double sum = collection.stream().mapToInt((v0) -> {
            return v0.size();
        }).sum();
        if (this.normalized) {
            sum = collection.parallelStream().mapToDouble(map -> {
                return map.values().stream().mapToDouble((v0) -> {
                    return v0.doubleValue();
                }).sum();
            }).sum();
        }
        if (sum == 0.0d) {
            return 0.0d;
        }
        return collection.parallelStream().mapToDouble(map2 -> {
            return score(map2, collection2);
        }).sum() / sum;
    }

    public double score(Map<V, Double> map, Collection<Map<V, Double>> collection) {
        return collection.stream().mapToDouble(map2 -> {
            return delta(map, map2);
        }).max().orElse(0.0d);
    }

    public double delta(Map<V, Double> map, Map<V, Double> map2) {
        if (this.modified && map.size() <= 1) {
            return 0.0d;
        }
        HashMap hashMap = new HashMap(map);
        hashMap.keySet().retainAll(map2.keySet());
        if (hashMap.isEmpty()) {
            return 0.0d;
        }
        return !this.normalized ? hashMap.size() : hashMap.values().stream().mapToDouble((v0) -> {
            return v0.doubleValue();
        }).sum();
    }
}
