import unittest
import torch
import numpy as np
from tensorboardX import SummaryWriter
from tensorboardX import summary
from .expect_reader import compare_proto

np.random.seed(0)
true_positive_counts = [75, 64, 21, 5, 0]
false_positive_counts = [150, 105, 18, 0, 0]
true_negative_counts = [0, 45, 132, 150, 150]
false_negative_counts = [0, 11, 54, 70, 75]
precision = [0.3333333, 0.3786982, 0.5384616, 1.0, 0.0]
recall = [1.0, 0.8533334, 0.28, 0.0666667, 0.0]


class PRCurveTest(unittest.TestCase):
    def test_smoke(self):
        with SummaryWriter() as writer:
            writer.add_pr_curve('xoxo', np.random.randint(2, size=100), np.random.rand(
                100), 1)
            writer.add_pr_curve_raw('prcurve with raw data',
                                    true_positive_counts,
                                    false_positive_counts,
                                    true_negative_counts,
                                    false_negative_counts,
                                    precision,
                                    recall,
                                    1)

    def test_pr_purve(self):
        random_labels = np.array([0, 1, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1,
            1, 0, 1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0,
            0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1,
            1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0,
            1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0])
        random_probs = np.array([0.33327776, 0.30032885, 0.79012837, 0.04306813, 0.65221544,
            0.58481968, 0.28305522, 0.53795795, 0.00729739, 0.52266951,
            0.22464247, 0.11262435, 0.41573075, 0.92493992, 0.73066758,
            0.43867735, 0.27955449, 0.56975382, 0.53933028, 0.34392824,
            0.30312509, 0.81732807, 0.55408544, 0.3969487 , 0.31768033,
            0.24353266, 0.47198005, 0.19999122, 0.05788022, 0.24046305,
            0.04651082, 0.30061738, 0.78321545, 0.82670207, 0.49200517,
            0.80904619, 0.96711993, 0.3160946 , 0.01049424, 0.60108337,
            0.56508792, 0.83729429, 0.9717386 , 0.46306053, 0.80232138,
            0.24166823, 0.7393237 , 0.50820418, 0.04944932, 0.53854157,
            0.10765172, 0.84723855, 0.20518299, 0.3143431 , 0.51299074,
            0.47065695, 0.54267833, 0.1812676 , 0.06265177, 0.34110327,
            0.30915171, 0.91870169, 0.91309447, 0.31395817, 0.36780571,
            0.98297986, 0.00594547, 0.52839042, 0.70229202, 0.37779588,
            0.15207045, 0.59759632, 0.72397032, 0.71502195, 0.90135725,
            0.43970107, 0.17123532, 0.08785938, 0.04986818, 0.62702444,
            0.69171023, 0.30537792, 0.30285433, 0.27124347, 0.27693729,
            0.7136039 , 0.48022489, 0.20916285, 0.2018599 , 0.92401008,
            0.30189681, 0.46862626, 0.96353024, 0.30468533, 0.68281294,
            0.30623562, 0.40795975, 0.76824531, 0.89824215, 0.69845035], dtype=np.float16)
        compare_proto(summary.pr_curve('tag', random_labels, random_probs, 1), self)

    def test_pr_purve_raw(self):
        compare_proto(summary.pr_curve_raw('prcurve with raw data',
                                           true_positive_counts,
                                           false_positive_counts,
                                           true_negative_counts,
                                           false_negative_counts,
                                           precision,
                                           recall,
                                           1),
                      self)