韩宇
init
1b7e88c
from pathlib import Path
from typing import Dict, List, Optional, Type
from threading import Thread
from omagent_core.engine.configuration.aaas_config import AaasConfig
import yaml
from omagent_core.engine.configuration.configuration import (TEMPLATE_CONFIG,
Configuration)
from omagent_core.engine.configuration.aaas_config import AAAS_TEMPLATE_CONFIG
from omagent_core.utils.registry import registry
from pydantic import BaseModel
import os
class Container:
def __init__(self):
self._connectors: Dict[str, BaseModel] = {}
self._components: Dict[str, BaseModel] = {}
self._stm_name: Optional[str] = None
self._ltm_name: Optional[str] = None
self._callback_name: Optional[str] = None
self._input_name: Optional[str] = None
self.conductor_config = Configuration()
self.aaas_config = AaasConfig()
def register_connector(
self,
connector: Type[BaseModel],
name: str = None,
overwrite: bool = False,
**kwargs,
) -> None:
"""Register a connector"""
if name is None:
name = connector.__name__
if name not in self._connectors or overwrite:
self._connectors[name] = connector(**kwargs)
def get_connector(self, name: str) -> BaseModel:
if name not in self._connectors:
raise KeyError(f"There is no connector named '{name}' in container.")
return self._connectors[name]
def register_component(
self,
component: str | Type[BaseModel],
name: str = None,
config: dict = {},
overwrite: bool = False,
) -> None:
"""Generic component registration method
Args:
component: Component name or class
key: The key to save and retrieve component
config: Component configuration
target_dict: Target dictionary to store component instances
component_category: One of the register mapping types, should be provided if component is a string
"""
if isinstance(component, str):
component_cls = registry.get_component(component)
component_name = component
if not component_cls:
raise ValueError(f"{component} not found in registry")
elif isinstance(component, type) and issubclass(component, BaseModel):
component_cls = component
component_name = component.__name__
else:
raise ValueError(f"Invalid component type: {type(component)}")
if (
name in self._components or component_name in self._components
) and not overwrite:
return name or component_name
required_connectors = self._get_required_connectors(component_cls)
if required_connectors:
for connector, cls_name in required_connectors:
if connector not in self._connectors.keys():
connector_cls = registry.get_connector(cls_name)
self.register_connector(connector_cls, connector)
config[connector] = self._connectors[connector]
self._components[name or component_name] = component_cls(**config)
return name or component_name
def get_component(self, component_name: str) -> BaseModel:
if component_name not in self._components:
raise KeyError(
f"There is no component named '{component_name}' in container. You need to register it first."
)
return self._components[component_name]
def _get_required_connectors(self, cls: Type[BaseModel]) -> List[str]:
required_connectors = []
for field_name, field in cls.model_fields.items():
if isinstance(field.annotation, type) and issubclass(
field.annotation, BaseModel
):
required_connectors.append([field_name, field.annotation.__name__])
return required_connectors
@property
def components(self) -> Dict[str, BaseModel]:
return self._components
def register_stm(
self,
stm: str | Type[BaseModel],
name: str = None,
config: dict = {},
overwrite: bool = False,
):
if os.getenv("OMAGENT_MODE") == "lite":
name = "SharedMemSTM"
name = self.register_component(stm, name, config, overwrite)
self._stm_name = name
@property
def stm(self) -> BaseModel:
if self._stm_name is None:
if os.getenv("OMAGENT_MODE") == "lite":
self.register_stm("SharedMemSTM")
self._stm_name = "SharedMemSTM"
else:
raise ValueError(
"STM component is not registered. Please use register_stm to register."
)
return self.get_component(self._stm_name)
def register_ltm(
self,
ltm: str | Type[BaseModel],
name: str = None,
config: dict = {},
overwrite: bool = False,
):
name = self.register_component(ltm, name, config, overwrite)
self._ltm_name = name
@property
def ltm(self) -> BaseModel:
if self._ltm_name is None:
raise ValueError(
"LTM component is not registered. Please use register_ltm to register."
)
return self.get_component(self._ltm_name)
def register_callback(
self,
callback: str | Type[BaseModel],
name: str = None,
config: dict = {},
overwrite: bool = False,
):
name = self.register_component(callback, name, config, overwrite)
self._callback_name = name
@property
def callback(self) -> BaseModel:
if self._callback_name is None:
raise ValueError(
"Callback component is not registered. Please use register_callback to register."
)
return self.get_component(self._callback_name)
def register_input(
self,
input: str | Type[BaseModel],
name: str = None,
config: dict = {},
overwrite: bool = False,
):
name = self.register_component(input, name, config, overwrite)
self._input_name = name
@property
def input(self) -> BaseModel:
if self._input_name is None:
raise ValueError(
"Input component is not registered. Please use register_input to register."
)
return self.get_component(self._input_name)
def compile_config(
self, output_path: Path, description: bool = True, env_var: bool = True
) -> None:
if (output_path / "container.yaml").exists():
print("container.yaml already exists, skip compiling")
config = yaml.load(
open(output_path / "container.yaml", "r"), Loader=yaml.FullLoader
)
return config
config = {
"conductor_config": TEMPLATE_CONFIG,
"aaas_config": AAAS_TEMPLATE_CONFIG,
"connectors": {},
"components": {},
}
exclude_fields = [
"_parent",
"component_stm",
"component_ltm",
"component_callback",
"component_input",
]
for name, connector in self._connectors.items():
config["connectors"][name] = connector.__class__.get_config_template(
description=description, env_var=env_var, exclude_fields=exclude_fields
)
exclude_fields.extend(self._connectors.keys())
for name, component in self._components.items():
config["components"][name] = component.__class__.get_config_template(
description=description, env_var=env_var, exclude_fields=exclude_fields
)
with open(output_path / "container.yaml", "w") as f:
f.write(yaml.dump(config, sort_keys=False, allow_unicode=True))
return config
def from_config(self, config_data: dict | str | Path) -> None:
"""Update container from configuration
Args:
config_data: The dict including connectors and components configurations
"""
def clean_config_dict(config_dict: dict) -> dict:
"""Recursively clean up the configuration dictionary, removing all 'description' and 'env_var' keys"""
cleaned = {}
for key, value in config_dict.items():
if isinstance(value, dict):
if "value" in value:
cleaned[key] = value["value"]
else:
cleaned[key] = clean_config_dict(value)
else:
cleaned[key] = value
return cleaned
if isinstance(config_data, str | Path):
if not Path(config_data).exists():
if os.getenv("OMAGENT_MODE") == "lite":
return
else:
raise FileNotFoundError(f"Config file not found: {config_data}")
config_data = yaml.load(open(config_data, "r"), Loader=yaml.FullLoader)
config_data = clean_config_dict(config_data)
if "conductor_config" in config_data:
self.conductor_config = Configuration(**config_data["conductor_config"])
if "aaas_config" in config_data:
self.aaas_config = AaasConfig(**config_data["aaas_config"])
# connectors
if "connectors" in config_data:
for name, config in config_data["connectors"].items():
connector_cls = registry.get_connector(config.pop("name"))
if connector_cls:
self.register_connector(
name=name, connector=connector_cls, overwrite=True, **config
)
# components
if "components" in config_data:
for name, config in config_data["components"].items():
self.register_component(
component=config.pop("name"),
name=name,
config=config,
overwrite=True,
)
self.check_connection()
def check_connection(self):
if os.getenv("OMAGENT_MODE") == "lite":
return
for name, connector in self._connectors.items():
try:
connector.check_connection()
except Exception as e:
raise ConnectionError(
f"Connection to {name} failed. Please check your connector config in container.yaml. \n Error Message: {e}"
)
try:
from omagent_core.engine.orkes.orkes_workflow_client import \
OrkesWorkflowClient
conductor_client = OrkesWorkflowClient(self.conductor_config)
conductor_client.check_connection()
except Exception as e:
raise ConnectionError(
f"Connection to Conductor failed. Please check your conductor config in container.yaml. \n Error Message: {e}"
)
print("--------------------------------")
print("All connections passed the connection check")
print("--------------------------------")
container = Container()