Spaces:
Sleeping
Sleeping
import logging | |
from traceback import print_exc | |
from typing import List, Dict | |
import os.path as osp | |
import io | |
import copy | |
import re | |
import uuid | |
from matplotlib.pyplot import isinteractive | |
from numpy import isin | |
import sys | |
import os | |
sys.path.append(os.getcwd()) | |
from cllm.agents.base import Action, DataType, Tool, NON_FILE_TYPES | |
from cllm.agents.builtin import TOOLS | |
from cllm.agents.container import auto_type | |
from cllm.utils import get_real_path, get_root_dir, transform_msgs | |
logger = logging.getLogger(__name__) | |
def code(source, type="py"): | |
return f"```{type}\n{source}\n```" | |
class Interpretor: | |
def __init__(self): | |
self.tools = TOOLS | |
self.non_file_types = NON_FILE_TYPES | |
def interpret(self, stages: List[List[Action]], history_msgs: List = []): | |
memory = {} | |
solution = copy.deepcopy(stages) | |
history_msgs = copy.deepcopy(history_msgs) | |
history_msgs = transform_msgs(history_msgs) | |
has_error = False | |
for actions in solution: | |
for action in actions: | |
tool = self.load_tool(name=action.tool_name) | |
tool_inputs = self.load_args(tool, action.inputs, memory) | |
tool_inputs["history_msgs"] = history_msgs | |
tool_inputs["root_dir"] = get_root_dir() | |
try: | |
tool_outputs = tool.model(**tool_inputs) | |
action.inputs = self._update_inputs(memory, action.inputs) | |
action.outputs, wrapped_outputs = self._update_output( | |
memory, action, tool_outputs, tool | |
) | |
logger.info( | |
"Call {}, args {}, return {}".format( | |
action.tool_name, action.inputs, action.outputs | |
) | |
) | |
executed_action = ( | |
action.tool_name, | |
action.inputs, | |
action.outputs, | |
) | |
except FileNotFoundError as e: | |
print_exc() | |
tool_outputs = None | |
logger.error(f"Error when executing {action.tool_name}: {e}") | |
has_error = True | |
wrapped_outputs = [] | |
executed_action = ( | |
action.tool_name, | |
action.inputs, | |
f"FileNotFoundError: No such file or directory: {osp.basename(e.filename)}", | |
) | |
except Exception as e: | |
print_exc() | |
tool_outputs = None | |
has_error = True | |
logger.error(f"Error when executing {action.tool_name}: {e}") | |
wrapped_outputs = [] | |
executed_action = ( | |
action.tool_name, | |
action.inputs, | |
f"Internal error: {e}", | |
) | |
yield executed_action, solution, wrapped_outputs | |
if has_error: | |
return | |
def _update_output(self, memory, action, tool_outputs, tool): | |
outputs = [] | |
wrapped_outputs = [] | |
if action.outputs is not None: | |
if len(action.outputs) == 1: | |
tool_outputs = [tool_outputs] | |
for i, (arg_name, arg_value) in enumerate( | |
zip(action.outputs, tool_outputs) | |
): | |
memory[arg_name] = arg_value | |
if arg_value is None: | |
outputs.append(arg_value) | |
wrapped_outputs.append( | |
auto_type( | |
arg_name, | |
DataType.TEXT, | |
None, | |
) | |
) | |
continue | |
if isinstance(arg_value, (dict, list)): | |
arg_value = self.pretty_floats(arg_value) | |
if tool.returns[i].type in self.non_file_types: | |
outputs.append(arg_value) | |
wrapped_outputs.append( | |
auto_type( | |
arg_name, | |
tool.returns[i].type, | |
arg_value, | |
) | |
) | |
continue | |
transformed_output = self.transform_output( | |
action.inputs, | |
tool.name, | |
tool.args, | |
arg_value, | |
tool.returns[i].type, | |
) | |
outputs.append(transformed_output) | |
memory[arg_name] = transformed_output | |
if not isinstance(transformed_output, list): | |
wrapped_outputs.append( | |
auto_type( | |
arg_name, | |
tool.returns[i].type, | |
transformed_output, | |
) | |
) | |
continue | |
for output in transformed_output: | |
if DataType.MASK == tool.returns[i].type: | |
output = output if isinstance(output, str) else output["mask"] | |
wrapped_outputs.append( | |
auto_type( | |
arg_name, | |
tool.returns[i].type, | |
output if isinstance(output, str) else output["mask"], | |
) | |
) | |
return outputs, wrapped_outputs | |
def pretty_floats(self, obj): | |
if isinstance(obj, float): | |
return round(obj, 4) | |
elif isinstance(obj, dict): | |
return dict((k, self.pretty_floats(v)) for k, v in obj.items()) | |
elif isinstance(obj, (list, tuple)): | |
return list(map(self.pretty_floats, obj)) | |
return obj | |
def _update_inputs(self, memory, action_inputs): | |
action_inputs = copy.deepcopy(action_inputs) | |
for key, value in action_inputs.items(): | |
if "<TOOL-GENERATED>" in value: | |
action_inputs[key] = memory.get(value, value) | |
elif "<GENERATED>" in value: | |
action_inputs[key] = memory.get(value, value) | |
return action_inputs | |
def gen_filename(self, too_name, resource_type): | |
def to_camelcase(s): | |
res = re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), s) | |
res = res[0].upper() + res[1:] | |
return res | |
if resource_type == DataType.VIDEO: | |
ext = "mp4" | |
elif resource_type == DataType.AUDIO: | |
ext = "wav" | |
elif resource_type == DataType.HTML: | |
ext = "html" | |
else: | |
ext = "png" | |
too_name = too_name.replace("_to_", "2_") | |
too_name = to_camelcase(too_name) | |
this_file_id = str(uuid.uuid4())[:6] | |
type_str = str(resource_type).split(".")[-1] | |
return f"{this_file_id}_{type_str}.{ext}" | |
def _save_resource(self, file_name, resource, resource_type): | |
if isinstance(resource, dict): | |
if "mask" in resource: | |
resource = resource["mask"] | |
if resource_type == DataType.HTML: | |
with open(get_real_path(file_name), "w") as fout: | |
fout.write(resource) | |
elif resource is not None: | |
if isinstance(resource, io.BufferedReader): | |
resource = resource.read() | |
with open(get_real_path(file_name), "wb") as fout: | |
fout.write(resource) | |
else: | |
return None | |
def transform_output( | |
self, action_inputs, tool_name, tool_args, tool_output, output_type | |
): | |
if output_type != DataType.MASK: | |
if isinstance(tool_output, list): | |
results = [] | |
for output in tool_output: | |
file_name = self.gen_filename(tool_name, output_type) | |
self._save_resource(file_name, output, output_type) | |
results.append(file_name) | |
return results | |
else: | |
file_name = self.gen_filename(tool_name, output_type) | |
self._save_resource(file_name, tool_output, output_type) | |
return file_name | |
tool_output = copy.deepcopy(tool_output) | |
if isinstance(tool_output, list): | |
for output in tool_output: | |
if isinstance(output["mask"], str): | |
continue | |
file_name = self.gen_filename(tool_name, output_type) | |
self._save_resource(file_name, output, output_type) | |
output["mask"] = file_name | |
elif isinstance(tool_output, bytes): | |
file_name = self.gen_filename(tool_name, output_type) | |
self._save_resource(file_name, tool_output, output_type) | |
tool_output = file_name | |
elif tool_output is None: | |
pass | |
else: | |
raise RuntimeError("Wrong type.") | |
return tool_output | |
def load_tool(self, name): | |
return self.tools[name] | |
def load_args(self, tool: Tool, action_inputs, memory): | |
real_args = {} | |
for item in tool.args: | |
arg_name = item.name | |
arg_value = action_inputs[arg_name] | |
if "<GENERATED>" in arg_value or "<TOOL-GENERATED>" in arg_value: | |
assert arg_value in memory, print(f"Unknown {arg_name}: {arg_value}") | |
real_args[arg_name] = memory[arg_value] | |
else: | |
real_args[arg_name] = arg_value | |
return real_args | |
def variables(self): | |
return {k: v for k, v in self.memory.items() if k not in TOOLS and k != "print"} | |