Spaces:
Sleeping
Sleeping
| import asyncio | |
| import dataclasses | |
| import inspect | |
| import logging | |
| import socket | |
| import time | |
| import traceback | |
| from abc import ABC, abstractmethod | |
| from copy import deepcopy | |
| from typing import Any, Callable, Optional, Union | |
| from omagent_core.base import BotBase | |
| from omagent_core.engine.automator import utils | |
| from omagent_core.engine.automator.utils import convert_from_dict_or_list | |
| from omagent_core.engine.configuration.configuration import Configuration | |
| from omagent_core.engine.http.api_client import ApiClient | |
| from omagent_core.engine.http.models import TaskExecLog | |
| from omagent_core.engine.http.models.task import Task | |
| from omagent_core.engine.http.models.task_result import TaskResult | |
| from omagent_core.engine.http.models.task_result_status import TaskResultStatus | |
| from omagent_core.engine.worker.exception import NonRetryableException | |
| from pydantic import Field | |
| from typing_extensions import Self | |
| ExecuteTaskFunction = Callable[[Union[Task, object]], Union[TaskResult, object]] | |
| logger = logging.getLogger(Configuration.get_logging_formatted_name(__name__)) | |
| def is_callable_input_parameter_a_task( | |
| callable: ExecuteTaskFunction, object_type: Any | |
| ) -> bool: | |
| parameters = inspect.signature(callable).parameters | |
| if len(parameters) != 1: | |
| return False | |
| parameter = parameters[list(parameters.keys())[0]] | |
| return ( | |
| parameter.annotation == object_type | |
| or parameter.annotation == parameter.empty | |
| or parameter.annotation == object | |
| ) | |
| def is_callable_return_value_of_type( | |
| callable: ExecuteTaskFunction, object_type: Any | |
| ) -> bool: | |
| return_annotation = inspect.signature(callable).return_annotation | |
| return return_annotation == object_type | |
| class BaseWorker(BotBase, ABC): | |
| poll_interval: float = Field( | |
| default=100, description="Worker poll interval in millisecond" | |
| ) | |
| domain: Optional[str] = Field(default=None, description="The domain of workflow") | |
| concurrency: int = Field(default=5, description="The concurrency of worker") | |
| _task_type: Optional[str] = None | |
| def model_post_init(self, __context: Any) -> None: | |
| self.task_definition_name = self.id or self.name | |
| self.next_task_index = 0 | |
| self._task_definition_name_cache = None | |
| self.api_client = ApiClient() | |
| self.worker_id = deepcopy(self.get_identity()) | |
| self._workflow_instance_id = None | |
| self._task_type = None | |
| for _, attr_value in self.__dict__.items(): | |
| if isinstance(attr_value, BotBase): | |
| attr_value._parent = self | |
| def task_type(self) -> str: | |
| return self._task_type | |
| def task_type(self, value: str): | |
| self._task_type = value | |
| def workflow_instance_id(self) -> Optional[str]: | |
| return self._workflow_instance_id | |
| def workflow_instance_id(self, value: Optional[str]): | |
| self._workflow_instance_id = value | |
| def _run(self, *args, **kwargs) -> Any: | |
| """Run the Node.""" | |
| def __call__(self, *args: Any, **kwds: Any) -> Any: | |
| print ("__call__") | |
| return self._run(*args, **kwds) | |
| def execute(self, task: Task) -> TaskResult: | |
| task_input = {} | |
| task_output = None | |
| task_result: TaskResult = self.get_task_result_from_task(task) | |
| if task.conversation_info: | |
| self.workflow_instance_id = '|'.join([ | |
| task.workflow_instance_id, | |
| task.conversation_info.get('agentId', ''), | |
| task.conversation_info.get('conversationId', ''), | |
| task.conversation_info.get('chatId', ''), | |
| ]) | |
| else: | |
| self.workflow_instance_id = task.workflow_instance_id | |
| try: | |
| if is_callable_input_parameter_a_task( | |
| callable=self._run, | |
| object_type=Task, | |
| ): | |
| task_output = self._run(task) | |
| else: | |
| params = inspect.signature(self._run).parameters | |
| for input_name in params: | |
| typ = params[input_name].annotation | |
| default_value = params[input_name].default | |
| if input_name in task.input_data: | |
| if typ in utils.simple_types: | |
| task_input[input_name] = task.input_data[input_name] | |
| else: | |
| task_input[input_name] = convert_from_dict_or_list( | |
| typ, task.input_data[input_name] | |
| ) | |
| else: | |
| if default_value is not inspect.Parameter.empty: | |
| task_input[input_name] = default_value | |
| else: | |
| task_input[input_name] = None | |
| if inspect.iscoroutinefunction(self._run): | |
| try: | |
| loop = asyncio.get_running_loop() | |
| except RuntimeError: | |
| loop = asyncio.new_event_loop() | |
| asyncio.set_event_loop(loop) | |
| task_output = loop.run_until_complete( | |
| asyncio.gather(self._run(**task_input), return_exceptions=True) | |
| )[0] | |
| else: | |
| task_output = self._run(**task_input) | |
| if type(task_output) == TaskResult: | |
| task_output.task_id = task.task_id | |
| task_output.workflow_instance_id = task.workflow_instance_id | |
| return task_output | |
| else: | |
| task_result.status = TaskResultStatus.COMPLETED | |
| task_result.output_data = task_output | |
| except NonRetryableException as ne: | |
| task_result.status = TaskResultStatus.FAILED_WITH_TERMINAL_ERROR | |
| if len(ne.args) > 0: | |
| task_result.reason_for_incompletion = ne.args[0] | |
| except Exception as ne: | |
| logger.error( | |
| f"Error executing task {task.task_def_name} with id {task.task_id}. error = {traceback.format_exc()}" | |
| ) | |
| task_result.logs = [ | |
| TaskExecLog( | |
| traceback.format_exc(), task_result.task_id, int(time.time()) | |
| ) | |
| ] | |
| task_result.status = TaskResultStatus.FAILED | |
| if len(ne.args) > 0: | |
| task_result.reason_for_incompletion = ne.args[0] | |
| self.workflow_instance_id = None | |
| if dataclasses.is_dataclass(type(task_result.output_data)): | |
| task_output = dataclasses.asdict(task_result.output_data) | |
| task_result.output_data = task_output | |
| return task_result | |
| if not isinstance(task_result.output_data, dict): | |
| task_output = task_result.output_data | |
| task_result.output_data = self.api_client.sanitize_for_serialization( | |
| task_output | |
| ) | |
| if not isinstance(task_result.output_data, dict): | |
| task_result.output_data = {"result": task_result.output_data} | |
| return task_result | |
| def get_identity(self) -> str: | |
| return self.worker_id if hasattr(self, "worker_id") else socket.gethostname() | |
| def get_polling_interval_in_seconds(self) -> float: | |
| return self.poll_interval / 1000 | |
| def get_task_definition_name(self) -> str: | |
| return self.task_definition_name_cache | |
| def task_definition_names(self): | |
| if isinstance(self.task_definition_name, list): | |
| return self.task_definition_name | |
| else: | |
| return [self.task_definition_name] | |
| def task_definition_name_cache(self): | |
| if self._task_definition_name_cache is None: | |
| self._task_definition_name_cache = self.compute_task_definition_name() | |
| return self._task_definition_name_cache | |
| def clear_task_definition_name_cache(self): | |
| self._task_definition_name_cache = None | |
| def compute_task_definition_name(self): | |
| if isinstance(self.task_definition_name, list): | |
| task_definition_name = self.task_definition_name[self.next_task_index] | |
| self.next_task_index = (self.next_task_index + 1) % len( | |
| self.task_definition_name | |
| ) | |
| return task_definition_name | |
| return self.task_definition_name | |
| def get_task_result_from_task(self, task: Task) -> TaskResult: | |
| return TaskResult( | |
| task_id=task.task_id, | |
| workflow_instance_id=task.workflow_instance_id, | |
| worker_id=self.get_identity(), | |
| biz_meta=task.biz_meta, | |
| callback_url=task.callback_url, | |
| ) | |
| def get_domain(self) -> str: | |
| return self.domain | |
| def paused(self) -> bool: | |
| return False | |