韩宇
init
1b7e88c
import importlib
from functools import partial
from pathlib import Path
from typing import Any, Callable, Dict, List
CATEGORIES = [
"prompt",
"llm",
"node",
"worker",
"tool",
"encoder",
"connector",
"component",
]
class Registry:
"""Class for module registration and retrieval."""
def __init__(self):
# Initializes a mapping for different categories of modules.
self.mapping = {key: {} for key in CATEGORIES}
def __getattr__(self, name: str) -> Callable:
if name.startswith(("register_", "get_")):
prefix, category = name.split("_", 1)
if category in CATEGORIES:
if prefix == "register":
return partial(self.register, category)
elif prefix == "get":
return partial(self.get, category)
raise AttributeError(
f"'{self.__class__.__name__}' object has no attribute '{name}'"
)
def _register(self, category: str, name: str = None):
"""
Registers a module under a specific category.
:param category: The category to register the module under.
:param name: The name to register the module as.
"""
def wrap(module):
nonlocal name
name = name or module.__name__
if name in self.mapping[category]:
raise ValueError(
f"Module {name} [{self.mapping[category].get(name)}] already registered in category {category}. Please use a different class name."
)
self.mapping.setdefault(category, {})[name] = module
return module
return wrap
def _get(self, category: str, name: str):
"""
Retrieves a module from a specified category.
:param category: The category to search in.
:param name: The name of the module to retrieve.
:raises KeyError: If the module is not found.
"""
try:
return self.mapping[category][name]
except KeyError:
raise KeyError(f"Module {name} not found in category {category}")
def register(self, category: str, name: str = None):
"""
Registers a module under a general category.
:param category: The category to register the module under.
:param name: Optional name to register the module as.
"""
return self._register(category, name)
def get(self, category: str, name: str):
"""
Retrieves a module from a general category.
:param category: The category to search in.
:param name: The name of the module to retrieve.
"""
return self._get(category, name)
def import_module(self, project_path: List[str] | str = None):
"""Import modules from default paths and optional project paths.
Args:
project_path: Optional path or list of paths to import modules from
"""
# Handle default paths
root_path = Path(__file__).parents[1]
default_path = [
root_path.joinpath("models"),
root_path.joinpath("tool_system"),
root_path.joinpath("services"),
root_path.joinpath("memories"),
root_path.joinpath("advanced_components"),
root_path.joinpath("clients"),
]
for path in default_path:
for module in path.rglob("*.[ps][yo]"):
if module.name == "workflow.py":
continue
module = str(module)
if "__init__" in module or "base.py" in module or "entry.py" in module:
continue
module = "omagent_core" + module.rsplit("omagent_core", 1)[1].rsplit(
".", 1
)[0].replace("/", ".")
importlib.import_module(module)
# Handle project paths
if project_path:
if isinstance(project_path, (str, Path)):
project_path = [project_path]
for path in project_path:
path = Path(path).absolute()
project_root = path.parent
for module in path.rglob("*.[ps][yo]"):
module = str(module)
if "__init__" in module:
continue
module = (
module.replace(str(project_root) + "/", "")
.rsplit(".", 1)[0]
.replace("/", ".")
)
importlib.import_module(module)
# Instantiate registry
registry = Registry()
if __name__ == "__main__":
@registry.register_node()
class TestNode:
name: "TestNode"
print(registry.get_node("TestNode"))