Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from typing import Any, Type | |
from mmpretrain.registry import MODELS | |
class ExtendModule: | |
"""Combine the base language model with adapter. This module will create a | |
instance from base with extended functions in adapter. | |
Args: | |
base (object): Base module could be any object that represent | |
a instance of language model or a dict that can build the | |
base module. | |
adapter: (dict): Dict to build the adapter. | |
""" | |
def __new__(cls, base: object, adapter: dict): | |
if isinstance(base, dict): | |
base = MODELS.build(base) | |
adapter_module = MODELS.get(adapter.pop('type')) | |
cls.extend_instance(base, adapter_module) | |
return adapter_module.extend_init(base, **adapter) | |
def extend_instance(cls, base: object, mixin: Type[Any]): | |
"""Apply mixins to a class instance after creation. | |
Args: | |
base (object): Base module instance. | |
mixin: (Type[Any]): Adapter class type to mixin. | |
""" | |
base_cls = base.__class__ | |
base_cls_name = base.__class__.__name__ | |
base.__class__ = type( | |
base_cls_name, (mixin, base_cls), | |
{}) # mixin needs to go first for our forward() logic to work | |
def getattr_recursive(obj, att): | |
""" | |
Return nested attribute of obj | |
Example: getattr_recursive(obj, 'a.b.c') is equivalent to obj.a.b.c | |
""" | |
if att == '': | |
return obj | |
i = att.find('.') | |
if i < 0: | |
return getattr(obj, att) | |
else: | |
return getattr_recursive(getattr(obj, att[:i]), att[i + 1:]) | |
def setattr_recursive(obj, att, val): | |
""" | |
Set nested attribute of obj | |
Example: setattr_recursive(obj, 'a.b.c', val) | |
is equivalent to obj.a.b.c = val | |
""" | |
if '.' in att: | |
obj = getattr_recursive(obj, '.'.join(att.split('.')[:-1])) | |
setattr(obj, att.split('.')[-1], val) | |