Spaces:
Sleeping
Sleeping
| from copy import deepcopy | |
| from lagent.schema import ActionReturn, ActionStatusCode, FunctionCall | |
| from .hook import Hook | |
| class ActionPreprocessor(Hook): | |
| """The ActionPreprocessor is a hook that preprocesses the action message | |
| and postprocesses the action return message. | |
| """ | |
| def before_action(self, executor, message, session_id): | |
| assert isinstance(message.formatted, FunctionCall) or ( | |
| isinstance(message.formatted, dict) and 'name' in message.content | |
| and 'parameters' in message.formatted) or ( | |
| 'action' in message.formatted | |
| and 'parameters' in message.formatted['action'] | |
| and 'name' in message.formatted['action']) | |
| if isinstance(message.formatted, dict): | |
| name = message.formatted.get('name', | |
| message.formatted['action']['name']) | |
| parameters = message.formatted.get( | |
| 'parameters', message.formatted['action']['parameters']) | |
| else: | |
| name = message.formatted.name | |
| parameters = message.formatted.parameters | |
| message.content = dict(name=name, parameters=parameters) | |
| return message | |
| def after_action(self, executor, message, session_id): | |
| action_return = message.content | |
| if isinstance(action_return, ActionReturn): | |
| if action_return.state == ActionStatusCode.SUCCESS: | |
| response = action_return.format_result() | |
| else: | |
| response = action_return.errmsg | |
| else: | |
| response = action_return | |
| message.content = response | |
| return message | |
| class InternLMActionProcessor(ActionPreprocessor): | |
| def __init__(self, code_parameter: str = 'command'): | |
| self.code_parameter = code_parameter | |
| def before_action(self, executor, message, session_id): | |
| message = deepcopy(message) | |
| assert isinstance(message.formatted, dict) and set( | |
| message.formatted).issuperset( | |
| {'tool_type', 'thought', 'action', 'status'}) | |
| if isinstance(message.formatted['action'], str): | |
| # encapsulate code interpreter arguments | |
| action_name = next(iter(executor.actions)) | |
| parameters = {self.code_parameter: message.formatted['action']} | |
| if action_name in ['AsyncIPythonInterpreter']: | |
| parameters['session_id'] = session_id | |
| message.formatted['action'] = dict( | |
| name=action_name, parameters=parameters) | |
| return super().before_action(executor, message, session_id) | |