File size: 6,668 Bytes
8cb7e14
 
 
 
4ff0435
8cb7e14
4ff0435
 
8cb7e14
4ff0435
db419e8
8cb7e14
4ff0435
8cb7e14
 
 
5683a92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4ff0435
5683a92
4ff0435
5683a92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8cb7e14
 
463dbd4
8cb7e14
 
 
 
 
 
 
5683a92
 
 
 
 
 
8cb7e14
 
 
 
 
 
 
 
5683a92
 
 
 
 
 
8cb7e14
 
 
 
 
 
 
 
 
 
 
 
5683a92
 
 
 
 
 
 
 
8cb7e14
 
 
 
 
 
 
 
5683a92
 
 
 
 
8cb7e14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
db419e8
5683a92
 
43c0f71
 
5683a92
db419e8
 
8cb7e14
db419e8
8cb7e14
 
 
 
 
 
 
 
 
 
50f8506
 
 
 
db419e8
09be0e1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
from copy import deepcopy
from typing import Dict, Any

import hydra
from aiflows.prompt_template import JinjaPrompt

from aiflows.base_flows import AtomicFlow
from aiflows.messages import UpdateMessage_Generic

from aiflows.utils import logging
from aiflows.messages import FlowMessage
# logging.set_verbosity_debug()  # ToDo: Has no effect on the logger for __name__. Level is warn, and info is not printed
log = logging.get_logger(f"aiflows.{__name__}")  # ToDo: Is there a better fix?


class HumanStandardInputFlow(AtomicFlow):
    """ This class implements a HumanStandardInputFlow. It's used to read input from the user/human. Typically used to get feedback from the user/human.
    
    *Configuration Parameters*:
    
    - `name` (str): The name of the flow.
    
    - `description` (str): A description of the flow. This description is used to generate the help message of the flow.
    Default: "Reads input from the user's standard input."
    
    - `request_multi_line_input_flag` (bool): If True, the user/human is requested to enter a multi-line input.
    If False, the user/human is requested to enter a single-line input. Default: No defaul, this parameter is required.
    
    - `end_of_input_string` (str): The string that the user/human should enter to indicate that the input is finished.
    This parameter is only used if "request_multi_line_input_flag" is True. Default: "EOI"
        
    - `query_message_prompt_template` (JinjaPrompt): The prompt template used to generate the query message. By default its of type aiflows.prompt_template.JinjaPrompt.
    None of the parameters of the prompt are defined by default and therefore need to be defined if one wants to use the init_human_message_prompt_template. Default parameters are defined in
    aiflows.prompt_template.jinja2_prompts.JinjaPrompt.
    
    - The other parameters are inherited from the default configuration of AtomicFlow (see AtomicFlow)
    
    *input_interface*:
    
    - No Input Interface. By default, the input interface expects no input. But if inputs are expected from the query_message_prompt_template,then the input interface should contain the keys specified in the input_variables of the query_message_prompt_template.
    
    *output_interface*:
    
    - `human_input` (str): The message inputed from the user/human.
        
    :param query_message_prompt_template: The prompt template used to generate the query message. Expected if the class is instantiated programmatically.
    :type query_message_prompt_template: JinjaPrompt
    :param \**kwargs: The keyword arguments passed to the AtomicFlow constructor. Use to create the flow_config. Includes request_multi_line_input_flag, end_of_input_string, input_keys, description of Configuration Parameters.
    :type \**kwargs: Dict[str, Any]    
    """
    REQUIRED_KEYS_CONFIG = ["request_multi_line_input_flag"]

    query_message_prompt_template: JinjaPrompt = None

    def __init__(self, query_message_prompt_template, **kwargs):
        super().__init__(**kwargs)
        self.query_message_prompt_template = query_message_prompt_template

    @classmethod
    def _set_up_prompts(cls, config):
        """ Instantiates the prompt templates from the config.
        
        :param config: The configuration of the flow.
        :type config: Dict[str, Any]
        :return: A dictionary of keyword arguments to pass to the constructor of the flow.
        """
        kwargs = {}

        kwargs["query_message_prompt_template"] = \
            hydra.utils.instantiate(config['query_message_prompt_template'], _convert_="partial")
        return kwargs

    @classmethod
    def instantiate_from_config(cls, config):
        """ Instantiates the flow from a config file.
        
        :param config: The configuration of the flow.
        :type config: Dict[str, Any]
        """
        
        flow_config = deepcopy(config)

        kwargs = {"flow_config": flow_config}

        # ~~~ Set up prompts ~~~
        kwargs.update(cls._set_up_prompts(flow_config))

        # ~~~ Instantiate flow ~~~
        return cls(**kwargs)

    @staticmethod
    def _get_message(prompt_template, input_data: Dict[str, Any]):
        """ Returns the message content given the prompt template and the input data.
        
        :param prompt_template: The prompt template.
        :type prompt_template: JinjaPrompt
        :param input_data: The input data.
        :type input_data: Dict[str, Any]
        :return: The message content.
        """
        template_kwargs = {}
        for input_variable in prompt_template.input_variables:
            template_kwargs[input_variable] = input_data[input_variable]

        msg_content = prompt_template.format(**template_kwargs)
        return msg_content

    def _read_input(self):
        """ Reads the input from the user/human's standard input.
        
        :return: The input read from the user/human's standard input.
        :rtype: str
        """
        if not self.flow_config["request_multi_line_input_flag"]:
            log.info("Please enter you single-line response and press enter.")
            human_input = input()
            return human_input

        end_of_input_string = self.flow_config["end_of_input_string"]
        log.info(f"Please enter your multi-line response below. "
                 f"To submit the response, write `{end_of_input_string}` on a new line and press enter.")

        content = []
        while True:
            line = input()
            if line == self.flow_config["end_of_input_string"]:
                break
            content.append(line)
        human_input = "\n".join(content)
        return human_input

    def run(self,
            input_message: FlowMessage):
        """ Runs the HumanStandardInputFlow. It's used to read input from the user/human's standard input.
        
        :param input_message: The input message
        :type input_message: FlowMessage
        """
        input_data = input_message.data
        
        query_message = self._get_message(self.query_message_prompt_template, input_data)
        
        state_update_message = UpdateMessage_Generic(
            created_by=self.flow_config['name'],
            updated_flow=self.flow_config["name"],
            data={"query_message": query_message},
        )
        self._log_message(state_update_message)

        log.info(query_message)
        human_input = self._read_input()

        reply_message = self.package_output_message(
            input_message = input_message,
            response = {"human_input": human_input}
        )
        
        self.send_message(reply_message)