Spaces:
Running
Running
from pathlib import Path | |
from typing import Dict, List, Optional, Type | |
from threading import Thread | |
from omagent_core.engine.configuration.aaas_config import AaasConfig | |
import yaml | |
from omagent_core.engine.configuration.configuration import (TEMPLATE_CONFIG, | |
Configuration) | |
from omagent_core.engine.configuration.aaas_config import AAAS_TEMPLATE_CONFIG | |
from omagent_core.utils.registry import registry | |
from pydantic import BaseModel | |
import os | |
class Container: | |
def __init__(self): | |
self._connectors: Dict[str, BaseModel] = {} | |
self._components: Dict[str, BaseModel] = {} | |
self._stm_name: Optional[str] = None | |
self._ltm_name: Optional[str] = None | |
self._callback_name: Optional[str] = None | |
self._input_name: Optional[str] = None | |
self.conductor_config = Configuration() | |
self.aaas_config = AaasConfig() | |
def register_connector( | |
self, | |
connector: Type[BaseModel], | |
name: str = None, | |
overwrite: bool = False, | |
**kwargs, | |
) -> None: | |
"""Register a connector""" | |
if name is None: | |
name = connector.__name__ | |
if name not in self._connectors or overwrite: | |
self._connectors[name] = connector(**kwargs) | |
def get_connector(self, name: str) -> BaseModel: | |
if name not in self._connectors: | |
raise KeyError(f"There is no connector named '{name}' in container.") | |
return self._connectors[name] | |
def register_component( | |
self, | |
component: str | Type[BaseModel], | |
name: str = None, | |
config: dict = {}, | |
overwrite: bool = False, | |
) -> None: | |
"""Generic component registration method | |
Args: | |
component: Component name or class | |
key: The key to save and retrieve component | |
config: Component configuration | |
target_dict: Target dictionary to store component instances | |
component_category: One of the register mapping types, should be provided if component is a string | |
""" | |
if isinstance(component, str): | |
component_cls = registry.get_component(component) | |
component_name = component | |
if not component_cls: | |
raise ValueError(f"{component} not found in registry") | |
elif isinstance(component, type) and issubclass(component, BaseModel): | |
component_cls = component | |
component_name = component.__name__ | |
else: | |
raise ValueError(f"Invalid component type: {type(component)}") | |
if ( | |
name in self._components or component_name in self._components | |
) and not overwrite: | |
return name or component_name | |
required_connectors = self._get_required_connectors(component_cls) | |
if required_connectors: | |
for connector, cls_name in required_connectors: | |
if connector not in self._connectors.keys(): | |
connector_cls = registry.get_connector(cls_name) | |
self.register_connector(connector_cls, connector) | |
config[connector] = self._connectors[connector] | |
self._components[name or component_name] = component_cls(**config) | |
return name or component_name | |
def get_component(self, component_name: str) -> BaseModel: | |
if component_name not in self._components: | |
raise KeyError( | |
f"There is no component named '{component_name}' in container. You need to register it first." | |
) | |
return self._components[component_name] | |
def _get_required_connectors(self, cls: Type[BaseModel]) -> List[str]: | |
required_connectors = [] | |
for field_name, field in cls.model_fields.items(): | |
if isinstance(field.annotation, type) and issubclass( | |
field.annotation, BaseModel | |
): | |
required_connectors.append([field_name, field.annotation.__name__]) | |
return required_connectors | |
def components(self) -> Dict[str, BaseModel]: | |
return self._components | |
def register_stm( | |
self, | |
stm: str | Type[BaseModel], | |
name: str = None, | |
config: dict = {}, | |
overwrite: bool = False, | |
): | |
if os.getenv("OMAGENT_MODE") == "lite": | |
name = "SharedMemSTM" | |
name = self.register_component(stm, name, config, overwrite) | |
self._stm_name = name | |
def stm(self) -> BaseModel: | |
if self._stm_name is None: | |
if os.getenv("OMAGENT_MODE") == "lite": | |
self.register_stm("SharedMemSTM") | |
self._stm_name = "SharedMemSTM" | |
else: | |
raise ValueError( | |
"STM component is not registered. Please use register_stm to register." | |
) | |
return self.get_component(self._stm_name) | |
def register_ltm( | |
self, | |
ltm: str | Type[BaseModel], | |
name: str = None, | |
config: dict = {}, | |
overwrite: bool = False, | |
): | |
name = self.register_component(ltm, name, config, overwrite) | |
self._ltm_name = name | |
def ltm(self) -> BaseModel: | |
if self._ltm_name is None: | |
raise ValueError( | |
"LTM component is not registered. Please use register_ltm to register." | |
) | |
return self.get_component(self._ltm_name) | |
def register_callback( | |
self, | |
callback: str | Type[BaseModel], | |
name: str = None, | |
config: dict = {}, | |
overwrite: bool = False, | |
): | |
name = self.register_component(callback, name, config, overwrite) | |
self._callback_name = name | |
def callback(self) -> BaseModel: | |
if self._callback_name is None: | |
raise ValueError( | |
"Callback component is not registered. Please use register_callback to register." | |
) | |
return self.get_component(self._callback_name) | |
def register_input( | |
self, | |
input: str | Type[BaseModel], | |
name: str = None, | |
config: dict = {}, | |
overwrite: bool = False, | |
): | |
name = self.register_component(input, name, config, overwrite) | |
self._input_name = name | |
def input(self) -> BaseModel: | |
if self._input_name is None: | |
raise ValueError( | |
"Input component is not registered. Please use register_input to register." | |
) | |
return self.get_component(self._input_name) | |
def compile_config( | |
self, output_path: Path, description: bool = True, env_var: bool = True | |
) -> None: | |
if (output_path / "container.yaml").exists(): | |
print("container.yaml already exists, skip compiling") | |
config = yaml.load( | |
open(output_path / "container.yaml", "r"), Loader=yaml.FullLoader | |
) | |
return config | |
config = { | |
"conductor_config": TEMPLATE_CONFIG, | |
"aaas_config": AAAS_TEMPLATE_CONFIG, | |
"connectors": {}, | |
"components": {}, | |
} | |
exclude_fields = [ | |
"_parent", | |
"component_stm", | |
"component_ltm", | |
"component_callback", | |
"component_input", | |
] | |
for name, connector in self._connectors.items(): | |
config["connectors"][name] = connector.__class__.get_config_template( | |
description=description, env_var=env_var, exclude_fields=exclude_fields | |
) | |
exclude_fields.extend(self._connectors.keys()) | |
for name, component in self._components.items(): | |
config["components"][name] = component.__class__.get_config_template( | |
description=description, env_var=env_var, exclude_fields=exclude_fields | |
) | |
with open(output_path / "container.yaml", "w") as f: | |
f.write(yaml.dump(config, sort_keys=False, allow_unicode=True)) | |
return config | |
def from_config(self, config_data: dict | str | Path) -> None: | |
"""Update container from configuration | |
Args: | |
config_data: The dict including connectors and components configurations | |
""" | |
def clean_config_dict(config_dict: dict) -> dict: | |
"""Recursively clean up the configuration dictionary, removing all 'description' and 'env_var' keys""" | |
cleaned = {} | |
for key, value in config_dict.items(): | |
if isinstance(value, dict): | |
if "value" in value: | |
cleaned[key] = value["value"] | |
else: | |
cleaned[key] = clean_config_dict(value) | |
else: | |
cleaned[key] = value | |
return cleaned | |
if isinstance(config_data, str | Path): | |
if not Path(config_data).exists(): | |
if os.getenv("OMAGENT_MODE") == "lite": | |
return | |
else: | |
raise FileNotFoundError(f"Config file not found: {config_data}") | |
config_data = yaml.load(open(config_data, "r"), Loader=yaml.FullLoader) | |
config_data = clean_config_dict(config_data) | |
if "conductor_config" in config_data: | |
self.conductor_config = Configuration(**config_data["conductor_config"]) | |
if "aaas_config" in config_data: | |
self.aaas_config = AaasConfig(**config_data["aaas_config"]) | |
# connectors | |
if "connectors" in config_data: | |
for name, config in config_data["connectors"].items(): | |
connector_cls = registry.get_connector(config.pop("name")) | |
if connector_cls: | |
self.register_connector( | |
name=name, connector=connector_cls, overwrite=True, **config | |
) | |
# components | |
if "components" in config_data: | |
for name, config in config_data["components"].items(): | |
self.register_component( | |
component=config.pop("name"), | |
name=name, | |
config=config, | |
overwrite=True, | |
) | |
self.check_connection() | |
def check_connection(self): | |
if os.getenv("OMAGENT_MODE") == "lite": | |
return | |
for name, connector in self._connectors.items(): | |
try: | |
connector.check_connection() | |
except Exception as e: | |
raise ConnectionError( | |
f"Connection to {name} failed. Please check your connector config in container.yaml. \n Error Message: {e}" | |
) | |
try: | |
from omagent_core.engine.orkes.orkes_workflow_client import \ | |
OrkesWorkflowClient | |
conductor_client = OrkesWorkflowClient(self.conductor_config) | |
conductor_client.check_connection() | |
except Exception as e: | |
raise ConnectionError( | |
f"Connection to Conductor failed. Please check your conductor config in container.yaml. \n Error Message: {e}" | |
) | |
print("--------------------------------") | |
print("All connections passed the connection check") | |
print("--------------------------------") | |
container = Container() | |