CatPtain commited on
Commit
bcc0d8a
·
verified ·
1 Parent(s): 93dd3cc

Upload 1150 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +6 -0
  2. api/core/__init__.py +1 -0
  3. api/core/hosting_configuration.py +255 -0
  4. api/core/indexing_runner.py +754 -0
  5. api/core/model_manager.py +559 -0
  6. api/core/moderation/__init__.py +0 -0
  7. api/core/moderation/api/__builtin__ +1 -0
  8. api/core/moderation/api/__init__.py +0 -0
  9. api/core/moderation/api/api.py +96 -0
  10. api/core/moderation/base.py +115 -0
  11. api/core/moderation/factory.py +49 -0
  12. api/core/moderation/input_moderation.py +71 -0
  13. api/core/moderation/keywords/__builtin__ +1 -0
  14. api/core/moderation/keywords/__init__.py +0 -0
  15. api/core/moderation/keywords/keywords.py +73 -0
  16. api/core/moderation/openai_moderation/__builtin__ +1 -0
  17. api/core/moderation/openai_moderation/__init__.py +0 -0
  18. api/core/moderation/openai_moderation/openai_moderation.py +60 -0
  19. api/core/moderation/output_moderation.py +131 -0
  20. api/core/ops/__init__.py +0 -0
  21. api/core/ops/base_trace_instance.py +26 -0
  22. api/core/ops/entities/__init__.py +0 -0
  23. api/core/ops/entities/config_entity.py +92 -0
  24. api/core/ops/entities/trace_entity.py +134 -0
  25. api/core/ops/langfuse_trace/__init__.py +0 -0
  26. api/core/ops/langfuse_trace/entities/__init__.py +0 -0
  27. api/core/ops/langfuse_trace/entities/langfuse_trace_entity.py +282 -0
  28. api/core/ops/langfuse_trace/langfuse_trace.py +455 -0
  29. api/core/ops/langsmith_trace/__init__.py +0 -0
  30. api/core/ops/langsmith_trace/entities/__init__.py +0 -0
  31. api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py +141 -0
  32. api/core/ops/langsmith_trace/langsmith_trace.py +524 -0
  33. api/core/ops/opik_trace/__init__.py +0 -0
  34. api/core/ops/opik_trace/opik_trace.py +469 -0
  35. api/core/ops/ops_trace_manager.py +811 -0
  36. api/core/ops/utils.py +62 -0
  37. api/core/prompt/__init__.py +0 -0
  38. api/core/prompt/advanced_prompt_transform.py +287 -0
  39. api/core/prompt/agent_history_prompt_transform.py +80 -0
  40. api/core/prompt/entities/__init__.py +0 -0
  41. api/core/prompt/entities/advanced_prompt_entities.py +50 -0
  42. api/core/prompt/prompt_templates/__init__.py +0 -0
  43. api/core/prompt/prompt_templates/advanced_prompt_templates.py +45 -0
  44. api/core/prompt/prompt_templates/baichuan_chat.json +13 -0
  45. api/core/prompt/prompt_templates/baichuan_completion.json +9 -0
  46. api/core/prompt/prompt_templates/common_chat.json +13 -0
  47. api/core/prompt/prompt_templates/common_completion.json +9 -0
  48. api/core/prompt/prompt_transform.py +90 -0
  49. api/core/prompt/simple_prompt_transform.py +327 -0
  50. api/core/prompt/utils/__init__.py +0 -0
.gitattributes CHANGED
@@ -6,3 +6,9 @@
6
 
7
  *.sh text eol=lf
8
  api/tests/integration_tests/model_runtime/assets/audio.mp3 filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
6
 
7
  *.sh text eol=lf
8
  api/tests/integration_tests/model_runtime/assets/audio.mp3 filter=lfs diff=lfs merge=lfs -text
