Spaces:
Running
Running
import importlib | |
from functools import partial | |
from pathlib import Path | |
from typing import Any, Callable, Dict, List | |
CATEGORIES = [ | |
"prompt", | |
"llm", | |
"node", | |
"worker", | |
"tool", | |
"encoder", | |
"connector", | |
"component", | |
] | |
class Registry: | |
"""Class for module registration and retrieval.""" | |
def __init__(self): | |
# Initializes a mapping for different categories of modules. | |
self.mapping = {key: {} for key in CATEGORIES} | |
def __getattr__(self, name: str) -> Callable: | |
if name.startswith(("register_", "get_")): | |
prefix, category = name.split("_", 1) | |
if category in CATEGORIES: | |
if prefix == "register": | |
return partial(self.register, category) | |
elif prefix == "get": | |
return partial(self.get, category) | |
raise AttributeError( | |
f"'{self.__class__.__name__}' object has no attribute '{name}'" | |
) | |
def _register(self, category: str, name: str = None): | |
""" | |
Registers a module under a specific category. | |
:param category: The category to register the module under. | |
:param name: The name to register the module as. | |
""" | |
def wrap(module): | |
nonlocal name | |
name = name or module.__name__ | |
if name in self.mapping[category]: | |
raise ValueError( | |
f"Module {name} [{self.mapping[category].get(name)}] already registered in category {category}. Please use a different class name." | |
) | |
self.mapping.setdefault(category, {})[name] = module | |
return module | |
return wrap | |
def _get(self, category: str, name: str): | |
""" | |
Retrieves a module from a specified category. | |
:param category: The category to search in. | |
:param name: The name of the module to retrieve. | |
:raises KeyError: If the module is not found. | |
""" | |
try: | |
return self.mapping[category][name] | |
except KeyError: | |
raise KeyError(f"Module {name} not found in category {category}") | |
def register(self, category: str, name: str = None): | |
""" | |
Registers a module under a general category. | |
:param category: The category to register the module under. | |
:param name: Optional name to register the module as. | |
""" | |
return self._register(category, name) | |
def get(self, category: str, name: str): | |
""" | |
Retrieves a module from a general category. | |
:param category: The category to search in. | |
:param name: The name of the module to retrieve. | |
""" | |
return self._get(category, name) | |
def import_module(self, project_path: List[str] | str = None): | |
"""Import modules from default paths and optional project paths. | |
Args: | |
project_path: Optional path or list of paths to import modules from | |
""" | |
# Handle default paths | |
root_path = Path(__file__).parents[1] | |
default_path = [ | |
root_path.joinpath("models"), | |
root_path.joinpath("tool_system"), | |
root_path.joinpath("services"), | |
root_path.joinpath("memories"), | |
root_path.joinpath("advanced_components"), | |
root_path.joinpath("clients"), | |
] | |
for path in default_path: | |
for module in path.rglob("*.[ps][yo]"): | |
if module.name == "workflow.py": | |
continue | |
module = str(module) | |
if "__init__" in module or "base.py" in module or "entry.py" in module: | |
continue | |
module = "omagent_core" + module.rsplit("omagent_core", 1)[1].rsplit( | |
".", 1 | |
)[0].replace("/", ".") | |
importlib.import_module(module) | |
# Handle project paths | |
if project_path: | |
if isinstance(project_path, (str, Path)): | |
project_path = [project_path] | |
for path in project_path: | |
path = Path(path).absolute() | |
project_root = path.parent | |
for module in path.rglob("*.[ps][yo]"): | |
module = str(module) | |
if "__init__" in module: | |
continue | |
module = ( | |
module.replace(str(project_root) + "/", "") | |
.rsplit(".", 1)[0] | |
.replace("/", ".") | |
) | |
importlib.import_module(module) | |
# Instantiate registry | |
registry = Registry() | |
if __name__ == "__main__": | |
class TestNode: | |
name: "TestNode" | |
print(registry.get_node("TestNode")) | |