Spaces:
Build error
Build error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| import json | |
| import os | |
| import tempfile | |
| import time | |
| import unittest | |
| import torch | |
| from torch import nn | |
| from detectron2.config import configurable, get_cfg | |
| from detectron2.engine import DefaultTrainer, SimpleTrainer, hooks | |
| from detectron2.modeling.meta_arch import META_ARCH_REGISTRY | |
| from detectron2.utils.events import CommonMetricPrinter, JSONWriter | |
| class _SimpleModel(nn.Module): | |
| def __init__(self, sleep_sec=0): | |
| super().__init__() | |
| self.mod = nn.Linear(10, 20) | |
| self.sleep_sec = sleep_sec | |
| def from_config(cls, cfg): | |
| return {} | |
| def forward(self, x): | |
| if self.sleep_sec > 0: | |
| time.sleep(self.sleep_sec) | |
| return {"loss": x.sum() + sum([x.mean() for x in self.parameters()])} | |
| class TestTrainer(unittest.TestCase): | |
| def _data_loader(self, device): | |
| device = torch.device(device) | |
| while True: | |
| yield torch.rand(3, 3).to(device) | |
| def test_simple_trainer(self, device="cpu"): | |
| model = _SimpleModel().to(device=device) | |
| trainer = SimpleTrainer( | |
| model, self._data_loader(device), torch.optim.SGD(model.parameters(), 0.1) | |
| ) | |
| trainer.train(0, 10) | |
| def test_simple_trainer_cuda(self): | |
| self.test_simple_trainer(device="cuda") | |
| def test_writer_hooks(self): | |
| model = _SimpleModel(sleep_sec=0.1) | |
| trainer = SimpleTrainer( | |
| model, self._data_loader("cpu"), torch.optim.SGD(model.parameters(), 0.1) | |
| ) | |
| max_iter = 50 | |
| with tempfile.TemporaryDirectory(prefix="detectron2_test") as d: | |
| json_file = os.path.join(d, "metrics.json") | |
| writers = [CommonMetricPrinter(max_iter), JSONWriter(json_file)] | |
| trainer.register_hooks( | |
| [hooks.EvalHook(0, lambda: {"metric": 100}), hooks.PeriodicWriter(writers)] | |
| ) | |
| with self.assertLogs(writers[0].logger) as logs: | |
| trainer.train(0, max_iter) | |
| with open(json_file, "r") as f: | |
| data = [json.loads(line.strip()) for line in f] | |
| self.assertEqual([x["iteration"] for x in data], [19, 39, 49, 50]) | |
| # the eval metric is in the last line with iter 50 | |
| self.assertIn("metric", data[-1], "Eval metric must be in last line of JSON!") | |
| # test logged messages from CommonMetricPrinter | |
| self.assertEqual(len(logs.output), 3) | |
| for log, iter in zip(logs.output, [19, 39, 49]): | |
| self.assertIn(f"iter: {iter}", log) | |
| self.assertIn("eta: 0:00:00", logs.output[-1], "Last ETA must be 0!") | |
| def test_default_trainer(self): | |
| cfg = get_cfg() | |
| cfg.MODEL.META_ARCHITECTURE = "_SimpleModel" | |
| cfg.DATASETS.TRAIN = ("coco_2017_val_100",) | |
| with tempfile.TemporaryDirectory(prefix="detectron2_test") as d: | |
| cfg.OUTPUT_DIR = d | |
| trainer = DefaultTrainer(cfg) | |
| # test property | |
| self.assertIs(trainer.model, trainer._trainer.model) | |
| trainer.model = _SimpleModel() | |
| self.assertIs(trainer.model, trainer._trainer.model) | |