|
import unittest |
|
import torch |
|
import numpy as np |
|
from tensorboardX import SummaryWriter |
|
from tensorboardX import summary |
|
from .expect_reader import compare_proto |
|
|
|
np.random.seed(0) |
|
true_positive_counts = [75, 64, 21, 5, 0] |
|
false_positive_counts = [150, 105, 18, 0, 0] |
|
true_negative_counts = [0, 45, 132, 150, 150] |
|
false_negative_counts = [0, 11, 54, 70, 75] |
|
precision = [0.3333333, 0.3786982, 0.5384616, 1.0, 0.0] |
|
recall = [1.0, 0.8533334, 0.28, 0.0666667, 0.0] |
|
|
|
|
|
class PRCurveTest(unittest.TestCase): |
|
def test_smoke(self): |
|
with SummaryWriter() as writer: |
|
writer.add_pr_curve('xoxo', np.random.randint(2, size=100), np.random.rand( |
|
100), 1) |
|
writer.add_pr_curve_raw('prcurve with raw data', |
|
true_positive_counts, |
|
false_positive_counts, |
|
true_negative_counts, |
|
false_negative_counts, |
|
precision, |
|
recall, |
|
1) |
|
|
|
def test_pr_purve(self): |
|
random_labels = np.array([0, 1, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, |
|
1, 0, 1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, |
|
0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, |
|
1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0, |
|
1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0]) |
|
random_probs = np.array([0.33327776, 0.30032885, 0.79012837, 0.04306813, 0.65221544, |
|
0.58481968, 0.28305522, 0.53795795, 0.00729739, 0.52266951, |
|
0.22464247, 0.11262435, 0.41573075, 0.92493992, 0.73066758, |
|
0.43867735, 0.27955449, 0.56975382, 0.53933028, 0.34392824, |
|
0.30312509, 0.81732807, 0.55408544, 0.3969487 , 0.31768033, |
|
0.24353266, 0.47198005, 0.19999122, 0.05788022, 0.24046305, |
|
0.04651082, 0.30061738, 0.78321545, 0.82670207, 0.49200517, |
|
0.80904619, 0.96711993, 0.3160946 , 0.01049424, 0.60108337, |
|
0.56508792, 0.83729429, 0.9717386 , 0.46306053, 0.80232138, |
|
0.24166823, 0.7393237 , 0.50820418, 0.04944932, 0.53854157, |
|
0.10765172, 0.84723855, 0.20518299, 0.3143431 , 0.51299074, |
|
0.47065695, 0.54267833, 0.1812676 , 0.06265177, 0.34110327, |
|
0.30915171, 0.91870169, 0.91309447, 0.31395817, 0.36780571, |
|
0.98297986, 0.00594547, 0.52839042, 0.70229202, 0.37779588, |
|
0.15207045, 0.59759632, 0.72397032, 0.71502195, 0.90135725, |
|
0.43970107, 0.17123532, 0.08785938, 0.04986818, 0.62702444, |
|
0.69171023, 0.30537792, 0.30285433, 0.27124347, 0.27693729, |
|
0.7136039 , 0.48022489, 0.20916285, 0.2018599 , 0.92401008, |
|
0.30189681, 0.46862626, 0.96353024, 0.30468533, 0.68281294, |
|
0.30623562, 0.40795975, 0.76824531, 0.89824215, 0.69845035], dtype=np.float16) |
|
compare_proto(summary.pr_curve('tag', random_labels, random_probs, 1), self) |
|
|
|
def test_pr_purve_raw(self): |
|
compare_proto(summary.pr_curve_raw('prcurve with raw data', |
|
true_positive_counts, |
|
false_positive_counts, |
|
true_negative_counts, |
|
false_negative_counts, |
|
precision, |
|
recall, |
|
1), |
|
self) |
|
|