9
+ api/core/tools/docs/images/index/image-1.png filter=lfs diff=lfs merge=lfs -text
10
+ api/core/tools/docs/images/index/image-2.png filter=lfs diff=lfs merge=lfs -text
11
+ api/core/tools/docs/images/index/image.png filter=lfs diff=lfs merge=lfs -text
12
+ api/core/tools/provider/builtin/comfyui/_assets/icon.png filter=lfs diff=lfs merge=lfs -text
13
+ api/core/tools/provider/builtin/dalle/_assets/icon.png filter=lfs diff=lfs merge=lfs -text
14
+ api/core/tools/provider/builtin/wecom/_assets/icon.png filter=lfs diff=lfs merge=lfs -text
api/core/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ import core.moderation.base
api/core/hosting_configuration.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ from flask import Flask
4
+ from pydantic import BaseModel
5
+
6
+ from configs import dify_config
7
+ from core.entities.provider_entities import QuotaUnit, RestrictModel
8
+ from core.model_runtime.entities.model_entities import ModelType
9
+ from models.provider import ProviderQuotaType
10
+
11
+
12
+ class HostingQuota(BaseModel):
13
+ quota_type: ProviderQuotaType
14
+ restrict_models: list[RestrictModel] = []
15
+
16
+
17
+ class TrialHostingQuota(HostingQuota):
18
+ quota_type: ProviderQuotaType = ProviderQuotaType.TRIAL
19
+ quota_limit: int = 0
20
+ """Quota limit for the hosting provider models. -1 means unlimited."""
21
+
22
+
23
+ class PaidHostingQuota(HostingQuota):
24
+ quota_type: ProviderQuotaType = ProviderQuotaType.PAID
25
+
26
+
27
+ class FreeHostingQuota(HostingQuota):
28
+ quota_type: ProviderQuotaType = ProviderQuotaType.FREE
29
+
30
+
31
+ class HostingProvider(BaseModel):
32
+ enabled: bool = False
33
+ credentials: Optional[dict] = None
34
+ quota_unit: Optional[QuotaUnit] = None
35
+ quotas: list[HostingQuota] = []
36
+
37
+
38
+ class HostedModerationConfig(BaseModel):
39
+ enabled: bool = False
40
+ providers: list[str] = []
41
+
42
+
43
+ class HostingConfiguration:
44
+ provider_map: dict[str, HostingProvider] = {}
45
+ moderation_config: Optional[HostedModerationConfig] = None
46
+
47
+ def init_app(self, app: Flask) -> None:
48
+ if dify_config.EDITION != "CLOUD":
49
+ return
50
+
51
+ self.provider_map["azure_openai"] = self.init_azure_openai()
52
+ self.provider_map["openai"] = self.init_openai()
53
+ self.provider_map["anthropic"] = self.init_anthropic()
54
+ self.provider_map["minimax"] = self.init_minimax()
55
+ self.provider_map["spark"] = self.init_spark()
56
+ self.provider_map["zhipuai"] = self.init_zhipuai()
57
+
58
+ self.moderation_config = self.init_moderation_config()
59
+
60
+ @staticmethod
61
+ def init_azure_openai() -> HostingProvider:
62
+ quota_unit = QuotaUnit.TIMES
63
+ if dify_config.HOSTED_AZURE_OPENAI_ENABLED:
64
+ credentials = {
65
+ "openai_api_key": dify_config.HOSTED_AZURE_OPENAI_API_KEY,
66
+ "openai_api_base": dify_config.HOSTED_AZURE_OPENAI_API_BASE,
67
+ "base_model_name": "gpt-35-turbo",
68
+ }
69
+
70
+ quotas: list[HostingQuota] = []
71
+ hosted_quota_limit = dify_config.HOSTED_AZURE_OPENAI_QUOTA_LIMIT
72
+ trial_quota = TrialHostingQuota(
73
+ quota_limit=hosted_quota_limit,
74
+ restrict_models=[
75
+ RestrictModel(model="gpt-4", base_model_name="gpt-4", model_type=ModelType.LLM),
76
+ RestrictModel(model="gpt-4o", base_model_name="gpt-4o", model_type=ModelType.LLM),
77
+ RestrictModel(model="gpt-4o-mini", base_model_name="gpt-4o-mini", model_type=ModelType.LLM),
78
+ RestrictModel(model="gpt-4-32k", base_model_name="gpt-4-32k", model_type=ModelType.LLM),
79
+ RestrictModel(
80
+ model="gpt-4-1106-preview", base_model_name="gpt-4-1106-preview", model_type=ModelType.LLM
81
+ ),
82
+ RestrictModel(
83
+ model="gpt-4-vision-preview", base_model_name="gpt-4-vision-preview", model_type=ModelType.LLM
84
+ ),
85
+ RestrictModel(model="gpt-35-turbo", base_model_name="gpt-35-turbo", model_type=ModelType.LLM),
86
+ RestrictModel(
87
+ model="gpt-35-turbo-1106", base_model_name="gpt-35-turbo-1106", model_type=ModelType.LLM
88
+ ),
89
+ RestrictModel(
90
+ model="gpt-35-turbo-instruct", base_model_name="gpt-35-turbo-instruct", model_type=ModelType.LLM
91
+ ),
92
+ RestrictModel(
93
+ model="gpt-35-turbo-16k", base_model_name="gpt-35-turbo-16k", model_type=ModelType.LLM
94
+ ),
95
+ RestrictModel(
96
+ model="text-davinci-003", base_model_name="text-davinci-003", model_type=ModelType.LLM
97
+ ),
98
+ RestrictModel(
99
+ model="text-embedding-ada-002",
100
+ base_model_name="text-embedding-ada-002",
101
+ model_type=ModelType.TEXT_EMBEDDING,
102
+ ),
103
+ RestrictModel(
104
+ model="text-embedding-3-small",
105
+ base_model_name="text-embedding-3-small",
106
+ model_type=ModelType.TEXT_EMBEDDING,
107
+ ),
108
+ RestrictModel(
109
+ model="text-embedding-3-large",
110
+ base_model_name="text-embedding-3-large",
111
+ model_type=ModelType.TEXT_EMBEDDING,
112
+ ),
113
+ ],
114
+ )
115
+ quotas.append(trial_quota)
116
+
117
+ return HostingProvider(enabled=True, credentials=credentials, quota_unit=quota_unit, quotas=quotas)
118
+
119
+ return HostingProvider(
120
+ enabled=False,
121
+ quota_unit=quota_unit,
122
+ )
123
+
124
+ def init_openai(self) -> HostingProvider:
125
+ quota_unit = QuotaUnit.CREDITS
126
+ quotas: list[HostingQuota] = []
127
+
128
+ if dify_config.HOSTED_OPENAI_TRIAL_ENABLED:
129
+ hosted_quota_limit = dify_config.HOSTED_OPENAI_QUOTA_LIMIT
130
+ trial_models = self.parse_restrict_models_from_env("HOSTED_OPENAI_TRIAL_MODELS")
131
+ trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trial_models)
132
+ quotas.append(trial_quota)
133
+
134
+ if dify_config.HOSTED_OPENAI_PAID_ENABLED:
135
+ paid_models = self.parse_restrict_models_from_env("HOSTED_OPENAI_PAID_MODELS")
136
+ paid_quota = PaidHostingQuota(restrict_models=paid_models)
137
+ quotas.append(paid_quota)
138
+
139
+ if len(quotas) > 0:
140
+ credentials = {
141
+ "openai_api_key": dify_config.HOSTED_OPENAI_API_KEY,
142
+ }
143
+
144
+ if dify_config.HOSTED_OPENAI_API_BASE:
145
+ credentials["openai_api_base"] = dify_config.HOSTED_OPENAI_API_BASE
146
+
147
+ if dify_config.HOSTED_OPENAI_API_ORGANIZATION:
148
+ credentials["openai_organization"] = dify_config.HOSTED_OPENAI_API_ORGANIZATION
149
+
150
+ return HostingProvider(enabled=True, credentials=credentials, quota_unit=quota_unit, quotas=quotas)
151
+
152
+ return HostingProvider(
153
+ enabled=False,
154
+ quota_unit=quota_unit,
155
+ )
156
+
157
+ @staticmethod
158
+ def init_anthropic() -> HostingProvider:
159
+ quota_unit = QuotaUnit.TOKENS
160
+ quotas: list[HostingQuota] = []
161
+
162
+ if dify_config.HOSTED_ANTHROPIC_TRIAL_ENABLED:
163
+ hosted_quota_limit = dify_config.HOSTED_ANTHROPIC_QUOTA_LIMIT
164
+ trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit)
165
+ quotas.append(trial_quota)
166
+
167
+ if dify_config.HOSTED_ANTHROPIC_PAID_ENABLED:
168
+ paid_quota = PaidHostingQuota()
169
+ quotas.append(paid_quota)
170
+
171
+ if len(quotas) > 0:
172
+ credentials = {
173
+ "anthropic_api_key": dify_config.HOSTED_ANTHROPIC_API_KEY,
174
+ }
175
+
176
+ if dify_config.HOSTED_ANTHROPIC_API_BASE:
177
+ credentials["anthropic_api_url"] = dify_config.HOSTED_ANTHROPIC_API_BASE
178
+
179
+ return HostingProvider(enabled=True, credentials=credentials, quota_unit=quota_unit, quotas=quotas)
180
+
181
+ return HostingProvider(
182
+ enabled=False,
183
+ quota_unit=quota_unit,
184
+ )
185
+
186
+ @staticmethod
187
+ def init_minimax() -> HostingProvider:
188
+ quota_unit = QuotaUnit.TOKENS
189
+ if dify_config.HOSTED_MINIMAX_ENABLED:
190
+ quotas: list[HostingQuota] = [FreeHostingQuota()]
191
+
192
+ return HostingProvider(
193
+ enabled=True,
194
+ credentials=None, # use credentials from the provider
195
+ quota_unit=quota_unit,
196
+ quotas=quotas,
197
+ )
198
+
199
+ return HostingProvider(
200
+ enabled=False,
201
+ quota_unit=quota_unit,
202
+ )
203
+
204
+ @staticmethod
205
+ def init_spark() -> HostingProvider:
206
+ quota_unit = QuotaUnit.TOKENS
207
+ if dify_config.HOSTED_SPARK_ENABLED:
208
+ quotas: list[HostingQuota] = [FreeHostingQuota()]
209
+
210
+ return HostingProvider(
211
+ enabled=True,
212
+ credentials=None, # use credentials from the provider
213
+ quota_unit=quota_unit,
214
+ quotas=quotas,
215
+ )
216
+
217
+ return HostingProvider(
218
+ enabled=False,
219
+ quota_unit=quota_unit,
220
+ )
221
+
222
+ @staticmethod
223
+ def init_zhipuai() -> HostingProvider:
224
+ quota_unit = QuotaUnit.TOKENS
225
+ if dify_config.HOSTED_ZHIPUAI_ENABLED:
226
+ quotas: list[HostingQuota] = [FreeHostingQuota()]
227
+
228
+ return HostingProvider(
229
+ enabled=True,
230
+ credentials=None, # use credentials from the provider
231
+ quota_unit=quota_unit,
232
+ quotas=quotas,
233
+ )
234
+
235
+ return HostingProvider(
236
+ enabled=False,
237
+ quota_unit=quota_unit,
238
+ )
239
+
240
+ @staticmethod
241
+ def init_moderation_config() -> HostedModerationConfig:
242
+ if dify_config.HOSTED_MODERATION_ENABLED and dify_config.HOSTED_MODERATION_PROVIDERS:
243
+ return HostedModerationConfig(enabled=True, providers=dify_config.HOSTED_MODERATION_PROVIDERS.split(","))
244
+
245
+ return HostedModerationConfig(enabled=False)
246
+
247
+ @staticmethod
248
+ def parse_restrict_models_from_env(env_var: str) -> list[RestrictModel]:
249
+ models_str = dify_config.model_dump().get(env_var)
250
+ models_list = models_str.split(",") if models_str else []
251
+ return [
252
+ RestrictModel(model=model_name.strip(), model_type=ModelType.LLM)
253
+ for model_name in models_list
254
+ if model_name.strip()
255
+ ]
api/core/indexing_runner.py ADDED
@@ -0,0 +1,754 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import concurrent.futures
2
+ import datetime
3
+ import json
4
+ import logging
5
+ import re
6
+ import threading
7
+ import time
8
+ import uuid
9
+ from typing import Any, Optional, cast
10
+
11
+ from flask import current_app
12
+ from flask_login import current_user # type: ignore
13
+ from sqlalchemy.orm.exc import ObjectDeletedError
14
+
15
+ from configs import dify_config
16
+ from core.entities.knowledge_entities import IndexingEstimate, PreviewDetail, QAPreviewDetail
17
+ from core.errors.error import ProviderTokenNotInitError
18
+ from core.model_manager import ModelInstance, ModelManager
19
+ from core.model_runtime.entities.model_entities import ModelType
20
+ from core.rag.cleaner.clean_processor import CleanProcessor
21
+ from core.rag.datasource.keyword.keyword_factory import Keyword
22
+ from core.rag.docstore.dataset_docstore import DatasetDocumentStore
23
+ from core.rag.extractor.entity.extract_setting import ExtractSetting
24
+ from core.rag.index_processor.constant.index_type import IndexType
25
+ from core.rag.index_processor.index_processor_base import BaseIndexProcessor
26
+ from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
27
+ from core.rag.models.document import ChildDocument, Document
28
+ from core.rag.splitter.fixed_text_splitter import (
29
+ EnhanceRecursiveCharacterTextSplitter,
30
+ FixedRecursiveCharacterTextSplitter,
31
+ )
32
+ from core.rag.splitter.text_splitter import TextSplitter
33
+ from core.tools.utils.web_reader_tool import get_image_upload_file_ids
34
+ from extensions.ext_database import db
35
+ from extensions.ext_redis import redis_client
36
+ from extensions.ext_storage import storage
37
+ from libs import helper
38
+ from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment
39
+ from models.dataset import Document as DatasetDocument
40
+ from models.model import UploadFile
41
+ from services.feature_service import FeatureService
42
+
43
+
44
+ class IndexingRunner:
45
+ def __init__(self):
46
+ self.storage = storage
47
+ self.model_manager = ModelManager()
48
+
49
+ def run(self, dataset_documents: list[DatasetDocument]):
50
+ """Run the indexing process."""
51
+ for dataset_document in dataset_documents:
52
+ try:
53
+ # get dataset
54
+ dataset = Dataset.query.filter_by(id=dataset_document.dataset_id).first()
55
+
56
+ if not dataset:
57
+ raise ValueError("no dataset found")
58
+
59
+ # get the process rule
60
+ processing_rule = (
61
+ db.session.query(DatasetProcessRule)
62
+ .filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
63
+ .first()
64
+ )
65
+ if not processing_rule:
66
+ raise ValueError("no process rule found")
67
+ index_type = dataset_document.doc_form
68
+ index_processor = IndexProcessorFactory(index_type).init_index_processor()
69
+ # extract
70
+ text_docs = self._extract(index_processor, dataset_document, processing_rule.to_dict())
71
+
72
+ # transform
73
+ documents = self._transform(
74
+ index_processor, dataset, text_docs, dataset_document.doc_language, processing_rule.to_dict()
75
+ )
76
+ # save segment
77
+ self._load_segments(dataset, dataset_document, documents)
78
+
79
+ # load
80
+ self._load(
81
+ index_processor=index_processor,
82
+ dataset=dataset,
83
+ dataset_document=dataset_document,
84
+ documents=documents,
85
+ )
86
+ except DocumentIsPausedError:
87
+ raise DocumentIsPausedError("Document paused, document id: {}".format(dataset_document.id))
88
+ except ProviderTokenNotInitError as e:
89
+ dataset_document.indexing_status = "error"
90
+ dataset_document.error = str(e.description)
91
+ dataset_document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
92
+ db.session.commit()
93
+ except ObjectDeletedError:
94
+ logging.warning("Document deleted, document id: {}".format(dataset_document.id))
95
+ except Exception as e:
96
+ logging.exception("consume document failed")
97
+ dataset_document.indexing_status = "error"
98
+ dataset_document.error = str(e)
99
+ dataset_document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
100
+ db.session.commit()
101
+
102
+ def run_in_splitting_status(self, dataset_document: DatasetDocument):
103
+ """Run the indexing process when the index_status is splitting."""
104
+ try:
105
+ # get dataset
106
+ dataset = Dataset.query.filter_by(id=dataset_document.dataset_id).first()
107
+
108
+ if not dataset:
109
+ raise ValueError("no dataset found")
110
+
111
+ # get exist document_segment list and delete
112
+ document_segments = DocumentSegment.query.filter_by(
113
+ dataset_id=dataset.id, document_id=dataset_document.id
114
+ ).all()
115
+
116
+ for document_segment in document_segments:
117
+ db.session.delete(document_segment)
118
+ if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
119
+ # delete child chunks
120
+ db.session.query(ChildChunk).filter(ChildChunk.segment_id == document_segment.id).delete()
121
+ db.session.commit()
122
+ # get the process rule
123
+ processing_rule = (
124
+ db.session.query(DatasetProcessRule)
125
+ .filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
126
+ .first()
127
+ )
128
+ if not processing_rule:
129
+ raise ValueError("no process rule found")
130
+
131
+ index_type = dataset_document.doc_form
132
+ index_processor = IndexProcessorFactory(index_type).init_index_processor()
133
+ # extract
134
+ text_docs = self._extract(index_processor, dataset_document, processing_rule.to_dict())
135
+
136
+ # transform
137
+ documents = self._transform(
138
+ index_processor, dataset, text_docs, dataset_document.doc_language, processing_rule.to_dict()
139
+ )
140
+ # save segment
141
+ self._load_segments(dataset, dataset_document, documents)
142
+
143
+ # load
144
+ self._load(
145
+ index_processor=index_processor, dataset=dataset, dataset_document=dataset_document, documents=documents
146
+ )
147
+ except DocumentIsPausedError:
148
+ raise DocumentIsPausedError("Document paused, document id: {}".format(dataset_document.id))
149
+ except ProviderTokenNotInitError as e:
150
+ dataset_document.indexing_status = "error"
151
+ dataset_document.error = str(e.description)
152
+ dataset_document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
153
+ db.session.commit()
154
+ except Exception as e:
155
+ logging.exception("consume document failed")
156
+ dataset_document.indexing_status = "error"
157
+ dataset_document.error = str(e)
158
+ dataset_document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
159
+ db.session.commit()
160
+
161
+ def run_in_indexing_status(self, dataset_document: DatasetDocument):
162
+ """Run the indexing process when the index_status is indexing."""
163
+ try:
164
+ # get dataset
165
+ dataset = Dataset.query.filter_by(id=dataset_document.dataset_id).first()
166
+
167
+ if not dataset:
168
+ raise ValueError("no dataset found")
169
+
170
+ # get exist document_segment list and delete
171
+ document_segments = DocumentSegment.query.filter_by(
172
+ dataset_id=dataset.id, document_id=dataset_document.id
173
+ ).all()
174
+
175
+ documents = []
176
+ if document_segments:
177
+ for document_segment in document_segments:
178
+ # transform segment to node
179
+ if document_segment.status != "completed":
180
+ document = Document(
181
+ page_content=document_segment.content,
182
+ metadata={
183
+ "doc_id": document_segment.index_node_id,
184
+ "doc_hash": document_segment.index_node_hash,
185
+ "document_id": document_segment.document_id,
186
+ "dataset_id": document_segment.dataset_id,
187
+ },
188
+ )
189
+ if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
190
+ child_chunks = document_segment.child_chunks
191
+ if child_chunks:
192
+ child_documents = []
193
+ for child_chunk in child_chunks:
194
+ child_document = ChildDocument(
195
+ page_content=child_chunk.content,
196
+ metadata={
197
+ "doc_id": child_chunk.index_node_id,
198
+ "doc_hash": child_chunk.index_node_hash,
199
+ "document_id": document_segment.document_id,
200
+ "dataset_id": document_segment.dataset_id,
201
+ },
202
+ )
203
+ child_documents.append(child_document)
204
+ document.children = child_documents
205
+ documents.append(document)
206
+
207
+ # build index
208
+ # get the process rule
209
+ processing_rule = (
210
+ db.session.query(DatasetProcessRule)
211
+ .filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
212
+ .first()
213
+ )
214
+
215
+ index_type = dataset_document.doc_form
216
+ index_processor = IndexProcessorFactory(index_type).init_index_processor()
217
+ self._load(
218
+ index_processor=index_processor, dataset=dataset, dataset_document=dataset_document, documents=documents
219
+ )
220
+ except DocumentIsPausedError:
221
+ raise DocumentIsPausedError("Document paused, document id: {}".format(dataset_document.id))
222
+ except ProviderTokenNotInitError as e:
223
+ dataset_document.indexing_status = "error"
224
+ dataset_document.error = str(e.description)
225
+ dataset_document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
226
+ db.session.commit()
227
+ except Exception as e:
228
+ logging.exception("consume document failed")
229
+ dataset_document.indexing_status = "error"
230
+ dataset_document.error = str(e)
231
+ dataset_document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
232
+ db.session.commit()
233
+
234
+ def indexing_estimate(
235
+ self,
236
+ tenant_id: str,
237
+ extract_settings: list[ExtractSetting],
238
+ tmp_processing_rule: dict,
239
+ doc_form: Optional[str] = None,
240
+ doc_language: str = "English",
241
+ dataset_id: Optional[str] = None,
242
+ indexing_technique: str = "economy",
243
+ ) -> IndexingEstimate:
244
+ """
245
+ Estimate the indexing for the document.
246
+ """
247
+ # check document limit
248
+ features = FeatureService.get_features(tenant_id)
249
+ if features.billing.enabled:
250
+ count = len(extract_settings)
251
+ batch_upload_limit = dify_config.BATCH_UPLOAD_LIMIT
252
+ if count > batch_upload_limit:
253
+ raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
254
+
255
+ embedding_model_instance = None
256
+ if dataset_id:
257
+ dataset = Dataset.query.filter_by(id=dataset_id).first()
258
+ if not dataset:
259
+ raise ValueError("Dataset not found.")
260
+ if dataset.indexing_technique == "high_quality" or indexing_technique == "high_quality":
261
+ if dataset.embedding_model_provider:
262
+ embedding_model_instance = self.model_manager.get_model_instance(
263
+ tenant_id=tenant_id,
264
+ provider=dataset.embedding_model_provider,
265
+ model_type=ModelType.TEXT_EMBEDDING,
266
+ model=dataset.embedding_model,
267
+ )
268
+ else:
269
+ embedding_model_instance = self.model_manager.get_default_model_instance(
270
+ tenant_id=tenant_id,
271
+ model_type=ModelType.TEXT_EMBEDDING,
272
+ )
273
+ else:
274
+ if indexing_technique == "high_quality":
275
+ embedding_model_instance = self.model_manager.get_default_model_instance(
276
+ tenant_id=tenant_id,
277
+ model_type=ModelType.TEXT_EMBEDDING,
278
+ )
279
+ preview_texts = [] # type: ignore
280
+
281
+ total_segments = 0
282
+ index_type = doc_form
283
+ index_processor = IndexProcessorFactory(index_type).init_index_processor()
284
+ for extract_setting in extract_settings:
285
+ # extract
286
+ processing_rule = DatasetProcessRule(
287
+ mode=tmp_processing_rule["mode"], rules=json.dumps(tmp_processing_rule["rules"])
288
+ )
289
+ text_docs = index_processor.extract(extract_setting, process_rule_mode=tmp_processing_rule["mode"])
290
+ documents = index_processor.transform(
291
+ text_docs,
292
+ embedding_model_instance=embedding_model_instance,
293
+ process_rule=processing_rule.to_dict(),
294
+ tenant_id=current_user.current_tenant_id,
295
+ doc_language=doc_language,
296
+ preview=True,
297
+ )
298
+ total_segments += len(documents)
299
+ for document in documents:
300
+ if len(preview_texts) < 10:
301
+ if doc_form and doc_form == "qa_model":
302
+ preview_detail = QAPreviewDetail(
303
+ question=document.page_content, answer=document.metadata.get("answer") or ""
304
+ )
305
+ preview_texts.append(preview_detail)
306
+ else:
307
+ preview_detail = PreviewDetail(content=document.page_content) # type: ignore
308
+ if document.children:
309
+ preview_detail.child_chunks = [child.page_content for child in document.children] # type: ignore
310
+ preview_texts.append(preview_detail)
311
+
312
+ # delete image files and related db records
313
+ image_upload_file_ids = get_image_upload_file_ids(document.page_content)
314
+ for upload_file_id in image_upload_file_ids:
315
+ image_file = db.session.query(UploadFile).filter(UploadFile.id == upload_file_id).first()
316
+ try:
317
+ if image_file:
318
+ storage.delete(image_file.key)
319
+ except Exception:
320
+ logging.exception(
321
+ "Delete image_files failed while indexing_estimate, \
322
+ image_upload_file_is: {}".format(upload_file_id)
323
+ )
324
+ db.session.delete(image_file)
325
+
326
+ if doc_form and doc_form == "qa_model":
327
+ return IndexingEstimate(total_segments=total_segments * 20, qa_preview=preview_texts, preview=[])
328
+ return IndexingEstimate(total_segments=total_segments, preview=preview_texts) # type: ignore
329
+
330
+ def _extract(
331
+ self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict
332
+ ) -> list[Document]:
333
+ # load file
334
+ if dataset_document.data_source_type not in {"upload_file", "notion_import", "website_crawl"}:
335
+ return []
336
+
337
+ data_source_info = dataset_document.data_source_info_dict
338
+ text_docs = []
339
+ if dataset_document.data_source_type == "upload_file":
340
+ if not data_source_info or "upload_file_id" not in data_source_info:
341
+ raise ValueError("no upload file found")
342
+
343
+ file_detail = (
344
+ db.session.query(UploadFile).filter(UploadFile.id == data_source_info["upload_file_id"]).one_or_none()
345
+ )
346
+
347
+ if file_detail:
348
+ extract_setting = ExtractSetting(
349
+ datasource_type="upload_file", upload_file=file_detail, document_model=dataset_document.doc_form
350
+ )
351
+ text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"])
352
+ elif dataset_document.data_source_type == "notion_import":
353
+ if (
354
+ not data_source_info
355
+ or "notion_workspace_id" not in data_source_info
356
+ or "notion_page_id" not in data_source_info
357
+ ):
358
+ raise ValueError("no notion import info found")
359
+ extract_setting = ExtractSetting(
360
+ datasource_type="notion_import",
361
+ notion_info={
362
+ "notion_workspace_id": data_source_info["notion_workspace_id"],
363
+ "notion_obj_id": data_source_info["notion_page_id"],
364
+ "notion_page_type": data_source_info["type"],
365
+ "document": dataset_document,
366
+ "tenant_id": dataset_document.tenant_id,
367
+ },
368
+ document_model=dataset_document.doc_form,
369
+ )
370
+ text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"])
371
+ elif dataset_document.data_source_type == "website_crawl":
372
+ if (
373
+ not data_source_info
374
+ or "provider" not in data_source_info
375
+ or "url" not in data_source_info
376
+ or "job_id" not in data_source_info
377
+ ):
378
+ raise ValueError("no website import info found")
379
+ extract_setting = ExtractSetting(
380
+ datasource_type="website_crawl",
381
+ website_info={
382
+ "provider": data_source_info["provider"],
383
+ "job_id": data_source_info["job_id"],
384
+ "tenant_id": dataset_document.tenant_id,
385
+ "url": data_source_info["url"],
386
+ "mode": data_source_info["mode"],
387
+ "only_main_content": data_source_info["only_main_content"],
388
+ },
389
+ document_model=dataset_document.doc_form,
390
+ )
391
+ text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"])
392
+ # update document status to splitting
393
+ self._update_document_index_status(
394
+ document_id=dataset_document.id,
395
+ after_indexing_status="splitting",
396
+ extra_update_params={
397
+ DatasetDocument.word_count: sum(len(text_doc.page_content) for text_doc in text_docs),
398
+ DatasetDocument.parsing_completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
399
+ },
400
+ )
401
+
402
+ # replace doc id to document model id
403
+ text_docs = cast(list[Document], text_docs)
404
+ for text_doc in text_docs:
405
+ if text_doc.metadata is not None:
406
+ text_doc.metadata["document_id"] = dataset_document.id
407
+ text_doc.metadata["dataset_id"] = dataset_document.dataset_id
408
+
409
+ return text_docs
410
+
411
+ @staticmethod
412
+ def filter_string(text):
413
+ text = re.sub(r"<\|", "<", text)
414
+ text = re.sub(r"\|>", ">", text)
415
+ text = re.sub(r"[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\xEF\xBF\xBE]", "", text)
416
+ # Unicode U+FFFE
417
+ text = re.sub("\ufffe", "", text)
418
+ return text
419
+
420
+ @staticmethod
421
+ def _get_splitter(
422
+ processing_rule_mode: str,
423
+ max_tokens: int,
424
+ chunk_overlap: int,
425
+ separator: str,
426
+ embedding_model_instance: Optional[ModelInstance],
427
+ ) -> TextSplitter:
428
+ """
429
+ Get the NodeParser object according to the processing rule.
430
+ """
431
+ if processing_rule_mode in ["custom", "hierarchical"]:
432
+ # The user-defined segmentation rule
433
+ max_segmentation_tokens_length = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH
434
+ if max_tokens < 50 or max_tokens > max_segmentation_tokens_length:
435
+ raise ValueError(f"Custom segment length should be between 50 and {max_segmentation_tokens_length}.")
436
+
437
+ if separator:
438
+ separator = separator.replace("\\n", "\n")
439
+
440
+ character_splitter = FixedRecursiveCharacterTextSplitter.from_encoder(
441
+ chunk_size=max_tokens,
442
+ chunk_overlap=chunk_overlap,
443
+ fixed_separator=separator,
444
+ separators=["\n\n", "。", ". ", " ", ""],
445
+ embedding_model_instance=embedding_model_instance,
446
+ )
447
+ else:
448
+ # Automatic segmentation
449
+ automatic_rules: dict[str, Any] = dict(DatasetProcessRule.AUTOMATIC_RULES["segmentation"])
450
+ character_splitter = EnhanceRecursiveCharacterTextSplitter.from_encoder(
451
+ chunk_size=automatic_rules["max_tokens"],
452
+ chunk_overlap=automatic_rules["chunk_overlap"],
453
+ separators=["\n\n", "。", ". ", " ", ""],
454
+ embedding_model_instance=embedding_model_instance,
455
+ )
456
+
457
+ return character_splitter # type: ignore
458
+
459
+ def _split_to_documents_for_estimate(
460
+ self, text_docs: list[Document], splitter: TextSplitter, processing_rule: DatasetProcessRule
461
+ ) -> list[Document]:
462
+ """
463
+ Split the text documents into nodes.
464
+ """
465
+ all_documents: list[Document] = []
466
+ for text_doc in text_docs:
467
+ # document clean
468
+ document_text = self._document_clean(text_doc.page_content, processing_rule)
469
+ text_doc.page_content = document_text
470
+
471
+ # parse document to nodes
472
+ documents = splitter.split_documents([text_doc])
473
+
474
+ split_documents = []
475
+ for document in documents:
476
+ if document.page_content is None or not document.page_content.strip():
477
+ continue
478
+ if document.metadata is not None:
479
+ doc_id = str(uuid.uuid4())
480
+ hash = helper.generate_text_hash(document.page_content)
481
+ document.metadata["doc_id"] = doc_id
482
+ document.metadata["doc_hash"] = hash
483
+
484
+ split_documents.append(document)
485
+
486
+ all_documents.extend(split_documents)
487
+
488
+ return all_documents
489
+
490
+ @staticmethod
491
+ def _document_clean(text: str, processing_rule: DatasetProcessRule) -> str:
492
+ """
493
+ Clean the document text according to the processing rules.
494
+ """
495
+ if processing_rule.mode == "automatic":
496
+ rules = DatasetProcessRule.AUTOMATIC_RULES
497
+ else:
498
+ rules = json.loads(processing_rule.rules) if processing_rule.rules else {}
499
+ document_text = CleanProcessor.clean(text, {"rules": rules})
500
+
501
+ return document_text
502
+
503
+ @staticmethod
504
+ def format_split_text(text: str) -> list[QAPreviewDetail]:
505
+ regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q\d+:|$)"
506
+ matches = re.findall(regex, text, re.UNICODE)
507
+
508
+ return [QAPreviewDetail(question=q, answer=re.sub(r"\n\s*", "\n", a.strip())) for q, a in matches if q and a]
509
+
510
+ def _load(
511
+ self,
512
+ index_processor: BaseIndexProcessor,
513
+ dataset: Dataset,
514
+ dataset_document: DatasetDocument,
515
+ documents: list[Document],
516
+ ) -> None:
517
+ """
518
+ insert index and update document/segment status to completed
519
+ """
520
+
521
+ embedding_model_instance = None
522
+ if dataset.indexing_technique == "high_quality":
523
+ embedding_model_instance = self.model_manager.get_model_instance(
524
+ tenant_id=dataset.tenant_id,
525
+ provider=dataset.embedding_model_provider,
526
+ model_type=ModelType.TEXT_EMBEDDING,
527
+ model=dataset.embedding_model,
528
+ )
529
+
530
+ # chunk nodes by chunk size
531
+ indexing_start_at = time.perf_counter()
532
+ tokens = 0
533
+ if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX:
534
+ # create keyword index
535
+ create_keyword_thread = threading.Thread(
536
+ target=self._process_keyword_index,
537
+ args=(current_app._get_current_object(), dataset.id, dataset_document.id, documents), # type: ignore
538
+ )
539
+ create_keyword_thread.start()
540
+
541
+ max_workers = 10
542
+ if dataset.indexing_technique == "high_quality":
543
+ with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
544
+ futures = []
545
+
546
+ # Distribute documents into multiple groups based on the hash values of page_content
547
+ # This is done to prevent multiple threads from processing the same document,
548
+ # Thereby avoiding potential database insertion deadlocks
549
+ document_groups: list[list[Document]] = [[] for _ in range(max_workers)]
550
+ for document in documents:
551
+ hash = helper.generate_text_hash(document.page_content)
552
+ group_index = int(hash, 16) % max_workers
553
+ document_groups[group_index].append(document)
554
+ for chunk_documents in document_groups:
555
+ if len(chunk_documents) == 0:
556
+ continue
557
+ futures.append(
558
+ executor.submit(
559
+ self._process_chunk,
560
+ current_app._get_current_object(), # type: ignore
561
+ index_processor,
562
+ chunk_documents,
563
+ dataset,
564
+ dataset_document,
565
+ embedding_model_instance,
566
+ )
567
+ )
568
+
569
+ for future in futures:
570
+ tokens += future.result()
571
+ if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX:
572
+ create_keyword_thread.join()
573
+ indexing_end_at = time.perf_counter()
574
+
575
+ # update document status to completed
576
+ self._update_document_index_status(
577
+ document_id=dataset_document.id,
578
+ after_indexing_status="completed",
579
+ extra_update_params={
580
+ DatasetDocument.tokens: tokens,
581
+ DatasetDocument.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
582
+ DatasetDocument.indexing_latency: indexing_end_at - indexing_start_at,
583
+ DatasetDocument.error: None,
584
+ },
585
+ )
586
+
587
+ @staticmethod
588
+ def _process_keyword_index(flask_app, dataset_id, document_id, documents):
589
+ with flask_app.app_context():
590
+ dataset = Dataset.query.filter_by(id=dataset_id).first()
591
+ if not dataset:
592
+ raise ValueError("no dataset found")
593
+ keyword = Keyword(dataset)
594
+ keyword.create(documents)
595
+ if dataset.indexing_technique != "high_quality":
596
+ document_ids = [document.metadata["doc_id"] for document in documents]
597
+ db.session.query(DocumentSegment).filter(
598
+ DocumentSegment.document_id == document_id,
599
+ DocumentSegment.dataset_id == dataset_id,
600
+ DocumentSegment.index_node_id.in_(document_ids),
601
+ DocumentSegment.status == "indexing",
602
+ ).update(
603
+ {
604
+ DocumentSegment.status: "completed",
605
+ DocumentSegment.enabled: True,
606
+ DocumentSegment.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
607
+ }
608
+ )
609
+
610
+ db.session.commit()
611
+
612
+ def _process_chunk(
613
+ self, flask_app, index_processor, chunk_documents, dataset, dataset_document, embedding_model_instance
614
+ ):
615
+ with flask_app.app_context():
616
+ # check document is paused
617
+ self._check_document_paused_status(dataset_document.id)
618
+
619
+ tokens = 0
620
+ if embedding_model_instance:
621
+ tokens += sum(
622
+ embedding_model_instance.get_text_embedding_num_tokens([document.page_content])
623
+ for document in chunk_documents
624
+ )
625
+
626
+ # load index
627
+ index_processor.load(dataset, chunk_documents, with_keywords=False)
628
+
629
+ document_ids = [document.metadata["doc_id"] for document in chunk_documents]
630
+ db.session.query(DocumentSegment).filter(
631
+ DocumentSegment.document_id == dataset_document.id,
632
+ DocumentSegment.dataset_id == dataset.id,
633
+ DocumentSegment.index_node_id.in_(document_ids),
634
+ DocumentSegment.status == "indexing",
635
+ ).update(
636
+ {
637
+ DocumentSegment.status: "completed",
638
+ DocumentSegment.enabled: True,
639
+ DocumentSegment.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
640
+ }
641
+ )
642
+
643
+ db.session.commit()
644
+
645
+ return tokens
646
+
647
+ @staticmethod
648
+ def _check_document_paused_status(document_id: str):
649
+ indexing_cache_key = "document_{}_is_paused".format(document_id)
650
+ result = redis_client.get(indexing_cache_key)
651
+ if result:
652
+ raise DocumentIsPausedError()
653
+
654
+ @staticmethod
655
+ def _update_document_index_status(
656
+ document_id: str, after_indexing_status: str, extra_update_params: Optional[dict] = None
657
+ ) -> None:
658
+ """
659
+ Update the document indexing status.
660
+ """
661
+ count = DatasetDocument.query.filter_by(id=document_id, is_paused=True).count()
662
+ if count > 0:
663
+ raise DocumentIsPausedError()
664
+ document = DatasetDocument.query.filter_by(id=document_id).first()
665
+ if not document:
666
+ raise DocumentIsDeletedPausedError()
667
+
668
+ update_params = {DatasetDocument.indexing_status: after_indexing_status}
669
+
670
+ if extra_update_params:
671
+ update_params.update(extra_update_params)
672
+
673
+ DatasetDocument.query.filter_by(id=document_id).update(update_params)
674
+ db.session.commit()
675
+
676
+ @staticmethod
677
+ def _update_segments_by_document(dataset_document_id: str, update_params: dict) -> None:
678
+ """
679
+ Update the document segment by document id.
680
+ """
681
+ DocumentSegment.query.filter_by(document_id=dataset_document_id).update(update_params)
682
+ db.session.commit()
683
+
684
+ def _transform(
685
+ self,
686
+ index_processor: BaseIndexProcessor,
687
+ dataset: Dataset,
688
+ text_docs: list[Document],
689
+ doc_language: str,
690
+ process_rule: dict,
691
+ ) -> list[Document]:
692
+ # get embedding model instance
693
+ embedding_model_instance = None
694
+ if dataset.indexing_technique == "high_quality":
695
+ if dataset.embedding_model_provider:
696
+ embedding_model_instance = self.model_manager.get_model_instance(
697
+ tenant_id=dataset.tenant_id,
698
+ provider=dataset.embedding_model_provider,
699
+ model_type=ModelType.TEXT_EMBEDDING,
700
+ model=dataset.embedding_model,
701
+ )
702
+ else:
703
+ embedding_model_instance = self.model_manager.get_default_model_instance(
704
+ tenant_id=dataset.tenant_id,
705
+ model_type=ModelType.TEXT_EMBEDDING,
706
+ )
707
+
708
+ documents = index_processor.transform(
709
+ text_docs,
710
+ embedding_model_instance=embedding_model_instance,
711
+ process_rule=process_rule,
712
+ tenant_id=dataset.tenant_id,
713
+ doc_language=doc_language,
714
+ )
715
+
716
+ return documents
717
+
718
+ def _load_segments(self, dataset, dataset_document, documents):
719
+ # save node to document segment
720
+ doc_store = DatasetDocumentStore(
721
+ dataset=dataset, user_id=dataset_document.created_by, document_id=dataset_document.id
722
+ )
723
+
724
+ # add document segments
725
+ doc_store.add_documents(docs=documents, save_child=dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX)
726
+
727
+ # update document status to indexing
728
+ cur_time = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
729
+ self._update_document_index_status(
730
+ document_id=dataset_document.id,
731
+ after_indexing_status="indexing",
732
+ extra_update_params={
733
+ DatasetDocument.cleaning_completed_at: cur_time,
734
+ DatasetDocument.splitting_completed_at: cur_time,
735
+ },
736
+ )
737
+
738
+ # update segment status to indexing
739
+ self._update_segments_by_document(
740
+ dataset_document_id=dataset_document.id,
741
+ update_params={
742
+ DocumentSegment.status: "indexing",
743
+ DocumentSegment.indexing_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
744
+ },
745
+ )
746
+ pass
747
+
748
+
749
+ class DocumentIsPausedError(Exception):
750
+ pass
751
+
752
+
753
+ class DocumentIsDeletedPausedError(Exception):
754
+ pass
api/core/model_manager.py ADDED
@@ -0,0 +1,559 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from collections.abc import Callable, Generator, Iterable, Sequence
3
+ from typing import IO, Any, Optional, Union, cast
4
+
5
+ from configs import dify_config
6
+ from core.entities.embedding_type import EmbeddingInputType
7
+ from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
8
+ from core.entities.provider_entities import ModelLoadBalancingConfiguration
9
+ from core.errors.error import ProviderTokenNotInitError
10
+ from core.model_runtime.callbacks.base_callback import Callback
11
+ from core.model_runtime.entities.llm_entities import LLMResult
12
+ from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
13
+ from core.model_runtime.entities.model_entities import ModelType
14
+ from core.model_runtime.entities.rerank_entities import RerankResult
15
+ from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
16
+ from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeConnectionError, InvokeRateLimitError
17
+ from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
18
+ from core.model_runtime.model_providers.__base.moderation_model import ModerationModel
19
+ from core.model_runtime.model_providers.__base.rerank_model import RerankModel
20
+ from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
21
+ from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
22
+ from core.model_runtime.model_providers.__base.tts_model import TTSModel
23
+ from core.provider_manager import ProviderManager
24
+ from extensions.ext_redis import redis_client
25
+ from models.provider import ProviderType
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+
30
+ class ModelInstance:
31
+ """
32
+ Model instance class
33
+ """
34
+
35
+ def __init__(self, provider_model_bundle: ProviderModelBundle, model: str) -> None:
36
+ self.provider_model_bundle = provider_model_bundle
37
+ self.model = model
38
+ self.provider = provider_model_bundle.configuration.provider.provider
39
+ self.credentials = self._fetch_credentials_from_bundle(provider_model_bundle, model)
40
+ self.model_type_instance = self.provider_model_bundle.model_type_instance
41
+ self.load_balancing_manager = self._get_load_balancing_manager(
42
+ configuration=provider_model_bundle.configuration,
43
+ model_type=provider_model_bundle.model_type_instance.model_type,
44
+ model=model,
45
+ credentials=self.credentials,
46
+ )
47
+
48
+ @staticmethod
49
+ def _fetch_credentials_from_bundle(provider_model_bundle: ProviderModelBundle, model: str) -> dict:
50
+ """
51
+ Fetch credentials from provider model bundle
52
+ :param provider_model_bundle: provider model bundle
53
+ :param model: model name
54
+ :return:
55
+ """
56
+ configuration = provider_model_bundle.configuration
57
+ model_type = provider_model_bundle.model_type_instance.model_type
58
+ credentials = configuration.get_current_credentials(model_type=model_type, model=model)
59
+
60
+ if credentials is None:
61
+ raise ProviderTokenNotInitError(f"Model {model} credentials is not initialized.")
62
+
63
+ return credentials
64
+
65
+ @staticmethod
66
+ def _get_load_balancing_manager(
67
+ configuration: ProviderConfiguration, model_type: ModelType, model: str, credentials: dict
68
+ ) -> Optional["LBModelManager"]:
69
+ """
70
+ Get load balancing model credentials
71
+ :param configuration: provider configuration
72
+ :param model_type: model type
73
+ :param model: model name
74
+ :param credentials: model credentials
75
+ :return:
76
+ """
77
+ if configuration.model_settings and configuration.using_provider_type == ProviderType.CUSTOM:
78
+ current_model_setting = None
79
+ # check if model is disabled by admin
80
+ for model_setting in configuration.model_settings:
81
+ if model_setting.model_type == model_type and model_setting.model == model:
82
+ current_model_setting = model_setting
83
+ break
84
+
85
+ # check if load balancing is enabled
86
+ if current_model_setting and current_model_setting.load_balancing_configs:
87
+ # use load balancing proxy to choose credentials
88
+ lb_model_manager = LBModelManager(
89
+ tenant_id=configuration.tenant_id,
90
+ provider=configuration.provider.provider,
91
+ model_type=model_type,
92
+ model=model,
93
+ load_balancing_configs=current_model_setting.load_balancing_configs,
94
+ managed_credentials=credentials if configuration.custom_configuration.provider else None,
95
+ )
96
+
97
+ return lb_model_manager
98
+
99
+ return None
100
+
101
+ def invoke_llm(
102
+ self,
103
+ prompt_messages: Sequence[PromptMessage],
104
+ model_parameters: Optional[dict] = None,
105
+ tools: Sequence[PromptMessageTool] | None = None,
106
+ stop: Optional[Sequence[str]] = None,
107
+ stream: bool = True,
108
+ user: Optional[str] = None,
109
+ callbacks: Optional[list[Callback]] = None,
110
+ ) -> Union[LLMResult, Generator]:
111
+ """
112
+ Invoke large language model
113
+
114
+ :param prompt_messages: prompt messages
115
+ :param model_parameters: model parameters
116
+ :param tools: tools for tool calling
117
+ :param stop: stop words
118
+ :param stream: is stream response
119
+ :param user: unique user id
120
+ :param callbacks: callbacks
121
+ :return: full response or stream response chunk generator result
122
+ """
123
+ if not isinstance(self.model_type_instance, LargeLanguageModel):
124
+ raise Exception("Model type instance is not LargeLanguageModel")
125
+
126
+ self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance)
127
+ return cast(
128
+ Union[LLMResult, Generator],
129
+ self._round_robin_invoke(
130
+ function=self.model_type_instance.invoke,
131
+ model=self.model,
132
+ credentials=self.credentials,
133
+ prompt_messages=prompt_messages,
134
+ model_parameters=model_parameters,
135
+ tools=tools,
136
+ stop=stop,
137
+ stream=stream,
138
+ user=user,
139
+ callbacks=callbacks,
140
+ ),
141
+ )
142
+
143
+ def get_llm_num_tokens(
144
+ self, prompt_messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None
145
+ ) -> int:
146
+ """
147
+ Get number of tokens for llm
148
+
149
+ :param prompt_messages: prompt messages
150
+ :param tools: tools for tool calling
151
+ :return:
152
+ """
153
+ if not isinstance(self.model_type_instance, LargeLanguageModel):
154
+ raise Exception("Model type instance is not LargeLanguageModel")
155
+
156
+ self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance)
157
+ return cast(
158
+ int,
159
+ self._round_robin_invoke(
160
+ function=self.model_type_instance.get_num_tokens,
161
+ model=self.model,
162
+ credentials=self.credentials,
163
+ prompt_messages=prompt_messages,
164
+ tools=tools,
165
+ ),
166
+ )
167
+
168
+ def invoke_text_embedding(
169
+ self, texts: list[str], user: Optional[str] = None, input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT
170
+ ) -> TextEmbeddingResult:
171
+ """
172
+ Invoke large language model
173
+
174
+ :param texts: texts to embed
175
+ :param user: unique user id
176
+ :param input_type: input type
177
+ :return: embeddings result
178
+ """
179
+ if not isinstance(self.model_type_instance, TextEmbeddingModel):
180
+ raise Exception("Model type instance is not TextEmbeddingModel")
181
+
182
+ self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance)
183
+ return cast(
184
+ TextEmbeddingResult,
185
+ self._round_robin_invoke(
186
+ function=self.model_type_instance.invoke,
187
+ model=self.model,
188
+ credentials=self.credentials,
189
+ texts=texts,
190
+ user=user,
191
+ input_type=input_type,
192
+ ),
193
+ )
194
+
195
+ def get_text_embedding_num_tokens(self, texts: list[str]) -> int:
196
+ """
197
+ Get number of tokens for text embedding
198
+
199
+ :param texts: texts to embed
200
+ :return:
201
+ """
202
+ if not isinstance(self.model_type_instance, TextEmbeddingModel):
203
+ raise Exception("Model type instance is not TextEmbeddingModel")
204
+
205
+ self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance)
206
+ return cast(
207
+ int,
208
+ self._round_robin_invoke(
209
+ function=self.model_type_instance.get_num_tokens,
210
+ model=self.model,
211
+ credentials=self.credentials,
212
+ texts=texts,
213
+ ),
214
+ )
215
+
216
+ def invoke_rerank(
217
+ self,
218
+ query: str,
219
+ docs: list[str],
220
+ score_threshold: Optional[float] = None,
221
+ top_n: Optional[int] = None,
222
+ user: Optional[str] = None,
223
+ ) -> RerankResult:
224
+ """
225
+ Invoke rerank model
226
+
227
+ :param query: search query
228
+ :param docs: docs for reranking
229
+ :param score_threshold: score threshold
230
+ :param top_n: top n
231
+ :param user: unique user id
232
+ :return: rerank result
233
+ """
234
+ if not isinstance(self.model_type_instance, RerankModel):
235
+ raise Exception("Model type instance is not RerankModel")
236
+
237
+ self.model_type_instance = cast(RerankModel, self.model_type_instance)
238
+ return cast(
239
+ RerankResult,
240
+ self._round_robin_invoke(
241
+ function=self.model_type_instance.invoke,
242
+ model=self.model,
243
+ credentials=self.credentials,
244
+ query=query,
245
+ docs=docs,
246
+ score_threshold=score_threshold,
247
+ top_n=top_n,
248
+ user=user,
249
+ ),
250
+ )
251
+
252
+ def invoke_moderation(self, text: str, user: Optional[str] = None) -> bool:
253
+ """
254
+ Invoke moderation model
255
+
256
+ :param text: text to moderate
257
+ :param user: unique user id
258
+ :return: false if text is safe, true otherwise
259
+ """
260
+ if not isinstance(self.model_type_instance, ModerationModel):
261
+ raise Exception("Model type instance is not ModerationModel")
262
+
263
+ self.model_type_instance = cast(ModerationModel, self.model_type_instance)
264
+ return cast(
265
+ bool,
266
+ self._round_robin_invoke(
267
+ function=self.model_type_instance.invoke,
268
+ model=self.model,
269
+ credentials=self.credentials,
270
+ text=text,
271
+ user=user,
272
+ ),
273
+ )
274
+
275
+ def invoke_speech2text(self, file: IO[bytes], user: Optional[str] = None) -> str:
276
+ """
277
+ Invoke large language model
278
+
279
+ :param file: audio file
280
+ :param user: unique user id
281
+ :return: text for given audio file
282
+ """
283
+ if not isinstance(self.model_type_instance, Speech2TextModel):
284
+ raise Exception("Model type instance is not Speech2TextModel")
285
+
286
+ self.model_type_instance = cast(Speech2TextModel, self.model_type_instance)
287
+ return cast(
288
+ str,
289
+ self._round_robin_invoke(
290
+ function=self.model_type_instance.invoke,
291
+ model=self.model,
292
+ credentials=self.credentials,
293
+ file=file,
294
+ user=user,
295
+ ),
296
+ )
297
+
298
+ def invoke_tts(self, content_text: str, tenant_id: str, voice: str, user: Optional[str] = None) -> Iterable[bytes]:
299
+ """
300
+ Invoke large language tts model
301
+
302
+ :param content_text: text content to be translated
303
+ :param tenant_id: user tenant id
304
+ :param voice: model timbre
305
+ :param user: unique user id
306
+ :return: text for given audio file
307
+ """
308
+ if not isinstance(self.model_type_instance, TTSModel):
309
+ raise Exception("Model type instance is not TTSModel")
310
+
311
+ self.model_type_instance = cast(TTSModel, self.model_type_instance)
312
+ return cast(
313
+ Iterable[bytes],
314
+ self._round_robin_invoke(
315
+ function=self.model_type_instance.invoke,
316
+ model=self.model,
317
+ credentials=self.credentials,
318
+ content_text=content_text,
319
+ user=user,
320
+ tenant_id=tenant_id,
321
+ voice=voice,
322
+ ),
323
+ )
324
+
325
+ def _round_robin_invoke(self, function: Callable[..., Any], *args, **kwargs) -> Any:
326
+ """
327
+ Round-robin invoke
328
+ :param function: function to invoke
329
+ :param args: function args
330
+ :param kwargs: function kwargs
331
+ :return:
332
+ """
333
+ if not self.load_balancing_manager:
334
+ return function(*args, **kwargs)
335
+
336
+ last_exception: Union[InvokeRateLimitError, InvokeAuthorizationError, InvokeConnectionError, None] = None
337
+ while True:
338
+ lb_config = self.load_balancing_manager.fetch_next()
339
+ if not lb_config:
340
+ if not last_exception:
341
+ raise ProviderTokenNotInitError("Model credentials is not initialized.")
342
+ else:
343
+ raise last_exception
344
+
345
+ try:
346
+ if "credentials" in kwargs:
347
+ del kwargs["credentials"]
348
+ return function(*args, **kwargs, credentials=lb_config.credentials)
349
+ except InvokeRateLimitError as e:
350
+ # expire in 60 seconds
351
+ self.load_balancing_manager.cooldown(lb_config, expire=60)
352
+ last_exception = e
353
+ continue
354
+ except (InvokeAuthorizationError, InvokeConnectionError) as e:
355
+ # expire in 10 seconds
356
+ self.load_balancing_manager.cooldown(lb_config, expire=10)
357
+ last_exception = e
358
+ continue
359
+ except Exception as e:
360
+ raise e
361
+
362
+ def get_tts_voices(self, language: Optional[str] = None) -> list:
363
+ """
364
+ Invoke large language tts model voices
365
+
366
+ :param language: tts language
367
+ :return: tts model voices
368
+ """
369
+ if not isinstance(self.model_type_instance, TTSModel):
370
+ raise Exception("Model type instance is not TTSModel")
371
+
372
+ self.model_type_instance = cast(TTSModel, self.model_type_instance)
373
+ return self.model_type_instance.get_tts_model_voices(
374
+ model=self.model, credentials=self.credentials, language=language
375
+ )
376
+
377
+
378
+ class ModelManager:
379
+ def __init__(self) -> None:
380
+ self._provider_manager = ProviderManager()
381
+
382
+ def get_model_instance(self, tenant_id: str, provider: str, model_type: ModelType, model: str) -> ModelInstance:
383
+ """
384
+ Get model instance
385
+ :param tenant_id: tenant id
386
+ :param provider: provider name
387
+ :param model_type: model type
388
+ :param model: model name
389
+ :return:
390
+ """
391
+ if not provider:
392
+ return self.get_default_model_instance(tenant_id, model_type)
393
+
394
+ provider_model_bundle = self._provider_manager.get_provider_model_bundle(
395
+ tenant_id=tenant_id, provider=provider, model_type=model_type
396
+ )
397
+
398
+ return ModelInstance(provider_model_bundle, model)
399
+
400
+ def get_default_provider_model_name(self, tenant_id: str, model_type: ModelType) -> tuple[str, str]:
401
+ """
402
+ Return first provider and the first model in the provider
403
+ :param tenant_id: tenant id
404
+ :param model_type: model type
405
+ :return: provider name, model name
406
+ """
407
+ return self._provider_manager.get_first_provider_first_model(tenant_id, model_type)
408
+
409
+ def get_default_model_instance(self, tenant_id: str, model_type: ModelType) -> ModelInstance:
410
+ """
411
+ Get default model instance
412
+ :param tenant_id: tenant id
413
+ :param model_type: model type
414
+ :return:
415
+ """
416
+ default_model_entity = self._provider_manager.get_default_model(tenant_id=tenant_id, model_type=model_type)
417
+
418
+ if not default_model_entity:
419
+ raise ProviderTokenNotInitError(f"Default model not found for {model_type}")
420
+
421
+ return self.get_model_instance(
422
+ tenant_id=tenant_id,
423
+ provider=default_model_entity.provider.provider,
424
+ model_type=model_type,
425
+ model=default_model_entity.model,
426
+ )
427
+
428
+
429
+ class LBModelManager:
430
+ def __init__(
431
+ self,
432
+ tenant_id: str,
433
+ provider: str,
434
+ model_type: ModelType,
435
+ model: str,
436
+ load_balancing_configs: list[ModelLoadBalancingConfiguration],
437
+ managed_credentials: Optional[dict] = None,
438
+ ) -> None:
439
+ """
440
+ Load balancing model manager
441
+ :param tenant_id: tenant_id
442
+ :param provider: provider
443
+ :param model_type: model_type
444
+ :param model: model name
445
+ :param load_balancing_configs: all load balancing configurations
446
+ :param managed_credentials: credentials if load balancing configuration name is __inherit__
447
+ """
448
+ self._tenant_id = tenant_id
449
+ self._provider = provider
450
+ self._model_type = model_type
451
+ self._model = model
452
+ self._load_balancing_configs = load_balancing_configs
453
+
454
+ for load_balancing_config in self._load_balancing_configs[:]: # Iterate over a shallow copy of the list
455
+ if load_balancing_config.name == "__inherit__":
456
+ if not managed_credentials:
457
+ # remove __inherit__ if managed credentials is not provided
458
+ self._load_balancing_configs.remove(load_balancing_config)
459
+ else:
460
+ load_balancing_config.credentials = managed_credentials
461
+
462
+ def fetch_next(self) -> Optional[ModelLoadBalancingConfiguration]:
463
+ """
464
+ Get next model load balancing config
465
+ Strategy: Round Robin
466
+ :return:
467
+ """
468
+ cache_key = "model_lb_index:{}:{}:{}:{}".format(
469
+ self._tenant_id, self._provider, self._model_type.value, self._model
470
+ )
471
+
472
+ cooldown_load_balancing_configs = []
473
+ max_index = len(self._load_balancing_configs)
474
+
475
+ while True:
476
+ current_index = redis_client.incr(cache_key)
477
+ current_index = cast(int, current_index)
478
+ if current_index >= 10000000:
479
+ current_index = 1
480
+ redis_client.set(cache_key, current_index)
481
+
482
+ redis_client.expire(cache_key, 3600)
483
+ if current_index > max_index:
484
+ current_index = current_index % max_index
485
+
486
+ real_index = current_index - 1
487
+ if real_index > max_index:
488
+ real_index = 0
489
+
490
+ config: ModelLoadBalancingConfiguration = self._load_balancing_configs[real_index]
491
+
492
+ if self.in_cooldown(config):
493
+ cooldown_load_balancing_configs.append(config)
494
+ if len(cooldown_load_balancing_configs) >= len(self._load_balancing_configs):
495
+ # all configs are in cooldown
496
+ return None
497
+
498
+ continue
499
+
500
+ if dify_config.DEBUG:
501
+ logger.info(
502
+ f"Model LB\nid: {config.id}\nname:{config.name}\n"
503
+ f"tenant_id: {self._tenant_id}\nprovider: {self._provider}\n"
504
+ f"model_type: {self._model_type.value}\nmodel: {self._model}"
505
+ )
506
+
507
+ return config
508
+
509
+ return None
510
+
511
+ def cooldown(self, config: ModelLoadBalancingConfiguration, expire: int = 60) -> None:
512
+ """
513
+ Cooldown model load balancing config
514
+ :param config: model load balancing config
515
+ :param expire: cooldown time
516
+ :return:
517
+ """
518
+ cooldown_cache_key = "model_lb_index:cooldown:{}:{}:{}:{}:{}".format(
519
+ self._tenant_id, self._provider, self._model_type.value, self._model, config.id
520
+ )
521
+
522
+ redis_client.setex(cooldown_cache_key, expire, "true")
523
+
524
+ def in_cooldown(self, config: ModelLoadBalancingConfiguration) -> bool:
525
+ """
526
+ Check if model load balancing config is in cooldown
527
+ :param config: model load balancing config
528
+ :return:
529
+ """
530
+ cooldown_cache_key = "model_lb_index:cooldown:{}:{}:{}:{}:{}".format(
531
+ self._tenant_id, self._provider, self._model_type.value, self._model, config.id
532
+ )
533
+
534
+ res: bool = redis_client.exists(cooldown_cache_key)
535
+ return res
536
+
537
+ @staticmethod
538
+ def get_config_in_cooldown_and_ttl(
539
+ tenant_id: str, provider: str, model_type: ModelType, model: str, config_id: str
540
+ ) -> tuple[bool, int]:
541
+ """
542
+ Get model load balancing config is in cooldown and ttl
543
+ :param tenant_id: workspace id
544
+ :param provider: provider name
545
+ :param model_type: model type
546
+ :param model: model name
547
+ :param config_id: model load balancing config id
548
+ :return:
549
+ """
550
+ cooldown_cache_key = "model_lb_index:cooldown:{}:{}:{}:{}:{}".format(
551
+ tenant_id, provider, model_type.value, model, config_id
552
+ )
553
+
554
+ ttl = redis_client.ttl(cooldown_cache_key)
555
+ if ttl == -2:
556
+ return False, 0
557
+
558
+ ttl = cast(int, ttl)
559
+ return True, ttl
api/core/moderation/__init__.py ADDED
File without changes
api/core/moderation/api/__builtin__ ADDED
@@ -0,0 +1 @@
 
 
1
+ 3
api/core/moderation/api/__init__.py ADDED
File without changes
api/core/moderation/api/api.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ from pydantic import BaseModel
4
+
5
+ from core.extension.api_based_extension_requestor import APIBasedExtensionPoint, APIBasedExtensionRequestor
6
+ from core.helper.encrypter import decrypt_token
7
+ from core.moderation.base import Moderation, ModerationAction, ModerationInputsResult, ModerationOutputsResult
8
+ from extensions.ext_database import db
9
+ from models.api_based_extension import APIBasedExtension
10
+
11
+
12
+ class ModerationInputParams(BaseModel):
13
+ app_id: str = ""
14
+ inputs: dict = {}
15
+ query: str = ""
16
+
17
+
18
+ class ModerationOutputParams(BaseModel):
19
+ app_id: str = ""
20
+ text: str
21
+
22
+
23
+ class ApiModeration(Moderation):
24
+ name: str = "api"
25
+
26
+ @classmethod
27
+ def validate_config(cls, tenant_id: str, config: dict) -> None:
28
+ """
29
+ Validate the incoming form config data.
30
+
31
+ :param tenant_id: the id of workspace
32
+ :param config: the form config data
33
+ :return:
34
+ """
35
+ cls._validate_inputs_and_outputs_config(config, False)
36
+
37
+ api_based_extension_id = config.get("api_based_extension_id")
38
+ if not api_based_extension_id:
39
+ raise ValueError("api_based_extension_id is required")
40
+
41
+ extension = cls._get_api_based_extension(tenant_id, api_based_extension_id)
42
+ if not extension:
43
+ raise ValueError("API-based Extension not found. Please check it again.")
44
+
45
+ def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
46
+ flagged = False
47
+ preset_response = ""
48
+ if self.config is None:
49
+ raise ValueError("The config is not set.")
50
+
51
+ if self.config["inputs_config"]["enabled"]:
52
+ params = ModerationInputParams(app_id=self.app_id, inputs=inputs, query=query)
53
+
54
+ result = self._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_INPUT, params.model_dump())
55
+ return ModerationInputsResult(**result)
56
+
57
+ return ModerationInputsResult(
58
+ flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
59
+ )
60
+
61
+ def moderation_for_outputs(self, text: str) -> ModerationOutputsResult:
62
+ flagged = False
63
+ preset_response = ""
64
+ if self.config is None:
65
+ raise ValueError("The config is not set.")
66
+
67
+ if self.config["outputs_config"]["enabled"]:
68
+ params = ModerationOutputParams(app_id=self.app_id, text=text)
69
+
70
+ result = self._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_OUTPUT, params.model_dump())
71
+ return ModerationOutputsResult(**result)
72
+
73
+ return ModerationOutputsResult(
74
+ flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
75
+ )
76
+
77
+ def _get_config_by_requestor(self, extension_point: APIBasedExtensionPoint, params: dict) -> dict:
78
+ if self.config is None:
79
+ raise ValueError("The config is not set.")
80
+ extension = self._get_api_based_extension(self.tenant_id, self.config.get("api_based_extension_id", ""))
81
+ if not extension:
82
+ raise ValueError("API-based Extension not found. Please check it again.")
83
+ requestor = APIBasedExtensionRequestor(extension.api_endpoint, decrypt_token(self.tenant_id, extension.api_key))
84
+
85
+ result = requestor.request(extension_point, params)
86
+ return result
87
+
88
+ @staticmethod
89
+ def _get_api_based_extension(tenant_id: str, api_based_extension_id: str) -> Optional[APIBasedExtension]:
90
+ extension = (
91
+ db.session.query(APIBasedExtension)
92
+ .filter(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id)
93
+ .first()
94
+ )
95
+
96
+ return extension
api/core/moderation/base.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from enum import Enum
3
+ from typing import Optional
4
+
5
+ from pydantic import BaseModel
6
+
7
+ from core.extension.extensible import Extensible, ExtensionModule
8
+
9
+
10
+ class ModerationAction(Enum):
11
+ DIRECT_OUTPUT = "direct_output"
12
+ OVERRIDDEN = "overridden"
13
+
14
+
15
+ class ModerationInputsResult(BaseModel):
16
+ flagged: bool = False
17
+ action: ModerationAction
18
+ preset_response: str = ""
19
+ inputs: dict = {}
20
+ query: str = ""
21
+
22
+
23
+ class ModerationOutputsResult(BaseModel):
24
+ flagged: bool = False
25
+ action: ModerationAction
26
+ preset_response: str = ""
27
+ text: str = ""
28
+
29
+
30
+ class Moderation(Extensible, ABC):
31
+ """
32
+ The base class of moderation.
33
+ """
34
+
35
+ module: ExtensionModule = ExtensionModule.MODERATION
36
+
37
+ def __init__(self, app_id: str, tenant_id: str, config: Optional[dict] = None) -> None:
38
+ super().__init__(tenant_id, config)
39
+ self.app_id = app_id
40
+
41
+ @classmethod
42
+ @abstractmethod
43
+ def validate_config(cls, tenant_id: str, config: dict) -> None:
44
+ """
45
+ Validate the incoming form config data.
46
+
47
+ :param tenant_id: the id of workspace
48
+ :param config: the form config data
49
+ :return:
50
+ """
51
+ raise NotImplementedError
52
+
53
+ @abstractmethod
54
+ def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
55
+ """
56
+ Moderation for inputs.
57
+ After the user inputs, this method will be called to perform sensitive content review
58
+ on the user inputs and return the processed results.
59
+
60
+ :param inputs: user inputs
61
+ :param query: query string (required in chat app)
62
+ :return:
63
+ """
64
+ raise NotImplementedError
65
+
66
+ @abstractmethod
67
+ def moderation_for_outputs(self, text: str) -> ModerationOutputsResult:
68
+ """
69
+ Moderation for outputs.
70
+ When LLM outputs content, the front end will pass the output content (may be segmented)
71
+ to this method for sensitive content review, and the output content will be shielded if the review fails.
72
+
73
+ :param text: LLM output content
74
+ :return:
75
+ """
76
+ raise NotImplementedError
77
+
78
+ @classmethod
79
+ def _validate_inputs_and_outputs_config(cls, config: dict, is_preset_response_required: bool) -> None:
80
+ # inputs_config
81
+ inputs_config = config.get("inputs_config")
82
+ if not isinstance(inputs_config, dict):
83
+ raise ValueError("inputs_config must be a dict")
84
+
85
+ # outputs_config
86
+ outputs_config = config.get("outputs_config")
87
+ if not isinstance(outputs_config, dict):
88
+ raise ValueError("outputs_config must be a dict")
89
+
90
+ inputs_config_enabled = inputs_config.get("enabled")
91
+ outputs_config_enabled = outputs_config.get("enabled")
92
+ if not inputs_config_enabled and not outputs_config_enabled:
93
+ raise ValueError("At least one of inputs_config or outputs_config must be enabled")
94
+
95
+ # preset_response
96
+ if not is_preset_response_required:
97
+ return
98
+
99
+ if inputs_config_enabled:
100
+ if not inputs_config.get("preset_response"):
101
+ raise ValueError("inputs_config.preset_response is required")
102
+
103
+ if len(inputs_config.get("preset_response", 0)) > 100:
104
+ raise ValueError("inputs_config.preset_response must be less than 100 characters")
105
+
106
+ if outputs_config_enabled:
107
+ if not outputs_config.get("preset_response"):
108
+ raise ValueError("outputs_config.preset_response is required")
109
+
110
+ if len(outputs_config.get("preset_response", 0)) > 100:
111
+ raise ValueError("outputs_config.preset_response must be less than 100 characters")
112
+
113
+
114
+ class ModerationError(Exception):
115
+ pass
api/core/moderation/factory.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from core.extension.extensible import ExtensionModule
2
+ from core.moderation.base import Moderation, ModerationInputsResult, ModerationOutputsResult
3
+ from extensions.ext_code_based_extension import code_based_extension
4
+
5
+
6
+ class ModerationFactory:
7
+ __extension_instance: Moderation
8
+
9
+ def __init__(self, name: str, app_id: str, tenant_id: str, config: dict) -> None:
10
+ extension_class = code_based_extension.extension_class(ExtensionModule.MODERATION, name)
11
+ self.__extension_instance = extension_class(app_id, tenant_id, config)
12
+
13
+ @classmethod
14
+ def validate_config(cls, name: str, tenant_id: str, config: dict) -> None:
15
+ """
16
+ Validate the incoming form config data.
17
+
18
+ :param name: the name of extension
19
+ :param tenant_id: the id of workspace
20
+ :param config: the form config data
21
+ :return:
22
+ """
23
+ code_based_extension.validate_form_schema(ExtensionModule.MODERATION, name, config)
24
+ extension_class = code_based_extension.extension_class(ExtensionModule.MODERATION, name)
25
+ # FIXME: mypy error, try to fix it instead of using type: ignore
26
+ extension_class.validate_config(tenant_id, config) # type: ignore
27
+
28
+ def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
29
+ """
30
+ Moderation for inputs.
31
+ After the user inputs, this method will be called to perform sensitive content review
32
+ on the user inputs and return the processed results.
33
+
34
+ :param inputs: user inputs
35
+ :param query: query string (required in chat app)
36
+ :return:
37
+ """
38
+ return self.__extension_instance.moderation_for_inputs(inputs, query)
39
+
40
+ def moderation_for_outputs(self, text: str) -> ModerationOutputsResult:
41
+ """
42
+ Moderation for outputs.
43
+ When LLM outputs content, the front end will pass the output content (may be segmented)
44
+ to this method for sensitive content review, and the output content will be shielded if the review fails.
45
+
46
+ :param text: LLM output content
47
+ :return:
48
+ """
49
+ return self.__extension_instance.moderation_for_outputs(text)
api/core/moderation/input_moderation.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from collections.abc import Mapping
3
+ from typing import Any, Optional
4
+
5
+ from core.app.app_config.entities import AppConfig
6
+ from core.moderation.base import ModerationAction, ModerationError
7
+ from core.moderation.factory import ModerationFactory
8
+ from core.ops.entities.trace_entity import TraceTaskName
9
+ from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
10
+ from core.ops.utils import measure_time
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class InputModeration:
16
+ def check(
17
+ self,
18
+ app_id: str,
19
+ tenant_id: str,
20
+ app_config: AppConfig,
21
+ inputs: Mapping[str, Any],
22
+ query: str,
23
+ message_id: str,
24
+ trace_manager: Optional[TraceQueueManager] = None,
25
+ ) -> tuple[bool, Mapping[str, Any], str]:
26
+ """
27
+ Process sensitive_word_avoidance.
28
+ :param app_id: app id
29
+ :param tenant_id: tenant id
30
+ :param app_config: app config
31
+ :param inputs: inputs
32
+ :param query: query
33
+ :param message_id: message id
34
+ :param trace_manager: trace manager
35
+ :return:
36
+ """
37
+ inputs = dict(inputs)
38
+ if not app_config.sensitive_word_avoidance:
39
+ return False, inputs, query
40
+
41
+ sensitive_word_avoidance_config = app_config.sensitive_word_avoidance
42
+ moderation_type = sensitive_word_avoidance_config.type
43
+
44
+ moderation_factory = ModerationFactory(
45
+ name=moderation_type, app_id=app_id, tenant_id=tenant_id, config=sensitive_word_avoidance_config.config
46
+ )
47
+
48
+ with measure_time() as timer:
49
+ moderation_result = moderation_factory.moderation_for_inputs(inputs, query)
50
+
51
+ if trace_manager:
52
+ trace_manager.add_trace_task(
53
+ TraceTask(
54
+ TraceTaskName.MODERATION_TRACE,
55
+ message_id=message_id,
56
+ moderation_result=moderation_result,
57
+ inputs=inputs,
58
+ timer=timer,
59
+ )
60
+ )
61
+
62
+ if not moderation_result.flagged:
63
+ return False, inputs, query
64
+
65
+ if moderation_result.action == ModerationAction.DIRECT_OUTPUT:
66
+ raise ModerationError(moderation_result.preset_response)
67
+ elif moderation_result.action == ModerationAction.OVERRIDDEN:
68
+ inputs = moderation_result.inputs
69
+ query = moderation_result.query
70
+
71
+ return True, inputs, query
api/core/moderation/keywords/__builtin__ ADDED
@@ -0,0 +1 @@
 
 
1
+ 2
api/core/moderation/keywords/__init__.py ADDED
File without changes
api/core/moderation/keywords/keywords.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections.abc import Sequence
2
+ from typing import Any
3
+
4
+ from core.moderation.base import Moderation, ModerationAction, ModerationInputsResult, ModerationOutputsResult
5
+
6
+
7
+ class KeywordsModeration(Moderation):
8
+ name: str = "keywords"
9
+
10
+ @classmethod
11
+ def validate_config(cls, tenant_id: str, config: dict) -> None:
12
+ """
13
+ Validate the incoming form config data.
14
+
15
+ :param tenant_id: the id of workspace
16
+ :param config: the form config data
17
+ :return:
18
+ """
19
+ cls._validate_inputs_and_outputs_config(config, True)
20
+
21
+ if not config.get("keywords"):
22
+ raise ValueError("keywords is required")
23
+
24
+ if len(config.get("keywords", [])) > 10000:
25
+ raise ValueError("keywords length must be less than 10000")
26
+
27
+ keywords_row_len = config["keywords"].split("\n")
28
+ if len(keywords_row_len) > 100:
29
+ raise ValueError("the number of rows for the keywords must be less than 100")
30
+
31
+ def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
32
+ flagged = False
33
+ preset_response = ""
34
+ if self.config is None:
35
+ raise ValueError("The config is not set.")
36
+
37
+ if self.config["inputs_config"]["enabled"]:
38
+ preset_response = self.config["inputs_config"]["preset_response"]
39
+
40
+ if query:
41
+ inputs["query__"] = query
42
+
43
+ # Filter out empty values
44
+ keywords_list = [keyword for keyword in self.config["keywords"].split("\n") if keyword]
45
+
46
+ flagged = self._is_violated(inputs, keywords_list)
47
+
48
+ return ModerationInputsResult(
49
+ flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
50
+ )
51
+
52
+ def moderation_for_outputs(self, text: str) -> ModerationOutputsResult:
53
+ flagged = False
54
+ preset_response = ""
55
+ if self.config is None:
56
+ raise ValueError("The config is not set.")
57
+
58
+ if self.config["outputs_config"]["enabled"]:
59
+ # Filter out empty values
60
+ keywords_list = [keyword for keyword in self.config["keywords"].split("\n") if keyword]
61
+
62
+ flagged = self._is_violated({"text": text}, keywords_list)
63
+ preset_response = self.config["outputs_config"]["preset_response"]
64
+
65
+ return ModerationOutputsResult(
66
+ flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
67
+ )
68
+
69
+ def _is_violated(self, inputs: dict, keywords_list: list) -> bool:
70
+ return any(self._check_keywords_in_value(keywords_list, value) for value in inputs.values())
71
+
72
+ def _check_keywords_in_value(self, keywords_list: Sequence[str], value: Any) -> bool:
73
+ return any(keyword.lower() in str(value).lower() for keyword in keywords_list)
api/core/moderation/openai_moderation/__builtin__ ADDED
@@ -0,0 +1 @@
 
 
1
+ 1
api/core/moderation/openai_moderation/__init__.py ADDED
File without changes
api/core/moderation/openai_moderation/openai_moderation.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from core.model_manager import ModelManager
2
+ from core.model_runtime.entities.model_entities import ModelType
3
+ from core.moderation.base import Moderation, ModerationAction, ModerationInputsResult, ModerationOutputsResult
4
+
5
+
6
+ class OpenAIModeration(Moderation):
7
+ name: str = "openai_moderation"
8
+
9
+ @classmethod
10
+ def validate_config(cls, tenant_id: str, config: dict) -> None:
11
+ """
12
+ Validate the incoming form config data.
13
+
14
+ :param tenant_id: the id of workspace
15
+ :param config: the form config data
16
+ :return:
17
+ """
18
+ cls._validate_inputs_and_outputs_config(config, True)
19
+
20
+ def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
21
+ flagged = False
22
+ preset_response = ""
23
+ if self.config is None:
24
+ raise ValueError("The config is not set.")
25
+
26
+ if self.config["inputs_config"]["enabled"]:
27
+ preset_response = self.config["inputs_config"]["preset_response"]
28
+
29
+ if query:
30
+ inputs["query__"] = query
31
+ flagged = self._is_violated(inputs)
32
+
33
+ return ModerationInputsResult(
34
+ flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
35
+ )
36
+
37
+ def moderation_for_outputs(self, text: str) -> ModerationOutputsResult:
38
+ flagged = False
39
+ preset_response = ""
40
+ if self.config is None:
41
+ raise ValueError("The config is not set.")
42
+
43
+ if self.config["outputs_config"]["enabled"]:
44
+ flagged = self._is_violated({"text": text})
45
+ preset_response = self.config["outputs_config"]["preset_response"]
46
+
47
+ return ModerationOutputsResult(
48
+ flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
49
+ )
50
+
51
+ def _is_violated(self, inputs: dict):
52
+ text = "\n".join(str(inputs.values()))
53
+ model_manager = ModelManager()
54
+ model_instance = model_manager.get_model_instance(
55
+ tenant_id=self.tenant_id, provider="openai", model_type=ModelType.MODERATION, model="text-moderation-stable"
56
+ )
57
+
58
+ openai_moderation = model_instance.invoke_moderation(text=text)
59
+
60
+ return openai_moderation
api/core/moderation/output_moderation.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import threading
3
+ import time
4
+ from typing import Any, Optional
5
+
6
+ from flask import Flask, current_app
7
+ from pydantic import BaseModel, ConfigDict
8
+
9
+ from configs import dify_config
10
+ from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
11
+ from core.app.entities.queue_entities import QueueMessageReplaceEvent
12
+ from core.moderation.base import ModerationAction, ModerationOutputsResult
13
+ from core.moderation.factory import ModerationFactory
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class ModerationRule(BaseModel):
19
+ type: str
20
+ config: dict[str, Any]
21
+
22
+
23
+ class OutputModeration(BaseModel):
24
+ tenant_id: str
25
+ app_id: str
26
+
27
+ rule: ModerationRule
28
+ queue_manager: AppQueueManager
29
+
30
+ thread: Optional[threading.Thread] = None
31
+ thread_running: bool = True
32
+ buffer: str = ""
33
+ is_final_chunk: bool = False
34
+ final_output: Optional[str] = None
35
+ model_config = ConfigDict(arbitrary_types_allowed=True)
36
+
37
+ def should_direct_output(self) -> bool:
38
+ return self.final_output is not None
39
+
40
+ def get_final_output(self) -> str:
41
+ return self.final_output or ""
42
+
43
+ def append_new_token(self, token: str) -> None:
44
+ self.buffer += token
45
+
46
+ if not self.thread:
47
+ self.thread = self.start_thread()
48
+
49
+ def moderation_completion(self, completion: str, public_event: bool = False) -> str:
50
+ self.buffer = completion
51
+ self.is_final_chunk = True
52
+
53
+ result = self.moderation(tenant_id=self.tenant_id, app_id=self.app_id, moderation_buffer=completion)
54
+
55
+ if not result or not result.flagged:
56
+ return completion
57
+
58
+ if result.action == ModerationAction.DIRECT_OUTPUT:
59
+ final_output = result.preset_response
60
+ else:
61
+ final_output = result.text
62
+
63
+ if public_event:
64
+ self.queue_manager.publish(QueueMessageReplaceEvent(text=final_output), PublishFrom.TASK_PIPELINE)
65
+
66
+ return final_output
67
+
68
+ def start_thread(self) -> threading.Thread:
69
+ buffer_size = dify_config.MODERATION_BUFFER_SIZE
70
+ thread = threading.Thread(
71
+ target=self.worker,
72
+ kwargs={
73
+ "flask_app": current_app._get_current_object(), # type: ignore
74
+ "buffer_size": buffer_size if buffer_size > 0 else dify_config.MODERATION_BUFFER_SIZE,
75
+ },
76
+ )
77
+
78
+ thread.start()
79
+
80
+ return thread
81
+
82
+ def stop_thread(self):
83
+ if self.thread and self.thread.is_alive():
84
+ self.thread_running = False
85
+
86
+ def worker(self, flask_app: Flask, buffer_size: int):
87
+ with flask_app.app_context():
88
+ current_length = 0
89
+ while self.thread_running:
90
+ moderation_buffer = self.buffer
91
+ buffer_length = len(moderation_buffer)
92
+ if not self.is_final_chunk:
93
+ chunk_length = buffer_length - current_length
94
+ if 0 <= chunk_length < buffer_size:
95
+ time.sleep(1)
96
+ continue
97
+
98
+ current_length = buffer_length
99
+
100
+ result = self.moderation(
101
+ tenant_id=self.tenant_id, app_id=self.app_id, moderation_buffer=moderation_buffer
102
+ )
103
+
104
+ if not result or not result.flagged:
105
+ continue
106
+
107
+ if result.action == ModerationAction.DIRECT_OUTPUT:
108
+ final_output = result.preset_response
109
+ self.final_output = final_output
110
+ else:
111
+ final_output = result.text + self.buffer[len(moderation_buffer) :]
112
+
113
+ # trigger replace event
114
+ if self.thread_running:
115
+ self.queue_manager.publish(QueueMessageReplaceEvent(text=final_output), PublishFrom.TASK_PIPELINE)
116
+
117
+ if result.action == ModerationAction.DIRECT_OUTPUT:
118
+ break
119
+
120
+ def moderation(self, tenant_id: str, app_id: str, moderation_buffer: str) -> Optional[ModerationOutputsResult]:
121
+ try:
122
+ moderation_factory = ModerationFactory(
123
+ name=self.rule.type, app_id=app_id, tenant_id=tenant_id, config=self.rule.config
124
+ )
125
+
126
+ result: ModerationOutputsResult = moderation_factory.moderation_for_outputs(moderation_buffer)
127
+ return result
128
+ except Exception as e:
129
+ logger.exception(f"Moderation Output error, app_id: {app_id}")
130
+
131
+ return None
api/core/ops/__init__.py ADDED
File without changes
api/core/ops/base_trace_instance.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+
3
+ from core.ops.entities.config_entity import BaseTracingConfig
4
+ from core.ops.entities.trace_entity import BaseTraceInfo
5
+
6
+
7
+ class BaseTraceInstance(ABC):
8
+ """
9
+ Base trace instance for ops trace services
10
+ """
11
+
12
+ @abstractmethod
13
+ def __init__(self, trace_config: BaseTracingConfig):
14
+ """
15
+ Abstract initializer for the trace instance.
16
+ Distribute trace tasks by matching entities
17
+ """
18
+ self.trace_config = trace_config
19
+
20
+ @abstractmethod
21
+ def trace(self, trace_info: BaseTraceInfo):
22
+ """
23
+ Abstract method to trace activities.
24
+ Subclasses must implement specific tracing logic for activities.
25
+ """
26
+ ...
api/core/ops/entities/__init__.py ADDED
File without changes
api/core/ops/entities/config_entity.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+
3
+ from pydantic import BaseModel, ValidationInfo, field_validator
4
+
5
+
6
+ class TracingProviderEnum(Enum):
7
+ LANGFUSE = "langfuse"
8
+ LANGSMITH = "langsmith"
9
+ OPIK = "opik"
10
+
11
+
12
+ class BaseTracingConfig(BaseModel):
13
+ """
14
+ Base model class for tracing
15
+ """
16
+
17
+ ...
18
+
19
+
20
+ class LangfuseConfig(BaseTracingConfig):
21
+ """
22
+ Model class for Langfuse tracing config.
23
+ """
24
+
25
+ public_key: str
26
+ secret_key: str
27
+ host: str = "https://api.langfuse.com"
28
+
29
+ @field_validator("host")
30
+ @classmethod
31
+ def set_value(cls, v, info: ValidationInfo):
32
+ if v is None or v == "":
33
+ v = "https://api.langfuse.com"
34
+ if not v.startswith("https://") and not v.startswith("http://"):
35
+ raise ValueError("host must start with https:// or http://")
36
+
37
+ return v
38
+
39
+
40
+ class LangSmithConfig(BaseTracingConfig):
41
+ """
42
+ Model class for Langsmith tracing config.
43
+ """
44
+
45
+ api_key: str
46
+ project: str
47
+ endpoint: str = "https://api.smith.langchain.com"
48
+
49
+ @field_validator("endpoint")
50
+ @classmethod
51
+ def set_value(cls, v, info: ValidationInfo):
52
+ if v is None or v == "":
53
+ v = "https://api.smith.langchain.com"
54
+ if not v.startswith("https://"):
55
+ raise ValueError("endpoint must start with https://")
56
+
57
+ return v
58
+
59
+
60
+ class OpikConfig(BaseTracingConfig):
61
+ """
62
+ Model class for Opik tracing config.
63
+ """
64
+
65
+ api_key: str | None = None
66
+ project: str | None = None
67
+ workspace: str | None = None
68
+ url: str = "https://www.comet.com/opik/api/"
69
+
70
+ @field_validator("project")
71
+ @classmethod
72
+ def project_validator(cls, v, info: ValidationInfo):
73
+ if v is None or v == "":
74
+ v = "Default Project"
75
+
76
+ return v
77
+
78
+ @field_validator("url")
79
+ @classmethod
80
+ def url_validator(cls, v, info: ValidationInfo):
81
+ if v is None or v == "":
82
+ v = "https://www.comet.com/opik/api/"
83
+ if not v.startswith(("https://", "http://")):
84
+ raise ValueError("url must start with https:// or http://")
85
+ if not v.endswith("/api/"):
86
+ raise ValueError("url should ends with /api/")
87
+
88
+ return v
89
+
90
+
91
+ OPS_FILE_PATH = "ops_trace/"
92
+ OPS_TRACE_FAILED_KEY = "FAILED_OPS_TRACE"
api/core/ops/entities/trace_entity.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections.abc import Mapping
2
+ from datetime import datetime
3
+ from enum import StrEnum
4
+ from typing import Any, Optional, Union
5
+
6
+ from pydantic import BaseModel, ConfigDict, field_validator
7
+
8
+
9
+ class BaseTraceInfo(BaseModel):
10
+ message_id: Optional[str] = None
11
+ message_data: Optional[Any] = None
12
+ inputs: Optional[Union[str, dict[str, Any], list]] = None
13
+ outputs: Optional[Union[str, dict[str, Any], list]] = None
14
+ start_time: Optional[datetime] = None
15
+ end_time: Optional[datetime] = None
16
+ metadata: dict[str, Any]
17
+
18
+ @field_validator("inputs", "outputs")
19
+ @classmethod
20
+ def ensure_type(cls, v):
21
+ if v is None:
22
+ return None
23
+ if isinstance(v, str | dict | list):
24
+ return v
25
+ return ""
26
+
27
+ class Config:
28
+ json_encoders = {
29
+ datetime: lambda v: v.isoformat(),
30
+ }
31
+
32
+
33
+ class WorkflowTraceInfo(BaseTraceInfo):
34
+ workflow_data: Any
35
+ conversation_id: Optional[str] = None
36
+ workflow_app_log_id: Optional[str] = None
37
+ workflow_id: str
38
+ tenant_id: str
39
+ workflow_run_id: str
40
+ workflow_run_elapsed_time: Union[int, float]
41
+ workflow_run_status: str
42
+ workflow_run_inputs: Mapping[str, Any]
43
+ workflow_run_outputs: Mapping[str, Any]
44
+ workflow_run_version: str
45
+ error: Optional[str] = None
46
+ total_tokens: int
47
+ file_list: list[str]
48
+ query: str
49
+ metadata: dict[str, Any]
50
+
51
+
52
+ class MessageTraceInfo(BaseTraceInfo):
53
+ conversation_model: str
54
+ message_tokens: int
55
+ answer_tokens: int
56
+ total_tokens: int
57
+ error: Optional[str] = None
58
+ file_list: Optional[Union[str, dict[str, Any], list]] = None
59
+ message_file_data: Optional[Any] = None
60
+ conversation_mode: str
61
+
62
+
63
+ class ModerationTraceInfo(BaseTraceInfo):
64
+ flagged: bool
65
+ action: str
66
+ preset_response: str
67
+ query: str
68
+
69
+
70
+ class SuggestedQuestionTraceInfo(BaseTraceInfo):
71
+ total_tokens: int
72
+ status: Optional[str] = None
73
+ error: Optional[str] = None
74
+ from_account_id: Optional[str] = None
75
+ agent_based: Optional[bool] = None
76
+ from_source: Optional[str] = None
77
+ model_provider: Optional[str] = None
78
+ model_id: Optional[str] = None
79
+ suggested_question: list[str]
80
+ level: str
81
+ status_message: Optional[str] = None
82
+ workflow_run_id: Optional[str] = None
83
+
84
+ model_config = ConfigDict(protected_namespaces=())
85
+
86
+
87
+ class DatasetRetrievalTraceInfo(BaseTraceInfo):
88
+ documents: Any
89
+
90
+
91
+ class ToolTraceInfo(BaseTraceInfo):
92
+ tool_name: str
93
+ tool_inputs: dict[str, Any]
94
+ tool_outputs: str
95
+ metadata: dict[str, Any]
96
+ message_file_data: Any
97
+ error: Optional[str] = None
98
+ tool_config: dict[str, Any]
99
+ time_cost: Union[int, float]
100
+ tool_parameters: dict[str, Any]
101
+ file_url: Union[str, None, list]
102
+
103
+
104
+ class GenerateNameTraceInfo(BaseTraceInfo):
105
+ conversation_id: Optional[str] = None
106
+ tenant_id: str
107
+
108
+
109
+ class TaskData(BaseModel):
110
+ app_id: str
111
+ trace_info_type: str
112
+ trace_info: Any
113
+
114
+
115
+ trace_info_info_map = {
116
+ "WorkflowTraceInfo": WorkflowTraceInfo,
117
+ "MessageTraceInfo": MessageTraceInfo,
118
+ "ModerationTraceInfo": ModerationTraceInfo,
119
+ "SuggestedQuestionTraceInfo": SuggestedQuestionTraceInfo,
120
+ "DatasetRetrievalTraceInfo": DatasetRetrievalTraceInfo,
121
+ "ToolTraceInfo": ToolTraceInfo,
122
+ "GenerateNameTraceInfo": GenerateNameTraceInfo,
123
+ }
124
+
125
+
126
+ class TraceTaskName(StrEnum):
127
+ CONVERSATION_TRACE = "conversation"
128
+ WORKFLOW_TRACE = "workflow"
129
+ MESSAGE_TRACE = "message"
130
+ MODERATION_TRACE = "moderation"
131
+ SUGGESTED_QUESTION_TRACE = "suggested_question"
132
+ DATASET_RETRIEVAL_TRACE = "dataset_retrieval"
133
+ TOOL_TRACE = "tool"
134
+ GENERATE_NAME_TRACE = "generate_conversation_name"
api/core/ops/langfuse_trace/__init__.py ADDED
File without changes
api/core/ops/langfuse_trace/entities/__init__.py ADDED
File without changes
api/core/ops/langfuse_trace/entities/langfuse_trace_entity.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime
2
+ from enum import StrEnum
3
+ from typing import Any, Optional, Union
4
+
5
+ from pydantic import BaseModel, ConfigDict, Field, field_validator
6
+ from pydantic_core.core_schema import ValidationInfo
7
+
8
+ from core.ops.utils import replace_text_with_content
9
+
10
+
11
+ def validate_input_output(v, field_name):
12
+ """
13
+ Validate input output
14
+ :param v:
15
+ :param field_name:
16
+ :return:
17
+ """
18
+ if v == {} or v is None:
19
+ return v
20
+ if isinstance(v, str):
21
+ return [
22
+ {
23
+ "role": "assistant" if field_name == "output" else "user",
24
+ "content": v,
25
+ }
26
+ ]
27
+ elif isinstance(v, list):
28
+ if len(v) > 0 and isinstance(v[0], dict):
29
+ v = replace_text_with_content(data=v)
30
+ return v
31
+ else:
32
+ return [
33
+ {
34
+ "role": "assistant" if field_name == "output" else "user",
35
+ "content": str(v),
36
+ }
37
+ ]
38
+
39
+ return v
40
+
41
+
42
+ class LevelEnum(StrEnum):
43
+ DEBUG = "DEBUG"
44
+ WARNING = "WARNING"
45
+ ERROR = "ERROR"
46
+ DEFAULT = "DEFAULT"
47
+
48
+
49
+ class LangfuseTrace(BaseModel):
50
+ """
51
+ Langfuse trace model
52
+ """
53
+
54
+ id: Optional[str] = Field(
55
+ default=None,
56
+ description="The id of the trace can be set, defaults to a random id. Used to link traces to external systems "
57
+ "or when creating a distributed trace. Traces are upserted on id.",
58
+ )
59
+ name: Optional[str] = Field(
60
+ default=None,
61
+ description="Identifier of the trace. Useful for sorting/filtering in the UI.",
62
+ )
63
+ input: Optional[Union[str, dict[str, Any], list, None]] = Field(
64
+ default=None, description="The input of the trace. Can be any JSON object."
65
+ )
66
+ output: Optional[Union[str, dict[str, Any], list, None]] = Field(
67
+ default=None, description="The output of the trace. Can be any JSON object."
68
+ )
69
+ metadata: Optional[dict[str, Any]] = Field(
70
+ default=None,
71
+ description="Additional metadata of the trace. Can be any JSON object. Metadata is merged when being updated "
72
+ "via the API.",
73
+ )
74
+ user_id: Optional[str] = Field(
75
+ default=None,
76
+ description="The id of the user that triggered the execution. Used to provide user-level analytics.",
77
+ )
78
+ session_id: Optional[str] = Field(
79
+ default=None,
80
+ description="Used to group multiple traces into a session in Langfuse. Use your own session/thread identifier.",
81
+ )
82
+ version: Optional[str] = Field(
83
+ default=None,
84
+ description="The version of the trace type. Used to understand how changes to the trace type affect metrics. "
85
+ "Useful in debugging.",
86
+ )
87
+ release: Optional[str] = Field(
88
+ default=None,
89
+ description="The release identifier of the current deployment. Used to understand how changes of different "
90
+ "deployments affect metrics. Useful in debugging.",
91
+ )
92
+ tags: Optional[list[str]] = Field(
93
+ default=None,
94
+ description="Tags are used to categorize or label traces. Traces can be filtered by tags in the UI and GET "
95
+ "API. Tags can also be changed in the UI. Tags are merged and never deleted via the API.",
96
+ )
97
+ public: Optional[bool] = Field(
98
+ default=None,
99
+ description="You can make a trace public to share it via a public link. This allows others to view the trace "
100
+ "without needing to log in or be members of your Langfuse project.",
101
+ )
102
+
103
+ @field_validator("input", "output")
104
+ @classmethod
105
+ def ensure_dict(cls, v, info: ValidationInfo):
106
+ field_name = info.field_name
107
+ return validate_input_output(v, field_name)
108
+
109
+
110
+ class LangfuseSpan(BaseModel):
111
+ """
112
+ Langfuse span model
113
+ """
114
+
115
+ id: Optional[str] = Field(
116
+ default=None,
117
+ description="The id of the span can be set, otherwise a random id is generated. Spans are upserted on id.",
118
+ )
119
+ session_id: Optional[str] = Field(
120
+ default=None,
121
+ description="Used to group multiple spans into a session in Langfuse. Use your own session/thread identifier.",
122
+ )
123
+ trace_id: Optional[str] = Field(
124
+ default=None,
125
+ description="The id of the trace the span belongs to. Used to link spans to traces.",
126
+ )
127
+ user_id: Optional[str] = Field(
128
+ default=None,
129
+ description="The id of the user that triggered the execution. Used to provide user-level analytics.",
130
+ )
131
+ start_time: Optional[datetime | str] = Field(
132
+ default_factory=datetime.now,
133
+ description="The time at which the span started, defaults to the current time.",
134
+ )
135
+ end_time: Optional[datetime | str] = Field(
136
+ default=None,
137
+ description="The time at which the span ended. Automatically set by span.end().",
138
+ )
139
+ name: Optional[str] = Field(
140
+ default=None,
141
+ description="Identifier of the span. Useful for sorting/filtering in the UI.",
142
+ )
143
+ metadata: Optional[dict[str, Any]] = Field(
144
+ default=None,
145
+ description="Additional metadata of the span. Can be any JSON object. Metadata is merged when being updated "
146
+ "via the API.",
147
+ )
148
+ level: Optional[str] = Field(
149
+ default=None,
150
+ description="The level of the span. Can be DEBUG, DEFAULT, WARNING or ERROR. Used for sorting/filtering of "
151
+ "traces with elevated error levels and for highlighting in the UI.",
152
+ )
153
+ status_message: Optional[str] = Field(
154
+ default=None,
155
+ description="The status message of the span. Additional field for context of the event. E.g. the error "
156
+ "message of an error event.",
157
+ )
158
+ input: Optional[Union[str, dict[str, Any], list, None]] = Field(
159
+ default=None, description="The input of the span. Can be any JSON object."
160
+ )
161
+ output: Optional[Union[str, dict[str, Any], list, None]] = Field(
162
+ default=None, description="The output of the span. Can be any JSON object."
163
+ )
164
+ version: Optional[str] = Field(
165
+ default=None,
166
+ description="The version of the span type. Used to understand how changes to the span type affect metrics. "
167
+ "Useful in debugging.",
168
+ )
169
+ parent_observation_id: Optional[str] = Field(
170
+ default=None,
171
+ description="The id of the observation the span belongs to. Used to link spans to observations.",
172
+ )
173
+
174
+ @field_validator("input", "output")
175
+ @classmethod
176
+ def ensure_dict(cls, v, info: ValidationInfo):
177
+ field_name = info.field_name
178
+ return validate_input_output(v, field_name)
179
+
180
+
181
+ class UnitEnum(StrEnum):
182
+ CHARACTERS = "CHARACTERS"
183
+ TOKENS = "TOKENS"
184
+ SECONDS = "SECONDS"
185
+ MILLISECONDS = "MILLISECONDS"
186
+ IMAGES = "IMAGES"
187
+
188
+
189
+ class GenerationUsage(BaseModel):
190
+ promptTokens: Optional[int] = None
191
+ completionTokens: Optional[int] = None
192
+ total: Optional[int] = None
193
+ input: Optional[int] = None
194
+ output: Optional[int] = None
195
+ unit: Optional[UnitEnum] = None
196
+ inputCost: Optional[float] = None
197
+ outputCost: Optional[float] = None
198
+ totalCost: Optional[float] = None
199
+
200
+ @field_validator("input", "output")
201
+ @classmethod
202
+ def ensure_dict(cls, v, info: ValidationInfo):
203
+ field_name = info.field_name
204
+ return validate_input_output(v, field_name)
205
+
206
+
207
+ class LangfuseGeneration(BaseModel):
208
+ id: Optional[str] = Field(
209
+ default=None,
210
+ description="The id of the generation can be set, defaults to random id.",
211
+ )
212
+ trace_id: Optional[str] = Field(
213
+ default=None,
214
+ description="The id of the trace the generation belongs to. Used to link generations to traces.",
215
+ )
216
+ parent_observation_id: Optional[str] = Field(
217
+ default=None,
218
+ description="The id of the observation the generation belongs to. Used to link generations to observations.",
219
+ )
220
+ name: Optional[str] = Field(
221
+ default=None,
222
+ description="Identifier of the generation. Useful for sorting/filtering in the UI.",
223
+ )
224
+ start_time: Optional[datetime | str] = Field(
225
+ default_factory=datetime.now,
226
+ description="The time at which the generation started, defaults to the current time.",
227
+ )
228
+ completion_start_time: Optional[datetime | str] = Field(
229
+ default=None,
230
+ description="The time at which the completion started (streaming). Set it to get latency analytics broken "
231
+ "down into time until completion started and completion duration.",
232
+ )
233
+ end_time: Optional[datetime | str] = Field(
234
+ default=None,
235
+ description="The time at which the generation ended. Automatically set by generation.end().",
236
+ )
237
+ model: Optional[str] = Field(default=None, description="The name of the model used for the generation.")
238
+ model_parameters: Optional[dict[str, Any]] = Field(
239
+ default=None,
240
+ description="The parameters of the model used for the generation; can be any key-value pairs.",
241
+ )
242
+ input: Optional[Any] = Field(
243
+ default=None,
244
+ description="The prompt used for the generation. Can be any string or JSON object.",
245
+ )
246
+ output: Optional[Any] = Field(
247
+ default=None,
248
+ description="The completion generated by the model. Can be any string or JSON object.",
249
+ )
250
+ usage: Optional[GenerationUsage] = Field(
251
+ default=None,
252
+ description="The usage object supports the OpenAi structure with tokens and a more generic version with "
253
+ "detailed costs and units.",
254
+ )
255
+ metadata: Optional[dict[str, Any]] = Field(
256
+ default=None,
257
+ description="Additional metadata of the generation. Can be any JSON object. Metadata is merged when being "
258
+ "updated via the API.",
259
+ )
260
+ level: Optional[LevelEnum] = Field(
261
+ default=None,
262
+ description="The level of the generation. Can be DEBUG, DEFAULT, WARNING or ERROR. Used for sorting/filtering "
263
+ "of traces with elevated error levels and for highlighting in the UI.",
264
+ )
265
+ status_message: Optional[str] = Field(
266
+ default=None,
267
+ description="The status message of the generation. Additional field for context of the event. E.g. the error "
268
+ "message of an error event.",
269
+ )
270
+ version: Optional[str] = Field(
271
+ default=None,
272
+ description="The version of the generation type. Used to understand how changes to the span type affect "
273
+ "metrics. Useful in debugging.",
274
+ )
275
+
276
+ model_config = ConfigDict(protected_namespaces=())
277
+
278
+ @field_validator("input", "output")
279
+ @classmethod
280
+ def ensure_dict(cls, v, info: ValidationInfo):
281
+ field_name = info.field_name
282
+ return validate_input_output(v, field_name)
api/core/ops/langfuse_trace/langfuse_trace.py ADDED
@@ -0,0 +1,455 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+ from datetime import datetime, timedelta
5
+ from typing import Optional
6
+
7
+ from langfuse import Langfuse # type: ignore
8
+
9
+ from core.ops.base_trace_instance import BaseTraceInstance
10
+ from core.ops.entities.config_entity import LangfuseConfig
11
+ from core.ops.entities.trace_entity import (
12
+ BaseTraceInfo,
13
+ DatasetRetrievalTraceInfo,
14
+ GenerateNameTraceInfo,
15
+ MessageTraceInfo,
16
+ ModerationTraceInfo,
17
+ SuggestedQuestionTraceInfo,
18
+ ToolTraceInfo,
19
+ TraceTaskName,
20
+ WorkflowTraceInfo,
21
+ )
22
+ from core.ops.langfuse_trace.entities.langfuse_trace_entity import (
23
+ GenerationUsage,
24
+ LangfuseGeneration,
25
+ LangfuseSpan,
26
+ LangfuseTrace,
27
+ LevelEnum,
28
+ UnitEnum,
29
+ )
30
+ from core.ops.utils import filter_none_values
31
+ from extensions.ext_database import db
32
+ from models.model import EndUser
33
+ from models.workflow import WorkflowNodeExecution
34
+
35
+ logger = logging.getLogger(__name__)
36
+
37
+
38
+ class LangFuseDataTrace(BaseTraceInstance):
39
+ def __init__(
40
+ self,
41
+ langfuse_config: LangfuseConfig,
42
+ ):
43
+ super().__init__(langfuse_config)
44
+ self.langfuse_client = Langfuse(
45
+ public_key=langfuse_config.public_key,
46
+ secret_key=langfuse_config.secret_key,
47
+ host=langfuse_config.host,
48
+ )
49
+ self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
50
+
51
+ def trace(self, trace_info: BaseTraceInfo):
52
+ if isinstance(trace_info, WorkflowTraceInfo):
53
+ self.workflow_trace(trace_info)
54
+ if isinstance(trace_info, MessageTraceInfo):
55
+ self.message_trace(trace_info)
56
+ if isinstance(trace_info, ModerationTraceInfo):
57
+ self.moderation_trace(trace_info)
58
+ if isinstance(trace_info, SuggestedQuestionTraceInfo):
59
+ self.suggested_question_trace(trace_info)
60
+ if isinstance(trace_info, DatasetRetrievalTraceInfo):
61
+ self.dataset_retrieval_trace(trace_info)
62
+ if isinstance(trace_info, ToolTraceInfo):
63
+ self.tool_trace(trace_info)
64
+ if isinstance(trace_info, GenerateNameTraceInfo):
65
+ self.generate_name_trace(trace_info)
66
+
67
+ def workflow_trace(self, trace_info: WorkflowTraceInfo):
68
+ trace_id = trace_info.workflow_run_id
69
+ user_id = trace_info.metadata.get("user_id")
70
+ metadata = trace_info.metadata
71
+ metadata["workflow_app_log_id"] = trace_info.workflow_app_log_id
72
+
73
+ if trace_info.message_id:
74
+ trace_id = trace_info.message_id
75
+ name = TraceTaskName.MESSAGE_TRACE.value
76
+ trace_data = LangfuseTrace(
77
+ id=trace_id,
78
+ user_id=user_id,
79
+ name=name,
80
+ input=dict(trace_info.workflow_run_inputs),
81
+ output=dict(trace_info.workflow_run_outputs),
82
+ metadata=metadata,
83
+ session_id=trace_info.conversation_id,
84
+ tags=["message", "workflow"],
85
+ )
86
+ self.add_trace(langfuse_trace_data=trace_data)
87
+ workflow_span_data = LangfuseSpan(
88
+ id=trace_info.workflow_run_id,
89
+ name=TraceTaskName.WORKFLOW_TRACE.value,
90
+ input=dict(trace_info.workflow_run_inputs),
91
+ output=dict(trace_info.workflow_run_outputs),
92
+ trace_id=trace_id,
93
+ start_time=trace_info.start_time,
94
+ end_time=trace_info.end_time,
95
+ metadata=metadata,
96
+ level=LevelEnum.DEFAULT if trace_info.error == "" else LevelEnum.ERROR,
97
+ status_message=trace_info.error or "",
98
+ )
99
+ self.add_span(langfuse_span_data=workflow_span_data)
100
+ else:
101
+ trace_data = LangfuseTrace(
102
+ id=trace_id,
103
+ user_id=user_id,
104
+ name=TraceTaskName.WORKFLOW_TRACE.value,
105
+ input=dict(trace_info.workflow_run_inputs),
106
+ output=dict(trace_info.workflow_run_outputs),
107
+ metadata=metadata,
108
+ session_id=trace_info.conversation_id,
109
+ tags=["workflow"],
110
+ )
111
+ self.add_trace(langfuse_trace_data=trace_data)
112
+
113
+ # through workflow_run_id get all_nodes_execution
114
+ workflow_nodes_execution_id_records = (
115
+ db.session.query(WorkflowNodeExecution.id)
116
+ .filter(WorkflowNodeExecution.workflow_run_id == trace_info.workflow_run_id)
117
+ .all()
118
+ )
119
+
120
+ for node_execution_id_record in workflow_nodes_execution_id_records:
121
+ node_execution = (
122
+ db.session.query(
123
+ WorkflowNodeExecution.id,
124
+ WorkflowNodeExecution.tenant_id,
125
+ WorkflowNodeExecution.app_id,
126
+ WorkflowNodeExecution.title,
127
+ WorkflowNodeExecution.node_type,
128
+ WorkflowNodeExecution.status,
129
+ WorkflowNodeExecution.inputs,
130
+ WorkflowNodeExecution.outputs,
131
+ WorkflowNodeExecution.created_at,
132
+ WorkflowNodeExecution.elapsed_time,
133
+ WorkflowNodeExecution.process_data,
134
+ WorkflowNodeExecution.execution_metadata,
135
+ )
136
+ .filter(WorkflowNodeExecution.id == node_execution_id_record.id)
137
+ .first()
138
+ )
139
+
140
+ if not node_execution:
141
+ continue
142
+
143
+ node_execution_id = node_execution.id
144
+ tenant_id = node_execution.tenant_id
145
+ app_id = node_execution.app_id
146
+ node_name = node_execution.title
147
+ node_type = node_execution.node_type
148
+ status = node_execution.status
149
+ if node_type == "llm":
150
+ inputs = (
151
+ json.loads(node_execution.process_data).get("prompts", {}) if node_execution.process_data else {}
152
+ )
153
+ else:
154
+ inputs = json.loads(node_execution.inputs) if node_execution.inputs else {}
155
+ outputs = json.loads(node_execution.outputs) if node_execution.outputs else {}
156
+ created_at = node_execution.created_at or datetime.now()
157
+ elapsed_time = node_execution.elapsed_time
158
+ finished_at = created_at + timedelta(seconds=elapsed_time)
159
+
160
+ metadata = json.loads(node_execution.execution_metadata) if node_execution.execution_metadata else {}
161
+ metadata.update(
162
+ {
163
+ "workflow_run_id": trace_info.workflow_run_id,
164
+ "node_execution_id": node_execution_id,
165
+ "tenant_id": tenant_id,
166
+ "app_id": app_id,
167
+ "node_name": node_name,
168
+ "node_type": node_type,
169
+ "status": status,
170
+ }
171
+ )
172
+ process_data = json.loads(node_execution.process_data) if node_execution.process_data else {}
173
+ model_provider = process_data.get("model_provider", None)
174
+ model_name = process_data.get("model_name", None)
175
+ if model_provider is not None and model_name is not None:
176
+ metadata.update(
177
+ {
178
+ "model_provider": model_provider,
179
+ "model_name": model_name,
180
+ }
181
+ )
182
+
183
+ # add span
184
+ if trace_info.message_id:
185
+ span_data = LangfuseSpan(
186
+ id=node_execution_id,
187
+ name=node_type,
188
+ input=inputs,
189
+ output=outputs,
190
+ trace_id=trace_id,
191
+ start_time=created_at,
192
+ end_time=finished_at,
193
+ metadata=metadata,
194
+ level=(LevelEnum.DEFAULT if status == "succeeded" else LevelEnum.ERROR),
195
+ status_message=trace_info.error or "",
196
+ parent_observation_id=trace_info.workflow_run_id,
197
+ )
198
+ else:
199
+ span_data = LangfuseSpan(
200
+ id=node_execution_id,
201
+ name=node_type,
202
+ input=inputs,
203
+ output=outputs,
204
+ trace_id=trace_id,
205
+ start_time=created_at,
206
+ end_time=finished_at,
207
+ metadata=metadata,
208
+ level=(LevelEnum.DEFAULT if status == "succeeded" else LevelEnum.ERROR),
209
+ status_message=trace_info.error or "",
210
+ )
211
+
212
+ self.add_span(langfuse_span_data=span_data)
213
+
214
+ if process_data and process_data.get("model_mode") == "chat":
215
+ total_token = metadata.get("total_tokens", 0)
216
+ # add generation
217
+ generation_usage = GenerationUsage(
218
+ total=total_token,
219
+ )
220
+
221
+ node_generation_data = LangfuseGeneration(
222
+ name="llm",
223
+ trace_id=trace_id,
224
+ model=process_data.get("model_name"),
225
+ parent_observation_id=node_execution_id,
226
+ start_time=created_at,
227
+ end_time=finished_at,
228
+ input=inputs,
229
+ output=outputs,
230
+ metadata=metadata,
231
+ level=(LevelEnum.DEFAULT if status == "succeeded" else LevelEnum.ERROR),
232
+ status_message=trace_info.error or "",
233
+ usage=generation_usage,
234
+ )
235
+
236
+ self.add_generation(langfuse_generation_data=node_generation_data)
237
+
238
+ def message_trace(self, trace_info: MessageTraceInfo, **kwargs):
239
+ # get message file data
240
+ file_list = trace_info.file_list
241
+ metadata = trace_info.metadata
242
+ message_data = trace_info.message_data
243
+ if message_data is None:
244
+ return
245
+ message_id = message_data.id
246
+
247
+ user_id = message_data.from_account_id
248
+ if message_data.from_end_user_id:
249
+ end_user_data: Optional[EndUser] = (
250
+ db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first()
251
+ )
252
+ if end_user_data is not None:
253
+ user_id = end_user_data.session_id
254
+ metadata["user_id"] = user_id
255
+
256
+ trace_data = LangfuseTrace(
257
+ id=message_id,
258
+ user_id=user_id,
259
+ name=TraceTaskName.MESSAGE_TRACE.value,
260
+ input={
261
+ "message": trace_info.inputs,
262
+ "files": file_list,
263
+ "message_tokens": trace_info.message_tokens,
264
+ "answer_tokens": trace_info.answer_tokens,
265
+ "total_tokens": trace_info.total_tokens,
266
+ "error": trace_info.error,
267
+ "provider_response_latency": message_data.provider_response_latency,
268
+ "created_at": trace_info.start_time,
269
+ },
270
+ output=trace_info.outputs,
271
+ metadata=metadata,
272
+ session_id=message_data.conversation_id,
273
+ tags=["message", str(trace_info.conversation_mode)],
274
+ version=None,
275
+ release=None,
276
+ public=None,
277
+ )
278
+ self.add_trace(langfuse_trace_data=trace_data)
279
+
280
+ # start add span
281
+ generation_usage = GenerationUsage(
282
+ input=trace_info.message_tokens,
283
+ output=trace_info.answer_tokens,
284
+ total=trace_info.total_tokens,
285
+ unit=UnitEnum.TOKENS,
286
+ totalCost=message_data.total_price,
287
+ )
288
+
289
+ langfuse_generation_data = LangfuseGeneration(
290
+ name="llm",
291
+ trace_id=message_id,
292
+ start_time=trace_info.start_time,
293
+ end_time=trace_info.end_time,
294
+ model=message_data.model_id,
295
+ input=trace_info.inputs,
296
+ output=message_data.answer,
297
+ metadata=metadata,
298
+ level=(LevelEnum.DEFAULT if message_data.status != "error" else LevelEnum.ERROR),
299
+ status_message=message_data.error or "",
300
+ usage=generation_usage,
301
+ )
302
+
303
+ self.add_generation(langfuse_generation_data)
304
+
305
+ def moderation_trace(self, trace_info: ModerationTraceInfo):
306
+ if trace_info.message_data is None:
307
+ return
308
+ span_data = LangfuseSpan(
309
+ name=TraceTaskName.MODERATION_TRACE.value,
310
+ input=trace_info.inputs,
311
+ output={
312
+ "action": trace_info.action,
313
+ "flagged": trace_info.flagged,
314
+ "preset_response": trace_info.preset_response,
315
+ "inputs": trace_info.inputs,
316
+ },
317
+ trace_id=trace_info.message_id,
318
+ start_time=trace_info.start_time or trace_info.message_data.created_at,
319
+ end_time=trace_info.end_time or trace_info.message_data.created_at,
320
+ metadata=trace_info.metadata,
321
+ )
322
+
323
+ self.add_span(langfuse_span_data=span_data)
324
+
325
+ def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo):
326
+ message_data = trace_info.message_data
327
+ if message_data is None:
328
+ return
329
+ generation_usage = GenerationUsage(
330
+ total=len(str(trace_info.suggested_question)),
331
+ input=len(trace_info.inputs) if trace_info.inputs else 0,
332
+ output=len(trace_info.suggested_question),
333
+ unit=UnitEnum.CHARACTERS,
334
+ )
335
+
336
+ generation_data = LangfuseGeneration(
337
+ name=TraceTaskName.SUGGESTED_QUESTION_TRACE.value,
338
+ input=trace_info.inputs,
339
+ output=str(trace_info.suggested_question),
340
+ trace_id=trace_info.message_id,
341
+ start_time=trace_info.start_time,
342
+ end_time=trace_info.end_time,
343
+ metadata=trace_info.metadata,
344
+ level=(LevelEnum.DEFAULT if message_data.status != "error" else LevelEnum.ERROR),
345
+ status_message=message_data.error or "",
346
+ usage=generation_usage,
347
+ )
348
+
349
+ self.add_generation(langfuse_generation_data=generation_data)
350
+
351
+ def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo):
352
+ if trace_info.message_data is None:
353
+ return
354
+ dataset_retrieval_span_data = LangfuseSpan(
355
+ name=TraceTaskName.DATASET_RETRIEVAL_TRACE.value,
356
+ input=trace_info.inputs,
357
+ output={"documents": trace_info.documents},
358
+ trace_id=trace_info.message_id,
359
+ start_time=trace_info.start_time or trace_info.message_data.created_at,
360
+ end_time=trace_info.end_time or trace_info.message_data.updated_at,
361
+ metadata=trace_info.metadata,
362
+ )
363
+
364
+ self.add_span(langfuse_span_data=dataset_retrieval_span_data)
365
+
366
+ def tool_trace(self, trace_info: ToolTraceInfo):
367
+ tool_span_data = LangfuseSpan(
368
+ name=trace_info.tool_name,
369
+ input=trace_info.tool_inputs,
370
+ output=trace_info.tool_outputs,
371
+ trace_id=trace_info.message_id,
372
+ start_time=trace_info.start_time,
373
+ end_time=trace_info.end_time,
374
+ metadata=trace_info.metadata,
375
+ level=(LevelEnum.DEFAULT if trace_info.error == "" or trace_info.error is None else LevelEnum.ERROR),
376
+ status_message=trace_info.error,
377
+ )
378
+
379
+ self.add_span(langfuse_span_data=tool_span_data)
380
+
381
+ def generate_name_trace(self, trace_info: GenerateNameTraceInfo):
382
+ name_generation_trace_data = LangfuseTrace(
383
+ name=TraceTaskName.GENERATE_NAME_TRACE.value,
384
+ input=trace_info.inputs,
385
+ output=trace_info.outputs,
386
+ user_id=trace_info.tenant_id,
387
+ metadata=trace_info.metadata,
388
+ session_id=trace_info.conversation_id,
389
+ )
390
+
391
+ self.add_trace(langfuse_trace_data=name_generation_trace_data)
392
+
393
+ name_generation_span_data = LangfuseSpan(
394
+ name=TraceTaskName.GENERATE_NAME_TRACE.value,
395
+ input=trace_info.inputs,
396
+ output=trace_info.outputs,
397
+ trace_id=trace_info.conversation_id,
398
+ start_time=trace_info.start_time,
399
+ end_time=trace_info.end_time,
400
+ metadata=trace_info.metadata,
401
+ )
402
+ self.add_span(langfuse_span_data=name_generation_span_data)
403
+
404
+ def add_trace(self, langfuse_trace_data: Optional[LangfuseTrace] = None):
405
+ format_trace_data = filter_none_values(langfuse_trace_data.model_dump()) if langfuse_trace_data else {}
406
+ try:
407
+ self.langfuse_client.trace(**format_trace_data)
408
+ logger.debug("LangFuse Trace created successfully")
409
+ except Exception as e:
410
+ raise ValueError(f"LangFuse Failed to create trace: {str(e)}")
411
+
412
+ def add_span(self, langfuse_span_data: Optional[LangfuseSpan] = None):
413
+ format_span_data = filter_none_values(langfuse_span_data.model_dump()) if langfuse_span_data else {}
414
+ try:
415
+ self.langfuse_client.span(**format_span_data)
416
+ logger.debug("LangFuse Span created successfully")
417
+ except Exception as e:
418
+ raise ValueError(f"LangFuse Failed to create span: {str(e)}")
419
+
420
+ def update_span(self, span, langfuse_span_data: Optional[LangfuseSpan] = None):
421
+ format_span_data = filter_none_values(langfuse_span_data.model_dump()) if langfuse_span_data else {}
422
+
423
+ span.end(**format_span_data)
424
+
425
+ def add_generation(self, langfuse_generation_data: Optional[LangfuseGeneration] = None):
426
+ format_generation_data = (
427
+ filter_none_values(langfuse_generation_data.model_dump()) if langfuse_generation_data else {}
428
+ )
429
+ try:
430
+ self.langfuse_client.generation(**format_generation_data)
431
+ logger.debug("LangFuse Generation created successfully")
432
+ except Exception as e:
433
+ raise ValueError(f"LangFuse Failed to create generation: {str(e)}")
434
+
435
+ def update_generation(self, generation, langfuse_generation_data: Optional[LangfuseGeneration] = None):
436
+ format_generation_data = (
437
+ filter_none_values(langfuse_generation_data.model_dump()) if langfuse_generation_data else {}
438
+ )
439
+
440
+ generation.end(**format_generation_data)
441
+
442
+ def api_check(self):
443
+ try:
444
+ return self.langfuse_client.auth_check()
445
+ except Exception as e:
446
+ logger.debug(f"LangFuse API check failed: {str(e)}")
447
+ raise ValueError(f"LangFuse API check failed: {str(e)}")
448
+
449
+ def get_project_key(self):
450
+ try:
451
+ projects = self.langfuse_client.client.projects.get()
452
+ return projects.data[0].id
453
+ except Exception as e:
454
+ logger.debug(f"LangFuse get project key failed: {str(e)}")
455
+ raise ValueError(f"LangFuse get project key failed: {str(e)}")
api/core/ops/langsmith_trace/__init__.py ADDED
File without changes
api/core/ops/langsmith_trace/entities/__init__.py ADDED
File without changes
api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime
2
+ from enum import StrEnum
3
+ from typing import Any, Optional, Union
4
+
5
+ from pydantic import BaseModel, Field, field_validator
6
+ from pydantic_core.core_schema import ValidationInfo
7
+
8
+ from core.ops.utils import replace_text_with_content
9
+
10
+
11
+ class LangSmithRunType(StrEnum):
12
+ tool = "tool"
13
+ chain = "chain"
14
+ llm = "llm"
15
+ retriever = "retriever"
16
+ embedding = "embedding"
17
+ prompt = "prompt"
18
+ parser = "parser"
19
+
20
+
21
+ class LangSmithTokenUsage(BaseModel):
22
+ input_tokens: Optional[int] = None
23
+ output_tokens: Optional[int] = None
24
+ total_tokens: Optional[int] = None
25
+
26
+
27
+ class LangSmithMultiModel(BaseModel):
28
+ file_list: Optional[list[str]] = Field(None, description="List of files")
29
+
30
+
31
+ class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel):
32
+ name: Optional[str] = Field(..., description="Name of the run")
33
+ inputs: Optional[Union[str, dict[str, Any], list, None]] = Field(None, description="Inputs of the run")
34
+ outputs: Optional[Union[str, dict[str, Any], list, None]] = Field(None, description="Outputs of the run")
35
+ run_type: LangSmithRunType = Field(..., description="Type of the run")
36
+ start_time: Optional[datetime | str] = Field(None, description="Start time of the run")
37
+ end_time: Optional[datetime | str] = Field(None, description="End time of the run")
38
+ extra: Optional[dict[str, Any]] = Field(None, description="Extra information of the run")
39
+ error: Optional[str] = Field(None, description="Error message of the run")
40
+ serialized: Optional[dict[str, Any]] = Field(None, description="Serialized data of the run")
41
+ parent_run_id: Optional[str] = Field(None, description="Parent run ID")
42
+ events: Optional[list[dict[str, Any]]] = Field(None, description="Events associated with the run")
43
+ tags: Optional[list[str]] = Field(None, description="Tags associated with the run")
44
+ trace_id: Optional[str] = Field(None, description="Trace ID associated with the run")
45
+ dotted_order: Optional[str] = Field(None, description="Dotted order of the run")
46
+ id: Optional[str] = Field(None, description="ID of the run")
47
+ session_id: Optional[str] = Field(None, description="Session ID associated with the run")
48
+ session_name: Optional[str] = Field(None, description="Session name associated with the run")
49
+ reference_example_id: Optional[str] = Field(None, description="Reference example ID associated with the run")
50
+ input_attachments: Optional[dict[str, Any]] = Field(None, description="Input attachments of the run")
51
+ output_attachments: Optional[dict[str, Any]] = Field(None, description="Output attachments of the run")
52
+
53
+ @field_validator("inputs", "outputs")
54
+ @classmethod
55
+ def ensure_dict(cls, v, info: ValidationInfo):
56
+ field_name = info.field_name
57
+ values = info.data
58
+ if v == {} or v is None:
59
+ return v
60
+ usage_metadata = {
61
+ "input_tokens": values.get("input_tokens", 0),
62
+ "output_tokens": values.get("output_tokens", 0),
63
+ "total_tokens": values.get("total_tokens", 0),
64
+ }
65
+ file_list = values.get("file_list", [])
66
+ if isinstance(v, str):
67
+ if field_name == "inputs":
68
+ return {
69
+ "messages": {
70
+ "role": "user",
71
+ "content": v,
72
+ "usage_metadata": usage_metadata,
73
+ "file_list": file_list,
74
+ },
75
+ }
76
+ elif field_name == "outputs":
77
+ return {
78
+ "choices": {
79
+ "role": "ai",
80
+ "content": v,
81
+ "usage_metadata": usage_metadata,
82
+ "file_list": file_list,
83
+ },
84
+ }
85
+ elif isinstance(v, list):
86
+ data = {}
87
+ if len(v) > 0 and isinstance(v[0], dict):
88
+ # rename text to content
89
+ v = replace_text_with_content(data=v)
90
+ if field_name == "inputs":
91
+ data = {
92
+ "messages": v,
93
+ }
94
+ elif field_name == "outputs":
95
+ data = {
96
+ "choices": {
97
+ "role": "ai",
98
+ "content": v,
99
+ "usage_metadata": usage_metadata,
100
+ "file_list": file_list,
101
+ },
102
+ }
103
+ return data
104
+ else:
105
+ return {
106
+ "choices": {
107
+ "role": "ai" if field_name == "outputs" else "user",
108
+ "content": str(v),
109
+ "usage_metadata": usage_metadata,
110
+ "file_list": file_list,
111
+ },
112
+ }
113
+ if isinstance(v, dict):
114
+ v["usage_metadata"] = usage_metadata
115
+ v["file_list"] = file_list
116
+ return v
117
+ return v
118
+
119
+ @classmethod
120
+ @field_validator("start_time", "end_time")
121
+ def format_time(cls, v, info: ValidationInfo):
122
+ if not isinstance(v, datetime):
123
+ raise ValueError(f"{info.field_name} must be a datetime object")
124
+ else:
125
+ return v.strftime("%Y-%m-%dT%H:%M:%S.%fZ")
126
+
127
+
128
+ class LangSmithRunUpdateModel(BaseModel):
129
+ run_id: str = Field(..., description="ID of the run")
130
+ trace_id: Optional[str] = Field(None, description="Trace ID associated with the run")
131
+ dotted_order: Optional[str] = Field(None, description="Dotted order of the run")
132
+ parent_run_id: Optional[str] = Field(None, description="Parent run ID")
133
+ end_time: Optional[datetime | str] = Field(None, description="End time of the run")
134
+ error: Optional[str] = Field(None, description="Error message of the run")
135
+ inputs: Optional[dict[str, Any]] = Field(None, description="Inputs of the run")
136
+ outputs: Optional[dict[str, Any]] = Field(None, description="Outputs of the run")
137
+ events: Optional[list[dict[str, Any]]] = Field(None, description="Events associated with the run")
138
+ tags: Optional[list[str]] = Field(None, description="Tags associated with the run")
139
+ extra: Optional[dict[str, Any]] = Field(None, description="Extra information of the run")
140
+ input_attachments: Optional[dict[str, Any]] = Field(None, description="Input attachments of the run")
141
+ output_attachments: Optional[dict[str, Any]] = Field(None, description="Output attachments of the run")
api/core/ops/langsmith_trace/langsmith_trace.py ADDED
@@ -0,0 +1,524 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+ import uuid
5
+ from datetime import datetime, timedelta
6
+ from typing import Optional, cast
7
+
8
+ from langsmith import Client
9
+ from langsmith.schemas import RunBase
10
+
11
+ from core.ops.base_trace_instance import BaseTraceInstance
12
+ from core.ops.entities.config_entity import LangSmithConfig
13
+ from core.ops.entities.trace_entity import (
14
+ BaseTraceInfo,
15
+ DatasetRetrievalTraceInfo,
16
+ GenerateNameTraceInfo,
17
+ MessageTraceInfo,
18
+ ModerationTraceInfo,
19
+ SuggestedQuestionTraceInfo,
20
+ ToolTraceInfo,
21
+ TraceTaskName,
22
+ WorkflowTraceInfo,
23
+ )
24
+ from core.ops.langsmith_trace.entities.langsmith_trace_entity import (
25
+ LangSmithRunModel,
26
+ LangSmithRunType,
27
+ LangSmithRunUpdateModel,
28
+ )
29
+ from core.ops.utils import filter_none_values, generate_dotted_order
30
+ from extensions.ext_database import db
31
+ from models.model import EndUser, MessageFile
32
+ from models.workflow import WorkflowNodeExecution
33
+
34
+ logger = logging.getLogger(__name__)
35
+
36
+
37
+ class LangSmithDataTrace(BaseTraceInstance):
38
+ def __init__(
39
+ self,
40
+ langsmith_config: LangSmithConfig,
41
+ ):
42
+ super().__init__(langsmith_config)
43
+ self.langsmith_key = langsmith_config.api_key
44
+ self.project_name = langsmith_config.project
45
+ self.project_id = None
46
+ self.langsmith_client = Client(api_key=langsmith_config.api_key, api_url=langsmith_config.endpoint)
47
+ self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
48
+
49
+ def trace(self, trace_info: BaseTraceInfo):
50
+ if isinstance(trace_info, WorkflowTraceInfo):
51
+ self.workflow_trace(trace_info)
52
+ if isinstance(trace_info, MessageTraceInfo):
53
+ self.message_trace(trace_info)
54
+ if isinstance(trace_info, ModerationTraceInfo):
55
+ self.moderation_trace(trace_info)
56
+ if isinstance(trace_info, SuggestedQuestionTraceInfo):
57
+ self.suggested_question_trace(trace_info)
58
+ if isinstance(trace_info, DatasetRetrievalTraceInfo):
59
+ self.dataset_retrieval_trace(trace_info)
60
+ if isinstance(trace_info, ToolTraceInfo):
61
+ self.tool_trace(trace_info)
62
+ if isinstance(trace_info, GenerateNameTraceInfo):
63
+ self.generate_name_trace(trace_info)
64
+
65
+ def workflow_trace(self, trace_info: WorkflowTraceInfo):
66
+ trace_id = trace_info.message_id or trace_info.workflow_run_id
67
+ if trace_info.start_time is None:
68
+ trace_info.start_time = datetime.now()
69
+ message_dotted_order = (
70
+ generate_dotted_order(trace_info.message_id, trace_info.start_time) if trace_info.message_id else None
71
+ )
72
+ workflow_dotted_order = generate_dotted_order(
73
+ trace_info.workflow_run_id,
74
+ trace_info.workflow_data.created_at,
75
+ message_dotted_order,
76
+ )
77
+ metadata = trace_info.metadata
78
+ metadata["workflow_app_log_id"] = trace_info.workflow_app_log_id
79
+
80
+ if trace_info.message_id:
81
+ message_run = LangSmithRunModel(
82
+ id=trace_info.message_id,
83
+ name=TraceTaskName.MESSAGE_TRACE.value,
84
+ inputs=dict(trace_info.workflow_run_inputs),
85
+ outputs=dict(trace_info.workflow_run_outputs),
86
+ run_type=LangSmithRunType.chain,
87
+ start_time=trace_info.start_time,
88
+ end_time=trace_info.end_time,
89
+ extra={
90
+ "metadata": metadata,
91
+ },
92
+ tags=["message", "workflow"],
93
+ error=trace_info.error,
94
+ trace_id=trace_id,
95
+ dotted_order=message_dotted_order,
96
+ file_list=[],
97
+ serialized=None,
98
+ parent_run_id=None,
99
+ events=[],
100
+ session_id=None,
101
+ session_name=None,
102
+ reference_example_id=None,
103
+ input_attachments={},
104
+ output_attachments={},
105
+ )
106
+ self.add_run(message_run)
107
+
108
+ langsmith_run = LangSmithRunModel(
109
+ file_list=trace_info.file_list,
110
+ total_tokens=trace_info.total_tokens,
111
+ id=trace_info.workflow_run_id,
112
+ name=TraceTaskName.WORKFLOW_TRACE.value,
113
+ inputs=dict(trace_info.workflow_run_inputs),
114
+ run_type=LangSmithRunType.tool,
115
+ start_time=trace_info.workflow_data.created_at,
116
+ end_time=trace_info.workflow_data.finished_at,
117
+ outputs=dict(trace_info.workflow_run_outputs),
118
+ extra={
119
+ "metadata": metadata,
120
+ },
121
+ error=trace_info.error,
122
+ tags=["workflow"],
123
+ parent_run_id=trace_info.message_id or None,
124
+ trace_id=trace_id,
125
+ dotted_order=workflow_dotted_order,
126
+ serialized=None,
127
+ events=[],
128
+ session_id=None,
129
+ session_name=None,
130
+ reference_example_id=None,
131
+ input_attachments={},
132
+ output_attachments={},
133
+ )
134
+
135
+ self.add_run(langsmith_run)
136
+
137
+ # through workflow_run_id get all_nodes_execution
138
+ workflow_nodes_execution_id_records = (
139
+ db.session.query(WorkflowNodeExecution.id)
140
+ .filter(WorkflowNodeExecution.workflow_run_id == trace_info.workflow_run_id)
141
+ .all()
142
+ )
143
+
144
+ for node_execution_id_record in workflow_nodes_execution_id_records:
145
+ node_execution = (
146
+ db.session.query(
147
+ WorkflowNodeExecution.id,
148
+ WorkflowNodeExecution.tenant_id,
149
+ WorkflowNodeExecution.app_id,
150
+ WorkflowNodeExecution.title,
151
+ WorkflowNodeExecution.node_type,
152
+ WorkflowNodeExecution.status,
153
+ WorkflowNodeExecution.inputs,
154
+ WorkflowNodeExecution.outputs,
155
+ WorkflowNodeExecution.created_at,
156
+ WorkflowNodeExecution.elapsed_time,
157
+ WorkflowNodeExecution.process_data,
158
+ WorkflowNodeExecution.execution_metadata,
159
+ )
160
+ .filter(WorkflowNodeExecution.id == node_execution_id_record.id)
161
+ .first()
162
+ )
163
+
164
+ if not node_execution:
165
+ continue
166
+
167
+ node_execution_id = node_execution.id
168
+ tenant_id = node_execution.tenant_id
169
+ app_id = node_execution.app_id
170
+ node_name = node_execution.title
171
+ node_type = node_execution.node_type
172
+ status = node_execution.status
173
+ if node_type == "llm":
174
+ inputs = (
175
+ json.loads(node_execution.process_data).get("prompts", {}) if node_execution.process_data else {}
176
+ )
177
+ else:
178
+ inputs = json.loads(node_execution.inputs) if node_execution.inputs else {}
179
+ outputs = json.loads(node_execution.outputs) if node_execution.outputs else {}
180
+ created_at = node_execution.created_at or datetime.now()
181
+ elapsed_time = node_execution.elapsed_time
182
+ finished_at = created_at + timedelta(seconds=elapsed_time)
183
+
184
+ execution_metadata = (
185
+ json.loads(node_execution.execution_metadata) if node_execution.execution_metadata else {}
186
+ )
187
+ node_total_tokens = execution_metadata.get("total_tokens", 0)
188
+ metadata = execution_metadata.copy()
189
+ metadata.update(
190
+ {
191
+ "workflow_run_id": trace_info.workflow_run_id,
192
+ "node_execution_id": node_execution_id,
193
+ "tenant_id": tenant_id,
194
+ "app_id": app_id,
195
+ "app_name": node_name,
196
+ "node_type": node_type,
197
+ "status": status,
198
+ }
199
+ )
200
+
201
+ process_data = json.loads(node_execution.process_data) if node_execution.process_data else {}
202
+ if process_data and process_data.get("model_mode") == "chat":
203
+ run_type = LangSmithRunType.llm
204
+ metadata.update(
205
+ {
206
+ "ls_provider": process_data.get("model_provider", ""),
207
+ "ls_model_name": process_data.get("model_name", ""),
208
+ }
209
+ )
210
+ elif node_type == "knowledge-retrieval":
211
+ run_type = LangSmithRunType.retriever
212
+ else:
213
+ run_type = LangSmithRunType.tool
214
+
215
+ node_dotted_order = generate_dotted_order(node_execution_id, created_at, workflow_dotted_order)
216
+ langsmith_run = LangSmithRunModel(
217
+ total_tokens=node_total_tokens,
218
+ name=node_type,
219
+ inputs=inputs,
220
+ run_type=run_type,
221
+ start_time=created_at,
222
+ end_time=finished_at,
223
+ outputs=outputs,
224
+ file_list=trace_info.file_list,
225
+ extra={
226
+ "metadata": metadata,
227
+ },
228
+ parent_run_id=trace_info.workflow_run_id,
229
+ tags=["node_execution"],
230
+ id=node_execution_id,
231
+ trace_id=trace_id,
232
+ dotted_order=node_dotted_order,
233
+ error="",
234
+ serialized=None,
235
+ events=[],
236
+ session_id=None,
237
+ session_name=None,
238
+ reference_example_id=None,
239
+ input_attachments={},
240
+ output_attachments={},
241
+ )
242
+
243
+ self.add_run(langsmith_run)
244
+
245
+ def message_trace(self, trace_info: MessageTraceInfo):
246
+ # get message file data
247
+ file_list = cast(list[str], trace_info.file_list) or []
248
+ message_file_data: Optional[MessageFile] = trace_info.message_file_data
249
+ file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else ""
250
+ file_list.append(file_url)
251
+ metadata = trace_info.metadata
252
+ message_data = trace_info.message_data
253
+ if message_data is None:
254
+ return
255
+ message_id = message_data.id
256
+
257
+ user_id = message_data.from_account_id
258
+ metadata["user_id"] = user_id
259
+
260
+ if message_data.from_end_user_id:
261
+ end_user_data: Optional[EndUser] = (
262
+ db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first()
263
+ )
264
+ if end_user_data is not None:
265
+ end_user_id = end_user_data.session_id
266
+ metadata["end_user_id"] = end_user_id
267
+
268
+ message_run = LangSmithRunModel(
269
+ input_tokens=trace_info.message_tokens,
270
+ output_tokens=trace_info.answer_tokens,
271
+ total_tokens=trace_info.total_tokens,
272
+ id=message_id,
273
+ name=TraceTaskName.MESSAGE_TRACE.value,
274
+ inputs=trace_info.inputs,
275
+ run_type=LangSmithRunType.chain,
276
+ start_time=trace_info.start_time,
277
+ end_time=trace_info.end_time,
278
+ outputs=message_data.answer,
279
+ extra={"metadata": metadata},
280
+ tags=["message", str(trace_info.conversation_mode)],
281
+ error=trace_info.error,
282
+ file_list=file_list,
283
+ serialized=None,
284
+ events=[],
285
+ session_id=None,
286
+ session_name=None,
287
+ reference_example_id=None,
288
+ input_attachments={},
289
+ output_attachments={},
290
+ trace_id=None,
291
+ dotted_order=None,
292
+ parent_run_id=None,
293
+ )
294
+ self.add_run(message_run)
295
+
296
+ # create llm run parented to message run
297
+ llm_run = LangSmithRunModel(
298
+ input_tokens=trace_info.message_tokens,
299
+ output_tokens=trace_info.answer_tokens,
300
+ total_tokens=trace_info.total_tokens,
301
+ name="llm",
302
+ inputs=trace_info.inputs,
303
+ run_type=LangSmithRunType.llm,
304
+ start_time=trace_info.start_time,
305
+ end_time=trace_info.end_time,
306
+ outputs=message_data.answer,
307
+ extra={"metadata": metadata},
308
+ parent_run_id=message_id,
309
+ tags=["llm", str(trace_info.conversation_mode)],
310
+ error=trace_info.error,
311
+ file_list=file_list,
312
+ serialized=None,
313
+ events=[],
314
+ session_id=None,
315
+ session_name=None,
316
+ reference_example_id=None,
317
+ input_attachments={},
318
+ output_attachments={},
319
+ trace_id=None,
320
+ dotted_order=None,
321
+ id=str(uuid.uuid4()),
322
+ )
323
+ self.add_run(llm_run)
324
+
325
+ def moderation_trace(self, trace_info: ModerationTraceInfo):
326
+ if trace_info.message_data is None:
327
+ return
328
+ langsmith_run = LangSmithRunModel(
329
+ name=TraceTaskName.MODERATION_TRACE.value,
330
+ inputs=trace_info.inputs,
331
+ outputs={
332
+ "action": trace_info.action,
333
+ "flagged": trace_info.flagged,
334
+ "preset_response": trace_info.preset_response,
335
+ "inputs": trace_info.inputs,
336
+ },
337
+ run_type=LangSmithRunType.tool,
338
+ extra={"metadata": trace_info.metadata},
339
+ tags=["moderation"],
340
+ parent_run_id=trace_info.message_id,
341
+ start_time=trace_info.start_time or trace_info.message_data.created_at,
342
+ end_time=trace_info.end_time or trace_info.message_data.updated_at,
343
+ id=str(uuid.uuid4()),
344
+ serialized=None,
345
+ events=[],
346
+ session_id=None,
347
+ session_name=None,
348
+ reference_example_id=None,
349
+ input_attachments={},
350
+ output_attachments={},
351
+ trace_id=None,
352
+ dotted_order=None,
353
+ error="",
354
+ file_list=[],
355
+ )
356
+
357
+ self.add_run(langsmith_run)
358
+
359
+ def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo):
360
+ message_data = trace_info.message_data
361
+ if message_data is None:
362
+ return
363
+ suggested_question_run = LangSmithRunModel(
364
+ name=TraceTaskName.SUGGESTED_QUESTION_TRACE.value,
365
+ inputs=trace_info.inputs,
366
+ outputs=trace_info.suggested_question,
367
+ run_type=LangSmithRunType.tool,
368
+ extra={"metadata": trace_info.metadata},
369
+ tags=["suggested_question"],
370
+ parent_run_id=trace_info.message_id,
371
+ start_time=trace_info.start_time or message_data.created_at,
372
+ end_time=trace_info.end_time or message_data.updated_at,
373
+ id=str(uuid.uuid4()),
374
+ serialized=None,
375
+ events=[],
376
+ session_id=None,
377
+ session_name=None,
378
+ reference_example_id=None,
379
+ input_attachments={},
380
+ output_attachments={},
381
+ trace_id=None,
382
+ dotted_order=None,
383
+ error="",
384
+ file_list=[],
385
+ )
386
+
387
+ self.add_run(suggested_question_run)
388
+
389
+ def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo):
390
+ if trace_info.message_data is None:
391
+ return
392
+ dataset_retrieval_run = LangSmithRunModel(
393
+ name=TraceTaskName.DATASET_RETRIEVAL_TRACE.value,
394
+ inputs=trace_info.inputs,
395
+ outputs={"documents": trace_info.documents},
396
+ run_type=LangSmithRunType.retriever,
397
+ extra={"metadata": trace_info.metadata},
398
+ tags=["dataset_retrieval"],
399
+ parent_run_id=trace_info.message_id,
400
+ start_time=trace_info.start_time or trace_info.message_data.created_at,
401
+ end_time=trace_info.end_time or trace_info.message_data.updated_at,
402
+ id=str(uuid.uuid4()),
403
+ serialized=None,
404
+ events=[],
405
+ session_id=None,
406
+ session_name=None,
407
+ reference_example_id=None,
408
+ input_attachments={},
409
+ output_attachments={},
410
+ trace_id=None,
411
+ dotted_order=None,
412
+ error="",
413
+ file_list=[],
414
+ )
415
+
416
+ self.add_run(dataset_retrieval_run)
417
+
418
+ def tool_trace(self, trace_info: ToolTraceInfo):
419
+ tool_run = LangSmithRunModel(
420
+ name=trace_info.tool_name,
421
+ inputs=trace_info.tool_inputs,
422
+ outputs=trace_info.tool_outputs,
423
+ run_type=LangSmithRunType.tool,
424
+ extra={
425
+ "metadata": trace_info.metadata,
426
+ },
427
+ tags=["tool", trace_info.tool_name],
428
+ parent_run_id=trace_info.message_id,
429
+ start_time=trace_info.start_time,
430
+ end_time=trace_info.end_time,
431
+ file_list=[cast(str, trace_info.file_url)],
432
+ id=str(uuid.uuid4()),
433
+ serialized=None,
434
+ events=[],
435
+ session_id=None,
436
+ session_name=None,
437
+ reference_example_id=None,
438
+ input_attachments={},
439
+ output_attachments={},
440
+ trace_id=None,
441
+ dotted_order=None,
442
+ error=trace_info.error or "",
443
+ )
444
+
445
+ self.add_run(tool_run)
446
+
447
+ def generate_name_trace(self, trace_info: GenerateNameTraceInfo):
448
+ name_run = LangSmithRunModel(
449
+ name=TraceTaskName.GENERATE_NAME_TRACE.value,
450
+ inputs=trace_info.inputs,
451
+ outputs=trace_info.outputs,
452
+ run_type=LangSmithRunType.tool,
453
+ extra={"metadata": trace_info.metadata},
454
+ tags=["generate_name"],
455
+ start_time=trace_info.start_time or datetime.now(),
456
+ end_time=trace_info.end_time or datetime.now(),
457
+ id=str(uuid.uuid4()),
458
+ serialized=None,
459
+ events=[],
460
+ session_id=None,
461
+ session_name=None,
462
+ reference_example_id=None,
463
+ input_attachments={},
464
+ output_attachments={},
465
+ trace_id=None,
466
+ dotted_order=None,
467
+ error="",
468
+ file_list=[],
469
+ parent_run_id=None,
470
+ )
471
+
472
+ self.add_run(name_run)
473
+
474
+ def add_run(self, run_data: LangSmithRunModel):
475
+ data = run_data.model_dump()
476
+ if self.project_id:
477
+ data["session_id"] = self.project_id
478
+ elif self.project_name:
479
+ data["session_name"] = self.project_name
480
+
481
+ data = filter_none_values(data)
482
+ try:
483
+ self.langsmith_client.create_run(**data)
484
+ logger.debug("LangSmith Run created successfully.")
485
+ except Exception as e:
486
+ raise ValueError(f"LangSmith Failed to create run: {str(e)}")
487
+
488
+ def update_run(self, update_run_data: LangSmithRunUpdateModel):
489
+ data = update_run_data.model_dump()
490
+ data = filter_none_values(data)
491
+ try:
492
+ self.langsmith_client.update_run(**data)
493
+ logger.debug("LangSmith Run updated successfully.")
494
+ except Exception as e:
495
+ raise ValueError(f"LangSmith Failed to update run: {str(e)}")
496
+
497
+ def api_check(self):
498
+ try:
499
+ random_project_name = f"test_project_{datetime.now().strftime('%Y%m%d%H%M%S')}"
500
+ self.langsmith_client.create_project(project_name=random_project_name)
501
+ self.langsmith_client.delete_project(project_name=random_project_name)
502
+ return True
503
+ except Exception as e:
504
+ logger.debug(f"LangSmith API check failed: {str(e)}")
505
+ raise ValueError(f"LangSmith API check failed: {str(e)}")
506
+
507
+ def get_project_url(self):
508
+ try:
509
+ run_data = RunBase(
510
+ id=uuid.uuid4(),
511
+ name="tool",
512
+ inputs={"input": "test"},
513
+ outputs={"output": "test"},
514
+ run_type=LangSmithRunType.tool,
515
+ start_time=datetime.now(),
516
+ )
517
+
518
+ project_url = self.langsmith_client.get_run_url(
519
+ run=run_data, project_id=self.project_id, project_name=self.project_name
520
+ )
521
+ return project_url.split("/r/")[0]
522
+ except Exception as e:
523
+ logger.debug(f"LangSmith get run url failed: {str(e)}")
524
+ raise ValueError(f"LangSmith get run url failed: {str(e)}")
api/core/ops/opik_trace/__init__.py ADDED
File without changes
api/core/ops/opik_trace/opik_trace.py ADDED
@@ -0,0 +1,469 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+ import uuid
5
+ from datetime import datetime, timedelta
6
+ from typing import Optional, cast
7
+
8
+ from opik import Opik, Trace
9
+ from opik.id_helpers import uuid4_to_uuid7
10
+
11
+ from core.ops.base_trace_instance import BaseTraceInstance
12
+ from core.ops.entities.config_entity import OpikConfig
13
+ from core.ops.entities.trace_entity import (
14
+ BaseTraceInfo,
15
+ DatasetRetrievalTraceInfo,
16
+ GenerateNameTraceInfo,
17
+ MessageTraceInfo,
18
+ ModerationTraceInfo,
19
+ SuggestedQuestionTraceInfo,
20
+ ToolTraceInfo,
21
+ TraceTaskName,
22
+ WorkflowTraceInfo,
23
+ )
24
+ from extensions.ext_database import db
25
+ from models.model import EndUser, MessageFile
26
+ from models.workflow import WorkflowNodeExecution
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ def wrap_dict(key_name, data):
32
+ """Make sure that the input data is a dict"""
33
+ if not isinstance(data, dict):
34
+ return {key_name: data}
35
+
36
+ return data
37
+
38
+
39
+ def wrap_metadata(metadata, **kwargs):
40
+ """Add common metatada to all Traces and Spans"""
41
+ metadata["created_from"] = "dify"
42
+
43
+ metadata.update(kwargs)
44
+
45
+ return metadata
46
+
47
+
48
+ def prepare_opik_uuid(user_datetime: Optional[datetime], user_uuid: Optional[str]):
49
+ """Opik needs UUIDv7 while Dify uses UUIDv4 for identifier of most
50
+ messages and objects. The type-hints of BaseTraceInfo indicates that
51
+ objects start_time and message_id could be null which means we cannot map
52
+ it to a UUIDv7. Given that we have no way to identify that object
53
+ uniquely, generate a new random one UUIDv7 in that case.
54
+ """
55
+
56
+ if user_datetime is None:
57
+ user_datetime = datetime.now()
58
+
59
+ if user_uuid is None:
60
+ user_uuid = str(uuid.uuid4())
61
+
62
+ return uuid4_to_uuid7(user_datetime, user_uuid)
63
+
64
+
65
+ class OpikDataTrace(BaseTraceInstance):
66
+ def __init__(
67
+ self,
68
+ opik_config: OpikConfig,
69
+ ):
70
+ super().__init__(opik_config)
71
+ self.opik_client = Opik(
72
+ project_name=opik_config.project,
73
+ workspace=opik_config.workspace,
74
+ host=opik_config.url,
75
+ api_key=opik_config.api_key,
76
+ )
77
+ self.project = opik_config.project
78
+ self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
79
+
80
+ def trace(self, trace_info: BaseTraceInfo):
81
+ if isinstance(trace_info, WorkflowTraceInfo):
82
+ self.workflow_trace(trace_info)
83
+ if isinstance(trace_info, MessageTraceInfo):
84
+ self.message_trace(trace_info)
85
+ if isinstance(trace_info, ModerationTraceInfo):
86
+ self.moderation_trace(trace_info)
87
+ if isinstance(trace_info, SuggestedQuestionTraceInfo):
88
+ self.suggested_question_trace(trace_info)
89
+ if isinstance(trace_info, DatasetRetrievalTraceInfo):
90
+ self.dataset_retrieval_trace(trace_info)
91
+ if isinstance(trace_info, ToolTraceInfo):
92
+ self.tool_trace(trace_info)
93
+ if isinstance(trace_info, GenerateNameTraceInfo):
94
+ self.generate_name_trace(trace_info)
95
+
96
+ def workflow_trace(self, trace_info: WorkflowTraceInfo):
97
+ dify_trace_id = trace_info.workflow_run_id
98
+ opik_trace_id = prepare_opik_uuid(trace_info.start_time, dify_trace_id)
99
+ workflow_metadata = wrap_metadata(
100
+ trace_info.metadata, message_id=trace_info.message_id, workflow_app_log_id=trace_info.workflow_app_log_id
101
+ )
102
+ root_span_id = None
103
+
104
+ if trace_info.message_id:
105
+ dify_trace_id = trace_info.message_id
106
+ opik_trace_id = prepare_opik_uuid(trace_info.start_time, dify_trace_id)
107
+
108
+ trace_data = {
109
+ "id": opik_trace_id,
110
+ "name": TraceTaskName.MESSAGE_TRACE.value,
111
+ "start_time": trace_info.start_time,
112
+ "end_time": trace_info.end_time,
113
+ "metadata": workflow_metadata,
114
+ "input": wrap_dict("input", trace_info.workflow_run_inputs),
115
+ "output": wrap_dict("output", trace_info.workflow_run_outputs),
116
+ "tags": ["message", "workflow"],
117
+ "project_name": self.project,
118
+ }
119
+ self.add_trace(trace_data)
120
+
121
+ root_span_id = prepare_opik_uuid(trace_info.start_time, trace_info.workflow_run_id)
122
+ span_data = {
123
+ "id": root_span_id,
124
+ "parent_span_id": None,
125
+ "trace_id": opik_trace_id,
126
+ "name": TraceTaskName.WORKFLOW_TRACE.value,
127
+ "input": wrap_dict("input", trace_info.workflow_run_inputs),
128
+ "output": wrap_dict("output", trace_info.workflow_run_outputs),
129
+ "start_time": trace_info.start_time,
130
+ "end_time": trace_info.end_time,
131
+ "metadata": workflow_metadata,
132
+ "tags": ["workflow"],
133
+ "project_name": self.project,
134
+ }
135
+ self.add_span(span_data)
136
+ else:
137
+ trace_data = {
138
+ "id": opik_trace_id,
139
+ "name": TraceTaskName.MESSAGE_TRACE.value,
140
+ "start_time": trace_info.start_time,
141
+ "end_time": trace_info.end_time,
142
+ "metadata": workflow_metadata,
143
+ "input": wrap_dict("input", trace_info.workflow_run_inputs),
144
+ "output": wrap_dict("output", trace_info.workflow_run_outputs),
145
+ "tags": ["workflow"],
146
+ "project_name": self.project,
147
+ }
148
+ self.add_trace(trace_data)
149
+
150
+ # through workflow_run_id get all_nodes_execution
151
+ workflow_nodes_execution_id_records = (
152
+ db.session.query(WorkflowNodeExecution.id)
153
+ .filter(WorkflowNodeExecution.workflow_run_id == trace_info.workflow_run_id)
154
+ .all()
155
+ )
156
+
157
+ for node_execution_id_record in workflow_nodes_execution_id_records:
158
+ node_execution = (
159
+ db.session.query(
160
+ WorkflowNodeExecution.id,
161
+ WorkflowNodeExecution.tenant_id,
162
+ WorkflowNodeExecution.app_id,
163
+ WorkflowNodeExecution.title,
164
+ WorkflowNodeExecution.node_type,
165
+ WorkflowNodeExecution.status,
166
+ WorkflowNodeExecution.inputs,
167
+ WorkflowNodeExecution.outputs,
168
+ WorkflowNodeExecution.created_at,
169
+ WorkflowNodeExecution.elapsed_time,
170
+ WorkflowNodeExecution.process_data,
171
+ WorkflowNodeExecution.execution_metadata,
172
+ )
173
+ .filter(WorkflowNodeExecution.id == node_execution_id_record.id)
174
+ .first()
175
+ )
176
+
177
+ if not node_execution:
178
+ continue
179
+
180
+ node_execution_id = node_execution.id
181
+ tenant_id = node_execution.tenant_id
182
+ app_id = node_execution.app_id
183
+ node_name = node_execution.title
184
+ node_type = node_execution.node_type
185
+ status = node_execution.status
186
+ if node_type == "llm":
187
+ inputs = (
188
+ json.loads(node_execution.process_data).get("prompts", {}) if node_execution.process_data else {}
189
+ )
190
+ else:
191
+ inputs = json.loads(node_execution.inputs) if node_execution.inputs else {}
192
+ outputs = json.loads(node_execution.outputs) if node_execution.outputs else {}
193
+ created_at = node_execution.created_at or datetime.now()
194
+ elapsed_time = node_execution.elapsed_time
195
+ finished_at = created_at + timedelta(seconds=elapsed_time)
196
+
197
+ execution_metadata = (
198
+ json.loads(node_execution.execution_metadata) if node_execution.execution_metadata else {}
199
+ )
200
+ metadata = execution_metadata.copy()
201
+ metadata.update(
202
+ {
203
+ "workflow_run_id": trace_info.workflow_run_id,
204
+ "node_execution_id": node_execution_id,
205
+ "tenant_id": tenant_id,
206
+ "app_id": app_id,
207
+ "app_name": node_name,
208
+ "node_type": node_type,
209
+ "status": status,
210
+ }
211
+ )
212
+
213
+ process_data = json.loads(node_execution.process_data) if node_execution.process_data else {}
214
+
215
+ provider = None
216
+ model = None
217
+ total_tokens = 0
218
+ completion_tokens = 0
219
+ prompt_tokens = 0
220
+
221
+ if process_data and process_data.get("model_mode") == "chat":
222
+ run_type = "llm"
223
+ provider = process_data.get("model_provider", None)
224
+ model = process_data.get("model_name", "")
225
+ metadata.update(
226
+ {
227
+ "ls_provider": provider,
228
+ "ls_model_name": model,
229
+ }
230
+ )
231
+
232
+ try:
233
+ if outputs.get("usage"):
234
+ total_tokens = outputs["usage"].get("total_tokens", 0)
235
+ prompt_tokens = outputs["usage"].get("prompt_tokens", 0)
236
+ completion_tokens = outputs["usage"].get("completion_tokens", 0)
237
+ except Exception:
238
+ logger.error("Failed to extract usage", exc_info=True)
239
+
240
+ else:
241
+ run_type = "tool"
242
+
243
+ parent_span_id = trace_info.workflow_app_log_id or trace_info.workflow_run_id
244
+
245
+ if not total_tokens:
246
+ total_tokens = execution_metadata.get("total_tokens", 0)
247
+
248
+ span_data = {
249
+ "trace_id": opik_trace_id,
250
+ "id": prepare_opik_uuid(created_at, node_execution_id),
251
+ "parent_span_id": prepare_opik_uuid(trace_info.start_time, parent_span_id),
252
+ "name": node_type,
253
+ "type": run_type,
254
+ "start_time": created_at,
255
+ "end_time": finished_at,
256
+ "metadata": wrap_metadata(metadata),
257
+ "input": wrap_dict("input", inputs),
258
+ "output": wrap_dict("output", outputs),
259
+ "tags": ["node_execution"],
260
+ "project_name": self.project,
261
+ "usage": {
262
+ "total_tokens": total_tokens,
263
+ "completion_tokens": completion_tokens,
264
+ "prompt_tokens": prompt_tokens,
265
+ },
266
+ "model": model,
267
+ "provider": provider,
268
+ }
269
+
270
+ self.add_span(span_data)
271
+
272
+ def message_trace(self, trace_info: MessageTraceInfo):
273
+ # get message file data
274
+ file_list = cast(list[str], trace_info.file_list) or []
275
+ message_file_data: Optional[MessageFile] = trace_info.message_file_data
276
+
277
+ if message_file_data is not None:
278
+ file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else ""
279
+ file_list.append(file_url)
280
+
281
+ message_data = trace_info.message_data
282
+ if message_data is None:
283
+ return
284
+
285
+ metadata = trace_info.metadata
286
+ message_id = trace_info.message_id
287
+
288
+ user_id = message_data.from_account_id
289
+ metadata["user_id"] = user_id
290
+ metadata["file_list"] = file_list
291
+
292
+ if message_data.from_end_user_id:
293
+ end_user_data: Optional[EndUser] = (
294
+ db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first()
295
+ )
296
+ if end_user_data is not None:
297
+ end_user_id = end_user_data.session_id
298
+ metadata["end_user_id"] = end_user_id
299
+
300
+ trace_data = {
301
+ "id": prepare_opik_uuid(trace_info.start_time, message_id),
302
+ "name": TraceTaskName.MESSAGE_TRACE.value,
303
+ "start_time": trace_info.start_time,
304
+ "end_time": trace_info.end_time,
305
+ "metadata": wrap_metadata(metadata),
306
+ "input": trace_info.inputs,
307
+ "output": message_data.answer,
308
+ "tags": ["message", str(trace_info.conversation_mode)],
309
+ "project_name": self.project,
310
+ }
311
+ trace = self.add_trace(trace_data)
312
+
313
+ span_data = {
314
+ "trace_id": trace.id,
315
+ "name": "llm",
316
+ "type": "llm",
317
+ "start_time": trace_info.start_time,
318
+ "end_time": trace_info.end_time,
319
+ "metadata": wrap_metadata(metadata),
320
+ "input": {"input": trace_info.inputs},
321
+ "output": {"output": message_data.answer},
322
+ "tags": ["llm", str(trace_info.conversation_mode)],
323
+ "usage": {
324
+ "completion_tokens": trace_info.answer_tokens,
325
+ "prompt_tokens": trace_info.message_tokens,
326
+ "total_tokens": trace_info.total_tokens,
327
+ },
328
+ "project_name": self.project,
329
+ }
330
+ self.add_span(span_data)
331
+
332
+ def moderation_trace(self, trace_info: ModerationTraceInfo):
333
+ if trace_info.message_data is None:
334
+ return
335
+
336
+ start_time = trace_info.start_time or trace_info.message_data.created_at
337
+
338
+ span_data = {
339
+ "trace_id": prepare_opik_uuid(start_time, trace_info.message_id),
340
+ "name": TraceTaskName.MODERATION_TRACE.value,
341
+ "type": "tool",
342
+ "start_time": start_time,
343
+ "end_time": trace_info.end_time or trace_info.message_data.updated_at,
344
+ "metadata": wrap_metadata(trace_info.metadata),
345
+ "input": wrap_dict("input", trace_info.inputs),
346
+ "output": {
347
+ "action": trace_info.action,
348
+ "flagged": trace_info.flagged,
349
+ "preset_response": trace_info.preset_response,
350
+ "inputs": trace_info.inputs,
351
+ },
352
+ "tags": ["moderation"],
353
+ }
354
+
355
+ self.add_span(span_data)
356
+
357
+ def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo):
358
+ message_data = trace_info.message_data
359
+ if message_data is None:
360
+ return
361
+
362
+ start_time = trace_info.start_time or message_data.created_at
363
+
364
+ span_data = {
365
+ "trace_id": prepare_opik_uuid(start_time, trace_info.message_id),
366
+ "name": TraceTaskName.SUGGESTED_QUESTION_TRACE.value,
367
+ "type": "tool",
368
+ "start_time": start_time,
369
+ "end_time": trace_info.end_time or message_data.updated_at,
370
+ "metadata": wrap_metadata(trace_info.metadata),
371
+ "input": wrap_dict("input", trace_info.inputs),
372
+ "output": wrap_dict("output", trace_info.suggested_question),
373
+ "tags": ["suggested_question"],
374
+ }
375
+
376
+ self.add_span(span_data)
377
+
378
+ def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo):
379
+ if trace_info.message_data is None:
380
+ return
381
+
382
+ start_time = trace_info.start_time or trace_info.message_data.created_at
383
+
384
+ span_data = {
385
+ "trace_id": prepare_opik_uuid(start_time, trace_info.message_id),
386
+ "name": TraceTaskName.DATASET_RETRIEVAL_TRACE.value,
387
+ "type": "tool",
388
+ "start_time": start_time,
389
+ "end_time": trace_info.end_time or trace_info.message_data.updated_at,
390
+ "metadata": wrap_metadata(trace_info.metadata),
391
+ "input": wrap_dict("input", trace_info.inputs),
392
+ "output": {"documents": trace_info.documents},
393
+ "tags": ["dataset_retrieval"],
394
+ }
395
+
396
+ self.add_span(span_data)
397
+
398
+ def tool_trace(self, trace_info: ToolTraceInfo):
399
+ span_data = {
400
+ "trace_id": prepare_opik_uuid(trace_info.start_time, trace_info.message_id),
401
+ "name": trace_info.tool_name,
402
+ "type": "tool",
403
+ "start_time": trace_info.start_time,
404
+ "end_time": trace_info.end_time,
405
+ "metadata": wrap_metadata(trace_info.metadata),
406
+ "input": wrap_dict("input", trace_info.tool_inputs),
407
+ "output": wrap_dict("output", trace_info.tool_outputs),
408
+ "tags": ["tool", trace_info.tool_name],
409
+ }
410
+
411
+ self.add_span(span_data)
412
+
413
+ def generate_name_trace(self, trace_info: GenerateNameTraceInfo):
414
+ trace_data = {
415
+ "id": prepare_opik_uuid(trace_info.start_time, trace_info.message_id),
416
+ "name": TraceTaskName.GENERATE_NAME_TRACE.value,
417
+ "start_time": trace_info.start_time,
418
+ "end_time": trace_info.end_time,
419
+ "metadata": wrap_metadata(trace_info.metadata),
420
+ "input": trace_info.inputs,
421
+ "output": trace_info.outputs,
422
+ "tags": ["generate_name"],
423
+ "project_name": self.project,
424
+ }
425
+
426
+ trace = self.add_trace(trace_data)
427
+
428
+ span_data = {
429
+ "trace_id": trace.id,
430
+ "name": TraceTaskName.GENERATE_NAME_TRACE.value,
431
+ "start_time": trace_info.start_time,
432
+ "end_time": trace_info.end_time,
433
+ "metadata": wrap_metadata(trace_info.metadata),
434
+ "input": wrap_dict("input", trace_info.inputs),
435
+ "output": wrap_dict("output", trace_info.outputs),
436
+ "tags": ["generate_name"],
437
+ }
438
+
439
+ self.add_span(span_data)
440
+
441
+ def add_trace(self, opik_trace_data: dict) -> Trace:
442
+ try:
443
+ trace = self.opik_client.trace(**opik_trace_data)
444
+ logger.debug("Opik Trace created successfully")
445
+ return trace
446
+ except Exception as e:
447
+ raise ValueError(f"Opik Failed to create trace: {str(e)}")
448
+
449
+ def add_span(self, opik_span_data: dict):
450
+ try:
451
+ self.opik_client.span(**opik_span_data)
452
+ logger.debug("Opik Span created successfully")
453
+ except Exception as e:
454
+ raise ValueError(f"Opik Failed to create span: {str(e)}")
455
+
456
+ def api_check(self):
457
+ try:
458
+ self.opik_client.auth_check()
459
+ return True
460
+ except Exception as e:
461
+ logger.info(f"Opik API check failed: {str(e)}", exc_info=True)
462
+ raise ValueError(f"Opik API check failed: {str(e)}")
463
+
464
+ def get_project_url(self):
465
+ try:
466
+ return self.opik_client.get_project_url(project_name=self.project)
467
+ except Exception as e:
468
+ logger.info(f"Opik get run url failed: {str(e)}", exc_info=True)
469
+ raise ValueError(f"Opik get run url failed: {str(e)}")
api/core/ops/ops_trace_manager.py ADDED
@@ -0,0 +1,811 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+ import queue
5
+ import threading
6
+ import time
7
+ from datetime import timedelta
8
+ from typing import Any, Optional, Union
9
+ from uuid import UUID, uuid4
10
+
11
+ from flask import current_app
12
+ from sqlalchemy import select
13
+ from sqlalchemy.orm import Session
14
+
15
+ from core.helper.encrypter import decrypt_token, encrypt_token, obfuscated_token
16
+ from core.ops.entities.config_entity import (
17
+ OPS_FILE_PATH,
18
+ LangfuseConfig,
19
+ LangSmithConfig,
20
+ OpikConfig,
21
+ TracingProviderEnum,
22
+ )
23
+ from core.ops.entities.trace_entity import (
24
+ DatasetRetrievalTraceInfo,
25
+ GenerateNameTraceInfo,
26
+ MessageTraceInfo,
27
+ ModerationTraceInfo,
28
+ SuggestedQuestionTraceInfo,
29
+ TaskData,
30
+ ToolTraceInfo,
31
+ TraceTaskName,
32
+ WorkflowTraceInfo,
33
+ )
34
+ from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace
35
+ from core.ops.langsmith_trace.langsmith_trace import LangSmithDataTrace
36
+ from core.ops.opik_trace.opik_trace import OpikDataTrace
37
+ from core.ops.utils import get_message_data
38
+ from extensions.ext_database import db
39
+ from extensions.ext_storage import storage
40
+ from models.model import App, AppModelConfig, Conversation, Message, MessageFile, TraceAppConfig
41
+ from models.workflow import WorkflowAppLog, WorkflowRun
42
+ from tasks.ops_trace_task import process_trace_tasks
43
+
44
+ provider_config_map: dict[str, dict[str, Any]] = {
45
+ TracingProviderEnum.LANGFUSE.value: {
46
+ "config_class": LangfuseConfig,
47
+ "secret_keys": ["public_key", "secret_key"],
48
+ "other_keys": ["host", "project_key"],
49
+ "trace_instance": LangFuseDataTrace,
50
+ },
51
+ TracingProviderEnum.LANGSMITH.value: {
52
+ "config_class": LangSmithConfig,
53
+ "secret_keys": ["api_key"],
54
+ "other_keys": ["project", "endpoint"],
55
+ "trace_instance": LangSmithDataTrace,
56
+ },
57
+ TracingProviderEnum.OPIK.value: {
58
+ "config_class": OpikConfig,
59
+ "secret_keys": ["api_key"],
60
+ "other_keys": ["project", "url", "workspace"],
61
+ "trace_instance": OpikDataTrace,
62
+ },
63
+ }
64
+
65
+
66
+ class OpsTraceManager:
67
+ @classmethod
68
+ def encrypt_tracing_config(
69
+ cls, tenant_id: str, tracing_provider: str, tracing_config: dict, current_trace_config=None
70
+ ):
71
+ """
72
+ Encrypt tracing config.
73
+ :param tenant_id: tenant id
74
+ :param tracing_provider: tracing provider
75
+ :param tracing_config: tracing config dictionary to be encrypted
76
+ :param current_trace_config: current tracing configuration for keeping existing values
77
+ :return: encrypted tracing configuration
78
+ """
79
+ # Get the configuration class and the keys that require encryption
80
+ config_class, secret_keys, other_keys = (
81
+ provider_config_map[tracing_provider]["config_class"],
82
+ provider_config_map[tracing_provider]["secret_keys"],
83
+ provider_config_map[tracing_provider]["other_keys"],
84
+ )
85
+
86
+ new_config = {}
87
+ # Encrypt necessary keys
88
+ for key in secret_keys:
89
+ if key in tracing_config:
90
+ if "*" in tracing_config[key]:
91
+ # If the key contains '*', retain the original value from the current config
92
+ new_config[key] = current_trace_config.get(key, tracing_config[key])
93
+ else:
94
+ # Otherwise, encrypt the key
95
+ new_config[key] = encrypt_token(tenant_id, tracing_config[key])
96
+
97
+ for key in other_keys:
98
+ new_config[key] = tracing_config.get(key, "")
99
+
100
+ # Create a new instance of the config class with the new configuration
101
+ encrypted_config = config_class(**new_config)
102
+ return encrypted_config.model_dump()
103
+
104
+ @classmethod
105
+ def decrypt_tracing_config(cls, tenant_id: str, tracing_provider: str, tracing_config: dict):
106
+ """
107
+ Decrypt tracing config
108
+ :param tenant_id: tenant id
109
+ :param tracing_provider: tracing provider
110
+ :param tracing_config: tracing config
111
+ :return:
112
+ """
113
+ config_class, secret_keys, other_keys = (
114
+ provider_config_map[tracing_provider]["config_class"],
115
+ provider_config_map[tracing_provider]["secret_keys"],
116
+ provider_config_map[tracing_provider]["other_keys"],
117
+ )
118
+ new_config = {}
119
+ for key in secret_keys:
120
+ if key in tracing_config:
121
+ new_config[key] = decrypt_token(tenant_id, tracing_config[key])
122
+
123
+ for key in other_keys:
124
+ new_config[key] = tracing_config.get(key, "")
125
+
126
+ return config_class(**new_config).model_dump()
127
+
128
+ @classmethod
129
+ def obfuscated_decrypt_token(cls, tracing_provider: str, decrypt_tracing_config: dict):
130
+ """
131
+ Decrypt tracing config
132
+ :param tracing_provider: tracing provider
133
+ :param decrypt_tracing_config: tracing config
134
+ :return:
135
+ """
136
+ config_class, secret_keys, other_keys = (
137
+ provider_config_map[tracing_provider]["config_class"],
138
+ provider_config_map[tracing_provider]["secret_keys"],
139
+ provider_config_map[tracing_provider]["other_keys"],
140
+ )
141
+ new_config = {}
142
+ for key in secret_keys:
143
+ if key in decrypt_tracing_config:
144
+ new_config[key] = obfuscated_token(decrypt_tracing_config[key])
145
+
146
+ for key in other_keys:
147
+ new_config[key] = decrypt_tracing_config.get(key, "")
148
+ return config_class(**new_config).model_dump()
149
+
150
+ @classmethod
151
+ def get_decrypted_tracing_config(cls, app_id: str, tracing_provider: str):
152
+ """
153
+ Get decrypted tracing config
154
+ :param app_id: app id
155
+ :param tracing_provider: tracing provider
156
+ :return:
157
+ """
158
+ trace_config_data: Optional[TraceAppConfig] = (
159
+ db.session.query(TraceAppConfig)
160
+ .filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider)
161
+ .first()
162
+ )
163
+
164
+ if not trace_config_data:
165
+ return None
166
+
167
+ # decrypt_token
168
+ app = db.session.query(App).filter(App.id == app_id).first()
169
+ if not app:
170
+ raise ValueError("App not found")
171
+
172
+ tenant_id = app.tenant_id
173
+ decrypt_tracing_config = cls.decrypt_tracing_config(
174
+ tenant_id, tracing_provider, trace_config_data.tracing_config
175
+ )
176
+
177
+ return decrypt_tracing_config
178
+
179
+ @classmethod
180
+ def get_ops_trace_instance(
181
+ cls,
182
+ app_id: Optional[Union[UUID, str]] = None,
183
+ ):
184
+ """
185
+ Get ops trace through model config
186
+ :param app_id: app_id
187
+ :return:
188
+ """
189
+ if isinstance(app_id, UUID):
190
+ app_id = str(app_id)
191
+
192
+ if app_id is None:
193
+ return None
194
+
195
+ app: Optional[App] = db.session.query(App).filter(App.id == app_id).first()
196
+
197
+ if app is None:
198
+ return None
199
+
200
+ app_ops_trace_config = json.loads(app.tracing) if app.tracing else None
201
+
202
+ if app_ops_trace_config is None:
203
+ return None
204
+
205
+ tracing_provider = app_ops_trace_config.get("tracing_provider")
206
+
207
+ if tracing_provider is None or tracing_provider not in provider_config_map:
208
+ return None
209
+
210
+ # decrypt_token
211
+ decrypt_trace_config = cls.get_decrypted_tracing_config(app_id, tracing_provider)
212
+ if app_ops_trace_config.get("enabled"):
213
+ trace_instance, config_class = (
214
+ provider_config_map[tracing_provider]["trace_instance"],
215
+ provider_config_map[tracing_provider]["config_class"],
216
+ )
217
+ tracing_instance = trace_instance(config_class(**decrypt_trace_config))
218
+ return tracing_instance
219
+
220
+ return None
221
+
222
+ @classmethod
223
+ def get_app_config_through_message_id(cls, message_id: str):
224
+ app_model_config = None
225
+ message_data = db.session.query(Message).filter(Message.id == message_id).first()
226
+ if not message_data:
227
+ return None
228
+ conversation_id = message_data.conversation_id
229
+ conversation_data = db.session.query(Conversation).filter(Conversation.id == conversation_id).first()
230
+ if not conversation_data:
231
+ return None
232
+
233
+ if conversation_data.app_model_config_id:
234
+ app_model_config = (
235
+ db.session.query(AppModelConfig)
236
+ .filter(AppModelConfig.id == conversation_data.app_model_config_id)
237
+ .first()
238
+ )
239
+ elif conversation_data.app_model_config_id is None and conversation_data.override_model_configs:
240
+ app_model_config = conversation_data.override_model_configs
241
+
242
+ return app_model_config
243
+
244
+ @classmethod
245
+ def update_app_tracing_config(cls, app_id: str, enabled: bool, tracing_provider: str):
246
+ """
247
+ Update app tracing config
248
+ :param app_id: app id
249
+ :param enabled: enabled
250
+ :param tracing_provider: tracing provider
251
+ :return:
252
+ """
253
+ # auth check
254
+ if tracing_provider not in provider_config_map and tracing_provider is not None:
255
+ raise ValueError(f"Invalid tracing provider: {tracing_provider}")
256
+
257
+ app_config: Optional[App] = db.session.query(App).filter(App.id == app_id).first()
258
+ if not app_config:
259
+ raise ValueError("App not found")
260
+ app_config.tracing = json.dumps(
261
+ {
262
+ "enabled": enabled,
263
+ "tracing_provider": tracing_provider,
264
+ }
265
+ )
266
+ db.session.commit()
267
+
268
+ @classmethod
269
+ def get_app_tracing_config(cls, app_id: str):
270
+ """
271
+ Get app tracing config
272
+ :param app_id: app id
273
+ :return:
274
+ """
275
+ app: Optional[App] = db.session.query(App).filter(App.id == app_id).first()
276
+ if not app:
277
+ raise ValueError("App not found")
278
+ if not app.tracing:
279
+ return {"enabled": False, "tracing_provider": None}
280
+ app_trace_config = json.loads(app.tracing)
281
+ return app_trace_config
282
+
283
+ @staticmethod
284
+ def check_trace_config_is_effective(tracing_config: dict, tracing_provider: str):
285
+ """
286
+ Check trace config is effective
287
+ :param tracing_config: tracing config
288
+ :param tracing_provider: tracing provider
289
+ :return:
290
+ """
291
+ config_type, trace_instance = (
292
+ provider_config_map[tracing_provider]["config_class"],
293
+ provider_config_map[tracing_provider]["trace_instance"],
294
+ )
295
+ tracing_config = config_type(**tracing_config)
296
+ return trace_instance(tracing_config).api_check()
297
+
298
+ @staticmethod
299
+ def get_trace_config_project_key(tracing_config: dict, tracing_provider: str):
300
+ """
301
+ get trace config is project key
302
+ :param tracing_config: tracing config
303
+ :param tracing_provider: tracing provider
304
+ :return:
305
+ """
306
+ config_type, trace_instance = (
307
+ provider_config_map[tracing_provider]["config_class"],
308
+ provider_config_map[tracing_provider]["trace_instance"],
309
+ )
310
+ tracing_config = config_type(**tracing_config)
311
+ return trace_instance(tracing_config).get_project_key()
312
+
313
+ @staticmethod
314
+ def get_trace_config_project_url(tracing_config: dict, tracing_provider: str):
315
+ """
316
+ get trace config is project key
317
+ :param tracing_config: tracing config
318
+ :param tracing_provider: tracing provider
319
+ :return:
320
+ """
321
+ config_type, trace_instance = (
322
+ provider_config_map[tracing_provider]["config_class"],
323
+ provider_config_map[tracing_provider]["trace_instance"],
324
+ )
325
+ tracing_config = config_type(**tracing_config)
326
+ return trace_instance(tracing_config).get_project_url()
327
+
328
+
329
+ class TraceTask:
330
+ def __init__(
331
+ self,
332
+ trace_type: Any,
333
+ message_id: Optional[str] = None,
334
+ workflow_run: Optional[WorkflowRun] = None,
335
+ conversation_id: Optional[str] = None,
336
+ user_id: Optional[str] = None,
337
+ timer: Optional[Any] = None,
338
+ **kwargs,
339
+ ):
340
+ self.trace_type = trace_type
341
+ self.message_id = message_id
342
+ self.workflow_run_id = workflow_run.id if workflow_run else None
343
+ self.conversation_id = conversation_id
344
+ self.user_id = user_id
345
+ self.timer = timer
346
+ self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
347
+ self.app_id = None
348
+
349
+ self.kwargs = kwargs
350
+
351
+ def execute(self):
352
+ return self.preprocess()
353
+
354
+ def preprocess(self):
355
+ preprocess_map = {
356
+ TraceTaskName.CONVERSATION_TRACE: lambda: self.conversation_trace(**self.kwargs),
357
+ TraceTaskName.WORKFLOW_TRACE: lambda: self.workflow_trace(
358
+ workflow_run_id=self.workflow_run_id, conversation_id=self.conversation_id, user_id=self.user_id
359
+ ),
360
+ TraceTaskName.MESSAGE_TRACE: lambda: self.message_trace(message_id=self.message_id),
361
+ TraceTaskName.MODERATION_TRACE: lambda: self.moderation_trace(
362
+ message_id=self.message_id, timer=self.timer, **self.kwargs
363
+ ),
364
+ TraceTaskName.SUGGESTED_QUESTION_TRACE: lambda: self.suggested_question_trace(
365
+ message_id=self.message_id, timer=self.timer, **self.kwargs
366
+ ),
367
+ TraceTaskName.DATASET_RETRIEVAL_TRACE: lambda: self.dataset_retrieval_trace(
368
+ message_id=self.message_id, timer=self.timer, **self.kwargs
369
+ ),
370
+ TraceTaskName.TOOL_TRACE: lambda: self.tool_trace(
371
+ message_id=self.message_id, timer=self.timer, **self.kwargs
372
+ ),
373
+ TraceTaskName.GENERATE_NAME_TRACE: lambda: self.generate_name_trace(
374
+ conversation_id=self.conversation_id, timer=self.timer, **self.kwargs
375
+ ),
376
+ }
377
+
378
+ return preprocess_map.get(self.trace_type, lambda: None)()
379
+
380
+ # process methods for different trace types
381
+ def conversation_trace(self, **kwargs):
382
+ return kwargs
383
+
384
+ def workflow_trace(
385
+ self,
386
+ *,
387
+ workflow_run_id: str | None,
388
+ conversation_id: str | None,
389
+ user_id: str | None,
390
+ ):
391
+ if not workflow_run_id:
392
+ return {}
393
+
394
+ with Session(db.engine) as session:
395
+ workflow_run_stmt = select(WorkflowRun).where(WorkflowRun.id == workflow_run_id)
396
+ workflow_run = session.scalars(workflow_run_stmt).first()
397
+ if not workflow_run:
398
+ raise ValueError("Workflow run not found")
399
+
400
+ workflow_id = workflow_run.workflow_id
401
+ tenant_id = workflow_run.tenant_id
402
+ workflow_run_id = workflow_run.id
403
+ workflow_run_elapsed_time = workflow_run.elapsed_time
404
+ workflow_run_status = workflow_run.status
405
+ workflow_run_inputs = workflow_run.inputs_dict
406
+ workflow_run_outputs = workflow_run.outputs_dict
407
+ workflow_run_version = workflow_run.version
408
+ error = workflow_run.error or ""
409
+
410
+ total_tokens = workflow_run.total_tokens
411
+
412
+ file_list = workflow_run_inputs.get("sys.file") or []
413
+ query = workflow_run_inputs.get("query") or workflow_run_inputs.get("sys.query") or ""
414
+
415
+ # get workflow_app_log_id
416
+ workflow_app_log_data_stmt = select(WorkflowAppLog.id).where(
417
+ WorkflowAppLog.tenant_id == tenant_id,
418
+ WorkflowAppLog.app_id == workflow_run.app_id,
419
+ WorkflowAppLog.workflow_run_id == workflow_run.id,
420
+ )
421
+ workflow_app_log_id = session.scalar(workflow_app_log_data_stmt)
422
+ # get message_id
423
+ message_id = None
424
+ if conversation_id:
425
+ message_data_stmt = select(Message.id).where(
426
+ Message.conversation_id == conversation_id,
427
+ Message.workflow_run_id == workflow_run_id,
428
+ )
429
+ message_id = session.scalar(message_data_stmt)
430
+
431
+ metadata = {
432
+ "workflow_id": workflow_id,
433
+ "conversation_id": conversation_id,
434
+ "workflow_run_id": workflow_run_id,
435
+ "tenant_id": tenant_id,
436
+ "elapsed_time": workflow_run_elapsed_time,
437
+ "status": workflow_run_status,
438
+ "version": workflow_run_version,
439
+ "total_tokens": total_tokens,
440
+ "file_list": file_list,
441
+ "triggered_form": workflow_run.triggered_from,
442
+ "user_id": user_id,
443
+ }
444
+
445
+ workflow_trace_info = WorkflowTraceInfo(
446
+ workflow_data=workflow_run.to_dict(),
447
+ conversation_id=conversation_id,
448
+ workflow_id=workflow_id,
449
+ tenant_id=tenant_id,
450
+ workflow_run_id=workflow_run_id,
451
+ workflow_run_elapsed_time=workflow_run_elapsed_time,
452
+ workflow_run_status=workflow_run_status,
453
+ workflow_run_inputs=workflow_run_inputs,
454
+ workflow_run_outputs=workflow_run_outputs,
455
+ workflow_run_version=workflow_run_version,
456
+ error=error,
457
+ total_tokens=total_tokens,
458
+ file_list=file_list,
459
+ query=query,
460
+ metadata=metadata,
461
+ workflow_app_log_id=workflow_app_log_id,
462
+ message_id=message_id,
463
+ start_time=workflow_run.created_at,
464
+ end_time=workflow_run.finished_at,
465
+ )
466
+ return workflow_trace_info
467
+
468
+ def message_trace(self, message_id: str | None):
469
+ if not message_id:
470
+ return {}
471
+ message_data = get_message_data(message_id)
472
+ if not message_data:
473
+ return {}
474
+ conversation_mode_stmt = select(Conversation.mode).where(Conversation.id == message_data.conversation_id)
475
+ conversation_mode = db.session.scalars(conversation_mode_stmt).all()
476
+ if not conversation_mode or len(conversation_mode) == 0:
477
+ return {}
478
+ conversation_mode = conversation_mode[0]
479
+ created_at = message_data.created_at
480
+ inputs = message_data.message
481
+
482
+ # get message file data
483
+ message_file_data = db.session.query(MessageFile).filter_by(message_id=message_id).first()
484
+ file_list = []
485
+ if message_file_data and message_file_data.url is not None:
486
+ file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else ""
487
+ file_list.append(file_url)
488
+
489
+ metadata = {
490
+ "conversation_id": message_data.conversation_id,
491
+ "ls_provider": message_data.model_provider,
492
+ "ls_model_name": message_data.model_id,
493
+ "status": message_data.status,
494
+ "from_end_user_id": message_data.from_end_user_id,
495
+ "from_account_id": message_data.from_account_id,
496
+ "agent_based": message_data.agent_based,
497
+ "workflow_run_id": message_data.workflow_run_id,
498
+ "from_source": message_data.from_source,
499
+ "message_id": message_id,
500
+ }
501
+
502
+ message_tokens = message_data.message_tokens
503
+
504
+ message_trace_info = MessageTraceInfo(
505
+ message_id=message_id,
506
+ message_data=message_data.to_dict(),
507
+ conversation_model=conversation_mode,
508
+ message_tokens=message_tokens,
509
+ answer_tokens=message_data.answer_tokens,
510
+ total_tokens=message_tokens + message_data.answer_tokens,
511
+ error=message_data.error or "",
512
+ inputs=inputs,
513
+ outputs=message_data.answer,
514
+ file_list=file_list,
515
+ start_time=created_at,
516
+ end_time=created_at + timedelta(seconds=message_data.provider_response_latency),
517
+ metadata=metadata,
518
+ message_file_data=message_file_data,
519
+ conversation_mode=conversation_mode,
520
+ )
521
+
522
+ return message_trace_info
523
+
524
+ def moderation_trace(self, message_id, timer, **kwargs):
525
+ moderation_result = kwargs.get("moderation_result")
526
+ if not moderation_result:
527
+ return {}
528
+ inputs = kwargs.get("inputs")
529
+ message_data = get_message_data(message_id)
530
+ if not message_data:
531
+ return {}
532
+ metadata = {
533
+ "message_id": message_id,
534
+ "action": moderation_result.action,
535
+ "preset_response": moderation_result.preset_response,
536
+ "query": moderation_result.query,
537
+ }
538
+
539
+ # get workflow_app_log_id
540
+ workflow_app_log_id = None
541
+ if message_data.workflow_run_id:
542
+ workflow_app_log_data = (
543
+ db.session.query(WorkflowAppLog).filter_by(workflow_run_id=message_data.workflow_run_id).first()
544
+ )
545
+ workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None
546
+
547
+ moderation_trace_info = ModerationTraceInfo(
548
+ message_id=workflow_app_log_id or message_id,
549
+ inputs=inputs,
550
+ message_data=message_data.to_dict(),
551
+ flagged=moderation_result.flagged,
552
+ action=moderation_result.action,
553
+ preset_response=moderation_result.preset_response,
554
+ query=moderation_result.query,
555
+ start_time=timer.get("start"),
556
+ end_time=timer.get("end"),
557
+ metadata=metadata,
558
+ )
559
+
560
+ return moderation_trace_info
561
+
562
+ def suggested_question_trace(self, message_id, timer, **kwargs):
563
+ suggested_question = kwargs.get("suggested_question", [])
564
+ message_data = get_message_data(message_id)
565
+ if not message_data:
566
+ return {}
567
+ metadata = {
568
+ "message_id": message_id,
569
+ "ls_provider": message_data.model_provider,
570
+ "ls_model_name": message_data.model_id,
571
+ "status": message_data.status,
572
+ "from_end_user_id": message_data.from_end_user_id,
573
+ "from_account_id": message_data.from_account_id,
574
+ "agent_based": message_data.agent_based,
575
+ "workflow_run_id": message_data.workflow_run_id,
576
+ "from_source": message_data.from_source,
577
+ }
578
+
579
+ # get workflow_app_log_id
580
+ workflow_app_log_id = None
581
+ if message_data.workflow_run_id:
582
+ workflow_app_log_data = (
583
+ db.session.query(WorkflowAppLog).filter_by(workflow_run_id=message_data.workflow_run_id).first()
584
+ )
585
+ workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None
586
+
587
+ suggested_question_trace_info = SuggestedQuestionTraceInfo(
588
+ message_id=workflow_app_log_id or message_id,
589
+ message_data=message_data.to_dict(),
590
+ inputs=message_data.message,
591
+ outputs=message_data.answer,
592
+ start_time=timer.get("start"),
593
+ end_time=timer.get("end"),
594
+ metadata=metadata,
595
+ total_tokens=message_data.message_tokens + message_data.answer_tokens,
596
+ status=message_data.status,
597
+ error=message_data.error,
598
+ from_account_id=message_data.from_account_id,
599
+ agent_based=message_data.agent_based,
600
+ from_source=message_data.from_source,
601
+ model_provider=message_data.model_provider,
602
+ model_id=message_data.model_id,
603
+ suggested_question=suggested_question,
604
+ level=message_data.status,
605
+ status_message=message_data.error,
606
+ )
607
+
608
+ return suggested_question_trace_info
609
+
610
+ def dataset_retrieval_trace(self, message_id, timer, **kwargs):
611
+ documents = kwargs.get("documents")
612
+ message_data = get_message_data(message_id)
613
+ if not message_data:
614
+ return {}
615
+
616
+ metadata = {
617
+ "message_id": message_id,
618
+ "ls_provider": message_data.model_provider,
619
+ "ls_model_name": message_data.model_id,
620
+ "status": message_data.status,
621
+ "from_end_user_id": message_data.from_end_user_id,
622
+ "from_account_id": message_data.from_account_id,
623
+ "agent_based": message_data.agent_based,
624
+ "workflow_run_id": message_data.workflow_run_id,
625
+ "from_source": message_data.from_source,
626
+ }
627
+
628
+ dataset_retrieval_trace_info = DatasetRetrievalTraceInfo(
629
+ message_id=message_id,
630
+ inputs=message_data.query or message_data.inputs,
631
+ documents=[doc.model_dump() for doc in documents] if documents else [],
632
+ start_time=timer.get("start"),
633
+ end_time=timer.get("end"),
634
+ metadata=metadata,
635
+ message_data=message_data.to_dict(),
636
+ )
637
+
638
+ return dataset_retrieval_trace_info
639
+
640
+ def tool_trace(self, message_id, timer, **kwargs):
641
+ tool_name = kwargs.get("tool_name", "")
642
+ tool_inputs = kwargs.get("tool_inputs", {})
643
+ tool_outputs = kwargs.get("tool_outputs", {})
644
+ message_data = get_message_data(message_id)
645
+ if not message_data:
646
+ return {}
647
+ tool_config = {}
648
+ time_cost = 0
649
+ error = None
650
+ tool_parameters = {}
651
+ created_time = message_data.created_at
652
+ end_time = message_data.updated_at
653
+ agent_thoughts = message_data.agent_thoughts
654
+ for agent_thought in agent_thoughts:
655
+ if tool_name in agent_thought.tools:
656
+ created_time = agent_thought.created_at
657
+ tool_meta_data = agent_thought.tool_meta.get(tool_name, {})
658
+ tool_config = tool_meta_data.get("tool_config", {})
659
+ time_cost = tool_meta_data.get("time_cost", 0)
660
+ end_time = created_time + timedelta(seconds=time_cost)
661
+ error = tool_meta_data.get("error", "")
662
+ tool_parameters = tool_meta_data.get("tool_parameters", {})
663
+ metadata = {
664
+ "message_id": message_id,
665
+ "tool_name": tool_name,
666
+ "tool_inputs": tool_inputs,
667
+ "tool_outputs": tool_outputs,
668
+ "tool_config": tool_config,
669
+ "time_cost": time_cost,
670
+ "error": error,
671
+ "tool_parameters": tool_parameters,
672
+ }
673
+
674
+ file_url = ""
675
+ message_file_data = db.session.query(MessageFile).filter_by(message_id=message_id).first()
676
+ if message_file_data:
677
+ message_file_id = message_file_data.id if message_file_data else None
678
+ type = message_file_data.type
679
+ created_by_role = message_file_data.created_by_role
680
+ created_user_id = message_file_data.created_by
681
+ file_url = f"{self.file_base_url}/{message_file_data.url}"
682
+
683
+ metadata.update(
684
+ {
685
+ "message_file_id": message_file_id,
686
+ "created_by_role": created_by_role,
687
+ "created_user_id": created_user_id,
688
+ "type": type,
689
+ }
690
+ )
691
+
692
+ tool_trace_info = ToolTraceInfo(
693
+ message_id=message_id,
694
+ message_data=message_data.to_dict(),
695
+ tool_name=tool_name,
696
+ start_time=timer.get("start") if timer else created_time,
697
+ end_time=timer.get("end") if timer else end_time,
698
+ tool_inputs=tool_inputs,
699
+ tool_outputs=tool_outputs,
700
+ metadata=metadata,
701
+ message_file_data=message_file_data,
702
+ error=error,
703
+ inputs=message_data.message,
704
+ outputs=message_data.answer,
705
+ tool_config=tool_config,
706
+ time_cost=time_cost,
707
+ tool_parameters=tool_parameters,
708
+ file_url=file_url,
709
+ )
710
+
711
+ return tool_trace_info
712
+
713
+ def generate_name_trace(self, conversation_id, timer, **kwargs):
714
+ generate_conversation_name = kwargs.get("generate_conversation_name")
715
+ inputs = kwargs.get("inputs")
716
+ tenant_id = kwargs.get("tenant_id")
717
+ if not tenant_id:
718
+ return {}
719
+ start_time = timer.get("start")
720
+ end_time = timer.get("end")
721
+
722
+ metadata = {
723
+ "conversation_id": conversation_id,
724
+ "tenant_id": tenant_id,
725
+ }
726
+
727
+ generate_name_trace_info = GenerateNameTraceInfo(
728
+ conversation_id=conversation_id,
729
+ inputs=inputs,
730
+ outputs=generate_conversation_name,
731
+ start_time=start_time,
732
+ end_time=end_time,
733
+ metadata=metadata,
734
+ tenant_id=tenant_id,
735
+ )
736
+
737
+ return generate_name_trace_info
738
+
739
+
740
+ trace_manager_timer: Optional[threading.Timer] = None
741
+ trace_manager_queue: queue.Queue = queue.Queue()
742
+ trace_manager_interval = int(os.getenv("TRACE_QUEUE_MANAGER_INTERVAL", 5))
743
+ trace_manager_batch_size = int(os.getenv("TRACE_QUEUE_MANAGER_BATCH_SIZE", 100))
744
+
745
+
746
+ class TraceQueueManager:
747
+ def __init__(self, app_id=None, user_id=None):
748
+ global trace_manager_timer
749
+
750
+ self.app_id = app_id
751
+ self.user_id = user_id
752
+ self.trace_instance = OpsTraceManager.get_ops_trace_instance(app_id)
753
+ self.flask_app = current_app._get_current_object() # type: ignore
754
+ if trace_manager_timer is None:
755
+ self.start_timer()
756
+
757
+ def add_trace_task(self, trace_task: TraceTask):
758
+ global trace_manager_timer, trace_manager_queue
759
+ try:
760
+ if self.trace_instance:
761
+ trace_task.app_id = self.app_id
762
+ trace_manager_queue.put(trace_task)
763
+ except Exception as e:
764
+ logging.exception(f"Error adding trace task, trace_type {trace_task.trace_type}")
765
+ finally:
766
+ self.start_timer()
767
+
768
+ def collect_tasks(self):
769
+ global trace_manager_queue
770
+ tasks: list[TraceTask] = []
771
+ while len(tasks) < trace_manager_batch_size and not trace_manager_queue.empty():
772
+ task = trace_manager_queue.get_nowait()
773
+ tasks.append(task)
774
+ trace_manager_queue.task_done()
775
+ return tasks
776
+
777
+ def run(self):
778
+ try:
779
+ tasks = self.collect_tasks()
780
+ if tasks:
781
+ self.send_to_celery(tasks)
782
+ except Exception as e:
783
+ logging.exception("Error processing trace tasks")
784
+
785
+ def start_timer(self):
786
+ global trace_manager_timer
787
+ if trace_manager_timer is None or not trace_manager_timer.is_alive():
788
+ trace_manager_timer = threading.Timer(trace_manager_interval, self.run)
789
+ trace_manager_timer.name = f"trace_manager_timer_{time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}"
790
+ trace_manager_timer.daemon = False
791
+ trace_manager_timer.start()
792
+
793
+ def send_to_celery(self, tasks: list[TraceTask]):
794
+ with self.flask_app.app_context():
795
+ for task in tasks:
796
+ if task.app_id is None:
797
+ continue
798
+ file_id = uuid4().hex
799
+ trace_info = task.execute()
800
+ task_data = TaskData(
801
+ app_id=task.app_id,
802
+ trace_info_type=type(trace_info).__name__,
803
+ trace_info=trace_info.model_dump() if trace_info else None,
804
+ )
805
+ file_path = f"{OPS_FILE_PATH}{task.app_id}/{file_id}.json"
806
+ storage.save(file_path, task_data.model_dump_json().encode("utf-8"))
807
+ file_info = {
808
+ "file_id": file_id,
809
+ "app_id": task.app_id,
810
+ }
811
+ process_trace_tasks.delay(file_info)
api/core/ops/utils.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from contextlib import contextmanager
2
+ from datetime import datetime
3
+ from typing import Optional, Union
4
+
5
+ from extensions.ext_database import db
6
+ from models.model import Message
7
+
8
+
9
+ def filter_none_values(data: dict):
10
+ new_data = {}
11
+ for key, value in data.items():
12
+ if value is None:
13
+ continue
14
+ if isinstance(value, datetime):
15
+ new_data[key] = value.isoformat()
16
+ else:
17
+ new_data[key] = value
18
+ return new_data
19
+
20
+
21
+ def get_message_data(message_id: str):
22
+ return db.session.query(Message).filter(Message.id == message_id).first()
23
+
24
+
25
+ @contextmanager
26
+ def measure_time():
27
+ timing_info = {"start": datetime.now(), "end": None}
28
+ try:
29
+ yield timing_info
30
+ finally:
31
+ timing_info["end"] = datetime.now()
32
+
33
+
34
+ def replace_text_with_content(data):
35
+ if isinstance(data, dict):
36
+ new_data = {}
37
+ for key, value in data.items():
38
+ if key == "text":
39
+ new_data["content"] = value
40
+ else:
41
+ new_data[key] = replace_text_with_content(value)
42
+ return new_data
43
+ elif isinstance(data, list):
44
+ return [replace_text_with_content(item) for item in data]
45
+ else:
46
+ return data
47
+
48
+
49
+ def generate_dotted_order(
50
+ run_id: str, start_time: Union[str, datetime], parent_dotted_order: Optional[str] = None
51
+ ) -> str:
52
+ """
53
+ generate dotted_order for langsmith
54
+ """
55
+ start_time = datetime.fromisoformat(start_time) if isinstance(start_time, str) else start_time
56
+ timestamp = start_time.strftime("%Y%m%dT%H%M%S%f")[:-3] + "Z"
57
+ current_segment = f"{timestamp}{run_id}"
58
+
59
+ if parent_dotted_order is None:
60
+ return current_segment
61
+
62
+ return f"{parent_dotted_order}.{current_segment}"
api/core/prompt/__init__.py ADDED
File without changes
api/core/prompt/advanced_prompt_transform.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections.abc import Mapping, Sequence
2
+ from typing import Optional, cast
3
+
4
+ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
5
+ from core.file import file_manager
6
+ from core.file.models import File
7
+ from core.helper.code_executor.jinja2.jinja2_formatter import Jinja2Formatter
8
+ from core.memory.token_buffer_memory import TokenBufferMemory
9
+ from core.model_runtime.entities import (
10
+ AssistantPromptMessage,
11
+ PromptMessage,
12
+ PromptMessageContent,
13
+ PromptMessageRole,
14
+ SystemPromptMessage,
15
+ TextPromptMessageContent,
16
+ UserPromptMessage,
17
+ )
18
+ from core.model_runtime.entities.message_entities import ImagePromptMessageContent
19
+ from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
20
+ from core.prompt.prompt_transform import PromptTransform
21
+ from core.prompt.utils.prompt_template_parser import PromptTemplateParser
22
+ from core.workflow.entities.variable_pool import VariablePool
23
+
24
+
25
+ class AdvancedPromptTransform(PromptTransform):
26
+ """
27
+ Advanced Prompt Transform for Workflow LLM Node.
28
+ """
29
+
30
+ def __init__(
31
+ self,
32
+ with_variable_tmpl: bool = False,
33
+ image_detail_config: ImagePromptMessageContent.DETAIL = ImagePromptMessageContent.DETAIL.LOW,
34
+ ) -> None:
35
+ self.with_variable_tmpl = with_variable_tmpl
36
+ self.image_detail_config = image_detail_config
37
+
38
+ def get_prompt(
39
+ self,
40
+ *,
41
+ prompt_template: Sequence[ChatModelMessage] | CompletionModelPromptTemplate,
42
+ inputs: Mapping[str, str],
43
+ query: str,
44
+ files: Sequence[File],
45
+ context: Optional[str],
46
+ memory_config: Optional[MemoryConfig],
47
+ memory: Optional[TokenBufferMemory],
48
+ model_config: ModelConfigWithCredentialsEntity,
49
+ ) -> list[PromptMessage]:
50
+ prompt_messages = []
51
+
52
+ if isinstance(prompt_template, CompletionModelPromptTemplate):
53
+ prompt_messages = self._get_completion_model_prompt_messages(
54
+ prompt_template=prompt_template,
55
+ inputs=inputs,
56
+ query=query,
57
+ files=files,
58
+ context=context,
59
+ memory_config=memory_config,
60
+ memory=memory,
61
+ model_config=model_config,
62
+ )
63
+ elif isinstance(prompt_template, list) and all(isinstance(item, ChatModelMessage) for item in prompt_template):
64
+ prompt_messages = self._get_chat_model_prompt_messages(
65
+ prompt_template=prompt_template,
66
+ inputs=inputs,
67
+ query=query,
68
+ files=files,
69
+ context=context,
70
+ memory_config=memory_config,
71
+ memory=memory,
72
+ model_config=model_config,
73
+ )
74
+
75
+ return prompt_messages
76
+
77
+ def _get_completion_model_prompt_messages(
78
+ self,
79
+ prompt_template: CompletionModelPromptTemplate,
80
+ inputs: Mapping[str, str],
81
+ query: Optional[str],
82
+ files: Sequence[File],
83
+ context: Optional[str],
84
+ memory_config: Optional[MemoryConfig],
85
+ memory: Optional[TokenBufferMemory],
86
+ model_config: ModelConfigWithCredentialsEntity,
87
+ ) -> list[PromptMessage]:
88
+ """
89
+ Get completion model prompt messages.
90
+ """
91
+ raw_prompt = prompt_template.text
92
+
93
+ prompt_messages: list[PromptMessage] = []
94
+
95
+ if prompt_template.edition_type == "basic" or not prompt_template.edition_type:
96
+ parser = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
97
+ prompt_inputs: Mapping[str, str] = {k: inputs[k] for k in parser.variable_keys if k in inputs}
98
+
99
+ prompt_inputs = self._set_context_variable(context, parser, prompt_inputs)
100
+
101
+ if memory and memory_config and memory_config.role_prefix:
102
+ role_prefix = memory_config.role_prefix
103
+ prompt_inputs = self._set_histories_variable(
104
+ memory=memory,
105
+ memory_config=memory_config,
106
+ raw_prompt=raw_prompt,
107
+ role_prefix=role_prefix,
108
+ parser=parser,
109
+ prompt_inputs=prompt_inputs,
110
+ model_config=model_config,
111
+ )
112
+
113
+ if query:
114
+ prompt_inputs = self._set_query_variable(query, parser, prompt_inputs)
115
+
116
+ prompt = parser.format(prompt_inputs)
117
+ else:
118
+ prompt = raw_prompt
119
+ prompt_inputs = inputs
120
+
121
+ prompt = Jinja2Formatter.format(prompt, prompt_inputs)
122
+
123
+ if files:
124
+ prompt_message_contents: list[PromptMessageContent] = []
125
+ prompt_message_contents.append(TextPromptMessageContent(data=prompt))
126
+ for file in files:
127
+ prompt_message_contents.append(file_manager.to_prompt_message_content(file))
128
+
129
+ prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
130
+ else:
131
+ prompt_messages.append(UserPromptMessage(content=prompt))
132
+
133
+ return prompt_messages
134
+
135
+ def _get_chat_model_prompt_messages(
136
+ self,
137
+ prompt_template: list[ChatModelMessage],
138
+ inputs: Mapping[str, str],
139
+ query: Optional[str],
140
+ files: Sequence[File],
141
+ context: Optional[str],
142
+ memory_config: Optional[MemoryConfig],
143
+ memory: Optional[TokenBufferMemory],
144
+ model_config: ModelConfigWithCredentialsEntity,
145
+ ) -> list[PromptMessage]:
146
+ """
147
+ Get chat model prompt messages.
148
+ """
149
+ prompt_messages: list[PromptMessage] = []
150
+ for prompt_item in prompt_template:
151
+ raw_prompt = prompt_item.text
152
+
153
+ if prompt_item.edition_type == "basic" or not prompt_item.edition_type:
154
+ if self.with_variable_tmpl:
155
+ vp = VariablePool()
156
+ for k, v in inputs.items():
157
+ if k.startswith("#"):
158
+ vp.add(k[1:-1].split("."), v)
159
+ raw_prompt = raw_prompt.replace("{{#context#}}", context or "")
160
+ prompt = vp.convert_template(raw_prompt).text
161
+ else:
162
+ parser = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
163
+ prompt_inputs: Mapping[str, str] = {k: inputs[k] for k in parser.variable_keys if k in inputs}
164
+ prompt_inputs = self._set_context_variable(
165
+ context=context, parser=parser, prompt_inputs=prompt_inputs
166
+ )
167
+ prompt = parser.format(prompt_inputs)
168
+ elif prompt_item.edition_type == "jinja2":
169
+ prompt = raw_prompt
170
+ prompt_inputs = inputs
171
+ prompt = Jinja2Formatter.format(template=prompt, inputs=prompt_inputs)
172
+ else:
173
+ raise ValueError(f"Invalid edition type: {prompt_item.edition_type}")
174
+
175
+ if prompt_item.role == PromptMessageRole.USER:
176
+ prompt_messages.append(UserPromptMessage(content=prompt))
177
+ elif prompt_item.role == PromptMessageRole.SYSTEM and prompt:
178
+ prompt_messages.append(SystemPromptMessage(content=prompt))
179
+ elif prompt_item.role == PromptMessageRole.ASSISTANT:
180
+ prompt_messages.append(AssistantPromptMessage(content=prompt))
181
+
182
+ if query and memory_config and memory_config.query_prompt_template:
183
+ parser = PromptTemplateParser(
184
+ template=memory_config.query_prompt_template, with_variable_tmpl=self.with_variable_tmpl
185
+ )
186
+ prompt_inputs = {k: inputs[k] for k in parser.variable_keys if k in inputs}
187
+ prompt_inputs["#sys.query#"] = query
188
+
189
+ prompt_inputs = self._set_context_variable(context, parser, prompt_inputs)
190
+
191
+ query = parser.format(prompt_inputs)
192
+
193
+ if memory and memory_config:
194
+ prompt_messages = self._append_chat_histories(memory, memory_config, prompt_messages, model_config)
195
+
196
+ if files and query is not None:
197
+ prompt_message_contents: list[PromptMessageContent] = []
198
+ prompt_message_contents.append(TextPromptMessageContent(data=query))
199
+ for file in files:
200
+ prompt_message_contents.append(file_manager.to_prompt_message_content(file))
201
+ prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
202
+ else:
203
+ prompt_messages.append(UserPromptMessage(content=query))
204
+ elif files:
205
+ if not query:
206
+ # get last message
207
+ last_message = prompt_messages[-1] if prompt_messages else None
208
+ if last_message and last_message.role == PromptMessageRole.USER:
209
+ # get last user message content and add files
210
+ prompt_message_contents = [TextPromptMessageContent(data=cast(str, last_message.content))]
211
+ for file in files:
212
+ prompt_message_contents.append(file_manager.to_prompt_message_content(file))
213
+
214
+ last_message.content = prompt_message_contents
215
+ else:
216
+ prompt_message_contents = [TextPromptMessageContent(data="")] # not for query
217
+ for file in files:
218
+ prompt_message_contents.append(file_manager.to_prompt_message_content(file))
219
+
220
+ prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
221
+ else:
222
+ prompt_message_contents = [TextPromptMessageContent(data=query)]
223
+ for file in files:
224
+ prompt_message_contents.append(file_manager.to_prompt_message_content(file))
225
+
226
+ prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
227
+ elif query:
228
+ prompt_messages.append(UserPromptMessage(content=query))
229
+
230
+ return prompt_messages
231
+
232
+ def _set_context_variable(
233
+ self, context: str | None, parser: PromptTemplateParser, prompt_inputs: Mapping[str, str]
234
+ ) -> Mapping[str, str]:
235
+ prompt_inputs = dict(prompt_inputs)
236
+ if "#context#" in parser.variable_keys:
237
+ if context:
238
+ prompt_inputs["#context#"] = context
239
+ else:
240
+ prompt_inputs["#context#"] = ""
241
+
242
+ return prompt_inputs
243
+
244
+ def _set_query_variable(
245
+ self, query: str, parser: PromptTemplateParser, prompt_inputs: Mapping[str, str]
246
+ ) -> Mapping[str, str]:
247
+ prompt_inputs = dict(prompt_inputs)
248
+ if "#query#" in parser.variable_keys:
249
+ if query:
250
+ prompt_inputs["#query#"] = query
251
+ else:
252
+ prompt_inputs["#query#"] = ""
253
+
254
+ return prompt_inputs
255
+
256
+ def _set_histories_variable(
257
+ self,
258
+ memory: TokenBufferMemory,
259
+ memory_config: MemoryConfig,
260
+ raw_prompt: str,
261
+ role_prefix: MemoryConfig.RolePrefix,
262
+ parser: PromptTemplateParser,
263
+ prompt_inputs: Mapping[str, str],
264
+ model_config: ModelConfigWithCredentialsEntity,
265
+ ) -> Mapping[str, str]:
266
+ prompt_inputs = dict(prompt_inputs)
267
+ if "#histories#" in parser.variable_keys:
268
+ if memory:
269
+ inputs = {"#histories#": "", **prompt_inputs}
270
+ parser = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
271
+ prompt_inputs = {k: inputs[k] for k in parser.variable_keys if k in inputs}
272
+ tmp_human_message = UserPromptMessage(content=parser.format(prompt_inputs))
273
+
274
+ rest_tokens = self._calculate_rest_token([tmp_human_message], model_config)
275
+
276
+ histories = self._get_history_messages_from_memory(
277
+ memory=memory,
278
+ memory_config=memory_config,
279
+ max_token_limit=rest_tokens,
280
+ human_prefix=role_prefix.user,
281
+ ai_prefix=role_prefix.assistant,
282
+ )
283
+ prompt_inputs["#histories#"] = histories
284
+ else:
285
+ prompt_inputs["#histories#"] = ""
286
+
287
+ return prompt_inputs
api/core/prompt/agent_history_prompt_transform.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, cast
2
+
3
+ from core.app.entities.app_invoke_entities import (
4
+ ModelConfigWithCredentialsEntity,
5
+ )
6
+ from core.memory.token_buffer_memory import TokenBufferMemory
7
+ from core.model_runtime.entities.message_entities import (
8
+ PromptMessage,
9
+ SystemPromptMessage,
10
+ UserPromptMessage,
11
+ )
12
+ from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
13
+ from core.prompt.prompt_transform import PromptTransform
14
+
15
+
16
+ class AgentHistoryPromptTransform(PromptTransform):
17
+ """
18
+ History Prompt Transform for Agent App
19
+ """
20
+
21
+ def __init__(
22
+ self,
23
+ model_config: ModelConfigWithCredentialsEntity,
24
+ prompt_messages: list[PromptMessage],
25
+ history_messages: list[PromptMessage],
26
+ memory: Optional[TokenBufferMemory] = None,
27
+ ):
28
+ self.model_config = model_config
29
+ self.prompt_messages = prompt_messages
30
+ self.history_messages = history_messages
31
+ self.memory = memory
32
+
33
+ def get_prompt(self) -> list[PromptMessage]:
34
+ prompt_messages: list[PromptMessage] = []
35
+ num_system = 0
36
+ for prompt_message in self.history_messages:
37
+ if isinstance(prompt_message, SystemPromptMessage):
38
+ prompt_messages.append(prompt_message)
39
+ num_system += 1
40
+
41
+ if not self.memory:
42
+ return prompt_messages
43
+
44
+ max_token_limit = self._calculate_rest_token(self.prompt_messages, self.model_config)
45
+
46
+ model_type_instance = self.model_config.provider_model_bundle.model_type_instance
47
+ model_type_instance = cast(LargeLanguageModel, model_type_instance)
48
+
49
+ curr_message_tokens = model_type_instance.get_num_tokens(
50
+ self.memory.model_instance.model, self.memory.model_instance.credentials, self.history_messages
51
+ )
52
+ if curr_message_tokens <= max_token_limit:
53
+ return self.history_messages
54
+
55
+ # number of prompt has been appended in current message
56
+ num_prompt = 0
57
+ # append prompt messages in desc order
58
+ for prompt_message in self.history_messages[::-1]:
59
+ if isinstance(prompt_message, SystemPromptMessage):
60
+ continue
61
+ prompt_messages.append(prompt_message)
62
+ num_prompt += 1
63
+ # a message is start with UserPromptMessage
64
+ if isinstance(prompt_message, UserPromptMessage):
65
+ curr_message_tokens = model_type_instance.get_num_tokens(
66
+ self.memory.model_instance.model, self.memory.model_instance.credentials, prompt_messages
67
+ )
68
+ # if current message token is overflow, drop all the prompts in current message and break
69
+ if curr_message_tokens > max_token_limit:
70
+ prompt_messages = prompt_messages[:-num_prompt]
71
+ break
72
+ num_prompt = 0
73
+ # return prompt messages in asc order
74
+ message_prompts = prompt_messages[num_system:]
75
+ message_prompts.reverse()
76
+
77
+ # merge system and message prompt
78
+ prompt_messages = prompt_messages[:num_system]
79
+ prompt_messages.extend(message_prompts)
80
+ return prompt_messages
api/core/prompt/entities/__init__.py ADDED
File without changes
api/core/prompt/entities/advanced_prompt_entities.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Literal, Optional
2
+
3
+ from pydantic import BaseModel
4
+
5
+ from core.model_runtime.entities.message_entities import PromptMessageRole
6
+
7
+
8
+ class ChatModelMessage(BaseModel):
9
+ """
10
+ Chat Message.
11
+ """
12
+
13
+ text: str
14
+ role: PromptMessageRole
15
+ edition_type: Optional[Literal["basic", "jinja2"]] = None
16
+
17
+
18
+ class CompletionModelPromptTemplate(BaseModel):
19
+ """
20
+ Completion Model Prompt Template.
21
+ """
22
+
23
+ text: str
24
+ edition_type: Optional[Literal["basic", "jinja2"]] = None
25
+
26
+
27
+ class MemoryConfig(BaseModel):
28
+ """
29
+ Memory Config.
30
+ """
31
+
32
+ class RolePrefix(BaseModel):
33
+ """
34
+ Role Prefix.
35
+ """
36
+
37
+ user: str
38
+ assistant: str
39
+
40
+ class WindowConfig(BaseModel):
41
+ """
42
+ Window Config.
43
+ """
44
+
45
+ enabled: bool
46
+ size: Optional[int] = None
47
+
48
+ role_prefix: Optional[RolePrefix] = None
49
+ window: WindowConfig
50
+ query_prompt_template: Optional[str] = None
api/core/prompt/prompt_templates/__init__.py ADDED
File without changes
api/core/prompt/prompt_templates/advanced_prompt_templates.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CONTEXT = "Use the following context as your learned knowledge, inside <context></context> XML tags.\n\n<context>\n{{#context#}}\n</context>\n\nWhen answer to user:\n- If you don't know, just say that you don't know.\n- If you don't know when you are not sure, ask for clarification.\nAvoid mentioning that you obtained the information from the context.\nAnd answer according to the language of the user's question.\n" # noqa: E501
2
+
3
+ BAICHUAN_CONTEXT = "用户在与一个客观的助手对话。助手会尊重找到的材料,给出全面专业的解释,但不会过度演绎。同时回答中不会暴露引用的材料:\n\n```\n{{#context#}}\n```\n\n" # noqa: E501
4
+
5
+ CHAT_APP_COMPLETION_PROMPT_CONFIG = {
6
+ "completion_prompt_config": {
7
+ "prompt": {
8
+ "text": "{{#pre_prompt#}}\nHere are the chat histories between human and assistant, inside <histories></histories> XML tags.\n\n<histories>\n{{#histories#}}\n</histories>\n\n\nHuman: {{#query#}}\n\nAssistant: " # noqa: E501
9
+ },
10
+ "conversation_histories_role": {"user_prefix": "Human", "assistant_prefix": "Assistant"},
11
+ },
12
+ "stop": ["Human:"],
13
+ }
14
+
15
+ CHAT_APP_CHAT_PROMPT_CONFIG = {"chat_prompt_config": {"prompt": [{"role": "system", "text": "{{#pre_prompt#}}"}]}}
16
+
17
+ COMPLETION_APP_CHAT_PROMPT_CONFIG = {"chat_prompt_config": {"prompt": [{"role": "user", "text": "{{#pre_prompt#}}"}]}}
18
+
19
+ COMPLETION_APP_COMPLETION_PROMPT_CONFIG = {
20
+ "completion_prompt_config": {"prompt": {"text": "{{#pre_prompt#}}"}},
21
+ "stop": ["Human:"],
22
+ }
23
+
24
+ BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG = {
25
+ "completion_prompt_config": {
26
+ "prompt": {
27
+ "text": "{{#pre_prompt#}}\n\n用户和助手的历史对话内容如下:\n```\n{{#histories#}}\n```\n\n\n\n用户:{{#query#}}" # noqa: E501
28
+ },
29
+ "conversation_histories_role": {"user_prefix": "用户", "assistant_prefix": "助手"},
30
+ },
31
+ "stop": ["用户:"],
32
+ }
33
+
34
+ BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG = {
35
+ "chat_prompt_config": {"prompt": [{"role": "system", "text": "{{#pre_prompt#}}"}]}
36
+ }
37
+
38
+ BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG = {
39
+ "chat_prompt_config": {"prompt": [{"role": "user", "text": "{{#pre_prompt#}}"}]}
40
+ }
41
+
42
+ BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG = {
43
+ "completion_prompt_config": {"prompt": {"text": "{{#pre_prompt#}}"}},
44
+ "stop": ["用户:"],
45
+ }
api/core/prompt/prompt_templates/baichuan_chat.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "human_prefix": "用户",
3
+ "assistant_prefix": "助手",
4
+ "context_prompt": "用户在与一个客观的助手对话。助手会尊重找到的材料,给出全面专业的解释,但不会过度演绎。同时回答中不会暴露引用的材料:\n\n```\n{{#context#}}\n```\n\n",
5
+ "histories_prompt": "用户和助手的历史对话内容如下:\n```\n{{#histories#}}\n```\n\n",
6
+ "system_prompt_orders": [
7
+ "context_prompt",
8
+ "pre_prompt",
9
+ "histories_prompt"
10
+ ],
11
+ "query_prompt": "\n\n用户:{{#query#}}",
12
+ "stops": ["用户:"]
13
+ }
api/core/prompt/prompt_templates/baichuan_completion.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "context_prompt": "用户在与一个客观的助手对话。助手会尊重找到的材料,给出全面专业的解释,但不会过度演绎。同时回答中不会暴露引用的材料:\n\n```\n{{#context#}}\n```\n",
3
+ "system_prompt_orders": [
4
+ "context_prompt",
5
+ "pre_prompt"
6
+ ],
7
+ "query_prompt": "{{#query#}}",
8
+ "stops": null
9
+ }
api/core/prompt/prompt_templates/common_chat.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "human_prefix": "Human",
3
+ "assistant_prefix": "Assistant",
4
+ "context_prompt": "Use the following context as your learned knowledge, inside <context></context> XML tags.\n\n<context>\n{{#context#}}\n</context>\n\nWhen answer to user:\n- If you don't know, just say that you don't know.\n- If you don't know when you are not sure, ask for clarification.\nAvoid mentioning that you obtained the information from the context.\nAnd answer according to the language of the user's question.\n\n",
5
+ "histories_prompt": "Here is the chat histories between human and assistant, inside <histories></histories> XML tags.\n\n<histories>\n{{#histories#}}\n</histories>\n\n",
6
+ "system_prompt_orders": [
7
+ "context_prompt",
8
+ "pre_prompt",
9
+ "histories_prompt"
10
+ ],
11
+ "query_prompt": "\n\nHuman: {{#query#}}\n\nAssistant: ",
12
+ "stops": ["\nHuman:", "</histories>"]
13
+ }
api/core/prompt/prompt_templates/common_completion.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "context_prompt": "Use the following context as your learned knowledge, inside <context></context> XML tags.\n\n<context>\n{{#context#}}\n</context>\n\nWhen answer to user:\n- If you don't know, just say that you don't know.\n- If you don't know when you are not sure, ask for clarification.\nAvoid mentioning that you obtained the information from the context.\nAnd answer according to the language of the user's question.\n\n",
3
+ "system_prompt_orders": [
4
+ "context_prompt",
5
+ "pre_prompt"
6
+ ],
7
+ "query_prompt": "{{#query#}}",
8
+ "stops": null
9
+ }
api/core/prompt/prompt_transform.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Optional
2
+
3
+ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
4
+ from core.memory.token_buffer_memory import TokenBufferMemory
5
+ from core.model_manager import ModelInstance
6
+ from core.model_runtime.entities.message_entities import PromptMessage
7
+ from core.model_runtime.entities.model_entities import ModelPropertyKey
8
+ from core.prompt.entities.advanced_prompt_entities import MemoryConfig
9
+
10
+
11
+ class PromptTransform:
12
+ def _append_chat_histories(
13
+ self,
14
+ memory: TokenBufferMemory,
15
+ memory_config: MemoryConfig,
16
+ prompt_messages: list[PromptMessage],
17
+ model_config: ModelConfigWithCredentialsEntity,
18
+ ) -> list[PromptMessage]:
19
+ rest_tokens = self._calculate_rest_token(prompt_messages, model_config)
20
+ histories = self._get_history_messages_list_from_memory(memory, memory_config, rest_tokens)
21
+ prompt_messages.extend(histories)
22
+
23
+ return prompt_messages
24
+
25
+ def _calculate_rest_token(
26
+ self, prompt_messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity
27
+ ) -> int:
28
+ rest_tokens = 2000
29
+
30
+ model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
31
+ if model_context_tokens:
32
+ model_instance = ModelInstance(
33
+ provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
34
+ )
35
+
36
+ curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages)
37
+
38
+ max_tokens = 0
39
+ for parameter_rule in model_config.model_schema.parameter_rules:
40
+ if parameter_rule.name == "max_tokens" or (
41
+ parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
42
+ ):
43
+ max_tokens = (
44
+ model_config.parameters.get(parameter_rule.name)
45
+ or model_config.parameters.get(parameter_rule.use_template or "")
46
+ ) or 0
47
+
48
+ rest_tokens = model_context_tokens - max_tokens - curr_message_tokens
49
+ rest_tokens = max(rest_tokens, 0)
50
+
51
+ return rest_tokens
52
+
53
+ def _get_history_messages_from_memory(
54
+ self,
55
+ memory: TokenBufferMemory,
56
+ memory_config: MemoryConfig,
57
+ max_token_limit: int,
58
+ human_prefix: Optional[str] = None,
59
+ ai_prefix: Optional[str] = None,
60
+ ) -> str:
61
+ """Get memory messages."""
62
+ kwargs: dict[str, Any] = {"max_token_limit": max_token_limit}
63
+
64
+ if human_prefix:
65
+ kwargs["human_prefix"] = human_prefix
66
+
67
+ if ai_prefix:
68
+ kwargs["ai_prefix"] = ai_prefix
69
+
70
+ if memory_config.window.enabled and memory_config.window.size is not None and memory_config.window.size > 0:
71
+ kwargs["message_limit"] = memory_config.window.size
72
+
73
+ return memory.get_history_prompt_text(**kwargs)
74
+
75
+ def _get_history_messages_list_from_memory(
76
+ self, memory: TokenBufferMemory, memory_config: MemoryConfig, max_token_limit: int
77
+ ) -> list[PromptMessage]:
78
+ """Get memory messages."""
79
+ return list(
80
+ memory.get_history_prompt_messages(
81
+ max_token_limit=max_token_limit,
82
+ message_limit=memory_config.window.size
83
+ if (
84
+ memory_config.window.enabled
85
+ and memory_config.window.size is not None
86
+ and memory_config.window.size > 0
87
+ )
88
+ else None,
89
+ )
90
+ )
api/core/prompt/simple_prompt_transform.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import enum
2
+ import json
3
+ import os
4
+ from collections.abc import Mapping, Sequence
5
+ from typing import TYPE_CHECKING, Any, Optional, cast
6
+
7
+ from core.app.app_config.entities import PromptTemplateEntity
8
+ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
9
+ from core.file import file_manager
10
+ from core.memory.token_buffer_memory import TokenBufferMemory
11
+ from core.model_runtime.entities.message_entities import (
12
+ PromptMessage,
13
+ PromptMessageContent,
14
+ SystemPromptMessage,
15
+ TextPromptMessageContent,
16
+ UserPromptMessage,
17
+ )
18
+ from core.prompt.entities.advanced_prompt_entities import MemoryConfig
19
+ from core.prompt.prompt_transform import PromptTransform
20
+ from core.prompt.utils.prompt_template_parser import PromptTemplateParser
21
+ from models.model import AppMode
22
+
23
+ if TYPE_CHECKING:
24
+ from core.file.models import File
25
+
26
+
27
+ class ModelMode(enum.StrEnum):
28
+ COMPLETION = "completion"
29
+ CHAT = "chat"
30
+
31
+ @classmethod
32
+ def value_of(cls, value: str) -> "ModelMode":
33
+ """
34
+ Get value of given mode.
35
+
36
+ :param value: mode value
37
+ :return: mode
38
+ """
39
+ for mode in cls:
40
+ if mode.value == value:
41
+ return mode
42
+ raise ValueError(f"invalid mode value {value}")
43
+
44
+
45
+ prompt_file_contents: dict[str, Any] = {}
46
+
47
+
48
+ class SimplePromptTransform(PromptTransform):
49
+ """
50
+ Simple Prompt Transform for Chatbot App Basic Mode.
51
+ """
52
+
53
+ def get_prompt(
54
+ self,
55
+ app_mode: AppMode,
56
+ prompt_template_entity: PromptTemplateEntity,
57
+ inputs: Mapping[str, str],
58
+ query: str,
59
+ files: Sequence["File"],
60
+ context: Optional[str],
61
+ memory: Optional[TokenBufferMemory],
62
+ model_config: ModelConfigWithCredentialsEntity,
63
+ ) -> tuple[list[PromptMessage], Optional[list[str]]]:
64
+ inputs = {key: str(value) for key, value in inputs.items()}
65
+
66
+ model_mode = ModelMode.value_of(model_config.mode)
67
+ if model_mode == ModelMode.CHAT:
68
+ prompt_messages, stops = self._get_chat_model_prompt_messages(
69
+ app_mode=app_mode,
70
+ pre_prompt=prompt_template_entity.simple_prompt_template or "",
71
+ inputs=inputs,
72
+ query=query,
73
+ files=files,
74
+ context=context,
75
+ memory=memory,
76
+ model_config=model_config,
77
+ )
78
+ else:
79
+ prompt_messages, stops = self._get_completion_model_prompt_messages(
80
+ app_mode=app_mode,
81
+ pre_prompt=prompt_template_entity.simple_prompt_template or "",
82
+ inputs=inputs,
83
+ query=query,
84
+ files=files,
85
+ context=context,
86
+ memory=memory,
87
+ model_config=model_config,
88
+ )
89
+
90
+ return prompt_messages, stops
91
+
92
+ def get_prompt_str_and_rules(
93
+ self,
94
+ app_mode: AppMode,
95
+ model_config: ModelConfigWithCredentialsEntity,
96
+ pre_prompt: str,
97
+ inputs: dict,
98
+ query: Optional[str] = None,
99
+ context: Optional[str] = None,
100
+ histories: Optional[str] = None,
101
+ ) -> tuple[str, dict]:
102
+ # get prompt template
103
+ prompt_template_config = self.get_prompt_template(
104
+ app_mode=app_mode,
105
+ provider=model_config.provider,
106
+ model=model_config.model,
107
+ pre_prompt=pre_prompt,
108
+ has_context=context is not None,
109
+ query_in_prompt=query is not None,
110
+ with_memory_prompt=histories is not None,
111
+ )
112
+
113
+ variables = {k: inputs[k] for k in prompt_template_config["custom_variable_keys"] if k in inputs}
114
+
115
+ for v in prompt_template_config["special_variable_keys"]:
116
+ # support #context#, #query# and #histories#
117
+ if v == "#context#":
118
+ variables["#context#"] = context or ""
119
+ elif v == "#query#":
120
+ variables["#query#"] = query or ""
121
+ elif v == "#histories#":
122
+ variables["#histories#"] = histories or ""
123
+
124
+ prompt_template = prompt_template_config["prompt_template"]
125
+ prompt = prompt_template.format(variables)
126
+
127
+ return prompt, prompt_template_config["prompt_rules"]
128
+
129
+ def get_prompt_template(
130
+ self,
131
+ app_mode: AppMode,
132
+ provider: str,
133
+ model: str,
134
+ pre_prompt: str,
135
+ has_context: bool,
136
+ query_in_prompt: bool,
137
+ with_memory_prompt: bool = False,
138
+ ) -> dict:
139
+ prompt_rules = self._get_prompt_rule(app_mode=app_mode, provider=provider, model=model)
140
+
141
+ custom_variable_keys = []
142
+ special_variable_keys = []
143
+
144
+ prompt = ""
145
+ for order in prompt_rules["system_prompt_orders"]:
146
+ if order == "context_prompt" and has_context:
147
+ prompt += prompt_rules["context_prompt"]
148
+ special_variable_keys.append("#context#")
149
+ elif order == "pre_prompt" and pre_prompt:
150
+ prompt += pre_prompt + "\n"
151
+ pre_prompt_template = PromptTemplateParser(template=pre_prompt)
152
+ custom_variable_keys = pre_prompt_template.variable_keys
153
+ elif order == "histories_prompt" and with_memory_prompt:
154
+ prompt += prompt_rules["histories_prompt"]
155
+ special_variable_keys.append("#histories#")
156
+
157
+ if query_in_prompt:
158
+ prompt += prompt_rules.get("query_prompt", "{{#query#}}")
159
+ special_variable_keys.append("#query#")
160
+
161
+ return {
162
+ "prompt_template": PromptTemplateParser(template=prompt),
163
+ "custom_variable_keys": custom_variable_keys,
164
+ "special_variable_keys": special_variable_keys,
165
+ "prompt_rules": prompt_rules,
166
+ }
167
+
168
+ def _get_chat_model_prompt_messages(
169
+ self,
170
+ app_mode: AppMode,
171
+ pre_prompt: str,
172
+ inputs: dict,
173
+ query: str,
174
+ context: Optional[str],
175
+ files: Sequence["File"],
176
+ memory: Optional[TokenBufferMemory],
177
+ model_config: ModelConfigWithCredentialsEntity,
178
+ ) -> tuple[list[PromptMessage], Optional[list[str]]]:
179
+ prompt_messages: list[PromptMessage] = []
180
+
181
+ # get prompt
182
+ prompt, _ = self.get_prompt_str_and_rules(
183
+ app_mode=app_mode,
184
+ model_config=model_config,
185
+ pre_prompt=pre_prompt,
186
+ inputs=inputs,
187
+ query=None,
188
+ context=context,
189
+ )
190
+
191
+ if prompt and query:
192
+ prompt_messages.append(SystemPromptMessage(content=prompt))
193
+
194
+ if memory:
195
+ prompt_messages = self._append_chat_histories(
196
+ memory=memory,
197
+ memory_config=MemoryConfig(
198
+ window=MemoryConfig.WindowConfig(
199
+ enabled=False,
200
+ )
201
+ ),
202
+ prompt_messages=prompt_messages,
203
+ model_config=model_config,
204
+ )
205
+
206
+ if query:
207
+ prompt_messages.append(self.get_last_user_message(query, files))
208
+ else:
209
+ prompt_messages.append(self.get_last_user_message(prompt, files))
210
+
211
+ return prompt_messages, None
212
+
213
+ def _get_completion_model_prompt_messages(
214
+ self,
215
+ app_mode: AppMode,
216
+ pre_prompt: str,
217
+ inputs: dict,
218
+ query: str,
219
+ context: Optional[str],
220
+ files: Sequence["File"],
221
+ memory: Optional[TokenBufferMemory],
222
+ model_config: ModelConfigWithCredentialsEntity,
223
+ ) -> tuple[list[PromptMessage], Optional[list[str]]]:
224
+ # get prompt
225
+ prompt, prompt_rules = self.get_prompt_str_and_rules(
226
+ app_mode=app_mode,
227
+ model_config=model_config,
228
+ pre_prompt=pre_prompt,
229
+ inputs=inputs,
230
+ query=query,
231
+ context=context,
232
+ )
233
+
234
+ if memory:
235
+ tmp_human_message = UserPromptMessage(content=prompt)
236
+
237
+ rest_tokens = self._calculate_rest_token([tmp_human_message], model_config)
238
+ histories = self._get_history_messages_from_memory(
239
+ memory=memory,
240
+ memory_config=MemoryConfig(
241
+ window=MemoryConfig.WindowConfig(
242
+ enabled=False,
243
+ )
244
+ ),
245
+ max_token_limit=rest_tokens,
246
+ human_prefix=prompt_rules.get("human_prefix", "Human"),
247
+ ai_prefix=prompt_rules.get("assistant_prefix", "Assistant"),
248
+ )
249
+
250
+ # get prompt
251
+ prompt, prompt_rules = self.get_prompt_str_and_rules(
252
+ app_mode=app_mode,
253
+ model_config=model_config,
254
+ pre_prompt=pre_prompt,
255
+ inputs=inputs,
256
+ query=query,
257
+ context=context,
258
+ histories=histories,
259
+ )
260
+
261
+ stops = prompt_rules.get("stops")
262
+ if stops is not None and len(stops) == 0:
263
+ stops = None
264
+
265
+ return [self.get_last_user_message(prompt, files)], stops
266
+
267
+ def get_last_user_message(self, prompt: str, files: Sequence["File"]) -> UserPromptMessage:
268
+ if files:
269
+ prompt_message_contents: list[PromptMessageContent] = []
270
+ prompt_message_contents.append(TextPromptMessageContent(data=prompt))
271
+ for file in files:
272
+ prompt_message_contents.append(file_manager.to_prompt_message_content(file))
273
+
274
+ prompt_message = UserPromptMessage(content=prompt_message_contents)
275
+ else:
276
+ prompt_message = UserPromptMessage(content=prompt)
277
+
278
+ return prompt_message
279
+
280
+ def _get_prompt_rule(self, app_mode: AppMode, provider: str, model: str) -> dict:
281
+ """
282
+ Get simple prompt rule.
283
+ :param app_mode: app mode
284
+ :param provider: model provider
285
+ :param model: model name
286
+ :return:
287
+ """
288
+ prompt_file_name = self._prompt_file_name(app_mode=app_mode, provider=provider, model=model)
289
+
290
+ # Check if the prompt file is already loaded
291
+ if prompt_file_name in prompt_file_contents:
292
+ return cast(dict, prompt_file_contents[prompt_file_name])
293
+
294
+ # Get the absolute path of the subdirectory
295
+ prompt_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "prompt_templates")
296
+ json_file_path = os.path.join(prompt_path, f"{prompt_file_name}.json")
297
+
298
+ # Open the JSON file and read its content
299
+ with open(json_file_path, encoding="utf-8") as json_file:
300
+ content = json.load(json_file)
301
+
302
+ # Store the content of the prompt file
303
+ prompt_file_contents[prompt_file_name] = content
304
+
305
+ return cast(dict, content)
306
+
307
+ def _prompt_file_name(self, app_mode: AppMode, provider: str, model: str) -> str:
308
+ # baichuan
309
+ is_baichuan = False
310
+ if provider == "baichuan":
311
+ is_baichuan = True
312
+ else:
313
+ baichuan_supported_providers = ["huggingface_hub", "openllm", "xinference"]
314
+ if provider in baichuan_supported_providers and "baichuan" in model.lower():
315
+ is_baichuan = True
316
+
317
+ if is_baichuan:
318
+ if app_mode == AppMode.COMPLETION:
319
+ return "baichuan_completion"
320
+ else:
321
+ return "baichuan_chat"
322
+
323
+ # common
324
+ if app_mode == AppMode.COMPLETION:
325
+ return "common_completion"
326
+ else:
327
+ return "common_chat"
api/core/prompt/utils/__init__.py ADDED
File without changes