Evgeny Zhukov
Origin: https://github.com/ali-vilab/UniAnimate/commit/d7814fa44a0a1154524b92fce0e3133a2604d333
2ba4412
# Copyright 2021 Alibaba Group Holding Limited. All Rights Reserved.
# Registry class & build_from_config function partially modified from
# https://github.com/open-mmlab/mmcv/blob/master/mmcv/utils/registry.py
# Copyright 2018-2020 Open-MMLab. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import inspect
import warnings
def build_from_config(cfg, registry, **kwargs):
""" Default builder function.
Args:
cfg (dict): A dict which contains parameters passes to target class or function.
Must contains key 'type', indicates the target class or function name.
registry (Registry): An registry to search target class or function.
kwargs (dict, optional): Other params not in config dict.
Returns:
Target class object or object returned by invoking function.
Raises:
TypeError:
KeyError:
Exception:
"""
if not isinstance(cfg, dict):
raise TypeError(f"config must be type dict, got {type(cfg)}")
if "type" not in cfg:
raise KeyError(f"config must contain key type, got {cfg}")
if not isinstance(registry, Registry):
raise TypeError(f"registry must be type Registry, got {type(registry)}")
cfg = copy.deepcopy(cfg)
req_type = cfg.pop("type")
req_type_entry = req_type
if isinstance(req_type, str):
req_type_entry = registry.get(req_type)
if req_type_entry is None:
try:
print(f"For Windows users, we explicitly import registry function {req_type} !!!")
from tools.inferences.inference_unianimate_entrance import inference_unianimate_entrance
from tools.inferences.inference_unianimate_long_entrance import inference_unianimate_long_entrance
from tools.modules.diffusions.diffusion_ddim import DiffusionDDIM
from tools.modules.diffusions.diffusion_ddim import DiffusionDDIMLong
from tools.modules.autoencoder import AutoencoderKL
from tools.modules.clip_embedder import FrozenOpenCLIPTextVisualEmbedder
from tools.modules.unet.unet_unianimate import UNetSD_UniAnimate
req_type_entry = eval(req_type)
except:
raise KeyError(f"{req_type} not found in {registry.name} registry")
if kwargs is not None:
cfg.update(kwargs)
if inspect.isclass(req_type_entry):
try:
return req_type_entry(**cfg)
except Exception as e:
raise Exception(f"Failed to init class {req_type_entry}, with {e}")
elif inspect.isfunction(req_type_entry):
try:
return req_type_entry(**cfg)
except Exception as e:
raise Exception(f"Failed to invoke function {req_type_entry}, with {e}")
else:
raise TypeError(f"type must be str or class, got {type(req_type_entry)}")
class Registry(object):
""" A registry maps key to classes or functions.
Example:
>>> MODELS = Registry('MODELS')
>>> @MODELS.register_class()
>>> class ResNet(object):
>>> pass
>>> resnet = MODELS.build(dict(type="ResNet"))
>>>
>>> import torchvision
>>> @MODELS.register_function("InceptionV3")
>>> def get_inception_v3(pretrained=False, progress=True):
>>> return torchvision.models.inception_v3(pretrained=pretrained, progress=progress)
>>> inception_v3 = MODELS.build(dict(type='InceptionV3', pretrained=True))
Args:
name (str): Registry name.
build_func (func, None): Instance construct function. Default is build_from_config.
allow_types (tuple): Indicates how to construct the instance, by constructing class or invoking function.
"""
def __init__(self, name, build_func=None, allow_types=("class", "function")):
self.name = name
self.allow_types = allow_types
self.class_map = {}
self.func_map = {}
self.build_func = build_func or build_from_config
def get(self, req_type):
return self.class_map.get(req_type) or self.func_map.get(req_type)
def build(self, *args, **kwargs):
return self.build_func(*args, **kwargs, registry=self)
def register_class(self, name=None):
def _register(cls):
if not inspect.isclass(cls):
raise TypeError(f"Module must be type class, got {type(cls)}")
if "class" not in self.allow_types:
raise TypeError(f"Register {self.name} only allows type {self.allow_types}, got class")
module_name = name or cls.__name__
if module_name in self.class_map:
warnings.warn(f"Class {module_name} already registered by {self.class_map[module_name]}, "
f"will be replaced by {cls}")
self.class_map[module_name] = cls
return cls
return _register
def register_function(self, name=None):
def _register(func):
if not inspect.isfunction(func):
raise TypeError(f"Registry must be type function, got {type(func)}")
if "function" not in self.allow_types:
raise TypeError(f"Registry {self.name} only allows type {self.allow_types}, got function")
func_name = name or func.__name__
if func_name in self.class_map:
warnings.warn(f"Function {func_name} already registered by {self.func_map[func_name]}, "
f"will be replaced by {func}")
self.func_map[func_name] = func
return func
return _register
def _list(self):
keys = sorted(list(self.class_map.keys()) + list(self.func_map.keys()))
descriptions = []
for key in keys:
if key in self.class_map:
descriptions.append(f"{key}: {self.class_map[key]}")
else:
descriptions.append(
f"{key}: <function '{self.func_map[key].__module__}.{self.func_map[key].__name__}'>")
return "\n".join(descriptions)
def __repr__(self):
description = self._list()
description = '\n'.join(['\t' + s for s in description.split('\n')])
return f"{self.__class__.__name__} [{self.name}], \n" + description