ragflow / agent /component /template.py
WANGRUI-ZB
Template conversion adds Jinjia2 syntax support (#4545)
0e469cf
raw
history blame
3.94 kB
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
import re
from agent.component.base import ComponentBase, ComponentParamBase
from jinja2 import Template as Jinja2Template
class TemplateParam(ComponentParamBase):
"""
Define the Generate component parameters.
"""
def __init__(self):
super().__init__()
self.content = ""
self.parameters = []
def check(self):
self.check_empty(self.content, "[Template] Content")
return True
class Template(ComponentBase):
component_name = "Template"
def get_dependent_components(self):
cpnts = set(
[
para["component_id"].split("@")[0]
for para in self._param.parameters
if para.get("component_id")
and para["component_id"].lower().find("answer") < 0
and para["component_id"].lower().find("begin") < 0
]
)
return list(cpnts)
def _run(self, history, **kwargs):
content = self._param.content
self._param.inputs = []
for para in self._param.parameters:
if not para.get("component_id"):
continue
component_id = para["component_id"].split("@")[0]
if para["component_id"].lower().find("@") >= 0:
cpn_id, key = para["component_id"].split("@")
for p in self._canvas.get_component(cpn_id)["obj"]._param.query:
if p["key"] == key:
value = p.get("value", "")
self.make_kwargs(para, kwargs, value)
break
else:
assert False, f"Can't find parameter '{key}' for {cpn_id}"
continue
cpn = self._canvas.get_component(component_id)["obj"]
if cpn.component_name.lower() == "answer":
hist = self._canvas.get_history(1)
if hist:
hist = hist[0]["content"]
else:
hist = ""
self.make_kwargs(para, kwargs, hist)
continue
_, out = cpn.output(allow_partial=False)
result = ""
if "content" in out.columns:
result = "\n".join(
[o if isinstance(o, str) else str(o) for o in out["content"]]
)
self.make_kwargs(para, kwargs, result)
template = Jinja2Template(content)
try:
content = template.render(kwargs)
except Exception:
pass
for n, v in kwargs.items():
try:
v = json.dumps(v, ensure_ascii=False)
except Exception:
pass
content = re.sub(
r"\{%s\}" % re.escape(n), v, content
)
content = re.sub(
r"(\\\"|\")", "", content
)
content = re.sub(
r"(#+)", r" \1 ", content
)
return Template.be_output(content)
def make_kwargs(self, para, kwargs, value):
self._param.inputs.append(
{"component_id": para["component_id"], "content": value}
)
try:
value = json.loads(value)
except Exception:
pass
kwargs[para["key"]] = value