Upload 163 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- api/core/agent/__init__.py +0 -0
- api/core/agent/base_agent_runner.py +539 -0
- api/core/agent/cot_agent_runner.py +437 -0
- api/core/agent/cot_chat_agent_runner.py +117 -0
- api/core/agent/cot_completion_agent_runner.py +88 -0
- api/core/agent/entities.py +82 -0
- api/core/agent/fc_agent_runner.py +473 -0
- api/core/agent/output_parser/cot_output_parser.py +208 -0
- api/core/agent/prompt/template.py +106 -0
- api/core/app/__init__.py +0 -0
- api/core/app/app_config/__init__.py +0 -0
- api/core/app/app_config/base_app_config_manager.py +49 -0
- api/core/app/app_config/common/__init__.py +0 -0
- api/core/app/app_config/common/sensitive_word_avoidance/__init__.py +0 -0
- api/core/app/app_config/common/sensitive_word_avoidance/manager.py +45 -0
- api/core/app/app_config/easy_ui_based_app/__init__.py +0 -0
- api/core/app/app_config/easy_ui_based_app/agent/__init__.py +0 -0
- api/core/app/app_config/easy_ui_based_app/agent/manager.py +81 -0
- api/core/app/app_config/easy_ui_based_app/dataset/__init__.py +0 -0
- api/core/app/app_config/easy_ui_based_app/dataset/manager.py +221 -0
- api/core/app/app_config/easy_ui_based_app/model_config/__init__.py +0 -0
- api/core/app/app_config/easy_ui_based_app/model_config/converter.py +87 -0
- api/core/app/app_config/easy_ui_based_app/model_config/manager.py +114 -0
- api/core/app/app_config/easy_ui_based_app/prompt_template/__init__.py +0 -0
- api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py +138 -0
- api/core/app/app_config/easy_ui_based_app/variables/__init__.py +0 -0
- api/core/app/app_config/easy_ui_based_app/variables/manager.py +168 -0
- api/core/app/app_config/entities.py +267 -0
- api/core/app/app_config/features/__init__.py +0 -0
- api/core/app/app_config/features/file_upload/__init__.py +0 -0
- api/core/app/app_config/features/file_upload/manager.py +44 -0
- api/core/app/app_config/features/more_like_this/__init__.py +0 -0
- api/core/app/app_config/features/more_like_this/manager.py +36 -0
- api/core/app/app_config/features/opening_statement/__init__.py +0 -0
- api/core/app/app_config/features/opening_statement/manager.py +41 -0
- api/core/app/app_config/features/retrieval_resource/__init__.py +0 -0
- api/core/app/app_config/features/retrieval_resource/manager.py +31 -0
- api/core/app/app_config/features/speech_to_text/__init__.py +0 -0
- api/core/app/app_config/features/speech_to_text/manager.py +36 -0
- api/core/app/app_config/features/suggested_questions_after_answer/__init__.py +0 -0
- api/core/app/app_config/features/suggested_questions_after_answer/manager.py +39 -0
- api/core/app/app_config/features/text_to_speech/__init__.py +0 -0
- api/core/app/app_config/features/text_to_speech/manager.py +45 -0
- api/core/app/app_config/workflow_ui_based_app/__init__.py +0 -0
- api/core/app/app_config/workflow_ui_based_app/variables/__init__.py +0 -0
- api/core/app/app_config/workflow_ui_based_app/variables/manager.py +22 -0
- api/core/app/apps/README.md +48 -0
- api/core/app/apps/__init__.py +0 -0
- api/core/app/apps/advanced_chat/__init__.py +0 -0
- api/core/app/apps/advanced_chat/app_config_manager.py +91 -0
api/core/agent/__init__.py
ADDED
File without changes
|
api/core/agent/base_agent_runner.py
ADDED
@@ -0,0 +1,539 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import logging
|
3 |
+
import uuid
|
4 |
+
from datetime import UTC, datetime
|
5 |
+
from typing import Optional, Union, cast
|
6 |
+
|
7 |
+
from core.agent.entities import AgentEntity, AgentToolEntity
|
8 |
+
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
9 |
+
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig
|
10 |
+
from core.app.apps.base_app_queue_manager import AppQueueManager
|
11 |
+
from core.app.apps.base_app_runner import AppRunner
|
12 |
+
from core.app.entities.app_invoke_entities import (
|
13 |
+
AgentChatAppGenerateEntity,
|
14 |
+
ModelConfigWithCredentialsEntity,
|
15 |
+
)
|
16 |
+
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
|
17 |
+
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
18 |
+
from core.file import file_manager
|
19 |
+
from core.memory.token_buffer_memory import TokenBufferMemory
|
20 |
+
from core.model_manager import ModelInstance
|
21 |
+
from core.model_runtime.entities import (
|
22 |
+
AssistantPromptMessage,
|
23 |
+
LLMUsage,
|
24 |
+
PromptMessage,
|
25 |
+
PromptMessageContent,
|
26 |
+
PromptMessageTool,
|
27 |
+
SystemPromptMessage,
|
28 |
+
TextPromptMessageContent,
|
29 |
+
ToolPromptMessage,
|
30 |
+
UserPromptMessage,
|
31 |
+
)
|
32 |
+
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
|
33 |
+
from core.model_runtime.entities.model_entities import ModelFeature
|
34 |
+
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
35 |
+
from core.model_runtime.utils.encoders import jsonable_encoder
|
36 |
+
from core.prompt.utils.extract_thread_messages import extract_thread_messages
|
37 |
+
from core.tools.entities.tool_entities import (
|
38 |
+
ToolParameter,
|
39 |
+
ToolRuntimeVariablePool,
|
40 |
+
)
|
41 |
+
from core.tools.tool.dataset_retriever_tool import DatasetRetrieverTool
|
42 |
+
from core.tools.tool.tool import Tool
|
43 |
+
from core.tools.tool_manager import ToolManager
|
44 |
+
from extensions.ext_database import db
|
45 |
+
from factories import file_factory
|
46 |
+
from models.model import Conversation, Message, MessageAgentThought, MessageFile
|
47 |
+
from models.tools import ToolConversationVariables
|
48 |
+
|
49 |
+
logger = logging.getLogger(__name__)
|
50 |
+
|
51 |
+
|
52 |
+
class BaseAgentRunner(AppRunner):
|
53 |
+
def __init__(
|
54 |
+
self,
|
55 |
+
*,
|
56 |
+
tenant_id: str,
|
57 |
+
application_generate_entity: AgentChatAppGenerateEntity,
|
58 |
+
conversation: Conversation,
|
59 |
+
app_config: AgentChatAppConfig,
|
60 |
+
model_config: ModelConfigWithCredentialsEntity,
|
61 |
+
config: AgentEntity,
|
62 |
+
queue_manager: AppQueueManager,
|
63 |
+
message: Message,
|
64 |
+
user_id: str,
|
65 |
+
memory: Optional[TokenBufferMemory] = None,
|
66 |
+
prompt_messages: Optional[list[PromptMessage]] = None,
|
67 |
+
variables_pool: Optional[ToolRuntimeVariablePool] = None,
|
68 |
+
db_variables: Optional[ToolConversationVariables] = None,
|
69 |
+
model_instance: ModelInstance,
|
70 |
+
) -> None:
|
71 |
+
self.tenant_id = tenant_id
|
72 |
+
self.application_generate_entity = application_generate_entity
|
73 |
+
self.conversation = conversation
|
74 |
+
self.app_config = app_config
|
75 |
+
self.model_config = model_config
|
76 |
+
self.config = config
|
77 |
+
self.queue_manager = queue_manager
|
78 |
+
self.message = message
|
79 |
+
self.user_id = user_id
|
80 |
+
self.memory = memory
|
81 |
+
self.history_prompt_messages = self.organize_agent_history(prompt_messages=prompt_messages or [])
|
82 |
+
self.variables_pool = variables_pool
|
83 |
+
self.db_variables_pool = db_variables
|
84 |
+
self.model_instance = model_instance
|
85 |
+
|
86 |
+
# init callback
|
87 |
+
self.agent_callback = DifyAgentCallbackHandler()
|
88 |
+
# init dataset tools
|
89 |
+
hit_callback = DatasetIndexToolCallbackHandler(
|
90 |
+
queue_manager=queue_manager,
|
91 |
+
app_id=self.app_config.app_id,
|
92 |
+
message_id=message.id,
|
93 |
+
user_id=user_id,
|
94 |
+
invoke_from=self.application_generate_entity.invoke_from,
|
95 |
+
)
|
96 |
+
self.dataset_tools = DatasetRetrieverTool.get_dataset_tools(
|
97 |
+
tenant_id=tenant_id,
|
98 |
+
dataset_ids=app_config.dataset.dataset_ids if app_config.dataset else [],
|
99 |
+
retrieve_config=app_config.dataset.retrieve_config if app_config.dataset else None,
|
100 |
+
return_resource=app_config.additional_features.show_retrieve_source,
|
101 |
+
invoke_from=application_generate_entity.invoke_from,
|
102 |
+
hit_callback=hit_callback,
|
103 |
+
)
|
104 |
+
# get how many agent thoughts have been created
|
105 |
+
self.agent_thought_count = (
|
106 |
+
db.session.query(MessageAgentThought)
|
107 |
+
.filter(
|
108 |
+
MessageAgentThought.message_id == self.message.id,
|
109 |
+
)
|
110 |
+
.count()
|
111 |
+
)
|
112 |
+
db.session.close()
|
113 |
+
|
114 |
+
# check if model supports stream tool call
|
115 |
+
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
|
116 |
+
model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
|
117 |
+
features = model_schema.features if model_schema and model_schema.features else []
|
118 |
+
self.stream_tool_call = ModelFeature.STREAM_TOOL_CALL in features
|
119 |
+
self.files = application_generate_entity.files if ModelFeature.VISION in features else []
|
120 |
+
self.query: Optional[str] = ""
|
121 |
+
self._current_thoughts: list[PromptMessage] = []
|
122 |
+
|
123 |
+
def _repack_app_generate_entity(
|
124 |
+
self, app_generate_entity: AgentChatAppGenerateEntity
|
125 |
+
) -> AgentChatAppGenerateEntity:
|
126 |
+
"""
|
127 |
+
Repack app generate entity
|
128 |
+
"""
|
129 |
+
if app_generate_entity.app_config.prompt_template.simple_prompt_template is None:
|
130 |
+
app_generate_entity.app_config.prompt_template.simple_prompt_template = ""
|
131 |
+
|
132 |
+
return app_generate_entity
|
133 |
+
|
134 |
+
def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> tuple[PromptMessageTool, Tool]:
|
135 |
+
"""
|
136 |
+
convert tool to prompt message tool
|
137 |
+
"""
|
138 |
+
tool_entity = ToolManager.get_agent_tool_runtime(
|
139 |
+
tenant_id=self.tenant_id,
|
140 |
+
app_id=self.app_config.app_id,
|
141 |
+
agent_tool=tool,
|
142 |
+
invoke_from=self.application_generate_entity.invoke_from,
|
143 |
+
)
|
144 |
+
tool_entity.load_variables(self.variables_pool)
|
145 |
+
|
146 |
+
message_tool = PromptMessageTool(
|
147 |
+
name=tool.tool_name,
|
148 |
+
description=tool_entity.description.llm if tool_entity.description else "",
|
149 |
+
parameters={
|
150 |
+
"type": "object",
|
151 |
+
"properties": {},
|
152 |
+
"required": [],
|
153 |
+
},
|
154 |
+
)
|
155 |
+
|
156 |
+
parameters = tool_entity.get_all_runtime_parameters()
|
157 |
+
for parameter in parameters:
|
158 |
+
if parameter.form != ToolParameter.ToolParameterForm.LLM:
|
159 |
+
continue
|
160 |
+
|
161 |
+
parameter_type = parameter.type.as_normal_type()
|
162 |
+
if parameter.type in {
|
163 |
+
ToolParameter.ToolParameterType.SYSTEM_FILES,
|
164 |
+
ToolParameter.ToolParameterType.FILE,
|
165 |
+
ToolParameter.ToolParameterType.FILES,
|
166 |
+
}:
|
167 |
+
continue
|
168 |
+
enum = []
|
169 |
+
if parameter.type == ToolParameter.ToolParameterType.SELECT:
|
170 |
+
enum = [option.value for option in parameter.options] if parameter.options else []
|
171 |
+
|
172 |
+
message_tool.parameters["properties"][parameter.name] = {
|
173 |
+
"type": parameter_type,
|
174 |
+
"description": parameter.llm_description or "",
|
175 |
+
}
|
176 |
+
|
177 |
+
if len(enum) > 0:
|
178 |
+
message_tool.parameters["properties"][parameter.name]["enum"] = enum
|
179 |
+
|
180 |
+
if parameter.required:
|
181 |
+
message_tool.parameters["required"].append(parameter.name)
|
182 |
+
|
183 |
+
return message_tool, tool_entity
|
184 |
+
|
185 |
+
def _convert_dataset_retriever_tool_to_prompt_message_tool(self, tool: DatasetRetrieverTool) -> PromptMessageTool:
|
186 |
+
"""
|
187 |
+
convert dataset retriever tool to prompt message tool
|
188 |
+
"""
|
189 |
+
prompt_tool = PromptMessageTool(
|
190 |
+
name=tool.identity.name if tool.identity else "unknown",
|
191 |
+
description=tool.description.llm if tool.description else "",
|
192 |
+
parameters={
|
193 |
+
"type": "object",
|
194 |
+
"properties": {},
|
195 |
+
"required": [],
|
196 |
+
},
|
197 |
+
)
|
198 |
+
|
199 |
+
for parameter in tool.get_runtime_parameters():
|
200 |
+
parameter_type = "string"
|
201 |
+
|
202 |
+
prompt_tool.parameters["properties"][parameter.name] = {
|
203 |
+
"type": parameter_type,
|
204 |
+
"description": parameter.llm_description or "",
|
205 |
+
}
|
206 |
+
|
207 |
+
if parameter.required:
|
208 |
+
if parameter.name not in prompt_tool.parameters["required"]:
|
209 |
+
prompt_tool.parameters["required"].append(parameter.name)
|
210 |
+
|
211 |
+
return prompt_tool
|
212 |
+
|
213 |
+
def _init_prompt_tools(self) -> tuple[dict[str, Tool], list[PromptMessageTool]]:
|
214 |
+
"""
|
215 |
+
Init tools
|
216 |
+
"""
|
217 |
+
tool_instances = {}
|
218 |
+
prompt_messages_tools = []
|
219 |
+
|
220 |
+
for tool in self.app_config.agent.tools or [] if self.app_config.agent else []:
|
221 |
+
try:
|
222 |
+
prompt_tool, tool_entity = self._convert_tool_to_prompt_message_tool(tool)
|
223 |
+
except Exception:
|
224 |
+
# api tool may be deleted
|
225 |
+
continue
|
226 |
+
# save tool entity
|
227 |
+
tool_instances[tool.tool_name] = tool_entity
|
228 |
+
# save prompt tool
|
229 |
+
prompt_messages_tools.append(prompt_tool)
|
230 |
+
|
231 |
+
# convert dataset tools into ModelRuntime Tool format
|
232 |
+
for dataset_tool in self.dataset_tools:
|
233 |
+
prompt_tool = self._convert_dataset_retriever_tool_to_prompt_message_tool(dataset_tool)
|
234 |
+
# save prompt tool
|
235 |
+
prompt_messages_tools.append(prompt_tool)
|
236 |
+
# save tool entity
|
237 |
+
if dataset_tool.identity is not None:
|
238 |
+
tool_instances[dataset_tool.identity.name] = dataset_tool
|
239 |
+
|
240 |
+
return tool_instances, prompt_messages_tools
|
241 |
+
|
242 |
+
def update_prompt_message_tool(self, tool: Tool, prompt_tool: PromptMessageTool) -> PromptMessageTool:
|
243 |
+
"""
|
244 |
+
update prompt message tool
|
245 |
+
"""
|
246 |
+
# try to get tool runtime parameters
|
247 |
+
tool_runtime_parameters = tool.get_runtime_parameters()
|
248 |
+
|
249 |
+
for parameter in tool_runtime_parameters:
|
250 |
+
if parameter.form != ToolParameter.ToolParameterForm.LLM:
|
251 |
+
continue
|
252 |
+
|
253 |
+
parameter_type = parameter.type.as_normal_type()
|
254 |
+
if parameter.type in {
|
255 |
+
ToolParameter.ToolParameterType.SYSTEM_FILES,
|
256 |
+
ToolParameter.ToolParameterType.FILE,
|
257 |
+
ToolParameter.ToolParameterType.FILES,
|
258 |
+
}:
|
259 |
+
continue
|
260 |
+
enum = []
|
261 |
+
if parameter.type == ToolParameter.ToolParameterType.SELECT:
|
262 |
+
enum = [option.value for option in parameter.options] if parameter.options else []
|
263 |
+
|
264 |
+
prompt_tool.parameters["properties"][parameter.name] = {
|
265 |
+
"type": parameter_type,
|
266 |
+
"description": parameter.llm_description or "",
|
267 |
+
}
|
268 |
+
|
269 |
+
if len(enum) > 0:
|
270 |
+
prompt_tool.parameters["properties"][parameter.name]["enum"] = enum
|
271 |
+
|
272 |
+
if parameter.required:
|
273 |
+
if parameter.name not in prompt_tool.parameters["required"]:
|
274 |
+
prompt_tool.parameters["required"].append(parameter.name)
|
275 |
+
|
276 |
+
return prompt_tool
|
277 |
+
|
278 |
+
def create_agent_thought(
|
279 |
+
self, message_id: str, message: str, tool_name: str, tool_input: str, messages_ids: list[str]
|
280 |
+
) -> MessageAgentThought:
|
281 |
+
"""
|
282 |
+
Create agent thought
|
283 |
+
"""
|
284 |
+
thought = MessageAgentThought(
|
285 |
+
message_id=message_id,
|
286 |
+
message_chain_id=None,
|
287 |
+
thought="",
|
288 |
+
tool=tool_name,
|
289 |
+
tool_labels_str="{}",
|
290 |
+
tool_meta_str="{}",
|
291 |
+
tool_input=tool_input,
|
292 |
+
message=message,
|
293 |
+
message_token=0,
|
294 |
+
message_unit_price=0,
|
295 |
+
message_price_unit=0,
|
296 |
+
message_files=json.dumps(messages_ids) if messages_ids else "",
|
297 |
+
answer="",
|
298 |
+
observation="",
|
299 |
+
answer_token=0,
|
300 |
+
answer_unit_price=0,
|
301 |
+
answer_price_unit=0,
|
302 |
+
tokens=0,
|
303 |
+
total_price=0,
|
304 |
+
position=self.agent_thought_count + 1,
|
305 |
+
currency="USD",
|
306 |
+
latency=0,
|
307 |
+
created_by_role="account",
|
308 |
+
created_by=self.user_id,
|
309 |
+
)
|
310 |
+
|
311 |
+
db.session.add(thought)
|
312 |
+
db.session.commit()
|
313 |
+
db.session.refresh(thought)
|
314 |
+
db.session.close()
|
315 |
+
|
316 |
+
self.agent_thought_count += 1
|
317 |
+
|
318 |
+
return thought
|
319 |
+
|
320 |
+
def save_agent_thought(
|
321 |
+
self,
|
322 |
+
agent_thought: MessageAgentThought,
|
323 |
+
tool_name: str,
|
324 |
+
tool_input: Union[str, dict],
|
325 |
+
thought: str,
|
326 |
+
observation: Union[str, dict, None],
|
327 |
+
tool_invoke_meta: Union[str, dict, None],
|
328 |
+
answer: str,
|
329 |
+
messages_ids: list[str],
|
330 |
+
llm_usage: LLMUsage | None = None,
|
331 |
+
):
|
332 |
+
"""
|
333 |
+
Save agent thought
|
334 |
+
"""
|
335 |
+
queried_thought = (
|
336 |
+
db.session.query(MessageAgentThought).filter(MessageAgentThought.id == agent_thought.id).first()
|
337 |
+
)
|
338 |
+
if not queried_thought:
|
339 |
+
raise ValueError(f"Agent thought {agent_thought.id} not found")
|
340 |
+
agent_thought = queried_thought
|
341 |
+
|
342 |
+
if thought:
|
343 |
+
agent_thought.thought = thought
|
344 |
+
|
345 |
+
if tool_name:
|
346 |
+
agent_thought.tool = tool_name
|
347 |
+
|
348 |
+
if tool_input:
|
349 |
+
if isinstance(tool_input, dict):
|
350 |
+
try:
|
351 |
+
tool_input = json.dumps(tool_input, ensure_ascii=False)
|
352 |
+
except Exception as e:
|
353 |
+
tool_input = json.dumps(tool_input)
|
354 |
+
|
355 |
+
agent_thought.tool_input = tool_input
|
356 |
+
|
357 |
+
if observation:
|
358 |
+
if isinstance(observation, dict):
|
359 |
+
try:
|
360 |
+
observation = json.dumps(observation, ensure_ascii=False)
|
361 |
+
except Exception as e:
|
362 |
+
observation = json.dumps(observation)
|
363 |
+
|
364 |
+
agent_thought.observation = observation
|
365 |
+
|
366 |
+
if answer:
|
367 |
+
agent_thought.answer = answer
|
368 |
+
|
369 |
+
if messages_ids is not None and len(messages_ids) > 0:
|
370 |
+
agent_thought.message_files = json.dumps(messages_ids)
|
371 |
+
|
372 |
+
if llm_usage:
|
373 |
+
agent_thought.message_token = llm_usage.prompt_tokens
|
374 |
+
agent_thought.message_price_unit = llm_usage.prompt_price_unit
|
375 |
+
agent_thought.message_unit_price = llm_usage.prompt_unit_price
|
376 |
+
agent_thought.answer_token = llm_usage.completion_tokens
|
377 |
+
agent_thought.answer_price_unit = llm_usage.completion_price_unit
|
378 |
+
agent_thought.answer_unit_price = llm_usage.completion_unit_price
|
379 |
+
agent_thought.tokens = llm_usage.total_tokens
|
380 |
+
agent_thought.total_price = llm_usage.total_price
|
381 |
+
|
382 |
+
# check if tool labels is not empty
|
383 |
+
labels = agent_thought.tool_labels or {}
|
384 |
+
tools = agent_thought.tool.split(";") if agent_thought.tool else []
|
385 |
+
for tool in tools:
|
386 |
+
if not tool:
|
387 |
+
continue
|
388 |
+
if tool not in labels:
|
389 |
+
tool_label = ToolManager.get_tool_label(tool)
|
390 |
+
if tool_label:
|
391 |
+
labels[tool] = tool_label.to_dict()
|
392 |
+
else:
|
393 |
+
labels[tool] = {"en_US": tool, "zh_Hans": tool}
|
394 |
+
|
395 |
+
agent_thought.tool_labels_str = json.dumps(labels)
|
396 |
+
|
397 |
+
if tool_invoke_meta is not None:
|
398 |
+
if isinstance(tool_invoke_meta, dict):
|
399 |
+
try:
|
400 |
+
tool_invoke_meta = json.dumps(tool_invoke_meta, ensure_ascii=False)
|
401 |
+
except Exception as e:
|
402 |
+
tool_invoke_meta = json.dumps(tool_invoke_meta)
|
403 |
+
|
404 |
+
agent_thought.tool_meta_str = tool_invoke_meta
|
405 |
+
|
406 |
+
db.session.commit()
|
407 |
+
db.session.close()
|
408 |
+
|
409 |
+
def update_db_variables(self, tool_variables: ToolRuntimeVariablePool, db_variables: ToolConversationVariables):
|
410 |
+
"""
|
411 |
+
convert tool variables to db variables
|
412 |
+
"""
|
413 |
+
queried_variables = (
|
414 |
+
db.session.query(ToolConversationVariables)
|
415 |
+
.filter(
|
416 |
+
ToolConversationVariables.conversation_id == self.message.conversation_id,
|
417 |
+
)
|
418 |
+
.first()
|
419 |
+
)
|
420 |
+
|
421 |
+
if not queried_variables:
|
422 |
+
return
|
423 |
+
|
424 |
+
db_variables = queried_variables
|
425 |
+
|
426 |
+
db_variables.updated_at = datetime.now(UTC).replace(tzinfo=None)
|
427 |
+
db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool))
|
428 |
+
db.session.commit()
|
429 |
+
db.session.close()
|
430 |
+
|
431 |
+
def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
432 |
+
"""
|
433 |
+
Organize agent history
|
434 |
+
"""
|
435 |
+
result: list[PromptMessage] = []
|
436 |
+
# check if there is a system message in the beginning of the conversation
|
437 |
+
for prompt_message in prompt_messages:
|
438 |
+
if isinstance(prompt_message, SystemPromptMessage):
|
439 |
+
result.append(prompt_message)
|
440 |
+
|
441 |
+
messages: list[Message] = (
|
442 |
+
db.session.query(Message)
|
443 |
+
.filter(
|
444 |
+
Message.conversation_id == self.message.conversation_id,
|
445 |
+
)
|
446 |
+
.order_by(Message.created_at.desc())
|
447 |
+
.all()
|
448 |
+
)
|
449 |
+
|
450 |
+
messages = list(reversed(extract_thread_messages(messages)))
|
451 |
+
|
452 |
+
for message in messages:
|
453 |
+
if message.id == self.message.id:
|
454 |
+
continue
|
455 |
+
|
456 |
+
result.append(self.organize_agent_user_prompt(message))
|
457 |
+
agent_thoughts: list[MessageAgentThought] = message.agent_thoughts
|
458 |
+
if agent_thoughts:
|
459 |
+
for agent_thought in agent_thoughts:
|
460 |
+
tools = agent_thought.tool
|
461 |
+
if tools:
|
462 |
+
tools = tools.split(";")
|
463 |
+
tool_calls: list[AssistantPromptMessage.ToolCall] = []
|
464 |
+
tool_call_response: list[ToolPromptMessage] = []
|
465 |
+
try:
|
466 |
+
tool_inputs = json.loads(agent_thought.tool_input)
|
467 |
+
except Exception as e:
|
468 |
+
tool_inputs = {tool: {} for tool in tools}
|
469 |
+
try:
|
470 |
+
tool_responses = json.loads(agent_thought.observation)
|
471 |
+
except Exception as e:
|
472 |
+
tool_responses = dict.fromkeys(tools, agent_thought.observation)
|
473 |
+
|
474 |
+
for tool in tools:
|
475 |
+
# generate a uuid for tool call
|
476 |
+
tool_call_id = str(uuid.uuid4())
|
477 |
+
tool_calls.append(
|
478 |
+
AssistantPromptMessage.ToolCall(
|
479 |
+
id=tool_call_id,
|
480 |
+
type="function",
|
481 |
+
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
482 |
+
name=tool,
|
483 |
+
arguments=json.dumps(tool_inputs.get(tool, {})),
|
484 |
+
),
|
485 |
+
)
|
486 |
+
)
|
487 |
+
tool_call_response.append(
|
488 |
+
ToolPromptMessage(
|
489 |
+
content=tool_responses.get(tool, agent_thought.observation),
|
490 |
+
name=tool,
|
491 |
+
tool_call_id=tool_call_id,
|
492 |
+
)
|
493 |
+
)
|
494 |
+
|
495 |
+
result.extend(
|
496 |
+
[
|
497 |
+
AssistantPromptMessage(
|
498 |
+
content=agent_thought.thought,
|
499 |
+
tool_calls=tool_calls,
|
500 |
+
),
|
501 |
+
*tool_call_response,
|
502 |
+
]
|
503 |
+
)
|
504 |
+
if not tools:
|
505 |
+
result.append(AssistantPromptMessage(content=agent_thought.thought))
|
506 |
+
else:
|
507 |
+
if message.answer:
|
508 |
+
result.append(AssistantPromptMessage(content=message.answer))
|
509 |
+
|
510 |
+
db.session.close()
|
511 |
+
|
512 |
+
return result
|
513 |
+
|
514 |
+
def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage:
|
515 |
+
files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all()
|
516 |
+
if not files:
|
517 |
+
return UserPromptMessage(content=message.query)
|
518 |
+
file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict())
|
519 |
+
if not file_extra_config:
|
520 |
+
return UserPromptMessage(content=message.query)
|
521 |
+
|
522 |
+
image_detail_config = file_extra_config.image_config.detail if file_extra_config.image_config else None
|
523 |
+
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
|
524 |
+
|
525 |
+
file_objs = file_factory.build_from_message_files(
|
526 |
+
message_files=files, tenant_id=self.tenant_id, config=file_extra_config
|
527 |
+
)
|
528 |
+
if not file_objs:
|
529 |
+
return UserPromptMessage(content=message.query)
|
530 |
+
prompt_message_contents: list[PromptMessageContent] = []
|
531 |
+
prompt_message_contents.append(TextPromptMessageContent(data=message.query))
|
532 |
+
for file in file_objs:
|
533 |
+
prompt_message_contents.append(
|
534 |
+
file_manager.to_prompt_message_content(
|
535 |
+
file,
|
536 |
+
image_detail_config=image_detail_config,
|
537 |
+
)
|
538 |
+
)
|
539 |
+
return UserPromptMessage(content=prompt_message_contents)
|
api/core/agent/cot_agent_runner.py
ADDED
@@ -0,0 +1,437 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from abc import ABC, abstractmethod
|
3 |
+
from collections.abc import Generator, Mapping
|
4 |
+
from typing import Any, Optional
|
5 |
+
|
6 |
+
from core.agent.base_agent_runner import BaseAgentRunner
|
7 |
+
from core.agent.entities import AgentScratchpadUnit
|
8 |
+
from core.agent.output_parser.cot_output_parser import CotAgentOutputParser
|
9 |
+
from core.app.apps.base_app_queue_manager import PublishFrom
|
10 |
+
from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent
|
11 |
+
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
12 |
+
from core.model_runtime.entities.message_entities import (
|
13 |
+
AssistantPromptMessage,
|
14 |
+
PromptMessage,
|
15 |
+
PromptMessageTool,
|
16 |
+
ToolPromptMessage,
|
17 |
+
UserPromptMessage,
|
18 |
+
)
|
19 |
+
from core.ops.ops_trace_manager import TraceQueueManager
|
20 |
+
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
|
21 |
+
from core.tools.entities.tool_entities import ToolInvokeMeta
|
22 |
+
from core.tools.tool.tool import Tool
|
23 |
+
from core.tools.tool_engine import ToolEngine
|
24 |
+
from models.model import Message
|
25 |
+
|
26 |
+
|
27 |
+
class CotAgentRunner(BaseAgentRunner, ABC):
|
28 |
+
_is_first_iteration = True
|
29 |
+
_ignore_observation_providers = ["wenxin"]
|
30 |
+
_historic_prompt_messages: list[PromptMessage] | None = None
|
31 |
+
_agent_scratchpad: list[AgentScratchpadUnit] | None = None
|
32 |
+
_instruction: str = "" # FIXME this must be str for now
|
33 |
+
_query: str | None = None
|
34 |
+
_prompt_messages_tools: list[PromptMessageTool] = []
|
35 |
+
|
36 |
+
def run(
|
37 |
+
self,
|
38 |
+
message: Message,
|
39 |
+
query: str,
|
40 |
+
inputs: Mapping[str, str],
|
41 |
+
) -> Generator:
|
42 |
+
"""
|
43 |
+
Run Cot agent application
|
44 |
+
"""
|
45 |
+
app_generate_entity = self.application_generate_entity
|
46 |
+
self._repack_app_generate_entity(app_generate_entity)
|
47 |
+
self._init_react_state(query)
|
48 |
+
|
49 |
+
trace_manager = app_generate_entity.trace_manager
|
50 |
+
|
51 |
+
# check model mode
|
52 |
+
if "Observation" not in app_generate_entity.model_conf.stop:
|
53 |
+
if app_generate_entity.model_conf.provider not in self._ignore_observation_providers:
|
54 |
+
app_generate_entity.model_conf.stop.append("Observation")
|
55 |
+
|
56 |
+
app_config = self.app_config
|
57 |
+
|
58 |
+
# init instruction
|
59 |
+
inputs = inputs or {}
|
60 |
+
instruction = app_config.prompt_template.simple_prompt_template
|
61 |
+
self._instruction = self._fill_in_inputs_from_external_data_tools(instruction=instruction or "", inputs=inputs)
|
62 |
+
|
63 |
+
iteration_step = 1
|
64 |
+
max_iteration_steps = min(app_config.agent.max_iteration if app_config.agent else 5, 5) + 1
|
65 |
+
|
66 |
+
# convert tools into ModelRuntime Tool format
|
67 |
+
tool_instances, self._prompt_messages_tools = self._init_prompt_tools()
|
68 |
+
|
69 |
+
function_call_state = True
|
70 |
+
llm_usage: dict[str, Optional[LLMUsage]] = {"usage": None}
|
71 |
+
final_answer = ""
|
72 |
+
|
73 |
+
def increase_usage(final_llm_usage_dict: dict[str, Optional[LLMUsage]], usage: LLMUsage):
|
74 |
+
if not final_llm_usage_dict["usage"]:
|
75 |
+
final_llm_usage_dict["usage"] = usage
|
76 |
+
else:
|
77 |
+
llm_usage = final_llm_usage_dict["usage"]
|
78 |
+
llm_usage.prompt_tokens += usage.prompt_tokens
|
79 |
+
llm_usage.completion_tokens += usage.completion_tokens
|
80 |
+
llm_usage.prompt_price += usage.prompt_price
|
81 |
+
llm_usage.completion_price += usage.completion_price
|
82 |
+
llm_usage.total_price += usage.total_price
|
83 |
+
|
84 |
+
model_instance = self.model_instance
|
85 |
+
|
86 |
+
while function_call_state and iteration_step <= max_iteration_steps:
|
87 |
+
# continue to run until there is not any tool call
|
88 |
+
function_call_state = False
|
89 |
+
|
90 |
+
if iteration_step == max_iteration_steps:
|
91 |
+
# the last iteration, remove all tools
|
92 |
+
self._prompt_messages_tools = []
|
93 |
+
|
94 |
+
message_file_ids: list[str] = []
|
95 |
+
|
96 |
+
agent_thought = self.create_agent_thought(
|
97 |
+
message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids
|
98 |
+
)
|
99 |
+
|
100 |
+
if iteration_step > 1:
|
101 |
+
self.queue_manager.publish(
|
102 |
+
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
|
103 |
+
)
|
104 |
+
|
105 |
+
# recalc llm max tokens
|
106 |
+
prompt_messages = self._organize_prompt_messages()
|
107 |
+
self.recalc_llm_max_tokens(self.model_config, prompt_messages)
|
108 |
+
# invoke model
|
109 |
+
chunks = model_instance.invoke_llm(
|
110 |
+
prompt_messages=prompt_messages,
|
111 |
+
model_parameters=app_generate_entity.model_conf.parameters,
|
112 |
+
tools=[],
|
113 |
+
stop=app_generate_entity.model_conf.stop,
|
114 |
+
stream=True,
|
115 |
+
user=self.user_id,
|
116 |
+
callbacks=[],
|
117 |
+
)
|
118 |
+
|
119 |
+
if not isinstance(chunks, Generator):
|
120 |
+
raise ValueError("Expected streaming response from LLM")
|
121 |
+
|
122 |
+
# check llm result
|
123 |
+
if not chunks:
|
124 |
+
raise ValueError("failed to invoke llm")
|
125 |
+
|
126 |
+
usage_dict: dict[str, Optional[LLMUsage]] = {"usage": None}
|
127 |
+
react_chunks = CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)
|
128 |
+
scratchpad = AgentScratchpadUnit(
|
129 |
+
agent_response="",
|
130 |
+
thought="",
|
131 |
+
action_str="",
|
132 |
+
observation="",
|
133 |
+
action=None,
|
134 |
+
)
|
135 |
+
|
136 |
+
# publish agent thought if it's first iteration
|
137 |
+
if iteration_step == 1:
|
138 |
+
self.queue_manager.publish(
|
139 |
+
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
|
140 |
+
)
|
141 |
+
|
142 |
+
for chunk in react_chunks:
|
143 |
+
if isinstance(chunk, AgentScratchpadUnit.Action):
|
144 |
+
action = chunk
|
145 |
+
# detect action
|
146 |
+
if scratchpad.agent_response is not None:
|
147 |
+
scratchpad.agent_response += json.dumps(chunk.model_dump())
|
148 |
+
scratchpad.action_str = json.dumps(chunk.model_dump())
|
149 |
+
scratchpad.action = action
|
150 |
+
else:
|
151 |
+
if scratchpad.agent_response is not None:
|
152 |
+
scratchpad.agent_response += chunk
|
153 |
+
if scratchpad.thought is not None:
|
154 |
+
scratchpad.thought += chunk
|
155 |
+
yield LLMResultChunk(
|
156 |
+
model=self.model_config.model,
|
157 |
+
prompt_messages=prompt_messages,
|
158 |
+
system_fingerprint="",
|
159 |
+
delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=chunk), usage=None),
|
160 |
+
)
|
161 |
+
if scratchpad.thought is not None:
|
162 |
+
scratchpad.thought = scratchpad.thought.strip() or "I am thinking about how to help you"
|
163 |
+
if self._agent_scratchpad is not None:
|
164 |
+
self._agent_scratchpad.append(scratchpad)
|
165 |
+
|
166 |
+
# get llm usage
|
167 |
+
if "usage" in usage_dict:
|
168 |
+
if usage_dict["usage"] is not None:
|
169 |
+
increase_usage(llm_usage, usage_dict["usage"])
|
170 |
+
else:
|
171 |
+
usage_dict["usage"] = LLMUsage.empty_usage()
|
172 |
+
|
173 |
+
self.save_agent_thought(
|
174 |
+
agent_thought=agent_thought,
|
175 |
+
tool_name=(scratchpad.action.action_name if scratchpad.action and not scratchpad.is_final() else ""),
|
176 |
+
tool_input={scratchpad.action.action_name: scratchpad.action.action_input} if scratchpad.action else {},
|
177 |
+
tool_invoke_meta={},
|
178 |
+
thought=scratchpad.thought or "",
|
179 |
+
observation="",
|
180 |
+
answer=scratchpad.agent_response or "",
|
181 |
+
messages_ids=[],
|
182 |
+
llm_usage=usage_dict["usage"],
|
183 |
+
)
|
184 |
+
|
185 |
+
if not scratchpad.is_final():
|
186 |
+
self.queue_manager.publish(
|
187 |
+
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
|
188 |
+
)
|
189 |
+
|
190 |
+
if not scratchpad.action:
|
191 |
+
# failed to extract action, return final answer directly
|
192 |
+
final_answer = ""
|
193 |
+
else:
|
194 |
+
if scratchpad.action.action_name.lower() == "final answer":
|
195 |
+
# action is final answer, return final answer directly
|
196 |
+
try:
|
197 |
+
if isinstance(scratchpad.action.action_input, dict):
|
198 |
+
final_answer = json.dumps(scratchpad.action.action_input)
|
199 |
+
elif isinstance(scratchpad.action.action_input, str):
|
200 |
+
final_answer = scratchpad.action.action_input
|
201 |
+
else:
|
202 |
+
final_answer = f"{scratchpad.action.action_input}"
|
203 |
+
except json.JSONDecodeError:
|
204 |
+
final_answer = f"{scratchpad.action.action_input}"
|
205 |
+
else:
|
206 |
+
function_call_state = True
|
207 |
+
# action is tool call, invoke tool
|
208 |
+
tool_invoke_response, tool_invoke_meta = self._handle_invoke_action(
|
209 |
+
action=scratchpad.action,
|
210 |
+
tool_instances=tool_instances,
|
211 |
+
message_file_ids=message_file_ids,
|
212 |
+
trace_manager=trace_manager,
|
213 |
+
)
|
214 |
+
scratchpad.observation = tool_invoke_response
|
215 |
+
scratchpad.agent_response = tool_invoke_response
|
216 |
+
|
217 |
+
self.save_agent_thought(
|
218 |
+
agent_thought=agent_thought,
|
219 |
+
tool_name=scratchpad.action.action_name,
|
220 |
+
tool_input={scratchpad.action.action_name: scratchpad.action.action_input},
|
221 |
+
thought=scratchpad.thought or "",
|
222 |
+
observation={scratchpad.action.action_name: tool_invoke_response},
|
223 |
+
tool_invoke_meta={scratchpad.action.action_name: tool_invoke_meta.to_dict()},
|
224 |
+
answer=scratchpad.agent_response,
|
225 |
+
messages_ids=message_file_ids,
|
226 |
+
llm_usage=usage_dict["usage"],
|
227 |
+
)
|
228 |
+
|
229 |
+
self.queue_manager.publish(
|
230 |
+
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
|
231 |
+
)
|
232 |
+
|
233 |
+
# update prompt tool message
|
234 |
+
for prompt_tool in self._prompt_messages_tools:
|
235 |
+
self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool)
|
236 |
+
|
237 |
+
iteration_step += 1
|
238 |
+
|
239 |
+
yield LLMResultChunk(
|
240 |
+
model=model_instance.model,
|
241 |
+
prompt_messages=prompt_messages,
|
242 |
+
delta=LLMResultChunkDelta(
|
243 |
+
index=0, message=AssistantPromptMessage(content=final_answer), usage=llm_usage["usage"]
|
244 |
+
),
|
245 |
+
system_fingerprint="",
|
246 |
+
)
|
247 |
+
|
248 |
+
# save agent thought
|
249 |
+
self.save_agent_thought(
|
250 |
+
agent_thought=agent_thought,
|
251 |
+
tool_name="",
|
252 |
+
tool_input={},
|
253 |
+
tool_invoke_meta={},
|
254 |
+
thought=final_answer,
|
255 |
+
observation={},
|
256 |
+
answer=final_answer,
|
257 |
+
messages_ids=[],
|
258 |
+
)
|
259 |
+
if self.variables_pool is not None and self.db_variables_pool is not None:
|
260 |
+
self.update_db_variables(self.variables_pool, self.db_variables_pool)
|
261 |
+
# publish end event
|
262 |
+
self.queue_manager.publish(
|
263 |
+
QueueMessageEndEvent(
|
264 |
+
llm_result=LLMResult(
|
265 |
+
model=model_instance.model,
|
266 |
+
prompt_messages=prompt_messages,
|
267 |
+
message=AssistantPromptMessage(content=final_answer),
|
268 |
+
usage=llm_usage["usage"] or LLMUsage.empty_usage(),
|
269 |
+
system_fingerprint="",
|
270 |
+
)
|
271 |
+
),
|
272 |
+
PublishFrom.APPLICATION_MANAGER,
|
273 |
+
)
|
274 |
+
|
275 |
+
def _handle_invoke_action(
|
276 |
+
self,
|
277 |
+
action: AgentScratchpadUnit.Action,
|
278 |
+
tool_instances: dict[str, Tool],
|
279 |
+
message_file_ids: list[str],
|
280 |
+
trace_manager: Optional[TraceQueueManager] = None,
|
281 |
+
) -> tuple[str, ToolInvokeMeta]:
|
282 |
+
"""
|
283 |
+
handle invoke action
|
284 |
+
:param action: action
|
285 |
+
:param tool_instances: tool instances
|
286 |
+
:param message_file_ids: message file ids
|
287 |
+
:param trace_manager: trace manager
|
288 |
+
:return: observation, meta
|
289 |
+
"""
|
290 |
+
# action is tool call, invoke tool
|
291 |
+
tool_call_name = action.action_name
|
292 |
+
tool_call_args = action.action_input
|
293 |
+
tool_instance = tool_instances.get(tool_call_name)
|
294 |
+
|
295 |
+
if not tool_instance:
|
296 |
+
answer = f"there is not a tool named {tool_call_name}"
|
297 |
+
return answer, ToolInvokeMeta.error_instance(answer)
|
298 |
+
|
299 |
+
if isinstance(tool_call_args, str):
|
300 |
+
try:
|
301 |
+
tool_call_args = json.loads(tool_call_args)
|
302 |
+
except json.JSONDecodeError:
|
303 |
+
pass
|
304 |
+
|
305 |
+
# invoke tool
|
306 |
+
tool_invoke_response, message_files, tool_invoke_meta = ToolEngine.agent_invoke(
|
307 |
+
tool=tool_instance,
|
308 |
+
tool_parameters=tool_call_args,
|
309 |
+
user_id=self.user_id,
|
310 |
+
tenant_id=self.tenant_id,
|
311 |
+
message=self.message,
|
312 |
+
invoke_from=self.application_generate_entity.invoke_from,
|
313 |
+
agent_tool_callback=self.agent_callback,
|
314 |
+
trace_manager=trace_manager,
|
315 |
+
)
|
316 |
+
|
317 |
+
# publish files
|
318 |
+
for message_file_id, save_as in message_files:
|
319 |
+
if save_as is not None and self.variables_pool:
|
320 |
+
# FIXME the save_as type is confusing, it should be a string or not
|
321 |
+
self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=str(save_as))
|
322 |
+
|
323 |
+
# publish message file
|
324 |
+
self.queue_manager.publish(
|
325 |
+
QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER
|
326 |
+
)
|
327 |
+
# add message file ids
|
328 |
+
message_file_ids.append(message_file_id)
|
329 |
+
|
330 |
+
return tool_invoke_response, tool_invoke_meta
|
331 |
+
|
332 |
+
def _convert_dict_to_action(self, action: dict) -> AgentScratchpadUnit.Action:
|
333 |
+
"""
|
334 |
+
convert dict to action
|
335 |
+
"""
|
336 |
+
return AgentScratchpadUnit.Action(action_name=action["action"], action_input=action["action_input"])
|
337 |
+
|
338 |
+
def _fill_in_inputs_from_external_data_tools(self, instruction: str, inputs: Mapping[str, Any]) -> str:
|
339 |
+
"""
|
340 |
+
fill in inputs from external data tools
|
341 |
+
"""
|
342 |
+
for key, value in inputs.items():
|
343 |
+
try:
|
344 |
+
instruction = instruction.replace(f"{{{{{key}}}}}", str(value))
|
345 |
+
except Exception as e:
|
346 |
+
continue
|
347 |
+
|
348 |
+
return instruction
|
349 |
+
|
350 |
+
def _init_react_state(self, query) -> None:
|
351 |
+
"""
|
352 |
+
init agent scratchpad
|
353 |
+
"""
|
354 |
+
self._query = query
|
355 |
+
self._agent_scratchpad = []
|
356 |
+
self._historic_prompt_messages = self._organize_historic_prompt_messages()
|
357 |
+
|
358 |
+
@abstractmethod
|
359 |
+
def _organize_prompt_messages(self) -> list[PromptMessage]:
|
360 |
+
"""
|
361 |
+
organize prompt messages
|
362 |
+
"""
|
363 |
+
|
364 |
+
def _format_assistant_message(self, agent_scratchpad: list[AgentScratchpadUnit]) -> str:
|
365 |
+
"""
|
366 |
+
format assistant message
|
367 |
+
"""
|
368 |
+
message = ""
|
369 |
+
for scratchpad in agent_scratchpad:
|
370 |
+
if scratchpad.is_final():
|
371 |
+
message += f"Final Answer: {scratchpad.agent_response}"
|
372 |
+
else:
|
373 |
+
message += f"Thought: {scratchpad.thought}\n\n"
|
374 |
+
if scratchpad.action_str:
|
375 |
+
message += f"Action: {scratchpad.action_str}\n\n"
|
376 |
+
if scratchpad.observation:
|
377 |
+
message += f"Observation: {scratchpad.observation}\n\n"
|
378 |
+
|
379 |
+
return message
|
380 |
+
|
381 |
+
def _organize_historic_prompt_messages(
|
382 |
+
self, current_session_messages: Optional[list[PromptMessage]] = None
|
383 |
+
) -> list[PromptMessage]:
|
384 |
+
"""
|
385 |
+
organize historic prompt messages
|
386 |
+
"""
|
387 |
+
result: list[PromptMessage] = []
|
388 |
+
scratchpads: list[AgentScratchpadUnit] = []
|
389 |
+
current_scratchpad: AgentScratchpadUnit | None = None
|
390 |
+
|
391 |
+
for message in self.history_prompt_messages:
|
392 |
+
if isinstance(message, AssistantPromptMessage):
|
393 |
+
if not current_scratchpad:
|
394 |
+
if not isinstance(message.content, str | None):
|
395 |
+
raise NotImplementedError("expected str type")
|
396 |
+
current_scratchpad = AgentScratchpadUnit(
|
397 |
+
agent_response=message.content,
|
398 |
+
thought=message.content or "I am thinking about how to help you",
|
399 |
+
action_str="",
|
400 |
+
action=None,
|
401 |
+
observation=None,
|
402 |
+
)
|
403 |
+
scratchpads.append(current_scratchpad)
|
404 |
+
if message.tool_calls:
|
405 |
+
try:
|
406 |
+
current_scratchpad.action = AgentScratchpadUnit.Action(
|
407 |
+
action_name=message.tool_calls[0].function.name,
|
408 |
+
action_input=json.loads(message.tool_calls[0].function.arguments),
|
409 |
+
)
|
410 |
+
current_scratchpad.action_str = json.dumps(current_scratchpad.action.to_dict())
|
411 |
+
except:
|
412 |
+
pass
|
413 |
+
elif isinstance(message, ToolPromptMessage):
|
414 |
+
if not current_scratchpad:
|
415 |
+
continue
|
416 |
+
if isinstance(message.content, str):
|
417 |
+
current_scratchpad.observation = message.content
|
418 |
+
else:
|
419 |
+
raise NotImplementedError("expected str type")
|
420 |
+
elif isinstance(message, UserPromptMessage):
|
421 |
+
if scratchpads:
|
422 |
+
result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads)))
|
423 |
+
scratchpads = []
|
424 |
+
current_scratchpad = None
|
425 |
+
|
426 |
+
result.append(message)
|
427 |
+
|
428 |
+
if scratchpads:
|
429 |
+
result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads)))
|
430 |
+
|
431 |
+
historic_prompts = AgentHistoryPromptTransform(
|
432 |
+
model_config=self.model_config,
|
433 |
+
prompt_messages=current_session_messages or [],
|
434 |
+
history_messages=result,
|
435 |
+
memory=self.memory,
|
436 |
+
).get_prompt()
|
437 |
+
return historic_prompts
|
api/core/agent/cot_chat_agent_runner.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
|
3 |
+
from core.agent.cot_agent_runner import CotAgentRunner
|
4 |
+
from core.file import file_manager
|
5 |
+
from core.model_runtime.entities import (
|
6 |
+
AssistantPromptMessage,
|
7 |
+
PromptMessage,
|
8 |
+
PromptMessageContent,
|
9 |
+
SystemPromptMessage,
|
10 |
+
TextPromptMessageContent,
|
11 |
+
UserPromptMessage,
|
12 |
+
)
|
13 |
+
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
|
14 |
+
from core.model_runtime.utils.encoders import jsonable_encoder
|
15 |
+
|
16 |
+
|
17 |
+
class CotChatAgentRunner(CotAgentRunner):
|
18 |
+
def _organize_system_prompt(self) -> SystemPromptMessage:
|
19 |
+
"""
|
20 |
+
Organize system prompt
|
21 |
+
"""
|
22 |
+
if not self.app_config.agent:
|
23 |
+
raise ValueError("Agent configuration is not set")
|
24 |
+
|
25 |
+
prompt_entity = self.app_config.agent.prompt
|
26 |
+
if not prompt_entity:
|
27 |
+
raise ValueError("Agent prompt configuration is not set")
|
28 |
+
first_prompt = prompt_entity.first_prompt
|
29 |
+
|
30 |
+
system_prompt = (
|
31 |
+
first_prompt.replace("{{instruction}}", self._instruction)
|
32 |
+
.replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools)))
|
33 |
+
.replace("{{tool_names}}", ", ".join([tool.name for tool in self._prompt_messages_tools]))
|
34 |
+
)
|
35 |
+
|
36 |
+
return SystemPromptMessage(content=system_prompt)
|
37 |
+
|
38 |
+
def _organize_user_query(self, query, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
39 |
+
"""
|
40 |
+
Organize user query
|
41 |
+
"""
|
42 |
+
if self.files:
|
43 |
+
prompt_message_contents: list[PromptMessageContent] = []
|
44 |
+
prompt_message_contents.append(TextPromptMessageContent(data=query))
|
45 |
+
|
46 |
+
# get image detail config
|
47 |
+
image_detail_config = (
|
48 |
+
self.application_generate_entity.file_upload_config.image_config.detail
|
49 |
+
if (
|
50 |
+
self.application_generate_entity.file_upload_config
|
51 |
+
and self.application_generate_entity.file_upload_config.image_config
|
52 |
+
)
|
53 |
+
else None
|
54 |
+
)
|
55 |
+
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
|
56 |
+
for file in self.files:
|
57 |
+
prompt_message_contents.append(
|
58 |
+
file_manager.to_prompt_message_content(
|
59 |
+
file,
|
60 |
+
image_detail_config=image_detail_config,
|
61 |
+
)
|
62 |
+
)
|
63 |
+
|
64 |
+
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
65 |
+
else:
|
66 |
+
prompt_messages.append(UserPromptMessage(content=query))
|
67 |
+
|
68 |
+
return prompt_messages
|
69 |
+
|
70 |
+
def _organize_prompt_messages(self) -> list[PromptMessage]:
|
71 |
+
"""
|
72 |
+
Organize
|
73 |
+
"""
|
74 |
+
# organize system prompt
|
75 |
+
system_message = self._organize_system_prompt()
|
76 |
+
|
77 |
+
# organize current assistant messages
|
78 |
+
agent_scratchpad = self._agent_scratchpad
|
79 |
+
if not agent_scratchpad:
|
80 |
+
assistant_messages = []
|
81 |
+
else:
|
82 |
+
assistant_message = AssistantPromptMessage(content="")
|
83 |
+
assistant_message.content = "" # FIXME: type check tell mypy that assistant_message.content is str
|
84 |
+
for unit in agent_scratchpad:
|
85 |
+
if unit.is_final():
|
86 |
+
assistant_message.content += f"Final Answer: {unit.agent_response}"
|
87 |
+
else:
|
88 |
+
assistant_message.content += f"Thought: {unit.thought}\n\n"
|
89 |
+
if unit.action_str:
|
90 |
+
assistant_message.content += f"Action: {unit.action_str}\n\n"
|
91 |
+
if unit.observation:
|
92 |
+
assistant_message.content += f"Observation: {unit.observation}\n\n"
|
93 |
+
|
94 |
+
assistant_messages = [assistant_message]
|
95 |
+
|
96 |
+
# query messages
|
97 |
+
query_messages = self._organize_user_query(self._query, [])
|
98 |
+
|
99 |
+
if assistant_messages:
|
100 |
+
# organize historic prompt messages
|
101 |
+
historic_messages = self._organize_historic_prompt_messages(
|
102 |
+
[system_message, *query_messages, *assistant_messages, UserPromptMessage(content="continue")]
|
103 |
+
)
|
104 |
+
messages = [
|
105 |
+
system_message,
|
106 |
+
*historic_messages,
|
107 |
+
*query_messages,
|
108 |
+
*assistant_messages,
|
109 |
+
UserPromptMessage(content="continue"),
|
110 |
+
]
|
111 |
+
else:
|
112 |
+
# organize historic prompt messages
|
113 |
+
historic_messages = self._organize_historic_prompt_messages([system_message, *query_messages])
|
114 |
+
messages = [system_message, *historic_messages, *query_messages]
|
115 |
+
|
116 |
+
# join all messages
|
117 |
+
return messages
|
api/core/agent/cot_completion_agent_runner.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from typing import Optional
|
3 |
+
|
4 |
+
from core.agent.cot_agent_runner import CotAgentRunner
|
5 |
+
from core.model_runtime.entities.message_entities import (
|
6 |
+
AssistantPromptMessage,
|
7 |
+
PromptMessage,
|
8 |
+
TextPromptMessageContent,
|
9 |
+
UserPromptMessage,
|
10 |
+
)
|
11 |
+
from core.model_runtime.utils.encoders import jsonable_encoder
|
12 |
+
|
13 |
+
|
14 |
+
class CotCompletionAgentRunner(CotAgentRunner):
|
15 |
+
def _organize_instruction_prompt(self) -> str:
|
16 |
+
"""
|
17 |
+
Organize instruction prompt
|
18 |
+
"""
|
19 |
+
if self.app_config.agent is None:
|
20 |
+
raise ValueError("Agent configuration is not set")
|
21 |
+
prompt_entity = self.app_config.agent.prompt
|
22 |
+
if prompt_entity is None:
|
23 |
+
raise ValueError("prompt entity is not set")
|
24 |
+
first_prompt = prompt_entity.first_prompt
|
25 |
+
|
26 |
+
system_prompt = (
|
27 |
+
first_prompt.replace("{{instruction}}", self._instruction)
|
28 |
+
.replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools)))
|
29 |
+
.replace("{{tool_names}}", ", ".join([tool.name for tool in self._prompt_messages_tools]))
|
30 |
+
)
|
31 |
+
|
32 |
+
return system_prompt
|
33 |
+
|
34 |
+
def _organize_historic_prompt(self, current_session_messages: Optional[list[PromptMessage]] = None) -> str:
|
35 |
+
"""
|
36 |
+
Organize historic prompt
|
37 |
+
"""
|
38 |
+
historic_prompt_messages = self._organize_historic_prompt_messages(current_session_messages)
|
39 |
+
historic_prompt = ""
|
40 |
+
|
41 |
+
for message in historic_prompt_messages:
|
42 |
+
if isinstance(message, UserPromptMessage):
|
43 |
+
historic_prompt += f"Question: {message.content}\n\n"
|
44 |
+
elif isinstance(message, AssistantPromptMessage):
|
45 |
+
if isinstance(message.content, str):
|
46 |
+
historic_prompt += message.content + "\n\n"
|
47 |
+
elif isinstance(message.content, list):
|
48 |
+
for content in message.content:
|
49 |
+
if not isinstance(content, TextPromptMessageContent):
|
50 |
+
continue
|
51 |
+
historic_prompt += content.data
|
52 |
+
|
53 |
+
return historic_prompt
|
54 |
+
|
55 |
+
def _organize_prompt_messages(self) -> list[PromptMessage]:
|
56 |
+
"""
|
57 |
+
Organize prompt messages
|
58 |
+
"""
|
59 |
+
# organize system prompt
|
60 |
+
system_prompt = self._organize_instruction_prompt()
|
61 |
+
|
62 |
+
# organize historic prompt messages
|
63 |
+
historic_prompt = self._organize_historic_prompt()
|
64 |
+
|
65 |
+
# organize current assistant messages
|
66 |
+
agent_scratchpad = self._agent_scratchpad
|
67 |
+
assistant_prompt = ""
|
68 |
+
for unit in agent_scratchpad or []:
|
69 |
+
if unit.is_final():
|
70 |
+
assistant_prompt += f"Final Answer: {unit.agent_response}"
|
71 |
+
else:
|
72 |
+
assistant_prompt += f"Thought: {unit.thought}\n\n"
|
73 |
+
if unit.action_str:
|
74 |
+
assistant_prompt += f"Action: {unit.action_str}\n\n"
|
75 |
+
if unit.observation:
|
76 |
+
assistant_prompt += f"Observation: {unit.observation}\n\n"
|
77 |
+
|
78 |
+
# query messages
|
79 |
+
query_prompt = f"Question: {self._query}"
|
80 |
+
|
81 |
+
# join all messages
|
82 |
+
prompt = (
|
83 |
+
system_prompt.replace("{{historic_messages}}", historic_prompt)
|
84 |
+
.replace("{{agent_scratchpad}}", assistant_prompt)
|
85 |
+
.replace("{{query}}", query_prompt)
|
86 |
+
)
|
87 |
+
|
88 |
+
return [UserPromptMessage(content=prompt)]
|
api/core/agent/entities.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from enum import Enum
|
2 |
+
from typing import Any, Literal, Optional, Union
|
3 |
+
|
4 |
+
from pydantic import BaseModel
|
5 |
+
|
6 |
+
|
7 |
+
class AgentToolEntity(BaseModel):
|
8 |
+
"""
|
9 |
+
Agent Tool Entity.
|
10 |
+
"""
|
11 |
+
|
12 |
+
provider_type: Literal["builtin", "api", "workflow"]
|
13 |
+
provider_id: str
|
14 |
+
tool_name: str
|
15 |
+
tool_parameters: dict[str, Any] = {}
|
16 |
+
|
17 |
+
|
18 |
+
class AgentPromptEntity(BaseModel):
|
19 |
+
"""
|
20 |
+
Agent Prompt Entity.
|
21 |
+
"""
|
22 |
+
|
23 |
+
first_prompt: str
|
24 |
+
next_iteration: str
|
25 |
+
|
26 |
+
|
27 |
+
class AgentScratchpadUnit(BaseModel):
|
28 |
+
"""
|
29 |
+
Agent First Prompt Entity.
|
30 |
+
"""
|
31 |
+
|
32 |
+
class Action(BaseModel):
|
33 |
+
"""
|
34 |
+
Action Entity.
|
35 |
+
"""
|
36 |
+
|
37 |
+
action_name: str
|
38 |
+
action_input: Union[dict, str]
|
39 |
+
|
40 |
+
def to_dict(self) -> dict:
|
41 |
+
"""
|
42 |
+
Convert to dictionary.
|
43 |
+
"""
|
44 |
+
return {
|
45 |
+
"action": self.action_name,
|
46 |
+
"action_input": self.action_input,
|
47 |
+
}
|
48 |
+
|
49 |
+
agent_response: Optional[str] = None
|
50 |
+
thought: Optional[str] = None
|
51 |
+
action_str: Optional[str] = None
|
52 |
+
observation: Optional[str] = None
|
53 |
+
action: Optional[Action] = None
|
54 |
+
|
55 |
+
def is_final(self) -> bool:
|
56 |
+
"""
|
57 |
+
Check if the scratchpad unit is final.
|
58 |
+
"""
|
59 |
+
return self.action is None or (
|
60 |
+
"final" in self.action.action_name.lower() and "answer" in self.action.action_name.lower()
|
61 |
+
)
|
62 |
+
|
63 |
+
|
64 |
+
class AgentEntity(BaseModel):
|
65 |
+
"""
|
66 |
+
Agent Entity.
|
67 |
+
"""
|
68 |
+
|
69 |
+
class Strategy(Enum):
|
70 |
+
"""
|
71 |
+
Agent Strategy.
|
72 |
+
"""
|
73 |
+
|
74 |
+
CHAIN_OF_THOUGHT = "chain-of-thought"
|
75 |
+
FUNCTION_CALLING = "function-calling"
|
76 |
+
|
77 |
+
provider: str
|
78 |
+
model: str
|
79 |
+
strategy: Strategy
|
80 |
+
prompt: Optional[AgentPromptEntity] = None
|
81 |
+
tools: list[AgentToolEntity] | None = None
|
82 |
+
max_iteration: int = 5
|
api/core/agent/fc_agent_runner.py
ADDED
@@ -0,0 +1,473 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import logging
|
3 |
+
from collections.abc import Generator
|
4 |
+
from copy import deepcopy
|
5 |
+
from typing import Any, Optional, Union
|
6 |
+
|
7 |
+
from core.agent.base_agent_runner import BaseAgentRunner
|
8 |
+
from core.app.apps.base_app_queue_manager import PublishFrom
|
9 |
+
from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent
|
10 |
+
from core.file import file_manager
|
11 |
+
from core.model_runtime.entities import (
|
12 |
+
AssistantPromptMessage,
|
13 |
+
LLMResult,
|
14 |
+
LLMResultChunk,
|
15 |
+
LLMResultChunkDelta,
|
16 |
+
LLMUsage,
|
17 |
+
PromptMessage,
|
18 |
+
PromptMessageContent,
|
19 |
+
PromptMessageContentType,
|
20 |
+
SystemPromptMessage,
|
21 |
+
TextPromptMessageContent,
|
22 |
+
ToolPromptMessage,
|
23 |
+
UserPromptMessage,
|
24 |
+
)
|
25 |
+
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
|
26 |
+
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
|
27 |
+
from core.tools.entities.tool_entities import ToolInvokeMeta
|
28 |
+
from core.tools.tool_engine import ToolEngine
|
29 |
+
from models.model import Message
|
30 |
+
|
31 |
+
logger = logging.getLogger(__name__)
|
32 |
+
|
33 |
+
|
34 |
+
class FunctionCallAgentRunner(BaseAgentRunner):
|
35 |
+
def run(self, message: Message, query: str, **kwargs: Any) -> Generator[LLMResultChunk, None, None]:
|
36 |
+
"""
|
37 |
+
Run FunctionCall agent application
|
38 |
+
"""
|
39 |
+
self.query = query
|
40 |
+
app_generate_entity = self.application_generate_entity
|
41 |
+
|
42 |
+
app_config = self.app_config
|
43 |
+
assert app_config is not None, "app_config is required"
|
44 |
+
assert app_config.agent is not None, "app_config.agent is required"
|
45 |
+
|
46 |
+
# convert tools into ModelRuntime Tool format
|
47 |
+
tool_instances, prompt_messages_tools = self._init_prompt_tools()
|
48 |
+
|
49 |
+
iteration_step = 1
|
50 |
+
max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1
|
51 |
+
|
52 |
+
# continue to run until there is not any tool call
|
53 |
+
function_call_state = True
|
54 |
+
llm_usage: dict[str, LLMUsage] = {"usage": LLMUsage.empty_usage()}
|
55 |
+
final_answer = ""
|
56 |
+
|
57 |
+
# get tracing instance
|
58 |
+
trace_manager = app_generate_entity.trace_manager
|
59 |
+
|
60 |
+
def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
|
61 |
+
if not final_llm_usage_dict["usage"]:
|
62 |
+
final_llm_usage_dict["usage"] = usage
|
63 |
+
else:
|
64 |
+
llm_usage = final_llm_usage_dict["usage"]
|
65 |
+
llm_usage.prompt_tokens += usage.prompt_tokens
|
66 |
+
llm_usage.completion_tokens += usage.completion_tokens
|
67 |
+
llm_usage.prompt_price += usage.prompt_price
|
68 |
+
llm_usage.completion_price += usage.completion_price
|
69 |
+
llm_usage.total_price += usage.total_price
|
70 |
+
|
71 |
+
model_instance = self.model_instance
|
72 |
+
|
73 |
+
while function_call_state and iteration_step <= max_iteration_steps:
|
74 |
+
function_call_state = False
|
75 |
+
|
76 |
+
if iteration_step == max_iteration_steps:
|
77 |
+
# the last iteration, remove all tools
|
78 |
+
prompt_messages_tools = []
|
79 |
+
|
80 |
+
message_file_ids: list[str] = []
|
81 |
+
agent_thought = self.create_agent_thought(
|
82 |
+
message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids
|
83 |
+
)
|
84 |
+
|
85 |
+
# recalc llm max tokens
|
86 |
+
prompt_messages = self._organize_prompt_messages()
|
87 |
+
self.recalc_llm_max_tokens(self.model_config, prompt_messages)
|
88 |
+
# invoke model
|
89 |
+
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = model_instance.invoke_llm(
|
90 |
+
prompt_messages=prompt_messages,
|
91 |
+
model_parameters=app_generate_entity.model_conf.parameters,
|
92 |
+
tools=prompt_messages_tools,
|
93 |
+
stop=app_generate_entity.model_conf.stop,
|
94 |
+
stream=self.stream_tool_call,
|
95 |
+
user=self.user_id,
|
96 |
+
callbacks=[],
|
97 |
+
)
|
98 |
+
|
99 |
+
tool_calls: list[tuple[str, str, dict[str, Any]]] = []
|
100 |
+
|
101 |
+
# save full response
|
102 |
+
response = ""
|
103 |
+
|
104 |
+
# save tool call names and inputs
|
105 |
+
tool_call_names = ""
|
106 |
+
tool_call_inputs = ""
|
107 |
+
|
108 |
+
current_llm_usage = None
|
109 |
+
|
110 |
+
if self.stream_tool_call and isinstance(chunks, Generator):
|
111 |
+
is_first_chunk = True
|
112 |
+
for chunk in chunks:
|
113 |
+
if is_first_chunk:
|
114 |
+
self.queue_manager.publish(
|
115 |
+
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
|
116 |
+
)
|
117 |
+
is_first_chunk = False
|
118 |
+
# check if there is any tool call
|
119 |
+
if self.check_tool_calls(chunk):
|
120 |
+
function_call_state = True
|
121 |
+
tool_calls.extend(self.extract_tool_calls(chunk) or [])
|
122 |
+
tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls])
|
123 |
+
try:
|
124 |
+
tool_call_inputs = json.dumps(
|
125 |
+
{tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False
|
126 |
+
)
|
127 |
+
except json.JSONDecodeError as e:
|
128 |
+
# ensure ascii to avoid encoding error
|
129 |
+
tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls})
|
130 |
+
|
131 |
+
if chunk.delta.message and chunk.delta.message.content:
|
132 |
+
if isinstance(chunk.delta.message.content, list):
|
133 |
+
for content in chunk.delta.message.content:
|
134 |
+
response += content.data
|
135 |
+
else:
|
136 |
+
response += str(chunk.delta.message.content)
|
137 |
+
|
138 |
+
if chunk.delta.usage:
|
139 |
+
increase_usage(llm_usage, chunk.delta.usage)
|
140 |
+
current_llm_usage = chunk.delta.usage
|
141 |
+
|
142 |
+
yield chunk
|
143 |
+
elif not self.stream_tool_call and isinstance(chunks, LLMResult):
|
144 |
+
result = chunks
|
145 |
+
# check if there is any tool call
|
146 |
+
if self.check_blocking_tool_calls(result):
|
147 |
+
function_call_state = True
|
148 |
+
tool_calls.extend(self.extract_blocking_tool_calls(result) or [])
|
149 |
+
tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls])
|
150 |
+
try:
|
151 |
+
tool_call_inputs = json.dumps(
|
152 |
+
{tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False
|
153 |
+
)
|
154 |
+
except json.JSONDecodeError as e:
|
155 |
+
# ensure ascii to avoid encoding error
|
156 |
+
tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls})
|
157 |
+
|
158 |
+
if result.usage:
|
159 |
+
increase_usage(llm_usage, result.usage)
|
160 |
+
current_llm_usage = result.usage
|
161 |
+
|
162 |
+
if result.message and result.message.content:
|
163 |
+
if isinstance(result.message.content, list):
|
164 |
+
for content in result.message.content:
|
165 |
+
response += content.data
|
166 |
+
else:
|
167 |
+
response += str(result.message.content)
|
168 |
+
|
169 |
+
if not result.message.content:
|
170 |
+
result.message.content = ""
|
171 |
+
|
172 |
+
self.queue_manager.publish(
|
173 |
+
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
|
174 |
+
)
|
175 |
+
|
176 |
+
yield LLMResultChunk(
|
177 |
+
model=model_instance.model,
|
178 |
+
prompt_messages=result.prompt_messages,
|
179 |
+
system_fingerprint=result.system_fingerprint,
|
180 |
+
delta=LLMResultChunkDelta(
|
181 |
+
index=0,
|
182 |
+
message=result.message,
|
183 |
+
usage=result.usage,
|
184 |
+
),
|
185 |
+
)
|
186 |
+
else:
|
187 |
+
raise RuntimeError(f"invalid chunks type: {type(chunks)}")
|
188 |
+
|
189 |
+
assistant_message = AssistantPromptMessage(content="", tool_calls=[])
|
190 |
+
if tool_calls:
|
191 |
+
assistant_message.tool_calls = [
|
192 |
+
AssistantPromptMessage.ToolCall(
|
193 |
+
id=tool_call[0],
|
194 |
+
type="function",
|
195 |
+
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
196 |
+
name=tool_call[1], arguments=json.dumps(tool_call[2], ensure_ascii=False)
|
197 |
+
),
|
198 |
+
)
|
199 |
+
for tool_call in tool_calls
|
200 |
+
]
|
201 |
+
else:
|
202 |
+
assistant_message.content = response
|
203 |
+
|
204 |
+
self._current_thoughts.append(assistant_message)
|
205 |
+
|
206 |
+
# save thought
|
207 |
+
self.save_agent_thought(
|
208 |
+
agent_thought=agent_thought,
|
209 |
+
tool_name=tool_call_names,
|
210 |
+
tool_input=tool_call_inputs,
|
211 |
+
thought=response,
|
212 |
+
tool_invoke_meta=None,
|
213 |
+
observation=None,
|
214 |
+
answer=response,
|
215 |
+
messages_ids=[],
|
216 |
+
llm_usage=current_llm_usage,
|
217 |
+
)
|
218 |
+
self.queue_manager.publish(
|
219 |
+
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
|
220 |
+
)
|
221 |
+
|
222 |
+
final_answer += response + "\n"
|
223 |
+
|
224 |
+
# call tools
|
225 |
+
tool_responses = []
|
226 |
+
for tool_call_id, tool_call_name, tool_call_args in tool_calls:
|
227 |
+
tool_instance = tool_instances.get(tool_call_name)
|
228 |
+
if not tool_instance:
|
229 |
+
tool_response = {
|
230 |
+
"tool_call_id": tool_call_id,
|
231 |
+
"tool_call_name": tool_call_name,
|
232 |
+
"tool_response": f"there is not a tool named {tool_call_name}",
|
233 |
+
"meta": ToolInvokeMeta.error_instance(f"there is not a tool named {tool_call_name}").to_dict(),
|
234 |
+
}
|
235 |
+
else:
|
236 |
+
# invoke tool
|
237 |
+
tool_invoke_response, message_files, tool_invoke_meta = ToolEngine.agent_invoke(
|
238 |
+
tool=tool_instance,
|
239 |
+
tool_parameters=tool_call_args,
|
240 |
+
user_id=self.user_id,
|
241 |
+
tenant_id=self.tenant_id,
|
242 |
+
message=self.message,
|
243 |
+
invoke_from=self.application_generate_entity.invoke_from,
|
244 |
+
agent_tool_callback=self.agent_callback,
|
245 |
+
trace_manager=trace_manager,
|
246 |
+
)
|
247 |
+
# publish files
|
248 |
+
for message_file_id, save_as in message_files:
|
249 |
+
if save_as:
|
250 |
+
if self.variables_pool:
|
251 |
+
self.variables_pool.set_file(
|
252 |
+
tool_name=tool_call_name, value=message_file_id, name=save_as
|
253 |
+
)
|
254 |
+
|
255 |
+
# publish message file
|
256 |
+
self.queue_manager.publish(
|
257 |
+
QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER
|
258 |
+
)
|
259 |
+
# add message file ids
|
260 |
+
message_file_ids.append(message_file_id)
|
261 |
+
|
262 |
+
tool_response = {
|
263 |
+
"tool_call_id": tool_call_id,
|
264 |
+
"tool_call_name": tool_call_name,
|
265 |
+
"tool_response": tool_invoke_response,
|
266 |
+
"meta": tool_invoke_meta.to_dict(),
|
267 |
+
}
|
268 |
+
|
269 |
+
tool_responses.append(tool_response)
|
270 |
+
if tool_response["tool_response"] is not None:
|
271 |
+
self._current_thoughts.append(
|
272 |
+
ToolPromptMessage(
|
273 |
+
content=str(tool_response["tool_response"]),
|
274 |
+
tool_call_id=tool_call_id,
|
275 |
+
name=tool_call_name,
|
276 |
+
)
|
277 |
+
)
|
278 |
+
|
279 |
+
if len(tool_responses) > 0:
|
280 |
+
# save agent thought
|
281 |
+
self.save_agent_thought(
|
282 |
+
agent_thought=agent_thought,
|
283 |
+
tool_name="",
|
284 |
+
tool_input="",
|
285 |
+
thought="",
|
286 |
+
tool_invoke_meta={
|
287 |
+
tool_response["tool_call_name"]: tool_response["meta"] for tool_response in tool_responses
|
288 |
+
},
|
289 |
+
observation={
|
290 |
+
tool_response["tool_call_name"]: tool_response["tool_response"]
|
291 |
+
for tool_response in tool_responses
|
292 |
+
},
|
293 |
+
answer="",
|
294 |
+
messages_ids=message_file_ids,
|
295 |
+
)
|
296 |
+
self.queue_manager.publish(
|
297 |
+
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
|
298 |
+
)
|
299 |
+
|
300 |
+
# update prompt tool
|
301 |
+
for prompt_tool in prompt_messages_tools:
|
302 |
+
self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool)
|
303 |
+
|
304 |
+
iteration_step += 1
|
305 |
+
|
306 |
+
if self.variables_pool and self.db_variables_pool:
|
307 |
+
self.update_db_variables(self.variables_pool, self.db_variables_pool)
|
308 |
+
# publish end event
|
309 |
+
self.queue_manager.publish(
|
310 |
+
QueueMessageEndEvent(
|
311 |
+
llm_result=LLMResult(
|
312 |
+
model=model_instance.model,
|
313 |
+
prompt_messages=prompt_messages,
|
314 |
+
message=AssistantPromptMessage(content=final_answer),
|
315 |
+
usage=llm_usage["usage"] or LLMUsage.empty_usage(),
|
316 |
+
system_fingerprint="",
|
317 |
+
)
|
318 |
+
),
|
319 |
+
PublishFrom.APPLICATION_MANAGER,
|
320 |
+
)
|
321 |
+
|
322 |
+
def check_tool_calls(self, llm_result_chunk: LLMResultChunk) -> bool:
|
323 |
+
"""
|
324 |
+
Check if there is any tool call in llm result chunk
|
325 |
+
"""
|
326 |
+
if llm_result_chunk.delta.message.tool_calls:
|
327 |
+
return True
|
328 |
+
return False
|
329 |
+
|
330 |
+
def check_blocking_tool_calls(self, llm_result: LLMResult) -> bool:
|
331 |
+
"""
|
332 |
+
Check if there is any blocking tool call in llm result
|
333 |
+
"""
|
334 |
+
if llm_result.message.tool_calls:
|
335 |
+
return True
|
336 |
+
return False
|
337 |
+
|
338 |
+
def extract_tool_calls(
|
339 |
+
self, llm_result_chunk: LLMResultChunk
|
340 |
+
) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
|
341 |
+
"""
|
342 |
+
Extract tool calls from llm result chunk
|
343 |
+
|
344 |
+
Returns:
|
345 |
+
List[Tuple[str, str, Dict[str, Any]]]: [(tool_call_id, tool_call_name, tool_call_args)]
|
346 |
+
"""
|
347 |
+
tool_calls = []
|
348 |
+
for prompt_message in llm_result_chunk.delta.message.tool_calls:
|
349 |
+
args = {}
|
350 |
+
if prompt_message.function.arguments != "":
|
351 |
+
args = json.loads(prompt_message.function.arguments)
|
352 |
+
|
353 |
+
tool_calls.append(
|
354 |
+
(
|
355 |
+
prompt_message.id,
|
356 |
+
prompt_message.function.name,
|
357 |
+
args,
|
358 |
+
)
|
359 |
+
)
|
360 |
+
|
361 |
+
return tool_calls
|
362 |
+
|
363 |
+
def extract_blocking_tool_calls(self, llm_result: LLMResult) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
|
364 |
+
"""
|
365 |
+
Extract blocking tool calls from llm result
|
366 |
+
|
367 |
+
Returns:
|
368 |
+
List[Tuple[str, str, Dict[str, Any]]]: [(tool_call_id, tool_call_name, tool_call_args)]
|
369 |
+
"""
|
370 |
+
tool_calls = []
|
371 |
+
for prompt_message in llm_result.message.tool_calls:
|
372 |
+
args = {}
|
373 |
+
if prompt_message.function.arguments != "":
|
374 |
+
args = json.loads(prompt_message.function.arguments)
|
375 |
+
|
376 |
+
tool_calls.append(
|
377 |
+
(
|
378 |
+
prompt_message.id,
|
379 |
+
prompt_message.function.name,
|
380 |
+
args,
|
381 |
+
)
|
382 |
+
)
|
383 |
+
|
384 |
+
return tool_calls
|
385 |
+
|
386 |
+
def _init_system_message(
|
387 |
+
self, prompt_template: str, prompt_messages: Optional[list[PromptMessage]] = None
|
388 |
+
) -> list[PromptMessage]:
|
389 |
+
"""
|
390 |
+
Initialize system message
|
391 |
+
"""
|
392 |
+
if not prompt_messages and prompt_template:
|
393 |
+
return [
|
394 |
+
SystemPromptMessage(content=prompt_template),
|
395 |
+
]
|
396 |
+
|
397 |
+
if prompt_messages and not isinstance(prompt_messages[0], SystemPromptMessage) and prompt_template:
|
398 |
+
prompt_messages.insert(0, SystemPromptMessage(content=prompt_template))
|
399 |
+
|
400 |
+
return prompt_messages or []
|
401 |
+
|
402 |
+
def _organize_user_query(self, query: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
403 |
+
"""
|
404 |
+
Organize user query
|
405 |
+
"""
|
406 |
+
if self.files:
|
407 |
+
prompt_message_contents: list[PromptMessageContent] = []
|
408 |
+
prompt_message_contents.append(TextPromptMessageContent(data=query))
|
409 |
+
|
410 |
+
# get image detail config
|
411 |
+
image_detail_config = (
|
412 |
+
self.application_generate_entity.file_upload_config.image_config.detail
|
413 |
+
if (
|
414 |
+
self.application_generate_entity.file_upload_config
|
415 |
+
and self.application_generate_entity.file_upload_config.image_config
|
416 |
+
)
|
417 |
+
else None
|
418 |
+
)
|
419 |
+
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
|
420 |
+
for file in self.files:
|
421 |
+
prompt_message_contents.append(
|
422 |
+
file_manager.to_prompt_message_content(
|
423 |
+
file,
|
424 |
+
image_detail_config=image_detail_config,
|
425 |
+
)
|
426 |
+
)
|
427 |
+
|
428 |
+
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
429 |
+
else:
|
430 |
+
prompt_messages.append(UserPromptMessage(content=query))
|
431 |
+
|
432 |
+
return prompt_messages
|
433 |
+
|
434 |
+
def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
435 |
+
"""
|
436 |
+
As for now, gpt supports both fc and vision at the first iteration.
|
437 |
+
We need to remove the image messages from the prompt messages at the first iteration.
|
438 |
+
"""
|
439 |
+
prompt_messages = deepcopy(prompt_messages)
|
440 |
+
|
441 |
+
for prompt_message in prompt_messages:
|
442 |
+
if isinstance(prompt_message, UserPromptMessage):
|
443 |
+
if isinstance(prompt_message.content, list):
|
444 |
+
prompt_message.content = "\n".join(
|
445 |
+
[
|
446 |
+
content.data
|
447 |
+
if content.type == PromptMessageContentType.TEXT
|
448 |
+
else "[image]"
|
449 |
+
if content.type == PromptMessageContentType.IMAGE
|
450 |
+
else "[file]"
|
451 |
+
for content in prompt_message.content
|
452 |
+
]
|
453 |
+
)
|
454 |
+
|
455 |
+
return prompt_messages
|
456 |
+
|
457 |
+
def _organize_prompt_messages(self):
|
458 |
+
prompt_template = self.app_config.prompt_template.simple_prompt_template or ""
|
459 |
+
self.history_prompt_messages = self._init_system_message(prompt_template, self.history_prompt_messages)
|
460 |
+
query_prompt_messages = self._organize_user_query(self.query or "", [])
|
461 |
+
|
462 |
+
self.history_prompt_messages = AgentHistoryPromptTransform(
|
463 |
+
model_config=self.model_config,
|
464 |
+
prompt_messages=[*query_prompt_messages, *self._current_thoughts],
|
465 |
+
history_messages=self.history_prompt_messages,
|
466 |
+
memory=self.memory,
|
467 |
+
).get_prompt()
|
468 |
+
|
469 |
+
prompt_messages = [*self.history_prompt_messages, *query_prompt_messages, *self._current_thoughts]
|
470 |
+
if len(self._current_thoughts) != 0:
|
471 |
+
# clear messages after the first iteration
|
472 |
+
prompt_messages = self._clear_user_prompt_image_messages(prompt_messages)
|
473 |
+
return prompt_messages
|
api/core/agent/output_parser/cot_output_parser.py
ADDED
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import re
|
3 |
+
from collections.abc import Generator
|
4 |
+
from typing import Union
|
5 |
+
|
6 |
+
from core.agent.entities import AgentScratchpadUnit
|
7 |
+
from core.model_runtime.entities.llm_entities import LLMResultChunk
|
8 |
+
|
9 |
+
|
10 |
+
class CotAgentOutputParser:
|
11 |
+
@classmethod
|
12 |
+
def handle_react_stream_output(
|
13 |
+
cls, llm_response: Generator[LLMResultChunk, None, None], usage_dict: dict
|
14 |
+
) -> Generator[Union[str, AgentScratchpadUnit.Action], None, None]:
|
15 |
+
def parse_action(json_str):
|
16 |
+
try:
|
17 |
+
action = json.loads(json_str, strict=False)
|
18 |
+
action_name = None
|
19 |
+
action_input = None
|
20 |
+
|
21 |
+
# cohere always returns a list
|
22 |
+
if isinstance(action, list) and len(action) == 1:
|
23 |
+
action = action[0]
|
24 |
+
|
25 |
+
for key, value in action.items():
|
26 |
+
if "input" in key.lower():
|
27 |
+
action_input = value
|
28 |
+
else:
|
29 |
+
action_name = value
|
30 |
+
|
31 |
+
if action_name is not None and action_input is not None:
|
32 |
+
return AgentScratchpadUnit.Action(
|
33 |
+
action_name=action_name,
|
34 |
+
action_input=action_input,
|
35 |
+
)
|
36 |
+
else:
|
37 |
+
return json_str or ""
|
38 |
+
except:
|
39 |
+
return json_str or ""
|
40 |
+
|
41 |
+
def extra_json_from_code_block(code_block) -> Generator[Union[str, AgentScratchpadUnit.Action], None, None]:
|
42 |
+
code_blocks = re.findall(r"```(.*?)```", code_block, re.DOTALL)
|
43 |
+
if not code_blocks:
|
44 |
+
return
|
45 |
+
for block in code_blocks:
|
46 |
+
json_text = re.sub(r"^[a-zA-Z]+\n", "", block.strip(), flags=re.MULTILINE)
|
47 |
+
yield parse_action(json_text)
|
48 |
+
|
49 |
+
code_block_cache = ""
|
50 |
+
code_block_delimiter_count = 0
|
51 |
+
in_code_block = False
|
52 |
+
json_cache = ""
|
53 |
+
json_quote_count = 0
|
54 |
+
in_json = False
|
55 |
+
got_json = False
|
56 |
+
|
57 |
+
action_cache = ""
|
58 |
+
action_str = "action:"
|
59 |
+
action_idx = 0
|
60 |
+
|
61 |
+
thought_cache = ""
|
62 |
+
thought_str = "thought:"
|
63 |
+
thought_idx = 0
|
64 |
+
|
65 |
+
last_character = ""
|
66 |
+
|
67 |
+
for response in llm_response:
|
68 |
+
if response.delta.usage:
|
69 |
+
usage_dict["usage"] = response.delta.usage
|
70 |
+
response_content = response.delta.message.content
|
71 |
+
if not isinstance(response_content, str):
|
72 |
+
continue
|
73 |
+
|
74 |
+
# stream
|
75 |
+
index = 0
|
76 |
+
while index < len(response_content):
|
77 |
+
steps = 1
|
78 |
+
delta = response_content[index : index + steps]
|
79 |
+
yield_delta = False
|
80 |
+
|
81 |
+
if delta == "`":
|
82 |
+
last_character = delta
|
83 |
+
code_block_cache += delta
|
84 |
+
code_block_delimiter_count += 1
|
85 |
+
else:
|
86 |
+
if not in_code_block:
|
87 |
+
if code_block_delimiter_count > 0:
|
88 |
+
last_character = delta
|
89 |
+
yield code_block_cache
|
90 |
+
code_block_cache = ""
|
91 |
+
else:
|
92 |
+
last_character = delta
|
93 |
+
code_block_cache += delta
|
94 |
+
code_block_delimiter_count = 0
|
95 |
+
|
96 |
+
if not in_code_block and not in_json:
|
97 |
+
if delta.lower() == action_str[action_idx] and action_idx == 0:
|
98 |
+
if last_character not in {"\n", " ", ""}:
|
99 |
+
yield_delta = True
|
100 |
+
else:
|
101 |
+
last_character = delta
|
102 |
+
action_cache += delta
|
103 |
+
action_idx += 1
|
104 |
+
if action_idx == len(action_str):
|
105 |
+
action_cache = ""
|
106 |
+
action_idx = 0
|
107 |
+
index += steps
|
108 |
+
continue
|
109 |
+
elif delta.lower() == action_str[action_idx] and action_idx > 0:
|
110 |
+
last_character = delta
|
111 |
+
action_cache += delta
|
112 |
+
action_idx += 1
|
113 |
+
if action_idx == len(action_str):
|
114 |
+
action_cache = ""
|
115 |
+
action_idx = 0
|
116 |
+
index += steps
|
117 |
+
continue
|
118 |
+
else:
|
119 |
+
if action_cache:
|
120 |
+
last_character = delta
|
121 |
+
yield action_cache
|
122 |
+
action_cache = ""
|
123 |
+
action_idx = 0
|
124 |
+
|
125 |
+
if delta.lower() == thought_str[thought_idx] and thought_idx == 0:
|
126 |
+
if last_character not in {"\n", " ", ""}:
|
127 |
+
yield_delta = True
|
128 |
+
else:
|
129 |
+
last_character = delta
|
130 |
+
thought_cache += delta
|
131 |
+
thought_idx += 1
|
132 |
+
if thought_idx == len(thought_str):
|
133 |
+
thought_cache = ""
|
134 |
+
thought_idx = 0
|
135 |
+
index += steps
|
136 |
+
continue
|
137 |
+
elif delta.lower() == thought_str[thought_idx] and thought_idx > 0:
|
138 |
+
last_character = delta
|
139 |
+
thought_cache += delta
|
140 |
+
thought_idx += 1
|
141 |
+
if thought_idx == len(thought_str):
|
142 |
+
thought_cache = ""
|
143 |
+
thought_idx = 0
|
144 |
+
index += steps
|
145 |
+
continue
|
146 |
+
else:
|
147 |
+
if thought_cache:
|
148 |
+
last_character = delta
|
149 |
+
yield thought_cache
|
150 |
+
thought_cache = ""
|
151 |
+
thought_idx = 0
|
152 |
+
|
153 |
+
if yield_delta:
|
154 |
+
index += steps
|
155 |
+
last_character = delta
|
156 |
+
yield delta
|
157 |
+
continue
|
158 |
+
|
159 |
+
if code_block_delimiter_count == 3:
|
160 |
+
if in_code_block:
|
161 |
+
last_character = delta
|
162 |
+
yield from extra_json_from_code_block(code_block_cache)
|
163 |
+
code_block_cache = ""
|
164 |
+
|
165 |
+
in_code_block = not in_code_block
|
166 |
+
code_block_delimiter_count = 0
|
167 |
+
|
168 |
+
if not in_code_block:
|
169 |
+
# handle single json
|
170 |
+
if delta == "{":
|
171 |
+
json_quote_count += 1
|
172 |
+
in_json = True
|
173 |
+
last_character = delta
|
174 |
+
json_cache += delta
|
175 |
+
elif delta == "}":
|
176 |
+
last_character = delta
|
177 |
+
json_cache += delta
|
178 |
+
if json_quote_count > 0:
|
179 |
+
json_quote_count -= 1
|
180 |
+
if json_quote_count == 0:
|
181 |
+
in_json = False
|
182 |
+
got_json = True
|
183 |
+
index += steps
|
184 |
+
continue
|
185 |
+
else:
|
186 |
+
if in_json:
|
187 |
+
last_character = delta
|
188 |
+
json_cache += delta
|
189 |
+
|
190 |
+
if got_json:
|
191 |
+
got_json = False
|
192 |
+
last_character = delta
|
193 |
+
yield parse_action(json_cache)
|
194 |
+
json_cache = ""
|
195 |
+
json_quote_count = 0
|
196 |
+
in_json = False
|
197 |
+
|
198 |
+
if not in_code_block and not in_json:
|
199 |
+
last_character = delta
|
200 |
+
yield delta.replace("`", "")
|
201 |
+
|
202 |
+
index += steps
|
203 |
+
|
204 |
+
if code_block_cache:
|
205 |
+
yield code_block_cache
|
206 |
+
|
207 |
+
if json_cache:
|
208 |
+
yield parse_action(json_cache)
|
api/core/agent/prompt/template.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
ENGLISH_REACT_COMPLETION_PROMPT_TEMPLATES = """Respond to the human as helpfully and accurately as possible.
|
2 |
+
|
3 |
+
{{instruction}}
|
4 |
+
|
5 |
+
You have access to the following tools:
|
6 |
+
|
7 |
+
{{tools}}
|
8 |
+
|
9 |
+
Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
|
10 |
+
Valid "action" values: "Final Answer" or {{tool_names}}
|
11 |
+
|
12 |
+
Provide only ONE action per $JSON_BLOB, as shown:
|
13 |
+
|
14 |
+
```
|
15 |
+
{
|
16 |
+
"action": $TOOL_NAME,
|
17 |
+
"action_input": $ACTION_INPUT
|
18 |
+
}
|
19 |
+
```
|
20 |
+
|
21 |
+
Follow this format:
|
22 |
+
|
23 |
+
Question: input question to answer
|
24 |
+
Thought: consider previous and subsequent steps
|
25 |
+
Action:
|
26 |
+
```
|
27 |
+
$JSON_BLOB
|
28 |
+
```
|
29 |
+
Observation: action result
|
30 |
+
... (repeat Thought/Action/Observation N times)
|
31 |
+
Thought: I know what to respond
|
32 |
+
Action:
|
33 |
+
```
|
34 |
+
{
|
35 |
+
"action": "Final Answer",
|
36 |
+
"action_input": "Final response to human"
|
37 |
+
}
|
38 |
+
```
|
39 |
+
|
40 |
+
Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.
|
41 |
+
{{historic_messages}}
|
42 |
+
Question: {{query}}
|
43 |
+
{{agent_scratchpad}}
|
44 |
+
Thought:""" # noqa: E501
|
45 |
+
|
46 |
+
|
47 |
+
ENGLISH_REACT_COMPLETION_AGENT_SCRATCHPAD_TEMPLATES = """Observation: {{observation}}
|
48 |
+
Thought:"""
|
49 |
+
|
50 |
+
ENGLISH_REACT_CHAT_PROMPT_TEMPLATES = """Respond to the human as helpfully and accurately as possible.
|
51 |
+
|
52 |
+
{{instruction}}
|
53 |
+
|
54 |
+
You have access to the following tools:
|
55 |
+
|
56 |
+
{{tools}}
|
57 |
+
|
58 |
+
Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
|
59 |
+
Valid "action" values: "Final Answer" or {{tool_names}}
|
60 |
+
|
61 |
+
Provide only ONE action per $JSON_BLOB, as shown:
|
62 |
+
|
63 |
+
```
|
64 |
+
{
|
65 |
+
"action": $TOOL_NAME,
|
66 |
+
"action_input": $ACTION_INPUT
|
67 |
+
}
|
68 |
+
```
|
69 |
+
|
70 |
+
Follow this format:
|
71 |
+
|
72 |
+
Question: input question to answer
|
73 |
+
Thought: consider previous and subsequent steps
|
74 |
+
Action:
|
75 |
+
```
|
76 |
+
$JSON_BLOB
|
77 |
+
```
|
78 |
+
Observation: action result
|
79 |
+
... (repeat Thought/Action/Observation N times)
|
80 |
+
Thought: I know what to respond
|
81 |
+
Action:
|
82 |
+
```
|
83 |
+
{
|
84 |
+
"action": "Final Answer",
|
85 |
+
"action_input": "Final response to human"
|
86 |
+
}
|
87 |
+
```
|
88 |
+
|
89 |
+
Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.
|
90 |
+
""" # noqa: E501
|
91 |
+
|
92 |
+
|
93 |
+
ENGLISH_REACT_CHAT_AGENT_SCRATCHPAD_TEMPLATES = ""
|
94 |
+
|
95 |
+
REACT_PROMPT_TEMPLATES = {
|
96 |
+
"english": {
|
97 |
+
"chat": {
|
98 |
+
"prompt": ENGLISH_REACT_CHAT_PROMPT_TEMPLATES,
|
99 |
+
"agent_scratchpad": ENGLISH_REACT_CHAT_AGENT_SCRATCHPAD_TEMPLATES,
|
100 |
+
},
|
101 |
+
"completion": {
|
102 |
+
"prompt": ENGLISH_REACT_COMPLETION_PROMPT_TEMPLATES,
|
103 |
+
"agent_scratchpad": ENGLISH_REACT_COMPLETION_AGENT_SCRATCHPAD_TEMPLATES,
|
104 |
+
},
|
105 |
+
}
|
106 |
+
}
|
api/core/app/__init__.py
ADDED
File without changes
|
api/core/app/app_config/__init__.py
ADDED
File without changes
|
api/core/app/app_config/base_app_config_manager.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections.abc import Mapping
|
2 |
+
from typing import Any
|
3 |
+
|
4 |
+
from core.app.app_config.entities import AppAdditionalFeatures
|
5 |
+
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
6 |
+
from core.app.app_config.features.more_like_this.manager import MoreLikeThisConfigManager
|
7 |
+
from core.app.app_config.features.opening_statement.manager import OpeningStatementConfigManager
|
8 |
+
from core.app.app_config.features.retrieval_resource.manager import RetrievalResourceConfigManager
|
9 |
+
from core.app.app_config.features.speech_to_text.manager import SpeechToTextConfigManager
|
10 |
+
from core.app.app_config.features.suggested_questions_after_answer.manager import (
|
11 |
+
SuggestedQuestionsAfterAnswerConfigManager,
|
12 |
+
)
|
13 |
+
from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager
|
14 |
+
from models.model import AppMode
|
15 |
+
|
16 |
+
|
17 |
+
class BaseAppConfigManager:
|
18 |
+
@classmethod
|
19 |
+
def convert_features(cls, config_dict: Mapping[str, Any], app_mode: AppMode) -> AppAdditionalFeatures:
|
20 |
+
"""
|
21 |
+
Convert app config to app model config
|
22 |
+
|
23 |
+
:param config_dict: app config
|
24 |
+
:param app_mode: app mode
|
25 |
+
"""
|
26 |
+
config_dict = dict(config_dict.items())
|
27 |
+
|
28 |
+
additional_features = AppAdditionalFeatures()
|
29 |
+
additional_features.show_retrieve_source = RetrievalResourceConfigManager.convert(config=config_dict)
|
30 |
+
|
31 |
+
additional_features.file_upload = FileUploadConfigManager.convert(
|
32 |
+
config=config_dict, is_vision=app_mode in {AppMode.CHAT, AppMode.COMPLETION, AppMode.AGENT_CHAT}
|
33 |
+
)
|
34 |
+
|
35 |
+
additional_features.opening_statement, additional_features.suggested_questions = (
|
36 |
+
OpeningStatementConfigManager.convert(config=config_dict)
|
37 |
+
)
|
38 |
+
|
39 |
+
additional_features.suggested_questions_after_answer = SuggestedQuestionsAfterAnswerConfigManager.convert(
|
40 |
+
config=config_dict
|
41 |
+
)
|
42 |
+
|
43 |
+
additional_features.more_like_this = MoreLikeThisConfigManager.convert(config=config_dict)
|
44 |
+
|
45 |
+
additional_features.speech_to_text = SpeechToTextConfigManager.convert(config=config_dict)
|
46 |
+
|
47 |
+
additional_features.text_to_speech = TextToSpeechConfigManager.convert(config=config_dict)
|
48 |
+
|
49 |
+
return additional_features
|
api/core/app/app_config/common/__init__.py
ADDED
File without changes
|
api/core/app/app_config/common/sensitive_word_avoidance/__init__.py
ADDED
File without changes
|
api/core/app/app_config/common/sensitive_word_avoidance/manager.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
from core.app.app_config.entities import SensitiveWordAvoidanceEntity
|
4 |
+
from core.moderation.factory import ModerationFactory
|
5 |
+
|
6 |
+
|
7 |
+
class SensitiveWordAvoidanceConfigManager:
|
8 |
+
@classmethod
|
9 |
+
def convert(cls, config: dict) -> Optional[SensitiveWordAvoidanceEntity]:
|
10 |
+
sensitive_word_avoidance_dict = config.get("sensitive_word_avoidance")
|
11 |
+
if not sensitive_word_avoidance_dict:
|
12 |
+
return None
|
13 |
+
|
14 |
+
if sensitive_word_avoidance_dict.get("enabled"):
|
15 |
+
return SensitiveWordAvoidanceEntity(
|
16 |
+
type=sensitive_word_avoidance_dict.get("type"),
|
17 |
+
config=sensitive_word_avoidance_dict.get("config"),
|
18 |
+
)
|
19 |
+
else:
|
20 |
+
return None
|
21 |
+
|
22 |
+
@classmethod
|
23 |
+
def validate_and_set_defaults(
|
24 |
+
cls, tenant_id, config: dict, only_structure_validate: bool = False
|
25 |
+
) -> tuple[dict, list[str]]:
|
26 |
+
if not config.get("sensitive_word_avoidance"):
|
27 |
+
config["sensitive_word_avoidance"] = {"enabled": False}
|
28 |
+
|
29 |
+
if not isinstance(config["sensitive_word_avoidance"], dict):
|
30 |
+
raise ValueError("sensitive_word_avoidance must be of dict type")
|
31 |
+
|
32 |
+
if "enabled" not in config["sensitive_word_avoidance"] or not config["sensitive_word_avoidance"]["enabled"]:
|
33 |
+
config["sensitive_word_avoidance"]["enabled"] = False
|
34 |
+
|
35 |
+
if config["sensitive_word_avoidance"]["enabled"]:
|
36 |
+
if not config["sensitive_word_avoidance"].get("type"):
|
37 |
+
raise ValueError("sensitive_word_avoidance.type is required")
|
38 |
+
|
39 |
+
if not only_structure_validate:
|
40 |
+
typ = config["sensitive_word_avoidance"]["type"]
|
41 |
+
sensitive_word_avoidance_config = config["sensitive_word_avoidance"]["config"]
|
42 |
+
|
43 |
+
ModerationFactory.validate_config(name=typ, tenant_id=tenant_id, config=sensitive_word_avoidance_config)
|
44 |
+
|
45 |
+
return config, ["sensitive_word_avoidance"]
|
api/core/app/app_config/easy_ui_based_app/__init__.py
ADDED
File without changes
|
api/core/app/app_config/easy_ui_based_app/agent/__init__.py
ADDED
File without changes
|
api/core/app/app_config/easy_ui_based_app/agent/manager.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
from core.agent.entities import AgentEntity, AgentPromptEntity, AgentToolEntity
|
4 |
+
from core.agent.prompt.template import REACT_PROMPT_TEMPLATES
|
5 |
+
|
6 |
+
|
7 |
+
class AgentConfigManager:
|
8 |
+
@classmethod
|
9 |
+
def convert(cls, config: dict) -> Optional[AgentEntity]:
|
10 |
+
"""
|
11 |
+
Convert model config to model config
|
12 |
+
|
13 |
+
:param config: model config args
|
14 |
+
"""
|
15 |
+
if "agent_mode" in config and config["agent_mode"] and "enabled" in config["agent_mode"]:
|
16 |
+
agent_dict = config.get("agent_mode", {})
|
17 |
+
agent_strategy = agent_dict.get("strategy", "cot")
|
18 |
+
|
19 |
+
if agent_strategy == "function_call":
|
20 |
+
strategy = AgentEntity.Strategy.FUNCTION_CALLING
|
21 |
+
elif agent_strategy in {"cot", "react"}:
|
22 |
+
strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
|
23 |
+
else:
|
24 |
+
# old configs, try to detect default strategy
|
25 |
+
if config["model"]["provider"] == "openai":
|
26 |
+
strategy = AgentEntity.Strategy.FUNCTION_CALLING
|
27 |
+
else:
|
28 |
+
strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
|
29 |
+
|
30 |
+
agent_tools = []
|
31 |
+
for tool in agent_dict.get("tools", []):
|
32 |
+
keys = tool.keys()
|
33 |
+
if len(keys) >= 4:
|
34 |
+
if "enabled" not in tool or not tool["enabled"]:
|
35 |
+
continue
|
36 |
+
|
37 |
+
agent_tool_properties = {
|
38 |
+
"provider_type": tool["provider_type"],
|
39 |
+
"provider_id": tool["provider_id"],
|
40 |
+
"tool_name": tool["tool_name"],
|
41 |
+
"tool_parameters": tool.get("tool_parameters", {}),
|
42 |
+
}
|
43 |
+
|
44 |
+
agent_tools.append(AgentToolEntity(**agent_tool_properties))
|
45 |
+
|
46 |
+
if "strategy" in config["agent_mode"] and config["agent_mode"]["strategy"] not in {
|
47 |
+
"react_router",
|
48 |
+
"router",
|
49 |
+
}:
|
50 |
+
agent_prompt = agent_dict.get("prompt", None) or {}
|
51 |
+
# check model mode
|
52 |
+
model_mode = config.get("model", {}).get("mode", "completion")
|
53 |
+
if model_mode == "completion":
|
54 |
+
agent_prompt_entity = AgentPromptEntity(
|
55 |
+
first_prompt=agent_prompt.get(
|
56 |
+
"first_prompt", REACT_PROMPT_TEMPLATES["english"]["completion"]["prompt"]
|
57 |
+
),
|
58 |
+
next_iteration=agent_prompt.get(
|
59 |
+
"next_iteration", REACT_PROMPT_TEMPLATES["english"]["completion"]["agent_scratchpad"]
|
60 |
+
),
|
61 |
+
)
|
62 |
+
else:
|
63 |
+
agent_prompt_entity = AgentPromptEntity(
|
64 |
+
first_prompt=agent_prompt.get(
|
65 |
+
"first_prompt", REACT_PROMPT_TEMPLATES["english"]["chat"]["prompt"]
|
66 |
+
),
|
67 |
+
next_iteration=agent_prompt.get(
|
68 |
+
"next_iteration", REACT_PROMPT_TEMPLATES["english"]["chat"]["agent_scratchpad"]
|
69 |
+
),
|
70 |
+
)
|
71 |
+
|
72 |
+
return AgentEntity(
|
73 |
+
provider=config["model"]["provider"],
|
74 |
+
model=config["model"]["name"],
|
75 |
+
strategy=strategy,
|
76 |
+
prompt=agent_prompt_entity,
|
77 |
+
tools=agent_tools,
|
78 |
+
max_iteration=agent_dict.get("max_iteration", 5),
|
79 |
+
)
|
80 |
+
|
81 |
+
return None
|
api/core/app/app_config/easy_ui_based_app/dataset/__init__.py
ADDED
File without changes
|
api/core/app/app_config/easy_ui_based_app/dataset/manager.py
ADDED
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import uuid
|
2 |
+
from typing import Optional
|
3 |
+
|
4 |
+
from core.app.app_config.entities import DatasetEntity, DatasetRetrieveConfigEntity
|
5 |
+
from core.entities.agent_entities import PlanningStrategy
|
6 |
+
from models.model import AppMode
|
7 |
+
from services.dataset_service import DatasetService
|
8 |
+
|
9 |
+
|
10 |
+
class DatasetConfigManager:
|
11 |
+
@classmethod
|
12 |
+
def convert(cls, config: dict) -> Optional[DatasetEntity]:
|
13 |
+
"""
|
14 |
+
Convert model config to model config
|
15 |
+
|
16 |
+
:param config: model config args
|
17 |
+
"""
|
18 |
+
dataset_ids = []
|
19 |
+
if "datasets" in config.get("dataset_configs", {}):
|
20 |
+
datasets = config.get("dataset_configs", {}).get("datasets", {"strategy": "router", "datasets": []})
|
21 |
+
|
22 |
+
for dataset in datasets.get("datasets", []):
|
23 |
+
keys = list(dataset.keys())
|
24 |
+
if len(keys) == 0 or keys[0] != "dataset":
|
25 |
+
continue
|
26 |
+
|
27 |
+
dataset = dataset["dataset"]
|
28 |
+
|
29 |
+
if "enabled" not in dataset or not dataset["enabled"]:
|
30 |
+
continue
|
31 |
+
|
32 |
+
dataset_id = dataset.get("id", None)
|
33 |
+
if dataset_id:
|
34 |
+
dataset_ids.append(dataset_id)
|
35 |
+
|
36 |
+
if (
|
37 |
+
"agent_mode" in config
|
38 |
+
and config["agent_mode"]
|
39 |
+
and "enabled" in config["agent_mode"]
|
40 |
+
and config["agent_mode"]["enabled"]
|
41 |
+
):
|
42 |
+
agent_dict = config.get("agent_mode", {})
|
43 |
+
|
44 |
+
for tool in agent_dict.get("tools", []):
|
45 |
+
keys = tool.keys()
|
46 |
+
if len(keys) == 1:
|
47 |
+
# old standard
|
48 |
+
key = list(tool.keys())[0]
|
49 |
+
|
50 |
+
if key != "dataset":
|
51 |
+
continue
|
52 |
+
|
53 |
+
tool_item = tool[key]
|
54 |
+
|
55 |
+
if "enabled" not in tool_item or not tool_item["enabled"]:
|
56 |
+
continue
|
57 |
+
|
58 |
+
dataset_id = tool_item["id"]
|
59 |
+
dataset_ids.append(dataset_id)
|
60 |
+
|
61 |
+
if len(dataset_ids) == 0:
|
62 |
+
return None
|
63 |
+
|
64 |
+
# dataset configs
|
65 |
+
if "dataset_configs" in config and config.get("dataset_configs"):
|
66 |
+
dataset_configs = config.get("dataset_configs")
|
67 |
+
else:
|
68 |
+
dataset_configs = {"retrieval_model": "multiple"}
|
69 |
+
if dataset_configs is None:
|
70 |
+
return None
|
71 |
+
query_variable = config.get("dataset_query_variable")
|
72 |
+
|
73 |
+
if dataset_configs["retrieval_model"] == "single":
|
74 |
+
return DatasetEntity(
|
75 |
+
dataset_ids=dataset_ids,
|
76 |
+
retrieve_config=DatasetRetrieveConfigEntity(
|
77 |
+
query_variable=query_variable,
|
78 |
+
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
|
79 |
+
dataset_configs["retrieval_model"]
|
80 |
+
),
|
81 |
+
),
|
82 |
+
)
|
83 |
+
else:
|
84 |
+
return DatasetEntity(
|
85 |
+
dataset_ids=dataset_ids,
|
86 |
+
retrieve_config=DatasetRetrieveConfigEntity(
|
87 |
+
query_variable=query_variable,
|
88 |
+
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
|
89 |
+
dataset_configs["retrieval_model"]
|
90 |
+
),
|
91 |
+
top_k=dataset_configs.get("top_k", 4),
|
92 |
+
score_threshold=dataset_configs.get("score_threshold"),
|
93 |
+
reranking_model=dataset_configs.get("reranking_model"),
|
94 |
+
weights=dataset_configs.get("weights"),
|
95 |
+
reranking_enabled=dataset_configs.get("reranking_enabled", True),
|
96 |
+
rerank_mode=dataset_configs.get("reranking_mode", "reranking_model"),
|
97 |
+
),
|
98 |
+
)
|
99 |
+
|
100 |
+
@classmethod
|
101 |
+
def validate_and_set_defaults(cls, tenant_id: str, app_mode: AppMode, config: dict) -> tuple[dict, list[str]]:
|
102 |
+
"""
|
103 |
+
Validate and set defaults for dataset feature
|
104 |
+
|
105 |
+
:param tenant_id: tenant ID
|
106 |
+
:param app_mode: app mode
|
107 |
+
:param config: app model config args
|
108 |
+
"""
|
109 |
+
# Extract dataset config for legacy compatibility
|
110 |
+
config = cls.extract_dataset_config_for_legacy_compatibility(tenant_id, app_mode, config)
|
111 |
+
|
112 |
+
# dataset_configs
|
113 |
+
if not config.get("dataset_configs"):
|
114 |
+
config["dataset_configs"] = {"retrieval_model": "single"}
|
115 |
+
|
116 |
+
if not config["dataset_configs"].get("datasets"):
|
117 |
+
config["dataset_configs"]["datasets"] = {"strategy": "router", "datasets": []}
|
118 |
+
|
119 |
+
if not isinstance(config["dataset_configs"], dict):
|
120 |
+
raise ValueError("dataset_configs must be of object type")
|
121 |
+
|
122 |
+
if not isinstance(config["dataset_configs"], dict):
|
123 |
+
raise ValueError("dataset_configs must be of object type")
|
124 |
+
|
125 |
+
need_manual_query_datasets = config.get("dataset_configs") and config["dataset_configs"].get(
|
126 |
+
"datasets", {}
|
127 |
+
).get("datasets")
|
128 |
+
|
129 |
+
if need_manual_query_datasets and app_mode == AppMode.COMPLETION:
|
130 |
+
# Only check when mode is completion
|
131 |
+
dataset_query_variable = config.get("dataset_query_variable")
|
132 |
+
|
133 |
+
if not dataset_query_variable:
|
134 |
+
raise ValueError("Dataset query variable is required when dataset is exist")
|
135 |
+
|
136 |
+
return config, ["agent_mode", "dataset_configs", "dataset_query_variable"]
|
137 |
+
|
138 |
+
@classmethod
|
139 |
+
def extract_dataset_config_for_legacy_compatibility(cls, tenant_id: str, app_mode: AppMode, config: dict) -> dict:
|
140 |
+
"""
|
141 |
+
Extract dataset config for legacy compatibility
|
142 |
+
|
143 |
+
:param tenant_id: tenant ID
|
144 |
+
:param app_mode: app mode
|
145 |
+
:param config: app model config args
|
146 |
+
"""
|
147 |
+
# Extract dataset config for legacy compatibility
|
148 |
+
if not config.get("agent_mode"):
|
149 |
+
config["agent_mode"] = {"enabled": False, "tools": []}
|
150 |
+
|
151 |
+
if not isinstance(config["agent_mode"], dict):
|
152 |
+
raise ValueError("agent_mode must be of object type")
|
153 |
+
|
154 |
+
# enabled
|
155 |
+
if "enabled" not in config["agent_mode"] or not config["agent_mode"]["enabled"]:
|
156 |
+
config["agent_mode"]["enabled"] = False
|
157 |
+
|
158 |
+
if not isinstance(config["agent_mode"]["enabled"], bool):
|
159 |
+
raise ValueError("enabled in agent_mode must be of boolean type")
|
160 |
+
|
161 |
+
# tools
|
162 |
+
if not config["agent_mode"].get("tools"):
|
163 |
+
config["agent_mode"]["tools"] = []
|
164 |
+
|
165 |
+
if not isinstance(config["agent_mode"]["tools"], list):
|
166 |
+
raise ValueError("tools in agent_mode must be a list of objects")
|
167 |
+
|
168 |
+
# strategy
|
169 |
+
if not config["agent_mode"].get("strategy"):
|
170 |
+
config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value
|
171 |
+
|
172 |
+
has_datasets = False
|
173 |
+
if config["agent_mode"]["strategy"] in {PlanningStrategy.ROUTER.value, PlanningStrategy.REACT_ROUTER.value}:
|
174 |
+
for tool in config["agent_mode"]["tools"]:
|
175 |
+
key = list(tool.keys())[0]
|
176 |
+
if key == "dataset":
|
177 |
+
# old style, use tool name as key
|
178 |
+
tool_item = tool[key]
|
179 |
+
|
180 |
+
if "enabled" not in tool_item or not tool_item["enabled"]:
|
181 |
+
tool_item["enabled"] = False
|
182 |
+
|
183 |
+
if not isinstance(tool_item["enabled"], bool):
|
184 |
+
raise ValueError("enabled in agent_mode.tools must be of boolean type")
|
185 |
+
|
186 |
+
if "id" not in tool_item:
|
187 |
+
raise ValueError("id is required in dataset")
|
188 |
+
|
189 |
+
try:
|
190 |
+
uuid.UUID(tool_item["id"])
|
191 |
+
except ValueError:
|
192 |
+
raise ValueError("id in dataset must be of UUID type")
|
193 |
+
|
194 |
+
if not cls.is_dataset_exists(tenant_id, tool_item["id"]):
|
195 |
+
raise ValueError("Dataset ID does not exist, please check your permission.")
|
196 |
+
|
197 |
+
has_datasets = True
|
198 |
+
|
199 |
+
need_manual_query_datasets = has_datasets and config["agent_mode"]["enabled"]
|
200 |
+
|
201 |
+
if need_manual_query_datasets and app_mode == AppMode.COMPLETION:
|
202 |
+
# Only check when mode is completion
|
203 |
+
dataset_query_variable = config.get("dataset_query_variable")
|
204 |
+
|
205 |
+
if not dataset_query_variable:
|
206 |
+
raise ValueError("Dataset query variable is required when dataset is exist")
|
207 |
+
|
208 |
+
return config
|
209 |
+
|
210 |
+
@classmethod
|
211 |
+
def is_dataset_exists(cls, tenant_id: str, dataset_id: str) -> bool:
|
212 |
+
# verify if the dataset ID exists
|
213 |
+
dataset = DatasetService.get_dataset(dataset_id)
|
214 |
+
|
215 |
+
if not dataset:
|
216 |
+
return False
|
217 |
+
|
218 |
+
if dataset.tenant_id != tenant_id:
|
219 |
+
return False
|
220 |
+
|
221 |
+
return True
|
api/core/app/app_config/easy_ui_based_app/model_config/__init__.py
ADDED
File without changes
|
api/core/app/app_config/easy_ui_based_app/model_config/converter.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import cast
|
2 |
+
|
3 |
+
from core.app.app_config.entities import EasyUIBasedAppConfig
|
4 |
+
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
5 |
+
from core.entities.model_entities import ModelStatus
|
6 |
+
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
7 |
+
from core.model_runtime.entities.model_entities import ModelType
|
8 |
+
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
9 |
+
from core.provider_manager import ProviderManager
|
10 |
+
|
11 |
+
|
12 |
+
class ModelConfigConverter:
|
13 |
+
@classmethod
|
14 |
+
def convert(cls, app_config: EasyUIBasedAppConfig) -> ModelConfigWithCredentialsEntity:
|
15 |
+
"""
|
16 |
+
Convert app model config dict to entity.
|
17 |
+
:param app_config: app config
|
18 |
+
:param skip_check: skip check
|
19 |
+
:raises ProviderTokenNotInitError: provider token not init error
|
20 |
+
:return: app orchestration config entity
|
21 |
+
"""
|
22 |
+
model_config = app_config.model
|
23 |
+
|
24 |
+
provider_manager = ProviderManager()
|
25 |
+
provider_model_bundle = provider_manager.get_provider_model_bundle(
|
26 |
+
tenant_id=app_config.tenant_id, provider=model_config.provider, model_type=ModelType.LLM
|
27 |
+
)
|
28 |
+
|
29 |
+
provider_name = provider_model_bundle.configuration.provider.provider
|
30 |
+
model_name = model_config.model
|
31 |
+
|
32 |
+
model_type_instance = provider_model_bundle.model_type_instance
|
33 |
+
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
34 |
+
|
35 |
+
# check model credentials
|
36 |
+
model_credentials = provider_model_bundle.configuration.get_current_credentials(
|
37 |
+
model_type=ModelType.LLM, model=model_config.model
|
38 |
+
)
|
39 |
+
|
40 |
+
if model_credentials is None:
|
41 |
+
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
|
42 |
+
|
43 |
+
# check model
|
44 |
+
provider_model = provider_model_bundle.configuration.get_provider_model(
|
45 |
+
model=model_config.model, model_type=ModelType.LLM
|
46 |
+
)
|
47 |
+
|
48 |
+
if provider_model is None:
|
49 |
+
model_name = model_config.model
|
50 |
+
raise ValueError(f"Model {model_name} not exist.")
|
51 |
+
|
52 |
+
if provider_model.status == ModelStatus.NO_CONFIGURE:
|
53 |
+
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
|
54 |
+
elif provider_model.status == ModelStatus.NO_PERMISSION:
|
55 |
+
raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.")
|
56 |
+
elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
|
57 |
+
raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.")
|
58 |
+
|
59 |
+
# model config
|
60 |
+
completion_params = model_config.parameters
|
61 |
+
stop = []
|
62 |
+
if "stop" in completion_params:
|
63 |
+
stop = completion_params["stop"]
|
64 |
+
del completion_params["stop"]
|
65 |
+
|
66 |
+
# get model mode
|
67 |
+
model_mode = model_config.mode
|
68 |
+
if not model_mode:
|
69 |
+
mode_enum = model_type_instance.get_model_mode(model=model_config.model, credentials=model_credentials)
|
70 |
+
|
71 |
+
model_mode = mode_enum.value
|
72 |
+
|
73 |
+
model_schema = model_type_instance.get_model_schema(model_config.model, model_credentials)
|
74 |
+
|
75 |
+
if not model_schema:
|
76 |
+
raise ValueError(f"Model {model_name} not exist.")
|
77 |
+
|
78 |
+
return ModelConfigWithCredentialsEntity(
|
79 |
+
provider=model_config.provider,
|
80 |
+
model=model_config.model,
|
81 |
+
model_schema=model_schema,
|
82 |
+
mode=model_mode,
|
83 |
+
provider_model_bundle=provider_model_bundle,
|
84 |
+
credentials=model_credentials,
|
85 |
+
parameters=completion_params,
|
86 |
+
stop=stop,
|
87 |
+
)
|
api/core/app/app_config/easy_ui_based_app/model_config/manager.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections.abc import Mapping
|
2 |
+
from typing import Any
|
3 |
+
|
4 |
+
from core.app.app_config.entities import ModelConfigEntity
|
5 |
+
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
6 |
+
from core.model_runtime.model_providers import model_provider_factory
|
7 |
+
from core.provider_manager import ProviderManager
|
8 |
+
|
9 |
+
|
10 |
+
class ModelConfigManager:
|
11 |
+
@classmethod
|
12 |
+
def convert(cls, config: dict) -> ModelConfigEntity:
|
13 |
+
"""
|
14 |
+
Convert model config to model config
|
15 |
+
|
16 |
+
:param config: model config args
|
17 |
+
"""
|
18 |
+
# model config
|
19 |
+
model_config = config.get("model")
|
20 |
+
|
21 |
+
if not model_config:
|
22 |
+
raise ValueError("model is required")
|
23 |
+
|
24 |
+
completion_params = model_config.get("completion_params")
|
25 |
+
stop = []
|
26 |
+
if "stop" in completion_params:
|
27 |
+
stop = completion_params["stop"]
|
28 |
+
del completion_params["stop"]
|
29 |
+
|
30 |
+
# get model mode
|
31 |
+
model_mode = model_config.get("mode")
|
32 |
+
|
33 |
+
return ModelConfigEntity(
|
34 |
+
provider=config["model"]["provider"],
|
35 |
+
model=config["model"]["name"],
|
36 |
+
mode=model_mode,
|
37 |
+
parameters=completion_params,
|
38 |
+
stop=stop,
|
39 |
+
)
|
40 |
+
|
41 |
+
@classmethod
|
42 |
+
def validate_and_set_defaults(cls, tenant_id: str, config: Mapping[str, Any]) -> tuple[dict, list[str]]:
|
43 |
+
"""
|
44 |
+
Validate and set defaults for model config
|
45 |
+
|
46 |
+
:param tenant_id: tenant id
|
47 |
+
:param config: app model config args
|
48 |
+
"""
|
49 |
+
if "model" not in config:
|
50 |
+
raise ValueError("model is required")
|
51 |
+
|
52 |
+
if not isinstance(config["model"], dict):
|
53 |
+
raise ValueError("model must be of object type")
|
54 |
+
|
55 |
+
# model.provider
|
56 |
+
provider_entities = model_provider_factory.get_providers()
|
57 |
+
model_provider_names = [provider.provider for provider in provider_entities]
|
58 |
+
if "provider" not in config["model"] or config["model"]["provider"] not in model_provider_names:
|
59 |
+
raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}")
|
60 |
+
|
61 |
+
# model.name
|
62 |
+
if "name" not in config["model"]:
|
63 |
+
raise ValueError("model.name is required")
|
64 |
+
|
65 |
+
provider_manager = ProviderManager()
|
66 |
+
models = provider_manager.get_configurations(tenant_id).get_models(
|
67 |
+
provider=config["model"]["provider"], model_type=ModelType.LLM
|
68 |
+
)
|
69 |
+
|
70 |
+
if not models:
|
71 |
+
raise ValueError("model.name must be in the specified model list")
|
72 |
+
|
73 |
+
model_ids = [m.model for m in models]
|
74 |
+
if config["model"]["name"] not in model_ids:
|
75 |
+
raise ValueError("model.name must be in the specified model list")
|
76 |
+
|
77 |
+
model_mode = None
|
78 |
+
for model in models:
|
79 |
+
if model.model == config["model"]["name"]:
|
80 |
+
model_mode = model.model_properties.get(ModelPropertyKey.MODE)
|
81 |
+
break
|
82 |
+
|
83 |
+
# model.mode
|
84 |
+
if model_mode:
|
85 |
+
config["model"]["mode"] = model_mode
|
86 |
+
else:
|
87 |
+
config["model"]["mode"] = "completion"
|
88 |
+
|
89 |
+
# model.completion_params
|
90 |
+
if "completion_params" not in config["model"]:
|
91 |
+
raise ValueError("model.completion_params is required")
|
92 |
+
|
93 |
+
config["model"]["completion_params"] = cls.validate_model_completion_params(
|
94 |
+
config["model"]["completion_params"]
|
95 |
+
)
|
96 |
+
|
97 |
+
return dict(config), ["model"]
|
98 |
+
|
99 |
+
@classmethod
|
100 |
+
def validate_model_completion_params(cls, cp: dict) -> dict:
|
101 |
+
# model.completion_params
|
102 |
+
if not isinstance(cp, dict):
|
103 |
+
raise ValueError("model.completion_params must be of object type")
|
104 |
+
|
105 |
+
# stop
|
106 |
+
if "stop" not in cp:
|
107 |
+
cp["stop"] = []
|
108 |
+
elif not isinstance(cp["stop"], list):
|
109 |
+
raise ValueError("stop in model.completion_params must be of list type")
|
110 |
+
|
111 |
+
if len(cp["stop"]) > 4:
|
112 |
+
raise ValueError("stop sequences must be less than 4")
|
113 |
+
|
114 |
+
return cp
|
api/core/app/app_config/easy_ui_based_app/prompt_template/__init__.py
ADDED
File without changes
|
api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from core.app.app_config.entities import (
|
2 |
+
AdvancedChatMessageEntity,
|
3 |
+
AdvancedChatPromptTemplateEntity,
|
4 |
+
AdvancedCompletionPromptTemplateEntity,
|
5 |
+
PromptTemplateEntity,
|
6 |
+
)
|
7 |
+
from core.model_runtime.entities.message_entities import PromptMessageRole
|
8 |
+
from core.prompt.simple_prompt_transform import ModelMode
|
9 |
+
from models.model import AppMode
|
10 |
+
|
11 |
+
|
12 |
+
class PromptTemplateConfigManager:
|
13 |
+
@classmethod
|
14 |
+
def convert(cls, config: dict) -> PromptTemplateEntity:
|
15 |
+
if not config.get("prompt_type"):
|
16 |
+
raise ValueError("prompt_type is required")
|
17 |
+
|
18 |
+
prompt_type = PromptTemplateEntity.PromptType.value_of(config["prompt_type"])
|
19 |
+
if prompt_type == PromptTemplateEntity.PromptType.SIMPLE:
|
20 |
+
simple_prompt_template = config.get("pre_prompt", "")
|
21 |
+
return PromptTemplateEntity(prompt_type=prompt_type, simple_prompt_template=simple_prompt_template)
|
22 |
+
else:
|
23 |
+
advanced_chat_prompt_template = None
|
24 |
+
chat_prompt_config = config.get("chat_prompt_config", {})
|
25 |
+
if chat_prompt_config:
|
26 |
+
chat_prompt_messages = []
|
27 |
+
for message in chat_prompt_config.get("prompt", []):
|
28 |
+
chat_prompt_messages.append(
|
29 |
+
AdvancedChatMessageEntity(
|
30 |
+
**{"text": message["text"], "role": PromptMessageRole.value_of(message["role"])}
|
31 |
+
)
|
32 |
+
)
|
33 |
+
|
34 |
+
advanced_chat_prompt_template = AdvancedChatPromptTemplateEntity(messages=chat_prompt_messages)
|
35 |
+
|
36 |
+
advanced_completion_prompt_template = None
|
37 |
+
completion_prompt_config = config.get("completion_prompt_config", {})
|
38 |
+
if completion_prompt_config:
|
39 |
+
completion_prompt_template_params = {
|
40 |
+
"prompt": completion_prompt_config["prompt"]["text"],
|
41 |
+
}
|
42 |
+
|
43 |
+
if "conversation_histories_role" in completion_prompt_config:
|
44 |
+
completion_prompt_template_params["role_prefix"] = {
|
45 |
+
"user": completion_prompt_config["conversation_histories_role"]["user_prefix"],
|
46 |
+
"assistant": completion_prompt_config["conversation_histories_role"]["assistant_prefix"],
|
47 |
+
}
|
48 |
+
|
49 |
+
advanced_completion_prompt_template = AdvancedCompletionPromptTemplateEntity(
|
50 |
+
**completion_prompt_template_params
|
51 |
+
)
|
52 |
+
|
53 |
+
return PromptTemplateEntity(
|
54 |
+
prompt_type=prompt_type,
|
55 |
+
advanced_chat_prompt_template=advanced_chat_prompt_template,
|
56 |
+
advanced_completion_prompt_template=advanced_completion_prompt_template,
|
57 |
+
)
|
58 |
+
|
59 |
+
@classmethod
|
60 |
+
def validate_and_set_defaults(cls, app_mode: AppMode, config: dict) -> tuple[dict, list[str]]:
|
61 |
+
"""
|
62 |
+
Validate pre_prompt and set defaults for prompt feature
|
63 |
+
depending on the config['model']
|
64 |
+
|
65 |
+
:param app_mode: app mode
|
66 |
+
:param config: app model config args
|
67 |
+
"""
|
68 |
+
if not config.get("prompt_type"):
|
69 |
+
config["prompt_type"] = PromptTemplateEntity.PromptType.SIMPLE.value
|
70 |
+
|
71 |
+
prompt_type_vals = [typ.value for typ in PromptTemplateEntity.PromptType]
|
72 |
+
if config["prompt_type"] not in prompt_type_vals:
|
73 |
+
raise ValueError(f"prompt_type must be in {prompt_type_vals}")
|
74 |
+
|
75 |
+
# chat_prompt_config
|
76 |
+
if not config.get("chat_prompt_config"):
|
77 |
+
config["chat_prompt_config"] = {}
|
78 |
+
|
79 |
+
if not isinstance(config["chat_prompt_config"], dict):
|
80 |
+
raise ValueError("chat_prompt_config must be of object type")
|
81 |
+
|
82 |
+
# completion_prompt_config
|
83 |
+
if not config.get("completion_prompt_config"):
|
84 |
+
config["completion_prompt_config"] = {}
|
85 |
+
|
86 |
+
if not isinstance(config["completion_prompt_config"], dict):
|
87 |
+
raise ValueError("completion_prompt_config must be of object type")
|
88 |
+
|
89 |
+
if config["prompt_type"] == PromptTemplateEntity.PromptType.ADVANCED.value:
|
90 |
+
if not config["chat_prompt_config"] and not config["completion_prompt_config"]:
|
91 |
+
raise ValueError(
|
92 |
+
"chat_prompt_config or completion_prompt_config is required when prompt_type is advanced"
|
93 |
+
)
|
94 |
+
|
95 |
+
model_mode_vals = [mode.value for mode in ModelMode]
|
96 |
+
if config["model"]["mode"] not in model_mode_vals:
|
97 |
+
raise ValueError(f"model.mode must be in {model_mode_vals} when prompt_type is advanced")
|
98 |
+
|
99 |
+
if app_mode == AppMode.CHAT and config["model"]["mode"] == ModelMode.COMPLETION.value:
|
100 |
+
user_prefix = config["completion_prompt_config"]["conversation_histories_role"]["user_prefix"]
|
101 |
+
assistant_prefix = config["completion_prompt_config"]["conversation_histories_role"]["assistant_prefix"]
|
102 |
+
|
103 |
+
if not user_prefix:
|
104 |
+
config["completion_prompt_config"]["conversation_histories_role"]["user_prefix"] = "Human"
|
105 |
+
|
106 |
+
if not assistant_prefix:
|
107 |
+
config["completion_prompt_config"]["conversation_histories_role"]["assistant_prefix"] = "Assistant"
|
108 |
+
|
109 |
+
if config["model"]["mode"] == ModelMode.CHAT.value:
|
110 |
+
prompt_list = config["chat_prompt_config"]["prompt"]
|
111 |
+
|
112 |
+
if len(prompt_list) > 10:
|
113 |
+
raise ValueError("prompt messages must be less than 10")
|
114 |
+
else:
|
115 |
+
# pre_prompt, for simple mode
|
116 |
+
if not config.get("pre_prompt"):
|
117 |
+
config["pre_prompt"] = ""
|
118 |
+
|
119 |
+
if not isinstance(config["pre_prompt"], str):
|
120 |
+
raise ValueError("pre_prompt must be of string type")
|
121 |
+
|
122 |
+
return config, ["prompt_type", "pre_prompt", "chat_prompt_config", "completion_prompt_config"]
|
123 |
+
|
124 |
+
@classmethod
|
125 |
+
def validate_post_prompt_and_set_defaults(cls, config: dict) -> dict:
|
126 |
+
"""
|
127 |
+
Validate post_prompt and set defaults for prompt feature
|
128 |
+
|
129 |
+
:param config: app model config args
|
130 |
+
"""
|
131 |
+
# post_prompt
|
132 |
+
if not config.get("post_prompt"):
|
133 |
+
config["post_prompt"] = ""
|
134 |
+
|
135 |
+
if not isinstance(config["post_prompt"], str):
|
136 |
+
raise ValueError("post_prompt must be of string type")
|
137 |
+
|
138 |
+
return config
|
api/core/app/app_config/easy_ui_based_app/variables/__init__.py
ADDED
File without changes
|
api/core/app/app_config/easy_ui_based_app/variables/manager.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
|
3 |
+
from core.app.app_config.entities import ExternalDataVariableEntity, VariableEntity, VariableEntityType
|
4 |
+
from core.external_data_tool.factory import ExternalDataToolFactory
|
5 |
+
|
6 |
+
|
7 |
+
class BasicVariablesConfigManager:
|
8 |
+
@classmethod
|
9 |
+
def convert(cls, config: dict) -> tuple[list[VariableEntity], list[ExternalDataVariableEntity]]:
|
10 |
+
"""
|
11 |
+
Convert model config to model config
|
12 |
+
|
13 |
+
:param config: model config args
|
14 |
+
"""
|
15 |
+
external_data_variables = []
|
16 |
+
variable_entities = []
|
17 |
+
|
18 |
+
# old external_data_tools
|
19 |
+
external_data_tools = config.get("external_data_tools", [])
|
20 |
+
for external_data_tool in external_data_tools:
|
21 |
+
if "enabled" not in external_data_tool or not external_data_tool["enabled"]:
|
22 |
+
continue
|
23 |
+
|
24 |
+
external_data_variables.append(
|
25 |
+
ExternalDataVariableEntity(
|
26 |
+
variable=external_data_tool["variable"],
|
27 |
+
type=external_data_tool["type"],
|
28 |
+
config=external_data_tool["config"],
|
29 |
+
)
|
30 |
+
)
|
31 |
+
|
32 |
+
# variables and external_data_tools
|
33 |
+
for variables in config.get("user_input_form", []):
|
34 |
+
variable_type = list(variables.keys())[0]
|
35 |
+
if variable_type == VariableEntityType.EXTERNAL_DATA_TOOL:
|
36 |
+
variable = variables[variable_type]
|
37 |
+
if "config" not in variable:
|
38 |
+
continue
|
39 |
+
|
40 |
+
external_data_variables.append(
|
41 |
+
ExternalDataVariableEntity(
|
42 |
+
variable=variable["variable"], type=variable["type"], config=variable["config"]
|
43 |
+
)
|
44 |
+
)
|
45 |
+
elif variable_type in {
|
46 |
+
VariableEntityType.TEXT_INPUT,
|
47 |
+
VariableEntityType.PARAGRAPH,
|
48 |
+
VariableEntityType.NUMBER,
|
49 |
+
VariableEntityType.SELECT,
|
50 |
+
}:
|
51 |
+
variable = variables[variable_type]
|
52 |
+
variable_entities.append(
|
53 |
+
VariableEntity(
|
54 |
+
type=variable_type,
|
55 |
+
variable=variable.get("variable"),
|
56 |
+
description=variable.get("description") or "",
|
57 |
+
label=variable.get("label"),
|
58 |
+
required=variable.get("required", False),
|
59 |
+
max_length=variable.get("max_length"),
|
60 |
+
options=variable.get("options") or [],
|
61 |
+
)
|
62 |
+
)
|
63 |
+
|
64 |
+
return variable_entities, external_data_variables
|
65 |
+
|
66 |
+
@classmethod
|
67 |
+
def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]:
|
68 |
+
"""
|
69 |
+
Validate and set defaults for user input form
|
70 |
+
|
71 |
+
:param tenant_id: workspace id
|
72 |
+
:param config: app model config args
|
73 |
+
"""
|
74 |
+
related_config_keys = []
|
75 |
+
config, current_related_config_keys = cls.validate_variables_and_set_defaults(config)
|
76 |
+
related_config_keys.extend(current_related_config_keys)
|
77 |
+
|
78 |
+
config, current_related_config_keys = cls.validate_external_data_tools_and_set_defaults(tenant_id, config)
|
79 |
+
related_config_keys.extend(current_related_config_keys)
|
80 |
+
|
81 |
+
return config, related_config_keys
|
82 |
+
|
83 |
+
@classmethod
|
84 |
+
def validate_variables_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
|
85 |
+
"""
|
86 |
+
Validate and set defaults for user input form
|
87 |
+
|
88 |
+
:param config: app model config args
|
89 |
+
"""
|
90 |
+
if not config.get("user_input_form"):
|
91 |
+
config["user_input_form"] = []
|
92 |
+
|
93 |
+
if not isinstance(config["user_input_form"], list):
|
94 |
+
raise ValueError("user_input_form must be a list of objects")
|
95 |
+
|
96 |
+
variables = []
|
97 |
+
for item in config["user_input_form"]:
|
98 |
+
key = list(item.keys())[0]
|
99 |
+
if key not in {"text-input", "select", "paragraph", "number", "external_data_tool"}:
|
100 |
+
raise ValueError("Keys in user_input_form list can only be 'text-input', 'paragraph' or 'select'")
|
101 |
+
|
102 |
+
form_item = item[key]
|
103 |
+
if "label" not in form_item:
|
104 |
+
raise ValueError("label is required in user_input_form")
|
105 |
+
|
106 |
+
if not isinstance(form_item["label"], str):
|
107 |
+
raise ValueError("label in user_input_form must be of string type")
|
108 |
+
|
109 |
+
if "variable" not in form_item:
|
110 |
+
raise ValueError("variable is required in user_input_form")
|
111 |
+
|
112 |
+
if not isinstance(form_item["variable"], str):
|
113 |
+
raise ValueError("variable in user_input_form must be of string type")
|
114 |
+
|
115 |
+
pattern = re.compile(r"^(?!\d)[\u4e00-\u9fa5A-Za-z0-9_\U0001F300-\U0001F64F\U0001F680-\U0001F6FF]{1,100}$")
|
116 |
+
if pattern.match(form_item["variable"]) is None:
|
117 |
+
raise ValueError("variable in user_input_form must be a string, and cannot start with a number")
|
118 |
+
|
119 |
+
variables.append(form_item["variable"])
|
120 |
+
|
121 |
+
if "required" not in form_item or not form_item["required"]:
|
122 |
+
form_item["required"] = False
|
123 |
+
|
124 |
+
if not isinstance(form_item["required"], bool):
|
125 |
+
raise ValueError("required in user_input_form must be of boolean type")
|
126 |
+
|
127 |
+
if key == "select":
|
128 |
+
if "options" not in form_item or not form_item["options"]:
|
129 |
+
form_item["options"] = []
|
130 |
+
|
131 |
+
if not isinstance(form_item["options"], list):
|
132 |
+
raise ValueError("options in user_input_form must be a list of strings")
|
133 |
+
|
134 |
+
if "default" in form_item and form_item["default"] and form_item["default"] not in form_item["options"]:
|
135 |
+
raise ValueError("default value in user_input_form must be in the options list")
|
136 |
+
|
137 |
+
return config, ["user_input_form"]
|
138 |
+
|
139 |
+
@classmethod
|
140 |
+
def validate_external_data_tools_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]:
|
141 |
+
"""
|
142 |
+
Validate and set defaults for external data fetch feature
|
143 |
+
|
144 |
+
:param tenant_id: workspace id
|
145 |
+
:param config: app model config args
|
146 |
+
"""
|
147 |
+
if not config.get("external_data_tools"):
|
148 |
+
config["external_data_tools"] = []
|
149 |
+
|
150 |
+
if not isinstance(config["external_data_tools"], list):
|
151 |
+
raise ValueError("external_data_tools must be of list type")
|
152 |
+
|
153 |
+
for tool in config["external_data_tools"]:
|
154 |
+
if "enabled" not in tool or not tool["enabled"]:
|
155 |
+
tool["enabled"] = False
|
156 |
+
|
157 |
+
if not tool["enabled"]:
|
158 |
+
continue
|
159 |
+
|
160 |
+
if "type" not in tool or not tool["type"]:
|
161 |
+
raise ValueError("external_data_tools[].type is required")
|
162 |
+
|
163 |
+
typ = tool["type"]
|
164 |
+
config = tool["config"]
|
165 |
+
|
166 |
+
ExternalDataToolFactory.validate_config(name=typ, tenant_id=tenant_id, config=config)
|
167 |
+
|
168 |
+
return config, ["external_data_tools"]
|
api/core/app/app_config/entities.py
ADDED
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections.abc import Sequence
|
2 |
+
from enum import Enum, StrEnum
|
3 |
+
from typing import Any, Optional
|
4 |
+
|
5 |
+
from pydantic import BaseModel, Field, field_validator
|
6 |
+
|
7 |
+
from core.file import FileTransferMethod, FileType, FileUploadConfig
|
8 |
+
from core.model_runtime.entities.message_entities import PromptMessageRole
|
9 |
+
from models.model import AppMode
|
10 |
+
|
11 |
+
|
12 |
+
class ModelConfigEntity(BaseModel):
|
13 |
+
"""
|
14 |
+
Model Config Entity.
|
15 |
+
"""
|
16 |
+
|
17 |
+
provider: str
|
18 |
+
model: str
|
19 |
+
mode: Optional[str] = None
|
20 |
+
parameters: dict[str, Any] = {}
|
21 |
+
stop: list[str] = []
|
22 |
+
|
23 |
+
|
24 |
+
class AdvancedChatMessageEntity(BaseModel):
|
25 |
+
"""
|
26 |
+
Advanced Chat Message Entity.
|
27 |
+
"""
|
28 |
+
|
29 |
+
text: str
|
30 |
+
role: PromptMessageRole
|
31 |
+
|
32 |
+
|
33 |
+
class AdvancedChatPromptTemplateEntity(BaseModel):
|
34 |
+
"""
|
35 |
+
Advanced Chat Prompt Template Entity.
|
36 |
+
"""
|
37 |
+
|
38 |
+
messages: list[AdvancedChatMessageEntity]
|
39 |
+
|
40 |
+
|
41 |
+
class AdvancedCompletionPromptTemplateEntity(BaseModel):
|
42 |
+
"""
|
43 |
+
Advanced Completion Prompt Template Entity.
|
44 |
+
"""
|
45 |
+
|
46 |
+
class RolePrefixEntity(BaseModel):
|
47 |
+
"""
|
48 |
+
Role Prefix Entity.
|
49 |
+
"""
|
50 |
+
|
51 |
+
user: str
|
52 |
+
assistant: str
|
53 |
+
|
54 |
+
prompt: str
|
55 |
+
role_prefix: Optional[RolePrefixEntity] = None
|
56 |
+
|
57 |
+
|
58 |
+
class PromptTemplateEntity(BaseModel):
|
59 |
+
"""
|
60 |
+
Prompt Template Entity.
|
61 |
+
"""
|
62 |
+
|
63 |
+
class PromptType(Enum):
|
64 |
+
"""
|
65 |
+
Prompt Type.
|
66 |
+
'simple', 'advanced'
|
67 |
+
"""
|
68 |
+
|
69 |
+
SIMPLE = "simple"
|
70 |
+
ADVANCED = "advanced"
|
71 |
+
|
72 |
+
@classmethod
|
73 |
+
def value_of(cls, value: str):
|
74 |
+
"""
|
75 |
+
Get value of given mode.
|
76 |
+
|
77 |
+
:param value: mode value
|
78 |
+
:return: mode
|
79 |
+
"""
|
80 |
+
for mode in cls:
|
81 |
+
if mode.value == value:
|
82 |
+
return mode
|
83 |
+
raise ValueError(f"invalid prompt type value {value}")
|
84 |
+
|
85 |
+
prompt_type: PromptType
|
86 |
+
simple_prompt_template: Optional[str] = None
|
87 |
+
advanced_chat_prompt_template: Optional[AdvancedChatPromptTemplateEntity] = None
|
88 |
+
advanced_completion_prompt_template: Optional[AdvancedCompletionPromptTemplateEntity] = None
|
89 |
+
|
90 |
+
|
91 |
+
class VariableEntityType(StrEnum):
|
92 |
+
TEXT_INPUT = "text-input"
|
93 |
+
SELECT = "select"
|
94 |
+
PARAGRAPH = "paragraph"
|
95 |
+
NUMBER = "number"
|
96 |
+
EXTERNAL_DATA_TOOL = "external_data_tool"
|
97 |
+
FILE = "file"
|
98 |
+
FILE_LIST = "file-list"
|
99 |
+
|
100 |
+
|
101 |
+
class VariableEntity(BaseModel):
|
102 |
+
"""
|
103 |
+
Variable Entity.
|
104 |
+
"""
|
105 |
+
|
106 |
+
variable: str
|
107 |
+
label: str
|
108 |
+
description: str = ""
|
109 |
+
type: VariableEntityType
|
110 |
+
required: bool = False
|
111 |
+
max_length: Optional[int] = None
|
112 |
+
options: Sequence[str] = Field(default_factory=list)
|
113 |
+
allowed_file_types: Sequence[FileType] = Field(default_factory=list)
|
114 |
+
allowed_file_extensions: Sequence[str] = Field(default_factory=list)
|
115 |
+
allowed_file_upload_methods: Sequence[FileTransferMethod] = Field(default_factory=list)
|
116 |
+
|
117 |
+
@field_validator("description", mode="before")
|
118 |
+
@classmethod
|
119 |
+
def convert_none_description(cls, v: Any) -> str:
|
120 |
+
return v or ""
|
121 |
+
|
122 |
+
@field_validator("options", mode="before")
|
123 |
+
@classmethod
|
124 |
+
def convert_none_options(cls, v: Any) -> Sequence[str]:
|
125 |
+
return v or []
|
126 |
+
|
127 |
+
|
128 |
+
class ExternalDataVariableEntity(BaseModel):
|
129 |
+
"""
|
130 |
+
External Data Variable Entity.
|
131 |
+
"""
|
132 |
+
|
133 |
+
variable: str
|
134 |
+
type: str
|
135 |
+
config: dict[str, Any] = {}
|
136 |
+
|
137 |
+
|
138 |
+
class DatasetRetrieveConfigEntity(BaseModel):
|
139 |
+
"""
|
140 |
+
Dataset Retrieve Config Entity.
|
141 |
+
"""
|
142 |
+
|
143 |
+
class RetrieveStrategy(Enum):
|
144 |
+
"""
|
145 |
+
Dataset Retrieve Strategy.
|
146 |
+
'single' or 'multiple'
|
147 |
+
"""
|
148 |
+
|
149 |
+
SINGLE = "single"
|
150 |
+
MULTIPLE = "multiple"
|
151 |
+
|
152 |
+
@classmethod
|
153 |
+
def value_of(cls, value: str):
|
154 |
+
"""
|
155 |
+
Get value of given mode.
|
156 |
+
|
157 |
+
:param value: mode value
|
158 |
+
:return: mode
|
159 |
+
"""
|
160 |
+
for mode in cls:
|
161 |
+
if mode.value == value:
|
162 |
+
return mode
|
163 |
+
raise ValueError(f"invalid retrieve strategy value {value}")
|
164 |
+
|
165 |
+
query_variable: Optional[str] = None # Only when app mode is completion
|
166 |
+
|
167 |
+
retrieve_strategy: RetrieveStrategy
|
168 |
+
top_k: Optional[int] = None
|
169 |
+
score_threshold: Optional[float] = 0.0
|
170 |
+
rerank_mode: Optional[str] = "reranking_model"
|
171 |
+
reranking_model: Optional[dict] = None
|
172 |
+
weights: Optional[dict] = None
|
173 |
+
reranking_enabled: Optional[bool] = True
|
174 |
+
|
175 |
+
|
176 |
+
class DatasetEntity(BaseModel):
|
177 |
+
"""
|
178 |
+
Dataset Config Entity.
|
179 |
+
"""
|
180 |
+
|
181 |
+
dataset_ids: list[str]
|
182 |
+
retrieve_config: DatasetRetrieveConfigEntity
|
183 |
+
|
184 |
+
|
185 |
+
class SensitiveWordAvoidanceEntity(BaseModel):
|
186 |
+
"""
|
187 |
+
Sensitive Word Avoidance Entity.
|
188 |
+
"""
|
189 |
+
|
190 |
+
type: str
|
191 |
+
config: dict[str, Any] = {}
|
192 |
+
|
193 |
+
|
194 |
+
class TextToSpeechEntity(BaseModel):
|
195 |
+
"""
|
196 |
+
Sensitive Word Avoidance Entity.
|
197 |
+
"""
|
198 |
+
|
199 |
+
enabled: bool
|
200 |
+
voice: Optional[str] = None
|
201 |
+
language: Optional[str] = None
|
202 |
+
|
203 |
+
|
204 |
+
class TracingConfigEntity(BaseModel):
|
205 |
+
"""
|
206 |
+
Tracing Config Entity.
|
207 |
+
"""
|
208 |
+
|
209 |
+
enabled: bool
|
210 |
+
tracing_provider: str
|
211 |
+
|
212 |
+
|
213 |
+
class AppAdditionalFeatures(BaseModel):
|
214 |
+
file_upload: Optional[FileUploadConfig] = None
|
215 |
+
opening_statement: Optional[str] = None
|
216 |
+
suggested_questions: list[str] = []
|
217 |
+
suggested_questions_after_answer: bool = False
|
218 |
+
show_retrieve_source: bool = False
|
219 |
+
more_like_this: bool = False
|
220 |
+
speech_to_text: bool = False
|
221 |
+
text_to_speech: Optional[TextToSpeechEntity] = None
|
222 |
+
trace_config: Optional[TracingConfigEntity] = None
|
223 |
+
|
224 |
+
|
225 |
+
class AppConfig(BaseModel):
|
226 |
+
"""
|
227 |
+
Application Config Entity.
|
228 |
+
"""
|
229 |
+
|
230 |
+
tenant_id: str
|
231 |
+
app_id: str
|
232 |
+
app_mode: AppMode
|
233 |
+
additional_features: AppAdditionalFeatures
|
234 |
+
variables: list[VariableEntity] = []
|
235 |
+
sensitive_word_avoidance: Optional[SensitiveWordAvoidanceEntity] = None
|
236 |
+
|
237 |
+
|
238 |
+
class EasyUIBasedAppModelConfigFrom(Enum):
|
239 |
+
"""
|
240 |
+
App Model Config From.
|
241 |
+
"""
|
242 |
+
|
243 |
+
ARGS = "args"
|
244 |
+
APP_LATEST_CONFIG = "app-latest-config"
|
245 |
+
CONVERSATION_SPECIFIC_CONFIG = "conversation-specific-config"
|
246 |
+
|
247 |
+
|
248 |
+
class EasyUIBasedAppConfig(AppConfig):
|
249 |
+
"""
|
250 |
+
Easy UI Based App Config Entity.
|
251 |
+
"""
|
252 |
+
|
253 |
+
app_model_config_from: EasyUIBasedAppModelConfigFrom
|
254 |
+
app_model_config_id: str
|
255 |
+
app_model_config_dict: dict
|
256 |
+
model: ModelConfigEntity
|
257 |
+
prompt_template: PromptTemplateEntity
|
258 |
+
dataset: Optional[DatasetEntity] = None
|
259 |
+
external_data_variables: list[ExternalDataVariableEntity] = []
|
260 |
+
|
261 |
+
|
262 |
+
class WorkflowUIBasedAppConfig(AppConfig):
|
263 |
+
"""
|
264 |
+
Workflow UI Based App Config Entity.
|
265 |
+
"""
|
266 |
+
|
267 |
+
workflow_id: str
|
api/core/app/app_config/features/__init__.py
ADDED
File without changes
|
api/core/app/app_config/features/file_upload/__init__.py
ADDED
File without changes
|
api/core/app/app_config/features/file_upload/manager.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections.abc import Mapping
|
2 |
+
from typing import Any
|
3 |
+
|
4 |
+
from core.file import FileUploadConfig
|
5 |
+
|
6 |
+
|
7 |
+
class FileUploadConfigManager:
|
8 |
+
@classmethod
|
9 |
+
def convert(cls, config: Mapping[str, Any], is_vision: bool = True):
|
10 |
+
"""
|
11 |
+
Convert model config to model config
|
12 |
+
|
13 |
+
:param config: model config args
|
14 |
+
:param is_vision: if True, the feature is vision feature
|
15 |
+
"""
|
16 |
+
file_upload_dict = config.get("file_upload")
|
17 |
+
if file_upload_dict:
|
18 |
+
if file_upload_dict.get("enabled"):
|
19 |
+
transform_methods = file_upload_dict.get("allowed_file_upload_methods", [])
|
20 |
+
data = {
|
21 |
+
"image_config": {
|
22 |
+
"number_limits": file_upload_dict["number_limits"],
|
23 |
+
"transfer_methods": transform_methods,
|
24 |
+
}
|
25 |
+
}
|
26 |
+
|
27 |
+
if is_vision:
|
28 |
+
data["image_config"]["detail"] = file_upload_dict.get("image", {}).get("detail", "low")
|
29 |
+
|
30 |
+
return FileUploadConfig.model_validate(data)
|
31 |
+
|
32 |
+
@classmethod
|
33 |
+
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
|
34 |
+
"""
|
35 |
+
Validate and set defaults for file upload feature
|
36 |
+
|
37 |
+
:param config: app model config args
|
38 |
+
"""
|
39 |
+
if not config.get("file_upload"):
|
40 |
+
config["file_upload"] = {}
|
41 |
+
else:
|
42 |
+
FileUploadConfig.model_validate(config["file_upload"])
|
43 |
+
|
44 |
+
return config, ["file_upload"]
|
api/core/app/app_config/features/more_like_this/__init__.py
ADDED
File without changes
|
api/core/app/app_config/features/more_like_this/manager.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class MoreLikeThisConfigManager:
|
2 |
+
@classmethod
|
3 |
+
def convert(cls, config: dict) -> bool:
|
4 |
+
"""
|
5 |
+
Convert model config to model config
|
6 |
+
|
7 |
+
:param config: model config args
|
8 |
+
"""
|
9 |
+
more_like_this = False
|
10 |
+
more_like_this_dict = config.get("more_like_this")
|
11 |
+
if more_like_this_dict:
|
12 |
+
if more_like_this_dict.get("enabled"):
|
13 |
+
more_like_this = True
|
14 |
+
|
15 |
+
return more_like_this
|
16 |
+
|
17 |
+
@classmethod
|
18 |
+
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
|
19 |
+
"""
|
20 |
+
Validate and set defaults for more like this feature
|
21 |
+
|
22 |
+
:param config: app model config args
|
23 |
+
"""
|
24 |
+
if not config.get("more_like_this"):
|
25 |
+
config["more_like_this"] = {"enabled": False}
|
26 |
+
|
27 |
+
if not isinstance(config["more_like_this"], dict):
|
28 |
+
raise ValueError("more_like_this must be of dict type")
|
29 |
+
|
30 |
+
if "enabled" not in config["more_like_this"] or not config["more_like_this"]["enabled"]:
|
31 |
+
config["more_like_this"]["enabled"] = False
|
32 |
+
|
33 |
+
if not isinstance(config["more_like_this"]["enabled"], bool):
|
34 |
+
raise ValueError("enabled in more_like_this must be of boolean type")
|
35 |
+
|
36 |
+
return config, ["more_like_this"]
|
api/core/app/app_config/features/opening_statement/__init__.py
ADDED
File without changes
|
api/core/app/app_config/features/opening_statement/manager.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class OpeningStatementConfigManager:
|
2 |
+
@classmethod
|
3 |
+
def convert(cls, config: dict) -> tuple[str, list]:
|
4 |
+
"""
|
5 |
+
Convert model config to model config
|
6 |
+
|
7 |
+
:param config: model config args
|
8 |
+
"""
|
9 |
+
# opening statement
|
10 |
+
opening_statement = config.get("opening_statement", "")
|
11 |
+
|
12 |
+
# suggested questions
|
13 |
+
suggested_questions_list = config.get("suggested_questions", [])
|
14 |
+
|
15 |
+
return opening_statement, suggested_questions_list
|
16 |
+
|
17 |
+
@classmethod
|
18 |
+
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
|
19 |
+
"""
|
20 |
+
Validate and set defaults for opening statement feature
|
21 |
+
|
22 |
+
:param config: app model config args
|
23 |
+
"""
|
24 |
+
if not config.get("opening_statement"):
|
25 |
+
config["opening_statement"] = ""
|
26 |
+
|
27 |
+
if not isinstance(config["opening_statement"], str):
|
28 |
+
raise ValueError("opening_statement must be of string type")
|
29 |
+
|
30 |
+
# suggested_questions
|
31 |
+
if not config.get("suggested_questions"):
|
32 |
+
config["suggested_questions"] = []
|
33 |
+
|
34 |
+
if not isinstance(config["suggested_questions"], list):
|
35 |
+
raise ValueError("suggested_questions must be of list type")
|
36 |
+
|
37 |
+
for question in config["suggested_questions"]:
|
38 |
+
if not isinstance(question, str):
|
39 |
+
raise ValueError("Elements in suggested_questions list must be of string type")
|
40 |
+
|
41 |
+
return config, ["opening_statement", "suggested_questions"]
|
api/core/app/app_config/features/retrieval_resource/__init__.py
ADDED
File without changes
|
api/core/app/app_config/features/retrieval_resource/manager.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class RetrievalResourceConfigManager:
|
2 |
+
@classmethod
|
3 |
+
def convert(cls, config: dict) -> bool:
|
4 |
+
show_retrieve_source = False
|
5 |
+
retriever_resource_dict = config.get("retriever_resource")
|
6 |
+
if retriever_resource_dict:
|
7 |
+
if retriever_resource_dict.get("enabled"):
|
8 |
+
show_retrieve_source = True
|
9 |
+
|
10 |
+
return show_retrieve_source
|
11 |
+
|
12 |
+
@classmethod
|
13 |
+
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
|
14 |
+
"""
|
15 |
+
Validate and set defaults for retriever resource feature
|
16 |
+
|
17 |
+
:param config: app model config args
|
18 |
+
"""
|
19 |
+
if not config.get("retriever_resource"):
|
20 |
+
config["retriever_resource"] = {"enabled": False}
|
21 |
+
|
22 |
+
if not isinstance(config["retriever_resource"], dict):
|
23 |
+
raise ValueError("retriever_resource must be of dict type")
|
24 |
+
|
25 |
+
if "enabled" not in config["retriever_resource"] or not config["retriever_resource"]["enabled"]:
|
26 |
+
config["retriever_resource"]["enabled"] = False
|
27 |
+
|
28 |
+
if not isinstance(config["retriever_resource"]["enabled"], bool):
|
29 |
+
raise ValueError("enabled in retriever_resource must be of boolean type")
|
30 |
+
|
31 |
+
return config, ["retriever_resource"]
|
api/core/app/app_config/features/speech_to_text/__init__.py
ADDED
File without changes
|
api/core/app/app_config/features/speech_to_text/manager.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class SpeechToTextConfigManager:
|
2 |
+
@classmethod
|
3 |
+
def convert(cls, config: dict) -> bool:
|
4 |
+
"""
|
5 |
+
Convert model config to model config
|
6 |
+
|
7 |
+
:param config: model config args
|
8 |
+
"""
|
9 |
+
speech_to_text = False
|
10 |
+
speech_to_text_dict = config.get("speech_to_text")
|
11 |
+
if speech_to_text_dict:
|
12 |
+
if speech_to_text_dict.get("enabled"):
|
13 |
+
speech_to_text = True
|
14 |
+
|
15 |
+
return speech_to_text
|
16 |
+
|
17 |
+
@classmethod
|
18 |
+
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
|
19 |
+
"""
|
20 |
+
Validate and set defaults for speech to text feature
|
21 |
+
|
22 |
+
:param config: app model config args
|
23 |
+
"""
|
24 |
+
if not config.get("speech_to_text"):
|
25 |
+
config["speech_to_text"] = {"enabled": False}
|
26 |
+
|
27 |
+
if not isinstance(config["speech_to_text"], dict):
|
28 |
+
raise ValueError("speech_to_text must be of dict type")
|
29 |
+
|
30 |
+
if "enabled" not in config["speech_to_text"] or not config["speech_to_text"]["enabled"]:
|
31 |
+
config["speech_to_text"]["enabled"] = False
|
32 |
+
|
33 |
+
if not isinstance(config["speech_to_text"]["enabled"], bool):
|
34 |
+
raise ValueError("enabled in speech_to_text must be of boolean type")
|
35 |
+
|
36 |
+
return config, ["speech_to_text"]
|
api/core/app/app_config/features/suggested_questions_after_answer/__init__.py
ADDED
File without changes
|
api/core/app/app_config/features/suggested_questions_after_answer/manager.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class SuggestedQuestionsAfterAnswerConfigManager:
|
2 |
+
@classmethod
|
3 |
+
def convert(cls, config: dict) -> bool:
|
4 |
+
"""
|
5 |
+
Convert model config to model config
|
6 |
+
|
7 |
+
:param config: model config args
|
8 |
+
"""
|
9 |
+
suggested_questions_after_answer = False
|
10 |
+
suggested_questions_after_answer_dict = config.get("suggested_questions_after_answer")
|
11 |
+
if suggested_questions_after_answer_dict:
|
12 |
+
if suggested_questions_after_answer_dict.get("enabled"):
|
13 |
+
suggested_questions_after_answer = True
|
14 |
+
|
15 |
+
return suggested_questions_after_answer
|
16 |
+
|
17 |
+
@classmethod
|
18 |
+
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
|
19 |
+
"""
|
20 |
+
Validate and set defaults for suggested questions feature
|
21 |
+
|
22 |
+
:param config: app model config args
|
23 |
+
"""
|
24 |
+
if not config.get("suggested_questions_after_answer"):
|
25 |
+
config["suggested_questions_after_answer"] = {"enabled": False}
|
26 |
+
|
27 |
+
if not isinstance(config["suggested_questions_after_answer"], dict):
|
28 |
+
raise ValueError("suggested_questions_after_answer must be of dict type")
|
29 |
+
|
30 |
+
if (
|
31 |
+
"enabled" not in config["suggested_questions_after_answer"]
|
32 |
+
or not config["suggested_questions_after_answer"]["enabled"]
|
33 |
+
):
|
34 |
+
config["suggested_questions_after_answer"]["enabled"] = False
|
35 |
+
|
36 |
+
if not isinstance(config["suggested_questions_after_answer"]["enabled"], bool):
|
37 |
+
raise ValueError("enabled in suggested_questions_after_answer must be of boolean type")
|
38 |
+
|
39 |
+
return config, ["suggested_questions_after_answer"]
|
api/core/app/app_config/features/text_to_speech/__init__.py
ADDED
File without changes
|
api/core/app/app_config/features/text_to_speech/manager.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from core.app.app_config.entities import TextToSpeechEntity
|
2 |
+
|
3 |
+
|
4 |
+
class TextToSpeechConfigManager:
|
5 |
+
@classmethod
|
6 |
+
def convert(cls, config: dict):
|
7 |
+
"""
|
8 |
+
Convert model config to model config
|
9 |
+
|
10 |
+
:param config: model config args
|
11 |
+
"""
|
12 |
+
text_to_speech = None
|
13 |
+
text_to_speech_dict = config.get("text_to_speech")
|
14 |
+
if text_to_speech_dict:
|
15 |
+
if text_to_speech_dict.get("enabled"):
|
16 |
+
text_to_speech = TextToSpeechEntity(
|
17 |
+
enabled=text_to_speech_dict.get("enabled"),
|
18 |
+
voice=text_to_speech_dict.get("voice"),
|
19 |
+
language=text_to_speech_dict.get("language"),
|
20 |
+
)
|
21 |
+
|
22 |
+
return text_to_speech
|
23 |
+
|
24 |
+
@classmethod
|
25 |
+
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
|
26 |
+
"""
|
27 |
+
Validate and set defaults for text to speech feature
|
28 |
+
|
29 |
+
:param config: app model config args
|
30 |
+
"""
|
31 |
+
if not config.get("text_to_speech"):
|
32 |
+
config["text_to_speech"] = {"enabled": False, "voice": "", "language": ""}
|
33 |
+
|
34 |
+
if not isinstance(config["text_to_speech"], dict):
|
35 |
+
raise ValueError("text_to_speech must be of dict type")
|
36 |
+
|
37 |
+
if "enabled" not in config["text_to_speech"] or not config["text_to_speech"]["enabled"]:
|
38 |
+
config["text_to_speech"]["enabled"] = False
|
39 |
+
config["text_to_speech"]["voice"] = ""
|
40 |
+
config["text_to_speech"]["language"] = ""
|
41 |
+
|
42 |
+
if not isinstance(config["text_to_speech"]["enabled"], bool):
|
43 |
+
raise ValueError("enabled in text_to_speech must be of boolean type")
|
44 |
+
|
45 |
+
return config, ["text_to_speech"]
|
api/core/app/app_config/workflow_ui_based_app/__init__.py
ADDED
File without changes
|
api/core/app/app_config/workflow_ui_based_app/variables/__init__.py
ADDED
File without changes
|
api/core/app/app_config/workflow_ui_based_app/variables/manager.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from core.app.app_config.entities import VariableEntity
|
2 |
+
from models.workflow import Workflow
|
3 |
+
|
4 |
+
|
5 |
+
class WorkflowVariablesConfigManager:
|
6 |
+
@classmethod
|
7 |
+
def convert(cls, workflow: Workflow) -> list[VariableEntity]:
|
8 |
+
"""
|
9 |
+
Convert workflow start variables to variables
|
10 |
+
|
11 |
+
:param workflow: workflow instance
|
12 |
+
"""
|
13 |
+
variables = []
|
14 |
+
|
15 |
+
# find start node
|
16 |
+
user_input_form = workflow.user_input_form()
|
17 |
+
|
18 |
+
# variables
|
19 |
+
for variable in user_input_form:
|
20 |
+
variables.append(VariableEntity.model_validate(variable))
|
21 |
+
|
22 |
+
return variables
|
api/core/app/apps/README.md
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Guidelines for Database Connection Management in App Runner and Task Pipeline
|
2 |
+
|
3 |
+
Due to the presence of tasks in App Runner that require long execution times, such as LLM generation and external requests, Flask-Sqlalchemy's strategy for database connection pooling is to allocate one connection (transaction) per request. This approach keeps a connection occupied even during non-DB tasks, leading to the inability to acquire new connections during high concurrency requests due to multiple long-running tasks.
|
4 |
+
|
5 |
+
Therefore, the database operations in App Runner and Task Pipeline must ensure connections are closed immediately after use, and it's better to pass IDs rather than Model objects to avoid detach errors.
|
6 |
+
|
7 |
+
Examples:
|
8 |
+
|
9 |
+
1. Creating a new record:
|
10 |
+
|
11 |
+
```python
|
12 |
+
app = App(id=1)
|
13 |
+
db.session.add(app)
|
14 |
+
db.session.commit()
|
15 |
+
db.session.refresh(app) # Retrieve table default values, like created_at, cached in the app object, won't affect after close
|
16 |
+
|
17 |
+
# Handle non-long-running tasks or store the content of the App instance in memory (via variable assignment).
|
18 |
+
|
19 |
+
db.session.close()
|
20 |
+
|
21 |
+
return app.id
|
22 |
+
```
|
23 |
+
|
24 |
+
2. Fetching a record from the table:
|
25 |
+
|
26 |
+
```python
|
27 |
+
app = db.session.query(App).filter(App.id == app_id).first()
|
28 |
+
|
29 |
+
created_at = app.created_at
|
30 |
+
|
31 |
+
db.session.close()
|
32 |
+
|
33 |
+
# Handle tasks (include long-running).
|
34 |
+
|
35 |
+
```
|
36 |
+
|
37 |
+
3. Updating a table field:
|
38 |
+
|
39 |
+
```python
|
40 |
+
app = db.session.query(App).filter(App.id == app_id).first()
|
41 |
+
|
42 |
+
app.updated_at = time.utcnow()
|
43 |
+
db.session.commit()
|
44 |
+
db.session.close()
|
45 |
+
|
46 |
+
return app_id
|
47 |
+
```
|
48 |
+
|
api/core/app/apps/__init__.py
ADDED
File without changes
|
api/core/app/apps/advanced_chat/__init__.py
ADDED
File without changes
|
api/core/app/apps/advanced_chat/app_config_manager.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from core.app.app_config.base_app_config_manager import BaseAppConfigManager
|
2 |
+
from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager
|
3 |
+
from core.app.app_config.entities import WorkflowUIBasedAppConfig
|
4 |
+
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
5 |
+
from core.app.app_config.features.opening_statement.manager import OpeningStatementConfigManager
|
6 |
+
from core.app.app_config.features.retrieval_resource.manager import RetrievalResourceConfigManager
|
7 |
+
from core.app.app_config.features.speech_to_text.manager import SpeechToTextConfigManager
|
8 |
+
from core.app.app_config.features.suggested_questions_after_answer.manager import (
|
9 |
+
SuggestedQuestionsAfterAnswerConfigManager,
|
10 |
+
)
|
11 |
+
from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager
|
12 |
+
from core.app.app_config.workflow_ui_based_app.variables.manager import WorkflowVariablesConfigManager
|
13 |
+
from models.model import App, AppMode
|
14 |
+
from models.workflow import Workflow
|
15 |
+
|
16 |
+
|
17 |
+
class AdvancedChatAppConfig(WorkflowUIBasedAppConfig):
|
18 |
+
"""
|
19 |
+
Advanced Chatbot App Config Entity.
|
20 |
+
"""
|
21 |
+
|
22 |
+
pass
|
23 |
+
|
24 |
+
|
25 |
+
class AdvancedChatAppConfigManager(BaseAppConfigManager):
|
26 |
+
@classmethod
|
27 |
+
def get_app_config(cls, app_model: App, workflow: Workflow) -> AdvancedChatAppConfig:
|
28 |
+
features_dict = workflow.features_dict
|
29 |
+
|
30 |
+
app_mode = AppMode.value_of(app_model.mode)
|
31 |
+
app_config = AdvancedChatAppConfig(
|
32 |
+
tenant_id=app_model.tenant_id,
|
33 |
+
app_id=app_model.id,
|
34 |
+
app_mode=app_mode,
|
35 |
+
workflow_id=workflow.id,
|
36 |
+
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=features_dict),
|
37 |
+
variables=WorkflowVariablesConfigManager.convert(workflow=workflow),
|
38 |
+
additional_features=cls.convert_features(features_dict, app_mode),
|
39 |
+
)
|
40 |
+
|
41 |
+
return app_config
|
42 |
+
|
43 |
+
@classmethod
|
44 |
+
def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False) -> dict:
|
45 |
+
"""
|
46 |
+
Validate for advanced chat app model config
|
47 |
+
|
48 |
+
:param tenant_id: tenant id
|
49 |
+
:param config: app model config args
|
50 |
+
:param only_structure_validate: if True, only structure validation will be performed
|
51 |
+
"""
|
52 |
+
related_config_keys = []
|
53 |
+
|
54 |
+
# file upload validation
|
55 |
+
config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(config=config)
|
56 |
+
related_config_keys.extend(current_related_config_keys)
|
57 |
+
|
58 |
+
# opening_statement
|
59 |
+
config, current_related_config_keys = OpeningStatementConfigManager.validate_and_set_defaults(config)
|
60 |
+
related_config_keys.extend(current_related_config_keys)
|
61 |
+
|
62 |
+
# suggested_questions_after_answer
|
63 |
+
config, current_related_config_keys = SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults(
|
64 |
+
config
|
65 |
+
)
|
66 |
+
related_config_keys.extend(current_related_config_keys)
|
67 |
+
|
68 |
+
# speech_to_text
|
69 |
+
config, current_related_config_keys = SpeechToTextConfigManager.validate_and_set_defaults(config)
|
70 |
+
related_config_keys.extend(current_related_config_keys)
|
71 |
+
|
72 |
+
# text_to_speech
|
73 |
+
config, current_related_config_keys = TextToSpeechConfigManager.validate_and_set_defaults(config)
|
74 |
+
related_config_keys.extend(current_related_config_keys)
|
75 |
+
|
76 |
+
# return retriever resource
|
77 |
+
config, current_related_config_keys = RetrievalResourceConfigManager.validate_and_set_defaults(config)
|
78 |
+
related_config_keys.extend(current_related_config_keys)
|
79 |
+
|
80 |
+
# moderation validation
|
81 |
+
config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(
|
82 |
+
tenant_id=tenant_id, config=config, only_structure_validate=only_structure_validate
|
83 |
+
)
|
84 |
+
related_config_keys.extend(current_related_config_keys)
|
85 |
+
|
86 |
+
related_config_keys = list(set(related_config_keys))
|
87 |
+
|
88 |
+
# Filter out extra parameters
|
89 |
+
filtered_config = {key: config.get(key) for key in related_config_keys}
|
90 |
+
|
91 |
+
return filtered_config
|