Spaces:
Runtime error
Runtime error
import abc | |
import dataclasses | |
import functools | |
import inspect | |
from dataclasses import Field, fields | |
from typing import Any, Callable, Dict, Optional, Tuple | |
from enum import Enum | |
from marshmallow import ValidationError | |
from dataclasses_json.utils import CatchAllVar | |
KnownParameters = Dict[str, Any] | |
UnknownParameters = Dict[str, Any] | |
class _UndefinedParameterAction(abc.ABC): | |
def handle_from_dict(cls, kvs: Dict[Any, Any]) -> Dict[str, Any]: | |
""" | |
Return the parameters to initialize the class with. | |
""" | |
pass | |
def handle_to_dict(obj, kvs: Dict[Any, Any]) -> Dict[Any, Any]: | |
""" | |
Return the parameters that will be written to the output dict | |
""" | |
return kvs | |
def handle_dump(obj) -> Dict[Any, Any]: | |
""" | |
Return the parameters that will be added to the schema dump. | |
""" | |
return {} | |
def create_init(obj) -> Callable: | |
return obj.__init__ | |
def _separate_defined_undefined_kvs(cls, kvs: Dict) -> \ | |
Tuple[KnownParameters, UnknownParameters]: | |
""" | |
Returns a 2 dictionaries: defined and undefined parameters | |
""" | |
class_fields = fields(cls) | |
field_names = [field.name for field in class_fields] | |
unknown_given_parameters = {k: v for k, v in kvs.items() if | |
k not in field_names} | |
known_given_parameters = {k: v for k, v in kvs.items() if | |
k in field_names} | |
return known_given_parameters, unknown_given_parameters | |
class _RaiseUndefinedParameters(_UndefinedParameterAction): | |
""" | |
This action raises UndefinedParameterError if it encounters an undefined | |
parameter during initialization. | |
""" | |
def handle_from_dict(cls, kvs: Dict) -> Dict[str, Any]: | |
known, unknown = \ | |
_UndefinedParameterAction._separate_defined_undefined_kvs( | |
cls=cls, kvs=kvs) | |
if len(unknown) > 0: | |
raise UndefinedParameterError( | |
f"Received undefined initialization arguments {unknown}") | |
return known | |
CatchAll = Optional[CatchAllVar] | |
class _IgnoreUndefinedParameters(_UndefinedParameterAction): | |
""" | |
This action does nothing when it encounters undefined parameters. | |
The undefined parameters can not be retrieved after the class has been | |
created. | |
""" | |
def handle_from_dict(cls, kvs: Dict) -> Dict[str, Any]: | |
known_given_parameters, _ = \ | |
_UndefinedParameterAction._separate_defined_undefined_kvs( | |
cls=cls, kvs=kvs) | |
return known_given_parameters | |
def create_init(obj) -> Callable: | |
original_init = obj.__init__ | |
init_signature = inspect.signature(original_init) | |
def _ignore_init(self, *args, **kwargs): | |
known_kwargs, _ = \ | |
_CatchAllUndefinedParameters._separate_defined_undefined_kvs( | |
obj, kwargs) | |
num_params_takeable = len( | |
init_signature.parameters) - 1 # don't count self | |
num_args_takeable = num_params_takeable - len(known_kwargs) | |
args = args[:num_args_takeable] | |
bound_parameters = init_signature.bind_partial(self, *args, | |
**known_kwargs) | |
bound_parameters.apply_defaults() | |
arguments = bound_parameters.arguments | |
arguments.pop("self", None) | |
final_parameters = \ | |
_IgnoreUndefinedParameters.handle_from_dict(obj, arguments) | |
original_init(self, **final_parameters) | |
return _ignore_init | |
class _CatchAllUndefinedParameters(_UndefinedParameterAction): | |
""" | |
This class allows to add a field of type utils.CatchAll which acts as a | |
dictionary into which all | |
undefined parameters will be written. | |
These parameters are not affected by LetterCase. | |
If no undefined parameters are given, this dictionary will be empty. | |
""" | |
class _SentinelNoDefault: | |
pass | |
def handle_from_dict(cls, kvs: Dict) -> Dict[str, Any]: | |
known, unknown = _UndefinedParameterAction \ | |
._separate_defined_undefined_kvs(cls=cls, kvs=kvs) | |
catch_all_field = _CatchAllUndefinedParameters._get_catch_all_field( | |
cls=cls) | |
if catch_all_field.name in known: | |
already_parsed = isinstance(known[catch_all_field.name], dict) | |
default_value = _CatchAllUndefinedParameters._get_default( | |
catch_all_field=catch_all_field) | |
received_default = default_value == known[catch_all_field.name] | |
value_to_write: Any | |
if received_default and len(unknown) == 0: | |
value_to_write = default_value | |
elif received_default and len(unknown) > 0: | |
value_to_write = unknown | |
elif already_parsed: | |
# Did not receive default | |
value_to_write = known[catch_all_field.name] | |
if len(unknown) > 0: | |
value_to_write.update(unknown) | |
else: | |
error_message = f"Received input field with " \ | |
f"same name as catch-all field: " \ | |
f"'{catch_all_field.name}': " \ | |
f"'{known[catch_all_field.name]}'" | |
raise UndefinedParameterError(error_message) | |
else: | |
value_to_write = unknown | |
known[catch_all_field.name] = value_to_write | |
return known | |
def _get_default(catch_all_field: Field) -> Any: | |
# access to the default factory currently causes | |
# a false-positive mypy error (16. Dec 2019): | |
# https://github.com/python/mypy/issues/6910 | |
# noinspection PyProtectedMember | |
has_default = not isinstance(catch_all_field.default, | |
dataclasses._MISSING_TYPE) | |
# noinspection PyProtectedMember | |
has_default_factory = not isinstance(catch_all_field.default_factory, | |
# type: ignore | |
dataclasses._MISSING_TYPE) | |
default_value = _CatchAllUndefinedParameters._SentinelNoDefault | |
if has_default: | |
default_value = catch_all_field.default | |
elif has_default_factory: | |
# This might be unwanted if the default factory constructs | |
# something expensive, | |
# because we have to construct it again just for this test | |
default_value = catch_all_field.default_factory() # type: ignore | |
return default_value | |
def handle_to_dict(obj, kvs: Dict[Any, Any]) -> Dict[Any, Any]: | |
catch_all_field = \ | |
_CatchAllUndefinedParameters._get_catch_all_field(obj) | |
undefined_parameters = kvs.pop(catch_all_field.name) | |
if isinstance(undefined_parameters, dict): | |
kvs.update( | |
undefined_parameters) # If desired handle letter case here | |
return kvs | |
def handle_dump(obj) -> Dict[Any, Any]: | |
catch_all_field = _CatchAllUndefinedParameters._get_catch_all_field( | |
cls=obj) | |
return getattr(obj, catch_all_field.name) | |
def create_init(obj) -> Callable: | |
original_init = obj.__init__ | |
init_signature = inspect.signature(original_init) | |
def _catch_all_init(self, *args, **kwargs): | |
known_kwargs, unknown_kwargs = \ | |
_CatchAllUndefinedParameters._separate_defined_undefined_kvs( | |
obj, kwargs) | |
num_params_takeable = len( | |
init_signature.parameters) - 1 # don't count self | |
if _CatchAllUndefinedParameters._get_catch_all_field( | |
obj).name not in known_kwargs: | |
num_params_takeable -= 1 | |
num_args_takeable = num_params_takeable - len(known_kwargs) | |
args, unknown_args = args[:num_args_takeable], args[ | |
num_args_takeable:] | |
bound_parameters = init_signature.bind_partial(self, *args, | |
**known_kwargs) | |
unknown_args = {f"_UNKNOWN{i}": v for i, v in | |
enumerate(unknown_args)} | |
arguments = bound_parameters.arguments | |
arguments.update(unknown_args) | |
arguments.update(unknown_kwargs) | |
arguments.pop("self", None) | |
final_parameters = _CatchAllUndefinedParameters.handle_from_dict( | |
obj, arguments) | |
original_init(self, **final_parameters) | |
return _catch_all_init | |
def _get_catch_all_field(cls) -> Field: | |
catch_all_fields = list( | |
filter(lambda f: f.type == Optional[CatchAllVar], fields(cls))) | |
number_of_catch_all_fields = len(catch_all_fields) | |
if number_of_catch_all_fields == 0: | |
raise UndefinedParameterError( | |
"No field of type dataclasses_json.CatchAll defined") | |
elif number_of_catch_all_fields > 1: | |
raise UndefinedParameterError( | |
f"Multiple catch-all fields supplied: " | |
f"{number_of_catch_all_fields}.") | |
else: | |
return catch_all_fields[0] | |
class Undefined(Enum): | |
""" | |
Choose the behavior what happens when an undefined parameter is encountered | |
during class initialization. | |
""" | |
INCLUDE = _CatchAllUndefinedParameters | |
RAISE = _RaiseUndefinedParameters | |
EXCLUDE = _IgnoreUndefinedParameters | |
class UndefinedParameterError(ValidationError): | |
""" | |
Raised when something has gone wrong handling undefined parameters. | |
""" | |
pass | |