Upload 1150 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +6 -0
- api/core/__init__.py +1 -0
- api/core/hosting_configuration.py +255 -0
- api/core/indexing_runner.py +754 -0
- api/core/model_manager.py +559 -0
- api/core/moderation/__init__.py +0 -0
- api/core/moderation/api/__builtin__ +1 -0
- api/core/moderation/api/__init__.py +0 -0
- api/core/moderation/api/api.py +96 -0
- api/core/moderation/base.py +115 -0
- api/core/moderation/factory.py +49 -0
- api/core/moderation/input_moderation.py +71 -0
- api/core/moderation/keywords/__builtin__ +1 -0
- api/core/moderation/keywords/__init__.py +0 -0
- api/core/moderation/keywords/keywords.py +73 -0
- api/core/moderation/openai_moderation/__builtin__ +1 -0
- api/core/moderation/openai_moderation/__init__.py +0 -0
- api/core/moderation/openai_moderation/openai_moderation.py +60 -0
- api/core/moderation/output_moderation.py +131 -0
- api/core/ops/__init__.py +0 -0
- api/core/ops/base_trace_instance.py +26 -0
- api/core/ops/entities/__init__.py +0 -0
- api/core/ops/entities/config_entity.py +92 -0
- api/core/ops/entities/trace_entity.py +134 -0
- api/core/ops/langfuse_trace/__init__.py +0 -0
- api/core/ops/langfuse_trace/entities/__init__.py +0 -0
- api/core/ops/langfuse_trace/entities/langfuse_trace_entity.py +282 -0
- api/core/ops/langfuse_trace/langfuse_trace.py +455 -0
- api/core/ops/langsmith_trace/__init__.py +0 -0
- api/core/ops/langsmith_trace/entities/__init__.py +0 -0
- api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py +141 -0
- api/core/ops/langsmith_trace/langsmith_trace.py +524 -0
- api/core/ops/opik_trace/__init__.py +0 -0
- api/core/ops/opik_trace/opik_trace.py +469 -0
- api/core/ops/ops_trace_manager.py +811 -0
- api/core/ops/utils.py +62 -0
- api/core/prompt/__init__.py +0 -0
- api/core/prompt/advanced_prompt_transform.py +287 -0
- api/core/prompt/agent_history_prompt_transform.py +80 -0
- api/core/prompt/entities/__init__.py +0 -0
- api/core/prompt/entities/advanced_prompt_entities.py +50 -0
- api/core/prompt/prompt_templates/__init__.py +0 -0
- api/core/prompt/prompt_templates/advanced_prompt_templates.py +45 -0
- api/core/prompt/prompt_templates/baichuan_chat.json +13 -0
- api/core/prompt/prompt_templates/baichuan_completion.json +9 -0
- api/core/prompt/prompt_templates/common_chat.json +13 -0
- api/core/prompt/prompt_templates/common_completion.json +9 -0
- api/core/prompt/prompt_transform.py +90 -0
- api/core/prompt/simple_prompt_transform.py +327 -0
- api/core/prompt/utils/__init__.py +0 -0
.gitattributes
CHANGED
@@ -6,3 +6,9 @@
|
|
6 |
|
7 |
*.sh text eol=lf
|
8 |
api/tests/integration_tests/model_runtime/assets/audio.mp3 filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
*.sh text eol=lf
|
8 |
api/tests/integration_tests/model_runtime/assets/audio.mp3 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
api/core/tools/docs/images/index/image-1.png filter=lfs diff=lfs merge=lfs -text
|
10 |
+
api/core/tools/docs/images/index/image-2.png filter=lfs diff=lfs merge=lfs -text
|
11 |
+
api/core/tools/docs/images/index/image.png filter=lfs diff=lfs merge=lfs -text
|
12 |
+
api/core/tools/provider/builtin/comfyui/_assets/icon.png filter=lfs diff=lfs merge=lfs -text
|
13 |
+
api/core/tools/provider/builtin/dalle/_assets/icon.png filter=lfs diff=lfs merge=lfs -text
|
14 |
+
api/core/tools/provider/builtin/wecom/_assets/icon.png filter=lfs diff=lfs merge=lfs -text
|
api/core/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
import core.moderation.base
|
api/core/hosting_configuration.py
ADDED
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
from flask import Flask
|
4 |
+
from pydantic import BaseModel
|
5 |
+
|
6 |
+
from configs import dify_config
|
7 |
+
from core.entities.provider_entities import QuotaUnit, RestrictModel
|
8 |
+
from core.model_runtime.entities.model_entities import ModelType
|
9 |
+
from models.provider import ProviderQuotaType
|
10 |
+
|
11 |
+
|
12 |
+
class HostingQuota(BaseModel):
|
13 |
+
quota_type: ProviderQuotaType
|
14 |
+
restrict_models: list[RestrictModel] = []
|
15 |
+
|
16 |
+
|
17 |
+
class TrialHostingQuota(HostingQuota):
|
18 |
+
quota_type: ProviderQuotaType = ProviderQuotaType.TRIAL
|
19 |
+
quota_limit: int = 0
|
20 |
+
"""Quota limit for the hosting provider models. -1 means unlimited."""
|
21 |
+
|
22 |
+
|
23 |
+
class PaidHostingQuota(HostingQuota):
|
24 |
+
quota_type: ProviderQuotaType = ProviderQuotaType.PAID
|
25 |
+
|
26 |
+
|
27 |
+
class FreeHostingQuota(HostingQuota):
|
28 |
+
quota_type: ProviderQuotaType = ProviderQuotaType.FREE
|
29 |
+
|
30 |
+
|
31 |
+
class HostingProvider(BaseModel):
|
32 |
+
enabled: bool = False
|
33 |
+
credentials: Optional[dict] = None
|
34 |
+
quota_unit: Optional[QuotaUnit] = None
|
35 |
+
quotas: list[HostingQuota] = []
|
36 |
+
|
37 |
+
|
38 |
+
class HostedModerationConfig(BaseModel):
|
39 |
+
enabled: bool = False
|
40 |
+
providers: list[str] = []
|
41 |
+
|
42 |
+
|
43 |
+
class HostingConfiguration:
|
44 |
+
provider_map: dict[str, HostingProvider] = {}
|
45 |
+
moderation_config: Optional[HostedModerationConfig] = None
|
46 |
+
|
47 |
+
def init_app(self, app: Flask) -> None:
|
48 |
+
if dify_config.EDITION != "CLOUD":
|
49 |
+
return
|
50 |
+
|
51 |
+
self.provider_map["azure_openai"] = self.init_azure_openai()
|
52 |
+
self.provider_map["openai"] = self.init_openai()
|
53 |
+
self.provider_map["anthropic"] = self.init_anthropic()
|
54 |
+
self.provider_map["minimax"] = self.init_minimax()
|
55 |
+
self.provider_map["spark"] = self.init_spark()
|
56 |
+
self.provider_map["zhipuai"] = self.init_zhipuai()
|
57 |
+
|
58 |
+
self.moderation_config = self.init_moderation_config()
|
59 |
+
|
60 |
+
@staticmethod
|
61 |
+
def init_azure_openai() -> HostingProvider:
|
62 |
+
quota_unit = QuotaUnit.TIMES
|
63 |
+
if dify_config.HOSTED_AZURE_OPENAI_ENABLED:
|
64 |
+
credentials = {
|
65 |
+
"openai_api_key": dify_config.HOSTED_AZURE_OPENAI_API_KEY,
|
66 |
+
"openai_api_base": dify_config.HOSTED_AZURE_OPENAI_API_BASE,
|
67 |
+
"base_model_name": "gpt-35-turbo",
|
68 |
+
}
|
69 |
+
|
70 |
+
quotas: list[HostingQuota] = []
|
71 |
+
hosted_quota_limit = dify_config.HOSTED_AZURE_OPENAI_QUOTA_LIMIT
|
72 |
+
trial_quota = TrialHostingQuota(
|
73 |
+
quota_limit=hosted_quota_limit,
|
74 |
+
restrict_models=[
|
75 |
+
RestrictModel(model="gpt-4", base_model_name="gpt-4", model_type=ModelType.LLM),
|
76 |
+
RestrictModel(model="gpt-4o", base_model_name="gpt-4o", model_type=ModelType.LLM),
|
77 |
+
RestrictModel(model="gpt-4o-mini", base_model_name="gpt-4o-mini", model_type=ModelType.LLM),
|
78 |
+
RestrictModel(model="gpt-4-32k", base_model_name="gpt-4-32k", model_type=ModelType.LLM),
|
79 |
+
RestrictModel(
|
80 |
+
model="gpt-4-1106-preview", base_model_name="gpt-4-1106-preview", model_type=ModelType.LLM
|
81 |
+
),
|
82 |
+
RestrictModel(
|
83 |
+
model="gpt-4-vision-preview", base_model_name="gpt-4-vision-preview", model_type=ModelType.LLM
|
84 |
+
),
|
85 |
+
RestrictModel(model="gpt-35-turbo", base_model_name="gpt-35-turbo", model_type=ModelType.LLM),
|
86 |
+
RestrictModel(
|
87 |
+
model="gpt-35-turbo-1106", base_model_name="gpt-35-turbo-1106", model_type=ModelType.LLM
|
88 |
+
),
|
89 |
+
RestrictModel(
|
90 |
+
model="gpt-35-turbo-instruct", base_model_name="gpt-35-turbo-instruct", model_type=ModelType.LLM
|
91 |
+
),
|
92 |
+
RestrictModel(
|
93 |
+
model="gpt-35-turbo-16k", base_model_name="gpt-35-turbo-16k", model_type=ModelType.LLM
|
94 |
+
),
|
95 |
+
RestrictModel(
|
96 |
+
model="text-davinci-003", base_model_name="text-davinci-003", model_type=ModelType.LLM
|
97 |
+
),
|
98 |
+
RestrictModel(
|
99 |
+
model="text-embedding-ada-002",
|
100 |
+
base_model_name="text-embedding-ada-002",
|
101 |
+
model_type=ModelType.TEXT_EMBEDDING,
|
102 |
+
),
|
103 |
+
RestrictModel(
|
104 |
+
model="text-embedding-3-small",
|
105 |
+
base_model_name="text-embedding-3-small",
|
106 |
+
model_type=ModelType.TEXT_EMBEDDING,
|
107 |
+
),
|
108 |
+
RestrictModel(
|
109 |
+
model="text-embedding-3-large",
|
110 |
+
base_model_name="text-embedding-3-large",
|
111 |
+
model_type=ModelType.TEXT_EMBEDDING,
|
112 |
+
),
|
113 |
+
],
|
114 |
+
)
|
115 |
+
quotas.append(trial_quota)
|
116 |
+
|
117 |
+
return HostingProvider(enabled=True, credentials=credentials, quota_unit=quota_unit, quotas=quotas)
|
118 |
+
|
119 |
+
return HostingProvider(
|
120 |
+
enabled=False,
|
121 |
+
quota_unit=quota_unit,
|
122 |
+
)
|
123 |
+
|
124 |
+
def init_openai(self) -> HostingProvider:
|
125 |
+
quota_unit = QuotaUnit.CREDITS
|
126 |
+
quotas: list[HostingQuota] = []
|
127 |
+
|
128 |
+
if dify_config.HOSTED_OPENAI_TRIAL_ENABLED:
|
129 |
+
hosted_quota_limit = dify_config.HOSTED_OPENAI_QUOTA_LIMIT
|
130 |
+
trial_models = self.parse_restrict_models_from_env("HOSTED_OPENAI_TRIAL_MODELS")
|
131 |
+
trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trial_models)
|
132 |
+
quotas.append(trial_quota)
|
133 |
+
|
134 |
+
if dify_config.HOSTED_OPENAI_PAID_ENABLED:
|
135 |
+
paid_models = self.parse_restrict_models_from_env("HOSTED_OPENAI_PAID_MODELS")
|
136 |
+
paid_quota = PaidHostingQuota(restrict_models=paid_models)
|
137 |
+
quotas.append(paid_quota)
|
138 |
+
|
139 |
+
if len(quotas) > 0:
|
140 |
+
credentials = {
|
141 |
+
"openai_api_key": dify_config.HOSTED_OPENAI_API_KEY,
|
142 |
+
}
|
143 |
+
|
144 |
+
if dify_config.HOSTED_OPENAI_API_BASE:
|
145 |
+
credentials["openai_api_base"] = dify_config.HOSTED_OPENAI_API_BASE
|
146 |
+
|
147 |
+
if dify_config.HOSTED_OPENAI_API_ORGANIZATION:
|
148 |
+
credentials["openai_organization"] = dify_config.HOSTED_OPENAI_API_ORGANIZATION
|
149 |
+
|
150 |
+
return HostingProvider(enabled=True, credentials=credentials, quota_unit=quota_unit, quotas=quotas)
|
151 |
+
|
152 |
+
return HostingProvider(
|
153 |
+
enabled=False,
|
154 |
+
quota_unit=quota_unit,
|
155 |
+
)
|
156 |
+
|
157 |
+
@staticmethod
|
158 |
+
def init_anthropic() -> HostingProvider:
|
159 |
+
quota_unit = QuotaUnit.TOKENS
|
160 |
+
quotas: list[HostingQuota] = []
|
161 |
+
|
162 |
+
if dify_config.HOSTED_ANTHROPIC_TRIAL_ENABLED:
|
163 |
+
hosted_quota_limit = dify_config.HOSTED_ANTHROPIC_QUOTA_LIMIT
|
164 |
+
trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit)
|
165 |
+
quotas.append(trial_quota)
|
166 |
+
|
167 |
+
if dify_config.HOSTED_ANTHROPIC_PAID_ENABLED:
|
168 |
+
paid_quota = PaidHostingQuota()
|
169 |
+
quotas.append(paid_quota)
|
170 |
+
|
171 |
+
if len(quotas) > 0:
|
172 |
+
credentials = {
|
173 |
+
"anthropic_api_key": dify_config.HOSTED_ANTHROPIC_API_KEY,
|
174 |
+
}
|
175 |
+
|
176 |
+
if dify_config.HOSTED_ANTHROPIC_API_BASE:
|
177 |
+
credentials["anthropic_api_url"] = dify_config.HOSTED_ANTHROPIC_API_BASE
|
178 |
+
|
179 |
+
return HostingProvider(enabled=True, credentials=credentials, quota_unit=quota_unit, quotas=quotas)
|
180 |
+
|
181 |
+
return HostingProvider(
|
182 |
+
enabled=False,
|
183 |
+
quota_unit=quota_unit,
|
184 |
+
)
|
185 |
+
|
186 |
+
@staticmethod
|
187 |
+
def init_minimax() -> HostingProvider:
|
188 |
+
quota_unit = QuotaUnit.TOKENS
|
189 |
+
if dify_config.HOSTED_MINIMAX_ENABLED:
|
190 |
+
quotas: list[HostingQuota] = [FreeHostingQuota()]
|
191 |
+
|
192 |
+
return HostingProvider(
|
193 |
+
enabled=True,
|
194 |
+
credentials=None, # use credentials from the provider
|
195 |
+
quota_unit=quota_unit,
|
196 |
+
quotas=quotas,
|
197 |
+
)
|
198 |
+
|
199 |
+
return HostingProvider(
|
200 |
+
enabled=False,
|
201 |
+
quota_unit=quota_unit,
|
202 |
+
)
|
203 |
+
|
204 |
+
@staticmethod
|
205 |
+
def init_spark() -> HostingProvider:
|
206 |
+
quota_unit = QuotaUnit.TOKENS
|
207 |
+
if dify_config.HOSTED_SPARK_ENABLED:
|
208 |
+
quotas: list[HostingQuota] = [FreeHostingQuota()]
|
209 |
+
|
210 |
+
return HostingProvider(
|
211 |
+
enabled=True,
|
212 |
+
credentials=None, # use credentials from the provider
|
213 |
+
quota_unit=quota_unit,
|
214 |
+
quotas=quotas,
|
215 |
+
)
|
216 |
+
|
217 |
+
return HostingProvider(
|
218 |
+
enabled=False,
|
219 |
+
quota_unit=quota_unit,
|
220 |
+
)
|
221 |
+
|
222 |
+
@staticmethod
|
223 |
+
def init_zhipuai() -> HostingProvider:
|
224 |
+
quota_unit = QuotaUnit.TOKENS
|
225 |
+
if dify_config.HOSTED_ZHIPUAI_ENABLED:
|
226 |
+
quotas: list[HostingQuota] = [FreeHostingQuota()]
|
227 |
+
|
228 |
+
return HostingProvider(
|
229 |
+
enabled=True,
|
230 |
+
credentials=None, # use credentials from the provider
|
231 |
+
quota_unit=quota_unit,
|
232 |
+
quotas=quotas,
|
233 |
+
)
|
234 |
+
|
235 |
+
return HostingProvider(
|
236 |
+
enabled=False,
|
237 |
+
quota_unit=quota_unit,
|
238 |
+
)
|
239 |
+
|
240 |
+
@staticmethod
|
241 |
+
def init_moderation_config() -> HostedModerationConfig:
|
242 |
+
if dify_config.HOSTED_MODERATION_ENABLED and dify_config.HOSTED_MODERATION_PROVIDERS:
|
243 |
+
return HostedModerationConfig(enabled=True, providers=dify_config.HOSTED_MODERATION_PROVIDERS.split(","))
|
244 |
+
|
245 |
+
return HostedModerationConfig(enabled=False)
|
246 |
+
|
247 |
+
@staticmethod
|
248 |
+
def parse_restrict_models_from_env(env_var: str) -> list[RestrictModel]:
|
249 |
+
models_str = dify_config.model_dump().get(env_var)
|
250 |
+
models_list = models_str.split(",") if models_str else []
|
251 |
+
return [
|
252 |
+
RestrictModel(model=model_name.strip(), model_type=ModelType.LLM)
|
253 |
+
for model_name in models_list
|
254 |
+
if model_name.strip()
|
255 |
+
]
|
api/core/indexing_runner.py
ADDED
@@ -0,0 +1,754 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import concurrent.futures
|
2 |
+
import datetime
|
3 |
+
import json
|
4 |
+
import logging
|
5 |
+
import re
|
6 |
+
import threading
|
7 |
+
import time
|
8 |
+
import uuid
|
9 |
+
from typing import Any, Optional, cast
|
10 |
+
|
11 |
+
from flask import current_app
|
12 |
+
from flask_login import current_user # type: ignore
|
13 |
+
from sqlalchemy.orm.exc import ObjectDeletedError
|
14 |
+
|
15 |
+
from configs import dify_config
|
16 |
+
from core.entities.knowledge_entities import IndexingEstimate, PreviewDetail, QAPreviewDetail
|
17 |
+
from core.errors.error import ProviderTokenNotInitError
|
18 |
+
from core.model_manager import ModelInstance, ModelManager
|
19 |
+
from core.model_runtime.entities.model_entities import ModelType
|
20 |
+
from core.rag.cleaner.clean_processor import CleanProcessor
|
21 |
+
from core.rag.datasource.keyword.keyword_factory import Keyword
|
22 |
+
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
|
23 |
+
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
24 |
+
from core.rag.index_processor.constant.index_type import IndexType
|
25 |
+
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
26 |
+
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
27 |
+
from core.rag.models.document import ChildDocument, Document
|
28 |
+
from core.rag.splitter.fixed_text_splitter import (
|
29 |
+
EnhanceRecursiveCharacterTextSplitter,
|
30 |
+
FixedRecursiveCharacterTextSplitter,
|
31 |
+
)
|
32 |
+
from core.rag.splitter.text_splitter import TextSplitter
|
33 |
+
from core.tools.utils.web_reader_tool import get_image_upload_file_ids
|
34 |
+
from extensions.ext_database import db
|
35 |
+
from extensions.ext_redis import redis_client
|
36 |
+
from extensions.ext_storage import storage
|
37 |
+
from libs import helper
|
38 |
+
from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment
|
39 |
+
from models.dataset import Document as DatasetDocument
|
40 |
+
from models.model import UploadFile
|
41 |
+
from services.feature_service import FeatureService
|
42 |
+
|
43 |
+
|
44 |
+
class IndexingRunner:
|
45 |
+
def __init__(self):
|
46 |
+
self.storage = storage
|
47 |
+
self.model_manager = ModelManager()
|
48 |
+
|
49 |
+
def run(self, dataset_documents: list[DatasetDocument]):
|
50 |
+
"""Run the indexing process."""
|
51 |
+
for dataset_document in dataset_documents:
|
52 |
+
try:
|
53 |
+
# get dataset
|
54 |
+
dataset = Dataset.query.filter_by(id=dataset_document.dataset_id).first()
|
55 |
+
|
56 |
+
if not dataset:
|
57 |
+
raise ValueError("no dataset found")
|
58 |
+
|
59 |
+
# get the process rule
|
60 |
+
processing_rule = (
|
61 |
+
db.session.query(DatasetProcessRule)
|
62 |
+
.filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
|
63 |
+
.first()
|
64 |
+
)
|
65 |
+
if not processing_rule:
|
66 |
+
raise ValueError("no process rule found")
|
67 |
+
index_type = dataset_document.doc_form
|
68 |
+
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
69 |
+
# extract
|
70 |
+
text_docs = self._extract(index_processor, dataset_document, processing_rule.to_dict())
|
71 |
+
|
72 |
+
# transform
|
73 |
+
documents = self._transform(
|
74 |
+
index_processor, dataset, text_docs, dataset_document.doc_language, processing_rule.to_dict()
|
75 |
+
)
|
76 |
+
# save segment
|
77 |
+
self._load_segments(dataset, dataset_document, documents)
|
78 |
+
|
79 |
+
# load
|
80 |
+
self._load(
|
81 |
+
index_processor=index_processor,
|
82 |
+
dataset=dataset,
|
83 |
+
dataset_document=dataset_document,
|
84 |
+
documents=documents,
|
85 |
+
)
|
86 |
+
except DocumentIsPausedError:
|
87 |
+
raise DocumentIsPausedError("Document paused, document id: {}".format(dataset_document.id))
|
88 |
+
except ProviderTokenNotInitError as e:
|
89 |
+
dataset_document.indexing_status = "error"
|
90 |
+
dataset_document.error = str(e.description)
|
91 |
+
dataset_document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
92 |
+
db.session.commit()
|
93 |
+
except ObjectDeletedError:
|
94 |
+
logging.warning("Document deleted, document id: {}".format(dataset_document.id))
|
95 |
+
except Exception as e:
|
96 |
+
logging.exception("consume document failed")
|
97 |
+
dataset_document.indexing_status = "error"
|
98 |
+
dataset_document.error = str(e)
|
99 |
+
dataset_document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
100 |
+
db.session.commit()
|
101 |
+
|
102 |
+
def run_in_splitting_status(self, dataset_document: DatasetDocument):
|
103 |
+
"""Run the indexing process when the index_status is splitting."""
|
104 |
+
try:
|
105 |
+
# get dataset
|
106 |
+
dataset = Dataset.query.filter_by(id=dataset_document.dataset_id).first()
|
107 |
+
|
108 |
+
if not dataset:
|
109 |
+
raise ValueError("no dataset found")
|
110 |
+
|
111 |
+
# get exist document_segment list and delete
|
112 |
+
document_segments = DocumentSegment.query.filter_by(
|
113 |
+
dataset_id=dataset.id, document_id=dataset_document.id
|
114 |
+
).all()
|
115 |
+
|
116 |
+
for document_segment in document_segments:
|
117 |
+
db.session.delete(document_segment)
|
118 |
+
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
|
119 |
+
# delete child chunks
|
120 |
+
db.session.query(ChildChunk).filter(ChildChunk.segment_id == document_segment.id).delete()
|
121 |
+
db.session.commit()
|
122 |
+
# get the process rule
|
123 |
+
processing_rule = (
|
124 |
+
db.session.query(DatasetProcessRule)
|
125 |
+
.filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
|
126 |
+
.first()
|
127 |
+
)
|
128 |
+
if not processing_rule:
|
129 |
+
raise ValueError("no process rule found")
|
130 |
+
|
131 |
+
index_type = dataset_document.doc_form
|
132 |
+
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
133 |
+
# extract
|
134 |
+
text_docs = self._extract(index_processor, dataset_document, processing_rule.to_dict())
|
135 |
+
|
136 |
+
# transform
|
137 |
+
documents = self._transform(
|
138 |
+
index_processor, dataset, text_docs, dataset_document.doc_language, processing_rule.to_dict()
|
139 |
+
)
|
140 |
+
# save segment
|
141 |
+
self._load_segments(dataset, dataset_document, documents)
|
142 |
+
|
143 |
+
# load
|
144 |
+
self._load(
|
145 |
+
index_processor=index_processor, dataset=dataset, dataset_document=dataset_document, documents=documents
|
146 |
+
)
|
147 |
+
except DocumentIsPausedError:
|
148 |
+
raise DocumentIsPausedError("Document paused, document id: {}".format(dataset_document.id))
|
149 |
+
except ProviderTokenNotInitError as e:
|
150 |
+
dataset_document.indexing_status = "error"
|
151 |
+
dataset_document.error = str(e.description)
|
152 |
+
dataset_document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
153 |
+
db.session.commit()
|
154 |
+
except Exception as e:
|
155 |
+
logging.exception("consume document failed")
|
156 |
+
dataset_document.indexing_status = "error"
|
157 |
+
dataset_document.error = str(e)
|
158 |
+
dataset_document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
159 |
+
db.session.commit()
|
160 |
+
|
161 |
+
def run_in_indexing_status(self, dataset_document: DatasetDocument):
|
162 |
+
"""Run the indexing process when the index_status is indexing."""
|
163 |
+
try:
|
164 |
+
# get dataset
|
165 |
+
dataset = Dataset.query.filter_by(id=dataset_document.dataset_id).first()
|
166 |
+
|
167 |
+
if not dataset:
|
168 |
+
raise ValueError("no dataset found")
|
169 |
+
|
170 |
+
# get exist document_segment list and delete
|
171 |
+
document_segments = DocumentSegment.query.filter_by(
|
172 |
+
dataset_id=dataset.id, document_id=dataset_document.id
|
173 |
+
).all()
|
174 |
+
|
175 |
+
documents = []
|
176 |
+
if document_segments:
|
177 |
+
for document_segment in document_segments:
|
178 |
+
# transform segment to node
|
179 |
+
if document_segment.status != "completed":
|
180 |
+
document = Document(
|
181 |
+
page_content=document_segment.content,
|
182 |
+
metadata={
|
183 |
+
"doc_id": document_segment.index_node_id,
|
184 |
+
"doc_hash": document_segment.index_node_hash,
|
185 |
+
"document_id": document_segment.document_id,
|
186 |
+
"dataset_id": document_segment.dataset_id,
|
187 |
+
},
|
188 |
+
)
|
189 |
+
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
|
190 |
+
child_chunks = document_segment.child_chunks
|
191 |
+
if child_chunks:
|
192 |
+
child_documents = []
|
193 |
+
for child_chunk in child_chunks:
|
194 |
+
child_document = ChildDocument(
|
195 |
+
page_content=child_chunk.content,
|
196 |
+
metadata={
|
197 |
+
"doc_id": child_chunk.index_node_id,
|
198 |
+
"doc_hash": child_chunk.index_node_hash,
|
199 |
+
"document_id": document_segment.document_id,
|
200 |
+
"dataset_id": document_segment.dataset_id,
|
201 |
+
},
|
202 |
+
)
|
203 |
+
child_documents.append(child_document)
|
204 |
+
document.children = child_documents
|
205 |
+
documents.append(document)
|
206 |
+
|
207 |
+
# build index
|
208 |
+
# get the process rule
|
209 |
+
processing_rule = (
|
210 |
+
db.session.query(DatasetProcessRule)
|
211 |
+
.filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
|
212 |
+
.first()
|
213 |
+
)
|
214 |
+
|
215 |
+
index_type = dataset_document.doc_form
|
216 |
+
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
217 |
+
self._load(
|
218 |
+
index_processor=index_processor, dataset=dataset, dataset_document=dataset_document, documents=documents
|
219 |
+
)
|
220 |
+
except DocumentIsPausedError:
|
221 |
+
raise DocumentIsPausedError("Document paused, document id: {}".format(dataset_document.id))
|
222 |
+
except ProviderTokenNotInitError as e:
|
223 |
+
dataset_document.indexing_status = "error"
|
224 |
+
dataset_document.error = str(e.description)
|
225 |
+
dataset_document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
226 |
+
db.session.commit()
|
227 |
+
except Exception as e:
|
228 |
+
logging.exception("consume document failed")
|
229 |
+
dataset_document.indexing_status = "error"
|
230 |
+
dataset_document.error = str(e)
|
231 |
+
dataset_document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
232 |
+
db.session.commit()
|
233 |
+
|
234 |
+
def indexing_estimate(
|
235 |
+
self,
|
236 |
+
tenant_id: str,
|
237 |
+
extract_settings: list[ExtractSetting],
|
238 |
+
tmp_processing_rule: dict,
|
239 |
+
doc_form: Optional[str] = None,
|
240 |
+
doc_language: str = "English",
|
241 |
+
dataset_id: Optional[str] = None,
|
242 |
+
indexing_technique: str = "economy",
|
243 |
+
) -> IndexingEstimate:
|
244 |
+
"""
|
245 |
+
Estimate the indexing for the document.
|
246 |
+
"""
|
247 |
+
# check document limit
|
248 |
+
features = FeatureService.get_features(tenant_id)
|
249 |
+
if features.billing.enabled:
|
250 |
+
count = len(extract_settings)
|
251 |
+
batch_upload_limit = dify_config.BATCH_UPLOAD_LIMIT
|
252 |
+
if count > batch_upload_limit:
|
253 |
+
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
|
254 |
+
|
255 |
+
embedding_model_instance = None
|
256 |
+
if dataset_id:
|
257 |
+
dataset = Dataset.query.filter_by(id=dataset_id).first()
|
258 |
+
if not dataset:
|
259 |
+
raise ValueError("Dataset not found.")
|
260 |
+
if dataset.indexing_technique == "high_quality" or indexing_technique == "high_quality":
|
261 |
+
if dataset.embedding_model_provider:
|
262 |
+
embedding_model_instance = self.model_manager.get_model_instance(
|
263 |
+
tenant_id=tenant_id,
|
264 |
+
provider=dataset.embedding_model_provider,
|
265 |
+
model_type=ModelType.TEXT_EMBEDDING,
|
266 |
+
model=dataset.embedding_model,
|
267 |
+
)
|
268 |
+
else:
|
269 |
+
embedding_model_instance = self.model_manager.get_default_model_instance(
|
270 |
+
tenant_id=tenant_id,
|
271 |
+
model_type=ModelType.TEXT_EMBEDDING,
|
272 |
+
)
|
273 |
+
else:
|
274 |
+
if indexing_technique == "high_quality":
|
275 |
+
embedding_model_instance = self.model_manager.get_default_model_instance(
|
276 |
+
tenant_id=tenant_id,
|
277 |
+
model_type=ModelType.TEXT_EMBEDDING,
|
278 |
+
)
|
279 |
+
preview_texts = [] # type: ignore
|
280 |
+
|
281 |
+
total_segments = 0
|
282 |
+
index_type = doc_form
|
283 |
+
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
284 |
+
for extract_setting in extract_settings:
|
285 |
+
# extract
|
286 |
+
processing_rule = DatasetProcessRule(
|
287 |
+
mode=tmp_processing_rule["mode"], rules=json.dumps(tmp_processing_rule["rules"])
|
288 |
+
)
|
289 |
+
text_docs = index_processor.extract(extract_setting, process_rule_mode=tmp_processing_rule["mode"])
|
290 |
+
documents = index_processor.transform(
|
291 |
+
text_docs,
|
292 |
+
embedding_model_instance=embedding_model_instance,
|
293 |
+
process_rule=processing_rule.to_dict(),
|
294 |
+
tenant_id=current_user.current_tenant_id,
|
295 |
+
doc_language=doc_language,
|
296 |
+
preview=True,
|
297 |
+
)
|
298 |
+
total_segments += len(documents)
|
299 |
+
for document in documents:
|
300 |
+
if len(preview_texts) < 10:
|
301 |
+
if doc_form and doc_form == "qa_model":
|
302 |
+
preview_detail = QAPreviewDetail(
|
303 |
+
question=document.page_content, answer=document.metadata.get("answer") or ""
|
304 |
+
)
|
305 |
+
preview_texts.append(preview_detail)
|
306 |
+
else:
|
307 |
+
preview_detail = PreviewDetail(content=document.page_content) # type: ignore
|
308 |
+
if document.children:
|
309 |
+
preview_detail.child_chunks = [child.page_content for child in document.children] # type: ignore
|
310 |
+
preview_texts.append(preview_detail)
|
311 |
+
|
312 |
+
# delete image files and related db records
|
313 |
+
image_upload_file_ids = get_image_upload_file_ids(document.page_content)
|
314 |
+
for upload_file_id in image_upload_file_ids:
|
315 |
+
image_file = db.session.query(UploadFile).filter(UploadFile.id == upload_file_id).first()
|
316 |
+
try:
|
317 |
+
if image_file:
|
318 |
+
storage.delete(image_file.key)
|
319 |
+
except Exception:
|
320 |
+
logging.exception(
|
321 |
+
"Delete image_files failed while indexing_estimate, \
|
322 |
+
image_upload_file_is: {}".format(upload_file_id)
|
323 |
+
)
|
324 |
+
db.session.delete(image_file)
|
325 |
+
|
326 |
+
if doc_form and doc_form == "qa_model":
|
327 |
+
return IndexingEstimate(total_segments=total_segments * 20, qa_preview=preview_texts, preview=[])
|
328 |
+
return IndexingEstimate(total_segments=total_segments, preview=preview_texts) # type: ignore
|
329 |
+
|
330 |
+
def _extract(
|
331 |
+
self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict
|
332 |
+
) -> list[Document]:
|
333 |
+
# load file
|
334 |
+
if dataset_document.data_source_type not in {"upload_file", "notion_import", "website_crawl"}:
|
335 |
+
return []
|
336 |
+
|
337 |
+
data_source_info = dataset_document.data_source_info_dict
|
338 |
+
text_docs = []
|
339 |
+
if dataset_document.data_source_type == "upload_file":
|
340 |
+
if not data_source_info or "upload_file_id" not in data_source_info:
|
341 |
+
raise ValueError("no upload file found")
|
342 |
+
|
343 |
+
file_detail = (
|
344 |
+
db.session.query(UploadFile).filter(UploadFile.id == data_source_info["upload_file_id"]).one_or_none()
|
345 |
+
)
|
346 |
+
|
347 |
+
if file_detail:
|
348 |
+
extract_setting = ExtractSetting(
|
349 |
+
datasource_type="upload_file", upload_file=file_detail, document_model=dataset_document.doc_form
|
350 |
+
)
|
351 |
+
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"])
|
352 |
+
elif dataset_document.data_source_type == "notion_import":
|
353 |
+
if (
|
354 |
+
not data_source_info
|
355 |
+
or "notion_workspace_id" not in data_source_info
|
356 |
+
or "notion_page_id" not in data_source_info
|
357 |
+
):
|
358 |
+
raise ValueError("no notion import info found")
|
359 |
+
extract_setting = ExtractSetting(
|
360 |
+
datasource_type="notion_import",
|
361 |
+
notion_info={
|
362 |
+
"notion_workspace_id": data_source_info["notion_workspace_id"],
|
363 |
+
"notion_obj_id": data_source_info["notion_page_id"],
|
364 |
+
"notion_page_type": data_source_info["type"],
|
365 |
+
"document": dataset_document,
|
366 |
+
"tenant_id": dataset_document.tenant_id,
|
367 |
+
},
|
368 |
+
document_model=dataset_document.doc_form,
|
369 |
+
)
|
370 |
+
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"])
|
371 |
+
elif dataset_document.data_source_type == "website_crawl":
|
372 |
+
if (
|
373 |
+
not data_source_info
|
374 |
+
or "provider" not in data_source_info
|
375 |
+
or "url" not in data_source_info
|
376 |
+
or "job_id" not in data_source_info
|
377 |
+
):
|
378 |
+
raise ValueError("no website import info found")
|
379 |
+
extract_setting = ExtractSetting(
|
380 |
+
datasource_type="website_crawl",
|
381 |
+
website_info={
|
382 |
+
"provider": data_source_info["provider"],
|
383 |
+
"job_id": data_source_info["job_id"],
|
384 |
+
"tenant_id": dataset_document.tenant_id,
|
385 |
+
"url": data_source_info["url"],
|
386 |
+
"mode": data_source_info["mode"],
|
387 |
+
"only_main_content": data_source_info["only_main_content"],
|
388 |
+
},
|
389 |
+
document_model=dataset_document.doc_form,
|
390 |
+
)
|
391 |
+
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"])
|
392 |
+
# update document status to splitting
|
393 |
+
self._update_document_index_status(
|
394 |
+
document_id=dataset_document.id,
|
395 |
+
after_indexing_status="splitting",
|
396 |
+
extra_update_params={
|
397 |
+
DatasetDocument.word_count: sum(len(text_doc.page_content) for text_doc in text_docs),
|
398 |
+
DatasetDocument.parsing_completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
|
399 |
+
},
|
400 |
+
)
|
401 |
+
|
402 |
+
# replace doc id to document model id
|
403 |
+
text_docs = cast(list[Document], text_docs)
|
404 |
+
for text_doc in text_docs:
|
405 |
+
if text_doc.metadata is not None:
|
406 |
+
text_doc.metadata["document_id"] = dataset_document.id
|
407 |
+
text_doc.metadata["dataset_id"] = dataset_document.dataset_id
|
408 |
+
|
409 |
+
return text_docs
|
410 |
+
|
411 |
+
@staticmethod
|
412 |
+
def filter_string(text):
|
413 |
+
text = re.sub(r"<\|", "<", text)
|
414 |
+
text = re.sub(r"\|>", ">", text)
|
415 |
+
text = re.sub(r"[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\xEF\xBF\xBE]", "", text)
|
416 |
+
# Unicode U+FFFE
|
417 |
+
text = re.sub("\ufffe", "", text)
|
418 |
+
return text
|
419 |
+
|
420 |
+
@staticmethod
|
421 |
+
def _get_splitter(
|
422 |
+
processing_rule_mode: str,
|
423 |
+
max_tokens: int,
|
424 |
+
chunk_overlap: int,
|
425 |
+
separator: str,
|
426 |
+
embedding_model_instance: Optional[ModelInstance],
|
427 |
+
) -> TextSplitter:
|
428 |
+
"""
|
429 |
+
Get the NodeParser object according to the processing rule.
|
430 |
+
"""
|
431 |
+
if processing_rule_mode in ["custom", "hierarchical"]:
|
432 |
+
# The user-defined segmentation rule
|
433 |
+
max_segmentation_tokens_length = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH
|
434 |
+
if max_tokens < 50 or max_tokens > max_segmentation_tokens_length:
|
435 |
+
raise ValueError(f"Custom segment length should be between 50 and {max_segmentation_tokens_length}.")
|
436 |
+
|
437 |
+
if separator:
|
438 |
+
separator = separator.replace("\\n", "\n")
|
439 |
+
|
440 |
+
character_splitter = FixedRecursiveCharacterTextSplitter.from_encoder(
|
441 |
+
chunk_size=max_tokens,
|
442 |
+
chunk_overlap=chunk_overlap,
|
443 |
+
fixed_separator=separator,
|
444 |
+
separators=["\n\n", "。", ". ", " ", ""],
|
445 |
+
embedding_model_instance=embedding_model_instance,
|
446 |
+
)
|
447 |
+
else:
|
448 |
+
# Automatic segmentation
|
449 |
+
automatic_rules: dict[str, Any] = dict(DatasetProcessRule.AUTOMATIC_RULES["segmentation"])
|
450 |
+
character_splitter = EnhanceRecursiveCharacterTextSplitter.from_encoder(
|
451 |
+
chunk_size=automatic_rules["max_tokens"],
|
452 |
+
chunk_overlap=automatic_rules["chunk_overlap"],
|
453 |
+
separators=["\n\n", "。", ". ", " ", ""],
|
454 |
+
embedding_model_instance=embedding_model_instance,
|
455 |
+
)
|
456 |
+
|
457 |
+
return character_splitter # type: ignore
|
458 |
+
|
459 |
+
def _split_to_documents_for_estimate(
|
460 |
+
self, text_docs: list[Document], splitter: TextSplitter, processing_rule: DatasetProcessRule
|
461 |
+
) -> list[Document]:
|
462 |
+
"""
|
463 |
+
Split the text documents into nodes.
|
464 |
+
"""
|
465 |
+
all_documents: list[Document] = []
|
466 |
+
for text_doc in text_docs:
|
467 |
+
# document clean
|
468 |
+
document_text = self._document_clean(text_doc.page_content, processing_rule)
|
469 |
+
text_doc.page_content = document_text
|
470 |
+
|
471 |
+
# parse document to nodes
|
472 |
+
documents = splitter.split_documents([text_doc])
|
473 |
+
|
474 |
+
split_documents = []
|
475 |
+
for document in documents:
|
476 |
+
if document.page_content is None or not document.page_content.strip():
|
477 |
+
continue
|
478 |
+
if document.metadata is not None:
|
479 |
+
doc_id = str(uuid.uuid4())
|
480 |
+
hash = helper.generate_text_hash(document.page_content)
|
481 |
+
document.metadata["doc_id"] = doc_id
|
482 |
+
document.metadata["doc_hash"] = hash
|
483 |
+
|
484 |
+
split_documents.append(document)
|
485 |
+
|
486 |
+
all_documents.extend(split_documents)
|
487 |
+
|
488 |
+
return all_documents
|
489 |
+
|
490 |
+
@staticmethod
|
491 |
+
def _document_clean(text: str, processing_rule: DatasetProcessRule) -> str:
|
492 |
+
"""
|
493 |
+
Clean the document text according to the processing rules.
|
494 |
+
"""
|
495 |
+
if processing_rule.mode == "automatic":
|
496 |
+
rules = DatasetProcessRule.AUTOMATIC_RULES
|
497 |
+
else:
|
498 |
+
rules = json.loads(processing_rule.rules) if processing_rule.rules else {}
|
499 |
+
document_text = CleanProcessor.clean(text, {"rules": rules})
|
500 |
+
|
501 |
+
return document_text
|
502 |
+
|
503 |
+
@staticmethod
|
504 |
+
def format_split_text(text: str) -> list[QAPreviewDetail]:
|
505 |
+
regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q\d+:|$)"
|
506 |
+
matches = re.findall(regex, text, re.UNICODE)
|
507 |
+
|
508 |
+
return [QAPreviewDetail(question=q, answer=re.sub(r"\n\s*", "\n", a.strip())) for q, a in matches if q and a]
|
509 |
+
|
510 |
+
def _load(
|
511 |
+
self,
|
512 |
+
index_processor: BaseIndexProcessor,
|
513 |
+
dataset: Dataset,
|
514 |
+
dataset_document: DatasetDocument,
|
515 |
+
documents: list[Document],
|
516 |
+
) -> None:
|
517 |
+
"""
|
518 |
+
insert index and update document/segment status to completed
|
519 |
+
"""
|
520 |
+
|
521 |
+
embedding_model_instance = None
|
522 |
+
if dataset.indexing_technique == "high_quality":
|
523 |
+
embedding_model_instance = self.model_manager.get_model_instance(
|
524 |
+
tenant_id=dataset.tenant_id,
|
525 |
+
provider=dataset.embedding_model_provider,
|
526 |
+
model_type=ModelType.TEXT_EMBEDDING,
|
527 |
+
model=dataset.embedding_model,
|
528 |
+
)
|
529 |
+
|
530 |
+
# chunk nodes by chunk size
|
531 |
+
indexing_start_at = time.perf_counter()
|
532 |
+
tokens = 0
|
533 |
+
if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX:
|
534 |
+
# create keyword index
|
535 |
+
create_keyword_thread = threading.Thread(
|
536 |
+
target=self._process_keyword_index,
|
537 |
+
args=(current_app._get_current_object(), dataset.id, dataset_document.id, documents), # type: ignore
|
538 |
+
)
|
539 |
+
create_keyword_thread.start()
|
540 |
+
|
541 |
+
max_workers = 10
|
542 |
+
if dataset.indexing_technique == "high_quality":
|
543 |
+
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
544 |
+
futures = []
|
545 |
+
|
546 |
+
# Distribute documents into multiple groups based on the hash values of page_content
|
547 |
+
# This is done to prevent multiple threads from processing the same document,
|
548 |
+
# Thereby avoiding potential database insertion deadlocks
|
549 |
+
document_groups: list[list[Document]] = [[] for _ in range(max_workers)]
|
550 |
+
for document in documents:
|
551 |
+
hash = helper.generate_text_hash(document.page_content)
|
552 |
+
group_index = int(hash, 16) % max_workers
|
553 |
+
document_groups[group_index].append(document)
|
554 |
+
for chunk_documents in document_groups:
|
555 |
+
if len(chunk_documents) == 0:
|
556 |
+
continue
|
557 |
+
futures.append(
|
558 |
+
executor.submit(
|
559 |
+
self._process_chunk,
|
560 |
+
current_app._get_current_object(), # type: ignore
|
561 |
+
index_processor,
|
562 |
+
chunk_documents,
|
563 |
+
dataset,
|
564 |
+
dataset_document,
|
565 |
+
embedding_model_instance,
|
566 |
+
)
|
567 |
+
)
|
568 |
+
|
569 |
+
for future in futures:
|
570 |
+
tokens += future.result()
|
571 |
+
if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX:
|
572 |
+
create_keyword_thread.join()
|
573 |
+
indexing_end_at = time.perf_counter()
|
574 |
+
|
575 |
+
# update document status to completed
|
576 |
+
self._update_document_index_status(
|
577 |
+
document_id=dataset_document.id,
|
578 |
+
after_indexing_status="completed",
|
579 |
+
extra_update_params={
|
580 |
+
DatasetDocument.tokens: tokens,
|
581 |
+
DatasetDocument.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
|
582 |
+
DatasetDocument.indexing_latency: indexing_end_at - indexing_start_at,
|
583 |
+
DatasetDocument.error: None,
|
584 |
+
},
|
585 |
+
)
|
586 |
+
|
587 |
+
@staticmethod
|
588 |
+
def _process_keyword_index(flask_app, dataset_id, document_id, documents):
|
589 |
+
with flask_app.app_context():
|
590 |
+
dataset = Dataset.query.filter_by(id=dataset_id).first()
|
591 |
+
if not dataset:
|
592 |
+
raise ValueError("no dataset found")
|
593 |
+
keyword = Keyword(dataset)
|
594 |
+
keyword.create(documents)
|
595 |
+
if dataset.indexing_technique != "high_quality":
|
596 |
+
document_ids = [document.metadata["doc_id"] for document in documents]
|
597 |
+
db.session.query(DocumentSegment).filter(
|
598 |
+
DocumentSegment.document_id == document_id,
|
599 |
+
DocumentSegment.dataset_id == dataset_id,
|
600 |
+
DocumentSegment.index_node_id.in_(document_ids),
|
601 |
+
DocumentSegment.status == "indexing",
|
602 |
+
).update(
|
603 |
+
{
|
604 |
+
DocumentSegment.status: "completed",
|
605 |
+
DocumentSegment.enabled: True,
|
606 |
+
DocumentSegment.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
|
607 |
+
}
|
608 |
+
)
|
609 |
+
|
610 |
+
db.session.commit()
|
611 |
+
|
612 |
+
def _process_chunk(
|
613 |
+
self, flask_app, index_processor, chunk_documents, dataset, dataset_document, embedding_model_instance
|
614 |
+
):
|
615 |
+
with flask_app.app_context():
|
616 |
+
# check document is paused
|
617 |
+
self._check_document_paused_status(dataset_document.id)
|
618 |
+
|
619 |
+
tokens = 0
|
620 |
+
if embedding_model_instance:
|
621 |
+
tokens += sum(
|
622 |
+
embedding_model_instance.get_text_embedding_num_tokens([document.page_content])
|
623 |
+
for document in chunk_documents
|
624 |
+
)
|
625 |
+
|
626 |
+
# load index
|
627 |
+
index_processor.load(dataset, chunk_documents, with_keywords=False)
|
628 |
+
|
629 |
+
document_ids = [document.metadata["doc_id"] for document in chunk_documents]
|
630 |
+
db.session.query(DocumentSegment).filter(
|
631 |
+
DocumentSegment.document_id == dataset_document.id,
|
632 |
+
DocumentSegment.dataset_id == dataset.id,
|
633 |
+
DocumentSegment.index_node_id.in_(document_ids),
|
634 |
+
DocumentSegment.status == "indexing",
|
635 |
+
).update(
|
636 |
+
{
|
637 |
+
DocumentSegment.status: "completed",
|
638 |
+
DocumentSegment.enabled: True,
|
639 |
+
DocumentSegment.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
|
640 |
+
}
|
641 |
+
)
|
642 |
+
|
643 |
+
db.session.commit()
|
644 |
+
|
645 |
+
return tokens
|
646 |
+
|
647 |
+
@staticmethod
|
648 |
+
def _check_document_paused_status(document_id: str):
|
649 |
+
indexing_cache_key = "document_{}_is_paused".format(document_id)
|
650 |
+
result = redis_client.get(indexing_cache_key)
|
651 |
+
if result:
|
652 |
+
raise DocumentIsPausedError()
|
653 |
+
|
654 |
+
@staticmethod
|
655 |
+
def _update_document_index_status(
|
656 |
+
document_id: str, after_indexing_status: str, extra_update_params: Optional[dict] = None
|
657 |
+
) -> None:
|
658 |
+
"""
|
659 |
+
Update the document indexing status.
|
660 |
+
"""
|
661 |
+
count = DatasetDocument.query.filter_by(id=document_id, is_paused=True).count()
|
662 |
+
if count > 0:
|
663 |
+
raise DocumentIsPausedError()
|
664 |
+
document = DatasetDocument.query.filter_by(id=document_id).first()
|
665 |
+
if not document:
|
666 |
+
raise DocumentIsDeletedPausedError()
|
667 |
+
|
668 |
+
update_params = {DatasetDocument.indexing_status: after_indexing_status}
|
669 |
+
|
670 |
+
if extra_update_params:
|
671 |
+
update_params.update(extra_update_params)
|
672 |
+
|
673 |
+
DatasetDocument.query.filter_by(id=document_id).update(update_params)
|
674 |
+
db.session.commit()
|
675 |
+
|
676 |
+
@staticmethod
|
677 |
+
def _update_segments_by_document(dataset_document_id: str, update_params: dict) -> None:
|
678 |
+
"""
|
679 |
+
Update the document segment by document id.
|
680 |
+
"""
|
681 |
+
DocumentSegment.query.filter_by(document_id=dataset_document_id).update(update_params)
|
682 |
+
db.session.commit()
|
683 |
+
|
684 |
+
def _transform(
|
685 |
+
self,
|
686 |
+
index_processor: BaseIndexProcessor,
|
687 |
+
dataset: Dataset,
|
688 |
+
text_docs: list[Document],
|
689 |
+
doc_language: str,
|
690 |
+
process_rule: dict,
|
691 |
+
) -> list[Document]:
|
692 |
+
# get embedding model instance
|
693 |
+
embedding_model_instance = None
|
694 |
+
if dataset.indexing_technique == "high_quality":
|
695 |
+
if dataset.embedding_model_provider:
|
696 |
+
embedding_model_instance = self.model_manager.get_model_instance(
|
697 |
+
tenant_id=dataset.tenant_id,
|
698 |
+
provider=dataset.embedding_model_provider,
|
699 |
+
model_type=ModelType.TEXT_EMBEDDING,
|
700 |
+
model=dataset.embedding_model,
|
701 |
+
)
|
702 |
+
else:
|
703 |
+
embedding_model_instance = self.model_manager.get_default_model_instance(
|
704 |
+
tenant_id=dataset.tenant_id,
|
705 |
+
model_type=ModelType.TEXT_EMBEDDING,
|
706 |
+
)
|
707 |
+
|
708 |
+
documents = index_processor.transform(
|
709 |
+
text_docs,
|
710 |
+
embedding_model_instance=embedding_model_instance,
|
711 |
+
process_rule=process_rule,
|
712 |
+
tenant_id=dataset.tenant_id,
|
713 |
+
doc_language=doc_language,
|
714 |
+
)
|
715 |
+
|
716 |
+
return documents
|
717 |
+
|
718 |
+
def _load_segments(self, dataset, dataset_document, documents):
|
719 |
+
# save node to document segment
|
720 |
+
doc_store = DatasetDocumentStore(
|
721 |
+
dataset=dataset, user_id=dataset_document.created_by, document_id=dataset_document.id
|
722 |
+
)
|
723 |
+
|
724 |
+
# add document segments
|
725 |
+
doc_store.add_documents(docs=documents, save_child=dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX)
|
726 |
+
|
727 |
+
# update document status to indexing
|
728 |
+
cur_time = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
729 |
+
self._update_document_index_status(
|
730 |
+
document_id=dataset_document.id,
|
731 |
+
after_indexing_status="indexing",
|
732 |
+
extra_update_params={
|
733 |
+
DatasetDocument.cleaning_completed_at: cur_time,
|
734 |
+
DatasetDocument.splitting_completed_at: cur_time,
|
735 |
+
},
|
736 |
+
)
|
737 |
+
|
738 |
+
# update segment status to indexing
|
739 |
+
self._update_segments_by_document(
|
740 |
+
dataset_document_id=dataset_document.id,
|
741 |
+
update_params={
|
742 |
+
DocumentSegment.status: "indexing",
|
743 |
+
DocumentSegment.indexing_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
|
744 |
+
},
|
745 |
+
)
|
746 |
+
pass
|
747 |
+
|
748 |
+
|
749 |
+
class DocumentIsPausedError(Exception):
|
750 |
+
pass
|
751 |
+
|
752 |
+
|
753 |
+
class DocumentIsDeletedPausedError(Exception):
|
754 |
+
pass
|
api/core/model_manager.py
ADDED
@@ -0,0 +1,559 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from collections.abc import Callable, Generator, Iterable, Sequence
|
3 |
+
from typing import IO, Any, Optional, Union, cast
|
4 |
+
|
5 |
+
from configs import dify_config
|
6 |
+
from core.entities.embedding_type import EmbeddingInputType
|
7 |
+
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
|
8 |
+
from core.entities.provider_entities import ModelLoadBalancingConfiguration
|
9 |
+
from core.errors.error import ProviderTokenNotInitError
|
10 |
+
from core.model_runtime.callbacks.base_callback import Callback
|
11 |
+
from core.model_runtime.entities.llm_entities import LLMResult
|
12 |
+
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
13 |
+
from core.model_runtime.entities.model_entities import ModelType
|
14 |
+
from core.model_runtime.entities.rerank_entities import RerankResult
|
15 |
+
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
|
16 |
+
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeConnectionError, InvokeRateLimitError
|
17 |
+
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
18 |
+
from core.model_runtime.model_providers.__base.moderation_model import ModerationModel
|
19 |
+
from core.model_runtime.model_providers.__base.rerank_model import RerankModel
|
20 |
+
from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
|
21 |
+
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
22 |
+
from core.model_runtime.model_providers.__base.tts_model import TTSModel
|
23 |
+
from core.provider_manager import ProviderManager
|
24 |
+
from extensions.ext_redis import redis_client
|
25 |
+
from models.provider import ProviderType
|
26 |
+
|
27 |
+
logger = logging.getLogger(__name__)
|
28 |
+
|
29 |
+
|
30 |
+
class ModelInstance:
|
31 |
+
"""
|
32 |
+
Model instance class
|
33 |
+
"""
|
34 |
+
|
35 |
+
def __init__(self, provider_model_bundle: ProviderModelBundle, model: str) -> None:
|
36 |
+
self.provider_model_bundle = provider_model_bundle
|
37 |
+
self.model = model
|
38 |
+
self.provider = provider_model_bundle.configuration.provider.provider
|
39 |
+
self.credentials = self._fetch_credentials_from_bundle(provider_model_bundle, model)
|
40 |
+
self.model_type_instance = self.provider_model_bundle.model_type_instance
|
41 |
+
self.load_balancing_manager = self._get_load_balancing_manager(
|
42 |
+
configuration=provider_model_bundle.configuration,
|
43 |
+
model_type=provider_model_bundle.model_type_instance.model_type,
|
44 |
+
model=model,
|
45 |
+
credentials=self.credentials,
|
46 |
+
)
|
47 |
+
|
48 |
+
@staticmethod
|
49 |
+
def _fetch_credentials_from_bundle(provider_model_bundle: ProviderModelBundle, model: str) -> dict:
|
50 |
+
"""
|
51 |
+
Fetch credentials from provider model bundle
|
52 |
+
:param provider_model_bundle: provider model bundle
|
53 |
+
:param model: model name
|
54 |
+
:return:
|
55 |
+
"""
|
56 |
+
configuration = provider_model_bundle.configuration
|
57 |
+
model_type = provider_model_bundle.model_type_instance.model_type
|
58 |
+
credentials = configuration.get_current_credentials(model_type=model_type, model=model)
|
59 |
+
|
60 |
+
if credentials is None:
|
61 |
+
raise ProviderTokenNotInitError(f"Model {model} credentials is not initialized.")
|
62 |
+
|
63 |
+
return credentials
|
64 |
+
|
65 |
+
@staticmethod
|
66 |
+
def _get_load_balancing_manager(
|
67 |
+
configuration: ProviderConfiguration, model_type: ModelType, model: str, credentials: dict
|
68 |
+
) -> Optional["LBModelManager"]:
|
69 |
+
"""
|
70 |
+
Get load balancing model credentials
|
71 |
+
:param configuration: provider configuration
|
72 |
+
:param model_type: model type
|
73 |
+
:param model: model name
|
74 |
+
:param credentials: model credentials
|
75 |
+
:return:
|
76 |
+
"""
|
77 |
+
if configuration.model_settings and configuration.using_provider_type == ProviderType.CUSTOM:
|
78 |
+
current_model_setting = None
|
79 |
+
# check if model is disabled by admin
|
80 |
+
for model_setting in configuration.model_settings:
|
81 |
+
if model_setting.model_type == model_type and model_setting.model == model:
|
82 |
+
current_model_setting = model_setting
|
83 |
+
break
|
84 |
+
|
85 |
+
# check if load balancing is enabled
|
86 |
+
if current_model_setting and current_model_setting.load_balancing_configs:
|
87 |
+
# use load balancing proxy to choose credentials
|
88 |
+
lb_model_manager = LBModelManager(
|
89 |
+
tenant_id=configuration.tenant_id,
|
90 |
+
provider=configuration.provider.provider,
|
91 |
+
model_type=model_type,
|
92 |
+
model=model,
|
93 |
+
load_balancing_configs=current_model_setting.load_balancing_configs,
|
94 |
+
managed_credentials=credentials if configuration.custom_configuration.provider else None,
|
95 |
+
)
|
96 |
+
|
97 |
+
return lb_model_manager
|
98 |
+
|
99 |
+
return None
|
100 |
+
|
101 |
+
def invoke_llm(
|
102 |
+
self,
|
103 |
+
prompt_messages: Sequence[PromptMessage],
|
104 |
+
model_parameters: Optional[dict] = None,
|
105 |
+
tools: Sequence[PromptMessageTool] | None = None,
|
106 |
+
stop: Optional[Sequence[str]] = None,
|
107 |
+
stream: bool = True,
|
108 |
+
user: Optional[str] = None,
|
109 |
+
callbacks: Optional[list[Callback]] = None,
|
110 |
+
) -> Union[LLMResult, Generator]:
|
111 |
+
"""
|
112 |
+
Invoke large language model
|
113 |
+
|
114 |
+
:param prompt_messages: prompt messages
|
115 |
+
:param model_parameters: model parameters
|
116 |
+
:param tools: tools for tool calling
|
117 |
+
:param stop: stop words
|
118 |
+
:param stream: is stream response
|
119 |
+
:param user: unique user id
|
120 |
+
:param callbacks: callbacks
|
121 |
+
:return: full response or stream response chunk generator result
|
122 |
+
"""
|
123 |
+
if not isinstance(self.model_type_instance, LargeLanguageModel):
|
124 |
+
raise Exception("Model type instance is not LargeLanguageModel")
|
125 |
+
|
126 |
+
self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance)
|
127 |
+
return cast(
|
128 |
+
Union[LLMResult, Generator],
|
129 |
+
self._round_robin_invoke(
|
130 |
+
function=self.model_type_instance.invoke,
|
131 |
+
model=self.model,
|
132 |
+
credentials=self.credentials,
|
133 |
+
prompt_messages=prompt_messages,
|
134 |
+
model_parameters=model_parameters,
|
135 |
+
tools=tools,
|
136 |
+
stop=stop,
|
137 |
+
stream=stream,
|
138 |
+
user=user,
|
139 |
+
callbacks=callbacks,
|
140 |
+
),
|
141 |
+
)
|
142 |
+
|
143 |
+
def get_llm_num_tokens(
|
144 |
+
self, prompt_messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None
|
145 |
+
) -> int:
|
146 |
+
"""
|
147 |
+
Get number of tokens for llm
|
148 |
+
|
149 |
+
:param prompt_messages: prompt messages
|
150 |
+
:param tools: tools for tool calling
|
151 |
+
:return:
|
152 |
+
"""
|
153 |
+
if not isinstance(self.model_type_instance, LargeLanguageModel):
|
154 |
+
raise Exception("Model type instance is not LargeLanguageModel")
|
155 |
+
|
156 |
+
self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance)
|
157 |
+
return cast(
|
158 |
+
int,
|
159 |
+
self._round_robin_invoke(
|
160 |
+
function=self.model_type_instance.get_num_tokens,
|
161 |
+
model=self.model,
|
162 |
+
credentials=self.credentials,
|
163 |
+
prompt_messages=prompt_messages,
|
164 |
+
tools=tools,
|
165 |
+
),
|
166 |
+
)
|
167 |
+
|
168 |
+
def invoke_text_embedding(
|
169 |
+
self, texts: list[str], user: Optional[str] = None, input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT
|
170 |
+
) -> TextEmbeddingResult:
|
171 |
+
"""
|
172 |
+
Invoke large language model
|
173 |
+
|
174 |
+
:param texts: texts to embed
|
175 |
+
:param user: unique user id
|
176 |
+
:param input_type: input type
|
177 |
+
:return: embeddings result
|
178 |
+
"""
|
179 |
+
if not isinstance(self.model_type_instance, TextEmbeddingModel):
|
180 |
+
raise Exception("Model type instance is not TextEmbeddingModel")
|
181 |
+
|
182 |
+
self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance)
|
183 |
+
return cast(
|
184 |
+
TextEmbeddingResult,
|
185 |
+
self._round_robin_invoke(
|
186 |
+
function=self.model_type_instance.invoke,
|
187 |
+
model=self.model,
|
188 |
+
credentials=self.credentials,
|
189 |
+
texts=texts,
|
190 |
+
user=user,
|
191 |
+
input_type=input_type,
|
192 |
+
),
|
193 |
+
)
|
194 |
+
|
195 |
+
def get_text_embedding_num_tokens(self, texts: list[str]) -> int:
|
196 |
+
"""
|
197 |
+
Get number of tokens for text embedding
|
198 |
+
|
199 |
+
:param texts: texts to embed
|
200 |
+
:return:
|
201 |
+
"""
|
202 |
+
if not isinstance(self.model_type_instance, TextEmbeddingModel):
|
203 |
+
raise Exception("Model type instance is not TextEmbeddingModel")
|
204 |
+
|
205 |
+
self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance)
|
206 |
+
return cast(
|
207 |
+
int,
|
208 |
+
self._round_robin_invoke(
|
209 |
+
function=self.model_type_instance.get_num_tokens,
|
210 |
+
model=self.model,
|
211 |
+
credentials=self.credentials,
|
212 |
+
texts=texts,
|
213 |
+
),
|
214 |
+
)
|
215 |
+
|
216 |
+
def invoke_rerank(
|
217 |
+
self,
|
218 |
+
query: str,
|
219 |
+
docs: list[str],
|
220 |
+
score_threshold: Optional[float] = None,
|
221 |
+
top_n: Optional[int] = None,
|
222 |
+
user: Optional[str] = None,
|
223 |
+
) -> RerankResult:
|
224 |
+
"""
|
225 |
+
Invoke rerank model
|
226 |
+
|
227 |
+
:param query: search query
|
228 |
+
:param docs: docs for reranking
|
229 |
+
:param score_threshold: score threshold
|
230 |
+
:param top_n: top n
|
231 |
+
:param user: unique user id
|
232 |
+
:return: rerank result
|
233 |
+
"""
|
234 |
+
if not isinstance(self.model_type_instance, RerankModel):
|
235 |
+
raise Exception("Model type instance is not RerankModel")
|
236 |
+
|
237 |
+
self.model_type_instance = cast(RerankModel, self.model_type_instance)
|
238 |
+
return cast(
|
239 |
+
RerankResult,
|
240 |
+
self._round_robin_invoke(
|
241 |
+
function=self.model_type_instance.invoke,
|
242 |
+
model=self.model,
|
243 |
+
credentials=self.credentials,
|
244 |
+
query=query,
|
245 |
+
docs=docs,
|
246 |
+
score_threshold=score_threshold,
|
247 |
+
top_n=top_n,
|
248 |
+
user=user,
|
249 |
+
),
|
250 |
+
)
|
251 |
+
|
252 |
+
def invoke_moderation(self, text: str, user: Optional[str] = None) -> bool:
|
253 |
+
"""
|
254 |
+
Invoke moderation model
|
255 |
+
|
256 |
+
:param text: text to moderate
|
257 |
+
:param user: unique user id
|
258 |
+
:return: false if text is safe, true otherwise
|
259 |
+
"""
|
260 |
+
if not isinstance(self.model_type_instance, ModerationModel):
|
261 |
+
raise Exception("Model type instance is not ModerationModel")
|
262 |
+
|
263 |
+
self.model_type_instance = cast(ModerationModel, self.model_type_instance)
|
264 |
+
return cast(
|
265 |
+
bool,
|
266 |
+
self._round_robin_invoke(
|
267 |
+
function=self.model_type_instance.invoke,
|
268 |
+
model=self.model,
|
269 |
+
credentials=self.credentials,
|
270 |
+
text=text,
|
271 |
+
user=user,
|
272 |
+
),
|
273 |
+
)
|
274 |
+
|
275 |
+
def invoke_speech2text(self, file: IO[bytes], user: Optional[str] = None) -> str:
|
276 |
+
"""
|
277 |
+
Invoke large language model
|
278 |
+
|
279 |
+
:param file: audio file
|
280 |
+
:param user: unique user id
|
281 |
+
:return: text for given audio file
|
282 |
+
"""
|
283 |
+
if not isinstance(self.model_type_instance, Speech2TextModel):
|
284 |
+
raise Exception("Model type instance is not Speech2TextModel")
|
285 |
+
|
286 |
+
self.model_type_instance = cast(Speech2TextModel, self.model_type_instance)
|
287 |
+
return cast(
|
288 |
+
str,
|
289 |
+
self._round_robin_invoke(
|
290 |
+
function=self.model_type_instance.invoke,
|
291 |
+
model=self.model,
|
292 |
+
credentials=self.credentials,
|
293 |
+
file=file,
|
294 |
+
user=user,
|
295 |
+
),
|
296 |
+
)
|
297 |
+
|
298 |
+
def invoke_tts(self, content_text: str, tenant_id: str, voice: str, user: Optional[str] = None) -> Iterable[bytes]:
|
299 |
+
"""
|
300 |
+
Invoke large language tts model
|
301 |
+
|
302 |
+
:param content_text: text content to be translated
|
303 |
+
:param tenant_id: user tenant id
|
304 |
+
:param voice: model timbre
|
305 |
+
:param user: unique user id
|
306 |
+
:return: text for given audio file
|
307 |
+
"""
|
308 |
+
if not isinstance(self.model_type_instance, TTSModel):
|
309 |
+
raise Exception("Model type instance is not TTSModel")
|
310 |
+
|
311 |
+
self.model_type_instance = cast(TTSModel, self.model_type_instance)
|
312 |
+
return cast(
|
313 |
+
Iterable[bytes],
|
314 |
+
self._round_robin_invoke(
|
315 |
+
function=self.model_type_instance.invoke,
|
316 |
+
model=self.model,
|
317 |
+
credentials=self.credentials,
|
318 |
+
content_text=content_text,
|
319 |
+
user=user,
|
320 |
+
tenant_id=tenant_id,
|
321 |
+
voice=voice,
|
322 |
+
),
|
323 |
+
)
|
324 |
+
|
325 |
+
def _round_robin_invoke(self, function: Callable[..., Any], *args, **kwargs) -> Any:
|
326 |
+
"""
|
327 |
+
Round-robin invoke
|
328 |
+
:param function: function to invoke
|
329 |
+
:param args: function args
|
330 |
+
:param kwargs: function kwargs
|
331 |
+
:return:
|
332 |
+
"""
|
333 |
+
if not self.load_balancing_manager:
|
334 |
+
return function(*args, **kwargs)
|
335 |
+
|
336 |
+
last_exception: Union[InvokeRateLimitError, InvokeAuthorizationError, InvokeConnectionError, None] = None
|
337 |
+
while True:
|
338 |
+
lb_config = self.load_balancing_manager.fetch_next()
|
339 |
+
if not lb_config:
|
340 |
+
if not last_exception:
|
341 |
+
raise ProviderTokenNotInitError("Model credentials is not initialized.")
|
342 |
+
else:
|
343 |
+
raise last_exception
|
344 |
+
|
345 |
+
try:
|
346 |
+
if "credentials" in kwargs:
|
347 |
+
del kwargs["credentials"]
|
348 |
+
return function(*args, **kwargs, credentials=lb_config.credentials)
|
349 |
+
except InvokeRateLimitError as e:
|
350 |
+
# expire in 60 seconds
|
351 |
+
self.load_balancing_manager.cooldown(lb_config, expire=60)
|
352 |
+
last_exception = e
|
353 |
+
continue
|
354 |
+
except (InvokeAuthorizationError, InvokeConnectionError) as e:
|
355 |
+
# expire in 10 seconds
|
356 |
+
self.load_balancing_manager.cooldown(lb_config, expire=10)
|
357 |
+
last_exception = e
|
358 |
+
continue
|
359 |
+
except Exception as e:
|
360 |
+
raise e
|
361 |
+
|
362 |
+
def get_tts_voices(self, language: Optional[str] = None) -> list:
|
363 |
+
"""
|
364 |
+
Invoke large language tts model voices
|
365 |
+
|
366 |
+
:param language: tts language
|
367 |
+
:return: tts model voices
|
368 |
+
"""
|
369 |
+
if not isinstance(self.model_type_instance, TTSModel):
|
370 |
+
raise Exception("Model type instance is not TTSModel")
|
371 |
+
|
372 |
+
self.model_type_instance = cast(TTSModel, self.model_type_instance)
|
373 |
+
return self.model_type_instance.get_tts_model_voices(
|
374 |
+
model=self.model, credentials=self.credentials, language=language
|
375 |
+
)
|
376 |
+
|
377 |
+
|
378 |
+
class ModelManager:
|
379 |
+
def __init__(self) -> None:
|
380 |
+
self._provider_manager = ProviderManager()
|
381 |
+
|
382 |
+
def get_model_instance(self, tenant_id: str, provider: str, model_type: ModelType, model: str) -> ModelInstance:
|
383 |
+
"""
|
384 |
+
Get model instance
|
385 |
+
:param tenant_id: tenant id
|
386 |
+
:param provider: provider name
|
387 |
+
:param model_type: model type
|
388 |
+
:param model: model name
|
389 |
+
:return:
|
390 |
+
"""
|
391 |
+
if not provider:
|
392 |
+
return self.get_default_model_instance(tenant_id, model_type)
|
393 |
+
|
394 |
+
provider_model_bundle = self._provider_manager.get_provider_model_bundle(
|
395 |
+
tenant_id=tenant_id, provider=provider, model_type=model_type
|
396 |
+
)
|
397 |
+
|
398 |
+
return ModelInstance(provider_model_bundle, model)
|
399 |
+
|
400 |
+
def get_default_provider_model_name(self, tenant_id: str, model_type: ModelType) -> tuple[str, str]:
|
401 |
+
"""
|
402 |
+
Return first provider and the first model in the provider
|
403 |
+
:param tenant_id: tenant id
|
404 |
+
:param model_type: model type
|
405 |
+
:return: provider name, model name
|
406 |
+
"""
|
407 |
+
return self._provider_manager.get_first_provider_first_model(tenant_id, model_type)
|
408 |
+
|
409 |
+
def get_default_model_instance(self, tenant_id: str, model_type: ModelType) -> ModelInstance:
|
410 |
+
"""
|
411 |
+
Get default model instance
|
412 |
+
:param tenant_id: tenant id
|
413 |
+
:param model_type: model type
|
414 |
+
:return:
|
415 |
+
"""
|
416 |
+
default_model_entity = self._provider_manager.get_default_model(tenant_id=tenant_id, model_type=model_type)
|
417 |
+
|
418 |
+
if not default_model_entity:
|
419 |
+
raise ProviderTokenNotInitError(f"Default model not found for {model_type}")
|
420 |
+
|
421 |
+
return self.get_model_instance(
|
422 |
+
tenant_id=tenant_id,
|
423 |
+
provider=default_model_entity.provider.provider,
|
424 |
+
model_type=model_type,
|
425 |
+
model=default_model_entity.model,
|
426 |
+
)
|
427 |
+
|
428 |
+
|
429 |
+
class LBModelManager:
|
430 |
+
def __init__(
|
431 |
+
self,
|
432 |
+
tenant_id: str,
|
433 |
+
provider: str,
|
434 |
+
model_type: ModelType,
|
435 |
+
model: str,
|
436 |
+
load_balancing_configs: list[ModelLoadBalancingConfiguration],
|
437 |
+
managed_credentials: Optional[dict] = None,
|
438 |
+
) -> None:
|
439 |
+
"""
|
440 |
+
Load balancing model manager
|
441 |
+
:param tenant_id: tenant_id
|
442 |
+
:param provider: provider
|
443 |
+
:param model_type: model_type
|
444 |
+
:param model: model name
|
445 |
+
:param load_balancing_configs: all load balancing configurations
|
446 |
+
:param managed_credentials: credentials if load balancing configuration name is __inherit__
|
447 |
+
"""
|
448 |
+
self._tenant_id = tenant_id
|
449 |
+
self._provider = provider
|
450 |
+
self._model_type = model_type
|
451 |
+
self._model = model
|
452 |
+
self._load_balancing_configs = load_balancing_configs
|
453 |
+
|
454 |
+
for load_balancing_config in self._load_balancing_configs[:]: # Iterate over a shallow copy of the list
|
455 |
+
if load_balancing_config.name == "__inherit__":
|
456 |
+
if not managed_credentials:
|
457 |
+
# remove __inherit__ if managed credentials is not provided
|
458 |
+
self._load_balancing_configs.remove(load_balancing_config)
|
459 |
+
else:
|
460 |
+
load_balancing_config.credentials = managed_credentials
|
461 |
+
|
462 |
+
def fetch_next(self) -> Optional[ModelLoadBalancingConfiguration]:
|
463 |
+
"""
|
464 |
+
Get next model load balancing config
|
465 |
+
Strategy: Round Robin
|
466 |
+
:return:
|
467 |
+
"""
|
468 |
+
cache_key = "model_lb_index:{}:{}:{}:{}".format(
|
469 |
+
self._tenant_id, self._provider, self._model_type.value, self._model
|
470 |
+
)
|
471 |
+
|
472 |
+
cooldown_load_balancing_configs = []
|
473 |
+
max_index = len(self._load_balancing_configs)
|
474 |
+
|
475 |
+
while True:
|
476 |
+
current_index = redis_client.incr(cache_key)
|
477 |
+
current_index = cast(int, current_index)
|
478 |
+
if current_index >= 10000000:
|
479 |
+
current_index = 1
|
480 |
+
redis_client.set(cache_key, current_index)
|
481 |
+
|
482 |
+
redis_client.expire(cache_key, 3600)
|
483 |
+
if current_index > max_index:
|
484 |
+
current_index = current_index % max_index
|
485 |
+
|
486 |
+
real_index = current_index - 1
|
487 |
+
if real_index > max_index:
|
488 |
+
real_index = 0
|
489 |
+
|
490 |
+
config: ModelLoadBalancingConfiguration = self._load_balancing_configs[real_index]
|
491 |
+
|
492 |
+
if self.in_cooldown(config):
|
493 |
+
cooldown_load_balancing_configs.append(config)
|
494 |
+
if len(cooldown_load_balancing_configs) >= len(self._load_balancing_configs):
|
495 |
+
# all configs are in cooldown
|
496 |
+
return None
|
497 |
+
|
498 |
+
continue
|
499 |
+
|
500 |
+
if dify_config.DEBUG:
|
501 |
+
logger.info(
|
502 |
+
f"Model LB\nid: {config.id}\nname:{config.name}\n"
|
503 |
+
f"tenant_id: {self._tenant_id}\nprovider: {self._provider}\n"
|
504 |
+
f"model_type: {self._model_type.value}\nmodel: {self._model}"
|
505 |
+
)
|
506 |
+
|
507 |
+
return config
|
508 |
+
|
509 |
+
return None
|
510 |
+
|
511 |
+
def cooldown(self, config: ModelLoadBalancingConfiguration, expire: int = 60) -> None:
|
512 |
+
"""
|
513 |
+
Cooldown model load balancing config
|
514 |
+
:param config: model load balancing config
|
515 |
+
:param expire: cooldown time
|
516 |
+
:return:
|
517 |
+
"""
|
518 |
+
cooldown_cache_key = "model_lb_index:cooldown:{}:{}:{}:{}:{}".format(
|
519 |
+
self._tenant_id, self._provider, self._model_type.value, self._model, config.id
|
520 |
+
)
|
521 |
+
|
522 |
+
redis_client.setex(cooldown_cache_key, expire, "true")
|
523 |
+
|
524 |
+
def in_cooldown(self, config: ModelLoadBalancingConfiguration) -> bool:
|
525 |
+
"""
|
526 |
+
Check if model load balancing config is in cooldown
|
527 |
+
:param config: model load balancing config
|
528 |
+
:return:
|
529 |
+
"""
|
530 |
+
cooldown_cache_key = "model_lb_index:cooldown:{}:{}:{}:{}:{}".format(
|
531 |
+
self._tenant_id, self._provider, self._model_type.value, self._model, config.id
|
532 |
+
)
|
533 |
+
|
534 |
+
res: bool = redis_client.exists(cooldown_cache_key)
|
535 |
+
return res
|
536 |
+
|
537 |
+
@staticmethod
|
538 |
+
def get_config_in_cooldown_and_ttl(
|
539 |
+
tenant_id: str, provider: str, model_type: ModelType, model: str, config_id: str
|
540 |
+
) -> tuple[bool, int]:
|
541 |
+
"""
|
542 |
+
Get model load balancing config is in cooldown and ttl
|
543 |
+
:param tenant_id: workspace id
|
544 |
+
:param provider: provider name
|
545 |
+
:param model_type: model type
|
546 |
+
:param model: model name
|
547 |
+
:param config_id: model load balancing config id
|
548 |
+
:return:
|
549 |
+
"""
|
550 |
+
cooldown_cache_key = "model_lb_index:cooldown:{}:{}:{}:{}:{}".format(
|
551 |
+
tenant_id, provider, model_type.value, model, config_id
|
552 |
+
)
|
553 |
+
|
554 |
+
ttl = redis_client.ttl(cooldown_cache_key)
|
555 |
+
if ttl == -2:
|
556 |
+
return False, 0
|
557 |
+
|
558 |
+
ttl = cast(int, ttl)
|
559 |
+
return True, ttl
|
api/core/moderation/__init__.py
ADDED
File without changes
|
api/core/moderation/api/__builtin__
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
3
|
api/core/moderation/api/__init__.py
ADDED
File without changes
|
api/core/moderation/api/api.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
from pydantic import BaseModel
|
4 |
+
|
5 |
+
from core.extension.api_based_extension_requestor import APIBasedExtensionPoint, APIBasedExtensionRequestor
|
6 |
+
from core.helper.encrypter import decrypt_token
|
7 |
+
from core.moderation.base import Moderation, ModerationAction, ModerationInputsResult, ModerationOutputsResult
|
8 |
+
from extensions.ext_database import db
|
9 |
+
from models.api_based_extension import APIBasedExtension
|
10 |
+
|
11 |
+
|
12 |
+
class ModerationInputParams(BaseModel):
|
13 |
+
app_id: str = ""
|
14 |
+
inputs: dict = {}
|
15 |
+
query: str = ""
|
16 |
+
|
17 |
+
|
18 |
+
class ModerationOutputParams(BaseModel):
|
19 |
+
app_id: str = ""
|
20 |
+
text: str
|
21 |
+
|
22 |
+
|
23 |
+
class ApiModeration(Moderation):
|
24 |
+
name: str = "api"
|
25 |
+
|
26 |
+
@classmethod
|
27 |
+
def validate_config(cls, tenant_id: str, config: dict) -> None:
|
28 |
+
"""
|
29 |
+
Validate the incoming form config data.
|
30 |
+
|
31 |
+
:param tenant_id: the id of workspace
|
32 |
+
:param config: the form config data
|
33 |
+
:return:
|
34 |
+
"""
|
35 |
+
cls._validate_inputs_and_outputs_config(config, False)
|
36 |
+
|
37 |
+
api_based_extension_id = config.get("api_based_extension_id")
|
38 |
+
if not api_based_extension_id:
|
39 |
+
raise ValueError("api_based_extension_id is required")
|
40 |
+
|
41 |
+
extension = cls._get_api_based_extension(tenant_id, api_based_extension_id)
|
42 |
+
if not extension:
|
43 |
+
raise ValueError("API-based Extension not found. Please check it again.")
|
44 |
+
|
45 |
+
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
|
46 |
+
flagged = False
|
47 |
+
preset_response = ""
|
48 |
+
if self.config is None:
|
49 |
+
raise ValueError("The config is not set.")
|
50 |
+
|
51 |
+
if self.config["inputs_config"]["enabled"]:
|
52 |
+
params = ModerationInputParams(app_id=self.app_id, inputs=inputs, query=query)
|
53 |
+
|
54 |
+
result = self._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_INPUT, params.model_dump())
|
55 |
+
return ModerationInputsResult(**result)
|
56 |
+
|
57 |
+
return ModerationInputsResult(
|
58 |
+
flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
|
59 |
+
)
|
60 |
+
|
61 |
+
def moderation_for_outputs(self, text: str) -> ModerationOutputsResult:
|
62 |
+
flagged = False
|
63 |
+
preset_response = ""
|
64 |
+
if self.config is None:
|
65 |
+
raise ValueError("The config is not set.")
|
66 |
+
|
67 |
+
if self.config["outputs_config"]["enabled"]:
|
68 |
+
params = ModerationOutputParams(app_id=self.app_id, text=text)
|
69 |
+
|
70 |
+
result = self._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_OUTPUT, params.model_dump())
|
71 |
+
return ModerationOutputsResult(**result)
|
72 |
+
|
73 |
+
return ModerationOutputsResult(
|
74 |
+
flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
|
75 |
+
)
|
76 |
+
|
77 |
+
def _get_config_by_requestor(self, extension_point: APIBasedExtensionPoint, params: dict) -> dict:
|
78 |
+
if self.config is None:
|
79 |
+
raise ValueError("The config is not set.")
|
80 |
+
extension = self._get_api_based_extension(self.tenant_id, self.config.get("api_based_extension_id", ""))
|
81 |
+
if not extension:
|
82 |
+
raise ValueError("API-based Extension not found. Please check it again.")
|
83 |
+
requestor = APIBasedExtensionRequestor(extension.api_endpoint, decrypt_token(self.tenant_id, extension.api_key))
|
84 |
+
|
85 |
+
result = requestor.request(extension_point, params)
|
86 |
+
return result
|
87 |
+
|
88 |
+
@staticmethod
|
89 |
+
def _get_api_based_extension(tenant_id: str, api_based_extension_id: str) -> Optional[APIBasedExtension]:
|
90 |
+
extension = (
|
91 |
+
db.session.query(APIBasedExtension)
|
92 |
+
.filter(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id)
|
93 |
+
.first()
|
94 |
+
)
|
95 |
+
|
96 |
+
return extension
|
api/core/moderation/base.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
from enum import Enum
|
3 |
+
from typing import Optional
|
4 |
+
|
5 |
+
from pydantic import BaseModel
|
6 |
+
|
7 |
+
from core.extension.extensible import Extensible, ExtensionModule
|
8 |
+
|
9 |
+
|
10 |
+
class ModerationAction(Enum):
|
11 |
+
DIRECT_OUTPUT = "direct_output"
|
12 |
+
OVERRIDDEN = "overridden"
|
13 |
+
|
14 |
+
|
15 |
+
class ModerationInputsResult(BaseModel):
|
16 |
+
flagged: bool = False
|
17 |
+
action: ModerationAction
|
18 |
+
preset_response: str = ""
|
19 |
+
inputs: dict = {}
|
20 |
+
query: str = ""
|
21 |
+
|
22 |
+
|
23 |
+
class ModerationOutputsResult(BaseModel):
|
24 |
+
flagged: bool = False
|
25 |
+
action: ModerationAction
|
26 |
+
preset_response: str = ""
|
27 |
+
text: str = ""
|
28 |
+
|
29 |
+
|
30 |
+
class Moderation(Extensible, ABC):
|
31 |
+
"""
|
32 |
+
The base class of moderation.
|
33 |
+
"""
|
34 |
+
|
35 |
+
module: ExtensionModule = ExtensionModule.MODERATION
|
36 |
+
|
37 |
+
def __init__(self, app_id: str, tenant_id: str, config: Optional[dict] = None) -> None:
|
38 |
+
super().__init__(tenant_id, config)
|
39 |
+
self.app_id = app_id
|
40 |
+
|
41 |
+
@classmethod
|
42 |
+
@abstractmethod
|
43 |
+
def validate_config(cls, tenant_id: str, config: dict) -> None:
|
44 |
+
"""
|
45 |
+
Validate the incoming form config data.
|
46 |
+
|
47 |
+
:param tenant_id: the id of workspace
|
48 |
+
:param config: the form config data
|
49 |
+
:return:
|
50 |
+
"""
|
51 |
+
raise NotImplementedError
|
52 |
+
|
53 |
+
@abstractmethod
|
54 |
+
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
|
55 |
+
"""
|
56 |
+
Moderation for inputs.
|
57 |
+
After the user inputs, this method will be called to perform sensitive content review
|
58 |
+
on the user inputs and return the processed results.
|
59 |
+
|
60 |
+
:param inputs: user inputs
|
61 |
+
:param query: query string (required in chat app)
|
62 |
+
:return:
|
63 |
+
"""
|
64 |
+
raise NotImplementedError
|
65 |
+
|
66 |
+
@abstractmethod
|
67 |
+
def moderation_for_outputs(self, text: str) -> ModerationOutputsResult:
|
68 |
+
"""
|
69 |
+
Moderation for outputs.
|
70 |
+
When LLM outputs content, the front end will pass the output content (may be segmented)
|
71 |
+
to this method for sensitive content review, and the output content will be shielded if the review fails.
|
72 |
+
|
73 |
+
:param text: LLM output content
|
74 |
+
:return:
|
75 |
+
"""
|
76 |
+
raise NotImplementedError
|
77 |
+
|
78 |
+
@classmethod
|
79 |
+
def _validate_inputs_and_outputs_config(cls, config: dict, is_preset_response_required: bool) -> None:
|
80 |
+
# inputs_config
|
81 |
+
inputs_config = config.get("inputs_config")
|
82 |
+
if not isinstance(inputs_config, dict):
|
83 |
+
raise ValueError("inputs_config must be a dict")
|
84 |
+
|
85 |
+
# outputs_config
|
86 |
+
outputs_config = config.get("outputs_config")
|
87 |
+
if not isinstance(outputs_config, dict):
|
88 |
+
raise ValueError("outputs_config must be a dict")
|
89 |
+
|
90 |
+
inputs_config_enabled = inputs_config.get("enabled")
|
91 |
+
outputs_config_enabled = outputs_config.get("enabled")
|
92 |
+
if not inputs_config_enabled and not outputs_config_enabled:
|
93 |
+
raise ValueError("At least one of inputs_config or outputs_config must be enabled")
|
94 |
+
|
95 |
+
# preset_response
|
96 |
+
if not is_preset_response_required:
|
97 |
+
return
|
98 |
+
|
99 |
+
if inputs_config_enabled:
|
100 |
+
if not inputs_config.get("preset_response"):
|
101 |
+
raise ValueError("inputs_config.preset_response is required")
|
102 |
+
|
103 |
+
if len(inputs_config.get("preset_response", 0)) > 100:
|
104 |
+
raise ValueError("inputs_config.preset_response must be less than 100 characters")
|
105 |
+
|
106 |
+
if outputs_config_enabled:
|
107 |
+
if not outputs_config.get("preset_response"):
|
108 |
+
raise ValueError("outputs_config.preset_response is required")
|
109 |
+
|
110 |
+
if len(outputs_config.get("preset_response", 0)) > 100:
|
111 |
+
raise ValueError("outputs_config.preset_response must be less than 100 characters")
|
112 |
+
|
113 |
+
|
114 |
+
class ModerationError(Exception):
|
115 |
+
pass
|
api/core/moderation/factory.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from core.extension.extensible import ExtensionModule
|
2 |
+
from core.moderation.base import Moderation, ModerationInputsResult, ModerationOutputsResult
|
3 |
+
from extensions.ext_code_based_extension import code_based_extension
|
4 |
+
|
5 |
+
|
6 |
+
class ModerationFactory:
|
7 |
+
__extension_instance: Moderation
|
8 |
+
|
9 |
+
def __init__(self, name: str, app_id: str, tenant_id: str, config: dict) -> None:
|
10 |
+
extension_class = code_based_extension.extension_class(ExtensionModule.MODERATION, name)
|
11 |
+
self.__extension_instance = extension_class(app_id, tenant_id, config)
|
12 |
+
|
13 |
+
@classmethod
|
14 |
+
def validate_config(cls, name: str, tenant_id: str, config: dict) -> None:
|
15 |
+
"""
|
16 |
+
Validate the incoming form config data.
|
17 |
+
|
18 |
+
:param name: the name of extension
|
19 |
+
:param tenant_id: the id of workspace
|
20 |
+
:param config: the form config data
|
21 |
+
:return:
|
22 |
+
"""
|
23 |
+
code_based_extension.validate_form_schema(ExtensionModule.MODERATION, name, config)
|
24 |
+
extension_class = code_based_extension.extension_class(ExtensionModule.MODERATION, name)
|
25 |
+
# FIXME: mypy error, try to fix it instead of using type: ignore
|
26 |
+
extension_class.validate_config(tenant_id, config) # type: ignore
|
27 |
+
|
28 |
+
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
|
29 |
+
"""
|
30 |
+
Moderation for inputs.
|
31 |
+
After the user inputs, this method will be called to perform sensitive content review
|
32 |
+
on the user inputs and return the processed results.
|
33 |
+
|
34 |
+
:param inputs: user inputs
|
35 |
+
:param query: query string (required in chat app)
|
36 |
+
:return:
|
37 |
+
"""
|
38 |
+
return self.__extension_instance.moderation_for_inputs(inputs, query)
|
39 |
+
|
40 |
+
def moderation_for_outputs(self, text: str) -> ModerationOutputsResult:
|
41 |
+
"""
|
42 |
+
Moderation for outputs.
|
43 |
+
When LLM outputs content, the front end will pass the output content (may be segmented)
|
44 |
+
to this method for sensitive content review, and the output content will be shielded if the review fails.
|
45 |
+
|
46 |
+
:param text: LLM output content
|
47 |
+
:return:
|
48 |
+
"""
|
49 |
+
return self.__extension_instance.moderation_for_outputs(text)
|
api/core/moderation/input_moderation.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from collections.abc import Mapping
|
3 |
+
from typing import Any, Optional
|
4 |
+
|
5 |
+
from core.app.app_config.entities import AppConfig
|
6 |
+
from core.moderation.base import ModerationAction, ModerationError
|
7 |
+
from core.moderation.factory import ModerationFactory
|
8 |
+
from core.ops.entities.trace_entity import TraceTaskName
|
9 |
+
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
10 |
+
from core.ops.utils import measure_time
|
11 |
+
|
12 |
+
logger = logging.getLogger(__name__)
|
13 |
+
|
14 |
+
|
15 |
+
class InputModeration:
|
16 |
+
def check(
|
17 |
+
self,
|
18 |
+
app_id: str,
|
19 |
+
tenant_id: str,
|
20 |
+
app_config: AppConfig,
|
21 |
+
inputs: Mapping[str, Any],
|
22 |
+
query: str,
|
23 |
+
message_id: str,
|
24 |
+
trace_manager: Optional[TraceQueueManager] = None,
|
25 |
+
) -> tuple[bool, Mapping[str, Any], str]:
|
26 |
+
"""
|
27 |
+
Process sensitive_word_avoidance.
|
28 |
+
:param app_id: app id
|
29 |
+
:param tenant_id: tenant id
|
30 |
+
:param app_config: app config
|
31 |
+
:param inputs: inputs
|
32 |
+
:param query: query
|
33 |
+
:param message_id: message id
|
34 |
+
:param trace_manager: trace manager
|
35 |
+
:return:
|
36 |
+
"""
|
37 |
+
inputs = dict(inputs)
|
38 |
+
if not app_config.sensitive_word_avoidance:
|
39 |
+
return False, inputs, query
|
40 |
+
|
41 |
+
sensitive_word_avoidance_config = app_config.sensitive_word_avoidance
|
42 |
+
moderation_type = sensitive_word_avoidance_config.type
|
43 |
+
|
44 |
+
moderation_factory = ModerationFactory(
|
45 |
+
name=moderation_type, app_id=app_id, tenant_id=tenant_id, config=sensitive_word_avoidance_config.config
|
46 |
+
)
|
47 |
+
|
48 |
+
with measure_time() as timer:
|
49 |
+
moderation_result = moderation_factory.moderation_for_inputs(inputs, query)
|
50 |
+
|
51 |
+
if trace_manager:
|
52 |
+
trace_manager.add_trace_task(
|
53 |
+
TraceTask(
|
54 |
+
TraceTaskName.MODERATION_TRACE,
|
55 |
+
message_id=message_id,
|
56 |
+
moderation_result=moderation_result,
|
57 |
+
inputs=inputs,
|
58 |
+
timer=timer,
|
59 |
+
)
|
60 |
+
)
|
61 |
+
|
62 |
+
if not moderation_result.flagged:
|
63 |
+
return False, inputs, query
|
64 |
+
|
65 |
+
if moderation_result.action == ModerationAction.DIRECT_OUTPUT:
|
66 |
+
raise ModerationError(moderation_result.preset_response)
|
67 |
+
elif moderation_result.action == ModerationAction.OVERRIDDEN:
|
68 |
+
inputs = moderation_result.inputs
|
69 |
+
query = moderation_result.query
|
70 |
+
|
71 |
+
return True, inputs, query
|
api/core/moderation/keywords/__builtin__
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
2
|
api/core/moderation/keywords/__init__.py
ADDED
File without changes
|
api/core/moderation/keywords/keywords.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections.abc import Sequence
|
2 |
+
from typing import Any
|
3 |
+
|
4 |
+
from core.moderation.base import Moderation, ModerationAction, ModerationInputsResult, ModerationOutputsResult
|
5 |
+
|
6 |
+
|
7 |
+
class KeywordsModeration(Moderation):
|
8 |
+
name: str = "keywords"
|
9 |
+
|
10 |
+
@classmethod
|
11 |
+
def validate_config(cls, tenant_id: str, config: dict) -> None:
|
12 |
+
"""
|
13 |
+
Validate the incoming form config data.
|
14 |
+
|
15 |
+
:param tenant_id: the id of workspace
|
16 |
+
:param config: the form config data
|
17 |
+
:return:
|
18 |
+
"""
|
19 |
+
cls._validate_inputs_and_outputs_config(config, True)
|
20 |
+
|
21 |
+
if not config.get("keywords"):
|
22 |
+
raise ValueError("keywords is required")
|
23 |
+
|
24 |
+
if len(config.get("keywords", [])) > 10000:
|
25 |
+
raise ValueError("keywords length must be less than 10000")
|
26 |
+
|
27 |
+
keywords_row_len = config["keywords"].split("\n")
|
28 |
+
if len(keywords_row_len) > 100:
|
29 |
+
raise ValueError("the number of rows for the keywords must be less than 100")
|
30 |
+
|
31 |
+
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
|
32 |
+
flagged = False
|
33 |
+
preset_response = ""
|
34 |
+
if self.config is None:
|
35 |
+
raise ValueError("The config is not set.")
|
36 |
+
|
37 |
+
if self.config["inputs_config"]["enabled"]:
|
38 |
+
preset_response = self.config["inputs_config"]["preset_response"]
|
39 |
+
|
40 |
+
if query:
|
41 |
+
inputs["query__"] = query
|
42 |
+
|
43 |
+
# Filter out empty values
|
44 |
+
keywords_list = [keyword for keyword in self.config["keywords"].split("\n") if keyword]
|
45 |
+
|
46 |
+
flagged = self._is_violated(inputs, keywords_list)
|
47 |
+
|
48 |
+
return ModerationInputsResult(
|
49 |
+
flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
|
50 |
+
)
|
51 |
+
|
52 |
+
def moderation_for_outputs(self, text: str) -> ModerationOutputsResult:
|
53 |
+
flagged = False
|
54 |
+
preset_response = ""
|
55 |
+
if self.config is None:
|
56 |
+
raise ValueError("The config is not set.")
|
57 |
+
|
58 |
+
if self.config["outputs_config"]["enabled"]:
|
59 |
+
# Filter out empty values
|
60 |
+
keywords_list = [keyword for keyword in self.config["keywords"].split("\n") if keyword]
|
61 |
+
|
62 |
+
flagged = self._is_violated({"text": text}, keywords_list)
|
63 |
+
preset_response = self.config["outputs_config"]["preset_response"]
|
64 |
+
|
65 |
+
return ModerationOutputsResult(
|
66 |
+
flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
|
67 |
+
)
|
68 |
+
|
69 |
+
def _is_violated(self, inputs: dict, keywords_list: list) -> bool:
|
70 |
+
return any(self._check_keywords_in_value(keywords_list, value) for value in inputs.values())
|
71 |
+
|
72 |
+
def _check_keywords_in_value(self, keywords_list: Sequence[str], value: Any) -> bool:
|
73 |
+
return any(keyword.lower() in str(value).lower() for keyword in keywords_list)
|
api/core/moderation/openai_moderation/__builtin__
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
1
|
api/core/moderation/openai_moderation/__init__.py
ADDED
File without changes
|
api/core/moderation/openai_moderation/openai_moderation.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from core.model_manager import ModelManager
|
2 |
+
from core.model_runtime.entities.model_entities import ModelType
|
3 |
+
from core.moderation.base import Moderation, ModerationAction, ModerationInputsResult, ModerationOutputsResult
|
4 |
+
|
5 |
+
|
6 |
+
class OpenAIModeration(Moderation):
|
7 |
+
name: str = "openai_moderation"
|
8 |
+
|
9 |
+
@classmethod
|
10 |
+
def validate_config(cls, tenant_id: str, config: dict) -> None:
|
11 |
+
"""
|
12 |
+
Validate the incoming form config data.
|
13 |
+
|
14 |
+
:param tenant_id: the id of workspace
|
15 |
+
:param config: the form config data
|
16 |
+
:return:
|
17 |
+
"""
|
18 |
+
cls._validate_inputs_and_outputs_config(config, True)
|
19 |
+
|
20 |
+
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
|
21 |
+
flagged = False
|
22 |
+
preset_response = ""
|
23 |
+
if self.config is None:
|
24 |
+
raise ValueError("The config is not set.")
|
25 |
+
|
26 |
+
if self.config["inputs_config"]["enabled"]:
|
27 |
+
preset_response = self.config["inputs_config"]["preset_response"]
|
28 |
+
|
29 |
+
if query:
|
30 |
+
inputs["query__"] = query
|
31 |
+
flagged = self._is_violated(inputs)
|
32 |
+
|
33 |
+
return ModerationInputsResult(
|
34 |
+
flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
|
35 |
+
)
|
36 |
+
|
37 |
+
def moderation_for_outputs(self, text: str) -> ModerationOutputsResult:
|
38 |
+
flagged = False
|
39 |
+
preset_response = ""
|
40 |
+
if self.config is None:
|
41 |
+
raise ValueError("The config is not set.")
|
42 |
+
|
43 |
+
if self.config["outputs_config"]["enabled"]:
|
44 |
+
flagged = self._is_violated({"text": text})
|
45 |
+
preset_response = self.config["outputs_config"]["preset_response"]
|
46 |
+
|
47 |
+
return ModerationOutputsResult(
|
48 |
+
flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
|
49 |
+
)
|
50 |
+
|
51 |
+
def _is_violated(self, inputs: dict):
|
52 |
+
text = "\n".join(str(inputs.values()))
|
53 |
+
model_manager = ModelManager()
|
54 |
+
model_instance = model_manager.get_model_instance(
|
55 |
+
tenant_id=self.tenant_id, provider="openai", model_type=ModelType.MODERATION, model="text-moderation-stable"
|
56 |
+
)
|
57 |
+
|
58 |
+
openai_moderation = model_instance.invoke_moderation(text=text)
|
59 |
+
|
60 |
+
return openai_moderation
|
api/core/moderation/output_moderation.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import threading
|
3 |
+
import time
|
4 |
+
from typing import Any, Optional
|
5 |
+
|
6 |
+
from flask import Flask, current_app
|
7 |
+
from pydantic import BaseModel, ConfigDict
|
8 |
+
|
9 |
+
from configs import dify_config
|
10 |
+
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
11 |
+
from core.app.entities.queue_entities import QueueMessageReplaceEvent
|
12 |
+
from core.moderation.base import ModerationAction, ModerationOutputsResult
|
13 |
+
from core.moderation.factory import ModerationFactory
|
14 |
+
|
15 |
+
logger = logging.getLogger(__name__)
|
16 |
+
|
17 |
+
|
18 |
+
class ModerationRule(BaseModel):
|
19 |
+
type: str
|
20 |
+
config: dict[str, Any]
|
21 |
+
|
22 |
+
|
23 |
+
class OutputModeration(BaseModel):
|
24 |
+
tenant_id: str
|
25 |
+
app_id: str
|
26 |
+
|
27 |
+
rule: ModerationRule
|
28 |
+
queue_manager: AppQueueManager
|
29 |
+
|
30 |
+
thread: Optional[threading.Thread] = None
|
31 |
+
thread_running: bool = True
|
32 |
+
buffer: str = ""
|
33 |
+
is_final_chunk: bool = False
|
34 |
+
final_output: Optional[str] = None
|
35 |
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
36 |
+
|
37 |
+
def should_direct_output(self) -> bool:
|
38 |
+
return self.final_output is not None
|
39 |
+
|
40 |
+
def get_final_output(self) -> str:
|
41 |
+
return self.final_output or ""
|
42 |
+
|
43 |
+
def append_new_token(self, token: str) -> None:
|
44 |
+
self.buffer += token
|
45 |
+
|
46 |
+
if not self.thread:
|
47 |
+
self.thread = self.start_thread()
|
48 |
+
|
49 |
+
def moderation_completion(self, completion: str, public_event: bool = False) -> str:
|
50 |
+
self.buffer = completion
|
51 |
+
self.is_final_chunk = True
|
52 |
+
|
53 |
+
result = self.moderation(tenant_id=self.tenant_id, app_id=self.app_id, moderation_buffer=completion)
|
54 |
+
|
55 |
+
if not result or not result.flagged:
|
56 |
+
return completion
|
57 |
+
|
58 |
+
if result.action == ModerationAction.DIRECT_OUTPUT:
|
59 |
+
final_output = result.preset_response
|
60 |
+
else:
|
61 |
+
final_output = result.text
|
62 |
+
|
63 |
+
if public_event:
|
64 |
+
self.queue_manager.publish(QueueMessageReplaceEvent(text=final_output), PublishFrom.TASK_PIPELINE)
|
65 |
+
|
66 |
+
return final_output
|
67 |
+
|
68 |
+
def start_thread(self) -> threading.Thread:
|
69 |
+
buffer_size = dify_config.MODERATION_BUFFER_SIZE
|
70 |
+
thread = threading.Thread(
|
71 |
+
target=self.worker,
|
72 |
+
kwargs={
|
73 |
+
"flask_app": current_app._get_current_object(), # type: ignore
|
74 |
+
"buffer_size": buffer_size if buffer_size > 0 else dify_config.MODERATION_BUFFER_SIZE,
|
75 |
+
},
|
76 |
+
)
|
77 |
+
|
78 |
+
thread.start()
|
79 |
+
|
80 |
+
return thread
|
81 |
+
|
82 |
+
def stop_thread(self):
|
83 |
+
if self.thread and self.thread.is_alive():
|
84 |
+
self.thread_running = False
|
85 |
+
|
86 |
+
def worker(self, flask_app: Flask, buffer_size: int):
|
87 |
+
with flask_app.app_context():
|
88 |
+
current_length = 0
|
89 |
+
while self.thread_running:
|
90 |
+
moderation_buffer = self.buffer
|
91 |
+
buffer_length = len(moderation_buffer)
|
92 |
+
if not self.is_final_chunk:
|
93 |
+
chunk_length = buffer_length - current_length
|
94 |
+
if 0 <= chunk_length < buffer_size:
|
95 |
+
time.sleep(1)
|
96 |
+
continue
|
97 |
+
|
98 |
+
current_length = buffer_length
|
99 |
+
|
100 |
+
result = self.moderation(
|
101 |
+
tenant_id=self.tenant_id, app_id=self.app_id, moderation_buffer=moderation_buffer
|
102 |
+
)
|
103 |
+
|
104 |
+
if not result or not result.flagged:
|
105 |
+
continue
|
106 |
+
|
107 |
+
if result.action == ModerationAction.DIRECT_OUTPUT:
|
108 |
+
final_output = result.preset_response
|
109 |
+
self.final_output = final_output
|
110 |
+
else:
|
111 |
+
final_output = result.text + self.buffer[len(moderation_buffer) :]
|
112 |
+
|
113 |
+
# trigger replace event
|
114 |
+
if self.thread_running:
|
115 |
+
self.queue_manager.publish(QueueMessageReplaceEvent(text=final_output), PublishFrom.TASK_PIPELINE)
|
116 |
+
|
117 |
+
if result.action == ModerationAction.DIRECT_OUTPUT:
|
118 |
+
break
|
119 |
+
|
120 |
+
def moderation(self, tenant_id: str, app_id: str, moderation_buffer: str) -> Optional[ModerationOutputsResult]:
|
121 |
+
try:
|
122 |
+
moderation_factory = ModerationFactory(
|
123 |
+
name=self.rule.type, app_id=app_id, tenant_id=tenant_id, config=self.rule.config
|
124 |
+
)
|
125 |
+
|
126 |
+
result: ModerationOutputsResult = moderation_factory.moderation_for_outputs(moderation_buffer)
|
127 |
+
return result
|
128 |
+
except Exception as e:
|
129 |
+
logger.exception(f"Moderation Output error, app_id: {app_id}")
|
130 |
+
|
131 |
+
return None
|
api/core/ops/__init__.py
ADDED
File without changes
|
api/core/ops/base_trace_instance.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
|
3 |
+
from core.ops.entities.config_entity import BaseTracingConfig
|
4 |
+
from core.ops.entities.trace_entity import BaseTraceInfo
|
5 |
+
|
6 |
+
|
7 |
+
class BaseTraceInstance(ABC):
|
8 |
+
"""
|
9 |
+
Base trace instance for ops trace services
|
10 |
+
"""
|
11 |
+
|
12 |
+
@abstractmethod
|
13 |
+
def __init__(self, trace_config: BaseTracingConfig):
|
14 |
+
"""
|
15 |
+
Abstract initializer for the trace instance.
|
16 |
+
Distribute trace tasks by matching entities
|
17 |
+
"""
|
18 |
+
self.trace_config = trace_config
|
19 |
+
|
20 |
+
@abstractmethod
|
21 |
+
def trace(self, trace_info: BaseTraceInfo):
|
22 |
+
"""
|
23 |
+
Abstract method to trace activities.
|
24 |
+
Subclasses must implement specific tracing logic for activities.
|
25 |
+
"""
|
26 |
+
...
|
api/core/ops/entities/__init__.py
ADDED
File without changes
|
api/core/ops/entities/config_entity.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from enum import Enum
|
2 |
+
|
3 |
+
from pydantic import BaseModel, ValidationInfo, field_validator
|
4 |
+
|
5 |
+
|
6 |
+
class TracingProviderEnum(Enum):
|
7 |
+
LANGFUSE = "langfuse"
|
8 |
+
LANGSMITH = "langsmith"
|
9 |
+
OPIK = "opik"
|
10 |
+
|
11 |
+
|
12 |
+
class BaseTracingConfig(BaseModel):
|
13 |
+
"""
|
14 |
+
Base model class for tracing
|
15 |
+
"""
|
16 |
+
|
17 |
+
...
|
18 |
+
|
19 |
+
|
20 |
+
class LangfuseConfig(BaseTracingConfig):
|
21 |
+
"""
|
22 |
+
Model class for Langfuse tracing config.
|
23 |
+
"""
|
24 |
+
|
25 |
+
public_key: str
|
26 |
+
secret_key: str
|
27 |
+
host: str = "https://api.langfuse.com"
|
28 |
+
|
29 |
+
@field_validator("host")
|
30 |
+
@classmethod
|
31 |
+
def set_value(cls, v, info: ValidationInfo):
|
32 |
+
if v is None or v == "":
|
33 |
+
v = "https://api.langfuse.com"
|
34 |
+
if not v.startswith("https://") and not v.startswith("http://"):
|
35 |
+
raise ValueError("host must start with https:// or http://")
|
36 |
+
|
37 |
+
return v
|
38 |
+
|
39 |
+
|
40 |
+
class LangSmithConfig(BaseTracingConfig):
|
41 |
+
"""
|
42 |
+
Model class for Langsmith tracing config.
|
43 |
+
"""
|
44 |
+
|
45 |
+
api_key: str
|
46 |
+
project: str
|
47 |
+
endpoint: str = "https://api.smith.langchain.com"
|
48 |
+
|
49 |
+
@field_validator("endpoint")
|
50 |
+
@classmethod
|
51 |
+
def set_value(cls, v, info: ValidationInfo):
|
52 |
+
if v is None or v == "":
|
53 |
+
v = "https://api.smith.langchain.com"
|
54 |
+
if not v.startswith("https://"):
|
55 |
+
raise ValueError("endpoint must start with https://")
|
56 |
+
|
57 |
+
return v
|
58 |
+
|
59 |
+
|
60 |
+
class OpikConfig(BaseTracingConfig):
|
61 |
+
"""
|
62 |
+
Model class for Opik tracing config.
|
63 |
+
"""
|
64 |
+
|
65 |
+
api_key: str | None = None
|
66 |
+
project: str | None = None
|
67 |
+
workspace: str | None = None
|
68 |
+
url: str = "https://www.comet.com/opik/api/"
|
69 |
+
|
70 |
+
@field_validator("project")
|
71 |
+
@classmethod
|
72 |
+
def project_validator(cls, v, info: ValidationInfo):
|
73 |
+
if v is None or v == "":
|
74 |
+
v = "Default Project"
|
75 |
+
|
76 |
+
return v
|
77 |
+
|
78 |
+
@field_validator("url")
|
79 |
+
@classmethod
|
80 |
+
def url_validator(cls, v, info: ValidationInfo):
|
81 |
+
if v is None or v == "":
|
82 |
+
v = "https://www.comet.com/opik/api/"
|
83 |
+
if not v.startswith(("https://", "http://")):
|
84 |
+
raise ValueError("url must start with https:// or http://")
|
85 |
+
if not v.endswith("/api/"):
|
86 |
+
raise ValueError("url should ends with /api/")
|
87 |
+
|
88 |
+
return v
|
89 |
+
|
90 |
+
|
91 |
+
OPS_FILE_PATH = "ops_trace/"
|
92 |
+
OPS_TRACE_FAILED_KEY = "FAILED_OPS_TRACE"
|
api/core/ops/entities/trace_entity.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections.abc import Mapping
|
2 |
+
from datetime import datetime
|
3 |
+
from enum import StrEnum
|
4 |
+
from typing import Any, Optional, Union
|
5 |
+
|
6 |
+
from pydantic import BaseModel, ConfigDict, field_validator
|
7 |
+
|
8 |
+
|
9 |
+
class BaseTraceInfo(BaseModel):
|
10 |
+
message_id: Optional[str] = None
|
11 |
+
message_data: Optional[Any] = None
|
12 |
+
inputs: Optional[Union[str, dict[str, Any], list]] = None
|
13 |
+
outputs: Optional[Union[str, dict[str, Any], list]] = None
|
14 |
+
start_time: Optional[datetime] = None
|
15 |
+
end_time: Optional[datetime] = None
|
16 |
+
metadata: dict[str, Any]
|
17 |
+
|
18 |
+
@field_validator("inputs", "outputs")
|
19 |
+
@classmethod
|
20 |
+
def ensure_type(cls, v):
|
21 |
+
if v is None:
|
22 |
+
return None
|
23 |
+
if isinstance(v, str | dict | list):
|
24 |
+
return v
|
25 |
+
return ""
|
26 |
+
|
27 |
+
class Config:
|
28 |
+
json_encoders = {
|
29 |
+
datetime: lambda v: v.isoformat(),
|
30 |
+
}
|
31 |
+
|
32 |
+
|
33 |
+
class WorkflowTraceInfo(BaseTraceInfo):
|
34 |
+
workflow_data: Any
|
35 |
+
conversation_id: Optional[str] = None
|
36 |
+
workflow_app_log_id: Optional[str] = None
|
37 |
+
workflow_id: str
|
38 |
+
tenant_id: str
|
39 |
+
workflow_run_id: str
|
40 |
+
workflow_run_elapsed_time: Union[int, float]
|
41 |
+
workflow_run_status: str
|
42 |
+
workflow_run_inputs: Mapping[str, Any]
|
43 |
+
workflow_run_outputs: Mapping[str, Any]
|
44 |
+
workflow_run_version: str
|
45 |
+
error: Optional[str] = None
|
46 |
+
total_tokens: int
|
47 |
+
file_list: list[str]
|
48 |
+
query: str
|
49 |
+
metadata: dict[str, Any]
|
50 |
+
|
51 |
+
|
52 |
+
class MessageTraceInfo(BaseTraceInfo):
|
53 |
+
conversation_model: str
|
54 |
+
message_tokens: int
|
55 |
+
answer_tokens: int
|
56 |
+
total_tokens: int
|
57 |
+
error: Optional[str] = None
|
58 |
+
file_list: Optional[Union[str, dict[str, Any], list]] = None
|
59 |
+
message_file_data: Optional[Any] = None
|
60 |
+
conversation_mode: str
|
61 |
+
|
62 |
+
|
63 |
+
class ModerationTraceInfo(BaseTraceInfo):
|
64 |
+
flagged: bool
|
65 |
+
action: str
|
66 |
+
preset_response: str
|
67 |
+
query: str
|
68 |
+
|
69 |
+
|
70 |
+
class SuggestedQuestionTraceInfo(BaseTraceInfo):
|
71 |
+
total_tokens: int
|
72 |
+
status: Optional[str] = None
|
73 |
+
error: Optional[str] = None
|
74 |
+
from_account_id: Optional[str] = None
|
75 |
+
agent_based: Optional[bool] = None
|
76 |
+
from_source: Optional[str] = None
|
77 |
+
model_provider: Optional[str] = None
|
78 |
+
model_id: Optional[str] = None
|
79 |
+
suggested_question: list[str]
|
80 |
+
level: str
|
81 |
+
status_message: Optional[str] = None
|
82 |
+
workflow_run_id: Optional[str] = None
|
83 |
+
|
84 |
+
model_config = ConfigDict(protected_namespaces=())
|
85 |
+
|
86 |
+
|
87 |
+
class DatasetRetrievalTraceInfo(BaseTraceInfo):
|
88 |
+
documents: Any
|
89 |
+
|
90 |
+
|
91 |
+
class ToolTraceInfo(BaseTraceInfo):
|
92 |
+
tool_name: str
|
93 |
+
tool_inputs: dict[str, Any]
|
94 |
+
tool_outputs: str
|
95 |
+
metadata: dict[str, Any]
|
96 |
+
message_file_data: Any
|
97 |
+
error: Optional[str] = None
|
98 |
+
tool_config: dict[str, Any]
|
99 |
+
time_cost: Union[int, float]
|
100 |
+
tool_parameters: dict[str, Any]
|
101 |
+
file_url: Union[str, None, list]
|
102 |
+
|
103 |
+
|
104 |
+
class GenerateNameTraceInfo(BaseTraceInfo):
|
105 |
+
conversation_id: Optional[str] = None
|
106 |
+
tenant_id: str
|
107 |
+
|
108 |
+
|
109 |
+
class TaskData(BaseModel):
|
110 |
+
app_id: str
|
111 |
+
trace_info_type: str
|
112 |
+
trace_info: Any
|
113 |
+
|
114 |
+
|
115 |
+
trace_info_info_map = {
|
116 |
+
"WorkflowTraceInfo": WorkflowTraceInfo,
|
117 |
+
"MessageTraceInfo": MessageTraceInfo,
|
118 |
+
"ModerationTraceInfo": ModerationTraceInfo,
|
119 |
+
"SuggestedQuestionTraceInfo": SuggestedQuestionTraceInfo,
|
120 |
+
"DatasetRetrievalTraceInfo": DatasetRetrievalTraceInfo,
|
121 |
+
"ToolTraceInfo": ToolTraceInfo,
|
122 |
+
"GenerateNameTraceInfo": GenerateNameTraceInfo,
|
123 |
+
}
|
124 |
+
|
125 |
+
|
126 |
+
class TraceTaskName(StrEnum):
|
127 |
+
CONVERSATION_TRACE = "conversation"
|
128 |
+
WORKFLOW_TRACE = "workflow"
|
129 |
+
MESSAGE_TRACE = "message"
|
130 |
+
MODERATION_TRACE = "moderation"
|
131 |
+
SUGGESTED_QUESTION_TRACE = "suggested_question"
|
132 |
+
DATASET_RETRIEVAL_TRACE = "dataset_retrieval"
|
133 |
+
TOOL_TRACE = "tool"
|
134 |
+
GENERATE_NAME_TRACE = "generate_conversation_name"
|
api/core/ops/langfuse_trace/__init__.py
ADDED
File without changes
|
api/core/ops/langfuse_trace/entities/__init__.py
ADDED
File without changes
|
api/core/ops/langfuse_trace/entities/langfuse_trace_entity.py
ADDED
@@ -0,0 +1,282 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datetime import datetime
|
2 |
+
from enum import StrEnum
|
3 |
+
from typing import Any, Optional, Union
|
4 |
+
|
5 |
+
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
6 |
+
from pydantic_core.core_schema import ValidationInfo
|
7 |
+
|
8 |
+
from core.ops.utils import replace_text_with_content
|
9 |
+
|
10 |
+
|
11 |
+
def validate_input_output(v, field_name):
|
12 |
+
"""
|
13 |
+
Validate input output
|
14 |
+
:param v:
|
15 |
+
:param field_name:
|
16 |
+
:return:
|
17 |
+
"""
|
18 |
+
if v == {} or v is None:
|
19 |
+
return v
|
20 |
+
if isinstance(v, str):
|
21 |
+
return [
|
22 |
+
{
|
23 |
+
"role": "assistant" if field_name == "output" else "user",
|
24 |
+
"content": v,
|
25 |
+
}
|
26 |
+
]
|
27 |
+
elif isinstance(v, list):
|
28 |
+
if len(v) > 0 and isinstance(v[0], dict):
|
29 |
+
v = replace_text_with_content(data=v)
|
30 |
+
return v
|
31 |
+
else:
|
32 |
+
return [
|
33 |
+
{
|
34 |
+
"role": "assistant" if field_name == "output" else "user",
|
35 |
+
"content": str(v),
|
36 |
+
}
|
37 |
+
]
|
38 |
+
|
39 |
+
return v
|
40 |
+
|
41 |
+
|
42 |
+
class LevelEnum(StrEnum):
|
43 |
+
DEBUG = "DEBUG"
|
44 |
+
WARNING = "WARNING"
|
45 |
+
ERROR = "ERROR"
|
46 |
+
DEFAULT = "DEFAULT"
|
47 |
+
|
48 |
+
|
49 |
+
class LangfuseTrace(BaseModel):
|
50 |
+
"""
|
51 |
+
Langfuse trace model
|
52 |
+
"""
|
53 |
+
|
54 |
+
id: Optional[str] = Field(
|
55 |
+
default=None,
|
56 |
+
description="The id of the trace can be set, defaults to a random id. Used to link traces to external systems "
|
57 |
+
"or when creating a distributed trace. Traces are upserted on id.",
|
58 |
+
)
|
59 |
+
name: Optional[str] = Field(
|
60 |
+
default=None,
|
61 |
+
description="Identifier of the trace. Useful for sorting/filtering in the UI.",
|
62 |
+
)
|
63 |
+
input: Optional[Union[str, dict[str, Any], list, None]] = Field(
|
64 |
+
default=None, description="The input of the trace. Can be any JSON object."
|
65 |
+
)
|
66 |
+
output: Optional[Union[str, dict[str, Any], list, None]] = Field(
|
67 |
+
default=None, description="The output of the trace. Can be any JSON object."
|
68 |
+
)
|
69 |
+
metadata: Optional[dict[str, Any]] = Field(
|
70 |
+
default=None,
|
71 |
+
description="Additional metadata of the trace. Can be any JSON object. Metadata is merged when being updated "
|
72 |
+
"via the API.",
|
73 |
+
)
|
74 |
+
user_id: Optional[str] = Field(
|
75 |
+
default=None,
|
76 |
+
description="The id of the user that triggered the execution. Used to provide user-level analytics.",
|
77 |
+
)
|
78 |
+
session_id: Optional[str] = Field(
|
79 |
+
default=None,
|
80 |
+
description="Used to group multiple traces into a session in Langfuse. Use your own session/thread identifier.",
|
81 |
+
)
|
82 |
+
version: Optional[str] = Field(
|
83 |
+
default=None,
|
84 |
+
description="The version of the trace type. Used to understand how changes to the trace type affect metrics. "
|
85 |
+
"Useful in debugging.",
|
86 |
+
)
|
87 |
+
release: Optional[str] = Field(
|
88 |
+
default=None,
|
89 |
+
description="The release identifier of the current deployment. Used to understand how changes of different "
|
90 |
+
"deployments affect metrics. Useful in debugging.",
|
91 |
+
)
|
92 |
+
tags: Optional[list[str]] = Field(
|
93 |
+
default=None,
|
94 |
+
description="Tags are used to categorize or label traces. Traces can be filtered by tags in the UI and GET "
|
95 |
+
"API. Tags can also be changed in the UI. Tags are merged and never deleted via the API.",
|
96 |
+
)
|
97 |
+
public: Optional[bool] = Field(
|
98 |
+
default=None,
|
99 |
+
description="You can make a trace public to share it via a public link. This allows others to view the trace "
|
100 |
+
"without needing to log in or be members of your Langfuse project.",
|
101 |
+
)
|
102 |
+
|
103 |
+
@field_validator("input", "output")
|
104 |
+
@classmethod
|
105 |
+
def ensure_dict(cls, v, info: ValidationInfo):
|
106 |
+
field_name = info.field_name
|
107 |
+
return validate_input_output(v, field_name)
|
108 |
+
|
109 |
+
|
110 |
+
class LangfuseSpan(BaseModel):
|
111 |
+
"""
|
112 |
+
Langfuse span model
|
113 |
+
"""
|
114 |
+
|
115 |
+
id: Optional[str] = Field(
|
116 |
+
default=None,
|
117 |
+
description="The id of the span can be set, otherwise a random id is generated. Spans are upserted on id.",
|
118 |
+
)
|
119 |
+
session_id: Optional[str] = Field(
|
120 |
+
default=None,
|
121 |
+
description="Used to group multiple spans into a session in Langfuse. Use your own session/thread identifier.",
|
122 |
+
)
|
123 |
+
trace_id: Optional[str] = Field(
|
124 |
+
default=None,
|
125 |
+
description="The id of the trace the span belongs to. Used to link spans to traces.",
|
126 |
+
)
|
127 |
+
user_id: Optional[str] = Field(
|
128 |
+
default=None,
|
129 |
+
description="The id of the user that triggered the execution. Used to provide user-level analytics.",
|
130 |
+
)
|
131 |
+
start_time: Optional[datetime | str] = Field(
|
132 |
+
default_factory=datetime.now,
|
133 |
+
description="The time at which the span started, defaults to the current time.",
|
134 |
+
)
|
135 |
+
end_time: Optional[datetime | str] = Field(
|
136 |
+
default=None,
|
137 |
+
description="The time at which the span ended. Automatically set by span.end().",
|
138 |
+
)
|
139 |
+
name: Optional[str] = Field(
|
140 |
+
default=None,
|
141 |
+
description="Identifier of the span. Useful for sorting/filtering in the UI.",
|
142 |
+
)
|
143 |
+
metadata: Optional[dict[str, Any]] = Field(
|
144 |
+
default=None,
|
145 |
+
description="Additional metadata of the span. Can be any JSON object. Metadata is merged when being updated "
|
146 |
+
"via the API.",
|
147 |
+
)
|
148 |
+
level: Optional[str] = Field(
|
149 |
+
default=None,
|
150 |
+
description="The level of the span. Can be DEBUG, DEFAULT, WARNING or ERROR. Used for sorting/filtering of "
|
151 |
+
"traces with elevated error levels and for highlighting in the UI.",
|
152 |
+
)
|
153 |
+
status_message: Optional[str] = Field(
|
154 |
+
default=None,
|
155 |
+
description="The status message of the span. Additional field for context of the event. E.g. the error "
|
156 |
+
"message of an error event.",
|
157 |
+
)
|
158 |
+
input: Optional[Union[str, dict[str, Any], list, None]] = Field(
|
159 |
+
default=None, description="The input of the span. Can be any JSON object."
|
160 |
+
)
|
161 |
+
output: Optional[Union[str, dict[str, Any], list, None]] = Field(
|
162 |
+
default=None, description="The output of the span. Can be any JSON object."
|
163 |
+
)
|
164 |
+
version: Optional[str] = Field(
|
165 |
+
default=None,
|
166 |
+
description="The version of the span type. Used to understand how changes to the span type affect metrics. "
|
167 |
+
"Useful in debugging.",
|
168 |
+
)
|
169 |
+
parent_observation_id: Optional[str] = Field(
|
170 |
+
default=None,
|
171 |
+
description="The id of the observation the span belongs to. Used to link spans to observations.",
|
172 |
+
)
|
173 |
+
|
174 |
+
@field_validator("input", "output")
|
175 |
+
@classmethod
|
176 |
+
def ensure_dict(cls, v, info: ValidationInfo):
|
177 |
+
field_name = info.field_name
|
178 |
+
return validate_input_output(v, field_name)
|
179 |
+
|
180 |
+
|
181 |
+
class UnitEnum(StrEnum):
|
182 |
+
CHARACTERS = "CHARACTERS"
|
183 |
+
TOKENS = "TOKENS"
|
184 |
+
SECONDS = "SECONDS"
|
185 |
+
MILLISECONDS = "MILLISECONDS"
|
186 |
+
IMAGES = "IMAGES"
|
187 |
+
|
188 |
+
|
189 |
+
class GenerationUsage(BaseModel):
|
190 |
+
promptTokens: Optional[int] = None
|
191 |
+
completionTokens: Optional[int] = None
|
192 |
+
total: Optional[int] = None
|
193 |
+
input: Optional[int] = None
|
194 |
+
output: Optional[int] = None
|
195 |
+
unit: Optional[UnitEnum] = None
|
196 |
+
inputCost: Optional[float] = None
|
197 |
+
outputCost: Optional[float] = None
|
198 |
+
totalCost: Optional[float] = None
|
199 |
+
|
200 |
+
@field_validator("input", "output")
|
201 |
+
@classmethod
|
202 |
+
def ensure_dict(cls, v, info: ValidationInfo):
|
203 |
+
field_name = info.field_name
|
204 |
+
return validate_input_output(v, field_name)
|
205 |
+
|
206 |
+
|
207 |
+
class LangfuseGeneration(BaseModel):
|
208 |
+
id: Optional[str] = Field(
|
209 |
+
default=None,
|
210 |
+
description="The id of the generation can be set, defaults to random id.",
|
211 |
+
)
|
212 |
+
trace_id: Optional[str] = Field(
|
213 |
+
default=None,
|
214 |
+
description="The id of the trace the generation belongs to. Used to link generations to traces.",
|
215 |
+
)
|
216 |
+
parent_observation_id: Optional[str] = Field(
|
217 |
+
default=None,
|
218 |
+
description="The id of the observation the generation belongs to. Used to link generations to observations.",
|
219 |
+
)
|
220 |
+
name: Optional[str] = Field(
|
221 |
+
default=None,
|
222 |
+
description="Identifier of the generation. Useful for sorting/filtering in the UI.",
|
223 |
+
)
|
224 |
+
start_time: Optional[datetime | str] = Field(
|
225 |
+
default_factory=datetime.now,
|
226 |
+
description="The time at which the generation started, defaults to the current time.",
|
227 |
+
)
|
228 |
+
completion_start_time: Optional[datetime | str] = Field(
|
229 |
+
default=None,
|
230 |
+
description="The time at which the completion started (streaming). Set it to get latency analytics broken "
|
231 |
+
"down into time until completion started and completion duration.",
|
232 |
+
)
|
233 |
+
end_time: Optional[datetime | str] = Field(
|
234 |
+
default=None,
|
235 |
+
description="The time at which the generation ended. Automatically set by generation.end().",
|
236 |
+
)
|
237 |
+
model: Optional[str] = Field(default=None, description="The name of the model used for the generation.")
|
238 |
+
model_parameters: Optional[dict[str, Any]] = Field(
|
239 |
+
default=None,
|
240 |
+
description="The parameters of the model used for the generation; can be any key-value pairs.",
|
241 |
+
)
|
242 |
+
input: Optional[Any] = Field(
|
243 |
+
default=None,
|
244 |
+
description="The prompt used for the generation. Can be any string or JSON object.",
|
245 |
+
)
|
246 |
+
output: Optional[Any] = Field(
|
247 |
+
default=None,
|
248 |
+
description="The completion generated by the model. Can be any string or JSON object.",
|
249 |
+
)
|
250 |
+
usage: Optional[GenerationUsage] = Field(
|
251 |
+
default=None,
|
252 |
+
description="The usage object supports the OpenAi structure with tokens and a more generic version with "
|
253 |
+
"detailed costs and units.",
|
254 |
+
)
|
255 |
+
metadata: Optional[dict[str, Any]] = Field(
|
256 |
+
default=None,
|
257 |
+
description="Additional metadata of the generation. Can be any JSON object. Metadata is merged when being "
|
258 |
+
"updated via the API.",
|
259 |
+
)
|
260 |
+
level: Optional[LevelEnum] = Field(
|
261 |
+
default=None,
|
262 |
+
description="The level of the generation. Can be DEBUG, DEFAULT, WARNING or ERROR. Used for sorting/filtering "
|
263 |
+
"of traces with elevated error levels and for highlighting in the UI.",
|
264 |
+
)
|
265 |
+
status_message: Optional[str] = Field(
|
266 |
+
default=None,
|
267 |
+
description="The status message of the generation. Additional field for context of the event. E.g. the error "
|
268 |
+
"message of an error event.",
|
269 |
+
)
|
270 |
+
version: Optional[str] = Field(
|
271 |
+
default=None,
|
272 |
+
description="The version of the generation type. Used to understand how changes to the span type affect "
|
273 |
+
"metrics. Useful in debugging.",
|
274 |
+
)
|
275 |
+
|
276 |
+
model_config = ConfigDict(protected_namespaces=())
|
277 |
+
|
278 |
+
@field_validator("input", "output")
|
279 |
+
@classmethod
|
280 |
+
def ensure_dict(cls, v, info: ValidationInfo):
|
281 |
+
field_name = info.field_name
|
282 |
+
return validate_input_output(v, field_name)
|
api/core/ops/langfuse_trace/langfuse_trace.py
ADDED
@@ -0,0 +1,455 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
from datetime import datetime, timedelta
|
5 |
+
from typing import Optional
|
6 |
+
|
7 |
+
from langfuse import Langfuse # type: ignore
|
8 |
+
|
9 |
+
from core.ops.base_trace_instance import BaseTraceInstance
|
10 |
+
from core.ops.entities.config_entity import LangfuseConfig
|
11 |
+
from core.ops.entities.trace_entity import (
|
12 |
+
BaseTraceInfo,
|
13 |
+
DatasetRetrievalTraceInfo,
|
14 |
+
GenerateNameTraceInfo,
|
15 |
+
MessageTraceInfo,
|
16 |
+
ModerationTraceInfo,
|
17 |
+
SuggestedQuestionTraceInfo,
|
18 |
+
ToolTraceInfo,
|
19 |
+
TraceTaskName,
|
20 |
+
WorkflowTraceInfo,
|
21 |
+
)
|
22 |
+
from core.ops.langfuse_trace.entities.langfuse_trace_entity import (
|
23 |
+
GenerationUsage,
|
24 |
+
LangfuseGeneration,
|
25 |
+
LangfuseSpan,
|
26 |
+
LangfuseTrace,
|
27 |
+
LevelEnum,
|
28 |
+
UnitEnum,
|
29 |
+
)
|
30 |
+
from core.ops.utils import filter_none_values
|
31 |
+
from extensions.ext_database import db
|
32 |
+
from models.model import EndUser
|
33 |
+
from models.workflow import WorkflowNodeExecution
|
34 |
+
|
35 |
+
logger = logging.getLogger(__name__)
|
36 |
+
|
37 |
+
|
38 |
+
class LangFuseDataTrace(BaseTraceInstance):
|
39 |
+
def __init__(
|
40 |
+
self,
|
41 |
+
langfuse_config: LangfuseConfig,
|
42 |
+
):
|
43 |
+
super().__init__(langfuse_config)
|
44 |
+
self.langfuse_client = Langfuse(
|
45 |
+
public_key=langfuse_config.public_key,
|
46 |
+
secret_key=langfuse_config.secret_key,
|
47 |
+
host=langfuse_config.host,
|
48 |
+
)
|
49 |
+
self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
|
50 |
+
|
51 |
+
def trace(self, trace_info: BaseTraceInfo):
|
52 |
+
if isinstance(trace_info, WorkflowTraceInfo):
|
53 |
+
self.workflow_trace(trace_info)
|
54 |
+
if isinstance(trace_info, MessageTraceInfo):
|
55 |
+
self.message_trace(trace_info)
|
56 |
+
if isinstance(trace_info, ModerationTraceInfo):
|
57 |
+
self.moderation_trace(trace_info)
|
58 |
+
if isinstance(trace_info, SuggestedQuestionTraceInfo):
|
59 |
+
self.suggested_question_trace(trace_info)
|
60 |
+
if isinstance(trace_info, DatasetRetrievalTraceInfo):
|
61 |
+
self.dataset_retrieval_trace(trace_info)
|
62 |
+
if isinstance(trace_info, ToolTraceInfo):
|
63 |
+
self.tool_trace(trace_info)
|
64 |
+
if isinstance(trace_info, GenerateNameTraceInfo):
|
65 |
+
self.generate_name_trace(trace_info)
|
66 |
+
|
67 |
+
def workflow_trace(self, trace_info: WorkflowTraceInfo):
|
68 |
+
trace_id = trace_info.workflow_run_id
|
69 |
+
user_id = trace_info.metadata.get("user_id")
|
70 |
+
metadata = trace_info.metadata
|
71 |
+
metadata["workflow_app_log_id"] = trace_info.workflow_app_log_id
|
72 |
+
|
73 |
+
if trace_info.message_id:
|
74 |
+
trace_id = trace_info.message_id
|
75 |
+
name = TraceTaskName.MESSAGE_TRACE.value
|
76 |
+
trace_data = LangfuseTrace(
|
77 |
+
id=trace_id,
|
78 |
+
user_id=user_id,
|
79 |
+
name=name,
|
80 |
+
input=dict(trace_info.workflow_run_inputs),
|
81 |
+
output=dict(trace_info.workflow_run_outputs),
|
82 |
+
metadata=metadata,
|
83 |
+
session_id=trace_info.conversation_id,
|
84 |
+
tags=["message", "workflow"],
|
85 |
+
)
|
86 |
+
self.add_trace(langfuse_trace_data=trace_data)
|
87 |
+
workflow_span_data = LangfuseSpan(
|
88 |
+
id=trace_info.workflow_run_id,
|
89 |
+
name=TraceTaskName.WORKFLOW_TRACE.value,
|
90 |
+
input=dict(trace_info.workflow_run_inputs),
|
91 |
+
output=dict(trace_info.workflow_run_outputs),
|
92 |
+
trace_id=trace_id,
|
93 |
+
start_time=trace_info.start_time,
|
94 |
+
end_time=trace_info.end_time,
|
95 |
+
metadata=metadata,
|
96 |
+
level=LevelEnum.DEFAULT if trace_info.error == "" else LevelEnum.ERROR,
|
97 |
+
status_message=trace_info.error or "",
|
98 |
+
)
|
99 |
+
self.add_span(langfuse_span_data=workflow_span_data)
|
100 |
+
else:
|
101 |
+
trace_data = LangfuseTrace(
|
102 |
+
id=trace_id,
|
103 |
+
user_id=user_id,
|
104 |
+
name=TraceTaskName.WORKFLOW_TRACE.value,
|
105 |
+
input=dict(trace_info.workflow_run_inputs),
|
106 |
+
output=dict(trace_info.workflow_run_outputs),
|
107 |
+
metadata=metadata,
|
108 |
+
session_id=trace_info.conversation_id,
|
109 |
+
tags=["workflow"],
|
110 |
+
)
|
111 |
+
self.add_trace(langfuse_trace_data=trace_data)
|
112 |
+
|
113 |
+
# through workflow_run_id get all_nodes_execution
|
114 |
+
workflow_nodes_execution_id_records = (
|
115 |
+
db.session.query(WorkflowNodeExecution.id)
|
116 |
+
.filter(WorkflowNodeExecution.workflow_run_id == trace_info.workflow_run_id)
|
117 |
+
.all()
|
118 |
+
)
|
119 |
+
|
120 |
+
for node_execution_id_record in workflow_nodes_execution_id_records:
|
121 |
+
node_execution = (
|
122 |
+
db.session.query(
|
123 |
+
WorkflowNodeExecution.id,
|
124 |
+
WorkflowNodeExecution.tenant_id,
|
125 |
+
WorkflowNodeExecution.app_id,
|
126 |
+
WorkflowNodeExecution.title,
|
127 |
+
WorkflowNodeExecution.node_type,
|
128 |
+
WorkflowNodeExecution.status,
|
129 |
+
WorkflowNodeExecution.inputs,
|
130 |
+
WorkflowNodeExecution.outputs,
|
131 |
+
WorkflowNodeExecution.created_at,
|
132 |
+
WorkflowNodeExecution.elapsed_time,
|
133 |
+
WorkflowNodeExecution.process_data,
|
134 |
+
WorkflowNodeExecution.execution_metadata,
|
135 |
+
)
|
136 |
+
.filter(WorkflowNodeExecution.id == node_execution_id_record.id)
|
137 |
+
.first()
|
138 |
+
)
|
139 |
+
|
140 |
+
if not node_execution:
|
141 |
+
continue
|
142 |
+
|
143 |
+
node_execution_id = node_execution.id
|
144 |
+
tenant_id = node_execution.tenant_id
|
145 |
+
app_id = node_execution.app_id
|
146 |
+
node_name = node_execution.title
|
147 |
+
node_type = node_execution.node_type
|
148 |
+
status = node_execution.status
|
149 |
+
if node_type == "llm":
|
150 |
+
inputs = (
|
151 |
+
json.loads(node_execution.process_data).get("prompts", {}) if node_execution.process_data else {}
|
152 |
+
)
|
153 |
+
else:
|
154 |
+
inputs = json.loads(node_execution.inputs) if node_execution.inputs else {}
|
155 |
+
outputs = json.loads(node_execution.outputs) if node_execution.outputs else {}
|
156 |
+
created_at = node_execution.created_at or datetime.now()
|
157 |
+
elapsed_time = node_execution.elapsed_time
|
158 |
+
finished_at = created_at + timedelta(seconds=elapsed_time)
|
159 |
+
|
160 |
+
metadata = json.loads(node_execution.execution_metadata) if node_execution.execution_metadata else {}
|
161 |
+
metadata.update(
|
162 |
+
{
|
163 |
+
"workflow_run_id": trace_info.workflow_run_id,
|
164 |
+
"node_execution_id": node_execution_id,
|
165 |
+
"tenant_id": tenant_id,
|
166 |
+
"app_id": app_id,
|
167 |
+
"node_name": node_name,
|
168 |
+
"node_type": node_type,
|
169 |
+
"status": status,
|
170 |
+
}
|
171 |
+
)
|
172 |
+
process_data = json.loads(node_execution.process_data) if node_execution.process_data else {}
|
173 |
+
model_provider = process_data.get("model_provider", None)
|
174 |
+
model_name = process_data.get("model_name", None)
|
175 |
+
if model_provider is not None and model_name is not None:
|
176 |
+
metadata.update(
|
177 |
+
{
|
178 |
+
"model_provider": model_provider,
|
179 |
+
"model_name": model_name,
|
180 |
+
}
|
181 |
+
)
|
182 |
+
|
183 |
+
# add span
|
184 |
+
if trace_info.message_id:
|
185 |
+
span_data = LangfuseSpan(
|
186 |
+
id=node_execution_id,
|
187 |
+
name=node_type,
|
188 |
+
input=inputs,
|
189 |
+
output=outputs,
|
190 |
+
trace_id=trace_id,
|
191 |
+
start_time=created_at,
|
192 |
+
end_time=finished_at,
|
193 |
+
metadata=metadata,
|
194 |
+
level=(LevelEnum.DEFAULT if status == "succeeded" else LevelEnum.ERROR),
|
195 |
+
status_message=trace_info.error or "",
|
196 |
+
parent_observation_id=trace_info.workflow_run_id,
|
197 |
+
)
|
198 |
+
else:
|
199 |
+
span_data = LangfuseSpan(
|
200 |
+
id=node_execution_id,
|
201 |
+
name=node_type,
|
202 |
+
input=inputs,
|
203 |
+
output=outputs,
|
204 |
+
trace_id=trace_id,
|
205 |
+
start_time=created_at,
|
206 |
+
end_time=finished_at,
|
207 |
+
metadata=metadata,
|
208 |
+
level=(LevelEnum.DEFAULT if status == "succeeded" else LevelEnum.ERROR),
|
209 |
+
status_message=trace_info.error or "",
|
210 |
+
)
|
211 |
+
|
212 |
+
self.add_span(langfuse_span_data=span_data)
|
213 |
+
|
214 |
+
if process_data and process_data.get("model_mode") == "chat":
|
215 |
+
total_token = metadata.get("total_tokens", 0)
|
216 |
+
# add generation
|
217 |
+
generation_usage = GenerationUsage(
|
218 |
+
total=total_token,
|
219 |
+
)
|
220 |
+
|
221 |
+
node_generation_data = LangfuseGeneration(
|
222 |
+
name="llm",
|
223 |
+
trace_id=trace_id,
|
224 |
+
model=process_data.get("model_name"),
|
225 |
+
parent_observation_id=node_execution_id,
|
226 |
+
start_time=created_at,
|
227 |
+
end_time=finished_at,
|
228 |
+
input=inputs,
|
229 |
+
output=outputs,
|
230 |
+
metadata=metadata,
|
231 |
+
level=(LevelEnum.DEFAULT if status == "succeeded" else LevelEnum.ERROR),
|
232 |
+
status_message=trace_info.error or "",
|
233 |
+
usage=generation_usage,
|
234 |
+
)
|
235 |
+
|
236 |
+
self.add_generation(langfuse_generation_data=node_generation_data)
|
237 |
+
|
238 |
+
def message_trace(self, trace_info: MessageTraceInfo, **kwargs):
|
239 |
+
# get message file data
|
240 |
+
file_list = trace_info.file_list
|
241 |
+
metadata = trace_info.metadata
|
242 |
+
message_data = trace_info.message_data
|
243 |
+
if message_data is None:
|
244 |
+
return
|
245 |
+
message_id = message_data.id
|
246 |
+
|
247 |
+
user_id = message_data.from_account_id
|
248 |
+
if message_data.from_end_user_id:
|
249 |
+
end_user_data: Optional[EndUser] = (
|
250 |
+
db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first()
|
251 |
+
)
|
252 |
+
if end_user_data is not None:
|
253 |
+
user_id = end_user_data.session_id
|
254 |
+
metadata["user_id"] = user_id
|
255 |
+
|
256 |
+
trace_data = LangfuseTrace(
|
257 |
+
id=message_id,
|
258 |
+
user_id=user_id,
|
259 |
+
name=TraceTaskName.MESSAGE_TRACE.value,
|
260 |
+
input={
|
261 |
+
"message": trace_info.inputs,
|
262 |
+
"files": file_list,
|
263 |
+
"message_tokens": trace_info.message_tokens,
|
264 |
+
"answer_tokens": trace_info.answer_tokens,
|
265 |
+
"total_tokens": trace_info.total_tokens,
|
266 |
+
"error": trace_info.error,
|
267 |
+
"provider_response_latency": message_data.provider_response_latency,
|
268 |
+
"created_at": trace_info.start_time,
|
269 |
+
},
|
270 |
+
output=trace_info.outputs,
|
271 |
+
metadata=metadata,
|
272 |
+
session_id=message_data.conversation_id,
|
273 |
+
tags=["message", str(trace_info.conversation_mode)],
|
274 |
+
version=None,
|
275 |
+
release=None,
|
276 |
+
public=None,
|
277 |
+
)
|
278 |
+
self.add_trace(langfuse_trace_data=trace_data)
|
279 |
+
|
280 |
+
# start add span
|
281 |
+
generation_usage = GenerationUsage(
|
282 |
+
input=trace_info.message_tokens,
|
283 |
+
output=trace_info.answer_tokens,
|
284 |
+
total=trace_info.total_tokens,
|
285 |
+
unit=UnitEnum.TOKENS,
|
286 |
+
totalCost=message_data.total_price,
|
287 |
+
)
|
288 |
+
|
289 |
+
langfuse_generation_data = LangfuseGeneration(
|
290 |
+
name="llm",
|
291 |
+
trace_id=message_id,
|
292 |
+
start_time=trace_info.start_time,
|
293 |
+
end_time=trace_info.end_time,
|
294 |
+
model=message_data.model_id,
|
295 |
+
input=trace_info.inputs,
|
296 |
+
output=message_data.answer,
|
297 |
+
metadata=metadata,
|
298 |
+
level=(LevelEnum.DEFAULT if message_data.status != "error" else LevelEnum.ERROR),
|
299 |
+
status_message=message_data.error or "",
|
300 |
+
usage=generation_usage,
|
301 |
+
)
|
302 |
+
|
303 |
+
self.add_generation(langfuse_generation_data)
|
304 |
+
|
305 |
+
def moderation_trace(self, trace_info: ModerationTraceInfo):
|
306 |
+
if trace_info.message_data is None:
|
307 |
+
return
|
308 |
+
span_data = LangfuseSpan(
|
309 |
+
name=TraceTaskName.MODERATION_TRACE.value,
|
310 |
+
input=trace_info.inputs,
|
311 |
+
output={
|
312 |
+
"action": trace_info.action,
|
313 |
+
"flagged": trace_info.flagged,
|
314 |
+
"preset_response": trace_info.preset_response,
|
315 |
+
"inputs": trace_info.inputs,
|
316 |
+
},
|
317 |
+
trace_id=trace_info.message_id,
|
318 |
+
start_time=trace_info.start_time or trace_info.message_data.created_at,
|
319 |
+
end_time=trace_info.end_time or trace_info.message_data.created_at,
|
320 |
+
metadata=trace_info.metadata,
|
321 |
+
)
|
322 |
+
|
323 |
+
self.add_span(langfuse_span_data=span_data)
|
324 |
+
|
325 |
+
def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo):
|
326 |
+
message_data = trace_info.message_data
|
327 |
+
if message_data is None:
|
328 |
+
return
|
329 |
+
generation_usage = GenerationUsage(
|
330 |
+
total=len(str(trace_info.suggested_question)),
|
331 |
+
input=len(trace_info.inputs) if trace_info.inputs else 0,
|
332 |
+
output=len(trace_info.suggested_question),
|
333 |
+
unit=UnitEnum.CHARACTERS,
|
334 |
+
)
|
335 |
+
|
336 |
+
generation_data = LangfuseGeneration(
|
337 |
+
name=TraceTaskName.SUGGESTED_QUESTION_TRACE.value,
|
338 |
+
input=trace_info.inputs,
|
339 |
+
output=str(trace_info.suggested_question),
|
340 |
+
trace_id=trace_info.message_id,
|
341 |
+
start_time=trace_info.start_time,
|
342 |
+
end_time=trace_info.end_time,
|
343 |
+
metadata=trace_info.metadata,
|
344 |
+
level=(LevelEnum.DEFAULT if message_data.status != "error" else LevelEnum.ERROR),
|
345 |
+
status_message=message_data.error or "",
|
346 |
+
usage=generation_usage,
|
347 |
+
)
|
348 |
+
|
349 |
+
self.add_generation(langfuse_generation_data=generation_data)
|
350 |
+
|
351 |
+
def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo):
|
352 |
+
if trace_info.message_data is None:
|
353 |
+
return
|
354 |
+
dataset_retrieval_span_data = LangfuseSpan(
|
355 |
+
name=TraceTaskName.DATASET_RETRIEVAL_TRACE.value,
|
356 |
+
input=trace_info.inputs,
|
357 |
+
output={"documents": trace_info.documents},
|
358 |
+
trace_id=trace_info.message_id,
|
359 |
+
start_time=trace_info.start_time or trace_info.message_data.created_at,
|
360 |
+
end_time=trace_info.end_time or trace_info.message_data.updated_at,
|
361 |
+
metadata=trace_info.metadata,
|
362 |
+
)
|
363 |
+
|
364 |
+
self.add_span(langfuse_span_data=dataset_retrieval_span_data)
|
365 |
+
|
366 |
+
def tool_trace(self, trace_info: ToolTraceInfo):
|
367 |
+
tool_span_data = LangfuseSpan(
|
368 |
+
name=trace_info.tool_name,
|
369 |
+
input=trace_info.tool_inputs,
|
370 |
+
output=trace_info.tool_outputs,
|
371 |
+
trace_id=trace_info.message_id,
|
372 |
+
start_time=trace_info.start_time,
|
373 |
+
end_time=trace_info.end_time,
|
374 |
+
metadata=trace_info.metadata,
|
375 |
+
level=(LevelEnum.DEFAULT if trace_info.error == "" or trace_info.error is None else LevelEnum.ERROR),
|
376 |
+
status_message=trace_info.error,
|
377 |
+
)
|
378 |
+
|
379 |
+
self.add_span(langfuse_span_data=tool_span_data)
|
380 |
+
|
381 |
+
def generate_name_trace(self, trace_info: GenerateNameTraceInfo):
|
382 |
+
name_generation_trace_data = LangfuseTrace(
|
383 |
+
name=TraceTaskName.GENERATE_NAME_TRACE.value,
|
384 |
+
input=trace_info.inputs,
|
385 |
+
output=trace_info.outputs,
|
386 |
+
user_id=trace_info.tenant_id,
|
387 |
+
metadata=trace_info.metadata,
|
388 |
+
session_id=trace_info.conversation_id,
|
389 |
+
)
|
390 |
+
|
391 |
+
self.add_trace(langfuse_trace_data=name_generation_trace_data)
|
392 |
+
|
393 |
+
name_generation_span_data = LangfuseSpan(
|
394 |
+
name=TraceTaskName.GENERATE_NAME_TRACE.value,
|
395 |
+
input=trace_info.inputs,
|
396 |
+
output=trace_info.outputs,
|
397 |
+
trace_id=trace_info.conversation_id,
|
398 |
+
start_time=trace_info.start_time,
|
399 |
+
end_time=trace_info.end_time,
|
400 |
+
metadata=trace_info.metadata,
|
401 |
+
)
|
402 |
+
self.add_span(langfuse_span_data=name_generation_span_data)
|
403 |
+
|
404 |
+
def add_trace(self, langfuse_trace_data: Optional[LangfuseTrace] = None):
|
405 |
+
format_trace_data = filter_none_values(langfuse_trace_data.model_dump()) if langfuse_trace_data else {}
|
406 |
+
try:
|
407 |
+
self.langfuse_client.trace(**format_trace_data)
|
408 |
+
logger.debug("LangFuse Trace created successfully")
|
409 |
+
except Exception as e:
|
410 |
+
raise ValueError(f"LangFuse Failed to create trace: {str(e)}")
|
411 |
+
|
412 |
+
def add_span(self, langfuse_span_data: Optional[LangfuseSpan] = None):
|
413 |
+
format_span_data = filter_none_values(langfuse_span_data.model_dump()) if langfuse_span_data else {}
|
414 |
+
try:
|
415 |
+
self.langfuse_client.span(**format_span_data)
|
416 |
+
logger.debug("LangFuse Span created successfully")
|
417 |
+
except Exception as e:
|
418 |
+
raise ValueError(f"LangFuse Failed to create span: {str(e)}")
|
419 |
+
|
420 |
+
def update_span(self, span, langfuse_span_data: Optional[LangfuseSpan] = None):
|
421 |
+
format_span_data = filter_none_values(langfuse_span_data.model_dump()) if langfuse_span_data else {}
|
422 |
+
|
423 |
+
span.end(**format_span_data)
|
424 |
+
|
425 |
+
def add_generation(self, langfuse_generation_data: Optional[LangfuseGeneration] = None):
|
426 |
+
format_generation_data = (
|
427 |
+
filter_none_values(langfuse_generation_data.model_dump()) if langfuse_generation_data else {}
|
428 |
+
)
|
429 |
+
try:
|
430 |
+
self.langfuse_client.generation(**format_generation_data)
|
431 |
+
logger.debug("LangFuse Generation created successfully")
|
432 |
+
except Exception as e:
|
433 |
+
raise ValueError(f"LangFuse Failed to create generation: {str(e)}")
|
434 |
+
|
435 |
+
def update_generation(self, generation, langfuse_generation_data: Optional[LangfuseGeneration] = None):
|
436 |
+
format_generation_data = (
|
437 |
+
filter_none_values(langfuse_generation_data.model_dump()) if langfuse_generation_data else {}
|
438 |
+
)
|
439 |
+
|
440 |
+
generation.end(**format_generation_data)
|
441 |
+
|
442 |
+
def api_check(self):
|
443 |
+
try:
|
444 |
+
return self.langfuse_client.auth_check()
|
445 |
+
except Exception as e:
|
446 |
+
logger.debug(f"LangFuse API check failed: {str(e)}")
|
447 |
+
raise ValueError(f"LangFuse API check failed: {str(e)}")
|
448 |
+
|
449 |
+
def get_project_key(self):
|
450 |
+
try:
|
451 |
+
projects = self.langfuse_client.client.projects.get()
|
452 |
+
return projects.data[0].id
|
453 |
+
except Exception as e:
|
454 |
+
logger.debug(f"LangFuse get project key failed: {str(e)}")
|
455 |
+
raise ValueError(f"LangFuse get project key failed: {str(e)}")
|
api/core/ops/langsmith_trace/__init__.py
ADDED
File without changes
|
api/core/ops/langsmith_trace/entities/__init__.py
ADDED
File without changes
|
api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datetime import datetime
|
2 |
+
from enum import StrEnum
|
3 |
+
from typing import Any, Optional, Union
|
4 |
+
|
5 |
+
from pydantic import BaseModel, Field, field_validator
|
6 |
+
from pydantic_core.core_schema import ValidationInfo
|
7 |
+
|
8 |
+
from core.ops.utils import replace_text_with_content
|
9 |
+
|
10 |
+
|
11 |
+
class LangSmithRunType(StrEnum):
|
12 |
+
tool = "tool"
|
13 |
+
chain = "chain"
|
14 |
+
llm = "llm"
|
15 |
+
retriever = "retriever"
|
16 |
+
embedding = "embedding"
|
17 |
+
prompt = "prompt"
|
18 |
+
parser = "parser"
|
19 |
+
|
20 |
+
|
21 |
+
class LangSmithTokenUsage(BaseModel):
|
22 |
+
input_tokens: Optional[int] = None
|
23 |
+
output_tokens: Optional[int] = None
|
24 |
+
total_tokens: Optional[int] = None
|
25 |
+
|
26 |
+
|
27 |
+
class LangSmithMultiModel(BaseModel):
|
28 |
+
file_list: Optional[list[str]] = Field(None, description="List of files")
|
29 |
+
|
30 |
+
|
31 |
+
class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel):
|
32 |
+
name: Optional[str] = Field(..., description="Name of the run")
|
33 |
+
inputs: Optional[Union[str, dict[str, Any], list, None]] = Field(None, description="Inputs of the run")
|
34 |
+
outputs: Optional[Union[str, dict[str, Any], list, None]] = Field(None, description="Outputs of the run")
|
35 |
+
run_type: LangSmithRunType = Field(..., description="Type of the run")
|
36 |
+
start_time: Optional[datetime | str] = Field(None, description="Start time of the run")
|
37 |
+
end_time: Optional[datetime | str] = Field(None, description="End time of the run")
|
38 |
+
extra: Optional[dict[str, Any]] = Field(None, description="Extra information of the run")
|
39 |
+
error: Optional[str] = Field(None, description="Error message of the run")
|
40 |
+
serialized: Optional[dict[str, Any]] = Field(None, description="Serialized data of the run")
|
41 |
+
parent_run_id: Optional[str] = Field(None, description="Parent run ID")
|
42 |
+
events: Optional[list[dict[str, Any]]] = Field(None, description="Events associated with the run")
|
43 |
+
tags: Optional[list[str]] = Field(None, description="Tags associated with the run")
|
44 |
+
trace_id: Optional[str] = Field(None, description="Trace ID associated with the run")
|
45 |
+
dotted_order: Optional[str] = Field(None, description="Dotted order of the run")
|
46 |
+
id: Optional[str] = Field(None, description="ID of the run")
|
47 |
+
session_id: Optional[str] = Field(None, description="Session ID associated with the run")
|
48 |
+
session_name: Optional[str] = Field(None, description="Session name associated with the run")
|
49 |
+
reference_example_id: Optional[str] = Field(None, description="Reference example ID associated with the run")
|
50 |
+
input_attachments: Optional[dict[str, Any]] = Field(None, description="Input attachments of the run")
|
51 |
+
output_attachments: Optional[dict[str, Any]] = Field(None, description="Output attachments of the run")
|
52 |
+
|
53 |
+
@field_validator("inputs", "outputs")
|
54 |
+
@classmethod
|
55 |
+
def ensure_dict(cls, v, info: ValidationInfo):
|
56 |
+
field_name = info.field_name
|
57 |
+
values = info.data
|
58 |
+
if v == {} or v is None:
|
59 |
+
return v
|
60 |
+
usage_metadata = {
|
61 |
+
"input_tokens": values.get("input_tokens", 0),
|
62 |
+
"output_tokens": values.get("output_tokens", 0),
|
63 |
+
"total_tokens": values.get("total_tokens", 0),
|
64 |
+
}
|
65 |
+
file_list = values.get("file_list", [])
|
66 |
+
if isinstance(v, str):
|
67 |
+
if field_name == "inputs":
|
68 |
+
return {
|
69 |
+
"messages": {
|
70 |
+
"role": "user",
|
71 |
+
"content": v,
|
72 |
+
"usage_metadata": usage_metadata,
|
73 |
+
"file_list": file_list,
|
74 |
+
},
|
75 |
+
}
|
76 |
+
elif field_name == "outputs":
|
77 |
+
return {
|
78 |
+
"choices": {
|
79 |
+
"role": "ai",
|
80 |
+
"content": v,
|
81 |
+
"usage_metadata": usage_metadata,
|
82 |
+
"file_list": file_list,
|
83 |
+
},
|
84 |
+
}
|
85 |
+
elif isinstance(v, list):
|
86 |
+
data = {}
|
87 |
+
if len(v) > 0 and isinstance(v[0], dict):
|
88 |
+
# rename text to content
|
89 |
+
v = replace_text_with_content(data=v)
|
90 |
+
if field_name == "inputs":
|
91 |
+
data = {
|
92 |
+
"messages": v,
|
93 |
+
}
|
94 |
+
elif field_name == "outputs":
|
95 |
+
data = {
|
96 |
+
"choices": {
|
97 |
+
"role": "ai",
|
98 |
+
"content": v,
|
99 |
+
"usage_metadata": usage_metadata,
|
100 |
+
"file_list": file_list,
|
101 |
+
},
|
102 |
+
}
|
103 |
+
return data
|
104 |
+
else:
|
105 |
+
return {
|
106 |
+
"choices": {
|
107 |
+
"role": "ai" if field_name == "outputs" else "user",
|
108 |
+
"content": str(v),
|
109 |
+
"usage_metadata": usage_metadata,
|
110 |
+
"file_list": file_list,
|
111 |
+
},
|
112 |
+
}
|
113 |
+
if isinstance(v, dict):
|
114 |
+
v["usage_metadata"] = usage_metadata
|
115 |
+
v["file_list"] = file_list
|
116 |
+
return v
|
117 |
+
return v
|
118 |
+
|
119 |
+
@classmethod
|
120 |
+
@field_validator("start_time", "end_time")
|
121 |
+
def format_time(cls, v, info: ValidationInfo):
|
122 |
+
if not isinstance(v, datetime):
|
123 |
+
raise ValueError(f"{info.field_name} must be a datetime object")
|
124 |
+
else:
|
125 |
+
return v.strftime("%Y-%m-%dT%H:%M:%S.%fZ")
|
126 |
+
|
127 |
+
|
128 |
+
class LangSmithRunUpdateModel(BaseModel):
|
129 |
+
run_id: str = Field(..., description="ID of the run")
|
130 |
+
trace_id: Optional[str] = Field(None, description="Trace ID associated with the run")
|
131 |
+
dotted_order: Optional[str] = Field(None, description="Dotted order of the run")
|
132 |
+
parent_run_id: Optional[str] = Field(None, description="Parent run ID")
|
133 |
+
end_time: Optional[datetime | str] = Field(None, description="End time of the run")
|
134 |
+
error: Optional[str] = Field(None, description="Error message of the run")
|
135 |
+
inputs: Optional[dict[str, Any]] = Field(None, description="Inputs of the run")
|
136 |
+
outputs: Optional[dict[str, Any]] = Field(None, description="Outputs of the run")
|
137 |
+
events: Optional[list[dict[str, Any]]] = Field(None, description="Events associated with the run")
|
138 |
+
tags: Optional[list[str]] = Field(None, description="Tags associated with the run")
|
139 |
+
extra: Optional[dict[str, Any]] = Field(None, description="Extra information of the run")
|
140 |
+
input_attachments: Optional[dict[str, Any]] = Field(None, description="Input attachments of the run")
|
141 |
+
output_attachments: Optional[dict[str, Any]] = Field(None, description="Output attachments of the run")
|
api/core/ops/langsmith_trace/langsmith_trace.py
ADDED
@@ -0,0 +1,524 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
import uuid
|
5 |
+
from datetime import datetime, timedelta
|
6 |
+
from typing import Optional, cast
|
7 |
+
|
8 |
+
from langsmith import Client
|
9 |
+
from langsmith.schemas import RunBase
|
10 |
+
|
11 |
+
from core.ops.base_trace_instance import BaseTraceInstance
|
12 |
+
from core.ops.entities.config_entity import LangSmithConfig
|
13 |
+
from core.ops.entities.trace_entity import (
|
14 |
+
BaseTraceInfo,
|
15 |
+
DatasetRetrievalTraceInfo,
|
16 |
+
GenerateNameTraceInfo,
|
17 |
+
MessageTraceInfo,
|
18 |
+
ModerationTraceInfo,
|
19 |
+
SuggestedQuestionTraceInfo,
|
20 |
+
ToolTraceInfo,
|
21 |
+
TraceTaskName,
|
22 |
+
WorkflowTraceInfo,
|
23 |
+
)
|
24 |
+
from core.ops.langsmith_trace.entities.langsmith_trace_entity import (
|
25 |
+
LangSmithRunModel,
|
26 |
+
LangSmithRunType,
|
27 |
+
LangSmithRunUpdateModel,
|
28 |
+
)
|
29 |
+
from core.ops.utils import filter_none_values, generate_dotted_order
|
30 |
+
from extensions.ext_database import db
|
31 |
+
from models.model import EndUser, MessageFile
|
32 |
+
from models.workflow import WorkflowNodeExecution
|
33 |
+
|
34 |
+
logger = logging.getLogger(__name__)
|
35 |
+
|
36 |
+
|
37 |
+
class LangSmithDataTrace(BaseTraceInstance):
|
38 |
+
def __init__(
|
39 |
+
self,
|
40 |
+
langsmith_config: LangSmithConfig,
|
41 |
+
):
|
42 |
+
super().__init__(langsmith_config)
|
43 |
+
self.langsmith_key = langsmith_config.api_key
|
44 |
+
self.project_name = langsmith_config.project
|
45 |
+
self.project_id = None
|
46 |
+
self.langsmith_client = Client(api_key=langsmith_config.api_key, api_url=langsmith_config.endpoint)
|
47 |
+
self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
|
48 |
+
|
49 |
+
def trace(self, trace_info: BaseTraceInfo):
|
50 |
+
if isinstance(trace_info, WorkflowTraceInfo):
|
51 |
+
self.workflow_trace(trace_info)
|
52 |
+
if isinstance(trace_info, MessageTraceInfo):
|
53 |
+
self.message_trace(trace_info)
|
54 |
+
if isinstance(trace_info, ModerationTraceInfo):
|
55 |
+
self.moderation_trace(trace_info)
|
56 |
+
if isinstance(trace_info, SuggestedQuestionTraceInfo):
|
57 |
+
self.suggested_question_trace(trace_info)
|
58 |
+
if isinstance(trace_info, DatasetRetrievalTraceInfo):
|
59 |
+
self.dataset_retrieval_trace(trace_info)
|
60 |
+
if isinstance(trace_info, ToolTraceInfo):
|
61 |
+
self.tool_trace(trace_info)
|
62 |
+
if isinstance(trace_info, GenerateNameTraceInfo):
|
63 |
+
self.generate_name_trace(trace_info)
|
64 |
+
|
65 |
+
def workflow_trace(self, trace_info: WorkflowTraceInfo):
|
66 |
+
trace_id = trace_info.message_id or trace_info.workflow_run_id
|
67 |
+
if trace_info.start_time is None:
|
68 |
+
trace_info.start_time = datetime.now()
|
69 |
+
message_dotted_order = (
|
70 |
+
generate_dotted_order(trace_info.message_id, trace_info.start_time) if trace_info.message_id else None
|
71 |
+
)
|
72 |
+
workflow_dotted_order = generate_dotted_order(
|
73 |
+
trace_info.workflow_run_id,
|
74 |
+
trace_info.workflow_data.created_at,
|
75 |
+
message_dotted_order,
|
76 |
+
)
|
77 |
+
metadata = trace_info.metadata
|
78 |
+
metadata["workflow_app_log_id"] = trace_info.workflow_app_log_id
|
79 |
+
|
80 |
+
if trace_info.message_id:
|
81 |
+
message_run = LangSmithRunModel(
|
82 |
+
id=trace_info.message_id,
|
83 |
+
name=TraceTaskName.MESSAGE_TRACE.value,
|
84 |
+
inputs=dict(trace_info.workflow_run_inputs),
|
85 |
+
outputs=dict(trace_info.workflow_run_outputs),
|
86 |
+
run_type=LangSmithRunType.chain,
|
87 |
+
start_time=trace_info.start_time,
|
88 |
+
end_time=trace_info.end_time,
|
89 |
+
extra={
|
90 |
+
"metadata": metadata,
|
91 |
+
},
|
92 |
+
tags=["message", "workflow"],
|
93 |
+
error=trace_info.error,
|
94 |
+
trace_id=trace_id,
|
95 |
+
dotted_order=message_dotted_order,
|
96 |
+
file_list=[],
|
97 |
+
serialized=None,
|
98 |
+
parent_run_id=None,
|
99 |
+
events=[],
|
100 |
+
session_id=None,
|
101 |
+
session_name=None,
|
102 |
+
reference_example_id=None,
|
103 |
+
input_attachments={},
|
104 |
+
output_attachments={},
|
105 |
+
)
|
106 |
+
self.add_run(message_run)
|
107 |
+
|
108 |
+
langsmith_run = LangSmithRunModel(
|
109 |
+
file_list=trace_info.file_list,
|
110 |
+
total_tokens=trace_info.total_tokens,
|
111 |
+
id=trace_info.workflow_run_id,
|
112 |
+
name=TraceTaskName.WORKFLOW_TRACE.value,
|
113 |
+
inputs=dict(trace_info.workflow_run_inputs),
|
114 |
+
run_type=LangSmithRunType.tool,
|
115 |
+
start_time=trace_info.workflow_data.created_at,
|
116 |
+
end_time=trace_info.workflow_data.finished_at,
|
117 |
+
outputs=dict(trace_info.workflow_run_outputs),
|
118 |
+
extra={
|
119 |
+
"metadata": metadata,
|
120 |
+
},
|
121 |
+
error=trace_info.error,
|
122 |
+
tags=["workflow"],
|
123 |
+
parent_run_id=trace_info.message_id or None,
|
124 |
+
trace_id=trace_id,
|
125 |
+
dotted_order=workflow_dotted_order,
|
126 |
+
serialized=None,
|
127 |
+
events=[],
|
128 |
+
session_id=None,
|
129 |
+
session_name=None,
|
130 |
+
reference_example_id=None,
|
131 |
+
input_attachments={},
|
132 |
+
output_attachments={},
|
133 |
+
)
|
134 |
+
|
135 |
+
self.add_run(langsmith_run)
|
136 |
+
|
137 |
+
# through workflow_run_id get all_nodes_execution
|
138 |
+
workflow_nodes_execution_id_records = (
|
139 |
+
db.session.query(WorkflowNodeExecution.id)
|
140 |
+
.filter(WorkflowNodeExecution.workflow_run_id == trace_info.workflow_run_id)
|
141 |
+
.all()
|
142 |
+
)
|
143 |
+
|
144 |
+
for node_execution_id_record in workflow_nodes_execution_id_records:
|
145 |
+
node_execution = (
|
146 |
+
db.session.query(
|
147 |
+
WorkflowNodeExecution.id,
|
148 |
+
WorkflowNodeExecution.tenant_id,
|
149 |
+
WorkflowNodeExecution.app_id,
|
150 |
+
WorkflowNodeExecution.title,
|
151 |
+
WorkflowNodeExecution.node_type,
|
152 |
+
WorkflowNodeExecution.status,
|
153 |
+
WorkflowNodeExecution.inputs,
|
154 |
+
WorkflowNodeExecution.outputs,
|
155 |
+
WorkflowNodeExecution.created_at,
|
156 |
+
WorkflowNodeExecution.elapsed_time,
|
157 |
+
WorkflowNodeExecution.process_data,
|
158 |
+
WorkflowNodeExecution.execution_metadata,
|
159 |
+
)
|
160 |
+
.filter(WorkflowNodeExecution.id == node_execution_id_record.id)
|
161 |
+
.first()
|
162 |
+
)
|
163 |
+
|
164 |
+
if not node_execution:
|
165 |
+
continue
|
166 |
+
|
167 |
+
node_execution_id = node_execution.id
|
168 |
+
tenant_id = node_execution.tenant_id
|
169 |
+
app_id = node_execution.app_id
|
170 |
+
node_name = node_execution.title
|
171 |
+
node_type = node_execution.node_type
|
172 |
+
status = node_execution.status
|
173 |
+
if node_type == "llm":
|
174 |
+
inputs = (
|
175 |
+
json.loads(node_execution.process_data).get("prompts", {}) if node_execution.process_data else {}
|
176 |
+
)
|
177 |
+
else:
|
178 |
+
inputs = json.loads(node_execution.inputs) if node_execution.inputs else {}
|
179 |
+
outputs = json.loads(node_execution.outputs) if node_execution.outputs else {}
|
180 |
+
created_at = node_execution.created_at or datetime.now()
|
181 |
+
elapsed_time = node_execution.elapsed_time
|
182 |
+
finished_at = created_at + timedelta(seconds=elapsed_time)
|
183 |
+
|
184 |
+
execution_metadata = (
|
185 |
+
json.loads(node_execution.execution_metadata) if node_execution.execution_metadata else {}
|
186 |
+
)
|
187 |
+
node_total_tokens = execution_metadata.get("total_tokens", 0)
|
188 |
+
metadata = execution_metadata.copy()
|
189 |
+
metadata.update(
|
190 |
+
{
|
191 |
+
"workflow_run_id": trace_info.workflow_run_id,
|
192 |
+
"node_execution_id": node_execution_id,
|
193 |
+
"tenant_id": tenant_id,
|
194 |
+
"app_id": app_id,
|
195 |
+
"app_name": node_name,
|
196 |
+
"node_type": node_type,
|
197 |
+
"status": status,
|
198 |
+
}
|
199 |
+
)
|
200 |
+
|
201 |
+
process_data = json.loads(node_execution.process_data) if node_execution.process_data else {}
|
202 |
+
if process_data and process_data.get("model_mode") == "chat":
|
203 |
+
run_type = LangSmithRunType.llm
|
204 |
+
metadata.update(
|
205 |
+
{
|
206 |
+
"ls_provider": process_data.get("model_provider", ""),
|
207 |
+
"ls_model_name": process_data.get("model_name", ""),
|
208 |
+
}
|
209 |
+
)
|
210 |
+
elif node_type == "knowledge-retrieval":
|
211 |
+
run_type = LangSmithRunType.retriever
|
212 |
+
else:
|
213 |
+
run_type = LangSmithRunType.tool
|
214 |
+
|
215 |
+
node_dotted_order = generate_dotted_order(node_execution_id, created_at, workflow_dotted_order)
|
216 |
+
langsmith_run = LangSmithRunModel(
|
217 |
+
total_tokens=node_total_tokens,
|
218 |
+
name=node_type,
|
219 |
+
inputs=inputs,
|
220 |
+
run_type=run_type,
|
221 |
+
start_time=created_at,
|
222 |
+
end_time=finished_at,
|
223 |
+
outputs=outputs,
|
224 |
+
file_list=trace_info.file_list,
|
225 |
+
extra={
|
226 |
+
"metadata": metadata,
|
227 |
+
},
|
228 |
+
parent_run_id=trace_info.workflow_run_id,
|
229 |
+
tags=["node_execution"],
|
230 |
+
id=node_execution_id,
|
231 |
+
trace_id=trace_id,
|
232 |
+
dotted_order=node_dotted_order,
|
233 |
+
error="",
|
234 |
+
serialized=None,
|
235 |
+
events=[],
|
236 |
+
session_id=None,
|
237 |
+
session_name=None,
|
238 |
+
reference_example_id=None,
|
239 |
+
input_attachments={},
|
240 |
+
output_attachments={},
|
241 |
+
)
|
242 |
+
|
243 |
+
self.add_run(langsmith_run)
|
244 |
+
|
245 |
+
def message_trace(self, trace_info: MessageTraceInfo):
|
246 |
+
# get message file data
|
247 |
+
file_list = cast(list[str], trace_info.file_list) or []
|
248 |
+
message_file_data: Optional[MessageFile] = trace_info.message_file_data
|
249 |
+
file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else ""
|
250 |
+
file_list.append(file_url)
|
251 |
+
metadata = trace_info.metadata
|
252 |
+
message_data = trace_info.message_data
|
253 |
+
if message_data is None:
|
254 |
+
return
|
255 |
+
message_id = message_data.id
|
256 |
+
|
257 |
+
user_id = message_data.from_account_id
|
258 |
+
metadata["user_id"] = user_id
|
259 |
+
|
260 |
+
if message_data.from_end_user_id:
|
261 |
+
end_user_data: Optional[EndUser] = (
|
262 |
+
db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first()
|
263 |
+
)
|
264 |
+
if end_user_data is not None:
|
265 |
+
end_user_id = end_user_data.session_id
|
266 |
+
metadata["end_user_id"] = end_user_id
|
267 |
+
|
268 |
+
message_run = LangSmithRunModel(
|
269 |
+
input_tokens=trace_info.message_tokens,
|
270 |
+
output_tokens=trace_info.answer_tokens,
|
271 |
+
total_tokens=trace_info.total_tokens,
|
272 |
+
id=message_id,
|
273 |
+
name=TraceTaskName.MESSAGE_TRACE.value,
|
274 |
+
inputs=trace_info.inputs,
|
275 |
+
run_type=LangSmithRunType.chain,
|
276 |
+
start_time=trace_info.start_time,
|
277 |
+
end_time=trace_info.end_time,
|
278 |
+
outputs=message_data.answer,
|
279 |
+
extra={"metadata": metadata},
|
280 |
+
tags=["message", str(trace_info.conversation_mode)],
|
281 |
+
error=trace_info.error,
|
282 |
+
file_list=file_list,
|
283 |
+
serialized=None,
|
284 |
+
events=[],
|
285 |
+
session_id=None,
|
286 |
+
session_name=None,
|
287 |
+
reference_example_id=None,
|
288 |
+
input_attachments={},
|
289 |
+
output_attachments={},
|
290 |
+
trace_id=None,
|
291 |
+
dotted_order=None,
|
292 |
+
parent_run_id=None,
|
293 |
+
)
|
294 |
+
self.add_run(message_run)
|
295 |
+
|
296 |
+
# create llm run parented to message run
|
297 |
+
llm_run = LangSmithRunModel(
|
298 |
+
input_tokens=trace_info.message_tokens,
|
299 |
+
output_tokens=trace_info.answer_tokens,
|
300 |
+
total_tokens=trace_info.total_tokens,
|
301 |
+
name="llm",
|
302 |
+
inputs=trace_info.inputs,
|
303 |
+
run_type=LangSmithRunType.llm,
|
304 |
+
start_time=trace_info.start_time,
|
305 |
+
end_time=trace_info.end_time,
|
306 |
+
outputs=message_data.answer,
|
307 |
+
extra={"metadata": metadata},
|
308 |
+
parent_run_id=message_id,
|
309 |
+
tags=["llm", str(trace_info.conversation_mode)],
|
310 |
+
error=trace_info.error,
|
311 |
+
file_list=file_list,
|
312 |
+
serialized=None,
|
313 |
+
events=[],
|
314 |
+
session_id=None,
|
315 |
+
session_name=None,
|
316 |
+
reference_example_id=None,
|
317 |
+
input_attachments={},
|
318 |
+
output_attachments={},
|
319 |
+
trace_id=None,
|
320 |
+
dotted_order=None,
|
321 |
+
id=str(uuid.uuid4()),
|
322 |
+
)
|
323 |
+
self.add_run(llm_run)
|
324 |
+
|
325 |
+
def moderation_trace(self, trace_info: ModerationTraceInfo):
|
326 |
+
if trace_info.message_data is None:
|
327 |
+
return
|
328 |
+
langsmith_run = LangSmithRunModel(
|
329 |
+
name=TraceTaskName.MODERATION_TRACE.value,
|
330 |
+
inputs=trace_info.inputs,
|
331 |
+
outputs={
|
332 |
+
"action": trace_info.action,
|
333 |
+
"flagged": trace_info.flagged,
|
334 |
+
"preset_response": trace_info.preset_response,
|
335 |
+
"inputs": trace_info.inputs,
|
336 |
+
},
|
337 |
+
run_type=LangSmithRunType.tool,
|
338 |
+
extra={"metadata": trace_info.metadata},
|
339 |
+
tags=["moderation"],
|
340 |
+
parent_run_id=trace_info.message_id,
|
341 |
+
start_time=trace_info.start_time or trace_info.message_data.created_at,
|
342 |
+
end_time=trace_info.end_time or trace_info.message_data.updated_at,
|
343 |
+
id=str(uuid.uuid4()),
|
344 |
+
serialized=None,
|
345 |
+
events=[],
|
346 |
+
session_id=None,
|
347 |
+
session_name=None,
|
348 |
+
reference_example_id=None,
|
349 |
+
input_attachments={},
|
350 |
+
output_attachments={},
|
351 |
+
trace_id=None,
|
352 |
+
dotted_order=None,
|
353 |
+
error="",
|
354 |
+
file_list=[],
|
355 |
+
)
|
356 |
+
|
357 |
+
self.add_run(langsmith_run)
|
358 |
+
|
359 |
+
def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo):
|
360 |
+
message_data = trace_info.message_data
|
361 |
+
if message_data is None:
|
362 |
+
return
|
363 |
+
suggested_question_run = LangSmithRunModel(
|
364 |
+
name=TraceTaskName.SUGGESTED_QUESTION_TRACE.value,
|
365 |
+
inputs=trace_info.inputs,
|
366 |
+
outputs=trace_info.suggested_question,
|
367 |
+
run_type=LangSmithRunType.tool,
|
368 |
+
extra={"metadata": trace_info.metadata},
|
369 |
+
tags=["suggested_question"],
|
370 |
+
parent_run_id=trace_info.message_id,
|
371 |
+
start_time=trace_info.start_time or message_data.created_at,
|
372 |
+
end_time=trace_info.end_time or message_data.updated_at,
|
373 |
+
id=str(uuid.uuid4()),
|
374 |
+
serialized=None,
|
375 |
+
events=[],
|
376 |
+
session_id=None,
|
377 |
+
session_name=None,
|
378 |
+
reference_example_id=None,
|
379 |
+
input_attachments={},
|
380 |
+
output_attachments={},
|
381 |
+
trace_id=None,
|
382 |
+
dotted_order=None,
|
383 |
+
error="",
|
384 |
+
file_list=[],
|
385 |
+
)
|
386 |
+
|
387 |
+
self.add_run(suggested_question_run)
|
388 |
+
|
389 |
+
def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo):
|
390 |
+
if trace_info.message_data is None:
|
391 |
+
return
|
392 |
+
dataset_retrieval_run = LangSmithRunModel(
|
393 |
+
name=TraceTaskName.DATASET_RETRIEVAL_TRACE.value,
|
394 |
+
inputs=trace_info.inputs,
|
395 |
+
outputs={"documents": trace_info.documents},
|
396 |
+
run_type=LangSmithRunType.retriever,
|
397 |
+
extra={"metadata": trace_info.metadata},
|
398 |
+
tags=["dataset_retrieval"],
|
399 |
+
parent_run_id=trace_info.message_id,
|
400 |
+
start_time=trace_info.start_time or trace_info.message_data.created_at,
|
401 |
+
end_time=trace_info.end_time or trace_info.message_data.updated_at,
|
402 |
+
id=str(uuid.uuid4()),
|
403 |
+
serialized=None,
|
404 |
+
events=[],
|
405 |
+
session_id=None,
|
406 |
+
session_name=None,
|
407 |
+
reference_example_id=None,
|
408 |
+
input_attachments={},
|
409 |
+
output_attachments={},
|
410 |
+
trace_id=None,
|
411 |
+
dotted_order=None,
|
412 |
+
error="",
|
413 |
+
file_list=[],
|
414 |
+
)
|
415 |
+
|
416 |
+
self.add_run(dataset_retrieval_run)
|
417 |
+
|
418 |
+
def tool_trace(self, trace_info: ToolTraceInfo):
|
419 |
+
tool_run = LangSmithRunModel(
|
420 |
+
name=trace_info.tool_name,
|
421 |
+
inputs=trace_info.tool_inputs,
|
422 |
+
outputs=trace_info.tool_outputs,
|
423 |
+
run_type=LangSmithRunType.tool,
|
424 |
+
extra={
|
425 |
+
"metadata": trace_info.metadata,
|
426 |
+
},
|
427 |
+
tags=["tool", trace_info.tool_name],
|
428 |
+
parent_run_id=trace_info.message_id,
|
429 |
+
start_time=trace_info.start_time,
|
430 |
+
end_time=trace_info.end_time,
|
431 |
+
file_list=[cast(str, trace_info.file_url)],
|
432 |
+
id=str(uuid.uuid4()),
|
433 |
+
serialized=None,
|
434 |
+
events=[],
|
435 |
+
session_id=None,
|
436 |
+
session_name=None,
|
437 |
+
reference_example_id=None,
|
438 |
+
input_attachments={},
|
439 |
+
output_attachments={},
|
440 |
+
trace_id=None,
|
441 |
+
dotted_order=None,
|
442 |
+
error=trace_info.error or "",
|
443 |
+
)
|
444 |
+
|
445 |
+
self.add_run(tool_run)
|
446 |
+
|
447 |
+
def generate_name_trace(self, trace_info: GenerateNameTraceInfo):
|
448 |
+
name_run = LangSmithRunModel(
|
449 |
+
name=TraceTaskName.GENERATE_NAME_TRACE.value,
|
450 |
+
inputs=trace_info.inputs,
|
451 |
+
outputs=trace_info.outputs,
|
452 |
+
run_type=LangSmithRunType.tool,
|
453 |
+
extra={"metadata": trace_info.metadata},
|
454 |
+
tags=["generate_name"],
|
455 |
+
start_time=trace_info.start_time or datetime.now(),
|
456 |
+
end_time=trace_info.end_time or datetime.now(),
|
457 |
+
id=str(uuid.uuid4()),
|
458 |
+
serialized=None,
|
459 |
+
events=[],
|
460 |
+
session_id=None,
|
461 |
+
session_name=None,
|
462 |
+
reference_example_id=None,
|
463 |
+
input_attachments={},
|
464 |
+
output_attachments={},
|
465 |
+
trace_id=None,
|
466 |
+
dotted_order=None,
|
467 |
+
error="",
|
468 |
+
file_list=[],
|
469 |
+
parent_run_id=None,
|
470 |
+
)
|
471 |
+
|
472 |
+
self.add_run(name_run)
|
473 |
+
|
474 |
+
def add_run(self, run_data: LangSmithRunModel):
|
475 |
+
data = run_data.model_dump()
|
476 |
+
if self.project_id:
|
477 |
+
data["session_id"] = self.project_id
|
478 |
+
elif self.project_name:
|
479 |
+
data["session_name"] = self.project_name
|
480 |
+
|
481 |
+
data = filter_none_values(data)
|
482 |
+
try:
|
483 |
+
self.langsmith_client.create_run(**data)
|
484 |
+
logger.debug("LangSmith Run created successfully.")
|
485 |
+
except Exception as e:
|
486 |
+
raise ValueError(f"LangSmith Failed to create run: {str(e)}")
|
487 |
+
|
488 |
+
def update_run(self, update_run_data: LangSmithRunUpdateModel):
|
489 |
+
data = update_run_data.model_dump()
|
490 |
+
data = filter_none_values(data)
|
491 |
+
try:
|
492 |
+
self.langsmith_client.update_run(**data)
|
493 |
+
logger.debug("LangSmith Run updated successfully.")
|
494 |
+
except Exception as e:
|
495 |
+
raise ValueError(f"LangSmith Failed to update run: {str(e)}")
|
496 |
+
|
497 |
+
def api_check(self):
|
498 |
+
try:
|
499 |
+
random_project_name = f"test_project_{datetime.now().strftime('%Y%m%d%H%M%S')}"
|
500 |
+
self.langsmith_client.create_project(project_name=random_project_name)
|
501 |
+
self.langsmith_client.delete_project(project_name=random_project_name)
|
502 |
+
return True
|
503 |
+
except Exception as e:
|
504 |
+
logger.debug(f"LangSmith API check failed: {str(e)}")
|
505 |
+
raise ValueError(f"LangSmith API check failed: {str(e)}")
|
506 |
+
|
507 |
+
def get_project_url(self):
|
508 |
+
try:
|
509 |
+
run_data = RunBase(
|
510 |
+
id=uuid.uuid4(),
|
511 |
+
name="tool",
|
512 |
+
inputs={"input": "test"},
|
513 |
+
outputs={"output": "test"},
|
514 |
+
run_type=LangSmithRunType.tool,
|
515 |
+
start_time=datetime.now(),
|
516 |
+
)
|
517 |
+
|
518 |
+
project_url = self.langsmith_client.get_run_url(
|
519 |
+
run=run_data, project_id=self.project_id, project_name=self.project_name
|
520 |
+
)
|
521 |
+
return project_url.split("/r/")[0]
|
522 |
+
except Exception as e:
|
523 |
+
logger.debug(f"LangSmith get run url failed: {str(e)}")
|
524 |
+
raise ValueError(f"LangSmith get run url failed: {str(e)}")
|
api/core/ops/opik_trace/__init__.py
ADDED
File without changes
|
api/core/ops/opik_trace/opik_trace.py
ADDED
@@ -0,0 +1,469 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
import uuid
|
5 |
+
from datetime import datetime, timedelta
|
6 |
+
from typing import Optional, cast
|
7 |
+
|
8 |
+
from opik import Opik, Trace
|
9 |
+
from opik.id_helpers import uuid4_to_uuid7
|
10 |
+
|
11 |
+
from core.ops.base_trace_instance import BaseTraceInstance
|
12 |
+
from core.ops.entities.config_entity import OpikConfig
|
13 |
+
from core.ops.entities.trace_entity import (
|
14 |
+
BaseTraceInfo,
|
15 |
+
DatasetRetrievalTraceInfo,
|
16 |
+
GenerateNameTraceInfo,
|
17 |
+
MessageTraceInfo,
|
18 |
+
ModerationTraceInfo,
|
19 |
+
SuggestedQuestionTraceInfo,
|
20 |
+
ToolTraceInfo,
|
21 |
+
TraceTaskName,
|
22 |
+
WorkflowTraceInfo,
|
23 |
+
)
|
24 |
+
from extensions.ext_database import db
|
25 |
+
from models.model import EndUser, MessageFile
|
26 |
+
from models.workflow import WorkflowNodeExecution
|
27 |
+
|
28 |
+
logger = logging.getLogger(__name__)
|
29 |
+
|
30 |
+
|
31 |
+
def wrap_dict(key_name, data):
|
32 |
+
"""Make sure that the input data is a dict"""
|
33 |
+
if not isinstance(data, dict):
|
34 |
+
return {key_name: data}
|
35 |
+
|
36 |
+
return data
|
37 |
+
|
38 |
+
|
39 |
+
def wrap_metadata(metadata, **kwargs):
|
40 |
+
"""Add common metatada to all Traces and Spans"""
|
41 |
+
metadata["created_from"] = "dify"
|
42 |
+
|
43 |
+
metadata.update(kwargs)
|
44 |
+
|
45 |
+
return metadata
|
46 |
+
|
47 |
+
|
48 |
+
def prepare_opik_uuid(user_datetime: Optional[datetime], user_uuid: Optional[str]):
|
49 |
+
"""Opik needs UUIDv7 while Dify uses UUIDv4 for identifier of most
|
50 |
+
messages and objects. The type-hints of BaseTraceInfo indicates that
|
51 |
+
objects start_time and message_id could be null which means we cannot map
|
52 |
+
it to a UUIDv7. Given that we have no way to identify that object
|
53 |
+
uniquely, generate a new random one UUIDv7 in that case.
|
54 |
+
"""
|
55 |
+
|
56 |
+
if user_datetime is None:
|
57 |
+
user_datetime = datetime.now()
|
58 |
+
|
59 |
+
if user_uuid is None:
|
60 |
+
user_uuid = str(uuid.uuid4())
|
61 |
+
|
62 |
+
return uuid4_to_uuid7(user_datetime, user_uuid)
|
63 |
+
|
64 |
+
|
65 |
+
class OpikDataTrace(BaseTraceInstance):
|
66 |
+
def __init__(
|
67 |
+
self,
|
68 |
+
opik_config: OpikConfig,
|
69 |
+
):
|
70 |
+
super().__init__(opik_config)
|
71 |
+
self.opik_client = Opik(
|
72 |
+
project_name=opik_config.project,
|
73 |
+
workspace=opik_config.workspace,
|
74 |
+
host=opik_config.url,
|
75 |
+
api_key=opik_config.api_key,
|
76 |
+
)
|
77 |
+
self.project = opik_config.project
|
78 |
+
self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
|
79 |
+
|
80 |
+
def trace(self, trace_info: BaseTraceInfo):
|
81 |
+
if isinstance(trace_info, WorkflowTraceInfo):
|
82 |
+
self.workflow_trace(trace_info)
|
83 |
+
if isinstance(trace_info, MessageTraceInfo):
|
84 |
+
self.message_trace(trace_info)
|
85 |
+
if isinstance(trace_info, ModerationTraceInfo):
|
86 |
+
self.moderation_trace(trace_info)
|
87 |
+
if isinstance(trace_info, SuggestedQuestionTraceInfo):
|
88 |
+
self.suggested_question_trace(trace_info)
|
89 |
+
if isinstance(trace_info, DatasetRetrievalTraceInfo):
|
90 |
+
self.dataset_retrieval_trace(trace_info)
|
91 |
+
if isinstance(trace_info, ToolTraceInfo):
|
92 |
+
self.tool_trace(trace_info)
|
93 |
+
if isinstance(trace_info, GenerateNameTraceInfo):
|
94 |
+
self.generate_name_trace(trace_info)
|
95 |
+
|
96 |
+
def workflow_trace(self, trace_info: WorkflowTraceInfo):
|
97 |
+
dify_trace_id = trace_info.workflow_run_id
|
98 |
+
opik_trace_id = prepare_opik_uuid(trace_info.start_time, dify_trace_id)
|
99 |
+
workflow_metadata = wrap_metadata(
|
100 |
+
trace_info.metadata, message_id=trace_info.message_id, workflow_app_log_id=trace_info.workflow_app_log_id
|
101 |
+
)
|
102 |
+
root_span_id = None
|
103 |
+
|
104 |
+
if trace_info.message_id:
|
105 |
+
dify_trace_id = trace_info.message_id
|
106 |
+
opik_trace_id = prepare_opik_uuid(trace_info.start_time, dify_trace_id)
|
107 |
+
|
108 |
+
trace_data = {
|
109 |
+
"id": opik_trace_id,
|
110 |
+
"name": TraceTaskName.MESSAGE_TRACE.value,
|
111 |
+
"start_time": trace_info.start_time,
|
112 |
+
"end_time": trace_info.end_time,
|
113 |
+
"metadata": workflow_metadata,
|
114 |
+
"input": wrap_dict("input", trace_info.workflow_run_inputs),
|
115 |
+
"output": wrap_dict("output", trace_info.workflow_run_outputs),
|
116 |
+
"tags": ["message", "workflow"],
|
117 |
+
"project_name": self.project,
|
118 |
+
}
|
119 |
+
self.add_trace(trace_data)
|
120 |
+
|
121 |
+
root_span_id = prepare_opik_uuid(trace_info.start_time, trace_info.workflow_run_id)
|
122 |
+
span_data = {
|
123 |
+
"id": root_span_id,
|
124 |
+
"parent_span_id": None,
|
125 |
+
"trace_id": opik_trace_id,
|
126 |
+
"name": TraceTaskName.WORKFLOW_TRACE.value,
|
127 |
+
"input": wrap_dict("input", trace_info.workflow_run_inputs),
|
128 |
+
"output": wrap_dict("output", trace_info.workflow_run_outputs),
|
129 |
+
"start_time": trace_info.start_time,
|
130 |
+
"end_time": trace_info.end_time,
|
131 |
+
"metadata": workflow_metadata,
|
132 |
+
"tags": ["workflow"],
|
133 |
+
"project_name": self.project,
|
134 |
+
}
|
135 |
+
self.add_span(span_data)
|
136 |
+
else:
|
137 |
+
trace_data = {
|
138 |
+
"id": opik_trace_id,
|
139 |
+
"name": TraceTaskName.MESSAGE_TRACE.value,
|
140 |
+
"start_time": trace_info.start_time,
|
141 |
+
"end_time": trace_info.end_time,
|
142 |
+
"metadata": workflow_metadata,
|
143 |
+
"input": wrap_dict("input", trace_info.workflow_run_inputs),
|
144 |
+
"output": wrap_dict("output", trace_info.workflow_run_outputs),
|
145 |
+
"tags": ["workflow"],
|
146 |
+
"project_name": self.project,
|
147 |
+
}
|
148 |
+
self.add_trace(trace_data)
|
149 |
+
|
150 |
+
# through workflow_run_id get all_nodes_execution
|
151 |
+
workflow_nodes_execution_id_records = (
|
152 |
+
db.session.query(WorkflowNodeExecution.id)
|
153 |
+
.filter(WorkflowNodeExecution.workflow_run_id == trace_info.workflow_run_id)
|
154 |
+
.all()
|
155 |
+
)
|
156 |
+
|
157 |
+
for node_execution_id_record in workflow_nodes_execution_id_records:
|
158 |
+
node_execution = (
|
159 |
+
db.session.query(
|
160 |
+
WorkflowNodeExecution.id,
|
161 |
+
WorkflowNodeExecution.tenant_id,
|
162 |
+
WorkflowNodeExecution.app_id,
|
163 |
+
WorkflowNodeExecution.title,
|
164 |
+
WorkflowNodeExecution.node_type,
|
165 |
+
WorkflowNodeExecution.status,
|
166 |
+
WorkflowNodeExecution.inputs,
|
167 |
+
WorkflowNodeExecution.outputs,
|
168 |
+
WorkflowNodeExecution.created_at,
|
169 |
+
WorkflowNodeExecution.elapsed_time,
|
170 |
+
WorkflowNodeExecution.process_data,
|
171 |
+
WorkflowNodeExecution.execution_metadata,
|
172 |
+
)
|
173 |
+
.filter(WorkflowNodeExecution.id == node_execution_id_record.id)
|
174 |
+
.first()
|
175 |
+
)
|
176 |
+
|
177 |
+
if not node_execution:
|
178 |
+
continue
|
179 |
+
|
180 |
+
node_execution_id = node_execution.id
|
181 |
+
tenant_id = node_execution.tenant_id
|
182 |
+
app_id = node_execution.app_id
|
183 |
+
node_name = node_execution.title
|
184 |
+
node_type = node_execution.node_type
|
185 |
+
status = node_execution.status
|
186 |
+
if node_type == "llm":
|
187 |
+
inputs = (
|
188 |
+
json.loads(node_execution.process_data).get("prompts", {}) if node_execution.process_data else {}
|
189 |
+
)
|
190 |
+
else:
|
191 |
+
inputs = json.loads(node_execution.inputs) if node_execution.inputs else {}
|
192 |
+
outputs = json.loads(node_execution.outputs) if node_execution.outputs else {}
|
193 |
+
created_at = node_execution.created_at or datetime.now()
|
194 |
+
elapsed_time = node_execution.elapsed_time
|
195 |
+
finished_at = created_at + timedelta(seconds=elapsed_time)
|
196 |
+
|
197 |
+
execution_metadata = (
|
198 |
+
json.loads(node_execution.execution_metadata) if node_execution.execution_metadata else {}
|
199 |
+
)
|
200 |
+
metadata = execution_metadata.copy()
|
201 |
+
metadata.update(
|
202 |
+
{
|
203 |
+
"workflow_run_id": trace_info.workflow_run_id,
|
204 |
+
"node_execution_id": node_execution_id,
|
205 |
+
"tenant_id": tenant_id,
|
206 |
+
"app_id": app_id,
|
207 |
+
"app_name": node_name,
|
208 |
+
"node_type": node_type,
|
209 |
+
"status": status,
|
210 |
+
}
|
211 |
+
)
|
212 |
+
|
213 |
+
process_data = json.loads(node_execution.process_data) if node_execution.process_data else {}
|
214 |
+
|
215 |
+
provider = None
|
216 |
+
model = None
|
217 |
+
total_tokens = 0
|
218 |
+
completion_tokens = 0
|
219 |
+
prompt_tokens = 0
|
220 |
+
|
221 |
+
if process_data and process_data.get("model_mode") == "chat":
|
222 |
+
run_type = "llm"
|
223 |
+
provider = process_data.get("model_provider", None)
|
224 |
+
model = process_data.get("model_name", "")
|
225 |
+
metadata.update(
|
226 |
+
{
|
227 |
+
"ls_provider": provider,
|
228 |
+
"ls_model_name": model,
|
229 |
+
}
|
230 |
+
)
|
231 |
+
|
232 |
+
try:
|
233 |
+
if outputs.get("usage"):
|
234 |
+
total_tokens = outputs["usage"].get("total_tokens", 0)
|
235 |
+
prompt_tokens = outputs["usage"].get("prompt_tokens", 0)
|
236 |
+
completion_tokens = outputs["usage"].get("completion_tokens", 0)
|
237 |
+
except Exception:
|
238 |
+
logger.error("Failed to extract usage", exc_info=True)
|
239 |
+
|
240 |
+
else:
|
241 |
+
run_type = "tool"
|
242 |
+
|
243 |
+
parent_span_id = trace_info.workflow_app_log_id or trace_info.workflow_run_id
|
244 |
+
|
245 |
+
if not total_tokens:
|
246 |
+
total_tokens = execution_metadata.get("total_tokens", 0)
|
247 |
+
|
248 |
+
span_data = {
|
249 |
+
"trace_id": opik_trace_id,
|
250 |
+
"id": prepare_opik_uuid(created_at, node_execution_id),
|
251 |
+
"parent_span_id": prepare_opik_uuid(trace_info.start_time, parent_span_id),
|
252 |
+
"name": node_type,
|
253 |
+
"type": run_type,
|
254 |
+
"start_time": created_at,
|
255 |
+
"end_time": finished_at,
|
256 |
+
"metadata": wrap_metadata(metadata),
|
257 |
+
"input": wrap_dict("input", inputs),
|
258 |
+
"output": wrap_dict("output", outputs),
|
259 |
+
"tags": ["node_execution"],
|
260 |
+
"project_name": self.project,
|
261 |
+
"usage": {
|
262 |
+
"total_tokens": total_tokens,
|
263 |
+
"completion_tokens": completion_tokens,
|
264 |
+
"prompt_tokens": prompt_tokens,
|
265 |
+
},
|
266 |
+
"model": model,
|
267 |
+
"provider": provider,
|
268 |
+
}
|
269 |
+
|
270 |
+
self.add_span(span_data)
|
271 |
+
|
272 |
+
def message_trace(self, trace_info: MessageTraceInfo):
|
273 |
+
# get message file data
|
274 |
+
file_list = cast(list[str], trace_info.file_list) or []
|
275 |
+
message_file_data: Optional[MessageFile] = trace_info.message_file_data
|
276 |
+
|
277 |
+
if message_file_data is not None:
|
278 |
+
file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else ""
|
279 |
+
file_list.append(file_url)
|
280 |
+
|
281 |
+
message_data = trace_info.message_data
|
282 |
+
if message_data is None:
|
283 |
+
return
|
284 |
+
|
285 |
+
metadata = trace_info.metadata
|
286 |
+
message_id = trace_info.message_id
|
287 |
+
|
288 |
+
user_id = message_data.from_account_id
|
289 |
+
metadata["user_id"] = user_id
|
290 |
+
metadata["file_list"] = file_list
|
291 |
+
|
292 |
+
if message_data.from_end_user_id:
|
293 |
+
end_user_data: Optional[EndUser] = (
|
294 |
+
db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first()
|
295 |
+
)
|
296 |
+
if end_user_data is not None:
|
297 |
+
end_user_id = end_user_data.session_id
|
298 |
+
metadata["end_user_id"] = end_user_id
|
299 |
+
|
300 |
+
trace_data = {
|
301 |
+
"id": prepare_opik_uuid(trace_info.start_time, message_id),
|
302 |
+
"name": TraceTaskName.MESSAGE_TRACE.value,
|
303 |
+
"start_time": trace_info.start_time,
|
304 |
+
"end_time": trace_info.end_time,
|
305 |
+
"metadata": wrap_metadata(metadata),
|
306 |
+
"input": trace_info.inputs,
|
307 |
+
"output": message_data.answer,
|
308 |
+
"tags": ["message", str(trace_info.conversation_mode)],
|
309 |
+
"project_name": self.project,
|
310 |
+
}
|
311 |
+
trace = self.add_trace(trace_data)
|
312 |
+
|
313 |
+
span_data = {
|
314 |
+
"trace_id": trace.id,
|
315 |
+
"name": "llm",
|
316 |
+
"type": "llm",
|
317 |
+
"start_time": trace_info.start_time,
|
318 |
+
"end_time": trace_info.end_time,
|
319 |
+
"metadata": wrap_metadata(metadata),
|
320 |
+
"input": {"input": trace_info.inputs},
|
321 |
+
"output": {"output": message_data.answer},
|
322 |
+
"tags": ["llm", str(trace_info.conversation_mode)],
|
323 |
+
"usage": {
|
324 |
+
"completion_tokens": trace_info.answer_tokens,
|
325 |
+
"prompt_tokens": trace_info.message_tokens,
|
326 |
+
"total_tokens": trace_info.total_tokens,
|
327 |
+
},
|
328 |
+
"project_name": self.project,
|
329 |
+
}
|
330 |
+
self.add_span(span_data)
|
331 |
+
|
332 |
+
def moderation_trace(self, trace_info: ModerationTraceInfo):
|
333 |
+
if trace_info.message_data is None:
|
334 |
+
return
|
335 |
+
|
336 |
+
start_time = trace_info.start_time or trace_info.message_data.created_at
|
337 |
+
|
338 |
+
span_data = {
|
339 |
+
"trace_id": prepare_opik_uuid(start_time, trace_info.message_id),
|
340 |
+
"name": TraceTaskName.MODERATION_TRACE.value,
|
341 |
+
"type": "tool",
|
342 |
+
"start_time": start_time,
|
343 |
+
"end_time": trace_info.end_time or trace_info.message_data.updated_at,
|
344 |
+
"metadata": wrap_metadata(trace_info.metadata),
|
345 |
+
"input": wrap_dict("input", trace_info.inputs),
|
346 |
+
"output": {
|
347 |
+
"action": trace_info.action,
|
348 |
+
"flagged": trace_info.flagged,
|
349 |
+
"preset_response": trace_info.preset_response,
|
350 |
+
"inputs": trace_info.inputs,
|
351 |
+
},
|
352 |
+
"tags": ["moderation"],
|
353 |
+
}
|
354 |
+
|
355 |
+
self.add_span(span_data)
|
356 |
+
|
357 |
+
def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo):
|
358 |
+
message_data = trace_info.message_data
|
359 |
+
if message_data is None:
|
360 |
+
return
|
361 |
+
|
362 |
+
start_time = trace_info.start_time or message_data.created_at
|
363 |
+
|
364 |
+
span_data = {
|
365 |
+
"trace_id": prepare_opik_uuid(start_time, trace_info.message_id),
|
366 |
+
"name": TraceTaskName.SUGGESTED_QUESTION_TRACE.value,
|
367 |
+
"type": "tool",
|
368 |
+
"start_time": start_time,
|
369 |
+
"end_time": trace_info.end_time or message_data.updated_at,
|
370 |
+
"metadata": wrap_metadata(trace_info.metadata),
|
371 |
+
"input": wrap_dict("input", trace_info.inputs),
|
372 |
+
"output": wrap_dict("output", trace_info.suggested_question),
|
373 |
+
"tags": ["suggested_question"],
|
374 |
+
}
|
375 |
+
|
376 |
+
self.add_span(span_data)
|
377 |
+
|
378 |
+
def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo):
|
379 |
+
if trace_info.message_data is None:
|
380 |
+
return
|
381 |
+
|
382 |
+
start_time = trace_info.start_time or trace_info.message_data.created_at
|
383 |
+
|
384 |
+
span_data = {
|
385 |
+
"trace_id": prepare_opik_uuid(start_time, trace_info.message_id),
|
386 |
+
"name": TraceTaskName.DATASET_RETRIEVAL_TRACE.value,
|
387 |
+
"type": "tool",
|
388 |
+
"start_time": start_time,
|
389 |
+
"end_time": trace_info.end_time or trace_info.message_data.updated_at,
|
390 |
+
"metadata": wrap_metadata(trace_info.metadata),
|
391 |
+
"input": wrap_dict("input", trace_info.inputs),
|
392 |
+
"output": {"documents": trace_info.documents},
|
393 |
+
"tags": ["dataset_retrieval"],
|
394 |
+
}
|
395 |
+
|
396 |
+
self.add_span(span_data)
|
397 |
+
|
398 |
+
def tool_trace(self, trace_info: ToolTraceInfo):
|
399 |
+
span_data = {
|
400 |
+
"trace_id": prepare_opik_uuid(trace_info.start_time, trace_info.message_id),
|
401 |
+
"name": trace_info.tool_name,
|
402 |
+
"type": "tool",
|
403 |
+
"start_time": trace_info.start_time,
|
404 |
+
"end_time": trace_info.end_time,
|
405 |
+
"metadata": wrap_metadata(trace_info.metadata),
|
406 |
+
"input": wrap_dict("input", trace_info.tool_inputs),
|
407 |
+
"output": wrap_dict("output", trace_info.tool_outputs),
|
408 |
+
"tags": ["tool", trace_info.tool_name],
|
409 |
+
}
|
410 |
+
|
411 |
+
self.add_span(span_data)
|
412 |
+
|
413 |
+
def generate_name_trace(self, trace_info: GenerateNameTraceInfo):
|
414 |
+
trace_data = {
|
415 |
+
"id": prepare_opik_uuid(trace_info.start_time, trace_info.message_id),
|
416 |
+
"name": TraceTaskName.GENERATE_NAME_TRACE.value,
|
417 |
+
"start_time": trace_info.start_time,
|
418 |
+
"end_time": trace_info.end_time,
|
419 |
+
"metadata": wrap_metadata(trace_info.metadata),
|
420 |
+
"input": trace_info.inputs,
|
421 |
+
"output": trace_info.outputs,
|
422 |
+
"tags": ["generate_name"],
|
423 |
+
"project_name": self.project,
|
424 |
+
}
|
425 |
+
|
426 |
+
trace = self.add_trace(trace_data)
|
427 |
+
|
428 |
+
span_data = {
|
429 |
+
"trace_id": trace.id,
|
430 |
+
"name": TraceTaskName.GENERATE_NAME_TRACE.value,
|
431 |
+
"start_time": trace_info.start_time,
|
432 |
+
"end_time": trace_info.end_time,
|
433 |
+
"metadata": wrap_metadata(trace_info.metadata),
|
434 |
+
"input": wrap_dict("input", trace_info.inputs),
|
435 |
+
"output": wrap_dict("output", trace_info.outputs),
|
436 |
+
"tags": ["generate_name"],
|
437 |
+
}
|
438 |
+
|
439 |
+
self.add_span(span_data)
|
440 |
+
|
441 |
+
def add_trace(self, opik_trace_data: dict) -> Trace:
|
442 |
+
try:
|
443 |
+
trace = self.opik_client.trace(**opik_trace_data)
|
444 |
+
logger.debug("Opik Trace created successfully")
|
445 |
+
return trace
|
446 |
+
except Exception as e:
|
447 |
+
raise ValueError(f"Opik Failed to create trace: {str(e)}")
|
448 |
+
|
449 |
+
def add_span(self, opik_span_data: dict):
|
450 |
+
try:
|
451 |
+
self.opik_client.span(**opik_span_data)
|
452 |
+
logger.debug("Opik Span created successfully")
|
453 |
+
except Exception as e:
|
454 |
+
raise ValueError(f"Opik Failed to create span: {str(e)}")
|
455 |
+
|
456 |
+
def api_check(self):
|
457 |
+
try:
|
458 |
+
self.opik_client.auth_check()
|
459 |
+
return True
|
460 |
+
except Exception as e:
|
461 |
+
logger.info(f"Opik API check failed: {str(e)}", exc_info=True)
|
462 |
+
raise ValueError(f"Opik API check failed: {str(e)}")
|
463 |
+
|
464 |
+
def get_project_url(self):
|
465 |
+
try:
|
466 |
+
return self.opik_client.get_project_url(project_name=self.project)
|
467 |
+
except Exception as e:
|
468 |
+
logger.info(f"Opik get run url failed: {str(e)}", exc_info=True)
|
469 |
+
raise ValueError(f"Opik get run url failed: {str(e)}")
|
api/core/ops/ops_trace_manager.py
ADDED
@@ -0,0 +1,811 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
import queue
|
5 |
+
import threading
|
6 |
+
import time
|
7 |
+
from datetime import timedelta
|
8 |
+
from typing import Any, Optional, Union
|
9 |
+
from uuid import UUID, uuid4
|
10 |
+
|
11 |
+
from flask import current_app
|
12 |
+
from sqlalchemy import select
|
13 |
+
from sqlalchemy.orm import Session
|
14 |
+
|
15 |
+
from core.helper.encrypter import decrypt_token, encrypt_token, obfuscated_token
|
16 |
+
from core.ops.entities.config_entity import (
|
17 |
+
OPS_FILE_PATH,
|
18 |
+
LangfuseConfig,
|
19 |
+
LangSmithConfig,
|
20 |
+
OpikConfig,
|
21 |
+
TracingProviderEnum,
|
22 |
+
)
|
23 |
+
from core.ops.entities.trace_entity import (
|
24 |
+
DatasetRetrievalTraceInfo,
|
25 |
+
GenerateNameTraceInfo,
|
26 |
+
MessageTraceInfo,
|
27 |
+
ModerationTraceInfo,
|
28 |
+
SuggestedQuestionTraceInfo,
|
29 |
+
TaskData,
|
30 |
+
ToolTraceInfo,
|
31 |
+
TraceTaskName,
|
32 |
+
WorkflowTraceInfo,
|
33 |
+
)
|
34 |
+
from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace
|
35 |
+
from core.ops.langsmith_trace.langsmith_trace import LangSmithDataTrace
|
36 |
+
from core.ops.opik_trace.opik_trace import OpikDataTrace
|
37 |
+
from core.ops.utils import get_message_data
|
38 |
+
from extensions.ext_database import db
|
39 |
+
from extensions.ext_storage import storage
|
40 |
+
from models.model import App, AppModelConfig, Conversation, Message, MessageFile, TraceAppConfig
|
41 |
+
from models.workflow import WorkflowAppLog, WorkflowRun
|
42 |
+
from tasks.ops_trace_task import process_trace_tasks
|
43 |
+
|
44 |
+
provider_config_map: dict[str, dict[str, Any]] = {
|
45 |
+
TracingProviderEnum.LANGFUSE.value: {
|
46 |
+
"config_class": LangfuseConfig,
|
47 |
+
"secret_keys": ["public_key", "secret_key"],
|
48 |
+
"other_keys": ["host", "project_key"],
|
49 |
+
"trace_instance": LangFuseDataTrace,
|
50 |
+
},
|
51 |
+
TracingProviderEnum.LANGSMITH.value: {
|
52 |
+
"config_class": LangSmithConfig,
|
53 |
+
"secret_keys": ["api_key"],
|
54 |
+
"other_keys": ["project", "endpoint"],
|
55 |
+
"trace_instance": LangSmithDataTrace,
|
56 |
+
},
|
57 |
+
TracingProviderEnum.OPIK.value: {
|
58 |
+
"config_class": OpikConfig,
|
59 |
+
"secret_keys": ["api_key"],
|
60 |
+
"other_keys": ["project", "url", "workspace"],
|
61 |
+
"trace_instance": OpikDataTrace,
|
62 |
+
},
|
63 |
+
}
|
64 |
+
|
65 |
+
|
66 |
+
class OpsTraceManager:
|
67 |
+
@classmethod
|
68 |
+
def encrypt_tracing_config(
|
69 |
+
cls, tenant_id: str, tracing_provider: str, tracing_config: dict, current_trace_config=None
|
70 |
+
):
|
71 |
+
"""
|
72 |
+
Encrypt tracing config.
|
73 |
+
:param tenant_id: tenant id
|
74 |
+
:param tracing_provider: tracing provider
|
75 |
+
:param tracing_config: tracing config dictionary to be encrypted
|
76 |
+
:param current_trace_config: current tracing configuration for keeping existing values
|
77 |
+
:return: encrypted tracing configuration
|
78 |
+
"""
|
79 |
+
# Get the configuration class and the keys that require encryption
|
80 |
+
config_class, secret_keys, other_keys = (
|
81 |
+
provider_config_map[tracing_provider]["config_class"],
|
82 |
+
provider_config_map[tracing_provider]["secret_keys"],
|
83 |
+
provider_config_map[tracing_provider]["other_keys"],
|
84 |
+
)
|
85 |
+
|
86 |
+
new_config = {}
|
87 |
+
# Encrypt necessary keys
|
88 |
+
for key in secret_keys:
|
89 |
+
if key in tracing_config:
|
90 |
+
if "*" in tracing_config[key]:
|
91 |
+
# If the key contains '*', retain the original value from the current config
|
92 |
+
new_config[key] = current_trace_config.get(key, tracing_config[key])
|
93 |
+
else:
|
94 |
+
# Otherwise, encrypt the key
|
95 |
+
new_config[key] = encrypt_token(tenant_id, tracing_config[key])
|
96 |
+
|
97 |
+
for key in other_keys:
|
98 |
+
new_config[key] = tracing_config.get(key, "")
|
99 |
+
|
100 |
+
# Create a new instance of the config class with the new configuration
|
101 |
+
encrypted_config = config_class(**new_config)
|
102 |
+
return encrypted_config.model_dump()
|
103 |
+
|
104 |
+
@classmethod
|
105 |
+
def decrypt_tracing_config(cls, tenant_id: str, tracing_provider: str, tracing_config: dict):
|
106 |
+
"""
|
107 |
+
Decrypt tracing config
|
108 |
+
:param tenant_id: tenant id
|
109 |
+
:param tracing_provider: tracing provider
|
110 |
+
:param tracing_config: tracing config
|
111 |
+
:return:
|
112 |
+
"""
|
113 |
+
config_class, secret_keys, other_keys = (
|
114 |
+
provider_config_map[tracing_provider]["config_class"],
|
115 |
+
provider_config_map[tracing_provider]["secret_keys"],
|
116 |
+
provider_config_map[tracing_provider]["other_keys"],
|
117 |
+
)
|
118 |
+
new_config = {}
|
119 |
+
for key in secret_keys:
|
120 |
+
if key in tracing_config:
|
121 |
+
new_config[key] = decrypt_token(tenant_id, tracing_config[key])
|
122 |
+
|
123 |
+
for key in other_keys:
|
124 |
+
new_config[key] = tracing_config.get(key, "")
|
125 |
+
|
126 |
+
return config_class(**new_config).model_dump()
|
127 |
+
|
128 |
+
@classmethod
|
129 |
+
def obfuscated_decrypt_token(cls, tracing_provider: str, decrypt_tracing_config: dict):
|
130 |
+
"""
|
131 |
+
Decrypt tracing config
|
132 |
+
:param tracing_provider: tracing provider
|
133 |
+
:param decrypt_tracing_config: tracing config
|
134 |
+
:return:
|
135 |
+
"""
|
136 |
+
config_class, secret_keys, other_keys = (
|
137 |
+
provider_config_map[tracing_provider]["config_class"],
|
138 |
+
provider_config_map[tracing_provider]["secret_keys"],
|
139 |
+
provider_config_map[tracing_provider]["other_keys"],
|
140 |
+
)
|
141 |
+
new_config = {}
|
142 |
+
for key in secret_keys:
|
143 |
+
if key in decrypt_tracing_config:
|
144 |
+
new_config[key] = obfuscated_token(decrypt_tracing_config[key])
|
145 |
+
|
146 |
+
for key in other_keys:
|
147 |
+
new_config[key] = decrypt_tracing_config.get(key, "")
|
148 |
+
return config_class(**new_config).model_dump()
|
149 |
+
|
150 |
+
@classmethod
|
151 |
+
def get_decrypted_tracing_config(cls, app_id: str, tracing_provider: str):
|
152 |
+
"""
|
153 |
+
Get decrypted tracing config
|
154 |
+
:param app_id: app id
|
155 |
+
:param tracing_provider: tracing provider
|
156 |
+
:return:
|
157 |
+
"""
|
158 |
+
trace_config_data: Optional[TraceAppConfig] = (
|
159 |
+
db.session.query(TraceAppConfig)
|
160 |
+
.filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider)
|
161 |
+
.first()
|
162 |
+
)
|
163 |
+
|
164 |
+
if not trace_config_data:
|
165 |
+
return None
|
166 |
+
|
167 |
+
# decrypt_token
|
168 |
+
app = db.session.query(App).filter(App.id == app_id).first()
|
169 |
+
if not app:
|
170 |
+
raise ValueError("App not found")
|
171 |
+
|
172 |
+
tenant_id = app.tenant_id
|
173 |
+
decrypt_tracing_config = cls.decrypt_tracing_config(
|
174 |
+
tenant_id, tracing_provider, trace_config_data.tracing_config
|
175 |
+
)
|
176 |
+
|
177 |
+
return decrypt_tracing_config
|
178 |
+
|
179 |
+
@classmethod
|
180 |
+
def get_ops_trace_instance(
|
181 |
+
cls,
|
182 |
+
app_id: Optional[Union[UUID, str]] = None,
|
183 |
+
):
|
184 |
+
"""
|
185 |
+
Get ops trace through model config
|
186 |
+
:param app_id: app_id
|
187 |
+
:return:
|
188 |
+
"""
|
189 |
+
if isinstance(app_id, UUID):
|
190 |
+
app_id = str(app_id)
|
191 |
+
|
192 |
+
if app_id is None:
|
193 |
+
return None
|
194 |
+
|
195 |
+
app: Optional[App] = db.session.query(App).filter(App.id == app_id).first()
|
196 |
+
|
197 |
+
if app is None:
|
198 |
+
return None
|
199 |
+
|
200 |
+
app_ops_trace_config = json.loads(app.tracing) if app.tracing else None
|
201 |
+
|
202 |
+
if app_ops_trace_config is None:
|
203 |
+
return None
|
204 |
+
|
205 |
+
tracing_provider = app_ops_trace_config.get("tracing_provider")
|
206 |
+
|
207 |
+
if tracing_provider is None or tracing_provider not in provider_config_map:
|
208 |
+
return None
|
209 |
+
|
210 |
+
# decrypt_token
|
211 |
+
decrypt_trace_config = cls.get_decrypted_tracing_config(app_id, tracing_provider)
|
212 |
+
if app_ops_trace_config.get("enabled"):
|
213 |
+
trace_instance, config_class = (
|
214 |
+
provider_config_map[tracing_provider]["trace_instance"],
|
215 |
+
provider_config_map[tracing_provider]["config_class"],
|
216 |
+
)
|
217 |
+
tracing_instance = trace_instance(config_class(**decrypt_trace_config))
|
218 |
+
return tracing_instance
|
219 |
+
|
220 |
+
return None
|
221 |
+
|
222 |
+
@classmethod
|
223 |
+
def get_app_config_through_message_id(cls, message_id: str):
|
224 |
+
app_model_config = None
|
225 |
+
message_data = db.session.query(Message).filter(Message.id == message_id).first()
|
226 |
+
if not message_data:
|
227 |
+
return None
|
228 |
+
conversation_id = message_data.conversation_id
|
229 |
+
conversation_data = db.session.query(Conversation).filter(Conversation.id == conversation_id).first()
|
230 |
+
if not conversation_data:
|
231 |
+
return None
|
232 |
+
|
233 |
+
if conversation_data.app_model_config_id:
|
234 |
+
app_model_config = (
|
235 |
+
db.session.query(AppModelConfig)
|
236 |
+
.filter(AppModelConfig.id == conversation_data.app_model_config_id)
|
237 |
+
.first()
|
238 |
+
)
|
239 |
+
elif conversation_data.app_model_config_id is None and conversation_data.override_model_configs:
|
240 |
+
app_model_config = conversation_data.override_model_configs
|
241 |
+
|
242 |
+
return app_model_config
|
243 |
+
|
244 |
+
@classmethod
|
245 |
+
def update_app_tracing_config(cls, app_id: str, enabled: bool, tracing_provider: str):
|
246 |
+
"""
|
247 |
+
Update app tracing config
|
248 |
+
:param app_id: app id
|
249 |
+
:param enabled: enabled
|
250 |
+
:param tracing_provider: tracing provider
|
251 |
+
:return:
|
252 |
+
"""
|
253 |
+
# auth check
|
254 |
+
if tracing_provider not in provider_config_map and tracing_provider is not None:
|
255 |
+
raise ValueError(f"Invalid tracing provider: {tracing_provider}")
|
256 |
+
|
257 |
+
app_config: Optional[App] = db.session.query(App).filter(App.id == app_id).first()
|
258 |
+
if not app_config:
|
259 |
+
raise ValueError("App not found")
|
260 |
+
app_config.tracing = json.dumps(
|
261 |
+
{
|
262 |
+
"enabled": enabled,
|
263 |
+
"tracing_provider": tracing_provider,
|
264 |
+
}
|
265 |
+
)
|
266 |
+
db.session.commit()
|
267 |
+
|
268 |
+
@classmethod
|
269 |
+
def get_app_tracing_config(cls, app_id: str):
|
270 |
+
"""
|
271 |
+
Get app tracing config
|
272 |
+
:param app_id: app id
|
273 |
+
:return:
|
274 |
+
"""
|
275 |
+
app: Optional[App] = db.session.query(App).filter(App.id == app_id).first()
|
276 |
+
if not app:
|
277 |
+
raise ValueError("App not found")
|
278 |
+
if not app.tracing:
|
279 |
+
return {"enabled": False, "tracing_provider": None}
|
280 |
+
app_trace_config = json.loads(app.tracing)
|
281 |
+
return app_trace_config
|
282 |
+
|
283 |
+
@staticmethod
|
284 |
+
def check_trace_config_is_effective(tracing_config: dict, tracing_provider: str):
|
285 |
+
"""
|
286 |
+
Check trace config is effective
|
287 |
+
:param tracing_config: tracing config
|
288 |
+
:param tracing_provider: tracing provider
|
289 |
+
:return:
|
290 |
+
"""
|
291 |
+
config_type, trace_instance = (
|
292 |
+
provider_config_map[tracing_provider]["config_class"],
|
293 |
+
provider_config_map[tracing_provider]["trace_instance"],
|
294 |
+
)
|
295 |
+
tracing_config = config_type(**tracing_config)
|
296 |
+
return trace_instance(tracing_config).api_check()
|
297 |
+
|
298 |
+
@staticmethod
|
299 |
+
def get_trace_config_project_key(tracing_config: dict, tracing_provider: str):
|
300 |
+
"""
|
301 |
+
get trace config is project key
|
302 |
+
:param tracing_config: tracing config
|
303 |
+
:param tracing_provider: tracing provider
|
304 |
+
:return:
|
305 |
+
"""
|
306 |
+
config_type, trace_instance = (
|
307 |
+
provider_config_map[tracing_provider]["config_class"],
|
308 |
+
provider_config_map[tracing_provider]["trace_instance"],
|
309 |
+
)
|
310 |
+
tracing_config = config_type(**tracing_config)
|
311 |
+
return trace_instance(tracing_config).get_project_key()
|
312 |
+
|
313 |
+
@staticmethod
|
314 |
+
def get_trace_config_project_url(tracing_config: dict, tracing_provider: str):
|
315 |
+
"""
|
316 |
+
get trace config is project key
|
317 |
+
:param tracing_config: tracing config
|
318 |
+
:param tracing_provider: tracing provider
|
319 |
+
:return:
|
320 |
+
"""
|
321 |
+
config_type, trace_instance = (
|
322 |
+
provider_config_map[tracing_provider]["config_class"],
|
323 |
+
provider_config_map[tracing_provider]["trace_instance"],
|
324 |
+
)
|
325 |
+
tracing_config = config_type(**tracing_config)
|
326 |
+
return trace_instance(tracing_config).get_project_url()
|
327 |
+
|
328 |
+
|
329 |
+
class TraceTask:
|
330 |
+
def __init__(
|
331 |
+
self,
|
332 |
+
trace_type: Any,
|
333 |
+
message_id: Optional[str] = None,
|
334 |
+
workflow_run: Optional[WorkflowRun] = None,
|
335 |
+
conversation_id: Optional[str] = None,
|
336 |
+
user_id: Optional[str] = None,
|
337 |
+
timer: Optional[Any] = None,
|
338 |
+
**kwargs,
|
339 |
+
):
|
340 |
+
self.trace_type = trace_type
|
341 |
+
self.message_id = message_id
|
342 |
+
self.workflow_run_id = workflow_run.id if workflow_run else None
|
343 |
+
self.conversation_id = conversation_id
|
344 |
+
self.user_id = user_id
|
345 |
+
self.timer = timer
|
346 |
+
self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
|
347 |
+
self.app_id = None
|
348 |
+
|
349 |
+
self.kwargs = kwargs
|
350 |
+
|
351 |
+
def execute(self):
|
352 |
+
return self.preprocess()
|
353 |
+
|
354 |
+
def preprocess(self):
|
355 |
+
preprocess_map = {
|
356 |
+
TraceTaskName.CONVERSATION_TRACE: lambda: self.conversation_trace(**self.kwargs),
|
357 |
+
TraceTaskName.WORKFLOW_TRACE: lambda: self.workflow_trace(
|
358 |
+
workflow_run_id=self.workflow_run_id, conversation_id=self.conversation_id, user_id=self.user_id
|
359 |
+
),
|
360 |
+
TraceTaskName.MESSAGE_TRACE: lambda: self.message_trace(message_id=self.message_id),
|
361 |
+
TraceTaskName.MODERATION_TRACE: lambda: self.moderation_trace(
|
362 |
+
message_id=self.message_id, timer=self.timer, **self.kwargs
|
363 |
+
),
|
364 |
+
TraceTaskName.SUGGESTED_QUESTION_TRACE: lambda: self.suggested_question_trace(
|
365 |
+
message_id=self.message_id, timer=self.timer, **self.kwargs
|
366 |
+
),
|
367 |
+
TraceTaskName.DATASET_RETRIEVAL_TRACE: lambda: self.dataset_retrieval_trace(
|
368 |
+
message_id=self.message_id, timer=self.timer, **self.kwargs
|
369 |
+
),
|
370 |
+
TraceTaskName.TOOL_TRACE: lambda: self.tool_trace(
|
371 |
+
message_id=self.message_id, timer=self.timer, **self.kwargs
|
372 |
+
),
|
373 |
+
TraceTaskName.GENERATE_NAME_TRACE: lambda: self.generate_name_trace(
|
374 |
+
conversation_id=self.conversation_id, timer=self.timer, **self.kwargs
|
375 |
+
),
|
376 |
+
}
|
377 |
+
|
378 |
+
return preprocess_map.get(self.trace_type, lambda: None)()
|
379 |
+
|
380 |
+
# process methods for different trace types
|
381 |
+
def conversation_trace(self, **kwargs):
|
382 |
+
return kwargs
|
383 |
+
|
384 |
+
def workflow_trace(
|
385 |
+
self,
|
386 |
+
*,
|
387 |
+
workflow_run_id: str | None,
|
388 |
+
conversation_id: str | None,
|
389 |
+
user_id: str | None,
|
390 |
+
):
|
391 |
+
if not workflow_run_id:
|
392 |
+
return {}
|
393 |
+
|
394 |
+
with Session(db.engine) as session:
|
395 |
+
workflow_run_stmt = select(WorkflowRun).where(WorkflowRun.id == workflow_run_id)
|
396 |
+
workflow_run = session.scalars(workflow_run_stmt).first()
|
397 |
+
if not workflow_run:
|
398 |
+
raise ValueError("Workflow run not found")
|
399 |
+
|
400 |
+
workflow_id = workflow_run.workflow_id
|
401 |
+
tenant_id = workflow_run.tenant_id
|
402 |
+
workflow_run_id = workflow_run.id
|
403 |
+
workflow_run_elapsed_time = workflow_run.elapsed_time
|
404 |
+
workflow_run_status = workflow_run.status
|
405 |
+
workflow_run_inputs = workflow_run.inputs_dict
|
406 |
+
workflow_run_outputs = workflow_run.outputs_dict
|
407 |
+
workflow_run_version = workflow_run.version
|
408 |
+
error = workflow_run.error or ""
|
409 |
+
|
410 |
+
total_tokens = workflow_run.total_tokens
|
411 |
+
|
412 |
+
file_list = workflow_run_inputs.get("sys.file") or []
|
413 |
+
query = workflow_run_inputs.get("query") or workflow_run_inputs.get("sys.query") or ""
|
414 |
+
|
415 |
+
# get workflow_app_log_id
|
416 |
+
workflow_app_log_data_stmt = select(WorkflowAppLog.id).where(
|
417 |
+
WorkflowAppLog.tenant_id == tenant_id,
|
418 |
+
WorkflowAppLog.app_id == workflow_run.app_id,
|
419 |
+
WorkflowAppLog.workflow_run_id == workflow_run.id,
|
420 |
+
)
|
421 |
+
workflow_app_log_id = session.scalar(workflow_app_log_data_stmt)
|
422 |
+
# get message_id
|
423 |
+
message_id = None
|
424 |
+
if conversation_id:
|
425 |
+
message_data_stmt = select(Message.id).where(
|
426 |
+
Message.conversation_id == conversation_id,
|
427 |
+
Message.workflow_run_id == workflow_run_id,
|
428 |
+
)
|
429 |
+
message_id = session.scalar(message_data_stmt)
|
430 |
+
|
431 |
+
metadata = {
|
432 |
+
"workflow_id": workflow_id,
|
433 |
+
"conversation_id": conversation_id,
|
434 |
+
"workflow_run_id": workflow_run_id,
|
435 |
+
"tenant_id": tenant_id,
|
436 |
+
"elapsed_time": workflow_run_elapsed_time,
|
437 |
+
"status": workflow_run_status,
|
438 |
+
"version": workflow_run_version,
|
439 |
+
"total_tokens": total_tokens,
|
440 |
+
"file_list": file_list,
|
441 |
+
"triggered_form": workflow_run.triggered_from,
|
442 |
+
"user_id": user_id,
|
443 |
+
}
|
444 |
+
|
445 |
+
workflow_trace_info = WorkflowTraceInfo(
|
446 |
+
workflow_data=workflow_run.to_dict(),
|
447 |
+
conversation_id=conversation_id,
|
448 |
+
workflow_id=workflow_id,
|
449 |
+
tenant_id=tenant_id,
|
450 |
+
workflow_run_id=workflow_run_id,
|
451 |
+
workflow_run_elapsed_time=workflow_run_elapsed_time,
|
452 |
+
workflow_run_status=workflow_run_status,
|
453 |
+
workflow_run_inputs=workflow_run_inputs,
|
454 |
+
workflow_run_outputs=workflow_run_outputs,
|
455 |
+
workflow_run_version=workflow_run_version,
|
456 |
+
error=error,
|
457 |
+
total_tokens=total_tokens,
|
458 |
+
file_list=file_list,
|
459 |
+
query=query,
|
460 |
+
metadata=metadata,
|
461 |
+
workflow_app_log_id=workflow_app_log_id,
|
462 |
+
message_id=message_id,
|
463 |
+
start_time=workflow_run.created_at,
|
464 |
+
end_time=workflow_run.finished_at,
|
465 |
+
)
|
466 |
+
return workflow_trace_info
|
467 |
+
|
468 |
+
def message_trace(self, message_id: str | None):
|
469 |
+
if not message_id:
|
470 |
+
return {}
|
471 |
+
message_data = get_message_data(message_id)
|
472 |
+
if not message_data:
|
473 |
+
return {}
|
474 |
+
conversation_mode_stmt = select(Conversation.mode).where(Conversation.id == message_data.conversation_id)
|
475 |
+
conversation_mode = db.session.scalars(conversation_mode_stmt).all()
|
476 |
+
if not conversation_mode or len(conversation_mode) == 0:
|
477 |
+
return {}
|
478 |
+
conversation_mode = conversation_mode[0]
|
479 |
+
created_at = message_data.created_at
|
480 |
+
inputs = message_data.message
|
481 |
+
|
482 |
+
# get message file data
|
483 |
+
message_file_data = db.session.query(MessageFile).filter_by(message_id=message_id).first()
|
484 |
+
file_list = []
|
485 |
+
if message_file_data and message_file_data.url is not None:
|
486 |
+
file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else ""
|
487 |
+
file_list.append(file_url)
|
488 |
+
|
489 |
+
metadata = {
|
490 |
+
"conversation_id": message_data.conversation_id,
|
491 |
+
"ls_provider": message_data.model_provider,
|
492 |
+
"ls_model_name": message_data.model_id,
|
493 |
+
"status": message_data.status,
|
494 |
+
"from_end_user_id": message_data.from_end_user_id,
|
495 |
+
"from_account_id": message_data.from_account_id,
|
496 |
+
"agent_based": message_data.agent_based,
|
497 |
+
"workflow_run_id": message_data.workflow_run_id,
|
498 |
+
"from_source": message_data.from_source,
|
499 |
+
"message_id": message_id,
|
500 |
+
}
|
501 |
+
|
502 |
+
message_tokens = message_data.message_tokens
|
503 |
+
|
504 |
+
message_trace_info = MessageTraceInfo(
|
505 |
+
message_id=message_id,
|
506 |
+
message_data=message_data.to_dict(),
|
507 |
+
conversation_model=conversation_mode,
|
508 |
+
message_tokens=message_tokens,
|
509 |
+
answer_tokens=message_data.answer_tokens,
|
510 |
+
total_tokens=message_tokens + message_data.answer_tokens,
|
511 |
+
error=message_data.error or "",
|
512 |
+
inputs=inputs,
|
513 |
+
outputs=message_data.answer,
|
514 |
+
file_list=file_list,
|
515 |
+
start_time=created_at,
|
516 |
+
end_time=created_at + timedelta(seconds=message_data.provider_response_latency),
|
517 |
+
metadata=metadata,
|
518 |
+
message_file_data=message_file_data,
|
519 |
+
conversation_mode=conversation_mode,
|
520 |
+
)
|
521 |
+
|
522 |
+
return message_trace_info
|
523 |
+
|
524 |
+
def moderation_trace(self, message_id, timer, **kwargs):
|
525 |
+
moderation_result = kwargs.get("moderation_result")
|
526 |
+
if not moderation_result:
|
527 |
+
return {}
|
528 |
+
inputs = kwargs.get("inputs")
|
529 |
+
message_data = get_message_data(message_id)
|
530 |
+
if not message_data:
|
531 |
+
return {}
|
532 |
+
metadata = {
|
533 |
+
"message_id": message_id,
|
534 |
+
"action": moderation_result.action,
|
535 |
+
"preset_response": moderation_result.preset_response,
|
536 |
+
"query": moderation_result.query,
|
537 |
+
}
|
538 |
+
|
539 |
+
# get workflow_app_log_id
|
540 |
+
workflow_app_log_id = None
|
541 |
+
if message_data.workflow_run_id:
|
542 |
+
workflow_app_log_data = (
|
543 |
+
db.session.query(WorkflowAppLog).filter_by(workflow_run_id=message_data.workflow_run_id).first()
|
544 |
+
)
|
545 |
+
workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None
|
546 |
+
|
547 |
+
moderation_trace_info = ModerationTraceInfo(
|
548 |
+
message_id=workflow_app_log_id or message_id,
|
549 |
+
inputs=inputs,
|
550 |
+
message_data=message_data.to_dict(),
|
551 |
+
flagged=moderation_result.flagged,
|
552 |
+
action=moderation_result.action,
|
553 |
+
preset_response=moderation_result.preset_response,
|
554 |
+
query=moderation_result.query,
|
555 |
+
start_time=timer.get("start"),
|
556 |
+
end_time=timer.get("end"),
|
557 |
+
metadata=metadata,
|
558 |
+
)
|
559 |
+
|
560 |
+
return moderation_trace_info
|
561 |
+
|
562 |
+
def suggested_question_trace(self, message_id, timer, **kwargs):
|
563 |
+
suggested_question = kwargs.get("suggested_question", [])
|
564 |
+
message_data = get_message_data(message_id)
|
565 |
+
if not message_data:
|
566 |
+
return {}
|
567 |
+
metadata = {
|
568 |
+
"message_id": message_id,
|
569 |
+
"ls_provider": message_data.model_provider,
|
570 |
+
"ls_model_name": message_data.model_id,
|
571 |
+
"status": message_data.status,
|
572 |
+
"from_end_user_id": message_data.from_end_user_id,
|
573 |
+
"from_account_id": message_data.from_account_id,
|
574 |
+
"agent_based": message_data.agent_based,
|
575 |
+
"workflow_run_id": message_data.workflow_run_id,
|
576 |
+
"from_source": message_data.from_source,
|
577 |
+
}
|
578 |
+
|
579 |
+
# get workflow_app_log_id
|
580 |
+
workflow_app_log_id = None
|
581 |
+
if message_data.workflow_run_id:
|
582 |
+
workflow_app_log_data = (
|
583 |
+
db.session.query(WorkflowAppLog).filter_by(workflow_run_id=message_data.workflow_run_id).first()
|
584 |
+
)
|
585 |
+
workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None
|
586 |
+
|
587 |
+
suggested_question_trace_info = SuggestedQuestionTraceInfo(
|
588 |
+
message_id=workflow_app_log_id or message_id,
|
589 |
+
message_data=message_data.to_dict(),
|
590 |
+
inputs=message_data.message,
|
591 |
+
outputs=message_data.answer,
|
592 |
+
start_time=timer.get("start"),
|
593 |
+
end_time=timer.get("end"),
|
594 |
+
metadata=metadata,
|
595 |
+
total_tokens=message_data.message_tokens + message_data.answer_tokens,
|
596 |
+
status=message_data.status,
|
597 |
+
error=message_data.error,
|
598 |
+
from_account_id=message_data.from_account_id,
|
599 |
+
agent_based=message_data.agent_based,
|
600 |
+
from_source=message_data.from_source,
|
601 |
+
model_provider=message_data.model_provider,
|
602 |
+
model_id=message_data.model_id,
|
603 |
+
suggested_question=suggested_question,
|
604 |
+
level=message_data.status,
|
605 |
+
status_message=message_data.error,
|
606 |
+
)
|
607 |
+
|
608 |
+
return suggested_question_trace_info
|
609 |
+
|
610 |
+
def dataset_retrieval_trace(self, message_id, timer, **kwargs):
|
611 |
+
documents = kwargs.get("documents")
|
612 |
+
message_data = get_message_data(message_id)
|
613 |
+
if not message_data:
|
614 |
+
return {}
|
615 |
+
|
616 |
+
metadata = {
|
617 |
+
"message_id": message_id,
|
618 |
+
"ls_provider": message_data.model_provider,
|
619 |
+
"ls_model_name": message_data.model_id,
|
620 |
+
"status": message_data.status,
|
621 |
+
"from_end_user_id": message_data.from_end_user_id,
|
622 |
+
"from_account_id": message_data.from_account_id,
|
623 |
+
"agent_based": message_data.agent_based,
|
624 |
+
"workflow_run_id": message_data.workflow_run_id,
|
625 |
+
"from_source": message_data.from_source,
|
626 |
+
}
|
627 |
+
|
628 |
+
dataset_retrieval_trace_info = DatasetRetrievalTraceInfo(
|
629 |
+
message_id=message_id,
|
630 |
+
inputs=message_data.query or message_data.inputs,
|
631 |
+
documents=[doc.model_dump() for doc in documents] if documents else [],
|
632 |
+
start_time=timer.get("start"),
|
633 |
+
end_time=timer.get("end"),
|
634 |
+
metadata=metadata,
|
635 |
+
message_data=message_data.to_dict(),
|
636 |
+
)
|
637 |
+
|
638 |
+
return dataset_retrieval_trace_info
|
639 |
+
|
640 |
+
def tool_trace(self, message_id, timer, **kwargs):
|
641 |
+
tool_name = kwargs.get("tool_name", "")
|
642 |
+
tool_inputs = kwargs.get("tool_inputs", {})
|
643 |
+
tool_outputs = kwargs.get("tool_outputs", {})
|
644 |
+
message_data = get_message_data(message_id)
|
645 |
+
if not message_data:
|
646 |
+
return {}
|
647 |
+
tool_config = {}
|
648 |
+
time_cost = 0
|
649 |
+
error = None
|
650 |
+
tool_parameters = {}
|
651 |
+
created_time = message_data.created_at
|
652 |
+
end_time = message_data.updated_at
|
653 |
+
agent_thoughts = message_data.agent_thoughts
|
654 |
+
for agent_thought in agent_thoughts:
|
655 |
+
if tool_name in agent_thought.tools:
|
656 |
+
created_time = agent_thought.created_at
|
657 |
+
tool_meta_data = agent_thought.tool_meta.get(tool_name, {})
|
658 |
+
tool_config = tool_meta_data.get("tool_config", {})
|
659 |
+
time_cost = tool_meta_data.get("time_cost", 0)
|
660 |
+
end_time = created_time + timedelta(seconds=time_cost)
|
661 |
+
error = tool_meta_data.get("error", "")
|
662 |
+
tool_parameters = tool_meta_data.get("tool_parameters", {})
|
663 |
+
metadata = {
|
664 |
+
"message_id": message_id,
|
665 |
+
"tool_name": tool_name,
|
666 |
+
"tool_inputs": tool_inputs,
|
667 |
+
"tool_outputs": tool_outputs,
|
668 |
+
"tool_config": tool_config,
|
669 |
+
"time_cost": time_cost,
|
670 |
+
"error": error,
|
671 |
+
"tool_parameters": tool_parameters,
|
672 |
+
}
|
673 |
+
|
674 |
+
file_url = ""
|
675 |
+
message_file_data = db.session.query(MessageFile).filter_by(message_id=message_id).first()
|
676 |
+
if message_file_data:
|
677 |
+
message_file_id = message_file_data.id if message_file_data else None
|
678 |
+
type = message_file_data.type
|
679 |
+
created_by_role = message_file_data.created_by_role
|
680 |
+
created_user_id = message_file_data.created_by
|
681 |
+
file_url = f"{self.file_base_url}/{message_file_data.url}"
|
682 |
+
|
683 |
+
metadata.update(
|
684 |
+
{
|
685 |
+
"message_file_id": message_file_id,
|
686 |
+
"created_by_role": created_by_role,
|
687 |
+
"created_user_id": created_user_id,
|
688 |
+
"type": type,
|
689 |
+
}
|
690 |
+
)
|
691 |
+
|
692 |
+
tool_trace_info = ToolTraceInfo(
|
693 |
+
message_id=message_id,
|
694 |
+
message_data=message_data.to_dict(),
|
695 |
+
tool_name=tool_name,
|
696 |
+
start_time=timer.get("start") if timer else created_time,
|
697 |
+
end_time=timer.get("end") if timer else end_time,
|
698 |
+
tool_inputs=tool_inputs,
|
699 |
+
tool_outputs=tool_outputs,
|
700 |
+
metadata=metadata,
|
701 |
+
message_file_data=message_file_data,
|
702 |
+
error=error,
|
703 |
+
inputs=message_data.message,
|
704 |
+
outputs=message_data.answer,
|
705 |
+
tool_config=tool_config,
|
706 |
+
time_cost=time_cost,
|
707 |
+
tool_parameters=tool_parameters,
|
708 |
+
file_url=file_url,
|
709 |
+
)
|
710 |
+
|
711 |
+
return tool_trace_info
|
712 |
+
|
713 |
+
def generate_name_trace(self, conversation_id, timer, **kwargs):
|
714 |
+
generate_conversation_name = kwargs.get("generate_conversation_name")
|
715 |
+
inputs = kwargs.get("inputs")
|
716 |
+
tenant_id = kwargs.get("tenant_id")
|
717 |
+
if not tenant_id:
|
718 |
+
return {}
|
719 |
+
start_time = timer.get("start")
|
720 |
+
end_time = timer.get("end")
|
721 |
+
|
722 |
+
metadata = {
|
723 |
+
"conversation_id": conversation_id,
|
724 |
+
"tenant_id": tenant_id,
|
725 |
+
}
|
726 |
+
|
727 |
+
generate_name_trace_info = GenerateNameTraceInfo(
|
728 |
+
conversation_id=conversation_id,
|
729 |
+
inputs=inputs,
|
730 |
+
outputs=generate_conversation_name,
|
731 |
+
start_time=start_time,
|
732 |
+
end_time=end_time,
|
733 |
+
metadata=metadata,
|
734 |
+
tenant_id=tenant_id,
|
735 |
+
)
|
736 |
+
|
737 |
+
return generate_name_trace_info
|
738 |
+
|
739 |
+
|
740 |
+
trace_manager_timer: Optional[threading.Timer] = None
|
741 |
+
trace_manager_queue: queue.Queue = queue.Queue()
|
742 |
+
trace_manager_interval = int(os.getenv("TRACE_QUEUE_MANAGER_INTERVAL", 5))
|
743 |
+
trace_manager_batch_size = int(os.getenv("TRACE_QUEUE_MANAGER_BATCH_SIZE", 100))
|
744 |
+
|
745 |
+
|
746 |
+
class TraceQueueManager:
|
747 |
+
def __init__(self, app_id=None, user_id=None):
|
748 |
+
global trace_manager_timer
|
749 |
+
|
750 |
+
self.app_id = app_id
|
751 |
+
self.user_id = user_id
|
752 |
+
self.trace_instance = OpsTraceManager.get_ops_trace_instance(app_id)
|
753 |
+
self.flask_app = current_app._get_current_object() # type: ignore
|
754 |
+
if trace_manager_timer is None:
|
755 |
+
self.start_timer()
|
756 |
+
|
757 |
+
def add_trace_task(self, trace_task: TraceTask):
|
758 |
+
global trace_manager_timer, trace_manager_queue
|
759 |
+
try:
|
760 |
+
if self.trace_instance:
|
761 |
+
trace_task.app_id = self.app_id
|
762 |
+
trace_manager_queue.put(trace_task)
|
763 |
+
except Exception as e:
|
764 |
+
logging.exception(f"Error adding trace task, trace_type {trace_task.trace_type}")
|
765 |
+
finally:
|
766 |
+
self.start_timer()
|
767 |
+
|
768 |
+
def collect_tasks(self):
|
769 |
+
global trace_manager_queue
|
770 |
+
tasks: list[TraceTask] = []
|
771 |
+
while len(tasks) < trace_manager_batch_size and not trace_manager_queue.empty():
|
772 |
+
task = trace_manager_queue.get_nowait()
|
773 |
+
tasks.append(task)
|
774 |
+
trace_manager_queue.task_done()
|
775 |
+
return tasks
|
776 |
+
|
777 |
+
def run(self):
|
778 |
+
try:
|
779 |
+
tasks = self.collect_tasks()
|
780 |
+
if tasks:
|
781 |
+
self.send_to_celery(tasks)
|
782 |
+
except Exception as e:
|
783 |
+
logging.exception("Error processing trace tasks")
|
784 |
+
|
785 |
+
def start_timer(self):
|
786 |
+
global trace_manager_timer
|
787 |
+
if trace_manager_timer is None or not trace_manager_timer.is_alive():
|
788 |
+
trace_manager_timer = threading.Timer(trace_manager_interval, self.run)
|
789 |
+
trace_manager_timer.name = f"trace_manager_timer_{time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}"
|
790 |
+
trace_manager_timer.daemon = False
|
791 |
+
trace_manager_timer.start()
|
792 |
+
|
793 |
+
def send_to_celery(self, tasks: list[TraceTask]):
|
794 |
+
with self.flask_app.app_context():
|
795 |
+
for task in tasks:
|
796 |
+
if task.app_id is None:
|
797 |
+
continue
|
798 |
+
file_id = uuid4().hex
|
799 |
+
trace_info = task.execute()
|
800 |
+
task_data = TaskData(
|
801 |
+
app_id=task.app_id,
|
802 |
+
trace_info_type=type(trace_info).__name__,
|
803 |
+
trace_info=trace_info.model_dump() if trace_info else None,
|
804 |
+
)
|
805 |
+
file_path = f"{OPS_FILE_PATH}{task.app_id}/{file_id}.json"
|
806 |
+
storage.save(file_path, task_data.model_dump_json().encode("utf-8"))
|
807 |
+
file_info = {
|
808 |
+
"file_id": file_id,
|
809 |
+
"app_id": task.app_id,
|
810 |
+
}
|
811 |
+
process_trace_tasks.delay(file_info)
|
api/core/ops/utils.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from contextlib import contextmanager
|
2 |
+
from datetime import datetime
|
3 |
+
from typing import Optional, Union
|
4 |
+
|
5 |
+
from extensions.ext_database import db
|
6 |
+
from models.model import Message
|
7 |
+
|
8 |
+
|
9 |
+
def filter_none_values(data: dict):
|
10 |
+
new_data = {}
|
11 |
+
for key, value in data.items():
|
12 |
+
if value is None:
|
13 |
+
continue
|
14 |
+
if isinstance(value, datetime):
|
15 |
+
new_data[key] = value.isoformat()
|
16 |
+
else:
|
17 |
+
new_data[key] = value
|
18 |
+
return new_data
|
19 |
+
|
20 |
+
|
21 |
+
def get_message_data(message_id: str):
|
22 |
+
return db.session.query(Message).filter(Message.id == message_id).first()
|
23 |
+
|
24 |
+
|
25 |
+
@contextmanager
|
26 |
+
def measure_time():
|
27 |
+
timing_info = {"start": datetime.now(), "end": None}
|
28 |
+
try:
|
29 |
+
yield timing_info
|
30 |
+
finally:
|
31 |
+
timing_info["end"] = datetime.now()
|
32 |
+
|
33 |
+
|
34 |
+
def replace_text_with_content(data):
|
35 |
+
if isinstance(data, dict):
|
36 |
+
new_data = {}
|
37 |
+
for key, value in data.items():
|
38 |
+
if key == "text":
|
39 |
+
new_data["content"] = value
|
40 |
+
else:
|
41 |
+
new_data[key] = replace_text_with_content(value)
|
42 |
+
return new_data
|
43 |
+
elif isinstance(data, list):
|
44 |
+
return [replace_text_with_content(item) for item in data]
|
45 |
+
else:
|
46 |
+
return data
|
47 |
+
|
48 |
+
|
49 |
+
def generate_dotted_order(
|
50 |
+
run_id: str, start_time: Union[str, datetime], parent_dotted_order: Optional[str] = None
|
51 |
+
) -> str:
|
52 |
+
"""
|
53 |
+
generate dotted_order for langsmith
|
54 |
+
"""
|
55 |
+
start_time = datetime.fromisoformat(start_time) if isinstance(start_time, str) else start_time
|
56 |
+
timestamp = start_time.strftime("%Y%m%dT%H%M%S%f")[:-3] + "Z"
|
57 |
+
current_segment = f"{timestamp}{run_id}"
|
58 |
+
|
59 |
+
if parent_dotted_order is None:
|
60 |
+
return current_segment
|
61 |
+
|
62 |
+
return f"{parent_dotted_order}.{current_segment}"
|
api/core/prompt/__init__.py
ADDED
File without changes
|
api/core/prompt/advanced_prompt_transform.py
ADDED
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections.abc import Mapping, Sequence
|
2 |
+
from typing import Optional, cast
|
3 |
+
|
4 |
+
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
5 |
+
from core.file import file_manager
|
6 |
+
from core.file.models import File
|
7 |
+
from core.helper.code_executor.jinja2.jinja2_formatter import Jinja2Formatter
|
8 |
+
from core.memory.token_buffer_memory import TokenBufferMemory
|
9 |
+
from core.model_runtime.entities import (
|
10 |
+
AssistantPromptMessage,
|
11 |
+
PromptMessage,
|
12 |
+
PromptMessageContent,
|
13 |
+
PromptMessageRole,
|
14 |
+
SystemPromptMessage,
|
15 |
+
TextPromptMessageContent,
|
16 |
+
UserPromptMessage,
|
17 |
+
)
|
18 |
+
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
|
19 |
+
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
|
20 |
+
from core.prompt.prompt_transform import PromptTransform
|
21 |
+
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
22 |
+
from core.workflow.entities.variable_pool import VariablePool
|
23 |
+
|
24 |
+
|
25 |
+
class AdvancedPromptTransform(PromptTransform):
|
26 |
+
"""
|
27 |
+
Advanced Prompt Transform for Workflow LLM Node.
|
28 |
+
"""
|
29 |
+
|
30 |
+
def __init__(
|
31 |
+
self,
|
32 |
+
with_variable_tmpl: bool = False,
|
33 |
+
image_detail_config: ImagePromptMessageContent.DETAIL = ImagePromptMessageContent.DETAIL.LOW,
|
34 |
+
) -> None:
|
35 |
+
self.with_variable_tmpl = with_variable_tmpl
|
36 |
+
self.image_detail_config = image_detail_config
|
37 |
+
|
38 |
+
def get_prompt(
|
39 |
+
self,
|
40 |
+
*,
|
41 |
+
prompt_template: Sequence[ChatModelMessage] | CompletionModelPromptTemplate,
|
42 |
+
inputs: Mapping[str, str],
|
43 |
+
query: str,
|
44 |
+
files: Sequence[File],
|
45 |
+
context: Optional[str],
|
46 |
+
memory_config: Optional[MemoryConfig],
|
47 |
+
memory: Optional[TokenBufferMemory],
|
48 |
+
model_config: ModelConfigWithCredentialsEntity,
|
49 |
+
) -> list[PromptMessage]:
|
50 |
+
prompt_messages = []
|
51 |
+
|
52 |
+
if isinstance(prompt_template, CompletionModelPromptTemplate):
|
53 |
+
prompt_messages = self._get_completion_model_prompt_messages(
|
54 |
+
prompt_template=prompt_template,
|
55 |
+
inputs=inputs,
|
56 |
+
query=query,
|
57 |
+
files=files,
|
58 |
+
context=context,
|
59 |
+
memory_config=memory_config,
|
60 |
+
memory=memory,
|
61 |
+
model_config=model_config,
|
62 |
+
)
|
63 |
+
elif isinstance(prompt_template, list) and all(isinstance(item, ChatModelMessage) for item in prompt_template):
|
64 |
+
prompt_messages = self._get_chat_model_prompt_messages(
|
65 |
+
prompt_template=prompt_template,
|
66 |
+
inputs=inputs,
|
67 |
+
query=query,
|
68 |
+
files=files,
|
69 |
+
context=context,
|
70 |
+
memory_config=memory_config,
|
71 |
+
memory=memory,
|
72 |
+
model_config=model_config,
|
73 |
+
)
|
74 |
+
|
75 |
+
return prompt_messages
|
76 |
+
|
77 |
+
def _get_completion_model_prompt_messages(
|
78 |
+
self,
|
79 |
+
prompt_template: CompletionModelPromptTemplate,
|
80 |
+
inputs: Mapping[str, str],
|
81 |
+
query: Optional[str],
|
82 |
+
files: Sequence[File],
|
83 |
+
context: Optional[str],
|
84 |
+
memory_config: Optional[MemoryConfig],
|
85 |
+
memory: Optional[TokenBufferMemory],
|
86 |
+
model_config: ModelConfigWithCredentialsEntity,
|
87 |
+
) -> list[PromptMessage]:
|
88 |
+
"""
|
89 |
+
Get completion model prompt messages.
|
90 |
+
"""
|
91 |
+
raw_prompt = prompt_template.text
|
92 |
+
|
93 |
+
prompt_messages: list[PromptMessage] = []
|
94 |
+
|
95 |
+
if prompt_template.edition_type == "basic" or not prompt_template.edition_type:
|
96 |
+
parser = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
|
97 |
+
prompt_inputs: Mapping[str, str] = {k: inputs[k] for k in parser.variable_keys if k in inputs}
|
98 |
+
|
99 |
+
prompt_inputs = self._set_context_variable(context, parser, prompt_inputs)
|
100 |
+
|
101 |
+
if memory and memory_config and memory_config.role_prefix:
|
102 |
+
role_prefix = memory_config.role_prefix
|
103 |
+
prompt_inputs = self._set_histories_variable(
|
104 |
+
memory=memory,
|
105 |
+
memory_config=memory_config,
|
106 |
+
raw_prompt=raw_prompt,
|
107 |
+
role_prefix=role_prefix,
|
108 |
+
parser=parser,
|
109 |
+
prompt_inputs=prompt_inputs,
|
110 |
+
model_config=model_config,
|
111 |
+
)
|
112 |
+
|
113 |
+
if query:
|
114 |
+
prompt_inputs = self._set_query_variable(query, parser, prompt_inputs)
|
115 |
+
|
116 |
+
prompt = parser.format(prompt_inputs)
|
117 |
+
else:
|
118 |
+
prompt = raw_prompt
|
119 |
+
prompt_inputs = inputs
|
120 |
+
|
121 |
+
prompt = Jinja2Formatter.format(prompt, prompt_inputs)
|
122 |
+
|
123 |
+
if files:
|
124 |
+
prompt_message_contents: list[PromptMessageContent] = []
|
125 |
+
prompt_message_contents.append(TextPromptMessageContent(data=prompt))
|
126 |
+
for file in files:
|
127 |
+
prompt_message_contents.append(file_manager.to_prompt_message_content(file))
|
128 |
+
|
129 |
+
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
130 |
+
else:
|
131 |
+
prompt_messages.append(UserPromptMessage(content=prompt))
|
132 |
+
|
133 |
+
return prompt_messages
|
134 |
+
|
135 |
+
def _get_chat_model_prompt_messages(
|
136 |
+
self,
|
137 |
+
prompt_template: list[ChatModelMessage],
|
138 |
+
inputs: Mapping[str, str],
|
139 |
+
query: Optional[str],
|
140 |
+
files: Sequence[File],
|
141 |
+
context: Optional[str],
|
142 |
+
memory_config: Optional[MemoryConfig],
|
143 |
+
memory: Optional[TokenBufferMemory],
|
144 |
+
model_config: ModelConfigWithCredentialsEntity,
|
145 |
+
) -> list[PromptMessage]:
|
146 |
+
"""
|
147 |
+
Get chat model prompt messages.
|
148 |
+
"""
|
149 |
+
prompt_messages: list[PromptMessage] = []
|
150 |
+
for prompt_item in prompt_template:
|
151 |
+
raw_prompt = prompt_item.text
|
152 |
+
|
153 |
+
if prompt_item.edition_type == "basic" or not prompt_item.edition_type:
|
154 |
+
if self.with_variable_tmpl:
|
155 |
+
vp = VariablePool()
|
156 |
+
for k, v in inputs.items():
|
157 |
+
if k.startswith("#"):
|
158 |
+
vp.add(k[1:-1].split("."), v)
|
159 |
+
raw_prompt = raw_prompt.replace("{{#context#}}", context or "")
|
160 |
+
prompt = vp.convert_template(raw_prompt).text
|
161 |
+
else:
|
162 |
+
parser = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
|
163 |
+
prompt_inputs: Mapping[str, str] = {k: inputs[k] for k in parser.variable_keys if k in inputs}
|
164 |
+
prompt_inputs = self._set_context_variable(
|
165 |
+
context=context, parser=parser, prompt_inputs=prompt_inputs
|
166 |
+
)
|
167 |
+
prompt = parser.format(prompt_inputs)
|
168 |
+
elif prompt_item.edition_type == "jinja2":
|
169 |
+
prompt = raw_prompt
|
170 |
+
prompt_inputs = inputs
|
171 |
+
prompt = Jinja2Formatter.format(template=prompt, inputs=prompt_inputs)
|
172 |
+
else:
|
173 |
+
raise ValueError(f"Invalid edition type: {prompt_item.edition_type}")
|
174 |
+
|
175 |
+
if prompt_item.role == PromptMessageRole.USER:
|
176 |
+
prompt_messages.append(UserPromptMessage(content=prompt))
|
177 |
+
elif prompt_item.role == PromptMessageRole.SYSTEM and prompt:
|
178 |
+
prompt_messages.append(SystemPromptMessage(content=prompt))
|
179 |
+
elif prompt_item.role == PromptMessageRole.ASSISTANT:
|
180 |
+
prompt_messages.append(AssistantPromptMessage(content=prompt))
|
181 |
+
|
182 |
+
if query and memory_config and memory_config.query_prompt_template:
|
183 |
+
parser = PromptTemplateParser(
|
184 |
+
template=memory_config.query_prompt_template, with_variable_tmpl=self.with_variable_tmpl
|
185 |
+
)
|
186 |
+
prompt_inputs = {k: inputs[k] for k in parser.variable_keys if k in inputs}
|
187 |
+
prompt_inputs["#sys.query#"] = query
|
188 |
+
|
189 |
+
prompt_inputs = self._set_context_variable(context, parser, prompt_inputs)
|
190 |
+
|
191 |
+
query = parser.format(prompt_inputs)
|
192 |
+
|
193 |
+
if memory and memory_config:
|
194 |
+
prompt_messages = self._append_chat_histories(memory, memory_config, prompt_messages, model_config)
|
195 |
+
|
196 |
+
if files and query is not None:
|
197 |
+
prompt_message_contents: list[PromptMessageContent] = []
|
198 |
+
prompt_message_contents.append(TextPromptMessageContent(data=query))
|
199 |
+
for file in files:
|
200 |
+
prompt_message_contents.append(file_manager.to_prompt_message_content(file))
|
201 |
+
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
202 |
+
else:
|
203 |
+
prompt_messages.append(UserPromptMessage(content=query))
|
204 |
+
elif files:
|
205 |
+
if not query:
|
206 |
+
# get last message
|
207 |
+
last_message = prompt_messages[-1] if prompt_messages else None
|
208 |
+
if last_message and last_message.role == PromptMessageRole.USER:
|
209 |
+
# get last user message content and add files
|
210 |
+
prompt_message_contents = [TextPromptMessageContent(data=cast(str, last_message.content))]
|
211 |
+
for file in files:
|
212 |
+
prompt_message_contents.append(file_manager.to_prompt_message_content(file))
|
213 |
+
|
214 |
+
last_message.content = prompt_message_contents
|
215 |
+
else:
|
216 |
+
prompt_message_contents = [TextPromptMessageContent(data="")] # not for query
|
217 |
+
for file in files:
|
218 |
+
prompt_message_contents.append(file_manager.to_prompt_message_content(file))
|
219 |
+
|
220 |
+
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
221 |
+
else:
|
222 |
+
prompt_message_contents = [TextPromptMessageContent(data=query)]
|
223 |
+
for file in files:
|
224 |
+
prompt_message_contents.append(file_manager.to_prompt_message_content(file))
|
225 |
+
|
226 |
+
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
227 |
+
elif query:
|
228 |
+
prompt_messages.append(UserPromptMessage(content=query))
|
229 |
+
|
230 |
+
return prompt_messages
|
231 |
+
|
232 |
+
def _set_context_variable(
|
233 |
+
self, context: str | None, parser: PromptTemplateParser, prompt_inputs: Mapping[str, str]
|
234 |
+
) -> Mapping[str, str]:
|
235 |
+
prompt_inputs = dict(prompt_inputs)
|
236 |
+
if "#context#" in parser.variable_keys:
|
237 |
+
if context:
|
238 |
+
prompt_inputs["#context#"] = context
|
239 |
+
else:
|
240 |
+
prompt_inputs["#context#"] = ""
|
241 |
+
|
242 |
+
return prompt_inputs
|
243 |
+
|
244 |
+
def _set_query_variable(
|
245 |
+
self, query: str, parser: PromptTemplateParser, prompt_inputs: Mapping[str, str]
|
246 |
+
) -> Mapping[str, str]:
|
247 |
+
prompt_inputs = dict(prompt_inputs)
|
248 |
+
if "#query#" in parser.variable_keys:
|
249 |
+
if query:
|
250 |
+
prompt_inputs["#query#"] = query
|
251 |
+
else:
|
252 |
+
prompt_inputs["#query#"] = ""
|
253 |
+
|
254 |
+
return prompt_inputs
|
255 |
+
|
256 |
+
def _set_histories_variable(
|
257 |
+
self,
|
258 |
+
memory: TokenBufferMemory,
|
259 |
+
memory_config: MemoryConfig,
|
260 |
+
raw_prompt: str,
|
261 |
+
role_prefix: MemoryConfig.RolePrefix,
|
262 |
+
parser: PromptTemplateParser,
|
263 |
+
prompt_inputs: Mapping[str, str],
|
264 |
+
model_config: ModelConfigWithCredentialsEntity,
|
265 |
+
) -> Mapping[str, str]:
|
266 |
+
prompt_inputs = dict(prompt_inputs)
|
267 |
+
if "#histories#" in parser.variable_keys:
|
268 |
+
if memory:
|
269 |
+
inputs = {"#histories#": "", **prompt_inputs}
|
270 |
+
parser = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
|
271 |
+
prompt_inputs = {k: inputs[k] for k in parser.variable_keys if k in inputs}
|
272 |
+
tmp_human_message = UserPromptMessage(content=parser.format(prompt_inputs))
|
273 |
+
|
274 |
+
rest_tokens = self._calculate_rest_token([tmp_human_message], model_config)
|
275 |
+
|
276 |
+
histories = self._get_history_messages_from_memory(
|
277 |
+
memory=memory,
|
278 |
+
memory_config=memory_config,
|
279 |
+
max_token_limit=rest_tokens,
|
280 |
+
human_prefix=role_prefix.user,
|
281 |
+
ai_prefix=role_prefix.assistant,
|
282 |
+
)
|
283 |
+
prompt_inputs["#histories#"] = histories
|
284 |
+
else:
|
285 |
+
prompt_inputs["#histories#"] = ""
|
286 |
+
|
287 |
+
return prompt_inputs
|
api/core/prompt/agent_history_prompt_transform.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, cast
|
2 |
+
|
3 |
+
from core.app.entities.app_invoke_entities import (
|
4 |
+
ModelConfigWithCredentialsEntity,
|
5 |
+
)
|
6 |
+
from core.memory.token_buffer_memory import TokenBufferMemory
|
7 |
+
from core.model_runtime.entities.message_entities import (
|
8 |
+
PromptMessage,
|
9 |
+
SystemPromptMessage,
|
10 |
+
UserPromptMessage,
|
11 |
+
)
|
12 |
+
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
13 |
+
from core.prompt.prompt_transform import PromptTransform
|
14 |
+
|
15 |
+
|
16 |
+
class AgentHistoryPromptTransform(PromptTransform):
|
17 |
+
"""
|
18 |
+
History Prompt Transform for Agent App
|
19 |
+
"""
|
20 |
+
|
21 |
+
def __init__(
|
22 |
+
self,
|
23 |
+
model_config: ModelConfigWithCredentialsEntity,
|
24 |
+
prompt_messages: list[PromptMessage],
|
25 |
+
history_messages: list[PromptMessage],
|
26 |
+
memory: Optional[TokenBufferMemory] = None,
|
27 |
+
):
|
28 |
+
self.model_config = model_config
|
29 |
+
self.prompt_messages = prompt_messages
|
30 |
+
self.history_messages = history_messages
|
31 |
+
self.memory = memory
|
32 |
+
|
33 |
+
def get_prompt(self) -> list[PromptMessage]:
|
34 |
+
prompt_messages: list[PromptMessage] = []
|
35 |
+
num_system = 0
|
36 |
+
for prompt_message in self.history_messages:
|
37 |
+
if isinstance(prompt_message, SystemPromptMessage):
|
38 |
+
prompt_messages.append(prompt_message)
|
39 |
+
num_system += 1
|
40 |
+
|
41 |
+
if not self.memory:
|
42 |
+
return prompt_messages
|
43 |
+
|
44 |
+
max_token_limit = self._calculate_rest_token(self.prompt_messages, self.model_config)
|
45 |
+
|
46 |
+
model_type_instance = self.model_config.provider_model_bundle.model_type_instance
|
47 |
+
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
48 |
+
|
49 |
+
curr_message_tokens = model_type_instance.get_num_tokens(
|
50 |
+
self.memory.model_instance.model, self.memory.model_instance.credentials, self.history_messages
|
51 |
+
)
|
52 |
+
if curr_message_tokens <= max_token_limit:
|
53 |
+
return self.history_messages
|
54 |
+
|
55 |
+
# number of prompt has been appended in current message
|
56 |
+
num_prompt = 0
|
57 |
+
# append prompt messages in desc order
|
58 |
+
for prompt_message in self.history_messages[::-1]:
|
59 |
+
if isinstance(prompt_message, SystemPromptMessage):
|
60 |
+
continue
|
61 |
+
prompt_messages.append(prompt_message)
|
62 |
+
num_prompt += 1
|
63 |
+
# a message is start with UserPromptMessage
|
64 |
+
if isinstance(prompt_message, UserPromptMessage):
|
65 |
+
curr_message_tokens = model_type_instance.get_num_tokens(
|
66 |
+
self.memory.model_instance.model, self.memory.model_instance.credentials, prompt_messages
|
67 |
+
)
|
68 |
+
# if current message token is overflow, drop all the prompts in current message and break
|
69 |
+
if curr_message_tokens > max_token_limit:
|
70 |
+
prompt_messages = prompt_messages[:-num_prompt]
|
71 |
+
break
|
72 |
+
num_prompt = 0
|
73 |
+
# return prompt messages in asc order
|
74 |
+
message_prompts = prompt_messages[num_system:]
|
75 |
+
message_prompts.reverse()
|
76 |
+
|
77 |
+
# merge system and message prompt
|
78 |
+
prompt_messages = prompt_messages[:num_system]
|
79 |
+
prompt_messages.extend(message_prompts)
|
80 |
+
return prompt_messages
|
api/core/prompt/entities/__init__.py
ADDED
File without changes
|
api/core/prompt/entities/advanced_prompt_entities.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Literal, Optional
|
2 |
+
|
3 |
+
from pydantic import BaseModel
|
4 |
+
|
5 |
+
from core.model_runtime.entities.message_entities import PromptMessageRole
|
6 |
+
|
7 |
+
|
8 |
+
class ChatModelMessage(BaseModel):
|
9 |
+
"""
|
10 |
+
Chat Message.
|
11 |
+
"""
|
12 |
+
|
13 |
+
text: str
|
14 |
+
role: PromptMessageRole
|
15 |
+
edition_type: Optional[Literal["basic", "jinja2"]] = None
|
16 |
+
|
17 |
+
|
18 |
+
class CompletionModelPromptTemplate(BaseModel):
|
19 |
+
"""
|
20 |
+
Completion Model Prompt Template.
|
21 |
+
"""
|
22 |
+
|
23 |
+
text: str
|
24 |
+
edition_type: Optional[Literal["basic", "jinja2"]] = None
|
25 |
+
|
26 |
+
|
27 |
+
class MemoryConfig(BaseModel):
|
28 |
+
"""
|
29 |
+
Memory Config.
|
30 |
+
"""
|
31 |
+
|
32 |
+
class RolePrefix(BaseModel):
|
33 |
+
"""
|
34 |
+
Role Prefix.
|
35 |
+
"""
|
36 |
+
|
37 |
+
user: str
|
38 |
+
assistant: str
|
39 |
+
|
40 |
+
class WindowConfig(BaseModel):
|
41 |
+
"""
|
42 |
+
Window Config.
|
43 |
+
"""
|
44 |
+
|
45 |
+
enabled: bool
|
46 |
+
size: Optional[int] = None
|
47 |
+
|
48 |
+
role_prefix: Optional[RolePrefix] = None
|
49 |
+
window: WindowConfig
|
50 |
+
query_prompt_template: Optional[str] = None
|
api/core/prompt/prompt_templates/__init__.py
ADDED
File without changes
|
api/core/prompt/prompt_templates/advanced_prompt_templates.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
CONTEXT = "Use the following context as your learned knowledge, inside <context></context> XML tags.\n\n<context>\n{{#context#}}\n</context>\n\nWhen answer to user:\n- If you don't know, just say that you don't know.\n- If you don't know when you are not sure, ask for clarification.\nAvoid mentioning that you obtained the information from the context.\nAnd answer according to the language of the user's question.\n" # noqa: E501
|
2 |
+
|
3 |
+
BAICHUAN_CONTEXT = "用户在与一个客观的助手对话。助手会尊重找到的材料,给出全面专业的解释,但不会过度演绎。同时回答中不会暴露引用的材料:\n\n```\n{{#context#}}\n```\n\n" # noqa: E501
|
4 |
+
|
5 |
+
CHAT_APP_COMPLETION_PROMPT_CONFIG = {
|
6 |
+
"completion_prompt_config": {
|
7 |
+
"prompt": {
|
8 |
+
"text": "{{#pre_prompt#}}\nHere are the chat histories between human and assistant, inside <histories></histories> XML tags.\n\n<histories>\n{{#histories#}}\n</histories>\n\n\nHuman: {{#query#}}\n\nAssistant: " # noqa: E501
|
9 |
+
},
|
10 |
+
"conversation_histories_role": {"user_prefix": "Human", "assistant_prefix": "Assistant"},
|
11 |
+
},
|
12 |
+
"stop": ["Human:"],
|
13 |
+
}
|
14 |
+
|
15 |
+
CHAT_APP_CHAT_PROMPT_CONFIG = {"chat_prompt_config": {"prompt": [{"role": "system", "text": "{{#pre_prompt#}}"}]}}
|
16 |
+
|
17 |
+
COMPLETION_APP_CHAT_PROMPT_CONFIG = {"chat_prompt_config": {"prompt": [{"role": "user", "text": "{{#pre_prompt#}}"}]}}
|
18 |
+
|
19 |
+
COMPLETION_APP_COMPLETION_PROMPT_CONFIG = {
|
20 |
+
"completion_prompt_config": {"prompt": {"text": "{{#pre_prompt#}}"}},
|
21 |
+
"stop": ["Human:"],
|
22 |
+
}
|
23 |
+
|
24 |
+
BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG = {
|
25 |
+
"completion_prompt_config": {
|
26 |
+
"prompt": {
|
27 |
+
"text": "{{#pre_prompt#}}\n\n用户和助手的历史对话内容如下:\n```\n{{#histories#}}\n```\n\n\n\n用户:{{#query#}}" # noqa: E501
|
28 |
+
},
|
29 |
+
"conversation_histories_role": {"user_prefix": "用户", "assistant_prefix": "助手"},
|
30 |
+
},
|
31 |
+
"stop": ["用户:"],
|
32 |
+
}
|
33 |
+
|
34 |
+
BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG = {
|
35 |
+
"chat_prompt_config": {"prompt": [{"role": "system", "text": "{{#pre_prompt#}}"}]}
|
36 |
+
}
|
37 |
+
|
38 |
+
BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG = {
|
39 |
+
"chat_prompt_config": {"prompt": [{"role": "user", "text": "{{#pre_prompt#}}"}]}
|
40 |
+
}
|
41 |
+
|
42 |
+
BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG = {
|
43 |
+
"completion_prompt_config": {"prompt": {"text": "{{#pre_prompt#}}"}},
|
44 |
+
"stop": ["用户:"],
|
45 |
+
}
|
api/core/prompt/prompt_templates/baichuan_chat.json
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"human_prefix": "用户",
|
3 |
+
"assistant_prefix": "助手",
|
4 |
+
"context_prompt": "用户在与一个客观的助手对话。助手会尊重找到的材料,给出全面专业的解释,但不会过度演绎。同时回答中不会暴露引用的材料:\n\n```\n{{#context#}}\n```\n\n",
|
5 |
+
"histories_prompt": "用户和助手的历史对话内容如下:\n```\n{{#histories#}}\n```\n\n",
|
6 |
+
"system_prompt_orders": [
|
7 |
+
"context_prompt",
|
8 |
+
"pre_prompt",
|
9 |
+
"histories_prompt"
|
10 |
+
],
|
11 |
+
"query_prompt": "\n\n用户:{{#query#}}",
|
12 |
+
"stops": ["用户:"]
|
13 |
+
}
|
api/core/prompt/prompt_templates/baichuan_completion.json
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"context_prompt": "用户在与一个客观的助手对话。助手会尊重找到的材料,给出全面专业的解释,但不会过度演绎。同时回答中不会暴露引用的材料:\n\n```\n{{#context#}}\n```\n",
|
3 |
+
"system_prompt_orders": [
|
4 |
+
"context_prompt",
|
5 |
+
"pre_prompt"
|
6 |
+
],
|
7 |
+
"query_prompt": "{{#query#}}",
|
8 |
+
"stops": null
|
9 |
+
}
|
api/core/prompt/prompt_templates/common_chat.json
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"human_prefix": "Human",
|
3 |
+
"assistant_prefix": "Assistant",
|
4 |
+
"context_prompt": "Use the following context as your learned knowledge, inside <context></context> XML tags.\n\n<context>\n{{#context#}}\n</context>\n\nWhen answer to user:\n- If you don't know, just say that you don't know.\n- If you don't know when you are not sure, ask for clarification.\nAvoid mentioning that you obtained the information from the context.\nAnd answer according to the language of the user's question.\n\n",
|
5 |
+
"histories_prompt": "Here is the chat histories between human and assistant, inside <histories></histories> XML tags.\n\n<histories>\n{{#histories#}}\n</histories>\n\n",
|
6 |
+
"system_prompt_orders": [
|
7 |
+
"context_prompt",
|
8 |
+
"pre_prompt",
|
9 |
+
"histories_prompt"
|
10 |
+
],
|
11 |
+
"query_prompt": "\n\nHuman: {{#query#}}\n\nAssistant: ",
|
12 |
+
"stops": ["\nHuman:", "</histories>"]
|
13 |
+
}
|
api/core/prompt/prompt_templates/common_completion.json
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"context_prompt": "Use the following context as your learned knowledge, inside <context></context> XML tags.\n\n<context>\n{{#context#}}\n</context>\n\nWhen answer to user:\n- If you don't know, just say that you don't know.\n- If you don't know when you are not sure, ask for clarification.\nAvoid mentioning that you obtained the information from the context.\nAnd answer according to the language of the user's question.\n\n",
|
3 |
+
"system_prompt_orders": [
|
4 |
+
"context_prompt",
|
5 |
+
"pre_prompt"
|
6 |
+
],
|
7 |
+
"query_prompt": "{{#query#}}",
|
8 |
+
"stops": null
|
9 |
+
}
|
api/core/prompt/prompt_transform.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Optional
|
2 |
+
|
3 |
+
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
4 |
+
from core.memory.token_buffer_memory import TokenBufferMemory
|
5 |
+
from core.model_manager import ModelInstance
|
6 |
+
from core.model_runtime.entities.message_entities import PromptMessage
|
7 |
+
from core.model_runtime.entities.model_entities import ModelPropertyKey
|
8 |
+
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
9 |
+
|
10 |
+
|
11 |
+
class PromptTransform:
|
12 |
+
def _append_chat_histories(
|
13 |
+
self,
|
14 |
+
memory: TokenBufferMemory,
|
15 |
+
memory_config: MemoryConfig,
|
16 |
+
prompt_messages: list[PromptMessage],
|
17 |
+
model_config: ModelConfigWithCredentialsEntity,
|
18 |
+
) -> list[PromptMessage]:
|
19 |
+
rest_tokens = self._calculate_rest_token(prompt_messages, model_config)
|
20 |
+
histories = self._get_history_messages_list_from_memory(memory, memory_config, rest_tokens)
|
21 |
+
prompt_messages.extend(histories)
|
22 |
+
|
23 |
+
return prompt_messages
|
24 |
+
|
25 |
+
def _calculate_rest_token(
|
26 |
+
self, prompt_messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity
|
27 |
+
) -> int:
|
28 |
+
rest_tokens = 2000
|
29 |
+
|
30 |
+
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
|
31 |
+
if model_context_tokens:
|
32 |
+
model_instance = ModelInstance(
|
33 |
+
provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
|
34 |
+
)
|
35 |
+
|
36 |
+
curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages)
|
37 |
+
|
38 |
+
max_tokens = 0
|
39 |
+
for parameter_rule in model_config.model_schema.parameter_rules:
|
40 |
+
if parameter_rule.name == "max_tokens" or (
|
41 |
+
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
|
42 |
+
):
|
43 |
+
max_tokens = (
|
44 |
+
model_config.parameters.get(parameter_rule.name)
|
45 |
+
or model_config.parameters.get(parameter_rule.use_template or "")
|
46 |
+
) or 0
|
47 |
+
|
48 |
+
rest_tokens = model_context_tokens - max_tokens - curr_message_tokens
|
49 |
+
rest_tokens = max(rest_tokens, 0)
|
50 |
+
|
51 |
+
return rest_tokens
|
52 |
+
|
53 |
+
def _get_history_messages_from_memory(
|
54 |
+
self,
|
55 |
+
memory: TokenBufferMemory,
|
56 |
+
memory_config: MemoryConfig,
|
57 |
+
max_token_limit: int,
|
58 |
+
human_prefix: Optional[str] = None,
|
59 |
+
ai_prefix: Optional[str] = None,
|
60 |
+
) -> str:
|
61 |
+
"""Get memory messages."""
|
62 |
+
kwargs: dict[str, Any] = {"max_token_limit": max_token_limit}
|
63 |
+
|
64 |
+
if human_prefix:
|
65 |
+
kwargs["human_prefix"] = human_prefix
|
66 |
+
|
67 |
+
if ai_prefix:
|
68 |
+
kwargs["ai_prefix"] = ai_prefix
|
69 |
+
|
70 |
+
if memory_config.window.enabled and memory_config.window.size is not None and memory_config.window.size > 0:
|
71 |
+
kwargs["message_limit"] = memory_config.window.size
|
72 |
+
|
73 |
+
return memory.get_history_prompt_text(**kwargs)
|
74 |
+
|
75 |
+
def _get_history_messages_list_from_memory(
|
76 |
+
self, memory: TokenBufferMemory, memory_config: MemoryConfig, max_token_limit: int
|
77 |
+
) -> list[PromptMessage]:
|
78 |
+
"""Get memory messages."""
|
79 |
+
return list(
|
80 |
+
memory.get_history_prompt_messages(
|
81 |
+
max_token_limit=max_token_limit,
|
82 |
+
message_limit=memory_config.window.size
|
83 |
+
if (
|
84 |
+
memory_config.window.enabled
|
85 |
+
and memory_config.window.size is not None
|
86 |
+
and memory_config.window.size > 0
|
87 |
+
)
|
88 |
+
else None,
|
89 |
+
)
|
90 |
+
)
|
api/core/prompt/simple_prompt_transform.py
ADDED
@@ -0,0 +1,327 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import enum
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
from collections.abc import Mapping, Sequence
|
5 |
+
from typing import TYPE_CHECKING, Any, Optional, cast
|
6 |
+
|
7 |
+
from core.app.app_config.entities import PromptTemplateEntity
|
8 |
+
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
9 |
+
from core.file import file_manager
|
10 |
+
from core.memory.token_buffer_memory import TokenBufferMemory
|
11 |
+
from core.model_runtime.entities.message_entities import (
|
12 |
+
PromptMessage,
|
13 |
+
PromptMessageContent,
|
14 |
+
SystemPromptMessage,
|
15 |
+
TextPromptMessageContent,
|
16 |
+
UserPromptMessage,
|
17 |
+
)
|
18 |
+
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
19 |
+
from core.prompt.prompt_transform import PromptTransform
|
20 |
+
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
21 |
+
from models.model import AppMode
|
22 |
+
|
23 |
+
if TYPE_CHECKING:
|
24 |
+
from core.file.models import File
|
25 |
+
|
26 |
+
|
27 |
+
class ModelMode(enum.StrEnum):
|
28 |
+
COMPLETION = "completion"
|
29 |
+
CHAT = "chat"
|
30 |
+
|
31 |
+
@classmethod
|
32 |
+
def value_of(cls, value: str) -> "ModelMode":
|
33 |
+
"""
|
34 |
+
Get value of given mode.
|
35 |
+
|
36 |
+
:param value: mode value
|
37 |
+
:return: mode
|
38 |
+
"""
|
39 |
+
for mode in cls:
|
40 |
+
if mode.value == value:
|
41 |
+
return mode
|
42 |
+
raise ValueError(f"invalid mode value {value}")
|
43 |
+
|
44 |
+
|
45 |
+
prompt_file_contents: dict[str, Any] = {}
|
46 |
+
|
47 |
+
|
48 |
+
class SimplePromptTransform(PromptTransform):
|
49 |
+
"""
|
50 |
+
Simple Prompt Transform for Chatbot App Basic Mode.
|
51 |
+
"""
|
52 |
+
|
53 |
+
def get_prompt(
|
54 |
+
self,
|
55 |
+
app_mode: AppMode,
|
56 |
+
prompt_template_entity: PromptTemplateEntity,
|
57 |
+
inputs: Mapping[str, str],
|
58 |
+
query: str,
|
59 |
+
files: Sequence["File"],
|
60 |
+
context: Optional[str],
|
61 |
+
memory: Optional[TokenBufferMemory],
|
62 |
+
model_config: ModelConfigWithCredentialsEntity,
|
63 |
+
) -> tuple[list[PromptMessage], Optional[list[str]]]:
|
64 |
+
inputs = {key: str(value) for key, value in inputs.items()}
|
65 |
+
|
66 |
+
model_mode = ModelMode.value_of(model_config.mode)
|
67 |
+
if model_mode == ModelMode.CHAT:
|
68 |
+
prompt_messages, stops = self._get_chat_model_prompt_messages(
|
69 |
+
app_mode=app_mode,
|
70 |
+
pre_prompt=prompt_template_entity.simple_prompt_template or "",
|
71 |
+
inputs=inputs,
|
72 |
+
query=query,
|
73 |
+
files=files,
|
74 |
+
context=context,
|
75 |
+
memory=memory,
|
76 |
+
model_config=model_config,
|
77 |
+
)
|
78 |
+
else:
|
79 |
+
prompt_messages, stops = self._get_completion_model_prompt_messages(
|
80 |
+
app_mode=app_mode,
|
81 |
+
pre_prompt=prompt_template_entity.simple_prompt_template or "",
|
82 |
+
inputs=inputs,
|
83 |
+
query=query,
|
84 |
+
files=files,
|
85 |
+
context=context,
|
86 |
+
memory=memory,
|
87 |
+
model_config=model_config,
|
88 |
+
)
|
89 |
+
|
90 |
+
return prompt_messages, stops
|
91 |
+
|
92 |
+
def get_prompt_str_and_rules(
|
93 |
+
self,
|
94 |
+
app_mode: AppMode,
|
95 |
+
model_config: ModelConfigWithCredentialsEntity,
|
96 |
+
pre_prompt: str,
|
97 |
+
inputs: dict,
|
98 |
+
query: Optional[str] = None,
|
99 |
+
context: Optional[str] = None,
|
100 |
+
histories: Optional[str] = None,
|
101 |
+
) -> tuple[str, dict]:
|
102 |
+
# get prompt template
|
103 |
+
prompt_template_config = self.get_prompt_template(
|
104 |
+
app_mode=app_mode,
|
105 |
+
provider=model_config.provider,
|
106 |
+
model=model_config.model,
|
107 |
+
pre_prompt=pre_prompt,
|
108 |
+
has_context=context is not None,
|
109 |
+
query_in_prompt=query is not None,
|
110 |
+
with_memory_prompt=histories is not None,
|
111 |
+
)
|
112 |
+
|
113 |
+
variables = {k: inputs[k] for k in prompt_template_config["custom_variable_keys"] if k in inputs}
|
114 |
+
|
115 |
+
for v in prompt_template_config["special_variable_keys"]:
|
116 |
+
# support #context#, #query# and #histories#
|
117 |
+
if v == "#context#":
|
118 |
+
variables["#context#"] = context or ""
|
119 |
+
elif v == "#query#":
|
120 |
+
variables["#query#"] = query or ""
|
121 |
+
elif v == "#histories#":
|
122 |
+
variables["#histories#"] = histories or ""
|
123 |
+
|
124 |
+
prompt_template = prompt_template_config["prompt_template"]
|
125 |
+
prompt = prompt_template.format(variables)
|
126 |
+
|
127 |
+
return prompt, prompt_template_config["prompt_rules"]
|
128 |
+
|
129 |
+
def get_prompt_template(
|
130 |
+
self,
|
131 |
+
app_mode: AppMode,
|
132 |
+
provider: str,
|
133 |
+
model: str,
|
134 |
+
pre_prompt: str,
|
135 |
+
has_context: bool,
|
136 |
+
query_in_prompt: bool,
|
137 |
+
with_memory_prompt: bool = False,
|
138 |
+
) -> dict:
|
139 |
+
prompt_rules = self._get_prompt_rule(app_mode=app_mode, provider=provider, model=model)
|
140 |
+
|
141 |
+
custom_variable_keys = []
|
142 |
+
special_variable_keys = []
|
143 |
+
|
144 |
+
prompt = ""
|
145 |
+
for order in prompt_rules["system_prompt_orders"]:
|
146 |
+
if order == "context_prompt" and has_context:
|
147 |
+
prompt += prompt_rules["context_prompt"]
|
148 |
+
special_variable_keys.append("#context#")
|
149 |
+
elif order == "pre_prompt" and pre_prompt:
|
150 |
+
prompt += pre_prompt + "\n"
|
151 |
+
pre_prompt_template = PromptTemplateParser(template=pre_prompt)
|
152 |
+
custom_variable_keys = pre_prompt_template.variable_keys
|
153 |
+
elif order == "histories_prompt" and with_memory_prompt:
|
154 |
+
prompt += prompt_rules["histories_prompt"]
|
155 |
+
special_variable_keys.append("#histories#")
|
156 |
+
|
157 |
+
if query_in_prompt:
|
158 |
+
prompt += prompt_rules.get("query_prompt", "{{#query#}}")
|
159 |
+
special_variable_keys.append("#query#")
|
160 |
+
|
161 |
+
return {
|
162 |
+
"prompt_template": PromptTemplateParser(template=prompt),
|
163 |
+
"custom_variable_keys": custom_variable_keys,
|
164 |
+
"special_variable_keys": special_variable_keys,
|
165 |
+
"prompt_rules": prompt_rules,
|
166 |
+
}
|
167 |
+
|
168 |
+
def _get_chat_model_prompt_messages(
|
169 |
+
self,
|
170 |
+
app_mode: AppMode,
|
171 |
+
pre_prompt: str,
|
172 |
+
inputs: dict,
|
173 |
+
query: str,
|
174 |
+
context: Optional[str],
|
175 |
+
files: Sequence["File"],
|
176 |
+
memory: Optional[TokenBufferMemory],
|
177 |
+
model_config: ModelConfigWithCredentialsEntity,
|
178 |
+
) -> tuple[list[PromptMessage], Optional[list[str]]]:
|
179 |
+
prompt_messages: list[PromptMessage] = []
|
180 |
+
|
181 |
+
# get prompt
|
182 |
+
prompt, _ = self.get_prompt_str_and_rules(
|
183 |
+
app_mode=app_mode,
|
184 |
+
model_config=model_config,
|
185 |
+
pre_prompt=pre_prompt,
|
186 |
+
inputs=inputs,
|
187 |
+
query=None,
|
188 |
+
context=context,
|
189 |
+
)
|
190 |
+
|
191 |
+
if prompt and query:
|
192 |
+
prompt_messages.append(SystemPromptMessage(content=prompt))
|
193 |
+
|
194 |
+
if memory:
|
195 |
+
prompt_messages = self._append_chat_histories(
|
196 |
+
memory=memory,
|
197 |
+
memory_config=MemoryConfig(
|
198 |
+
window=MemoryConfig.WindowConfig(
|
199 |
+
enabled=False,
|
200 |
+
)
|
201 |
+
),
|
202 |
+
prompt_messages=prompt_messages,
|
203 |
+
model_config=model_config,
|
204 |
+
)
|
205 |
+
|
206 |
+
if query:
|
207 |
+
prompt_messages.append(self.get_last_user_message(query, files))
|
208 |
+
else:
|
209 |
+
prompt_messages.append(self.get_last_user_message(prompt, files))
|
210 |
+
|
211 |
+
return prompt_messages, None
|
212 |
+
|
213 |
+
def _get_completion_model_prompt_messages(
|
214 |
+
self,
|
215 |
+
app_mode: AppMode,
|
216 |
+
pre_prompt: str,
|
217 |
+
inputs: dict,
|
218 |
+
query: str,
|
219 |
+
context: Optional[str],
|
220 |
+
files: Sequence["File"],
|
221 |
+
memory: Optional[TokenBufferMemory],
|
222 |
+
model_config: ModelConfigWithCredentialsEntity,
|
223 |
+
) -> tuple[list[PromptMessage], Optional[list[str]]]:
|
224 |
+
# get prompt
|
225 |
+
prompt, prompt_rules = self.get_prompt_str_and_rules(
|
226 |
+
app_mode=app_mode,
|
227 |
+
model_config=model_config,
|
228 |
+
pre_prompt=pre_prompt,
|
229 |
+
inputs=inputs,
|
230 |
+
query=query,
|
231 |
+
context=context,
|
232 |
+
)
|
233 |
+
|
234 |
+
if memory:
|
235 |
+
tmp_human_message = UserPromptMessage(content=prompt)
|
236 |
+
|
237 |
+
rest_tokens = self._calculate_rest_token([tmp_human_message], model_config)
|
238 |
+
histories = self._get_history_messages_from_memory(
|
239 |
+
memory=memory,
|
240 |
+
memory_config=MemoryConfig(
|
241 |
+
window=MemoryConfig.WindowConfig(
|
242 |
+
enabled=False,
|
243 |
+
)
|
244 |
+
),
|
245 |
+
max_token_limit=rest_tokens,
|
246 |
+
human_prefix=prompt_rules.get("human_prefix", "Human"),
|
247 |
+
ai_prefix=prompt_rules.get("assistant_prefix", "Assistant"),
|
248 |
+
)
|
249 |
+
|
250 |
+
# get prompt
|
251 |
+
prompt, prompt_rules = self.get_prompt_str_and_rules(
|
252 |
+
app_mode=app_mode,
|
253 |
+
model_config=model_config,
|
254 |
+
pre_prompt=pre_prompt,
|
255 |
+
inputs=inputs,
|
256 |
+
query=query,
|
257 |
+
context=context,
|
258 |
+
histories=histories,
|
259 |
+
)
|
260 |
+
|
261 |
+
stops = prompt_rules.get("stops")
|
262 |
+
if stops is not None and len(stops) == 0:
|
263 |
+
stops = None
|
264 |
+
|
265 |
+
return [self.get_last_user_message(prompt, files)], stops
|
266 |
+
|
267 |
+
def get_last_user_message(self, prompt: str, files: Sequence["File"]) -> UserPromptMessage:
|
268 |
+
if files:
|
269 |
+
prompt_message_contents: list[PromptMessageContent] = []
|
270 |
+
prompt_message_contents.append(TextPromptMessageContent(data=prompt))
|
271 |
+
for file in files:
|
272 |
+
prompt_message_contents.append(file_manager.to_prompt_message_content(file))
|
273 |
+
|
274 |
+
prompt_message = UserPromptMessage(content=prompt_message_contents)
|
275 |
+
else:
|
276 |
+
prompt_message = UserPromptMessage(content=prompt)
|
277 |
+
|
278 |
+
return prompt_message
|
279 |
+
|
280 |
+
def _get_prompt_rule(self, app_mode: AppMode, provider: str, model: str) -> dict:
|
281 |
+
"""
|
282 |
+
Get simple prompt rule.
|
283 |
+
:param app_mode: app mode
|
284 |
+
:param provider: model provider
|
285 |
+
:param model: model name
|
286 |
+
:return:
|
287 |
+
"""
|
288 |
+
prompt_file_name = self._prompt_file_name(app_mode=app_mode, provider=provider, model=model)
|
289 |
+
|
290 |
+
# Check if the prompt file is already loaded
|
291 |
+
if prompt_file_name in prompt_file_contents:
|
292 |
+
return cast(dict, prompt_file_contents[prompt_file_name])
|
293 |
+
|
294 |
+
# Get the absolute path of the subdirectory
|
295 |
+
prompt_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "prompt_templates")
|
296 |
+
json_file_path = os.path.join(prompt_path, f"{prompt_file_name}.json")
|
297 |
+
|
298 |
+
# Open the JSON file and read its content
|
299 |
+
with open(json_file_path, encoding="utf-8") as json_file:
|
300 |
+
content = json.load(json_file)
|
301 |
+
|
302 |
+
# Store the content of the prompt file
|
303 |
+
prompt_file_contents[prompt_file_name] = content
|
304 |
+
|
305 |
+
return cast(dict, content)
|
306 |
+
|
307 |
+
def _prompt_file_name(self, app_mode: AppMode, provider: str, model: str) -> str:
|
308 |
+
# baichuan
|
309 |
+
is_baichuan = False
|
310 |
+
if provider == "baichuan":
|
311 |
+
is_baichuan = True
|
312 |
+
else:
|
313 |
+
baichuan_supported_providers = ["huggingface_hub", "openllm", "xinference"]
|
314 |
+
if provider in baichuan_supported_providers and "baichuan" in model.lower():
|
315 |
+
is_baichuan = True
|
316 |
+
|
317 |
+
if is_baichuan:
|
318 |
+
if app_mode == AppMode.COMPLETION:
|
319 |
+
return "baichuan_completion"
|
320 |
+
else:
|
321 |
+
return "baichuan_chat"
|
322 |
+
|
323 |
+
# common
|
324 |
+
if app_mode == AppMode.COMPLETION:
|
325 |
+
return "common_completion"
|
326 |
+
else:
|
327 |
+
return "common_chat"
|
api/core/prompt/utils/__init__.py
ADDED
File without changes
|