Spaces:
Paused
Paused
| #!/usr/bin/env python | |
| # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | |
| import os | |
| import tempfile | |
| import unittest | |
| import torch | |
| from detectron2.config import configurable, downgrade_config, get_cfg, upgrade_config | |
| from detectron2.layers import ShapeSpec | |
| _V0_CFG = """ | |
| MODEL: | |
| RPN_HEAD: | |
| NAME: "TEST" | |
| VERSION: 0 | |
| """ | |
| _V1_CFG = """ | |
| MODEL: | |
| WEIGHT: "/path/to/weight" | |
| """ | |
| class TestConfigVersioning(unittest.TestCase): | |
| def test_upgrade_downgrade_consistency(self): | |
| cfg = get_cfg() | |
| # check that custom is preserved | |
| cfg.USER_CUSTOM = 1 | |
| down = downgrade_config(cfg, to_version=0) | |
| up = upgrade_config(down) | |
| self.assertTrue(up == cfg) | |
| def _merge_cfg_str(self, cfg, merge_str): | |
| f = tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) | |
| try: | |
| f.write(merge_str) | |
| f.close() | |
| cfg.merge_from_file(f.name) | |
| finally: | |
| os.remove(f.name) | |
| return cfg | |
| def test_auto_upgrade(self): | |
| cfg = get_cfg() | |
| latest_ver = cfg.VERSION | |
| cfg.USER_CUSTOM = 1 | |
| self._merge_cfg_str(cfg, _V0_CFG) | |
| self.assertEqual(cfg.MODEL.RPN.HEAD_NAME, "TEST") | |
| self.assertEqual(cfg.VERSION, latest_ver) | |
| def test_guess_v1(self): | |
| cfg = get_cfg() | |
| latest_ver = cfg.VERSION | |
| self._merge_cfg_str(cfg, _V1_CFG) | |
| self.assertEqual(cfg.VERSION, latest_ver) | |
| class _TestClassA(torch.nn.Module): | |
| def __init__(self, arg1, arg2, arg3=3): | |
| super().__init__() | |
| self.arg1 = arg1 | |
| self.arg2 = arg2 | |
| self.arg3 = arg3 | |
| assert arg1 == 1 | |
| assert arg2 == 2 | |
| assert arg3 == 3 | |
| def from_config(cls, cfg): | |
| args = {"arg1": cfg.ARG1, "arg2": cfg.ARG2} | |
| return args | |
| class _TestClassB(_TestClassA): | |
| def __init__(self, input_shape, arg1, arg2, arg3=3): | |
| """ | |
| Doc of _TestClassB | |
| """ | |
| assert input_shape == "shape" | |
| super().__init__(arg1, arg2, arg3) | |
| def from_config(cls, cfg, input_shape): # test extra positional arg in from_config | |
| args = {"arg1": cfg.ARG1, "arg2": cfg.ARG2} | |
| args["input_shape"] = input_shape | |
| return args | |
| class _LegacySubClass(_TestClassB): | |
| # an old subclass written in cfg style | |
| def __init__(self, cfg, input_shape, arg4=4): | |
| super().__init__(cfg, input_shape) | |
| assert self.arg1 == 1 | |
| assert self.arg2 == 2 | |
| assert self.arg3 == 3 | |
| class _NewSubClassNewInit(_TestClassB): | |
| # test new subclass with a new __init__ | |
| def __init__(self, input_shape, arg4=4, **kwargs): | |
| super().__init__(input_shape, **kwargs) | |
| assert self.arg1 == 1 | |
| assert self.arg2 == 2 | |
| assert self.arg3 == 3 | |
| class _LegacySubClassNotCfg(_TestClassB): | |
| # an old subclass written in cfg style, but argument is not called "cfg" | |
| def __init__(self, config, input_shape): | |
| super().__init__(config, input_shape) | |
| assert self.arg1 == 1 | |
| assert self.arg2 == 2 | |
| assert self.arg3 == 3 | |
| class _TestClassC(_TestClassB): | |
| def from_config(cls, cfg, input_shape, **kwargs): # test extra kwarg overwrite | |
| args = {"arg1": cfg.ARG1, "arg2": cfg.ARG2} | |
| args["input_shape"] = input_shape | |
| args.update(kwargs) | |
| return args | |
| class _TestClassD(_TestClassA): | |
| def __init__(self, input_shape: ShapeSpec, arg1: int, arg2, arg3=3): | |
| assert input_shape == "shape" | |
| super().__init__(arg1, arg2, arg3) | |
| # _TestClassA.from_config does not have input_shape args. | |
| # Test whether input_shape will be forwarded to __init__ | |
| class TestConfigurable(unittest.TestCase): | |
| def testInitWithArgs(self): | |
| _ = _TestClassA(arg1=1, arg2=2, arg3=3) | |
| _ = _TestClassB("shape", arg1=1, arg2=2) | |
| _ = _TestClassC("shape", arg1=1, arg2=2) | |
| _ = _TestClassD("shape", arg1=1, arg2=2, arg3=3) | |
| def testPatchedAttr(self): | |
| self.assertTrue("Doc" in _TestClassB.__init__.__doc__) | |
| self.assertEqual(_TestClassD.__init__.__annotations__["arg1"], int) | |
| def testInitWithCfg(self): | |
| cfg = get_cfg() | |
| cfg.ARG1 = 1 | |
| cfg.ARG2 = 2 | |
| cfg.ARG3 = 3 | |
| _ = _TestClassA(cfg) | |
| _ = _TestClassB(cfg, input_shape="shape") | |
| _ = _TestClassC(cfg, input_shape="shape") | |
| _ = _TestClassD(cfg, input_shape="shape") | |
| _ = _LegacySubClass(cfg, input_shape="shape") | |
| _ = _NewSubClassNewInit(cfg, input_shape="shape") | |
| _ = _LegacySubClassNotCfg(cfg, input_shape="shape") | |
| with self.assertRaises(TypeError): | |
| # disallow forwarding positional args to __init__ since it's prone to errors | |
| _ = _TestClassD(cfg, "shape") | |
| # call with kwargs instead | |
| _ = _TestClassA(cfg=cfg) | |
| _ = _TestClassB(cfg=cfg, input_shape="shape") | |
| _ = _TestClassC(cfg=cfg, input_shape="shape") | |
| _ = _TestClassD(cfg=cfg, input_shape="shape") | |
| _ = _LegacySubClass(cfg=cfg, input_shape="shape") | |
| _ = _NewSubClassNewInit(cfg=cfg, input_shape="shape") | |
| _ = _LegacySubClassNotCfg(config=cfg, input_shape="shape") | |
| def testInitWithCfgOverwrite(self): | |
| cfg = get_cfg() | |
| cfg.ARG1 = 1 | |
| cfg.ARG2 = 999 # wrong config | |
| with self.assertRaises(AssertionError): | |
| _ = _TestClassA(cfg, arg3=3) | |
| # overwrite arg2 with correct config later: | |
| _ = _TestClassA(cfg, arg2=2, arg3=3) | |
| _ = _TestClassB(cfg, input_shape="shape", arg2=2, arg3=3) | |
| _ = _TestClassC(cfg, input_shape="shape", arg2=2, arg3=3) | |
| _ = _TestClassD(cfg, input_shape="shape", arg2=2, arg3=3) | |
| # call with kwargs cfg=cfg instead | |
| _ = _TestClassA(cfg=cfg, arg2=2, arg3=3) | |
| _ = _TestClassB(cfg=cfg, input_shape="shape", arg2=2, arg3=3) | |
| _ = _TestClassC(cfg=cfg, input_shape="shape", arg2=2, arg3=3) | |
| _ = _TestClassD(cfg=cfg, input_shape="shape", arg2=2, arg3=3) | |
| def testInitWithCfgWrongArgs(self): | |
| cfg = get_cfg() | |
| cfg.ARG1 = 1 | |
| cfg.ARG2 = 2 | |
| with self.assertRaises(TypeError): | |
| _ = _TestClassB(cfg, "shape", not_exist=1) | |
| with self.assertRaises(TypeError): | |
| _ = _TestClassC(cfg, "shape", not_exist=1) | |
| with self.assertRaises(TypeError): | |
| _ = _TestClassD(cfg, "shape", not_exist=1) | |
| def testBadClass(self): | |
| class _BadClass1: | |
| def __init__(self, a=1, b=2): | |
| pass | |
| class _BadClass2: | |
| def __init__(self, a=1, b=2): | |
| pass | |
| def from_config(self, cfg): # noqa | |
| pass | |
| class _BadClass3: | |
| def __init__(self, a=1, b=2): | |
| pass | |
| # bad name: must be cfg | |
| def from_config(cls, config): # noqa | |
| pass | |
| with self.assertRaises(AttributeError): | |
| _ = _BadClass1(a=1) | |
| with self.assertRaises(TypeError): | |
| _ = _BadClass2(a=1) | |
| with self.assertRaises(TypeError): | |
| _ = _BadClass3(get_cfg()) | |