|
|
|
import json |
|
import os |
|
import tempfile |
|
import unittest |
|
|
|
from detectron2.utils.events import CommonMetricPrinter, EventStorage, JSONWriter |
|
|
|
|
|
class TestEventWriter(unittest.TestCase): |
|
def testScalar(self): |
|
with tempfile.TemporaryDirectory( |
|
prefix="detectron2_tests" |
|
) as dir, EventStorage() as storage: |
|
json_file = os.path.join(dir, "test.json") |
|
writer = JSONWriter(json_file) |
|
for k in range(60): |
|
storage.put_scalar("key", k, smoothing_hint=False) |
|
if (k + 1) % 20 == 0: |
|
writer.write() |
|
storage.step() |
|
writer.close() |
|
with open(json_file) as f: |
|
data = [json.loads(l) for l in f] |
|
self.assertTrue([int(k["key"]) for k in data] == [19, 39, 59]) |
|
|
|
def testScalarMismatchedPeriod(self): |
|
with tempfile.TemporaryDirectory( |
|
prefix="detectron2_tests" |
|
) as dir, EventStorage() as storage: |
|
json_file = os.path.join(dir, "test.json") |
|
|
|
writer = JSONWriter(json_file) |
|
for k in range(60): |
|
if k % 17 == 0: |
|
storage.put_scalar("key2", k, smoothing_hint=False) |
|
storage.put_scalar("key", k, smoothing_hint=False) |
|
if (k + 1) % 20 == 0: |
|
writer.write() |
|
storage.step() |
|
writer.close() |
|
with open(json_file) as f: |
|
data = [json.loads(l) for l in f] |
|
self.assertTrue([int(k.get("key2", 0)) for k in data] == [17, 0, 34, 0, 51, 0]) |
|
self.assertTrue([int(k.get("key", 0)) for k in data] == [0, 19, 0, 39, 0, 59]) |
|
self.assertTrue([int(k["iteration"]) for k in data] == [17, 19, 34, 39, 51, 59]) |
|
|
|
def testPrintETA(self): |
|
with EventStorage() as s: |
|
p1 = CommonMetricPrinter(10) |
|
p2 = CommonMetricPrinter() |
|
|
|
s.put_scalar("time", 1.0) |
|
s.step() |
|
s.put_scalar("time", 1.0) |
|
s.step() |
|
|
|
with self.assertLogs("detectron2.utils.events") as logs: |
|
p1.write() |
|
self.assertIn("eta", logs.output[0]) |
|
|
|
with self.assertLogs("detectron2.utils.events") as logs: |
|
p2.write() |
|
self.assertNotIn("eta", logs.output[0]) |
|
|