File size: 6,668 Bytes
8cb7e14 4ff0435 8cb7e14 4ff0435 8cb7e14 4ff0435 db419e8 8cb7e14 4ff0435 8cb7e14 5683a92 4ff0435 5683a92 4ff0435 5683a92 8cb7e14 463dbd4 8cb7e14 5683a92 8cb7e14 5683a92 8cb7e14 5683a92 8cb7e14 5683a92 8cb7e14 db419e8 5683a92 43c0f71 5683a92 db419e8 8cb7e14 db419e8 8cb7e14 50f8506 db419e8 09be0e1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
from copy import deepcopy
from typing import Dict, Any
import hydra
from aiflows.prompt_template import JinjaPrompt
from aiflows.base_flows import AtomicFlow
from aiflows.messages import UpdateMessage_Generic
from aiflows.utils import logging
from aiflows.messages import FlowMessage
# logging.set_verbosity_debug() # ToDo: Has no effect on the logger for __name__. Level is warn, and info is not printed
log = logging.get_logger(f"aiflows.{__name__}") # ToDo: Is there a better fix?
class HumanStandardInputFlow(AtomicFlow):
""" This class implements a HumanStandardInputFlow. It's used to read input from the user/human. Typically used to get feedback from the user/human.
*Configuration Parameters*:
- `name` (str): The name of the flow.
- `description` (str): A description of the flow. This description is used to generate the help message of the flow.
Default: "Reads input from the user's standard input."
- `request_multi_line_input_flag` (bool): If True, the user/human is requested to enter a multi-line input.
If False, the user/human is requested to enter a single-line input. Default: No defaul, this parameter is required.
- `end_of_input_string` (str): The string that the user/human should enter to indicate that the input is finished.
This parameter is only used if "request_multi_line_input_flag" is True. Default: "EOI"
- `query_message_prompt_template` (JinjaPrompt): The prompt template used to generate the query message. By default its of type aiflows.prompt_template.JinjaPrompt.
None of the parameters of the prompt are defined by default and therefore need to be defined if one wants to use the init_human_message_prompt_template. Default parameters are defined in
aiflows.prompt_template.jinja2_prompts.JinjaPrompt.
- The other parameters are inherited from the default configuration of AtomicFlow (see AtomicFlow)
*input_interface*:
- No Input Interface. By default, the input interface expects no input. But if inputs are expected from the query_message_prompt_template,then the input interface should contain the keys specified in the input_variables of the query_message_prompt_template.
*output_interface*:
- `human_input` (str): The message inputed from the user/human.
:param query_message_prompt_template: The prompt template used to generate the query message. Expected if the class is instantiated programmatically.
:type query_message_prompt_template: JinjaPrompt
:param \**kwargs: The keyword arguments passed to the AtomicFlow constructor. Use to create the flow_config. Includes request_multi_line_input_flag, end_of_input_string, input_keys, description of Configuration Parameters.
:type \**kwargs: Dict[str, Any]
"""
REQUIRED_KEYS_CONFIG = ["request_multi_line_input_flag"]
query_message_prompt_template: JinjaPrompt = None
def __init__(self, query_message_prompt_template, **kwargs):
super().__init__(**kwargs)
self.query_message_prompt_template = query_message_prompt_template
@classmethod
def _set_up_prompts(cls, config):
""" Instantiates the prompt templates from the config.
:param config: The configuration of the flow.
:type config: Dict[str, Any]
:return: A dictionary of keyword arguments to pass to the constructor of the flow.
"""
kwargs = {}
kwargs["query_message_prompt_template"] = \
hydra.utils.instantiate(config['query_message_prompt_template'], _convert_="partial")
return kwargs
@classmethod
def instantiate_from_config(cls, config):
""" Instantiates the flow from a config file.
:param config: The configuration of the flow.
:type config: Dict[str, Any]
"""
flow_config = deepcopy(config)
kwargs = {"flow_config": flow_config}
# ~~~ Set up prompts ~~~
kwargs.update(cls._set_up_prompts(flow_config))
# ~~~ Instantiate flow ~~~
return cls(**kwargs)
@staticmethod
def _get_message(prompt_template, input_data: Dict[str, Any]):
""" Returns the message content given the prompt template and the input data.
:param prompt_template: The prompt template.
:type prompt_template: JinjaPrompt
:param input_data: The input data.
:type input_data: Dict[str, Any]
:return: The message content.
"""
template_kwargs = {}
for input_variable in prompt_template.input_variables:
template_kwargs[input_variable] = input_data[input_variable]
msg_content = prompt_template.format(**template_kwargs)
return msg_content
def _read_input(self):
""" Reads the input from the user/human's standard input.
:return: The input read from the user/human's standard input.
:rtype: str
"""
if not self.flow_config["request_multi_line_input_flag"]:
log.info("Please enter you single-line response and press enter.")
human_input = input()
return human_input
end_of_input_string = self.flow_config["end_of_input_string"]
log.info(f"Please enter your multi-line response below. "
f"To submit the response, write `{end_of_input_string}` on a new line and press enter.")
content = []
while True:
line = input()
if line == self.flow_config["end_of_input_string"]:
break
content.append(line)
human_input = "\n".join(content)
return human_input
def run(self,
input_message: FlowMessage):
""" Runs the HumanStandardInputFlow. It's used to read input from the user/human's standard input.
:param input_message: The input message
:type input_message: FlowMessage
"""
input_data = input_message.data
query_message = self._get_message(self.query_message_prompt_template, input_data)
state_update_message = UpdateMessage_Generic(
created_by=self.flow_config['name'],
updated_flow=self.flow_config["name"],
data={"query_message": query_message},
)
self._log_message(state_update_message)
log.info(query_message)
human_input = self._read_input()
reply_message = self.package_output_message(
input_message = input_message,
response = {"human_input": human_input}
)
self.send_message(reply_message)
|