Spaces:
Sleeping
Sleeping
| from pathlib import Path | |
| from typing import Dict, List, Optional, Type | |
| from threading import Thread | |
| from omagent_core.engine.configuration.aaas_config import AaasConfig | |
| import yaml | |
| from omagent_core.engine.configuration.configuration import (TEMPLATE_CONFIG, | |
| Configuration) | |
| from omagent_core.engine.configuration.aaas_config import AAAS_TEMPLATE_CONFIG | |
| from omagent_core.utils.registry import registry | |
| from pydantic import BaseModel | |
| import os | |
| class Container: | |
| def __init__(self): | |
| self._connectors: Dict[str, BaseModel] = {} | |
| self._components: Dict[str, BaseModel] = {} | |
| self._stm_name: Optional[str] = None | |
| self._ltm_name: Optional[str] = None | |
| self._callback_name: Optional[str] = None | |
| self._input_name: Optional[str] = None | |
| self.conductor_config = Configuration() | |
| self.aaas_config = AaasConfig() | |
| def register_connector( | |
| self, | |
| connector: Type[BaseModel], | |
| name: str = None, | |
| overwrite: bool = False, | |
| **kwargs, | |
| ) -> None: | |
| """Register a connector""" | |
| if name is None: | |
| name = connector.__name__ | |
| if name not in self._connectors or overwrite: | |
| self._connectors[name] = connector(**kwargs) | |
| def get_connector(self, name: str) -> BaseModel: | |
| if name not in self._connectors: | |
| raise KeyError(f"There is no connector named '{name}' in container.") | |
| return self._connectors[name] | |
| def register_component( | |
| self, | |
| component: str | Type[BaseModel], | |
| name: str = None, | |
| config: dict = {}, | |
| overwrite: bool = False, | |
| ) -> None: | |
| """Generic component registration method | |
| Args: | |
| component: Component name or class | |
| key: The key to save and retrieve component | |
| config: Component configuration | |
| target_dict: Target dictionary to store component instances | |
| component_category: One of the register mapping types, should be provided if component is a string | |
| """ | |
| if isinstance(component, str): | |
| component_cls = registry.get_component(component) | |
| component_name = component | |
| if not component_cls: | |
| raise ValueError(f"{component} not found in registry") | |
| elif isinstance(component, type) and issubclass(component, BaseModel): | |
| component_cls = component | |
| component_name = component.__name__ | |
| else: | |
| raise ValueError(f"Invalid component type: {type(component)}") | |
| if ( | |
| name in self._components or component_name in self._components | |
| ) and not overwrite: | |
| return name or component_name | |
| required_connectors = self._get_required_connectors(component_cls) | |
| if required_connectors: | |
| for connector, cls_name in required_connectors: | |
| if connector not in self._connectors.keys(): | |
| connector_cls = registry.get_connector(cls_name) | |
| self.register_connector(connector_cls, connector) | |
| config[connector] = self._connectors[connector] | |
| self._components[name or component_name] = component_cls(**config) | |
| return name or component_name | |
| def get_component(self, component_name: str) -> BaseModel: | |
| if component_name not in self._components: | |
| raise KeyError( | |
| f"There is no component named '{component_name}' in container. You need to register it first." | |
| ) | |
| return self._components[component_name] | |
| def _get_required_connectors(self, cls: Type[BaseModel]) -> List[str]: | |
| required_connectors = [] | |
| for field_name, field in cls.model_fields.items(): | |
| if isinstance(field.annotation, type) and issubclass( | |
| field.annotation, BaseModel | |
| ): | |
| required_connectors.append([field_name, field.annotation.__name__]) | |
| return required_connectors | |
| def components(self) -> Dict[str, BaseModel]: | |
| return self._components | |
| def register_stm( | |
| self, | |
| stm: str | Type[BaseModel], | |
| name: str = None, | |
| config: dict = {}, | |
| overwrite: bool = False, | |
| ): | |
| if os.getenv("OMAGENT_MODE") == "lite": | |
| name = "SharedMemSTM" | |
| name = self.register_component(stm, name, config, overwrite) | |
| self._stm_name = name | |
| def stm(self) -> BaseModel: | |
| if self._stm_name is None: | |
| if os.getenv("OMAGENT_MODE") == "lite": | |
| self.register_stm("SharedMemSTM") | |
| self._stm_name = "SharedMemSTM" | |
| else: | |
| raise ValueError( | |
| "STM component is not registered. Please use register_stm to register." | |
| ) | |
| return self.get_component(self._stm_name) | |
| def register_ltm( | |
| self, | |
| ltm: str | Type[BaseModel], | |
| name: str = None, | |
| config: dict = {}, | |
| overwrite: bool = False, | |
| ): | |
| name = self.register_component(ltm, name, config, overwrite) | |
| self._ltm_name = name | |
| def ltm(self) -> BaseModel: | |
| if self._ltm_name is None: | |
| raise ValueError( | |
| "LTM component is not registered. Please use register_ltm to register." | |
| ) | |
| return self.get_component(self._ltm_name) | |
| def register_callback( | |
| self, | |
| callback: str | Type[BaseModel], | |
| name: str = None, | |
| config: dict = {}, | |
| overwrite: bool = False, | |
| ): | |
| name = self.register_component(callback, name, config, overwrite) | |
| self._callback_name = name | |
| def callback(self) -> BaseModel: | |
| if self._callback_name is None: | |
| raise ValueError( | |
| "Callback component is not registered. Please use register_callback to register." | |
| ) | |
| return self.get_component(self._callback_name) | |
| def register_input( | |
| self, | |
| input: str | Type[BaseModel], | |
| name: str = None, | |
| config: dict = {}, | |
| overwrite: bool = False, | |
| ): | |
| name = self.register_component(input, name, config, overwrite) | |
| self._input_name = name | |
| def input(self) -> BaseModel: | |
| if self._input_name is None: | |
| raise ValueError( | |
| "Input component is not registered. Please use register_input to register." | |
| ) | |
| return self.get_component(self._input_name) | |
| def compile_config( | |
| self, output_path: Path, description: bool = True, env_var: bool = True | |
| ) -> None: | |
| if (output_path / "container.yaml").exists(): | |
| print("container.yaml already exists, skip compiling") | |
| config = yaml.load( | |
| open(output_path / "container.yaml", "r"), Loader=yaml.FullLoader | |
| ) | |
| return config | |
| config = { | |
| "conductor_config": TEMPLATE_CONFIG, | |
| "aaas_config": AAAS_TEMPLATE_CONFIG, | |
| "connectors": {}, | |
| "components": {}, | |
| } | |
| exclude_fields = [ | |
| "_parent", | |
| "component_stm", | |
| "component_ltm", | |
| "component_callback", | |
| "component_input", | |
| ] | |
| for name, connector in self._connectors.items(): | |
| config["connectors"][name] = connector.__class__.get_config_template( | |
| description=description, env_var=env_var, exclude_fields=exclude_fields | |
| ) | |
| exclude_fields.extend(self._connectors.keys()) | |
| for name, component in self._components.items(): | |
| config["components"][name] = component.__class__.get_config_template( | |
| description=description, env_var=env_var, exclude_fields=exclude_fields | |
| ) | |
| with open(output_path / "container.yaml", "w") as f: | |
| f.write(yaml.dump(config, sort_keys=False, allow_unicode=True)) | |
| return config | |
| def from_config(self, config_data: dict | str | Path) -> None: | |
| """Update container from configuration | |
| Args: | |
| config_data: The dict including connectors and components configurations | |
| """ | |
| def clean_config_dict(config_dict: dict) -> dict: | |
| """Recursively clean up the configuration dictionary, removing all 'description' and 'env_var' keys""" | |
| cleaned = {} | |
| for key, value in config_dict.items(): | |
| if isinstance(value, dict): | |
| if "value" in value: | |
| cleaned[key] = value["value"] | |
| else: | |
| cleaned[key] = clean_config_dict(value) | |
| else: | |
| cleaned[key] = value | |
| return cleaned | |
| if isinstance(config_data, str | Path): | |
| if not Path(config_data).exists(): | |
| if os.getenv("OMAGENT_MODE") == "lite": | |
| return | |
| else: | |
| raise FileNotFoundError(f"Config file not found: {config_data}") | |
| config_data = yaml.load(open(config_data, "r"), Loader=yaml.FullLoader) | |
| config_data = clean_config_dict(config_data) | |
| if "conductor_config" in config_data: | |
| self.conductor_config = Configuration(**config_data["conductor_config"]) | |
| if "aaas_config" in config_data: | |
| self.aaas_config = AaasConfig(**config_data["aaas_config"]) | |
| # connectors | |
| if "connectors" in config_data: | |
| for name, config in config_data["connectors"].items(): | |
| connector_cls = registry.get_connector(config.pop("name")) | |
| if connector_cls: | |
| self.register_connector( | |
| name=name, connector=connector_cls, overwrite=True, **config | |
| ) | |
| # components | |
| if "components" in config_data: | |
| for name, config in config_data["components"].items(): | |
| self.register_component( | |
| component=config.pop("name"), | |
| name=name, | |
| config=config, | |
| overwrite=True, | |
| ) | |
| self.check_connection() | |
| def check_connection(self): | |
| if os.getenv("OMAGENT_MODE") == "lite": | |
| return | |
| for name, connector in self._connectors.items(): | |
| try: | |
| connector.check_connection() | |
| except Exception as e: | |
| raise ConnectionError( | |
| f"Connection to {name} failed. Please check your connector config in container.yaml. \n Error Message: {e}" | |
| ) | |
| try: | |
| from omagent_core.engine.orkes.orkes_workflow_client import \ | |
| OrkesWorkflowClient | |
| conductor_client = OrkesWorkflowClient(self.conductor_config) | |
| conductor_client.check_connection() | |
| except Exception as e: | |
| raise ConnectionError( | |
| f"Connection to Conductor failed. Please check your conductor config in container.yaml. \n Error Message: {e}" | |
| ) | |
| print("--------------------------------") | |
| print("All connections passed the connection check") | |
| print("--------------------------------") | |
| container = Container() | |