Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| import unittest | |
| import torch | |
| from detectron2.modeling.meta_arch import GeneralizedRCNN | |
| from detectron2.utils.registry import _convert_target_to_string, locate | |
| class A: | |
| class B: | |
| pass | |
| class TestLocate(unittest.TestCase): | |
| def _test_obj(self, obj): | |
| name = _convert_target_to_string(obj) | |
| newobj = locate(name) | |
| self.assertIs(obj, newobj) | |
| def test_basic(self): | |
| self._test_obj(GeneralizedRCNN) | |
| def test_inside_class(self): | |
| # requires using __qualname__ instead of __name__ | |
| self._test_obj(A.B) | |
| def test_builtin(self): | |
| self._test_obj(len) | |
| self._test_obj(dict) | |
| def test_pytorch_optim(self): | |
| # pydoc.locate does not work for it | |
| self._test_obj(torch.optim.SGD) | |
| def test_failure(self): | |
| with self.assertRaises(ImportError): | |
| locate("asdf") | |
| def test_compress_target(self): | |
| from detectron2.data.transforms import RandomCrop | |
| name = _convert_target_to_string(RandomCrop) | |
| # name shouldn't contain 'augmentation_impl' | |
| self.assertEqual(name, "detectron2.data.transforms.RandomCrop") | |
| self.assertIs(RandomCrop, locate(name)) | |