File size: 5,811 Bytes
3fdcc70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import json
from typing import List
import logging
import os
import sys

sys.path.append(os.getcwd())
from ..base import Action, NON_FILE_TYPES

# from cllm.services.tog import TaskSolver, TaskDecomposer, config
# from cllm.services.nlp.llms import ChatOpenAI, MessageMemory
from cllm.services.tog.api import tog, task_decomposer
from collections import OrderedDict
import copy


logger = logging.getLogger(__name__)


class Planner:
    def __init__(
        self, streaming=False, backend="remote", device="cuda:0", **llm_kwargs
    ):
        self.streaming = streaming
        if backend == "local":
            pass
            # self.cfg = config
            # self.device = device
            # self.mem = MessageMemory(**self.cfg.memory)
            # self.llm = ChatOpenAI(temperature=0.2, **llm_kwargs)
            # self.tog = TaskSolver(self.llm, self.cfg.task_solver_config, device).solve
            # self.decomposer = TaskDecomposer(device, self.cfg.task_decomposer_cfg).solve
        elif backend == "remote":
            self.decomposer = task_decomposer
            self.tog = tog
        else:
            raise ValueError("Backend should be chosen from [remote, local]")

    def _find_latest_resource(self, resources, type):
        for key, val in list(resources.items())[::-1]:
            if val == type:
                return key
        return None

    def _check_task_decomposition(
        self, task_decomposition: str | list, available_resources: dict
    ):
        copy_task_decomposition = copy.deepcopy(task_decomposition)
        available_resources = copy.deepcopy(available_resources)
        if isinstance(copy_task_decomposition, str):
            copy_task_decomposition = json.loads(copy_task_decomposition)

        for subtask in copy_task_decomposition:
            for arg in subtask["args"]:
                if arg["type"] in NON_FILE_TYPES:
                    continue

                r_type = available_resources.get(arg["value"], "None").split(".")[-1]
                if arg["value"] not in available_resources or arg["type"] != r_type:
                    new_value = self._find_latest_resource(
                        available_resources, arg["type"]
                    )
                    if new_value is None:
                        logger.error(
                            f"No available resource for {arg['value']} with type {arg['type']}"
                        )
                        return None

                    arg["value"] = new_value

            available_resources[subtask["returns"][0]["value"]] = subtask["returns"][0][
                "type"
            ]
        return json.dumps(copy_task_decomposition, indent=2, ensure_ascii=False)

    def wrap_request(self, request, memory):
        logger.info(memory)
        resource_list = {k: v.split(".")[-1] for k, v in memory.items()}
        request = f"Resource list: {resource_list}\n{request}"
        logger.info(f"Input: {request}")
        return request

    def solve_streaming(self, request: str, memory: dict = OrderedDict()):
        request = self.wrap_request(request, memory)
        sub_tasks = self.decomposer(request, streaming=self.streaming)
        logger.info(f"Task decomposition: \n{sub_tasks}")
        sub_tasks = self._check_task_decomposition(sub_tasks, memory)
        yield sub_tasks
        if sub_tasks in [None, "", []]:
            yield None
        else:
            solutions = self.tog(request, sub_tasks, streaming=self.streaming)
            yield solutions

    def solve(self, request: str, memory: dict = OrderedDict()) -> List:
        self.wrap_request(request, memory)
        sub_tasks = self.decomposer(request)
        solutions = self.tog(request, sub_tasks)
        print(f"solutions: {solutions}")
        return sub_tasks, solutions

    def plan(self, task, memory: dict = OrderedDict()) -> List:
        if self.streaming:
            return self.solve_streaming(task, memory)
        else:
            return self.solve(task, memory)

    def _check_solutions(self, solution: List | str) -> bool:
        if isinstance(solution, str):
            solution = json.loads(solution)
        if len(solution) == 0:
            return False

        valid = True
        for i, stage_candiate in enumerate(solution):
            if len(stage_candiate) == 0:
                logger.error(f"No solution is found in {i}-th subtask.")
                valid = False
            elif (
                "solution" in stage_candiate[0]
                and len(stage_candiate[0]["solution"]) == 0
            ):
                logger.error(f"No solution is found in {i+1}-th subtask.")
                valid = False
            else:
                logger.info(f"Solutions for {i+1}-th subtask:\n{stage_candiate}")
        return valid

    def parse(self, solution: List | str) -> List[List[Action]]:
        if isinstance(solution, str):
            solution = json.loads(solution)

        if not self._check_solutions(solution):
            return None

        if isinstance(solution[0], Action):
            return solution

        stages = []
        for i, stage_candiate in enumerate(solution):
            stage = stage_candiate[0]["solution"]
            actions = []
            for action in stage:
                inputs = {arg["name"]: arg["value"] for arg in action["args"]}
                outputs = [r["value"] for r in action["returns"]]
                actions.append(
                    Action(action["tool_name"], inputs=inputs, outputs=outputs)
                )
            stages.append(actions)
        return stages

    def __call__(
        self, request: str, memory: dict = OrderedDict()
    ) -> List[List[Action]]:
        solution = self.solve(request, memory)
        return self.parse(solution)