|  | import difflib | 
					
						
						|  | import inspect | 
					
						
						|  | import json | 
					
						
						|  | import os | 
					
						
						|  | import pkgutil | 
					
						
						|  | from abc import abstractmethod | 
					
						
						|  | from copy import deepcopy | 
					
						
						|  | from typing import Dict, List, Optional, Union, final | 
					
						
						|  |  | 
					
						
						|  | from .dataclass import AbstractField, Dataclass, Field, InternalField, fields | 
					
						
						|  | from .logging_utils import get_logger | 
					
						
						|  | from .parsing_utils import ( | 
					
						
						|  | separate_inside_and_outside_square_brackets, | 
					
						
						|  | ) | 
					
						
						|  | from .settings_utils import get_settings | 
					
						
						|  | from .text_utils import camel_to_snake_case, is_camel_case | 
					
						
						|  | from .type_utils import issubtype | 
					
						
						|  | from .utils import artifacts_json_cache, save_json | 
					
						
						|  |  | 
					
						
						|  | logger = get_logger() | 
					
						
						|  | settings = get_settings() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class Artifactories: | 
					
						
						|  | def __new__(cls): | 
					
						
						|  | if not hasattr(cls, "instance"): | 
					
						
						|  | cls.instance = super().__new__(cls) | 
					
						
						|  | cls.instance.artifactories = [] | 
					
						
						|  |  | 
					
						
						|  | return cls.instance | 
					
						
						|  |  | 
					
						
						|  | def __iter__(self): | 
					
						
						|  | self._index = 0 | 
					
						
						|  | return self | 
					
						
						|  |  | 
					
						
						|  | def __next__(self): | 
					
						
						|  | while self._index < len(self.artifactories): | 
					
						
						|  | artifactory = self.artifactories[self._index] | 
					
						
						|  | self._index += 1 | 
					
						
						|  | if ( | 
					
						
						|  | settings.use_only_local_catalogs and not artifactory.is_local | 
					
						
						|  | ): | 
					
						
						|  | continue | 
					
						
						|  | return artifactory | 
					
						
						|  | raise StopIteration | 
					
						
						|  |  | 
					
						
						|  | def register(self, artifactory): | 
					
						
						|  | assert isinstance( | 
					
						
						|  | artifactory, Artifactory | 
					
						
						|  | ), "Artifactory must be an instance of Artifactory" | 
					
						
						|  | assert hasattr( | 
					
						
						|  | artifactory, "__contains__" | 
					
						
						|  | ), "Artifactory must have __contains__ method" | 
					
						
						|  | assert hasattr( | 
					
						
						|  | artifactory, "__getitem__" | 
					
						
						|  | ), "Artifactory must have __getitem__ method" | 
					
						
						|  | self.artifactories = [artifactory, *self.artifactories] | 
					
						
						|  |  | 
					
						
						|  | def unregister(self, artifactory): | 
					
						
						|  | assert isinstance( | 
					
						
						|  | artifactory, Artifactory | 
					
						
						|  | ), "Artifactory must be an instance of Artifactory" | 
					
						
						|  | assert hasattr( | 
					
						
						|  | artifactory, "__contains__" | 
					
						
						|  | ), "Artifactory must have __contains__ method" | 
					
						
						|  | assert hasattr( | 
					
						
						|  | artifactory, "__getitem__" | 
					
						
						|  | ), "Artifactory must have __getitem__ method" | 
					
						
						|  | self.artifactories.remove(artifactory) | 
					
						
						|  |  | 
					
						
						|  | def reset(self): | 
					
						
						|  | self.artifactories = [] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def map_values_in_place(object, mapper): | 
					
						
						|  | if isinstance(object, dict): | 
					
						
						|  | for key, value in object.items(): | 
					
						
						|  | object[key] = mapper(value) | 
					
						
						|  | return object | 
					
						
						|  | if isinstance(object, list): | 
					
						
						|  | for i in range(len(object)): | 
					
						
						|  | object[i] = mapper(object[i]) | 
					
						
						|  | return object | 
					
						
						|  | return mapper(object) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def get_closest_artifact_type(type): | 
					
						
						|  | artifact_type_options = list(Artifact._class_register.keys()) | 
					
						
						|  | matches = difflib.get_close_matches(type, artifact_type_options) | 
					
						
						|  | if matches: | 
					
						
						|  | return matches[0] | 
					
						
						|  | return None | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class UnrecognizedArtifactTypeError(ValueError): | 
					
						
						|  | def __init__(self, type) -> None: | 
					
						
						|  | maybe_class = "".join(word.capitalize() for word in type.split("_")) | 
					
						
						|  | message = f"'{type}' is not a recognized artifact 'type'. Make sure a the class defined this type (Probably called '{maybe_class}' or similar) is defined and/or imported anywhere in the code executed." | 
					
						
						|  | closest_artifact_type = get_closest_artifact_type(type) | 
					
						
						|  | if closest_artifact_type is not None: | 
					
						
						|  | message += "\n\n" f"Did you mean '{closest_artifact_type}'?" | 
					
						
						|  | super().__init__(message) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class MissingArtifactTypeError(ValueError): | 
					
						
						|  | def __init__(self, dic) -> None: | 
					
						
						|  | message = ( | 
					
						
						|  | f"Missing 'type' parameter. Expected 'type' in artifact dict, got {dic}" | 
					
						
						|  | ) | 
					
						
						|  | super().__init__(message) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class Artifact(Dataclass): | 
					
						
						|  | _class_register = {} | 
					
						
						|  |  | 
					
						
						|  | type: str = Field(default=None, final=True, init=False) | 
					
						
						|  | __description__: str = InternalField( | 
					
						
						|  | default=None, required=False, also_positional=False | 
					
						
						|  | ) | 
					
						
						|  | __tags__: Dict[str, str] = InternalField( | 
					
						
						|  | default_factory=dict, required=False, also_positional=False | 
					
						
						|  | ) | 
					
						
						|  | __id__: str = InternalField(default=None, required=False, also_positional=False) | 
					
						
						|  |  | 
					
						
						|  | @classmethod | 
					
						
						|  | def is_artifact_dict(cls, d): | 
					
						
						|  | return isinstance(d, dict) and "type" in d | 
					
						
						|  |  | 
					
						
						|  | @classmethod | 
					
						
						|  | def verify_artifact_dict(cls, d): | 
					
						
						|  | if not isinstance(d, dict): | 
					
						
						|  | raise ValueError( | 
					
						
						|  | f"Artifact dict <{d}> must be of type 'dict', got '{type(d)}'." | 
					
						
						|  | ) | 
					
						
						|  | if "type" not in d: | 
					
						
						|  | raise MissingArtifactTypeError(d) | 
					
						
						|  | if not cls.is_registered_type(d["type"]): | 
					
						
						|  | raise UnrecognizedArtifactTypeError(d["type"]) | 
					
						
						|  |  | 
					
						
						|  | @classmethod | 
					
						
						|  | def get_artifact_type(cls): | 
					
						
						|  | return camel_to_snake_case(cls.__name__) | 
					
						
						|  |  | 
					
						
						|  | @classmethod | 
					
						
						|  | def register_class(cls, artifact_class): | 
					
						
						|  | assert issubclass( | 
					
						
						|  | artifact_class, Artifact | 
					
						
						|  | ), f"Artifact class must be a subclass of Artifact, got '{artifact_class}'" | 
					
						
						|  | assert is_camel_case( | 
					
						
						|  | artifact_class.__name__ | 
					
						
						|  | ), f"Artifact class name must be legal camel case, got '{artifact_class.__name__}'" | 
					
						
						|  |  | 
					
						
						|  | snake_case_key = camel_to_snake_case(artifact_class.__name__) | 
					
						
						|  |  | 
					
						
						|  | if cls.is_registered_type(snake_case_key): | 
					
						
						|  | assert ( | 
					
						
						|  | str(cls._class_register[snake_case_key]) == str(artifact_class) | 
					
						
						|  | ), f"Artifact class name must be unique, '{snake_case_key}' already exists for {cls._class_register[snake_case_key]}. Cannot be overridden by {artifact_class}." | 
					
						
						|  |  | 
					
						
						|  | return snake_case_key | 
					
						
						|  |  | 
					
						
						|  | cls._class_register[snake_case_key] = artifact_class | 
					
						
						|  |  | 
					
						
						|  | return snake_case_key | 
					
						
						|  |  | 
					
						
						|  | def __init_subclass__(cls, **kwargs): | 
					
						
						|  | super().__init_subclass__(**kwargs) | 
					
						
						|  | cls.register_class(cls) | 
					
						
						|  |  | 
					
						
						|  | @classmethod | 
					
						
						|  | def is_artifact_file(cls, path): | 
					
						
						|  | if not os.path.exists(path) or not os.path.isfile(path): | 
					
						
						|  | return False | 
					
						
						|  | with open(path) as f: | 
					
						
						|  | d = json.load(f) | 
					
						
						|  | return cls.is_artifact_dict(d) | 
					
						
						|  |  | 
					
						
						|  | @classmethod | 
					
						
						|  | def is_registered_type(cls, type: str): | 
					
						
						|  | return type in cls._class_register | 
					
						
						|  |  | 
					
						
						|  | @classmethod | 
					
						
						|  | def is_registered_class_name(cls, class_name: str): | 
					
						
						|  | snake_case_key = camel_to_snake_case(class_name) | 
					
						
						|  | return cls.is_registered_type(snake_case_key) | 
					
						
						|  |  | 
					
						
						|  | @classmethod | 
					
						
						|  | def is_registered_class(cls, clz: object): | 
					
						
						|  | return clz in set(cls._class_register.values()) | 
					
						
						|  |  | 
					
						
						|  | @classmethod | 
					
						
						|  | def _recursive_load(cls, obj): | 
					
						
						|  | if isinstance(obj, dict): | 
					
						
						|  | new_d = {} | 
					
						
						|  | for key, value in obj.items(): | 
					
						
						|  | new_d[key] = cls._recursive_load(value) | 
					
						
						|  | obj = new_d | 
					
						
						|  | elif isinstance(obj, list): | 
					
						
						|  | obj = [cls._recursive_load(value) for value in obj] | 
					
						
						|  | else: | 
					
						
						|  | pass | 
					
						
						|  | if cls.is_artifact_dict(obj): | 
					
						
						|  | cls.verify_artifact_dict(obj) | 
					
						
						|  | return cls._class_register[obj.pop("type")](**obj) | 
					
						
						|  |  | 
					
						
						|  | return obj | 
					
						
						|  |  | 
					
						
						|  | @classmethod | 
					
						
						|  | def from_dict(cls, d, overwrite_args=None): | 
					
						
						|  | if overwrite_args is not None: | 
					
						
						|  | d = {**d, **overwrite_args} | 
					
						
						|  | cls.verify_artifact_dict(d) | 
					
						
						|  | return cls._recursive_load(d) | 
					
						
						|  |  | 
					
						
						|  | @classmethod | 
					
						
						|  | def load(cls, path, artifact_identifier=None, overwrite_args=None): | 
					
						
						|  | d = artifacts_json_cache(path) | 
					
						
						|  | new_artifact = cls.from_dict(d, overwrite_args=overwrite_args) | 
					
						
						|  | new_artifact.__id__ = artifact_identifier | 
					
						
						|  | return new_artifact | 
					
						
						|  |  | 
					
						
						|  | def prepare(self): | 
					
						
						|  | pass | 
					
						
						|  |  | 
					
						
						|  | def verify(self): | 
					
						
						|  | pass | 
					
						
						|  |  | 
					
						
						|  | @final | 
					
						
						|  | def __pre_init__(self, **kwargs): | 
					
						
						|  | self._init_dict = get_raw(kwargs) | 
					
						
						|  |  | 
					
						
						|  | @final | 
					
						
						|  | def __post_init__(self): | 
					
						
						|  | self.type = self.register_class(self.__class__) | 
					
						
						|  |  | 
					
						
						|  | for field in fields(self): | 
					
						
						|  | if issubtype( | 
					
						
						|  | field.type, Union[Artifact, List[Artifact], Dict[str, Artifact]] | 
					
						
						|  | ): | 
					
						
						|  | value = getattr(self, field.name) | 
					
						
						|  | value = map_values_in_place(value, maybe_recover_artifact) | 
					
						
						|  | setattr(self, field.name, value) | 
					
						
						|  |  | 
					
						
						|  | self.prepare() | 
					
						
						|  | self.verify() | 
					
						
						|  |  | 
					
						
						|  | def _to_raw_dict(self): | 
					
						
						|  | return {"type": self.type, **self._init_dict} | 
					
						
						|  |  | 
					
						
						|  | def save(self, path): | 
					
						
						|  | data = self.to_dict() | 
					
						
						|  | save_json(path, data) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def get_raw(obj): | 
					
						
						|  | if isinstance(obj, Artifact): | 
					
						
						|  | return obj._to_raw_dict() | 
					
						
						|  |  | 
					
						
						|  | if isinstance(obj, tuple) and hasattr(obj, "_fields"): | 
					
						
						|  | return type(obj)(*[get_raw(v) for v in obj]) | 
					
						
						|  |  | 
					
						
						|  | if isinstance(obj, (list, tuple)): | 
					
						
						|  | return type(obj)([get_raw(v) for v in obj]) | 
					
						
						|  |  | 
					
						
						|  | if isinstance(obj, dict): | 
					
						
						|  | return type(obj)({get_raw(k): get_raw(v) for k, v in obj.items()}) | 
					
						
						|  |  | 
					
						
						|  | return deepcopy(obj) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class ArtifactList(list, Artifact): | 
					
						
						|  | def prepare(self): | 
					
						
						|  | for artifact in self: | 
					
						
						|  | artifact.prepare() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class Artifactory(Artifact): | 
					
						
						|  | is_local: bool = AbstractField() | 
					
						
						|  |  | 
					
						
						|  | @abstractmethod | 
					
						
						|  | def __contains__(self, name: str) -> bool: | 
					
						
						|  | pass | 
					
						
						|  |  | 
					
						
						|  | @abstractmethod | 
					
						
						|  | def __getitem__(self, name) -> Artifact: | 
					
						
						|  | pass | 
					
						
						|  |  | 
					
						
						|  | @abstractmethod | 
					
						
						|  | def get_with_overwrite(self, name, overwrite_args) -> Artifact: | 
					
						
						|  | pass | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class UnitxtArtifactNotFoundError(Exception): | 
					
						
						|  | def __init__(self, name, artifactories): | 
					
						
						|  | self.name = name | 
					
						
						|  | self.artifactories = artifactories | 
					
						
						|  |  | 
					
						
						|  | def __str__(self): | 
					
						
						|  | msg = f"Artifact {self.name} does not exist, in artifactories:{self.artifactories}." | 
					
						
						|  | if settings.use_only_local_catalogs: | 
					
						
						|  | msg += f" Notice that unitxt.settings.use_only_local_catalogs is set to True, if you want to use remote catalogs set this settings or the environment variable {settings.use_only_local_catalogs_key}." | 
					
						
						|  | return f"Artifact {self.name} does not exist, in artifactories:{self.artifactories}" | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def fetch_artifact(name): | 
					
						
						|  | if Artifact.is_artifact_file(name): | 
					
						
						|  | return Artifact.load(name), None | 
					
						
						|  |  | 
					
						
						|  | artifactory, name, args = get_artifactory_name_and_args(name=name) | 
					
						
						|  |  | 
					
						
						|  | return artifactory.get_with_overwrite(name, overwrite_args=args), artifactory | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def get_artifactory_name_and_args( | 
					
						
						|  | name: str, artifactories: Optional[List[Artifactory]] = None | 
					
						
						|  | ): | 
					
						
						|  | name, args = separate_inside_and_outside_square_brackets(name) | 
					
						
						|  |  | 
					
						
						|  | if artifactories is None: | 
					
						
						|  | artifactories = list(Artifactories()) | 
					
						
						|  |  | 
					
						
						|  | for artifactory in artifactories: | 
					
						
						|  | if name in artifactory: | 
					
						
						|  | return artifactory, name, args | 
					
						
						|  |  | 
					
						
						|  | raise UnitxtArtifactNotFoundError(name, artifactories) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def verbosed_fetch_artifact(identifier): | 
					
						
						|  | artifact, artifactory = fetch_artifact(identifier) | 
					
						
						|  | logger.info(f"Artifact {identifier} is fetched from {artifactory}") | 
					
						
						|  | return artifact | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def reset_artifacts_json_cache(): | 
					
						
						|  | artifacts_json_cache.cache_clear() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def maybe_recover_artifact(artifact): | 
					
						
						|  | if isinstance(artifact, str): | 
					
						
						|  | return verbosed_fetch_artifact(artifact) | 
					
						
						|  |  | 
					
						
						|  | return artifact | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def register_all_artifacts(path): | 
					
						
						|  | for loader, module_name, _is_pkg in pkgutil.walk_packages(path): | 
					
						
						|  | logger.info(__name__) | 
					
						
						|  | if module_name == __name__: | 
					
						
						|  | continue | 
					
						
						|  | logger.info(f"Loading {module_name}") | 
					
						
						|  |  | 
					
						
						|  | module = loader.find_module(module_name).load_module(module_name) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | for _name, obj in inspect.getmembers(module): | 
					
						
						|  |  | 
					
						
						|  | if inspect.isclass(obj): | 
					
						
						|  |  | 
					
						
						|  | if issubclass(obj, Artifact) and obj is not Artifact: | 
					
						
						|  | logger.info(obj) | 
					
						
						|  |  |