Upload 1285 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +6 -0
- api/core/model_runtime/README.md +70 -0
- api/core/model_runtime/README_CN.md +89 -0
- api/core/model_runtime/__init__.py +0 -0
- api/core/model_runtime/model_providers/__base/__init__.py +0 -0
- api/core/model_runtime/model_providers/__base/ai_model.py +334 -0
- api/core/model_runtime/model_providers/__base/audio.mp3 +3 -0
- api/core/model_runtime/model_providers/__base/large_language_model.py +904 -0
- api/core/model_runtime/model_providers/__base/model_provider.py +121 -0
- api/core/model_runtime/model_providers/__base/moderation_model.py +49 -0
- api/core/model_runtime/model_providers/__base/rerank_model.py +69 -0
- api/core/model_runtime/model_providers/__base/speech2text_model.py +59 -0
- api/core/model_runtime/model_providers/__base/text2img_model.py +54 -0
- api/core/model_runtime/model_providers/__base/text_embedding_model.py +111 -0
- api/core/model_runtime/model_providers/__base/tokenizers/gpt2/merges.txt +0 -0
- api/core/model_runtime/model_providers/__base/tokenizers/gpt2/special_tokens_map.json +23 -0
- api/core/model_runtime/model_providers/__base/tokenizers/gpt2/tokenizer_config.json +33 -0
- api/core/model_runtime/model_providers/__base/tokenizers/gpt2/vocab.json +0 -0
- api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenzier.py +51 -0
- api/core/model_runtime/model_providers/__base/tts_model.py +179 -0
- api/core/model_runtime/model_providers/__init__.py +3 -0
- api/core/model_runtime/model_providers/_position.yaml +43 -0
- api/core/model_runtime/model_providers/anthropic/__init__.py +0 -0
- api/core/model_runtime/model_providers/anthropic/_assets/icon_l_en.svg +78 -0
- api/core/model_runtime/model_providers/anthropic/_assets/icon_s_en.svg +4 -0
- api/core/model_runtime/model_providers/anthropic/anthropic.py +28 -0
- api/core/model_runtime/model_providers/anthropic/anthropic.yaml +39 -0
- api/core/model_runtime/model_providers/anthropic/llm/__init__.py +0 -0
- api/core/model_runtime/model_providers/anthropic/llm/_position.yaml +10 -0
- api/core/model_runtime/model_providers/anthropic/llm/claude-2.1.yaml +36 -0
- api/core/model_runtime/model_providers/anthropic/llm/claude-2.yaml +37 -0
- api/core/model_runtime/model_providers/anthropic/llm/claude-3-5-haiku-20241022.yaml +38 -0
- api/core/model_runtime/model_providers/anthropic/llm/claude-3-5-sonnet-20240620.yaml +40 -0
- api/core/model_runtime/model_providers/anthropic/llm/claude-3-5-sonnet-20241022.yaml +40 -0
- api/core/model_runtime/model_providers/anthropic/llm/claude-3-haiku-20240307.yaml +39 -0
- api/core/model_runtime/model_providers/anthropic/llm/claude-3-opus-20240229.yaml +39 -0
- api/core/model_runtime/model_providers/anthropic/llm/claude-3-sonnet-20240229.yaml +39 -0
- api/core/model_runtime/model_providers/anthropic/llm/claude-instant-1.2.yaml +36 -0
- api/core/model_runtime/model_providers/anthropic/llm/claude-instant-1.yaml +36 -0
- api/core/model_runtime/model_providers/anthropic/llm/llm.py +654 -0
- api/core/model_runtime/model_providers/azure_ai_studio/__init__.py +0 -0
- api/core/model_runtime/model_providers/azure_ai_studio/_assets/icon_l_en.png +0 -0
- api/core/model_runtime/model_providers/azure_ai_studio/_assets/icon_s_en.png +0 -0
- api/core/model_runtime/model_providers/azure_ai_studio/azure_ai_studio.py +17 -0
- api/core/model_runtime/model_providers/azure_ai_studio/azure_ai_studio.yaml +99 -0
- api/core/model_runtime/model_providers/azure_ai_studio/llm/__init__.py +0 -0
- api/core/model_runtime/model_providers/azure_ai_studio/llm/llm.py +345 -0
- api/core/model_runtime/model_providers/azure_ai_studio/rerank/__init__.py +0 -0
- api/core/model_runtime/model_providers/azure_ai_studio/rerank/rerank.py +164 -0
- api/core/model_runtime/model_providers/azure_openai/__init__.py +0 -0
.gitattributes
CHANGED
@@ -26,3 +26,9 @@ api/core/model_runtime/docs/zh_Hans/images/index/image-20231210144229650.png fil
|
|
26 |
api/core/model_runtime/docs/zh_Hans/images/index/image-20231210144814617.png filter=lfs diff=lfs merge=lfs -text
|
27 |
api/core/model_runtime/docs/zh_Hans/images/index/image-20231210165243632.png filter=lfs diff=lfs merge=lfs -text
|
28 |
api/core/model_runtime/docs/zh_Hans/images/index/image.png filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
api/core/model_runtime/docs/zh_Hans/images/index/image-20231210144814617.png filter=lfs diff=lfs merge=lfs -text
|
27 |
api/core/model_runtime/docs/zh_Hans/images/index/image-20231210165243632.png filter=lfs diff=lfs merge=lfs -text
|
28 |
api/core/model_runtime/docs/zh_Hans/images/index/image.png filter=lfs diff=lfs merge=lfs -text
|
29 |
+
api/core/model_runtime/model_providers/__base/audio.mp3 filter=lfs diff=lfs merge=lfs -text
|
30 |
+
api/core/model_runtime/model_providers/gpustack/_assets/icon_l_en.png filter=lfs diff=lfs merge=lfs -text
|
31 |
+
api/core/model_runtime/model_providers/leptonai/_assets/icon_l_en.png filter=lfs diff=lfs merge=lfs -text
|
32 |
+
api/core/model_runtime/model_providers/mixedbread/_assets/icon_l_en.png filter=lfs diff=lfs merge=lfs -text
|
33 |
+
api/core/model_runtime/model_providers/nvidia_nim/_assets/icon_l_en.png filter=lfs diff=lfs merge=lfs -text
|
34 |
+
api/core/model_runtime/model_providers/nvidia/_assets/icon_l_en.png filter=lfs diff=lfs merge=lfs -text
|
api/core/model_runtime/README.md
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Model Runtime
|
2 |
+
|
3 |
+
This module provides the interface for invoking and authenticating various models, and offers Dify a unified information and credentials form rule for model providers.
|
4 |
+
|
5 |
+
- On one hand, it decouples models from upstream and downstream processes, facilitating horizontal expansion for developers,
|
6 |
+
- On the other hand, it allows for direct display of providers and models in the frontend interface by simply defining them in the backend, eliminating the need to modify frontend logic.
|
7 |
+
|
8 |
+
## Features
|
9 |
+
|
10 |
+
- Supports capability invocation for 5 types of models
|
11 |
+
|
12 |
+
- `LLM` - LLM text completion, dialogue, pre-computed tokens capability
|
13 |
+
- `Text Embedding Model` - Text Embedding, pre-computed tokens capability
|
14 |
+
- `Rerank Model` - Segment Rerank capability
|
15 |
+
- `Speech-to-text Model` - Speech to text capability
|
16 |
+
- `Text-to-speech Model` - Text to speech capability
|
17 |
+
- `Moderation` - Moderation capability
|
18 |
+
|
19 |
+
- Model provider display
|
20 |
+
|
21 |
+

|
22 |
+
|
23 |
+
Displays a list of all supported providers, including provider names, icons, supported model types list, predefined model list, configuration method, and credentials form rules, etc. For detailed rule design, see: [Schema](./docs/en_US/schema.md).
|
24 |
+
|
25 |
+
- Selectable model list display
|
26 |
+
|
27 |
+

|
28 |
+
|
29 |
+
After configuring provider/model credentials, the dropdown (application orchestration interface/default model) allows viewing of the available LLM list. Greyed out items represent predefined model lists from providers without configured credentials, facilitating user review of supported models.
|
30 |
+
|
31 |
+
In addition, this list also returns configurable parameter information and rules for LLM, as shown below:
|
32 |
+
|
33 |
+

|
34 |
+
|
35 |
+
These parameters are all defined in the backend, allowing different settings for various parameters supported by different models, as detailed in: [Schema](./docs/en_US/schema.md#ParameterRule).
|
36 |
+
|
37 |
+
- Provider/model credential authentication
|
38 |
+
|
39 |
+

|
40 |
+
|
41 |
+

|
42 |
+
|
43 |
+
The provider list returns configuration information for the credentials form, which can be authenticated through Runtime's interface. The first image above is a provider credential DEMO, and the second is a model credential DEMO.
|
44 |
+
|
45 |
+
## Structure
|
46 |
+
|
47 |
+

|
48 |
+
|
49 |
+
Model Runtime is divided into three layers:
|
50 |
+
|
51 |
+
- The outermost layer is the factory method
|
52 |
+
|
53 |
+
It provides methods for obtaining all providers, all model lists, getting provider instances, and authenticating provider/model credentials.
|
54 |
+
|
55 |
+
- The second layer is the provider layer
|
56 |
+
|
57 |
+
It provides the current provider's model list, model instance obtaining, provider credential authentication, and provider configuration rule information, **allowing horizontal expansion** to support different providers.
|
58 |
+
|
59 |
+
- The bottom layer is the model layer
|
60 |
+
|
61 |
+
It offers direct invocation of various model types, predefined model configuration information, getting predefined/remote model lists, model credential authentication methods. Different models provide additional special methods, like LLM's pre-computed tokens method, cost information obtaining method, etc., **allowing horizontal expansion** for different models under the same provider (within supported model types).
|
62 |
+
|
63 |
+
|
64 |
+
|
65 |
+
## Next Steps
|
66 |
+
|
67 |
+
- Add new provider configuration: [Link](./docs/en_US/provider_scale_out.md)
|
68 |
+
- Add new models for existing providers: [Link](./docs/en_US/provider_scale_out.md#AddModel)
|
69 |
+
- View YAML configuration rules: [Link](./docs/en_US/schema.md)
|
70 |
+
- Implement interface methods: [Link](./docs/en_US/interfaces.md)
|
api/core/model_runtime/README_CN.md
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Model Runtime
|
2 |
+
|
3 |
+
该模块提供了各模型的调用、鉴权接口,并为 Dify 提供了统一的模型供应商的信息和凭据表单规则。
|
4 |
+
|
5 |
+
- 一方面将模型和上下游解耦,方便开发者对模型横向扩展,
|
6 |
+
- 另一方面提供了只需在后端定义供应商和模型,即可在前端页面直接展示,无需修改前端逻辑。
|
7 |
+
|
8 |
+
## 功能介绍
|
9 |
+
|
10 |
+
- 支持 5 种模型类型的能力调用
|
11 |
+
|
12 |
+
- `LLM` - LLM 文本补全、对话,预计算 tokens 能力
|
13 |
+
- `Text Embedidng Model` - 文本 Embedding ,预计算 tokens 能力
|
14 |
+
- `Rerank Model` - 分段 Rerank 能力
|
15 |
+
- `Speech-to-text Model` - 语音转文本能力
|
16 |
+
- `Text-to-speech Model` - 文本转语音能力
|
17 |
+
- `Moderation` - Moderation 能力
|
18 |
+
|
19 |
+
- 模型供应商展示
|
20 |
+
|
21 |
+

|
22 |
+
|
23 |
+
展示所有已支持的供应商列表,除了返回供应商名称、图标之外,还提供了支持的模型类型列表,预定义模型列表、配置方式以及配置凭据的表单规则等等,规则设计详见:[Schema](./docs/zh_Hans/schema.md)。
|
24 |
+
|
25 |
+
- 可选择的模型列表展示
|
26 |
+
|
27 |
+

|
28 |
+
|
29 |
+
配置供应商/模型凭据后,可在此下拉(应用编排界面/默认模型)查看可用的 LLM 列表,其中灰色的为未配置凭据供应商的预定义模型列表,方便用户查看已支持的模型。
|
30 |
+
|
31 |
+
除此之外,该列表还返回了 LLM 可配置的参数信息和规则,如下图:
|
32 |
+
|
33 |
+

|
34 |
+
|
35 |
+
这里的参数均为后端定义,相比之前只有 5 种固定参数,这里可为不同模型设置所支持的各种参数,详见:[Schema](./docs/zh_Hans/schema.md#ParameterRule)。
|
36 |
+
|
37 |
+
- 供应商/模型凭据鉴权
|
38 |
+
|
39 |
+

|
40 |
+
|
41 |
+

|
42 |
+
|
43 |
+
供应商列表返回了凭据表单的配置信息,可通过 Runtime 提供的接口对凭据进行鉴权,上图 1 为供应商凭据 DEMO,上图 2 为模型凭据 DEMO。
|
44 |
+
|
45 |
+
## 结构
|
46 |
+
|
47 |
+

|
48 |
+
|
49 |
+
Model Runtime 分三层:
|
50 |
+
|
51 |
+
- 最外层为工厂方法
|
52 |
+
|
53 |
+
提供获取所有供应商、所有模型列表、获取供应商实例、供应商/模型凭据鉴权方法。
|
54 |
+
|
55 |
+
- 第二层为供应商层
|
56 |
+
|
57 |
+
提供获取当前供应商模型列表、获取模型实例、供应商凭据鉴权、供应商配置规则信息,**可横向扩展**以支持不同的供应商。
|
58 |
+
|
59 |
+
对于供应商/模型凭据,有两种情况
|
60 |
+
- 如OpenAI这类中心化供应商,需要定义如**api_key**这类的鉴权凭据
|
61 |
+
- 如[**Xinference**](https://github.com/xorbitsai/inference)这类本地部署的供应商,需要定义如**server_url**这类的地址凭据,有时候还需要定义**model_uid**之类的模型类型凭据,就像下面这样,当在供应商层定义了这些凭据后,就可以在前端页面上直接展示,无需修改前端逻辑。
|
62 |
+

|
63 |
+
|
64 |
+
当配置好凭据后,就可以通过DifyRuntime的外部接口直接获取到对应供应商所需要的**Schema**(凭据表单规则),从而在可以在不修改前端逻辑的情况下,提供新的供应商/模型的支持。
|
65 |
+
|
66 |
+
- 最底层为模型层
|
67 |
+
|
68 |
+
提供各种模型类型的直接调用、预定义模型配置信息、获取预定义/远程模型列表、模型凭据鉴权方法,不同模型额外提供了特殊方法,如 LLM 提供预计算 tokens 方法、获取费用信息方法等,**可横向扩展**同供应商下不同的模型(支持的模型类型下)。
|
69 |
+
|
70 |
+
在这里我们需要先区分模型参数与模型凭据。
|
71 |
+
|
72 |
+
- 模型参数(**在本层定义**):这是一类经常需要变动,随时调整的参数,如 LLM 的 **max_tokens**、**temperature** 等,这些参数是由用户在前端页面上进行调整的,因此需要在后端定义参数的规则,以便前端页面进行展示和调整。在DifyRuntime中,他们的参数名一般为**model_parameters: dict[str, any]**。
|
73 |
+
|
74 |
+
- 模型凭据(**在供应商层定义**):这是一类不经常变动,一般在配置好后就不会再变动的参数,如 **api_key**、**server_url** 等。在DifyRuntime中,他们的参数名一般为**credentials: dict[str, any]**,Provider层的credentials会直接被传递到这一层,不需要再单独定义。
|
75 |
+
|
76 |
+
## 下一步
|
77 |
+
|
78 |
+
### [增加新的供应商配置 👈🏻](./docs/zh_Hans/provider_scale_out.md)
|
79 |
+
当添加后,这里将会出现一个新的供应商
|
80 |
+
|
81 |
+

|
82 |
+
|
83 |
+
### [为已存在的供应商新增模型 👈🏻](./docs/zh_Hans/provider_scale_out.md#增加模型)
|
84 |
+
当添加后,对应供应商的模型列表中将会出现一个新的预定义模型供用户选择,如GPT-3.5 GPT-4 ChatGLM3-6b等,而对于支持自定义模型的供应商,则不需要新增模型。
|
85 |
+
|
86 |
+

|
87 |
+
|
88 |
+
### [接口的具体实现 👈🏻](./docs/zh_Hans/interfaces.md)
|
89 |
+
你可以在这里找到你想要查看的接口的具体实现,以及接口的参数和返回值的具体含义。
|
api/core/model_runtime/__init__.py
ADDED
File without changes
|
api/core/model_runtime/model_providers/__base/__init__.py
ADDED
File without changes
|
api/core/model_runtime/model_providers/__base/ai_model.py
ADDED
@@ -0,0 +1,334 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import decimal
|
2 |
+
import os
|
3 |
+
from abc import ABC, abstractmethod
|
4 |
+
from typing import Optional
|
5 |
+
|
6 |
+
from pydantic import ConfigDict
|
7 |
+
|
8 |
+
from core.helper.position_helper import get_position_map, sort_by_position_map
|
9 |
+
from core.model_runtime.entities.common_entities import I18nObject
|
10 |
+
from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE
|
11 |
+
from core.model_runtime.entities.model_entities import (
|
12 |
+
AIModelEntity,
|
13 |
+
DefaultParameterName,
|
14 |
+
FetchFrom,
|
15 |
+
ModelType,
|
16 |
+
PriceConfig,
|
17 |
+
PriceInfo,
|
18 |
+
PriceType,
|
19 |
+
)
|
20 |
+
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
21 |
+
from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
|
22 |
+
from core.tools.utils.yaml_utils import load_yaml_file
|
23 |
+
|
24 |
+
|
25 |
+
class AIModel(ABC):
|
26 |
+
"""
|
27 |
+
Base class for all models.
|
28 |
+
"""
|
29 |
+
|
30 |
+
model_type: ModelType
|
31 |
+
model_schemas: Optional[list[AIModelEntity]] = None
|
32 |
+
started_at: float = 0
|
33 |
+
|
34 |
+
# pydantic configs
|
35 |
+
model_config = ConfigDict(protected_namespaces=())
|
36 |
+
|
37 |
+
@abstractmethod
|
38 |
+
def validate_credentials(self, model: str, credentials: dict) -> None:
|
39 |
+
"""
|
40 |
+
Validate model credentials
|
41 |
+
|
42 |
+
:param model: model name
|
43 |
+
:param credentials: model credentials
|
44 |
+
:return:
|
45 |
+
"""
|
46 |
+
raise NotImplementedError
|
47 |
+
|
48 |
+
@property
|
49 |
+
@abstractmethod
|
50 |
+
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
51 |
+
"""
|
52 |
+
Map model invoke error to unified error
|
53 |
+
The key is the error type thrown to the caller
|
54 |
+
The value is the error type thrown by the model,
|
55 |
+
which needs to be converted into a unified error type for the caller.
|
56 |
+
|
57 |
+
:return: Invoke error mapping
|
58 |
+
"""
|
59 |
+
raise NotImplementedError
|
60 |
+
|
61 |
+
def _transform_invoke_error(self, error: Exception) -> InvokeError:
|
62 |
+
"""
|
63 |
+
Transform invoke error to unified error
|
64 |
+
|
65 |
+
:param error: model invoke error
|
66 |
+
:return: unified error
|
67 |
+
"""
|
68 |
+
provider_name = self.__class__.__module__.split(".")[-3]
|
69 |
+
|
70 |
+
for invoke_error, model_errors in self._invoke_error_mapping.items():
|
71 |
+
if isinstance(error, tuple(model_errors)):
|
72 |
+
if invoke_error == InvokeAuthorizationError:
|
73 |
+
return invoke_error(
|
74 |
+
description=(
|
75 |
+
f"[{provider_name}] Incorrect model credentials provided, please check and try again."
|
76 |
+
)
|
77 |
+
)
|
78 |
+
|
79 |
+
return invoke_error(description=f"[{provider_name}] {invoke_error.description}, {str(error)}")
|
80 |
+
|
81 |
+
return InvokeError(description=f"[{provider_name}] Error: {str(error)}")
|
82 |
+
|
83 |
+
def get_price(self, model: str, credentials: dict, price_type: PriceType, tokens: int) -> PriceInfo:
|
84 |
+
"""
|
85 |
+
Get price for given model and tokens
|
86 |
+
|
87 |
+
:param model: model name
|
88 |
+
:param credentials: model credentials
|
89 |
+
:param price_type: price type
|
90 |
+
:param tokens: number of tokens
|
91 |
+
:return: price info
|
92 |
+
"""
|
93 |
+
# get model schema
|
94 |
+
model_schema = self.get_model_schema(model, credentials)
|
95 |
+
|
96 |
+
# get price info from predefined model schema
|
97 |
+
price_config: Optional[PriceConfig] = None
|
98 |
+
if model_schema and model_schema.pricing:
|
99 |
+
price_config = model_schema.pricing
|
100 |
+
|
101 |
+
# get unit price
|
102 |
+
unit_price = None
|
103 |
+
if price_config:
|
104 |
+
if price_type == PriceType.INPUT:
|
105 |
+
unit_price = price_config.input
|
106 |
+
elif price_type == PriceType.OUTPUT and price_config.output is not None:
|
107 |
+
unit_price = price_config.output
|
108 |
+
|
109 |
+
if unit_price is None:
|
110 |
+
return PriceInfo(
|
111 |
+
unit_price=decimal.Decimal("0.0"),
|
112 |
+
unit=decimal.Decimal("0.0"),
|
113 |
+
total_amount=decimal.Decimal("0.0"),
|
114 |
+
currency="USD",
|
115 |
+
)
|
116 |
+
|
117 |
+
# calculate total amount
|
118 |
+
if not price_config:
|
119 |
+
raise ValueError(f"Price config not found for model {model}")
|
120 |
+
total_amount = tokens * unit_price * price_config.unit
|
121 |
+
total_amount = total_amount.quantize(decimal.Decimal("0.0000001"), rounding=decimal.ROUND_HALF_UP)
|
122 |
+
|
123 |
+
return PriceInfo(
|
124 |
+
unit_price=unit_price,
|
125 |
+
unit=price_config.unit,
|
126 |
+
total_amount=total_amount,
|
127 |
+
currency=price_config.currency,
|
128 |
+
)
|
129 |
+
|
130 |
+
def predefined_models(self) -> list[AIModelEntity]:
|
131 |
+
"""
|
132 |
+
Get all predefined models for given provider.
|
133 |
+
|
134 |
+
:return:
|
135 |
+
"""
|
136 |
+
if self.model_schemas:
|
137 |
+
return self.model_schemas
|
138 |
+
|
139 |
+
model_schemas = []
|
140 |
+
|
141 |
+
# get module name
|
142 |
+
model_type = self.__class__.__module__.split(".")[-1]
|
143 |
+
|
144 |
+
# get provider name
|
145 |
+
provider_name = self.__class__.__module__.split(".")[-3]
|
146 |
+
|
147 |
+
# get the path of current classes
|
148 |
+
current_path = os.path.abspath(__file__)
|
149 |
+
# get parent path of the current path
|
150 |
+
provider_model_type_path = os.path.join(
|
151 |
+
os.path.dirname(os.path.dirname(current_path)), provider_name, model_type
|
152 |
+
)
|
153 |
+
|
154 |
+
# get all yaml files path under provider_model_type_path that do not start with __
|
155 |
+
model_schema_yaml_paths = [
|
156 |
+
os.path.join(provider_model_type_path, model_schema_yaml)
|
157 |
+
for model_schema_yaml in os.listdir(provider_model_type_path)
|
158 |
+
if not model_schema_yaml.startswith("__")
|
159 |
+
and not model_schema_yaml.startswith("_")
|
160 |
+
and os.path.isfile(os.path.join(provider_model_type_path, model_schema_yaml))
|
161 |
+
and model_schema_yaml.endswith(".yaml")
|
162 |
+
]
|
163 |
+
|
164 |
+
# get _position.yaml file path
|
165 |
+
position_map = get_position_map(provider_model_type_path)
|
166 |
+
|
167 |
+
# traverse all model_schema_yaml_paths
|
168 |
+
for model_schema_yaml_path in model_schema_yaml_paths:
|
169 |
+
# read yaml data from yaml file
|
170 |
+
yaml_data = load_yaml_file(model_schema_yaml_path)
|
171 |
+
|
172 |
+
new_parameter_rules = []
|
173 |
+
for parameter_rule in yaml_data.get("parameter_rules", []):
|
174 |
+
if "use_template" in parameter_rule:
|
175 |
+
try:
|
176 |
+
default_parameter_name = DefaultParameterName.value_of(parameter_rule["use_template"])
|
177 |
+
default_parameter_rule = self._get_default_parameter_rule_variable_map(default_parameter_name)
|
178 |
+
copy_default_parameter_rule = default_parameter_rule.copy()
|
179 |
+
copy_default_parameter_rule.update(parameter_rule)
|
180 |
+
parameter_rule = copy_default_parameter_rule
|
181 |
+
except ValueError:
|
182 |
+
pass
|
183 |
+
|
184 |
+
if "label" not in parameter_rule:
|
185 |
+
parameter_rule["label"] = {"zh_Hans": parameter_rule["name"], "en_US": parameter_rule["name"]}
|
186 |
+
|
187 |
+
new_parameter_rules.append(parameter_rule)
|
188 |
+
|
189 |
+
yaml_data["parameter_rules"] = new_parameter_rules
|
190 |
+
|
191 |
+
if "label" not in yaml_data:
|
192 |
+
yaml_data["label"] = {"zh_Hans": yaml_data["model"], "en_US": yaml_data["model"]}
|
193 |
+
|
194 |
+
yaml_data["fetch_from"] = FetchFrom.PREDEFINED_MODEL.value
|
195 |
+
|
196 |
+
try:
|
197 |
+
# yaml_data to entity
|
198 |
+
model_schema = AIModelEntity(**yaml_data)
|
199 |
+
except Exception as e:
|
200 |
+
model_schema_yaml_file_name = os.path.basename(model_schema_yaml_path).rstrip(".yaml")
|
201 |
+
raise Exception(
|
202 |
+
f"Invalid model schema for {provider_name}.{model_type}.{model_schema_yaml_file_name}: {str(e)}"
|
203 |
+
)
|
204 |
+
|
205 |
+
# cache model schema
|
206 |
+
model_schemas.append(model_schema)
|
207 |
+
|
208 |
+
# resort model schemas by position
|
209 |
+
model_schemas = sort_by_position_map(position_map, model_schemas, lambda x: x.model)
|
210 |
+
|
211 |
+
# cache model schemas
|
212 |
+
self.model_schemas = model_schemas
|
213 |
+
|
214 |
+
return model_schemas
|
215 |
+
|
216 |
+
def get_model_schema(self, model: str, credentials: Optional[dict] = None) -> Optional[AIModelEntity]:
|
217 |
+
"""
|
218 |
+
Get model schema by model name and credentials
|
219 |
+
|
220 |
+
:param model: model name
|
221 |
+
:param credentials: model credentials
|
222 |
+
:return: model schema
|
223 |
+
"""
|
224 |
+
# Try to get model schema from predefined models
|
225 |
+
for predefined_model in self.predefined_models():
|
226 |
+
if model == predefined_model.model:
|
227 |
+
return predefined_model
|
228 |
+
|
229 |
+
# Try to get model schema from credentials
|
230 |
+
if credentials:
|
231 |
+
model_schema = self.get_customizable_model_schema_from_credentials(model, credentials)
|
232 |
+
if model_schema:
|
233 |
+
return model_schema
|
234 |
+
|
235 |
+
return None
|
236 |
+
|
237 |
+
def get_customizable_model_schema_from_credentials(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
238 |
+
"""
|
239 |
+
Get customizable model schema from credentials
|
240 |
+
|
241 |
+
:param model: model name
|
242 |
+
:param credentials: model credentials
|
243 |
+
:return: model schema
|
244 |
+
"""
|
245 |
+
return self._get_customizable_model_schema(model, credentials)
|
246 |
+
|
247 |
+
def _get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
248 |
+
"""
|
249 |
+
Get customizable model schema and fill in the template
|
250 |
+
"""
|
251 |
+
schema = self.get_customizable_model_schema(model, credentials)
|
252 |
+
|
253 |
+
if not schema:
|
254 |
+
return None
|
255 |
+
|
256 |
+
# fill in the template
|
257 |
+
new_parameter_rules = []
|
258 |
+
for parameter_rule in schema.parameter_rules:
|
259 |
+
if parameter_rule.use_template:
|
260 |
+
try:
|
261 |
+
default_parameter_name = DefaultParameterName.value_of(parameter_rule.use_template)
|
262 |
+
default_parameter_rule = self._get_default_parameter_rule_variable_map(default_parameter_name)
|
263 |
+
if not parameter_rule.max and "max" in default_parameter_rule:
|
264 |
+
parameter_rule.max = default_parameter_rule["max"]
|
265 |
+
if not parameter_rule.min and "min" in default_parameter_rule:
|
266 |
+
parameter_rule.min = default_parameter_rule["min"]
|
267 |
+
if not parameter_rule.default and "default" in default_parameter_rule:
|
268 |
+
parameter_rule.default = default_parameter_rule["default"]
|
269 |
+
if not parameter_rule.precision and "precision" in default_parameter_rule:
|
270 |
+
parameter_rule.precision = default_parameter_rule["precision"]
|
271 |
+
if not parameter_rule.required and "required" in default_parameter_rule:
|
272 |
+
parameter_rule.required = default_parameter_rule["required"]
|
273 |
+
if not parameter_rule.help and "help" in default_parameter_rule:
|
274 |
+
parameter_rule.help = I18nObject(
|
275 |
+
en_US=default_parameter_rule["help"]["en_US"],
|
276 |
+
)
|
277 |
+
if (
|
278 |
+
parameter_rule.help
|
279 |
+
and not parameter_rule.help.en_US
|
280 |
+
and ("help" in default_parameter_rule and "en_US" in default_parameter_rule["help"])
|
281 |
+
):
|
282 |
+
parameter_rule.help.en_US = default_parameter_rule["help"]["en_US"]
|
283 |
+
if (
|
284 |
+
parameter_rule.help
|
285 |
+
and not parameter_rule.help.zh_Hans
|
286 |
+
and ("help" in default_parameter_rule and "zh_Hans" in default_parameter_rule["help"])
|
287 |
+
):
|
288 |
+
parameter_rule.help.zh_Hans = default_parameter_rule["help"].get(
|
289 |
+
"zh_Hans", default_parameter_rule["help"]["en_US"]
|
290 |
+
)
|
291 |
+
except ValueError:
|
292 |
+
pass
|
293 |
+
|
294 |
+
new_parameter_rules.append(parameter_rule)
|
295 |
+
|
296 |
+
schema.parameter_rules = new_parameter_rules
|
297 |
+
|
298 |
+
return schema
|
299 |
+
|
300 |
+
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
301 |
+
"""
|
302 |
+
Get customizable model schema
|
303 |
+
|
304 |
+
:param model: model name
|
305 |
+
:param credentials: model credentials
|
306 |
+
:return: model schema
|
307 |
+
"""
|
308 |
+
return None
|
309 |
+
|
310 |
+
def _get_default_parameter_rule_variable_map(self, name: DefaultParameterName) -> dict:
|
311 |
+
"""
|
312 |
+
Get default parameter rule for given name
|
313 |
+
|
314 |
+
:param name: parameter name
|
315 |
+
:return: parameter rule
|
316 |
+
"""
|
317 |
+
default_parameter_rule = PARAMETER_RULE_TEMPLATE.get(name)
|
318 |
+
|
319 |
+
if not default_parameter_rule:
|
320 |
+
raise Exception(f"Invalid model parameter rule name {name}")
|
321 |
+
|
322 |
+
return default_parameter_rule
|
323 |
+
|
324 |
+
def _get_num_tokens_by_gpt2(self, text: str) -> int:
|
325 |
+
"""
|
326 |
+
Get number of tokens for given prompt messages by gpt2
|
327 |
+
Some provider models do not provide an interface for obtaining the number of tokens.
|
328 |
+
Here, the gpt2 tokenizer is used to calculate the number of tokens.
|
329 |
+
This method can be executed offline, and the gpt2 tokenizer has been cached in the project.
|
330 |
+
|
331 |
+
:param text: plain text of prompt. You need to convert the original message to plain text
|
332 |
+
:return: number of tokens
|
333 |
+
"""
|
334 |
+
return GPT2Tokenizer.get_num_tokens(text)
|
api/core/model_runtime/model_providers/__base/audio.mp3
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:29b714073410fefc10ecb80526b5c7c33df73b0830ff0e7778d5065a6cfcae3e
|
3 |
+
size 218880
|
api/core/model_runtime/model_providers/__base/large_language_model.py
ADDED
@@ -0,0 +1,904 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import re
|
3 |
+
import time
|
4 |
+
from abc import abstractmethod
|
5 |
+
from collections.abc import Generator, Sequence
|
6 |
+
from typing import Optional, Union
|
7 |
+
|
8 |
+
from pydantic import ConfigDict
|
9 |
+
|
10 |
+
from configs import dify_config
|
11 |
+
from core.model_runtime.callbacks.base_callback import Callback
|
12 |
+
from core.model_runtime.callbacks.logging_callback import LoggingCallback
|
13 |
+
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
14 |
+
from core.model_runtime.entities.message_entities import (
|
15 |
+
AssistantPromptMessage,
|
16 |
+
PromptMessage,
|
17 |
+
PromptMessageContentType,
|
18 |
+
PromptMessageTool,
|
19 |
+
SystemPromptMessage,
|
20 |
+
UserPromptMessage,
|
21 |
+
)
|
22 |
+
from core.model_runtime.entities.model_entities import (
|
23 |
+
ModelPropertyKey,
|
24 |
+
ModelType,
|
25 |
+
ParameterRule,
|
26 |
+
ParameterType,
|
27 |
+
PriceType,
|
28 |
+
)
|
29 |
+
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
30 |
+
|
31 |
+
logger = logging.getLogger(__name__)
|
32 |
+
|
33 |
+
|
34 |
+
class LargeLanguageModel(AIModel):
|
35 |
+
"""
|
36 |
+
Model class for large language model.
|
37 |
+
"""
|
38 |
+
|
39 |
+
model_type: ModelType = ModelType.LLM
|
40 |
+
|
41 |
+
# pydantic configs
|
42 |
+
model_config = ConfigDict(protected_namespaces=())
|
43 |
+
|
44 |
+
def invoke(
|
45 |
+
self,
|
46 |
+
model: str,
|
47 |
+
credentials: dict,
|
48 |
+
prompt_messages: list[PromptMessage],
|
49 |
+
model_parameters: Optional[dict] = None,
|
50 |
+
tools: Optional[list[PromptMessageTool]] = None,
|
51 |
+
stop: Optional[list[str]] = None,
|
52 |
+
stream: bool = True,
|
53 |
+
user: Optional[str] = None,
|
54 |
+
callbacks: Optional[list[Callback]] = None,
|
55 |
+
) -> Union[LLMResult, Generator]:
|
56 |
+
"""
|
57 |
+
Invoke large language model
|
58 |
+
|
59 |
+
:param model: model name
|
60 |
+
:param credentials: model credentials
|
61 |
+
:param prompt_messages: prompt messages
|
62 |
+
:param model_parameters: model parameters
|
63 |
+
:param tools: tools for tool calling
|
64 |
+
:param stop: stop words
|
65 |
+
:param stream: is stream response
|
66 |
+
:param user: unique user id
|
67 |
+
:param callbacks: callbacks
|
68 |
+
:return: full response or stream response chunk generator result
|
69 |
+
"""
|
70 |
+
# validate and filter model parameters
|
71 |
+
if model_parameters is None:
|
72 |
+
model_parameters = {}
|
73 |
+
|
74 |
+
model_parameters = self._validate_and_filter_model_parameters(model, model_parameters, credentials)
|
75 |
+
|
76 |
+
self.started_at = time.perf_counter()
|
77 |
+
|
78 |
+
callbacks = callbacks or []
|
79 |
+
|
80 |
+
if dify_config.DEBUG:
|
81 |
+
callbacks.append(LoggingCallback())
|
82 |
+
|
83 |
+
# trigger before invoke callbacks
|
84 |
+
self._trigger_before_invoke_callbacks(
|
85 |
+
model=model,
|
86 |
+
credentials=credentials,
|
87 |
+
prompt_messages=prompt_messages,
|
88 |
+
model_parameters=model_parameters,
|
89 |
+
tools=tools,
|
90 |
+
stop=stop,
|
91 |
+
stream=stream,
|
92 |
+
user=user,
|
93 |
+
callbacks=callbacks,
|
94 |
+
)
|
95 |
+
|
96 |
+
try:
|
97 |
+
if "response_format" in model_parameters and model_parameters["response_format"] in {"JSON", "XML"}:
|
98 |
+
result = self._code_block_mode_wrapper(
|
99 |
+
model=model,
|
100 |
+
credentials=credentials,
|
101 |
+
prompt_messages=prompt_messages,
|
102 |
+
model_parameters=model_parameters,
|
103 |
+
tools=tools,
|
104 |
+
stop=stop,
|
105 |
+
stream=stream,
|
106 |
+
user=user,
|
107 |
+
callbacks=callbacks,
|
108 |
+
)
|
109 |
+
else:
|
110 |
+
result = self._invoke(
|
111 |
+
model=model,
|
112 |
+
credentials=credentials,
|
113 |
+
prompt_messages=prompt_messages,
|
114 |
+
model_parameters=model_parameters,
|
115 |
+
tools=tools,
|
116 |
+
stop=stop,
|
117 |
+
stream=stream,
|
118 |
+
user=user,
|
119 |
+
)
|
120 |
+
except Exception as e:
|
121 |
+
self._trigger_invoke_error_callbacks(
|
122 |
+
model=model,
|
123 |
+
ex=e,
|
124 |
+
credentials=credentials,
|
125 |
+
prompt_messages=prompt_messages,
|
126 |
+
model_parameters=model_parameters,
|
127 |
+
tools=tools,
|
128 |
+
stop=stop,
|
129 |
+
stream=stream,
|
130 |
+
user=user,
|
131 |
+
callbacks=callbacks,
|
132 |
+
)
|
133 |
+
|
134 |
+
raise self._transform_invoke_error(e)
|
135 |
+
|
136 |
+
if stream and isinstance(result, Generator):
|
137 |
+
return self._invoke_result_generator(
|
138 |
+
model=model,
|
139 |
+
result=result,
|
140 |
+
credentials=credentials,
|
141 |
+
prompt_messages=prompt_messages,
|
142 |
+
model_parameters=model_parameters,
|
143 |
+
tools=tools,
|
144 |
+
stop=stop,
|
145 |
+
stream=stream,
|
146 |
+
user=user,
|
147 |
+
callbacks=callbacks,
|
148 |
+
)
|
149 |
+
elif isinstance(result, LLMResult):
|
150 |
+
self._trigger_after_invoke_callbacks(
|
151 |
+
model=model,
|
152 |
+
result=result,
|
153 |
+
credentials=credentials,
|
154 |
+
prompt_messages=prompt_messages,
|
155 |
+
model_parameters=model_parameters,
|
156 |
+
tools=tools,
|
157 |
+
stop=stop,
|
158 |
+
stream=stream,
|
159 |
+
user=user,
|
160 |
+
callbacks=callbacks,
|
161 |
+
)
|
162 |
+
|
163 |
+
return result
|
164 |
+
|
165 |
+
def _code_block_mode_wrapper(
|
166 |
+
self,
|
167 |
+
model: str,
|
168 |
+
credentials: dict,
|
169 |
+
prompt_messages: list[PromptMessage],
|
170 |
+
model_parameters: dict,
|
171 |
+
tools: Optional[list[PromptMessageTool]] = None,
|
172 |
+
stop: Optional[Sequence[str]] = None,
|
173 |
+
stream: bool = True,
|
174 |
+
user: Optional[str] = None,
|
175 |
+
callbacks: Optional[list[Callback]] = None,
|
176 |
+
) -> Union[LLMResult, Generator]:
|
177 |
+
"""
|
178 |
+
Code block mode wrapper, ensure the response is a code block with output markdown quote
|
179 |
+
|
180 |
+
:param model: model name
|
181 |
+
:param credentials: model credentials
|
182 |
+
:param prompt_messages: prompt messages
|
183 |
+
:param model_parameters: model parameters
|
184 |
+
:param tools: tools for tool calling
|
185 |
+
:param stop: stop words
|
186 |
+
:param stream: is stream response
|
187 |
+
:param user: unique user id
|
188 |
+
:param callbacks: callbacks
|
189 |
+
:return: full response or stream response chunk generator result
|
190 |
+
"""
|
191 |
+
|
192 |
+
block_prompts = """You should always follow the instructions and output a valid {{block}} object.
|
193 |
+
The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
|
194 |
+
if you are not sure about the structure.
|
195 |
+
|
196 |
+
<instructions>
|
197 |
+
{{instructions}}
|
198 |
+
</instructions>
|
199 |
+
""" # noqa: E501
|
200 |
+
|
201 |
+
code_block = model_parameters.get("response_format", "")
|
202 |
+
if not code_block:
|
203 |
+
return self._invoke(
|
204 |
+
model=model,
|
205 |
+
credentials=credentials,
|
206 |
+
prompt_messages=prompt_messages,
|
207 |
+
model_parameters=model_parameters,
|
208 |
+
tools=tools,
|
209 |
+
stop=stop,
|
210 |
+
stream=stream,
|
211 |
+
user=user,
|
212 |
+
)
|
213 |
+
|
214 |
+
model_parameters.pop("response_format")
|
215 |
+
stop = list(stop) if stop is not None else []
|
216 |
+
stop.extend(["\n```", "```\n"])
|
217 |
+
block_prompts = block_prompts.replace("{{block}}", code_block)
|
218 |
+
|
219 |
+
# check if there is a system message
|
220 |
+
if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage):
|
221 |
+
# override the system message
|
222 |
+
prompt_messages[0] = SystemPromptMessage(
|
223 |
+
content=block_prompts.replace("{{instructions}}", str(prompt_messages[0].content))
|
224 |
+
)
|
225 |
+
else:
|
226 |
+
# insert the system message
|
227 |
+
prompt_messages.insert(
|
228 |
+
0,
|
229 |
+
SystemPromptMessage(
|
230 |
+
content=block_prompts.replace("{{instructions}}", f"Please output a valid {code_block} object.")
|
231 |
+
),
|
232 |
+
)
|
233 |
+
|
234 |
+
if len(prompt_messages) > 0 and isinstance(prompt_messages[-1], UserPromptMessage):
|
235 |
+
# add ```JSON\n to the last text message
|
236 |
+
if isinstance(prompt_messages[-1].content, str):
|
237 |
+
prompt_messages[-1].content += f"\n```{code_block}\n"
|
238 |
+
elif isinstance(prompt_messages[-1].content, list):
|
239 |
+
for i in range(len(prompt_messages[-1].content) - 1, -1, -1):
|
240 |
+
if prompt_messages[-1].content[i].type == PromptMessageContentType.TEXT:
|
241 |
+
prompt_messages[-1].content[i].data += f"\n```{code_block}\n"
|
242 |
+
break
|
243 |
+
else:
|
244 |
+
# append a user message
|
245 |
+
prompt_messages.append(UserPromptMessage(content=f"```{code_block}\n"))
|
246 |
+
|
247 |
+
response = self._invoke(
|
248 |
+
model=model,
|
249 |
+
credentials=credentials,
|
250 |
+
prompt_messages=prompt_messages,
|
251 |
+
model_parameters=model_parameters,
|
252 |
+
tools=tools,
|
253 |
+
stop=stop,
|
254 |
+
stream=stream,
|
255 |
+
user=user,
|
256 |
+
)
|
257 |
+
|
258 |
+
if isinstance(response, Generator):
|
259 |
+
first_chunk = next(response)
|
260 |
+
|
261 |
+
def new_generator():
|
262 |
+
yield first_chunk
|
263 |
+
yield from response
|
264 |
+
|
265 |
+
if first_chunk.delta.message.content and first_chunk.delta.message.content.startswith("`"):
|
266 |
+
return self._code_block_mode_stream_processor_with_backtick(
|
267 |
+
model=model, prompt_messages=prompt_messages, input_generator=new_generator()
|
268 |
+
)
|
269 |
+
else:
|
270 |
+
return self._code_block_mode_stream_processor(
|
271 |
+
model=model, prompt_messages=prompt_messages, input_generator=new_generator()
|
272 |
+
)
|
273 |
+
|
274 |
+
return response
|
275 |
+
|
276 |
+
def _code_block_mode_stream_processor(
|
277 |
+
self, model: str, prompt_messages: list[PromptMessage], input_generator: Generator[LLMResultChunk, None, None]
|
278 |
+
) -> Generator[LLMResultChunk, None, None]:
|
279 |
+
"""
|
280 |
+
Code block mode stream processor, ensure the response is a code block with output markdown quote
|
281 |
+
|
282 |
+
:param model: model name
|
283 |
+
:param prompt_messages: prompt messages
|
284 |
+
:param input_generator: input generator
|
285 |
+
:return: output generator
|
286 |
+
"""
|
287 |
+
state = "normal"
|
288 |
+
backtick_count = 0
|
289 |
+
for piece in input_generator:
|
290 |
+
if piece.delta.message.content:
|
291 |
+
content = piece.delta.message.content
|
292 |
+
piece.delta.message.content = ""
|
293 |
+
yield piece
|
294 |
+
content_piece = content
|
295 |
+
else:
|
296 |
+
yield piece
|
297 |
+
continue
|
298 |
+
new_piece: str = ""
|
299 |
+
for char in content_piece:
|
300 |
+
char = str(char)
|
301 |
+
if state == "normal":
|
302 |
+
if char == "`":
|
303 |
+
state = "in_backticks"
|
304 |
+
backtick_count = 1
|
305 |
+
else:
|
306 |
+
new_piece += char
|
307 |
+
elif state == "in_backticks":
|
308 |
+
if char == "`":
|
309 |
+
backtick_count += 1
|
310 |
+
if backtick_count == 3:
|
311 |
+
state = "skip_content"
|
312 |
+
backtick_count = 0
|
313 |
+
else:
|
314 |
+
new_piece += "`" * backtick_count + char
|
315 |
+
state = "normal"
|
316 |
+
backtick_count = 0
|
317 |
+
elif state == "skip_content":
|
318 |
+
if char.isspace():
|
319 |
+
state = "normal"
|
320 |
+
|
321 |
+
if new_piece:
|
322 |
+
yield LLMResultChunk(
|
323 |
+
model=model,
|
324 |
+
prompt_messages=prompt_messages,
|
325 |
+
delta=LLMResultChunkDelta(
|
326 |
+
index=0,
|
327 |
+
message=AssistantPromptMessage(content=new_piece, tool_calls=[]),
|
328 |
+
),
|
329 |
+
)
|
330 |
+
|
331 |
+
def _code_block_mode_stream_processor_with_backtick(
|
332 |
+
self, model: str, prompt_messages: list, input_generator: Generator[LLMResultChunk, None, None]
|
333 |
+
) -> Generator[LLMResultChunk, None, None]:
|
334 |
+
"""
|
335 |
+
Code block mode stream processor, ensure the response is a code block with output markdown quote.
|
336 |
+
This version skips the language identifier that follows the opening triple backticks.
|
337 |
+
|
338 |
+
:param model: model name
|
339 |
+
:param prompt_messages: prompt messages
|
340 |
+
:param input_generator: input generator
|
341 |
+
:return: output generator
|
342 |
+
"""
|
343 |
+
state = "search_start"
|
344 |
+
backtick_count = 0
|
345 |
+
|
346 |
+
for piece in input_generator:
|
347 |
+
if piece.delta.message.content:
|
348 |
+
content = piece.delta.message.content
|
349 |
+
# Reset content to ensure we're only processing and yielding the relevant parts
|
350 |
+
piece.delta.message.content = ""
|
351 |
+
# Yield a piece with cleared content before processing it to maintain the generator structure
|
352 |
+
yield piece
|
353 |
+
content_piece = content
|
354 |
+
else:
|
355 |
+
# Yield pieces without content directly
|
356 |
+
yield piece
|
357 |
+
continue
|
358 |
+
|
359 |
+
if state == "done":
|
360 |
+
continue
|
361 |
+
|
362 |
+
new_piece: str = ""
|
363 |
+
for char in content_piece:
|
364 |
+
if state == "search_start":
|
365 |
+
if char == "`":
|
366 |
+
backtick_count += 1
|
367 |
+
if backtick_count == 3:
|
368 |
+
state = "skip_language"
|
369 |
+
backtick_count = 0
|
370 |
+
else:
|
371 |
+
backtick_count = 0
|
372 |
+
elif state == "skip_language":
|
373 |
+
# Skip everything until the first newline, marking the end of the language identifier
|
374 |
+
if char == "\n":
|
375 |
+
state = "in_code_block"
|
376 |
+
elif state == "in_code_block":
|
377 |
+
if char == "`":
|
378 |
+
backtick_count += 1
|
379 |
+
if backtick_count == 3:
|
380 |
+
state = "done"
|
381 |
+
break
|
382 |
+
else:
|
383 |
+
if backtick_count > 0:
|
384 |
+
# If backticks were counted but we're still collecting content, it was a false start
|
385 |
+
new_piece += "`" * backtick_count
|
386 |
+
backtick_count = 0
|
387 |
+
new_piece += str(char)
|
388 |
+
|
389 |
+
elif state == "done":
|
390 |
+
break
|
391 |
+
|
392 |
+
if new_piece:
|
393 |
+
# Only yield content collected within the code block
|
394 |
+
yield LLMResultChunk(
|
395 |
+
model=model,
|
396 |
+
prompt_messages=prompt_messages,
|
397 |
+
delta=LLMResultChunkDelta(
|
398 |
+
index=0,
|
399 |
+
message=AssistantPromptMessage(content=new_piece, tool_calls=[]),
|
400 |
+
),
|
401 |
+
)
|
402 |
+
|
403 |
+
def _wrap_thinking_by_reasoning_content(self, delta: dict, is_reasoning: bool) -> tuple[str, bool]:
|
404 |
+
"""
|
405 |
+
If the reasoning response is from delta.get("reasoning_content"), we wrap
|
406 |
+
it with HTML think tag.
|
407 |
+
|
408 |
+
:param delta: delta dictionary from LLM streaming response
|
409 |
+
:param is_reasoning: is reasoning
|
410 |
+
:return: tuple of (processed_content, is_reasoning)
|
411 |
+
"""
|
412 |
+
|
413 |
+
content = delta.get("content") or ""
|
414 |
+
reasoning_content = delta.get("reasoning_content")
|
415 |
+
|
416 |
+
if reasoning_content:
|
417 |
+
if not is_reasoning:
|
418 |
+
content = "<think>\n" + reasoning_content
|
419 |
+
is_reasoning = True
|
420 |
+
else:
|
421 |
+
content = reasoning_content
|
422 |
+
elif is_reasoning and content:
|
423 |
+
# do not end reasoning when content is empty
|
424 |
+
# there may be more reasoning_content later that follows previous reasoning closely
|
425 |
+
content = "\n</think>" + content
|
426 |
+
is_reasoning = False
|
427 |
+
return content, is_reasoning
|
428 |
+
|
429 |
+
def _invoke_result_generator(
|
430 |
+
self,
|
431 |
+
model: str,
|
432 |
+
result: Generator,
|
433 |
+
credentials: dict,
|
434 |
+
prompt_messages: list[PromptMessage],
|
435 |
+
model_parameters: dict,
|
436 |
+
tools: Optional[list[PromptMessageTool]] = None,
|
437 |
+
stop: Optional[Sequence[str]] = None,
|
438 |
+
stream: bool = True,
|
439 |
+
user: Optional[str] = None,
|
440 |
+
callbacks: Optional[list[Callback]] = None,
|
441 |
+
) -> Generator:
|
442 |
+
"""
|
443 |
+
Invoke result generator
|
444 |
+
|
445 |
+
:param result: result generator
|
446 |
+
:return: result generator
|
447 |
+
"""
|
448 |
+
callbacks = callbacks or []
|
449 |
+
prompt_message = AssistantPromptMessage(content="")
|
450 |
+
usage = None
|
451 |
+
system_fingerprint = None
|
452 |
+
real_model = model
|
453 |
+
|
454 |
+
try:
|
455 |
+
for chunk in result:
|
456 |
+
yield chunk
|
457 |
+
|
458 |
+
self._trigger_new_chunk_callbacks(
|
459 |
+
chunk=chunk,
|
460 |
+
model=model,
|
461 |
+
credentials=credentials,
|
462 |
+
prompt_messages=prompt_messages,
|
463 |
+
model_parameters=model_parameters,
|
464 |
+
tools=tools,
|
465 |
+
stop=stop,
|
466 |
+
stream=stream,
|
467 |
+
user=user,
|
468 |
+
callbacks=callbacks,
|
469 |
+
)
|
470 |
+
|
471 |
+
prompt_message.content += chunk.delta.message.content
|
472 |
+
real_model = chunk.model
|
473 |
+
if chunk.delta.usage:
|
474 |
+
usage = chunk.delta.usage
|
475 |
+
|
476 |
+
if chunk.system_fingerprint:
|
477 |
+
system_fingerprint = chunk.system_fingerprint
|
478 |
+
except Exception as e:
|
479 |
+
raise self._transform_invoke_error(e)
|
480 |
+
|
481 |
+
self._trigger_after_invoke_callbacks(
|
482 |
+
model=model,
|
483 |
+
result=LLMResult(
|
484 |
+
model=real_model,
|
485 |
+
prompt_messages=prompt_messages,
|
486 |
+
message=prompt_message,
|
487 |
+
usage=usage or LLMUsage.empty_usage(),
|
488 |
+
system_fingerprint=system_fingerprint,
|
489 |
+
),
|
490 |
+
credentials=credentials,
|
491 |
+
prompt_messages=prompt_messages,
|
492 |
+
model_parameters=model_parameters,
|
493 |
+
tools=tools,
|
494 |
+
stop=stop,
|
495 |
+
stream=stream,
|
496 |
+
user=user,
|
497 |
+
callbacks=callbacks,
|
498 |
+
)
|
499 |
+
|
500 |
+
@abstractmethod
|
501 |
+
def _invoke(
|
502 |
+
self,
|
503 |
+
model: str,
|
504 |
+
credentials: dict,
|
505 |
+
prompt_messages: list[PromptMessage],
|
506 |
+
model_parameters: dict,
|
507 |
+
tools: Optional[list[PromptMessageTool]] = None,
|
508 |
+
stop: Optional[Sequence[str]] = None,
|
509 |
+
stream: bool = True,
|
510 |
+
user: Optional[str] = None,
|
511 |
+
) -> Union[LLMResult, Generator]:
|
512 |
+
"""
|
513 |
+
Invoke large language model
|
514 |
+
|
515 |
+
:param model: model name
|
516 |
+
:param credentials: model credentials
|
517 |
+
:param prompt_messages: prompt messages
|
518 |
+
:param model_parameters: model parameters
|
519 |
+
:param tools: tools for tool calling
|
520 |
+
:param stop: stop words
|
521 |
+
:param stream: is stream response
|
522 |
+
:param user: unique user id
|
523 |
+
:return: full response or stream response chunk generator result
|
524 |
+
"""
|
525 |
+
raise NotImplementedError
|
526 |
+
|
527 |
+
@abstractmethod
|
528 |
+
def get_num_tokens(
|
529 |
+
self,
|
530 |
+
model: str,
|
531 |
+
credentials: dict,
|
532 |
+
prompt_messages: list[PromptMessage],
|
533 |
+
tools: Optional[list[PromptMessageTool]] = None,
|
534 |
+
) -> int:
|
535 |
+
"""
|
536 |
+
Get number of tokens for given prompt messages
|
537 |
+
|
538 |
+
:param model: model name
|
539 |
+
:param credentials: model credentials
|
540 |
+
:param prompt_messages: prompt messages
|
541 |
+
:param tools: tools for tool calling
|
542 |
+
:return:
|
543 |
+
"""
|
544 |
+
raise NotImplementedError
|
545 |
+
|
546 |
+
def enforce_stop_tokens(self, text: str, stop: list[str]) -> str:
|
547 |
+
"""Cut off the text as soon as any stop words occur."""
|
548 |
+
return re.split("|".join(stop), text, maxsplit=1)[0]
|
549 |
+
|
550 |
+
def get_parameter_rules(self, model: str, credentials: dict) -> list[ParameterRule]:
|
551 |
+
"""
|
552 |
+
Get parameter rules
|
553 |
+
|
554 |
+
:param model: model name
|
555 |
+
:param credentials: model credentials
|
556 |
+
:return: parameter rules
|
557 |
+
"""
|
558 |
+
model_schema = self.get_model_schema(model, credentials)
|
559 |
+
if model_schema:
|
560 |
+
return model_schema.parameter_rules
|
561 |
+
|
562 |
+
return []
|
563 |
+
|
564 |
+
def get_model_mode(self, model: str, credentials: Optional[dict] = None) -> LLMMode:
|
565 |
+
"""
|
566 |
+
Get model mode
|
567 |
+
|
568 |
+
:param model: model name
|
569 |
+
:param credentials: model credentials
|
570 |
+
:return: model mode
|
571 |
+
"""
|
572 |
+
model_schema = self.get_model_schema(model, credentials)
|
573 |
+
|
574 |
+
mode = LLMMode.CHAT
|
575 |
+
if model_schema and model_schema.model_properties.get(ModelPropertyKey.MODE):
|
576 |
+
mode = LLMMode.value_of(model_schema.model_properties[ModelPropertyKey.MODE])
|
577 |
+
|
578 |
+
return mode
|
579 |
+
|
580 |
+
def _calc_response_usage(
|
581 |
+
self, model: str, credentials: dict, prompt_tokens: int, completion_tokens: int
|
582 |
+
) -> LLMUsage:
|
583 |
+
"""
|
584 |
+
Calculate response usage
|
585 |
+
|
586 |
+
:param model: model name
|
587 |
+
:param credentials: model credentials
|
588 |
+
:param prompt_tokens: prompt tokens
|
589 |
+
:param completion_tokens: completion tokens
|
590 |
+
:return: usage
|
591 |
+
"""
|
592 |
+
# get prompt price info
|
593 |
+
prompt_price_info = self.get_price(
|
594 |
+
model=model,
|
595 |
+
credentials=credentials,
|
596 |
+
price_type=PriceType.INPUT,
|
597 |
+
tokens=prompt_tokens,
|
598 |
+
)
|
599 |
+
|
600 |
+
# get completion price info
|
601 |
+
completion_price_info = self.get_price(
|
602 |
+
model=model, credentials=credentials, price_type=PriceType.OUTPUT, tokens=completion_tokens
|
603 |
+
)
|
604 |
+
|
605 |
+
# transform usage
|
606 |
+
usage = LLMUsage(
|
607 |
+
prompt_tokens=prompt_tokens,
|
608 |
+
prompt_unit_price=prompt_price_info.unit_price,
|
609 |
+
prompt_price_unit=prompt_price_info.unit,
|
610 |
+
prompt_price=prompt_price_info.total_amount,
|
611 |
+
completion_tokens=completion_tokens,
|
612 |
+
completion_unit_price=completion_price_info.unit_price,
|
613 |
+
completion_price_unit=completion_price_info.unit,
|
614 |
+
completion_price=completion_price_info.total_amount,
|
615 |
+
total_tokens=prompt_tokens + completion_tokens,
|
616 |
+
total_price=prompt_price_info.total_amount + completion_price_info.total_amount,
|
617 |
+
currency=prompt_price_info.currency,
|
618 |
+
latency=time.perf_counter() - self.started_at,
|
619 |
+
)
|
620 |
+
|
621 |
+
return usage
|
622 |
+
|
623 |
+
def _trigger_before_invoke_callbacks(
|
624 |
+
self,
|
625 |
+
model: str,
|
626 |
+
credentials: dict,
|
627 |
+
prompt_messages: list[PromptMessage],
|
628 |
+
model_parameters: dict,
|
629 |
+
tools: Optional[list[PromptMessageTool]] = None,
|
630 |
+
stop: Optional[Sequence[str]] = None,
|
631 |
+
stream: bool = True,
|
632 |
+
user: Optional[str] = None,
|
633 |
+
callbacks: Optional[list[Callback]] = None,
|
634 |
+
) -> None:
|
635 |
+
"""
|
636 |
+
Trigger before invoke callbacks
|
637 |
+
|
638 |
+
:param model: model name
|
639 |
+
:param credentials: model credentials
|
640 |
+
:param prompt_messages: prompt messages
|
641 |
+
:param model_parameters: model parameters
|
642 |
+
:param tools: tools for tool calling
|
643 |
+
:param stop: stop words
|
644 |
+
:param stream: is stream response
|
645 |
+
:param user: unique user id
|
646 |
+
:param callbacks: callbacks
|
647 |
+
"""
|
648 |
+
if callbacks:
|
649 |
+
for callback in callbacks:
|
650 |
+
try:
|
651 |
+
callback.on_before_invoke(
|
652 |
+
llm_instance=self,
|
653 |
+
model=model,
|
654 |
+
credentials=credentials,
|
655 |
+
prompt_messages=prompt_messages,
|
656 |
+
model_parameters=model_parameters,
|
657 |
+
tools=tools,
|
658 |
+
stop=stop,
|
659 |
+
stream=stream,
|
660 |
+
user=user,
|
661 |
+
)
|
662 |
+
except Exception as e:
|
663 |
+
if callback.raise_error:
|
664 |
+
raise e
|
665 |
+
else:
|
666 |
+
logger.warning(f"Callback {callback.__class__.__name__} on_before_invoke failed with error {e}")
|
667 |
+
|
668 |
+
def _trigger_new_chunk_callbacks(
|
669 |
+
self,
|
670 |
+
chunk: LLMResultChunk,
|
671 |
+
model: str,
|
672 |
+
credentials: dict,
|
673 |
+
prompt_messages: list[PromptMessage],
|
674 |
+
model_parameters: dict,
|
675 |
+
tools: Optional[list[PromptMessageTool]] = None,
|
676 |
+
stop: Optional[Sequence[str]] = None,
|
677 |
+
stream: bool = True,
|
678 |
+
user: Optional[str] = None,
|
679 |
+
callbacks: Optional[list[Callback]] = None,
|
680 |
+
) -> None:
|
681 |
+
"""
|
682 |
+
Trigger new chunk callbacks
|
683 |
+
|
684 |
+
:param chunk: chunk
|
685 |
+
:param model: model name
|
686 |
+
:param credentials: model credentials
|
687 |
+
:param prompt_messages: prompt messages
|
688 |
+
:param model_parameters: model parameters
|
689 |
+
:param tools: tools for tool calling
|
690 |
+
:param stop: stop words
|
691 |
+
:param stream: is stream response
|
692 |
+
:param user: unique user id
|
693 |
+
"""
|
694 |
+
if callbacks:
|
695 |
+
for callback in callbacks:
|
696 |
+
try:
|
697 |
+
callback.on_new_chunk(
|
698 |
+
llm_instance=self,
|
699 |
+
chunk=chunk,
|
700 |
+
model=model,
|
701 |
+
credentials=credentials,
|
702 |
+
prompt_messages=prompt_messages,
|
703 |
+
model_parameters=model_parameters,
|
704 |
+
tools=tools,
|
705 |
+
stop=stop,
|
706 |
+
stream=stream,
|
707 |
+
user=user,
|
708 |
+
)
|
709 |
+
except Exception as e:
|
710 |
+
if callback.raise_error:
|
711 |
+
raise e
|
712 |
+
else:
|
713 |
+
logger.warning(f"Callback {callback.__class__.__name__} on_new_chunk failed with error {e}")
|
714 |
+
|
715 |
+
def _trigger_after_invoke_callbacks(
|
716 |
+
self,
|
717 |
+
model: str,
|
718 |
+
result: LLMResult,
|
719 |
+
credentials: dict,
|
720 |
+
prompt_messages: list[PromptMessage],
|
721 |
+
model_parameters: dict,
|
722 |
+
tools: Optional[list[PromptMessageTool]] = None,
|
723 |
+
stop: Optional[Sequence[str]] = None,
|
724 |
+
stream: bool = True,
|
725 |
+
user: Optional[str] = None,
|
726 |
+
callbacks: Optional[list[Callback]] = None,
|
727 |
+
) -> None:
|
728 |
+
"""
|
729 |
+
Trigger after invoke callbacks
|
730 |
+
|
731 |
+
:param model: model name
|
732 |
+
:param result: result
|
733 |
+
:param credentials: model credentials
|
734 |
+
:param prompt_messages: prompt messages
|
735 |
+
:param model_parameters: model parameters
|
736 |
+
:param tools: tools for tool calling
|
737 |
+
:param stop: stop words
|
738 |
+
:param stream: is stream response
|
739 |
+
:param user: unique user id
|
740 |
+
:param callbacks: callbacks
|
741 |
+
"""
|
742 |
+
if callbacks:
|
743 |
+
for callback in callbacks:
|
744 |
+
try:
|
745 |
+
callback.on_after_invoke(
|
746 |
+
llm_instance=self,
|
747 |
+
result=result,
|
748 |
+
model=model,
|
749 |
+
credentials=credentials,
|
750 |
+
prompt_messages=prompt_messages,
|
751 |
+
model_parameters=model_parameters,
|
752 |
+
tools=tools,
|
753 |
+
stop=stop,
|
754 |
+
stream=stream,
|
755 |
+
user=user,
|
756 |
+
)
|
757 |
+
except Exception as e:
|
758 |
+
if callback.raise_error:
|
759 |
+
raise e
|
760 |
+
else:
|
761 |
+
logger.warning(f"Callback {callback.__class__.__name__} on_after_invoke failed with error {e}")
|
762 |
+
|
763 |
+
def _trigger_invoke_error_callbacks(
|
764 |
+
self,
|
765 |
+
model: str,
|
766 |
+
ex: Exception,
|
767 |
+
credentials: dict,
|
768 |
+
prompt_messages: list[PromptMessage],
|
769 |
+
model_parameters: dict,
|
770 |
+
tools: Optional[list[PromptMessageTool]] = None,
|
771 |
+
stop: Optional[Sequence[str]] = None,
|
772 |
+
stream: bool = True,
|
773 |
+
user: Optional[str] = None,
|
774 |
+
callbacks: Optional[list[Callback]] = None,
|
775 |
+
) -> None:
|
776 |
+
"""
|
777 |
+
Trigger invoke error callbacks
|
778 |
+
|
779 |
+
:param model: model name
|
780 |
+
:param ex: exception
|
781 |
+
:param credentials: model credentials
|
782 |
+
:param prompt_messages: prompt messages
|
783 |
+
:param model_parameters: model parameters
|
784 |
+
:param tools: tools for tool calling
|
785 |
+
:param stop: stop words
|
786 |
+
:param stream: is stream response
|
787 |
+
:param user: unique user id
|
788 |
+
:param callbacks: callbacks
|
789 |
+
"""
|
790 |
+
if callbacks:
|
791 |
+
for callback in callbacks:
|
792 |
+
try:
|
793 |
+
callback.on_invoke_error(
|
794 |
+
llm_instance=self,
|
795 |
+
ex=ex,
|
796 |
+
model=model,
|
797 |
+
credentials=credentials,
|
798 |
+
prompt_messages=prompt_messages,
|
799 |
+
model_parameters=model_parameters,
|
800 |
+
tools=tools,
|
801 |
+
stop=stop,
|
802 |
+
stream=stream,
|
803 |
+
user=user,
|
804 |
+
)
|
805 |
+
except Exception as e:
|
806 |
+
if callback.raise_error:
|
807 |
+
raise e
|
808 |
+
else:
|
809 |
+
logger.warning(f"Callback {callback.__class__.__name__} on_invoke_error failed with error {e}")
|
810 |
+
|
811 |
+
def _validate_and_filter_model_parameters(self, model: str, model_parameters: dict, credentials: dict) -> dict:
|
812 |
+
"""
|
813 |
+
Validate model parameters
|
814 |
+
|
815 |
+
:param model: model name
|
816 |
+
:param model_parameters: model parameters
|
817 |
+
:param credentials: model credentials
|
818 |
+
:return:
|
819 |
+
"""
|
820 |
+
parameter_rules = self.get_parameter_rules(model, credentials)
|
821 |
+
|
822 |
+
# validate model parameters
|
823 |
+
filtered_model_parameters = {}
|
824 |
+
for parameter_rule in parameter_rules:
|
825 |
+
parameter_name = parameter_rule.name
|
826 |
+
parameter_value = model_parameters.get(parameter_name)
|
827 |
+
if parameter_value is None:
|
828 |
+
if parameter_rule.use_template and parameter_rule.use_template in model_parameters:
|
829 |
+
# if parameter value is None, use template value variable name instead
|
830 |
+
parameter_value = model_parameters[parameter_rule.use_template]
|
831 |
+
else:
|
832 |
+
if parameter_rule.required:
|
833 |
+
if parameter_rule.default is not None:
|
834 |
+
filtered_model_parameters[parameter_name] = parameter_rule.default
|
835 |
+
continue
|
836 |
+
else:
|
837 |
+
raise ValueError(f"Model Parameter {parameter_name} is required.")
|
838 |
+
else:
|
839 |
+
continue
|
840 |
+
|
841 |
+
# validate parameter value type
|
842 |
+
if parameter_rule.type == ParameterType.INT:
|
843 |
+
if not isinstance(parameter_value, int):
|
844 |
+
raise ValueError(f"Model Parameter {parameter_name} should be int.")
|
845 |
+
|
846 |
+
# validate parameter value range
|
847 |
+
if parameter_rule.min is not None and parameter_value < parameter_rule.min:
|
848 |
+
raise ValueError(
|
849 |
+
f"Model Parameter {parameter_name} should be greater than or equal to {parameter_rule.min}."
|
850 |
+
)
|
851 |
+
|
852 |
+
if parameter_rule.max is not None and parameter_value > parameter_rule.max:
|
853 |
+
raise ValueError(
|
854 |
+
f"Model Parameter {parameter_name} should be less than or equal to {parameter_rule.max}."
|
855 |
+
)
|
856 |
+
elif parameter_rule.type == ParameterType.FLOAT:
|
857 |
+
if not isinstance(parameter_value, float | int):
|
858 |
+
raise ValueError(f"Model Parameter {parameter_name} should be float.")
|
859 |
+
|
860 |
+
# validate parameter value precision
|
861 |
+
if parameter_rule.precision is not None:
|
862 |
+
if parameter_rule.precision == 0:
|
863 |
+
if parameter_value != int(parameter_value):
|
864 |
+
raise ValueError(f"Model Parameter {parameter_name} should be int.")
|
865 |
+
else:
|
866 |
+
if parameter_value != round(parameter_value, parameter_rule.precision):
|
867 |
+
raise ValueError(
|
868 |
+
f"Model Parameter {parameter_name} should be round to {parameter_rule.precision}"
|
869 |
+
f" decimal places."
|
870 |
+
)
|
871 |
+
|
872 |
+
# validate parameter value range
|
873 |
+
if parameter_rule.min is not None and parameter_value < parameter_rule.min:
|
874 |
+
raise ValueError(
|
875 |
+
f"Model Parameter {parameter_name} should be greater than or equal to {parameter_rule.min}."
|
876 |
+
)
|
877 |
+
|
878 |
+
if parameter_rule.max is not None and parameter_value > parameter_rule.max:
|
879 |
+
raise ValueError(
|
880 |
+
f"Model Parameter {parameter_name} should be less than or equal to {parameter_rule.max}."
|
881 |
+
)
|
882 |
+
elif parameter_rule.type == ParameterType.BOOLEAN:
|
883 |
+
if not isinstance(parameter_value, bool):
|
884 |
+
raise ValueError(f"Model Parameter {parameter_name} should be bool.")
|
885 |
+
elif parameter_rule.type == ParameterType.STRING:
|
886 |
+
if not isinstance(parameter_value, str):
|
887 |
+
raise ValueError(f"Model Parameter {parameter_name} should be string.")
|
888 |
+
|
889 |
+
# validate options
|
890 |
+
if parameter_rule.options and parameter_value not in parameter_rule.options:
|
891 |
+
raise ValueError(f"Model Parameter {parameter_name} should be one of {parameter_rule.options}.")
|
892 |
+
elif parameter_rule.type == ParameterType.TEXT:
|
893 |
+
if not isinstance(parameter_value, str):
|
894 |
+
raise ValueError(f"Model Parameter {parameter_name} should be text.")
|
895 |
+
|
896 |
+
# validate options
|
897 |
+
if parameter_rule.options and parameter_value not in parameter_rule.options:
|
898 |
+
raise ValueError(f"Model Parameter {parameter_name} should be one of {parameter_rule.options}.")
|
899 |
+
else:
|
900 |
+
raise ValueError(f"Model Parameter {parameter_name} type {parameter_rule.type} is not supported.")
|
901 |
+
|
902 |
+
filtered_model_parameters[parameter_name] = parameter_value
|
903 |
+
|
904 |
+
return filtered_model_parameters
|
api/core/model_runtime/model_providers/__base/model_provider.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from abc import ABC, abstractmethod
|
3 |
+
from typing import Optional
|
4 |
+
|
5 |
+
from core.helper.module_import_helper import get_subclasses_from_module, import_module_from_source
|
6 |
+
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
7 |
+
from core.model_runtime.entities.provider_entities import ProviderEntity
|
8 |
+
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
9 |
+
from core.tools.utils.yaml_utils import load_yaml_file
|
10 |
+
|
11 |
+
|
12 |
+
class ModelProvider(ABC):
|
13 |
+
provider_schema: Optional[ProviderEntity] = None
|
14 |
+
model_instance_map: dict[str, AIModel] = {}
|
15 |
+
|
16 |
+
@abstractmethod
|
17 |
+
def validate_provider_credentials(self, credentials: dict) -> None:
|
18 |
+
"""
|
19 |
+
Validate provider credentials
|
20 |
+
You can choose any validate_credentials method of model type or implement validate method by yourself,
|
21 |
+
such as: get model list api
|
22 |
+
|
23 |
+
if validate failed, raise exception
|
24 |
+
|
25 |
+
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
|
26 |
+
"""
|
27 |
+
raise NotImplementedError
|
28 |
+
|
29 |
+
def get_provider_schema(self) -> ProviderEntity:
|
30 |
+
"""
|
31 |
+
Get provider schema
|
32 |
+
|
33 |
+
:return: provider schema
|
34 |
+
"""
|
35 |
+
if self.provider_schema:
|
36 |
+
return self.provider_schema
|
37 |
+
|
38 |
+
# get dirname of the current path
|
39 |
+
provider_name = self.__class__.__module__.split(".")[-1]
|
40 |
+
|
41 |
+
# get the path of the model_provider classes
|
42 |
+
base_path = os.path.abspath(__file__)
|
43 |
+
current_path = os.path.join(os.path.dirname(os.path.dirname(base_path)), provider_name)
|
44 |
+
|
45 |
+
# read provider schema from yaml file
|
46 |
+
yaml_path = os.path.join(current_path, f"{provider_name}.yaml")
|
47 |
+
yaml_data = load_yaml_file(yaml_path)
|
48 |
+
|
49 |
+
try:
|
50 |
+
# yaml_data to entity
|
51 |
+
provider_schema = ProviderEntity(**yaml_data)
|
52 |
+
except Exception as e:
|
53 |
+
raise Exception(f"Invalid provider schema for {provider_name}: {str(e)}")
|
54 |
+
|
55 |
+
# cache schema
|
56 |
+
self.provider_schema = provider_schema
|
57 |
+
|
58 |
+
return provider_schema
|
59 |
+
|
60 |
+
def models(self, model_type: ModelType) -> list[AIModelEntity]:
|
61 |
+
"""
|
62 |
+
Get all models for given model type
|
63 |
+
|
64 |
+
:param model_type: model type defined in `ModelType`
|
65 |
+
:return: list of models
|
66 |
+
"""
|
67 |
+
provider_schema = self.get_provider_schema()
|
68 |
+
if model_type not in provider_schema.supported_model_types:
|
69 |
+
return []
|
70 |
+
|
71 |
+
# get model instance of the model type
|
72 |
+
model_instance = self.get_model_instance(model_type)
|
73 |
+
|
74 |
+
# get predefined models (predefined_models)
|
75 |
+
models = model_instance.predefined_models()
|
76 |
+
|
77 |
+
# return models
|
78 |
+
return models
|
79 |
+
|
80 |
+
def get_model_instance(self, model_type: ModelType) -> AIModel:
|
81 |
+
"""
|
82 |
+
Get model instance
|
83 |
+
|
84 |
+
:param model_type: model type defined in `ModelType`
|
85 |
+
:return:
|
86 |
+
"""
|
87 |
+
# get dirname of the current path
|
88 |
+
provider_name = self.__class__.__module__.split(".")[-1]
|
89 |
+
|
90 |
+
if f"{provider_name}.{model_type.value}" in self.model_instance_map:
|
91 |
+
return self.model_instance_map[f"{provider_name}.{model_type.value}"]
|
92 |
+
|
93 |
+
# get the path of the model type classes
|
94 |
+
base_path = os.path.abspath(__file__)
|
95 |
+
model_type_name = model_type.value.replace("-", "_")
|
96 |
+
model_type_path = os.path.join(os.path.dirname(os.path.dirname(base_path)), provider_name, model_type_name)
|
97 |
+
model_type_py_path = os.path.join(model_type_path, f"{model_type_name}.py")
|
98 |
+
|
99 |
+
if not os.path.isdir(model_type_path) or not os.path.exists(model_type_py_path):
|
100 |
+
raise Exception(f"Invalid model type {model_type} for provider {provider_name}")
|
101 |
+
|
102 |
+
# Dynamic loading {model_type_name}.py file and find the subclass of AIModel
|
103 |
+
parent_module = ".".join(self.__class__.__module__.split(".")[:-1])
|
104 |
+
mod = import_module_from_source(
|
105 |
+
module_name=f"{parent_module}.{model_type_name}.{model_type_name}", py_file_path=model_type_py_path
|
106 |
+
)
|
107 |
+
# FIXME "type" has no attribute "__abstractmethods__" ignore it for now fix it later
|
108 |
+
model_class = next(
|
109 |
+
filter(
|
110 |
+
lambda x: x.__module__ == mod.__name__ and not x.__abstractmethods__, # type: ignore
|
111 |
+
get_subclasses_from_module(mod, AIModel),
|
112 |
+
),
|
113 |
+
None,
|
114 |
+
)
|
115 |
+
if not model_class:
|
116 |
+
raise Exception(f"Missing AIModel Class for model type {model_type} in {model_type_py_path}")
|
117 |
+
|
118 |
+
model_instance_map = model_class()
|
119 |
+
self.model_instance_map[f"{provider_name}.{model_type.value}"] = model_instance_map
|
120 |
+
|
121 |
+
return model_instance_map
|
api/core/model_runtime/model_providers/__base/moderation_model.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
from abc import abstractmethod
|
3 |
+
from typing import Optional
|
4 |
+
|
5 |
+
from pydantic import ConfigDict
|
6 |
+
|
7 |
+
from core.model_runtime.entities.model_entities import ModelType
|
8 |
+
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
9 |
+
|
10 |
+
|
11 |
+
class ModerationModel(AIModel):
|
12 |
+
"""
|
13 |
+
Model class for moderation model.
|
14 |
+
"""
|
15 |
+
|
16 |
+
model_type: ModelType = ModelType.MODERATION
|
17 |
+
|
18 |
+
# pydantic configs
|
19 |
+
model_config = ConfigDict(protected_namespaces=())
|
20 |
+
|
21 |
+
def invoke(self, model: str, credentials: dict, text: str, user: Optional[str] = None) -> bool:
|
22 |
+
"""
|
23 |
+
Invoke moderation model
|
24 |
+
|
25 |
+
:param model: model name
|
26 |
+
:param credentials: model credentials
|
27 |
+
:param text: text to moderate
|
28 |
+
:param user: unique user id
|
29 |
+
:return: false if text is safe, true otherwise
|
30 |
+
"""
|
31 |
+
self.started_at = time.perf_counter()
|
32 |
+
|
33 |
+
try:
|
34 |
+
return self._invoke(model, credentials, text, user)
|
35 |
+
except Exception as e:
|
36 |
+
raise self._transform_invoke_error(e)
|
37 |
+
|
38 |
+
@abstractmethod
|
39 |
+
def _invoke(self, model: str, credentials: dict, text: str, user: Optional[str] = None) -> bool:
|
40 |
+
"""
|
41 |
+
Invoke large language model
|
42 |
+
|
43 |
+
:param model: model name
|
44 |
+
:param credentials: model credentials
|
45 |
+
:param text: text to moderate
|
46 |
+
:param user: unique user id
|
47 |
+
:return: false if text is safe, true otherwise
|
48 |
+
"""
|
49 |
+
raise NotImplementedError
|
api/core/model_runtime/model_providers/__base/rerank_model.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
from abc import abstractmethod
|
3 |
+
from typing import Optional
|
4 |
+
|
5 |
+
from core.model_runtime.entities.model_entities import ModelType
|
6 |
+
from core.model_runtime.entities.rerank_entities import RerankResult
|
7 |
+
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
8 |
+
|
9 |
+
|
10 |
+
class RerankModel(AIModel):
|
11 |
+
"""
|
12 |
+
Base Model class for rerank model.
|
13 |
+
"""
|
14 |
+
|
15 |
+
model_type: ModelType = ModelType.RERANK
|
16 |
+
|
17 |
+
def invoke(
|
18 |
+
self,
|
19 |
+
model: str,
|
20 |
+
credentials: dict,
|
21 |
+
query: str,
|
22 |
+
docs: list[str],
|
23 |
+
score_threshold: Optional[float] = None,
|
24 |
+
top_n: Optional[int] = None,
|
25 |
+
user: Optional[str] = None,
|
26 |
+
) -> RerankResult:
|
27 |
+
"""
|
28 |
+
Invoke rerank model
|
29 |
+
|
30 |
+
:param model: model name
|
31 |
+
:param credentials: model credentials
|
32 |
+
:param query: search query
|
33 |
+
:param docs: docs for reranking
|
34 |
+
:param score_threshold: score threshold
|
35 |
+
:param top_n: top n
|
36 |
+
:param user: unique user id
|
37 |
+
:return: rerank result
|
38 |
+
"""
|
39 |
+
self.started_at = time.perf_counter()
|
40 |
+
|
41 |
+
try:
|
42 |
+
return self._invoke(model, credentials, query, docs, score_threshold, top_n, user)
|
43 |
+
except Exception as e:
|
44 |
+
raise self._transform_invoke_error(e)
|
45 |
+
|
46 |
+
@abstractmethod
|
47 |
+
def _invoke(
|
48 |
+
self,
|
49 |
+
model: str,
|
50 |
+
credentials: dict,
|
51 |
+
query: str,
|
52 |
+
docs: list[str],
|
53 |
+
score_threshold: Optional[float] = None,
|
54 |
+
top_n: Optional[int] = None,
|
55 |
+
user: Optional[str] = None,
|
56 |
+
) -> RerankResult:
|
57 |
+
"""
|
58 |
+
Invoke rerank model
|
59 |
+
|
60 |
+
:param model: model name
|
61 |
+
:param credentials: model credentials
|
62 |
+
:param query: search query
|
63 |
+
:param docs: docs for reranking
|
64 |
+
:param score_threshold: score threshold
|
65 |
+
:param top_n: top n
|
66 |
+
:param user: unique user id
|
67 |
+
:return: rerank result
|
68 |
+
"""
|
69 |
+
raise NotImplementedError
|
api/core/model_runtime/model_providers/__base/speech2text_model.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from abc import abstractmethod
|
3 |
+
from typing import IO, Optional
|
4 |
+
|
5 |
+
from pydantic import ConfigDict
|
6 |
+
|
7 |
+
from core.model_runtime.entities.model_entities import ModelType
|
8 |
+
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
9 |
+
|
10 |
+
|
11 |
+
class Speech2TextModel(AIModel):
|
12 |
+
"""
|
13 |
+
Model class for speech2text model.
|
14 |
+
"""
|
15 |
+
|
16 |
+
model_type: ModelType = ModelType.SPEECH2TEXT
|
17 |
+
|
18 |
+
# pydantic configs
|
19 |
+
model_config = ConfigDict(protected_namespaces=())
|
20 |
+
|
21 |
+
def invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str:
|
22 |
+
"""
|
23 |
+
Invoke large language model
|
24 |
+
|
25 |
+
:param model: model name
|
26 |
+
:param credentials: model credentials
|
27 |
+
:param file: audio file
|
28 |
+
:param user: unique user id
|
29 |
+
:return: text for given audio file
|
30 |
+
"""
|
31 |
+
try:
|
32 |
+
return self._invoke(model, credentials, file, user)
|
33 |
+
except Exception as e:
|
34 |
+
raise self._transform_invoke_error(e)
|
35 |
+
|
36 |
+
@abstractmethod
|
37 |
+
def _invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str:
|
38 |
+
"""
|
39 |
+
Invoke large language model
|
40 |
+
|
41 |
+
:param model: model name
|
42 |
+
:param credentials: model credentials
|
43 |
+
:param file: audio file
|
44 |
+
:param user: unique user id
|
45 |
+
:return: text for given audio file
|
46 |
+
"""
|
47 |
+
raise NotImplementedError
|
48 |
+
|
49 |
+
def _get_demo_file_path(self) -> str:
|
50 |
+
"""
|
51 |
+
Get demo file for given model
|
52 |
+
|
53 |
+
:return: demo file
|
54 |
+
"""
|
55 |
+
# Get the directory of the current file
|
56 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
57 |
+
|
58 |
+
# Construct the path to the audio file
|
59 |
+
return os.path.join(current_dir, "audio.mp3")
|
api/core/model_runtime/model_providers/__base/text2img_model.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import abstractmethod
|
2 |
+
from typing import IO, Optional
|
3 |
+
|
4 |
+
from pydantic import ConfigDict
|
5 |
+
|
6 |
+
from core.model_runtime.entities.model_entities import ModelType
|
7 |
+
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
8 |
+
|
9 |
+
|
10 |
+
class Text2ImageModel(AIModel):
|
11 |
+
"""
|
12 |
+
Model class for text2img model.
|
13 |
+
"""
|
14 |
+
|
15 |
+
model_type: ModelType = ModelType.TEXT2IMG
|
16 |
+
|
17 |
+
# pydantic configs
|
18 |
+
model_config = ConfigDict(protected_namespaces=())
|
19 |
+
|
20 |
+
def invoke(
|
21 |
+
self, model: str, credentials: dict, prompt: str, model_parameters: dict, user: Optional[str] = None
|
22 |
+
) -> list[IO[bytes]]:
|
23 |
+
"""
|
24 |
+
Invoke Text2Image model
|
25 |
+
|
26 |
+
:param model: model name
|
27 |
+
:param credentials: model credentials
|
28 |
+
:param prompt: prompt for image generation
|
29 |
+
:param model_parameters: model parameters
|
30 |
+
:param user: unique user id
|
31 |
+
|
32 |
+
:return: image bytes
|
33 |
+
"""
|
34 |
+
try:
|
35 |
+
return self._invoke(model, credentials, prompt, model_parameters, user)
|
36 |
+
except Exception as e:
|
37 |
+
raise self._transform_invoke_error(e)
|
38 |
+
|
39 |
+
@abstractmethod
|
40 |
+
def _invoke(
|
41 |
+
self, model: str, credentials: dict, prompt: str, model_parameters: dict, user: Optional[str] = None
|
42 |
+
) -> list[IO[bytes]]:
|
43 |
+
"""
|
44 |
+
Invoke Text2Image model
|
45 |
+
|
46 |
+
:param model: model name
|
47 |
+
:param credentials: model credentials
|
48 |
+
:param prompt: prompt for image generation
|
49 |
+
:param model_parameters: model parameters
|
50 |
+
:param user: unique user id
|
51 |
+
|
52 |
+
:return: image bytes
|
53 |
+
"""
|
54 |
+
raise NotImplementedError
|
api/core/model_runtime/model_providers/__base/text_embedding_model.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
from abc import abstractmethod
|
3 |
+
from typing import Optional
|
4 |
+
|
5 |
+
from pydantic import ConfigDict
|
6 |
+
|
7 |
+
from core.entities.embedding_type import EmbeddingInputType
|
8 |
+
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
9 |
+
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
|
10 |
+
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
11 |
+
|
12 |
+
|
13 |
+
class TextEmbeddingModel(AIModel):
|
14 |
+
"""
|
15 |
+
Model class for text embedding model.
|
16 |
+
"""
|
17 |
+
|
18 |
+
model_type: ModelType = ModelType.TEXT_EMBEDDING
|
19 |
+
|
20 |
+
# pydantic configs
|
21 |
+
model_config = ConfigDict(protected_namespaces=())
|
22 |
+
|
23 |
+
def invoke(
|
24 |
+
self,
|
25 |
+
model: str,
|
26 |
+
credentials: dict,
|
27 |
+
texts: list[str],
|
28 |
+
user: Optional[str] = None,
|
29 |
+
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
30 |
+
) -> TextEmbeddingResult:
|
31 |
+
"""
|
32 |
+
Invoke text embedding model
|
33 |
+
|
34 |
+
:param model: model name
|
35 |
+
:param credentials: model credentials
|
36 |
+
:param texts: texts to embed
|
37 |
+
:param user: unique user id
|
38 |
+
:param input_type: input type
|
39 |
+
:return: embeddings result
|
40 |
+
"""
|
41 |
+
self.started_at = time.perf_counter()
|
42 |
+
|
43 |
+
try:
|
44 |
+
return self._invoke(model, credentials, texts, user, input_type)
|
45 |
+
except Exception as e:
|
46 |
+
raise self._transform_invoke_error(e)
|
47 |
+
|
48 |
+
@abstractmethod
|
49 |
+
def _invoke(
|
50 |
+
self,
|
51 |
+
model: str,
|
52 |
+
credentials: dict,
|
53 |
+
texts: list[str],
|
54 |
+
user: Optional[str] = None,
|
55 |
+
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
56 |
+
) -> TextEmbeddingResult:
|
57 |
+
"""
|
58 |
+
Invoke text embedding model
|
59 |
+
|
60 |
+
:param model: model name
|
61 |
+
:param credentials: model credentials
|
62 |
+
:param texts: texts to embed
|
63 |
+
:param user: unique user id
|
64 |
+
:param input_type: input type
|
65 |
+
:return: embeddings result
|
66 |
+
"""
|
67 |
+
raise NotImplementedError
|
68 |
+
|
69 |
+
@abstractmethod
|
70 |
+
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
71 |
+
"""
|
72 |
+
Get number of tokens for given prompt messages
|
73 |
+
|
74 |
+
:param model: model name
|
75 |
+
:param credentials: model credentials
|
76 |
+
:param texts: texts to embed
|
77 |
+
:return:
|
78 |
+
"""
|
79 |
+
raise NotImplementedError
|
80 |
+
|
81 |
+
def _get_context_size(self, model: str, credentials: dict) -> int:
|
82 |
+
"""
|
83 |
+
Get context size for given embedding model
|
84 |
+
|
85 |
+
:param model: model name
|
86 |
+
:param credentials: model credentials
|
87 |
+
:return: context size
|
88 |
+
"""
|
89 |
+
model_schema = self.get_model_schema(model, credentials)
|
90 |
+
|
91 |
+
if model_schema and ModelPropertyKey.CONTEXT_SIZE in model_schema.model_properties:
|
92 |
+
content_size: int = model_schema.model_properties[ModelPropertyKey.CONTEXT_SIZE]
|
93 |
+
return content_size
|
94 |
+
|
95 |
+
return 1000
|
96 |
+
|
97 |
+
def _get_max_chunks(self, model: str, credentials: dict) -> int:
|
98 |
+
"""
|
99 |
+
Get max chunks for given embedding model
|
100 |
+
|
101 |
+
:param model: model name
|
102 |
+
:param credentials: model credentials
|
103 |
+
:return: max chunks
|
104 |
+
"""
|
105 |
+
model_schema = self.get_model_schema(model, credentials)
|
106 |
+
|
107 |
+
if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties:
|
108 |
+
max_chunks: int = model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS]
|
109 |
+
return max_chunks
|
110 |
+
|
111 |
+
return 1
|
api/core/model_runtime/model_providers/__base/tokenizers/gpt2/merges.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
api/core/model_runtime/model_providers/__base/tokenizers/gpt2/special_tokens_map.json
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"bos_token": {
|
3 |
+
"content": "<|endoftext|>",
|
4 |
+
"lstrip": false,
|
5 |
+
"normalized": true,
|
6 |
+
"rstrip": false,
|
7 |
+
"single_word": false
|
8 |
+
},
|
9 |
+
"eos_token": {
|
10 |
+
"content": "<|endoftext|>",
|
11 |
+
"lstrip": false,
|
12 |
+
"normalized": true,
|
13 |
+
"rstrip": false,
|
14 |
+
"single_word": false
|
15 |
+
},
|
16 |
+
"unk_token": {
|
17 |
+
"content": "<|endoftext|>",
|
18 |
+
"lstrip": false,
|
19 |
+
"normalized": true,
|
20 |
+
"rstrip": false,
|
21 |
+
"single_word": false
|
22 |
+
}
|
23 |
+
}
|
api/core/model_runtime/model_providers/__base/tokenizers/gpt2/tokenizer_config.json
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"add_bos_token": false,
|
3 |
+
"add_prefix_space": false,
|
4 |
+
"bos_token": {
|
5 |
+
"__type": "AddedToken",
|
6 |
+
"content": "<|endoftext|>",
|
7 |
+
"lstrip": false,
|
8 |
+
"normalized": true,
|
9 |
+
"rstrip": false,
|
10 |
+
"single_word": false
|
11 |
+
},
|
12 |
+
"clean_up_tokenization_spaces": true,
|
13 |
+
"eos_token": {
|
14 |
+
"__type": "AddedToken",
|
15 |
+
"content": "<|endoftext|>",
|
16 |
+
"lstrip": false,
|
17 |
+
"normalized": true,
|
18 |
+
"rstrip": false,
|
19 |
+
"single_word": false
|
20 |
+
},
|
21 |
+
"errors": "replace",
|
22 |
+
"model_max_length": 1024,
|
23 |
+
"pad_token": null,
|
24 |
+
"tokenizer_class": "GPT2Tokenizer",
|
25 |
+
"unk_token": {
|
26 |
+
"__type": "AddedToken",
|
27 |
+
"content": "<|endoftext|>",
|
28 |
+
"lstrip": false,
|
29 |
+
"normalized": true,
|
30 |
+
"rstrip": false,
|
31 |
+
"single_word": false
|
32 |
+
}
|
33 |
+
}
|
api/core/model_runtime/model_providers/__base/tokenizers/gpt2/vocab.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenzier.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from threading import Lock
|
3 |
+
from typing import Any
|
4 |
+
|
5 |
+
logger = logging.getLogger(__name__)
|
6 |
+
|
7 |
+
_tokenizer: Any = None
|
8 |
+
_lock = Lock()
|
9 |
+
|
10 |
+
|
11 |
+
class GPT2Tokenizer:
|
12 |
+
@staticmethod
|
13 |
+
def _get_num_tokens_by_gpt2(text: str) -> int:
|
14 |
+
"""
|
15 |
+
use gpt2 tokenizer to get num tokens
|
16 |
+
"""
|
17 |
+
_tokenizer = GPT2Tokenizer.get_encoder()
|
18 |
+
tokens = _tokenizer.encode(text)
|
19 |
+
return len(tokens)
|
20 |
+
|
21 |
+
@staticmethod
|
22 |
+
def get_num_tokens(text: str) -> int:
|
23 |
+
# Because this process needs more cpu resource, we turn this back before we find a better way to handle it.
|
24 |
+
#
|
25 |
+
# future = _executor.submit(GPT2Tokenizer._get_num_tokens_by_gpt2, text)
|
26 |
+
# result = future.result()
|
27 |
+
# return cast(int, result)
|
28 |
+
return GPT2Tokenizer._get_num_tokens_by_gpt2(text)
|
29 |
+
|
30 |
+
@staticmethod
|
31 |
+
def get_encoder() -> Any:
|
32 |
+
global _tokenizer, _lock
|
33 |
+
with _lock:
|
34 |
+
if _tokenizer is None:
|
35 |
+
# Try to use tiktoken to get the tokenizer because it is faster
|
36 |
+
#
|
37 |
+
try:
|
38 |
+
import tiktoken
|
39 |
+
|
40 |
+
_tokenizer = tiktoken.get_encoding("gpt2")
|
41 |
+
except Exception:
|
42 |
+
from os.path import abspath, dirname, join
|
43 |
+
|
44 |
+
from transformers import GPT2Tokenizer as TransformerGPT2Tokenizer # type: ignore
|
45 |
+
|
46 |
+
base_path = abspath(__file__)
|
47 |
+
gpt2_tokenizer_path = join(dirname(base_path), "gpt2")
|
48 |
+
_tokenizer = TransformerGPT2Tokenizer.from_pretrained(gpt2_tokenizer_path)
|
49 |
+
logger.info("Fallback to Transformers' GPT-2 tokenizer from tiktoken")
|
50 |
+
|
51 |
+
return _tokenizer
|
api/core/model_runtime/model_providers/__base/tts_model.py
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import re
|
3 |
+
from abc import abstractmethod
|
4 |
+
from collections.abc import Iterable
|
5 |
+
from typing import Any, Optional
|
6 |
+
|
7 |
+
from pydantic import ConfigDict
|
8 |
+
|
9 |
+
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
10 |
+
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
11 |
+
|
12 |
+
logger = logging.getLogger(__name__)
|
13 |
+
|
14 |
+
|
15 |
+
class TTSModel(AIModel):
|
16 |
+
"""
|
17 |
+
Model class for TTS model.
|
18 |
+
"""
|
19 |
+
|
20 |
+
model_type: ModelType = ModelType.TTS
|
21 |
+
|
22 |
+
# pydantic configs
|
23 |
+
model_config = ConfigDict(protected_namespaces=())
|
24 |
+
|
25 |
+
def invoke(
|
26 |
+
self,
|
27 |
+
model: str,
|
28 |
+
tenant_id: str,
|
29 |
+
credentials: dict,
|
30 |
+
content_text: str,
|
31 |
+
voice: str,
|
32 |
+
user: Optional[str] = None,
|
33 |
+
) -> Iterable[bytes]:
|
34 |
+
"""
|
35 |
+
Invoke large language model
|
36 |
+
|
37 |
+
:param model: model name
|
38 |
+
:param tenant_id: user tenant id
|
39 |
+
:param credentials: model credentials
|
40 |
+
:param voice: model timbre
|
41 |
+
:param content_text: text content to be translated
|
42 |
+
:param streaming: output is streaming
|
43 |
+
:param user: unique user id
|
44 |
+
:return: translated audio file
|
45 |
+
"""
|
46 |
+
try:
|
47 |
+
return self._invoke(
|
48 |
+
model=model,
|
49 |
+
credentials=credentials,
|
50 |
+
user=user,
|
51 |
+
content_text=content_text,
|
52 |
+
voice=voice,
|
53 |
+
tenant_id=tenant_id,
|
54 |
+
)
|
55 |
+
except Exception as e:
|
56 |
+
raise self._transform_invoke_error(e)
|
57 |
+
|
58 |
+
@abstractmethod
|
59 |
+
def _invoke(
|
60 |
+
self,
|
61 |
+
model: str,
|
62 |
+
tenant_id: str,
|
63 |
+
credentials: dict,
|
64 |
+
content_text: str,
|
65 |
+
voice: str,
|
66 |
+
user: Optional[str] = None,
|
67 |
+
) -> Iterable[bytes]:
|
68 |
+
"""
|
69 |
+
Invoke large language model
|
70 |
+
|
71 |
+
:param model: model name
|
72 |
+
:param tenant_id: user tenant id
|
73 |
+
:param credentials: model credentials
|
74 |
+
:param voice: model timbre
|
75 |
+
:param content_text: text content to be translated
|
76 |
+
:param streaming: output is streaming
|
77 |
+
:param user: unique user id
|
78 |
+
:return: translated audio file
|
79 |
+
"""
|
80 |
+
raise NotImplementedError
|
81 |
+
|
82 |
+
def get_tts_model_voices(self, model: str, credentials: dict, language: Optional[str] = None) -> list:
|
83 |
+
"""
|
84 |
+
Retrieves the list of voices supported by a given text-to-speech (TTS) model.
|
85 |
+
|
86 |
+
:param language: The language for which the voices are requested.
|
87 |
+
:param model: The name of the TTS model.
|
88 |
+
:param credentials: The credentials required to access the TTS model.
|
89 |
+
:return: A list of voices supported by the TTS model.
|
90 |
+
"""
|
91 |
+
model_schema = self.get_model_schema(model, credentials)
|
92 |
+
|
93 |
+
if not model_schema or ModelPropertyKey.VOICES not in model_schema.model_properties:
|
94 |
+
raise ValueError("this model does not support voice")
|
95 |
+
|
96 |
+
voices = model_schema.model_properties[ModelPropertyKey.VOICES]
|
97 |
+
if language:
|
98 |
+
return [
|
99 |
+
{"name": d["name"], "value": d["mode"]} for d in voices if language and language in d.get("language")
|
100 |
+
]
|
101 |
+
else:
|
102 |
+
return [{"name": d["name"], "value": d["mode"]} for d in voices]
|
103 |
+
|
104 |
+
def _get_model_default_voice(self, model: str, credentials: dict) -> Any:
|
105 |
+
"""
|
106 |
+
Get voice for given tts model
|
107 |
+
|
108 |
+
:param model: model name
|
109 |
+
:param credentials: model credentials
|
110 |
+
:return: voice
|
111 |
+
"""
|
112 |
+
model_schema = self.get_model_schema(model, credentials)
|
113 |
+
|
114 |
+
if model_schema and ModelPropertyKey.DEFAULT_VOICE in model_schema.model_properties:
|
115 |
+
return model_schema.model_properties[ModelPropertyKey.DEFAULT_VOICE]
|
116 |
+
|
117 |
+
def _get_model_audio_type(self, model: str, credentials: dict) -> str:
|
118 |
+
"""
|
119 |
+
Get audio type for given tts model
|
120 |
+
|
121 |
+
:param model: model name
|
122 |
+
:param credentials: model credentials
|
123 |
+
:return: voice
|
124 |
+
"""
|
125 |
+
model_schema = self.get_model_schema(model, credentials)
|
126 |
+
|
127 |
+
if not model_schema or ModelPropertyKey.AUDIO_TYPE not in model_schema.model_properties:
|
128 |
+
raise ValueError("this model does not support audio type")
|
129 |
+
|
130 |
+
audio_type: str = model_schema.model_properties[ModelPropertyKey.AUDIO_TYPE]
|
131 |
+
return audio_type
|
132 |
+
|
133 |
+
def _get_model_word_limit(self, model: str, credentials: dict) -> int:
|
134 |
+
"""
|
135 |
+
Get audio type for given tts model
|
136 |
+
:return: audio type
|
137 |
+
"""
|
138 |
+
model_schema = self.get_model_schema(model, credentials)
|
139 |
+
|
140 |
+
if not model_schema or ModelPropertyKey.WORD_LIMIT not in model_schema.model_properties:
|
141 |
+
raise ValueError("this model does not support word limit")
|
142 |
+
world_limit: int = model_schema.model_properties[ModelPropertyKey.WORD_LIMIT]
|
143 |
+
|
144 |
+
return world_limit
|
145 |
+
|
146 |
+
def _get_model_workers_limit(self, model: str, credentials: dict) -> int:
|
147 |
+
"""
|
148 |
+
Get audio max workers for given tts model
|
149 |
+
:return: audio type
|
150 |
+
"""
|
151 |
+
model_schema = self.get_model_schema(model, credentials)
|
152 |
+
|
153 |
+
if not model_schema or ModelPropertyKey.MAX_WORKERS not in model_schema.model_properties:
|
154 |
+
raise ValueError("this model does not support max workers")
|
155 |
+
workers_limit: int = model_schema.model_properties[ModelPropertyKey.MAX_WORKERS]
|
156 |
+
|
157 |
+
return workers_limit
|
158 |
+
|
159 |
+
@staticmethod
|
160 |
+
def _split_text_into_sentences(org_text, max_length=2000, pattern=r"[。.!?]"):
|
161 |
+
match = re.compile(pattern)
|
162 |
+
tx = match.finditer(org_text)
|
163 |
+
start = 0
|
164 |
+
result = []
|
165 |
+
one_sentence = ""
|
166 |
+
for i in tx:
|
167 |
+
end = i.regs[0][1]
|
168 |
+
tmp = org_text[start:end]
|
169 |
+
if len(one_sentence + tmp) > max_length:
|
170 |
+
result.append(one_sentence)
|
171 |
+
one_sentence = ""
|
172 |
+
one_sentence += tmp
|
173 |
+
start = end
|
174 |
+
last_sens = org_text[start:]
|
175 |
+
if last_sens:
|
176 |
+
one_sentence += last_sens
|
177 |
+
if one_sentence != "":
|
178 |
+
result.append(one_sentence)
|
179 |
+
return result
|
api/core/model_runtime/model_providers/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
2 |
+
|
3 |
+
model_provider_factory = ModelProviderFactory()
|
api/core/model_runtime/model_providers/_position.yaml
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
- openai
|
2 |
+
- deepseek
|
3 |
+
- anthropic
|
4 |
+
- azure_openai
|
5 |
+
- google
|
6 |
+
- vertex_ai
|
7 |
+
- nvidia
|
8 |
+
- nvidia_nim
|
9 |
+
- cohere
|
10 |
+
- upstage
|
11 |
+
- bedrock
|
12 |
+
- togetherai
|
13 |
+
- openrouter
|
14 |
+
- ollama
|
15 |
+
- mistralai
|
16 |
+
- groq
|
17 |
+
- replicate
|
18 |
+
- huggingface_hub
|
19 |
+
- xinference
|
20 |
+
- triton_inference_server
|
21 |
+
- zhipuai
|
22 |
+
- baichuan
|
23 |
+
- spark
|
24 |
+
- minimax
|
25 |
+
- tongyi
|
26 |
+
- wenxin
|
27 |
+
- moonshot
|
28 |
+
- tencent
|
29 |
+
- jina
|
30 |
+
- chatglm
|
31 |
+
- yi
|
32 |
+
- openllm
|
33 |
+
- localai
|
34 |
+
- volcengine_maas
|
35 |
+
- openai_api_compatible
|
36 |
+
- hunyuan
|
37 |
+
- siliconflow
|
38 |
+
- perfxcloud
|
39 |
+
- zhinao
|
40 |
+
- fireworks
|
41 |
+
- mixedbread
|
42 |
+
- nomic
|
43 |
+
- voyage
|
api/core/model_runtime/model_providers/anthropic/__init__.py
ADDED
File without changes
|
api/core/model_runtime/model_providers/anthropic/_assets/icon_l_en.svg
ADDED
|
api/core/model_runtime/model_providers/anthropic/_assets/icon_s_en.svg
ADDED
|
api/core/model_runtime/model_providers/anthropic/anthropic.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
|
3 |
+
from core.model_runtime.entities.model_entities import ModelType
|
4 |
+
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
5 |
+
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
6 |
+
|
7 |
+
logger = logging.getLogger(__name__)
|
8 |
+
|
9 |
+
|
10 |
+
class AnthropicProvider(ModelProvider):
|
11 |
+
def validate_provider_credentials(self, credentials: dict) -> None:
|
12 |
+
"""
|
13 |
+
Validate provider credentials
|
14 |
+
|
15 |
+
if validate failed, raise exception
|
16 |
+
|
17 |
+
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
|
18 |
+
"""
|
19 |
+
try:
|
20 |
+
model_instance = self.get_model_instance(ModelType.LLM)
|
21 |
+
|
22 |
+
# Use `claude-3-opus-20240229` model for validate,
|
23 |
+
model_instance.validate_credentials(model="claude-3-opus-20240229", credentials=credentials)
|
24 |
+
except CredentialsValidateFailedError as ex:
|
25 |
+
raise ex
|
26 |
+
except Exception as ex:
|
27 |
+
logger.exception(f"{self.get_provider_schema().provider} credentials validate failed")
|
28 |
+
raise ex
|
api/core/model_runtime/model_providers/anthropic/anthropic.yaml
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
provider: anthropic
|
2 |
+
label:
|
3 |
+
en_US: Anthropic
|
4 |
+
description:
|
5 |
+
en_US: Anthropic’s powerful models, such as Claude 3.
|
6 |
+
zh_Hans: Anthropic 的强大模型,例如 Claude 3。
|
7 |
+
icon_small:
|
8 |
+
en_US: icon_s_en.svg
|
9 |
+
icon_large:
|
10 |
+
en_US: icon_l_en.svg
|
11 |
+
background: "#F0F0EB"
|
12 |
+
help:
|
13 |
+
title:
|
14 |
+
en_US: Get your API Key from Anthropic
|
15 |
+
zh_Hans: 从 Anthropic 获取 API Key
|
16 |
+
url:
|
17 |
+
en_US: https://console.anthropic.com/account/keys
|
18 |
+
supported_model_types:
|
19 |
+
- llm
|
20 |
+
configurate_methods:
|
21 |
+
- predefined-model
|
22 |
+
provider_credential_schema:
|
23 |
+
credential_form_schemas:
|
24 |
+
- variable: anthropic_api_key
|
25 |
+
label:
|
26 |
+
en_US: API Key
|
27 |
+
type: secret-input
|
28 |
+
required: true
|
29 |
+
placeholder:
|
30 |
+
zh_Hans: 在此输入您的 API Key
|
31 |
+
en_US: Enter your API Key
|
32 |
+
- variable: anthropic_api_url
|
33 |
+
label:
|
34 |
+
en_US: API URL
|
35 |
+
type: text-input
|
36 |
+
required: false
|
37 |
+
placeholder:
|
38 |
+
zh_Hans: 在此输入您的 API URL
|
39 |
+
en_US: Enter your API URL
|
api/core/model_runtime/model_providers/anthropic/llm/__init__.py
ADDED
File without changes
|
api/core/model_runtime/model_providers/anthropic/llm/_position.yaml
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
- claude-3-5-haiku-20241022
|
2 |
+
- claude-3-5-sonnet-20241022
|
3 |
+
- claude-3-5-sonnet-20240620
|
4 |
+
- claude-3-haiku-20240307
|
5 |
+
- claude-3-opus-20240229
|
6 |
+
- claude-3-sonnet-20240229
|
7 |
+
- claude-2.1
|
8 |
+
- claude-instant-1.2
|
9 |
+
- claude-2
|
10 |
+
- claude-instant-1
|
api/core/model_runtime/model_providers/anthropic/llm/claude-2.1.yaml
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model: claude-2.1
|
2 |
+
label:
|
3 |
+
en_US: claude-2.1
|
4 |
+
model_type: llm
|
5 |
+
features:
|
6 |
+
- agent-thought
|
7 |
+
model_properties:
|
8 |
+
mode: chat
|
9 |
+
context_size: 200000
|
10 |
+
parameter_rules:
|
11 |
+
- name: temperature
|
12 |
+
use_template: temperature
|
13 |
+
- name: top_p
|
14 |
+
use_template: top_p
|
15 |
+
- name: top_k
|
16 |
+
label:
|
17 |
+
zh_Hans: 取样数量
|
18 |
+
en_US: Top k
|
19 |
+
type: int
|
20 |
+
help:
|
21 |
+
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
22 |
+
en_US: Only sample from the top K options for each subsequent token.
|
23 |
+
required: false
|
24 |
+
- name: max_tokens_to_sample
|
25 |
+
use_template: max_tokens
|
26 |
+
required: true
|
27 |
+
default: 4096
|
28 |
+
min: 1
|
29 |
+
max: 4096
|
30 |
+
- name: response_format
|
31 |
+
use_template: response_format
|
32 |
+
pricing:
|
33 |
+
input: '8.00'
|
34 |
+
output: '24.00'
|
35 |
+
unit: '0.000001'
|
36 |
+
currency: USD
|
api/core/model_runtime/model_providers/anthropic/llm/claude-2.yaml
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model: claude-2
|
2 |
+
label:
|
3 |
+
en_US: claude-2
|
4 |
+
model_type: llm
|
5 |
+
features:
|
6 |
+
- agent-thought
|
7 |
+
model_properties:
|
8 |
+
mode: chat
|
9 |
+
context_size: 100000
|
10 |
+
parameter_rules:
|
11 |
+
- name: temperature
|
12 |
+
use_template: temperature
|
13 |
+
- name: top_p
|
14 |
+
use_template: top_p
|
15 |
+
- name: top_k
|
16 |
+
label:
|
17 |
+
zh_Hans: 取样数量
|
18 |
+
en_US: Top k
|
19 |
+
type: int
|
20 |
+
help:
|
21 |
+
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
22 |
+
en_US: Only sample from the top K options for each subsequent token.
|
23 |
+
required: false
|
24 |
+
- name: max_tokens_to_sample
|
25 |
+
use_template: max_tokens
|
26 |
+
required: true
|
27 |
+
default: 4096
|
28 |
+
min: 1
|
29 |
+
max: 4096
|
30 |
+
- name: response_format
|
31 |
+
use_template: response_format
|
32 |
+
pricing:
|
33 |
+
input: '8.00'
|
34 |
+
output: '24.00'
|
35 |
+
unit: '0.000001'
|
36 |
+
currency: USD
|
37 |
+
deprecated: true
|
api/core/model_runtime/model_providers/anthropic/llm/claude-3-5-haiku-20241022.yaml
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model: claude-3-5-haiku-20241022
|
2 |
+
label:
|
3 |
+
en_US: claude-3-5-haiku-20241022
|
4 |
+
model_type: llm
|
5 |
+
features:
|
6 |
+
- agent-thought
|
7 |
+
- tool-call
|
8 |
+
- stream-tool-call
|
9 |
+
model_properties:
|
10 |
+
mode: chat
|
11 |
+
context_size: 200000
|
12 |
+
parameter_rules:
|
13 |
+
- name: temperature
|
14 |
+
use_template: temperature
|
15 |
+
- name: top_p
|
16 |
+
use_template: top_p
|
17 |
+
- name: top_k
|
18 |
+
label:
|
19 |
+
zh_Hans: 取样数量
|
20 |
+
en_US: Top k
|
21 |
+
type: int
|
22 |
+
help:
|
23 |
+
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
24 |
+
en_US: Only sample from the top K options for each subsequent token.
|
25 |
+
required: false
|
26 |
+
- name: max_tokens
|
27 |
+
use_template: max_tokens
|
28 |
+
required: true
|
29 |
+
default: 8192
|
30 |
+
min: 1
|
31 |
+
max: 8192
|
32 |
+
- name: response_format
|
33 |
+
use_template: response_format
|
34 |
+
pricing:
|
35 |
+
input: '1.00'
|
36 |
+
output: '5.00'
|
37 |
+
unit: '0.000001'
|
38 |
+
currency: USD
|
api/core/model_runtime/model_providers/anthropic/llm/claude-3-5-sonnet-20240620.yaml
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model: claude-3-5-sonnet-20240620
|
2 |
+
label:
|
3 |
+
en_US: claude-3-5-sonnet-20240620
|
4 |
+
model_type: llm
|
5 |
+
features:
|
6 |
+
- agent-thought
|
7 |
+
- vision
|
8 |
+
- tool-call
|
9 |
+
- stream-tool-call
|
10 |
+
- document
|
11 |
+
model_properties:
|
12 |
+
mode: chat
|
13 |
+
context_size: 200000
|
14 |
+
parameter_rules:
|
15 |
+
- name: temperature
|
16 |
+
use_template: temperature
|
17 |
+
- name: top_p
|
18 |
+
use_template: top_p
|
19 |
+
- name: top_k
|
20 |
+
label:
|
21 |
+
zh_Hans: 取样数量
|
22 |
+
en_US: Top k
|
23 |
+
type: int
|
24 |
+
help:
|
25 |
+
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
26 |
+
en_US: Only sample from the top K options for each subsequent token.
|
27 |
+
required: false
|
28 |
+
- name: max_tokens
|
29 |
+
use_template: max_tokens
|
30 |
+
required: true
|
31 |
+
default: 8192
|
32 |
+
min: 1
|
33 |
+
max: 8192
|
34 |
+
- name: response_format
|
35 |
+
use_template: response_format
|
36 |
+
pricing:
|
37 |
+
input: '3.00'
|
38 |
+
output: '15.00'
|
39 |
+
unit: '0.000001'
|
40 |
+
currency: USD
|
api/core/model_runtime/model_providers/anthropic/llm/claude-3-5-sonnet-20241022.yaml
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model: claude-3-5-sonnet-20241022
|
2 |
+
label:
|
3 |
+
en_US: claude-3-5-sonnet-20241022
|
4 |
+
model_type: llm
|
5 |
+
features:
|
6 |
+
- agent-thought
|
7 |
+
- vision
|
8 |
+
- tool-call
|
9 |
+
- stream-tool-call
|
10 |
+
- document
|
11 |
+
model_properties:
|
12 |
+
mode: chat
|
13 |
+
context_size: 200000
|
14 |
+
parameter_rules:
|
15 |
+
- name: temperature
|
16 |
+
use_template: temperature
|
17 |
+
- name: top_p
|
18 |
+
use_template: top_p
|
19 |
+
- name: top_k
|
20 |
+
label:
|
21 |
+
zh_Hans: 取样数量
|
22 |
+
en_US: Top k
|
23 |
+
type: int
|
24 |
+
help:
|
25 |
+
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
26 |
+
en_US: Only sample from the top K options for each subsequent token.
|
27 |
+
required: false
|
28 |
+
- name: max_tokens
|
29 |
+
use_template: max_tokens
|
30 |
+
required: true
|
31 |
+
default: 8192
|
32 |
+
min: 1
|
33 |
+
max: 8192
|
34 |
+
- name: response_format
|
35 |
+
use_template: response_format
|
36 |
+
pricing:
|
37 |
+
input: '3.00'
|
38 |
+
output: '15.00'
|
39 |
+
unit: '0.000001'
|
40 |
+
currency: USD
|
api/core/model_runtime/model_providers/anthropic/llm/claude-3-haiku-20240307.yaml
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model: claude-3-haiku-20240307
|
2 |
+
label:
|
3 |
+
en_US: claude-3-haiku-20240307
|
4 |
+
model_type: llm
|
5 |
+
features:
|
6 |
+
- agent-thought
|
7 |
+
- vision
|
8 |
+
- tool-call
|
9 |
+
- stream-tool-call
|
10 |
+
model_properties:
|
11 |
+
mode: chat
|
12 |
+
context_size: 200000
|
13 |
+
parameter_rules:
|
14 |
+
- name: temperature
|
15 |
+
use_template: temperature
|
16 |
+
- name: top_p
|
17 |
+
use_template: top_p
|
18 |
+
- name: top_k
|
19 |
+
label:
|
20 |
+
zh_Hans: 取样数量
|
21 |
+
en_US: Top k
|
22 |
+
type: int
|
23 |
+
help:
|
24 |
+
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
25 |
+
en_US: Only sample from the top K options for each subsequent token.
|
26 |
+
required: false
|
27 |
+
- name: max_tokens
|
28 |
+
use_template: max_tokens
|
29 |
+
required: true
|
30 |
+
default: 4096
|
31 |
+
min: 1
|
32 |
+
max: 4096
|
33 |
+
- name: response_format
|
34 |
+
use_template: response_format
|
35 |
+
pricing:
|
36 |
+
input: '0.25'
|
37 |
+
output: '1.25'
|
38 |
+
unit: '0.000001'
|
39 |
+
currency: USD
|
api/core/model_runtime/model_providers/anthropic/llm/claude-3-opus-20240229.yaml
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model: claude-3-opus-20240229
|
2 |
+
label:
|
3 |
+
en_US: claude-3-opus-20240229
|
4 |
+
model_type: llm
|
5 |
+
features:
|
6 |
+
- agent-thought
|
7 |
+
- vision
|
8 |
+
- tool-call
|
9 |
+
- stream-tool-call
|
10 |
+
model_properties:
|
11 |
+
mode: chat
|
12 |
+
context_size: 200000
|
13 |
+
parameter_rules:
|
14 |
+
- name: temperature
|
15 |
+
use_template: temperature
|
16 |
+
- name: top_p
|
17 |
+
use_template: top_p
|
18 |
+
- name: top_k
|
19 |
+
label:
|
20 |
+
zh_Hans: 取样数量
|
21 |
+
en_US: Top k
|
22 |
+
type: int
|
23 |
+
help:
|
24 |
+
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
25 |
+
en_US: Only sample from the top K options for each subsequent token.
|
26 |
+
required: false
|
27 |
+
- name: max_tokens
|
28 |
+
use_template: max_tokens
|
29 |
+
required: true
|
30 |
+
default: 4096
|
31 |
+
min: 1
|
32 |
+
max: 4096
|
33 |
+
- name: response_format
|
34 |
+
use_template: response_format
|
35 |
+
pricing:
|
36 |
+
input: '15.00'
|
37 |
+
output: '75.00'
|
38 |
+
unit: '0.000001'
|
39 |
+
currency: USD
|
api/core/model_runtime/model_providers/anthropic/llm/claude-3-sonnet-20240229.yaml
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model: claude-3-sonnet-20240229
|
2 |
+
label:
|
3 |
+
en_US: claude-3-sonnet-20240229
|
4 |
+
model_type: llm
|
5 |
+
features:
|
6 |
+
- agent-thought
|
7 |
+
- vision
|
8 |
+
- tool-call
|
9 |
+
- stream-tool-call
|
10 |
+
model_properties:
|
11 |
+
mode: chat
|
12 |
+
context_size: 200000
|
13 |
+
parameter_rules:
|
14 |
+
- name: temperature
|
15 |
+
use_template: temperature
|
16 |
+
- name: top_p
|
17 |
+
use_template: top_p
|
18 |
+
- name: top_k
|
19 |
+
label:
|
20 |
+
zh_Hans: 取样数量
|
21 |
+
en_US: Top k
|
22 |
+
type: int
|
23 |
+
help:
|
24 |
+
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
25 |
+
en_US: Only sample from the top K options for each subsequent token.
|
26 |
+
required: false
|
27 |
+
- name: max_tokens
|
28 |
+
use_template: max_tokens
|
29 |
+
required: true
|
30 |
+
default: 4096
|
31 |
+
min: 1
|
32 |
+
max: 4096
|
33 |
+
- name: response_format
|
34 |
+
use_template: response_format
|
35 |
+
pricing:
|
36 |
+
input: '3.00'
|
37 |
+
output: '15.00'
|
38 |
+
unit: '0.000001'
|
39 |
+
currency: USD
|
api/core/model_runtime/model_providers/anthropic/llm/claude-instant-1.2.yaml
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model: claude-instant-1.2
|
2 |
+
label:
|
3 |
+
en_US: claude-instant-1.2
|
4 |
+
model_type: llm
|
5 |
+
features: [ ]
|
6 |
+
model_properties:
|
7 |
+
mode: chat
|
8 |
+
context_size: 100000
|
9 |
+
parameter_rules:
|
10 |
+
- name: temperature
|
11 |
+
use_template: temperature
|
12 |
+
- name: top_p
|
13 |
+
use_template: top_p
|
14 |
+
- name: top_k
|
15 |
+
label:
|
16 |
+
zh_Hans: 取样数量
|
17 |
+
en_US: Top k
|
18 |
+
type: int
|
19 |
+
help:
|
20 |
+
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
21 |
+
en_US: Only sample from the top K options for each subsequent token.
|
22 |
+
required: false
|
23 |
+
- name: max_tokens
|
24 |
+
use_template: max_tokens
|
25 |
+
required: true
|
26 |
+
default: 4096
|
27 |
+
min: 1
|
28 |
+
max: 4096
|
29 |
+
- name: response_format
|
30 |
+
use_template: response_format
|
31 |
+
pricing:
|
32 |
+
input: '1.63'
|
33 |
+
output: '5.51'
|
34 |
+
unit: '0.000001'
|
35 |
+
currency: USD
|
36 |
+
deprecated: true
|
api/core/model_runtime/model_providers/anthropic/llm/claude-instant-1.yaml
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model: claude-instant-1
|
2 |
+
label:
|
3 |
+
en_US: claude-instant-1
|
4 |
+
model_type: llm
|
5 |
+
features: [ ]
|
6 |
+
model_properties:
|
7 |
+
mode: chat
|
8 |
+
context_size: 100000
|
9 |
+
parameter_rules:
|
10 |
+
- name: temperature
|
11 |
+
use_template: temperature
|
12 |
+
- name: top_p
|
13 |
+
use_template: top_p
|
14 |
+
- name: top_k
|
15 |
+
label:
|
16 |
+
zh_Hans: 取样数量
|
17 |
+
en_US: Top k
|
18 |
+
type: int
|
19 |
+
help:
|
20 |
+
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
21 |
+
en_US: Only sample from the top K options for each subsequent token.
|
22 |
+
required: false
|
23 |
+
- name: max_tokens_to_sample
|
24 |
+
use_template: max_tokens
|
25 |
+
required: true
|
26 |
+
default: 4096
|
27 |
+
min: 1
|
28 |
+
max: 4096
|
29 |
+
- name: response_format
|
30 |
+
use_template: response_format
|
31 |
+
pricing:
|
32 |
+
input: '1.63'
|
33 |
+
output: '5.51'
|
34 |
+
unit: '0.000001'
|
35 |
+
currency: USD
|
36 |
+
deprecated: true
|
api/core/model_runtime/model_providers/anthropic/llm/llm.py
ADDED
@@ -0,0 +1,654 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
import json
|
3 |
+
from collections.abc import Generator, Sequence
|
4 |
+
from typing import Optional, Union, cast
|
5 |
+
|
6 |
+
import anthropic
|
7 |
+
import requests
|
8 |
+
from anthropic import Anthropic, Stream
|
9 |
+
from anthropic.types import (
|
10 |
+
ContentBlockDeltaEvent,
|
11 |
+
Message,
|
12 |
+
MessageDeltaEvent,
|
13 |
+
MessageStartEvent,
|
14 |
+
MessageStopEvent,
|
15 |
+
MessageStreamEvent,
|
16 |
+
completion_create_params,
|
17 |
+
)
|
18 |
+
from anthropic.types.beta.tools import ToolsBetaMessage
|
19 |
+
from httpx import Timeout
|
20 |
+
|
21 |
+
from core.model_runtime.callbacks.base_callback import Callback
|
22 |
+
from core.model_runtime.entities import (
|
23 |
+
AssistantPromptMessage,
|
24 |
+
DocumentPromptMessageContent,
|
25 |
+
ImagePromptMessageContent,
|
26 |
+
PromptMessage,
|
27 |
+
PromptMessageContentType,
|
28 |
+
PromptMessageTool,
|
29 |
+
SystemPromptMessage,
|
30 |
+
TextPromptMessageContent,
|
31 |
+
ToolPromptMessage,
|
32 |
+
UserPromptMessage,
|
33 |
+
)
|
34 |
+
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
35 |
+
from core.model_runtime.errors.invoke import (
|
36 |
+
InvokeAuthorizationError,
|
37 |
+
InvokeBadRequestError,
|
38 |
+
InvokeConnectionError,
|
39 |
+
InvokeError,
|
40 |
+
InvokeRateLimitError,
|
41 |
+
InvokeServerUnavailableError,
|
42 |
+
)
|
43 |
+
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
44 |
+
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
45 |
+
|
46 |
+
ANTHROPIC_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object.
|
47 |
+
The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
|
48 |
+
if you are not sure about the structure.
|
49 |
+
|
50 |
+
<instructions>
|
51 |
+
{{instructions}}
|
52 |
+
</instructions>
|
53 |
+
""" # noqa: E501
|
54 |
+
|
55 |
+
|
56 |
+
class AnthropicLargeLanguageModel(LargeLanguageModel):
|
57 |
+
def _invoke(
|
58 |
+
self,
|
59 |
+
model: str,
|
60 |
+
credentials: dict,
|
61 |
+
prompt_messages: list[PromptMessage],
|
62 |
+
model_parameters: dict,
|
63 |
+
tools: Optional[list[PromptMessageTool]] = None,
|
64 |
+
stop: Optional[list[str]] = None,
|
65 |
+
stream: bool = True,
|
66 |
+
user: Optional[str] = None,
|
67 |
+
) -> Union[LLMResult, Generator]:
|
68 |
+
"""
|
69 |
+
Invoke large language model
|
70 |
+
|
71 |
+
:param model: model name
|
72 |
+
:param credentials: model credentials
|
73 |
+
:param prompt_messages: prompt messages
|
74 |
+
:param model_parameters: model parameters
|
75 |
+
:param tools: tools for tool calling
|
76 |
+
:param stop: stop words
|
77 |
+
:param stream: is stream response
|
78 |
+
:param user: unique user id
|
79 |
+
:return: full response or stream response chunk generator result
|
80 |
+
"""
|
81 |
+
# invoke model
|
82 |
+
return self._chat_generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
83 |
+
|
84 |
+
def _chat_generate(
|
85 |
+
self,
|
86 |
+
model: str,
|
87 |
+
credentials: dict,
|
88 |
+
prompt_messages: Sequence[PromptMessage],
|
89 |
+
model_parameters: dict,
|
90 |
+
tools: Optional[list[PromptMessageTool]] = None,
|
91 |
+
stop: Optional[Sequence[str]] = None,
|
92 |
+
stream: bool = True,
|
93 |
+
user: Optional[str] = None,
|
94 |
+
) -> Union[LLMResult, Generator]:
|
95 |
+
"""
|
96 |
+
Invoke llm chat model
|
97 |
+
|
98 |
+
:param model: model name
|
99 |
+
:param credentials: credentials
|
100 |
+
:param prompt_messages: prompt messages
|
101 |
+
:param model_parameters: model parameters
|
102 |
+
:param stop: stop words
|
103 |
+
:param stream: is stream response
|
104 |
+
:param user: unique user id
|
105 |
+
:return: full response or stream response chunk generator result
|
106 |
+
"""
|
107 |
+
# transform credentials to kwargs for model instance
|
108 |
+
credentials_kwargs = self._to_credential_kwargs(credentials)
|
109 |
+
|
110 |
+
# transform model parameters from completion api of anthropic to chat api
|
111 |
+
if "max_tokens_to_sample" in model_parameters:
|
112 |
+
model_parameters["max_tokens"] = model_parameters.pop("max_tokens_to_sample")
|
113 |
+
|
114 |
+
# init model client
|
115 |
+
client = Anthropic(**credentials_kwargs)
|
116 |
+
|
117 |
+
extra_model_kwargs = {}
|
118 |
+
if stop:
|
119 |
+
extra_model_kwargs["stop_sequences"] = stop
|
120 |
+
|
121 |
+
if user:
|
122 |
+
extra_model_kwargs["metadata"] = completion_create_params.Metadata(user_id=user)
|
123 |
+
|
124 |
+
system, prompt_message_dicts = self._convert_prompt_messages(prompt_messages)
|
125 |
+
|
126 |
+
if system:
|
127 |
+
extra_model_kwargs["system"] = system
|
128 |
+
|
129 |
+
# Add the new header for claude-3-5-sonnet-20240620 model
|
130 |
+
extra_headers = {}
|
131 |
+
if model == "claude-3-5-sonnet-20240620":
|
132 |
+
if model_parameters.get("max_tokens", 0) > 4096:
|
133 |
+
extra_headers["anthropic-beta"] = "max-tokens-3-5-sonnet-2024-07-15"
|
134 |
+
|
135 |
+
if any(
|
136 |
+
isinstance(content, DocumentPromptMessageContent)
|
137 |
+
for prompt_message in prompt_messages
|
138 |
+
if isinstance(prompt_message.content, list)
|
139 |
+
for content in prompt_message.content
|
140 |
+
):
|
141 |
+
extra_headers["anthropic-beta"] = "pdfs-2024-09-25"
|
142 |
+
|
143 |
+
if tools:
|
144 |
+
extra_model_kwargs["tools"] = [self._transform_tool_prompt(tool) for tool in tools]
|
145 |
+
response = client.beta.tools.messages.create(
|
146 |
+
model=model,
|
147 |
+
messages=prompt_message_dicts,
|
148 |
+
stream=stream,
|
149 |
+
extra_headers=extra_headers,
|
150 |
+
**model_parameters,
|
151 |
+
**extra_model_kwargs,
|
152 |
+
)
|
153 |
+
else:
|
154 |
+
# chat model
|
155 |
+
response = client.messages.create(
|
156 |
+
model=model,
|
157 |
+
messages=prompt_message_dicts,
|
158 |
+
stream=stream,
|
159 |
+
extra_headers=extra_headers,
|
160 |
+
**model_parameters,
|
161 |
+
**extra_model_kwargs,
|
162 |
+
)
|
163 |
+
|
164 |
+
if stream:
|
165 |
+
return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages)
|
166 |
+
|
167 |
+
return self._handle_chat_generate_response(model, credentials, response, prompt_messages)
|
168 |
+
|
169 |
+
def _code_block_mode_wrapper(
|
170 |
+
self,
|
171 |
+
model: str,
|
172 |
+
credentials: dict,
|
173 |
+
prompt_messages: list[PromptMessage],
|
174 |
+
model_parameters: dict,
|
175 |
+
tools: Optional[list[PromptMessageTool]] = None,
|
176 |
+
stop: Optional[list[str]] = None,
|
177 |
+
stream: bool = True,
|
178 |
+
user: Optional[str] = None,
|
179 |
+
callbacks: Optional[list[Callback]] = None,
|
180 |
+
) -> Union[LLMResult, Generator]:
|
181 |
+
"""
|
182 |
+
Code block mode wrapper for invoking large language model
|
183 |
+
"""
|
184 |
+
if model_parameters.get("response_format"):
|
185 |
+
stop = stop or []
|
186 |
+
# chat model
|
187 |
+
self._transform_chat_json_prompts(
|
188 |
+
model=model,
|
189 |
+
credentials=credentials,
|
190 |
+
prompt_messages=prompt_messages,
|
191 |
+
model_parameters=model_parameters,
|
192 |
+
tools=tools,
|
193 |
+
stop=stop,
|
194 |
+
stream=stream,
|
195 |
+
user=user,
|
196 |
+
response_format=model_parameters["response_format"],
|
197 |
+
)
|
198 |
+
model_parameters.pop("response_format")
|
199 |
+
|
200 |
+
return self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
201 |
+
|
202 |
+
def _transform_tool_prompt(self, tool: PromptMessageTool) -> dict:
|
203 |
+
return {"name": tool.name, "description": tool.description, "input_schema": tool.parameters}
|
204 |
+
|
205 |
+
def _transform_chat_json_prompts(
|
206 |
+
self,
|
207 |
+
model: str,
|
208 |
+
credentials: dict,
|
209 |
+
prompt_messages: list[PromptMessage],
|
210 |
+
model_parameters: dict,
|
211 |
+
tools: list[PromptMessageTool] | None = None,
|
212 |
+
stop: list[str] | None = None,
|
213 |
+
stream: bool = True,
|
214 |
+
user: str | None = None,
|
215 |
+
response_format: str = "JSON",
|
216 |
+
) -> None:
|
217 |
+
"""
|
218 |
+
Transform json prompts
|
219 |
+
"""
|
220 |
+
if "```\n" not in stop:
|
221 |
+
stop.append("```\n")
|
222 |
+
if "\n```" not in stop:
|
223 |
+
stop.append("\n```")
|
224 |
+
|
225 |
+
# check if there is a system message
|
226 |
+
if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage):
|
227 |
+
# override the system message
|
228 |
+
prompt_messages[0] = SystemPromptMessage(
|
229 |
+
content=ANTHROPIC_BLOCK_MODE_PROMPT.replace("{{instructions}}", prompt_messages[0].content).replace(
|
230 |
+
"{{block}}", response_format
|
231 |
+
)
|
232 |
+
)
|
233 |
+
prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}"))
|
234 |
+
else:
|
235 |
+
# insert the system message
|
236 |
+
prompt_messages.insert(
|
237 |
+
0,
|
238 |
+
SystemPromptMessage(
|
239 |
+
content=ANTHROPIC_BLOCK_MODE_PROMPT.replace(
|
240 |
+
"{{instructions}}", f"Please output a valid {response_format} object."
|
241 |
+
).replace("{{block}}", response_format)
|
242 |
+
),
|
243 |
+
)
|
244 |
+
prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}"))
|
245 |
+
|
246 |
+
def get_num_tokens(
|
247 |
+
self,
|
248 |
+
model: str,
|
249 |
+
credentials: dict,
|
250 |
+
prompt_messages: list[PromptMessage],
|
251 |
+
tools: Optional[list[PromptMessageTool]] = None,
|
252 |
+
) -> int:
|
253 |
+
"""
|
254 |
+
Get number of tokens for given prompt messages
|
255 |
+
|
256 |
+
:param model: model name
|
257 |
+
:param credentials: model credentials
|
258 |
+
:param prompt_messages: prompt messages
|
259 |
+
:param tools: tools for tool calling
|
260 |
+
:return:
|
261 |
+
"""
|
262 |
+
prompt = self._convert_messages_to_prompt_anthropic(prompt_messages)
|
263 |
+
|
264 |
+
client = Anthropic(api_key="")
|
265 |
+
tokens = client.count_tokens(prompt)
|
266 |
+
|
267 |
+
tool_call_inner_prompts_tokens_map = {
|
268 |
+
"claude-3-opus-20240229": 395,
|
269 |
+
"claude-3-haiku-20240307": 264,
|
270 |
+
"claude-3-sonnet-20240229": 159,
|
271 |
+
}
|
272 |
+
|
273 |
+
if model in tool_call_inner_prompts_tokens_map and tools:
|
274 |
+
tokens += tool_call_inner_prompts_tokens_map[model]
|
275 |
+
|
276 |
+
return tokens
|
277 |
+
|
278 |
+
def validate_credentials(self, model: str, credentials: dict) -> None:
|
279 |
+
"""
|
280 |
+
Validate model credentials
|
281 |
+
|
282 |
+
:param model: model name
|
283 |
+
:param credentials: model credentials
|
284 |
+
:return:
|
285 |
+
"""
|
286 |
+
try:
|
287 |
+
self._chat_generate(
|
288 |
+
model=model,
|
289 |
+
credentials=credentials,
|
290 |
+
prompt_messages=[
|
291 |
+
UserPromptMessage(content="ping"),
|
292 |
+
],
|
293 |
+
model_parameters={
|
294 |
+
"temperature": 0,
|
295 |
+
"max_tokens": 20,
|
296 |
+
},
|
297 |
+
stream=False,
|
298 |
+
)
|
299 |
+
except Exception as ex:
|
300 |
+
raise CredentialsValidateFailedError(str(ex))
|
301 |
+
|
302 |
+
def _handle_chat_generate_response(
|
303 |
+
self,
|
304 |
+
model: str,
|
305 |
+
credentials: dict,
|
306 |
+
response: Union[Message, ToolsBetaMessage],
|
307 |
+
prompt_messages: list[PromptMessage],
|
308 |
+
) -> LLMResult:
|
309 |
+
"""
|
310 |
+
Handle llm chat response
|
311 |
+
|
312 |
+
:param model: model name
|
313 |
+
:param credentials: credentials
|
314 |
+
:param response: response
|
315 |
+
:param prompt_messages: prompt messages
|
316 |
+
:return: llm response
|
317 |
+
"""
|
318 |
+
# transform assistant message to prompt message
|
319 |
+
assistant_prompt_message = AssistantPromptMessage(content="", tool_calls=[])
|
320 |
+
|
321 |
+
for content in response.content:
|
322 |
+
if content.type == "text":
|
323 |
+
assistant_prompt_message.content += content.text
|
324 |
+
elif content.type == "tool_use":
|
325 |
+
tool_call = AssistantPromptMessage.ToolCall(
|
326 |
+
id=content.id,
|
327 |
+
type="function",
|
328 |
+
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
329 |
+
name=content.name, arguments=json.dumps(content.input)
|
330 |
+
),
|
331 |
+
)
|
332 |
+
assistant_prompt_message.tool_calls.append(tool_call)
|
333 |
+
|
334 |
+
# calculate num tokens
|
335 |
+
prompt_tokens = (response.usage and response.usage.input_tokens) or self.get_num_tokens(
|
336 |
+
model, credentials, prompt_messages
|
337 |
+
)
|
338 |
+
|
339 |
+
completion_tokens = (response.usage and response.usage.output_tokens) or self.get_num_tokens(
|
340 |
+
model, credentials, [assistant_prompt_message]
|
341 |
+
)
|
342 |
+
|
343 |
+
# transform usage
|
344 |
+
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
345 |
+
|
346 |
+
# transform response
|
347 |
+
response = LLMResult(
|
348 |
+
model=response.model, prompt_messages=prompt_messages, message=assistant_prompt_message, usage=usage
|
349 |
+
)
|
350 |
+
|
351 |
+
return response
|
352 |
+
|
353 |
+
def _handle_chat_generate_stream_response(
|
354 |
+
self, model: str, credentials: dict, response: Stream[MessageStreamEvent], prompt_messages: list[PromptMessage]
|
355 |
+
) -> Generator:
|
356 |
+
"""
|
357 |
+
Handle llm chat stream response
|
358 |
+
|
359 |
+
:param model: model name
|
360 |
+
:param response: response
|
361 |
+
:param prompt_messages: prompt messages
|
362 |
+
:return: llm response chunk generator
|
363 |
+
"""
|
364 |
+
full_assistant_content = ""
|
365 |
+
return_model = None
|
366 |
+
input_tokens = 0
|
367 |
+
output_tokens = 0
|
368 |
+
finish_reason = None
|
369 |
+
index = 0
|
370 |
+
|
371 |
+
tool_calls: list[AssistantPromptMessage.ToolCall] = []
|
372 |
+
|
373 |
+
for chunk in response:
|
374 |
+
if isinstance(chunk, MessageStartEvent):
|
375 |
+
if hasattr(chunk, "content_block"):
|
376 |
+
content_block = chunk.content_block
|
377 |
+
if isinstance(content_block, dict):
|
378 |
+
if content_block.get("type") == "tool_use":
|
379 |
+
tool_call = AssistantPromptMessage.ToolCall(
|
380 |
+
id=content_block.get("id"),
|
381 |
+
type="function",
|
382 |
+
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
383 |
+
name=content_block.get("name"), arguments=""
|
384 |
+
),
|
385 |
+
)
|
386 |
+
tool_calls.append(tool_call)
|
387 |
+
elif hasattr(chunk, "delta"):
|
388 |
+
delta = chunk.delta
|
389 |
+
if isinstance(delta, dict) and len(tool_calls) > 0:
|
390 |
+
if delta.get("type") == "input_json_delta":
|
391 |
+
tool_calls[-1].function.arguments += delta.get("partial_json", "")
|
392 |
+
elif chunk.message:
|
393 |
+
return_model = chunk.message.model
|
394 |
+
input_tokens = chunk.message.usage.input_tokens
|
395 |
+
elif isinstance(chunk, MessageDeltaEvent):
|
396 |
+
output_tokens = chunk.usage.output_tokens
|
397 |
+
finish_reason = chunk.delta.stop_reason
|
398 |
+
elif isinstance(chunk, MessageStopEvent):
|
399 |
+
# transform usage
|
400 |
+
usage = self._calc_response_usage(model, credentials, input_tokens, output_tokens)
|
401 |
+
|
402 |
+
# transform empty tool call arguments to {}
|
403 |
+
for tool_call in tool_calls:
|
404 |
+
if not tool_call.function.arguments:
|
405 |
+
tool_call.function.arguments = "{}"
|
406 |
+
|
407 |
+
yield LLMResultChunk(
|
408 |
+
model=return_model,
|
409 |
+
prompt_messages=prompt_messages,
|
410 |
+
delta=LLMResultChunkDelta(
|
411 |
+
index=index + 1,
|
412 |
+
message=AssistantPromptMessage(content="", tool_calls=tool_calls),
|
413 |
+
finish_reason=finish_reason,
|
414 |
+
usage=usage,
|
415 |
+
),
|
416 |
+
)
|
417 |
+
elif isinstance(chunk, ContentBlockDeltaEvent):
|
418 |
+
chunk_text = chunk.delta.text or ""
|
419 |
+
full_assistant_content += chunk_text
|
420 |
+
|
421 |
+
# transform assistant message to prompt message
|
422 |
+
assistant_prompt_message = AssistantPromptMessage(content=chunk_text)
|
423 |
+
|
424 |
+
index = chunk.index
|
425 |
+
|
426 |
+
yield LLMResultChunk(
|
427 |
+
model=return_model,
|
428 |
+
prompt_messages=prompt_messages,
|
429 |
+
delta=LLMResultChunkDelta(
|
430 |
+
index=chunk.index,
|
431 |
+
message=assistant_prompt_message,
|
432 |
+
),
|
433 |
+
)
|
434 |
+
|
435 |
+
def _to_credential_kwargs(self, credentials: dict) -> dict:
|
436 |
+
"""
|
437 |
+
Transform credentials to kwargs for model instance
|
438 |
+
|
439 |
+
:param credentials:
|
440 |
+
:return:
|
441 |
+
"""
|
442 |
+
credentials_kwargs = {
|
443 |
+
"api_key": credentials["anthropic_api_key"],
|
444 |
+
"timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0),
|
445 |
+
"max_retries": 1,
|
446 |
+
}
|
447 |
+
|
448 |
+
if credentials.get("anthropic_api_url"):
|
449 |
+
credentials["anthropic_api_url"] = credentials["anthropic_api_url"].rstrip("/")
|
450 |
+
credentials_kwargs["base_url"] = credentials["anthropic_api_url"]
|
451 |
+
|
452 |
+
return credentials_kwargs
|
453 |
+
|
454 |
+
def _convert_prompt_messages(self, prompt_messages: Sequence[PromptMessage]) -> tuple[str, list[dict]]:
|
455 |
+
"""
|
456 |
+
Convert prompt messages to dict list and system
|
457 |
+
"""
|
458 |
+
system = ""
|
459 |
+
first_loop = True
|
460 |
+
for message in prompt_messages:
|
461 |
+
if isinstance(message, SystemPromptMessage):
|
462 |
+
if isinstance(message.content, str):
|
463 |
+
message.content = message.content.strip()
|
464 |
+
elif isinstance(message.content, list):
|
465 |
+
# System prompt only support text
|
466 |
+
message.content = "".join(
|
467 |
+
c.data.strip() for c in message.content if isinstance(c, TextPromptMessageContent)
|
468 |
+
)
|
469 |
+
else:
|
470 |
+
raise ValueError(f"Unknown system prompt message content type {type(message.content)}")
|
471 |
+
if first_loop:
|
472 |
+
system = message.content
|
473 |
+
first_loop = False
|
474 |
+
else:
|
475 |
+
system += "\n"
|
476 |
+
system += message.content
|
477 |
+
|
478 |
+
prompt_message_dicts = []
|
479 |
+
for message in prompt_messages:
|
480 |
+
if not isinstance(message, SystemPromptMessage):
|
481 |
+
if isinstance(message, UserPromptMessage):
|
482 |
+
message = cast(UserPromptMessage, message)
|
483 |
+
if isinstance(message.content, str):
|
484 |
+
# handle empty user prompt see #10013 #10520
|
485 |
+
# responses, ignore user prompts containing only whitespace, the Claude API can't handle it.
|
486 |
+
if not message.content.strip():
|
487 |
+
continue
|
488 |
+
message_dict = {"role": "user", "content": message.content}
|
489 |
+
prompt_message_dicts.append(message_dict)
|
490 |
+
else:
|
491 |
+
sub_messages = []
|
492 |
+
for message_content in message.content:
|
493 |
+
if message_content.type == PromptMessageContentType.TEXT:
|
494 |
+
message_content = cast(TextPromptMessageContent, message_content)
|
495 |
+
sub_message_dict = {"type": "text", "text": message_content.data}
|
496 |
+
sub_messages.append(sub_message_dict)
|
497 |
+
elif message_content.type == PromptMessageContentType.IMAGE:
|
498 |
+
message_content = cast(ImagePromptMessageContent, message_content)
|
499 |
+
if not message_content.base64_data:
|
500 |
+
# fetch image data from url
|
501 |
+
try:
|
502 |
+
image_content = requests.get(message_content.url).content
|
503 |
+
base64_data = base64.b64encode(image_content).decode("utf-8")
|
504 |
+
except Exception as ex:
|
505 |
+
raise ValueError(
|
506 |
+
f"Failed to fetch image data from url {message_content.data}, {ex}"
|
507 |
+
)
|
508 |
+
else:
|
509 |
+
base64_data = message_content.base64_data
|
510 |
+
|
511 |
+
mime_type = message_content.mime_type
|
512 |
+
if mime_type not in {"image/jpeg", "image/png", "image/gif", "image/webp"}:
|
513 |
+
raise ValueError(
|
514 |
+
f"Unsupported image type {mime_type}, "
|
515 |
+
f"only support image/jpeg, image/png, image/gif, and image/webp"
|
516 |
+
)
|
517 |
+
|
518 |
+
sub_message_dict = {
|
519 |
+
"type": "image",
|
520 |
+
"source": {"type": "base64", "media_type": mime_type, "data": base64_data},
|
521 |
+
}
|
522 |
+
sub_messages.append(sub_message_dict)
|
523 |
+
elif isinstance(message_content, DocumentPromptMessageContent):
|
524 |
+
if message_content.mime_type != "application/pdf":
|
525 |
+
raise ValueError(
|
526 |
+
f"Unsupported document type {message_content.mime_type}, "
|
527 |
+
"only support application/pdf"
|
528 |
+
)
|
529 |
+
sub_message_dict = {
|
530 |
+
"type": "document",
|
531 |
+
"source": {
|
532 |
+
"type": "base64",
|
533 |
+
"media_type": message_content.mime_type,
|
534 |
+
"data": message_content.base64_data,
|
535 |
+
},
|
536 |
+
}
|
537 |
+
sub_messages.append(sub_message_dict)
|
538 |
+
prompt_message_dicts.append({"role": "user", "content": sub_messages})
|
539 |
+
elif isinstance(message, AssistantPromptMessage):
|
540 |
+
message = cast(AssistantPromptMessage, message)
|
541 |
+
content = []
|
542 |
+
if message.tool_calls:
|
543 |
+
for tool_call in message.tool_calls:
|
544 |
+
content.append(
|
545 |
+
{
|
546 |
+
"type": "tool_use",
|
547 |
+
"id": tool_call.id,
|
548 |
+
"name": tool_call.function.name,
|
549 |
+
"input": json.loads(tool_call.function.arguments),
|
550 |
+
}
|
551 |
+
)
|
552 |
+
if message.content:
|
553 |
+
content.append({"type": "text", "text": message.content})
|
554 |
+
|
555 |
+
if prompt_message_dicts[-1]["role"] == "assistant":
|
556 |
+
prompt_message_dicts[-1]["content"].extend(content)
|
557 |
+
else:
|
558 |
+
prompt_message_dicts.append({"role": "assistant", "content": content})
|
559 |
+
elif isinstance(message, ToolPromptMessage):
|
560 |
+
message = cast(ToolPromptMessage, message)
|
561 |
+
message_dict = {
|
562 |
+
"role": "user",
|
563 |
+
"content": [
|
564 |
+
{"type": "tool_result", "tool_use_id": message.tool_call_id, "content": message.content}
|
565 |
+
],
|
566 |
+
}
|
567 |
+
prompt_message_dicts.append(message_dict)
|
568 |
+
else:
|
569 |
+
raise ValueError(f"Got unknown type {message}")
|
570 |
+
|
571 |
+
return system, prompt_message_dicts
|
572 |
+
|
573 |
+
def _convert_one_message_to_text(self, message: PromptMessage) -> str:
|
574 |
+
"""
|
575 |
+
Convert a single message to a string.
|
576 |
+
|
577 |
+
:param message: PromptMessage to convert.
|
578 |
+
:return: String representation of the message.
|
579 |
+
"""
|
580 |
+
human_prompt = "\n\nHuman:"
|
581 |
+
ai_prompt = "\n\nAssistant:"
|
582 |
+
content = message.content
|
583 |
+
|
584 |
+
if isinstance(message, UserPromptMessage):
|
585 |
+
message_text = f"{human_prompt} {content}"
|
586 |
+
if not isinstance(message.content, list):
|
587 |
+
message_text = f"{ai_prompt} {content}"
|
588 |
+
else:
|
589 |
+
message_text = ""
|
590 |
+
for sub_message in message.content:
|
591 |
+
if sub_message.type == PromptMessageContentType.TEXT:
|
592 |
+
message_text += f"{human_prompt} {sub_message.data}"
|
593 |
+
elif sub_message.type == PromptMessageContentType.IMAGE:
|
594 |
+
message_text += f"{human_prompt} [IMAGE]"
|
595 |
+
elif isinstance(message, AssistantPromptMessage):
|
596 |
+
if not isinstance(message.content, list):
|
597 |
+
message_text = f"{ai_prompt} {content}"
|
598 |
+
else:
|
599 |
+
message_text = ""
|
600 |
+
for sub_message in message.content:
|
601 |
+
if sub_message.type == PromptMessageContentType.TEXT:
|
602 |
+
message_text += f"{ai_prompt} {sub_message.data}"
|
603 |
+
elif sub_message.type == PromptMessageContentType.IMAGE:
|
604 |
+
message_text += f"{ai_prompt} [IMAGE]"
|
605 |
+
elif isinstance(message, SystemPromptMessage):
|
606 |
+
message_text = content
|
607 |
+
elif isinstance(message, ToolPromptMessage):
|
608 |
+
message_text = f"{human_prompt} {message.content}"
|
609 |
+
else:
|
610 |
+
raise ValueError(f"Got unknown type {message}")
|
611 |
+
|
612 |
+
return message_text
|
613 |
+
|
614 |
+
def _convert_messages_to_prompt_anthropic(self, messages: list[PromptMessage]) -> str:
|
615 |
+
"""
|
616 |
+
Format a list of messages into a full prompt for the Anthropic model
|
617 |
+
|
618 |
+
:param messages: List of PromptMessage to combine.
|
619 |
+
:return: Combined string with necessary human_prompt and ai_prompt tags.
|
620 |
+
"""
|
621 |
+
if not messages:
|
622 |
+
return ""
|
623 |
+
|
624 |
+
messages = messages.copy() # don't mutate the original list
|
625 |
+
if not isinstance(messages[-1], AssistantPromptMessage):
|
626 |
+
messages.append(AssistantPromptMessage(content=""))
|
627 |
+
|
628 |
+
text = "".join(self._convert_one_message_to_text(message) for message in messages)
|
629 |
+
|
630 |
+
# trim off the trailing ' ' that might come from the "Assistant: "
|
631 |
+
return text.rstrip()
|
632 |
+
|
633 |
+
@property
|
634 |
+
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
635 |
+
"""
|
636 |
+
Map model invoke error to unified error
|
637 |
+
The key is the error type thrown to the caller
|
638 |
+
The value is the error type thrown by the model,
|
639 |
+
which needs to be converted into a unified error type for the caller.
|
640 |
+
|
641 |
+
:return: Invoke error mapping
|
642 |
+
"""
|
643 |
+
return {
|
644 |
+
InvokeConnectionError: [anthropic.APIConnectionError, anthropic.APITimeoutError],
|
645 |
+
InvokeServerUnavailableError: [anthropic.InternalServerError],
|
646 |
+
InvokeRateLimitError: [anthropic.RateLimitError],
|
647 |
+
InvokeAuthorizationError: [anthropic.AuthenticationError, anthropic.PermissionDeniedError],
|
648 |
+
InvokeBadRequestError: [
|
649 |
+
anthropic.BadRequestError,
|
650 |
+
anthropic.NotFoundError,
|
651 |
+
anthropic.UnprocessableEntityError,
|
652 |
+
anthropic.APIError,
|
653 |
+
],
|
654 |
+
}
|
api/core/model_runtime/model_providers/azure_ai_studio/__init__.py
ADDED
File without changes
|
api/core/model_runtime/model_providers/azure_ai_studio/_assets/icon_l_en.png
ADDED
![]() |
api/core/model_runtime/model_providers/azure_ai_studio/_assets/icon_s_en.png
ADDED
![]() |
api/core/model_runtime/model_providers/azure_ai_studio/azure_ai_studio.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
|
3 |
+
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
4 |
+
|
5 |
+
logger = logging.getLogger(__name__)
|
6 |
+
|
7 |
+
|
8 |
+
class AzureAIStudioProvider(ModelProvider):
|
9 |
+
def validate_provider_credentials(self, credentials: dict) -> None:
|
10 |
+
"""
|
11 |
+
Validate provider credentials
|
12 |
+
|
13 |
+
if validate failed, raise exception
|
14 |
+
|
15 |
+
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
|
16 |
+
"""
|
17 |
+
pass
|
api/core/model_runtime/model_providers/azure_ai_studio/azure_ai_studio.yaml
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
provider: azure_ai_studio
|
2 |
+
label:
|
3 |
+
zh_Hans: Azure AI Studio
|
4 |
+
en_US: Azure AI Studio
|
5 |
+
icon_small:
|
6 |
+
en_US: icon_s_en.png
|
7 |
+
icon_large:
|
8 |
+
en_US: icon_l_en.png
|
9 |
+
description:
|
10 |
+
en_US: Azure AI Studio
|
11 |
+
zh_Hans: Azure AI Studio
|
12 |
+
background: "#93c5fd"
|
13 |
+
help:
|
14 |
+
title:
|
15 |
+
en_US: How to deploy customized model on Azure AI Studio
|
16 |
+
zh_Hans: 如何在Azure AI Studio上的私有化部署的模型
|
17 |
+
url:
|
18 |
+
en_US: https://learn.microsoft.com/en-us/azure/ai-studio/how-to/deploy-models
|
19 |
+
zh_Hans: https://learn.microsoft.com/zh-cn/azure/ai-studio/how-to/deploy-models
|
20 |
+
supported_model_types:
|
21 |
+
- llm
|
22 |
+
- rerank
|
23 |
+
configurate_methods:
|
24 |
+
- customizable-model
|
25 |
+
model_credential_schema:
|
26 |
+
model:
|
27 |
+
label:
|
28 |
+
en_US: Model Name
|
29 |
+
zh_Hans: 模型名称
|
30 |
+
placeholder:
|
31 |
+
en_US: Enter your model name
|
32 |
+
zh_Hans: 输入模型名称
|
33 |
+
credential_form_schemas:
|
34 |
+
- variable: endpoint
|
35 |
+
label:
|
36 |
+
en_US: Azure AI Studio Endpoint
|
37 |
+
type: text-input
|
38 |
+
required: true
|
39 |
+
placeholder:
|
40 |
+
zh_Hans: 请输入你的Azure AI Studio推理端点
|
41 |
+
en_US: 'Enter your API Endpoint, eg: https://example.com'
|
42 |
+
- variable: api_key
|
43 |
+
required: true
|
44 |
+
label:
|
45 |
+
en_US: API Key
|
46 |
+
zh_Hans: API Key
|
47 |
+
type: secret-input
|
48 |
+
placeholder:
|
49 |
+
en_US: Enter your Azure AI Studio API Key
|
50 |
+
zh_Hans: 在此输入您的 Azure AI Studio API Key
|
51 |
+
show_on:
|
52 |
+
- variable: __model_type
|
53 |
+
value: llm
|
54 |
+
- variable: mode
|
55 |
+
show_on:
|
56 |
+
- variable: __model_type
|
57 |
+
value: llm
|
58 |
+
label:
|
59 |
+
en_US: Completion mode
|
60 |
+
type: select
|
61 |
+
required: false
|
62 |
+
default: chat
|
63 |
+
placeholder:
|
64 |
+
zh_Hans: 选择对话类型
|
65 |
+
en_US: Select completion mode
|
66 |
+
options:
|
67 |
+
- value: completion
|
68 |
+
label:
|
69 |
+
en_US: Completion
|
70 |
+
zh_Hans: 补全
|
71 |
+
- value: chat
|
72 |
+
label:
|
73 |
+
en_US: Chat
|
74 |
+
zh_Hans: 对话
|
75 |
+
- variable: context_size
|
76 |
+
label:
|
77 |
+
zh_Hans: 模型上下文长度
|
78 |
+
en_US: Model context size
|
79 |
+
required: true
|
80 |
+
show_on:
|
81 |
+
- variable: __model_type
|
82 |
+
value: llm
|
83 |
+
type: text-input
|
84 |
+
default: "4096"
|
85 |
+
placeholder:
|
86 |
+
zh_Hans: 在此输入您的模型上下文长度
|
87 |
+
en_US: Enter your Model context size
|
88 |
+
- variable: jwt_token
|
89 |
+
required: true
|
90 |
+
label:
|
91 |
+
en_US: JWT Token
|
92 |
+
zh_Hans: JWT令牌
|
93 |
+
type: secret-input
|
94 |
+
placeholder:
|
95 |
+
en_US: Enter your Azure AI Studio JWT Token
|
96 |
+
zh_Hans: 在此输入您的 Azure AI Studio 推理 API Key
|
97 |
+
show_on:
|
98 |
+
- variable: __model_type
|
99 |
+
value: rerank
|
api/core/model_runtime/model_providers/azure_ai_studio/llm/__init__.py
ADDED
File without changes
|
api/core/model_runtime/model_providers/azure_ai_studio/llm/llm.py
ADDED
@@ -0,0 +1,345 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from collections.abc import Generator, Sequence
|
3 |
+
from typing import Any, Optional, Union
|
4 |
+
|
5 |
+
from azure.ai.inference import ChatCompletionsClient
|
6 |
+
from azure.ai.inference.models import StreamingChatCompletionsUpdate, SystemMessage, UserMessage
|
7 |
+
from azure.core.credentials import AzureKeyCredential
|
8 |
+
from azure.core.exceptions import (
|
9 |
+
ClientAuthenticationError,
|
10 |
+
DecodeError,
|
11 |
+
DeserializationError,
|
12 |
+
HttpResponseError,
|
13 |
+
ResourceExistsError,
|
14 |
+
ResourceModifiedError,
|
15 |
+
ResourceNotFoundError,
|
16 |
+
ResourceNotModifiedError,
|
17 |
+
SerializationError,
|
18 |
+
ServiceRequestError,
|
19 |
+
ServiceResponseError,
|
20 |
+
)
|
21 |
+
|
22 |
+
from core.model_runtime.callbacks.base_callback import Callback
|
23 |
+
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
24 |
+
from core.model_runtime.entities.message_entities import (
|
25 |
+
AssistantPromptMessage,
|
26 |
+
PromptMessage,
|
27 |
+
PromptMessageTool,
|
28 |
+
)
|
29 |
+
from core.model_runtime.entities.model_entities import (
|
30 |
+
AIModelEntity,
|
31 |
+
FetchFrom,
|
32 |
+
I18nObject,
|
33 |
+
ModelPropertyKey,
|
34 |
+
ModelType,
|
35 |
+
ParameterRule,
|
36 |
+
ParameterType,
|
37 |
+
)
|
38 |
+
from core.model_runtime.errors.invoke import (
|
39 |
+
InvokeAuthorizationError,
|
40 |
+
InvokeBadRequestError,
|
41 |
+
InvokeConnectionError,
|
42 |
+
InvokeError,
|
43 |
+
InvokeServerUnavailableError,
|
44 |
+
)
|
45 |
+
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
46 |
+
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
47 |
+
|
48 |
+
logger = logging.getLogger(__name__)
|
49 |
+
|
50 |
+
|
51 |
+
class AzureAIStudioLargeLanguageModel(LargeLanguageModel):
|
52 |
+
"""
|
53 |
+
Model class for Azure AI Studio large language model.
|
54 |
+
"""
|
55 |
+
|
56 |
+
client: Any = None
|
57 |
+
|
58 |
+
from azure.ai.inference.models import StreamingChatCompletionsUpdate
|
59 |
+
|
60 |
+
def _invoke(
|
61 |
+
self,
|
62 |
+
model: str,
|
63 |
+
credentials: dict,
|
64 |
+
prompt_messages: Sequence[PromptMessage],
|
65 |
+
model_parameters: dict,
|
66 |
+
tools: Optional[Sequence[PromptMessageTool]] = None,
|
67 |
+
stop: Optional[Sequence[str]] = None,
|
68 |
+
stream: bool = True,
|
69 |
+
user: Optional[str] = None,
|
70 |
+
) -> Union[LLMResult, Generator]:
|
71 |
+
"""
|
72 |
+
Invoke large language model
|
73 |
+
|
74 |
+
:param model: model name
|
75 |
+
:param credentials: model credentials
|
76 |
+
:param prompt_messages: prompt messages
|
77 |
+
:param model_parameters: model parameters
|
78 |
+
:param tools: tools for tool calling
|
79 |
+
:param stop: stop words
|
80 |
+
:param stream: is stream response
|
81 |
+
:param user: unique user id
|
82 |
+
:return: full response or stream response chunk generator result
|
83 |
+
"""
|
84 |
+
|
85 |
+
if not self.client:
|
86 |
+
endpoint = str(credentials.get("endpoint"))
|
87 |
+
api_key = str(credentials.get("api_key"))
|
88 |
+
self.client = ChatCompletionsClient(endpoint=endpoint, credential=AzureKeyCredential(api_key))
|
89 |
+
|
90 |
+
messages = [{"role": msg.role.value, "content": msg.content} for msg in prompt_messages]
|
91 |
+
|
92 |
+
payload = {
|
93 |
+
"messages": messages,
|
94 |
+
"max_tokens": model_parameters.get("max_tokens", 4096),
|
95 |
+
"temperature": model_parameters.get("temperature", 0),
|
96 |
+
"top_p": model_parameters.get("top_p", 1),
|
97 |
+
"stream": stream,
|
98 |
+
"model": model,
|
99 |
+
}
|
100 |
+
|
101 |
+
if stop:
|
102 |
+
payload["stop"] = stop
|
103 |
+
|
104 |
+
if tools:
|
105 |
+
payload["tools"] = [tool.model_dump() for tool in tools]
|
106 |
+
|
107 |
+
try:
|
108 |
+
response = self.client.complete(**payload)
|
109 |
+
|
110 |
+
if stream:
|
111 |
+
return self._handle_stream_response(response, model, prompt_messages)
|
112 |
+
else:
|
113 |
+
return self._handle_non_stream_response(response, model, prompt_messages, credentials)
|
114 |
+
except Exception as e:
|
115 |
+
raise self._transform_invoke_error(e)
|
116 |
+
|
117 |
+
def _handle_stream_response(self, response, model: str, prompt_messages: list[PromptMessage]) -> Generator:
|
118 |
+
for chunk in response:
|
119 |
+
if isinstance(chunk, StreamingChatCompletionsUpdate):
|
120 |
+
if chunk.choices:
|
121 |
+
delta = chunk.choices[0].delta
|
122 |
+
if delta.content:
|
123 |
+
yield LLMResultChunk(
|
124 |
+
model=model,
|
125 |
+
prompt_messages=prompt_messages,
|
126 |
+
delta=LLMResultChunkDelta(
|
127 |
+
index=0,
|
128 |
+
message=AssistantPromptMessage(content=delta.content, tool_calls=[]),
|
129 |
+
),
|
130 |
+
)
|
131 |
+
|
132 |
+
def _handle_non_stream_response(
|
133 |
+
self, response, model: str, prompt_messages: list[PromptMessage], credentials: dict
|
134 |
+
) -> LLMResult:
|
135 |
+
assistant_text = response.choices[0].message.content
|
136 |
+
assistant_prompt_message = AssistantPromptMessage(content=assistant_text)
|
137 |
+
usage = self._calc_response_usage(
|
138 |
+
model, credentials, response.usage.prompt_tokens, response.usage.completion_tokens
|
139 |
+
)
|
140 |
+
result = LLMResult(model=model, prompt_messages=prompt_messages, message=assistant_prompt_message, usage=usage)
|
141 |
+
|
142 |
+
if hasattr(response, "system_fingerprint"):
|
143 |
+
result.system_fingerprint = response.system_fingerprint
|
144 |
+
|
145 |
+
return result
|
146 |
+
|
147 |
+
def _invoke_result_generator(
|
148 |
+
self,
|
149 |
+
model: str,
|
150 |
+
result: Generator,
|
151 |
+
credentials: dict,
|
152 |
+
prompt_messages: list[PromptMessage],
|
153 |
+
model_parameters: dict,
|
154 |
+
tools: Optional[list[PromptMessageTool]] = None,
|
155 |
+
stop: Optional[list[str]] = None,
|
156 |
+
stream: bool = True,
|
157 |
+
user: Optional[str] = None,
|
158 |
+
callbacks: Optional[list[Callback]] = None,
|
159 |
+
) -> Generator:
|
160 |
+
"""
|
161 |
+
Invoke result generator
|
162 |
+
|
163 |
+
:param result: result generator
|
164 |
+
:return: result generator
|
165 |
+
"""
|
166 |
+
callbacks = callbacks or []
|
167 |
+
prompt_message = AssistantPromptMessage(content="")
|
168 |
+
usage = None
|
169 |
+
system_fingerprint = None
|
170 |
+
real_model = model
|
171 |
+
|
172 |
+
try:
|
173 |
+
for chunk in result:
|
174 |
+
if isinstance(chunk, dict):
|
175 |
+
content = chunk["choices"][0]["message"]["content"]
|
176 |
+
usage = chunk["usage"]
|
177 |
+
chunk = LLMResultChunk(
|
178 |
+
model=model,
|
179 |
+
prompt_messages=prompt_messages,
|
180 |
+
delta=LLMResultChunkDelta(
|
181 |
+
index=0,
|
182 |
+
message=AssistantPromptMessage(content=content, tool_calls=[]),
|
183 |
+
),
|
184 |
+
system_fingerprint=chunk.get("system_fingerprint"),
|
185 |
+
)
|
186 |
+
|
187 |
+
yield chunk
|
188 |
+
|
189 |
+
self._trigger_new_chunk_callbacks(
|
190 |
+
chunk=chunk,
|
191 |
+
model=model,
|
192 |
+
credentials=credentials,
|
193 |
+
prompt_messages=prompt_messages,
|
194 |
+
model_parameters=model_parameters,
|
195 |
+
tools=tools,
|
196 |
+
stop=stop,
|
197 |
+
stream=stream,
|
198 |
+
user=user,
|
199 |
+
callbacks=callbacks,
|
200 |
+
)
|
201 |
+
|
202 |
+
prompt_message.content += chunk.delta.message.content
|
203 |
+
real_model = chunk.model
|
204 |
+
if hasattr(chunk.delta, "usage"):
|
205 |
+
usage = chunk.delta.usage
|
206 |
+
|
207 |
+
if chunk.system_fingerprint:
|
208 |
+
system_fingerprint = chunk.system_fingerprint
|
209 |
+
except Exception as e:
|
210 |
+
raise self._transform_invoke_error(e)
|
211 |
+
|
212 |
+
self._trigger_after_invoke_callbacks(
|
213 |
+
model=model,
|
214 |
+
result=LLMResult(
|
215 |
+
model=real_model,
|
216 |
+
prompt_messages=prompt_messages,
|
217 |
+
message=prompt_message,
|
218 |
+
usage=usage or LLMUsage.empty_usage(),
|
219 |
+
system_fingerprint=system_fingerprint,
|
220 |
+
),
|
221 |
+
credentials=credentials,
|
222 |
+
prompt_messages=prompt_messages,
|
223 |
+
model_parameters=model_parameters,
|
224 |
+
tools=tools,
|
225 |
+
stop=stop,
|
226 |
+
stream=stream,
|
227 |
+
user=user,
|
228 |
+
callbacks=callbacks,
|
229 |
+
)
|
230 |
+
|
231 |
+
def get_num_tokens(
|
232 |
+
self,
|
233 |
+
model: str,
|
234 |
+
credentials: dict,
|
235 |
+
prompt_messages: list[PromptMessage],
|
236 |
+
tools: Optional[list[PromptMessageTool]] = None,
|
237 |
+
) -> int:
|
238 |
+
"""
|
239 |
+
Get number of tokens for given prompt messages
|
240 |
+
|
241 |
+
:param model: model name
|
242 |
+
:param credentials: model credentials
|
243 |
+
:param prompt_messages: prompt messages
|
244 |
+
:param tools: tools for tool calling
|
245 |
+
:return:
|
246 |
+
"""
|
247 |
+
# Implement token counting logic here
|
248 |
+
# Might need to use a tokenizer specific to the Azure AI Studio model
|
249 |
+
return 0
|
250 |
+
|
251 |
+
def validate_credentials(self, model: str, credentials: dict) -> None:
|
252 |
+
"""
|
253 |
+
Validate model credentials
|
254 |
+
|
255 |
+
:param model: model name
|
256 |
+
:param credentials: model credentials
|
257 |
+
:return:
|
258 |
+
"""
|
259 |
+
try:
|
260 |
+
endpoint = str(credentials.get("endpoint"))
|
261 |
+
api_key = str(credentials.get("api_key"))
|
262 |
+
client = ChatCompletionsClient(endpoint=endpoint, credential=AzureKeyCredential(api_key))
|
263 |
+
client.complete(
|
264 |
+
messages=[
|
265 |
+
SystemMessage(content="I say 'ping', you say 'pong'"),
|
266 |
+
UserMessage(content="ping"),
|
267 |
+
],
|
268 |
+
model=model,
|
269 |
+
)
|
270 |
+
except Exception as ex:
|
271 |
+
raise CredentialsValidateFailedError(str(ex))
|
272 |
+
|
273 |
+
@property
|
274 |
+
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
275 |
+
"""
|
276 |
+
Map model invoke error to unified error
|
277 |
+
The key is the error type thrown to the caller
|
278 |
+
The value is the error type thrown by the model,
|
279 |
+
which needs to be converted into a unified error type for the caller.
|
280 |
+
|
281 |
+
:return: Invoke error mapping
|
282 |
+
"""
|
283 |
+
return {
|
284 |
+
InvokeConnectionError: [
|
285 |
+
ServiceRequestError,
|
286 |
+
],
|
287 |
+
InvokeServerUnavailableError: [
|
288 |
+
ServiceResponseError,
|
289 |
+
],
|
290 |
+
InvokeAuthorizationError: [
|
291 |
+
ClientAuthenticationError,
|
292 |
+
],
|
293 |
+
InvokeBadRequestError: [
|
294 |
+
HttpResponseError,
|
295 |
+
DecodeError,
|
296 |
+
ResourceExistsError,
|
297 |
+
ResourceNotFoundError,
|
298 |
+
ResourceModifiedError,
|
299 |
+
ResourceNotModifiedError,
|
300 |
+
SerializationError,
|
301 |
+
DeserializationError,
|
302 |
+
],
|
303 |
+
}
|
304 |
+
|
305 |
+
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
306 |
+
"""
|
307 |
+
Used to define customizable model schema
|
308 |
+
"""
|
309 |
+
rules = [
|
310 |
+
ParameterRule(
|
311 |
+
name="temperature",
|
312 |
+
type=ParameterType.FLOAT,
|
313 |
+
use_template="temperature",
|
314 |
+
label=I18nObject(zh_Hans="温度", en_US="Temperature"),
|
315 |
+
),
|
316 |
+
ParameterRule(
|
317 |
+
name="top_p",
|
318 |
+
type=ParameterType.FLOAT,
|
319 |
+
use_template="top_p",
|
320 |
+
label=I18nObject(zh_Hans="Top P", en_US="Top P"),
|
321 |
+
),
|
322 |
+
ParameterRule(
|
323 |
+
name="max_tokens",
|
324 |
+
type=ParameterType.INT,
|
325 |
+
use_template="max_tokens",
|
326 |
+
min=1,
|
327 |
+
default=512,
|
328 |
+
label=I18nObject(zh_Hans="最大生成长度", en_US="Max Tokens"),
|
329 |
+
),
|
330 |
+
]
|
331 |
+
|
332 |
+
entity = AIModelEntity(
|
333 |
+
model=model,
|
334 |
+
label=I18nObject(en_US=model),
|
335 |
+
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
336 |
+
model_type=ModelType.LLM,
|
337 |
+
features=[],
|
338 |
+
model_properties={
|
339 |
+
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", "4096")),
|
340 |
+
ModelPropertyKey.MODE: credentials.get("mode", LLMMode.CHAT),
|
341 |
+
},
|
342 |
+
parameter_rules=rules,
|
343 |
+
)
|
344 |
+
|
345 |
+
return entity
|
api/core/model_runtime/model_providers/azure_ai_studio/rerank/__init__.py
ADDED
File without changes
|
api/core/model_runtime/model_providers/azure_ai_studio/rerank/rerank.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
import ssl
|
5 |
+
import urllib.request
|
6 |
+
from typing import Optional
|
7 |
+
|
8 |
+
from core.model_runtime.entities.common_entities import I18nObject
|
9 |
+
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
|
10 |
+
from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
|
11 |
+
from core.model_runtime.errors.invoke import (
|
12 |
+
InvokeAuthorizationError,
|
13 |
+
InvokeBadRequestError,
|
14 |
+
InvokeConnectionError,
|
15 |
+
InvokeError,
|
16 |
+
InvokeRateLimitError,
|
17 |
+
InvokeServerUnavailableError,
|
18 |
+
)
|
19 |
+
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
20 |
+
from core.model_runtime.model_providers.__base.rerank_model import RerankModel
|
21 |
+
|
22 |
+
logger = logging.getLogger(__name__)
|
23 |
+
|
24 |
+
|
25 |
+
class AzureRerankModel(RerankModel):
|
26 |
+
"""
|
27 |
+
Model class for Azure AI Studio rerank model.
|
28 |
+
"""
|
29 |
+
|
30 |
+
def _allow_self_signed_https(self, allowed):
|
31 |
+
# bypass the server certificate verification on client side
|
32 |
+
if allowed and not os.environ.get("PYTHONHTTPSVERIFY", "") and getattr(ssl, "_create_unverified_context", None):
|
33 |
+
ssl._create_default_https_context = ssl._create_unverified_context
|
34 |
+
|
35 |
+
def _azure_rerank(self, query_input: str, docs: list[str], endpoint: str, api_key: str):
|
36 |
+
# self._allow_self_signed_https(True) # Enable if using self-signed certificate
|
37 |
+
|
38 |
+
data = {"inputs": query_input, "docs": docs}
|
39 |
+
|
40 |
+
body = json.dumps(data).encode("utf-8")
|
41 |
+
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
|
42 |
+
|
43 |
+
req = urllib.request.Request(endpoint, body, headers)
|
44 |
+
|
45 |
+
try:
|
46 |
+
with urllib.request.urlopen(req) as response:
|
47 |
+
result = response.read()
|
48 |
+
return json.loads(result)
|
49 |
+
except urllib.error.HTTPError as error:
|
50 |
+
logger.exception(f"The request failed with status code: {error.code}")
|
51 |
+
logger.exception(error.info())
|
52 |
+
logger.exception(error.read().decode("utf8", "ignore"))
|
53 |
+
raise
|
54 |
+
|
55 |
+
def _invoke(
|
56 |
+
self,
|
57 |
+
model: str,
|
58 |
+
credentials: dict,
|
59 |
+
query: str,
|
60 |
+
docs: list[str],
|
61 |
+
score_threshold: Optional[float] = None,
|
62 |
+
top_n: Optional[int] = None,
|
63 |
+
user: Optional[str] = None,
|
64 |
+
) -> RerankResult:
|
65 |
+
"""
|
66 |
+
Invoke rerank model
|
67 |
+
|
68 |
+
:param model: model name
|
69 |
+
:param credentials: model credentials
|
70 |
+
:param query: search query
|
71 |
+
:param docs: docs for reranking
|
72 |
+
:param score_threshold: score threshold
|
73 |
+
:param top_n: top n
|
74 |
+
:param user: unique user id
|
75 |
+
:return: rerank result
|
76 |
+
"""
|
77 |
+
try:
|
78 |
+
if len(docs) == 0:
|
79 |
+
return RerankResult(model=model, docs=[])
|
80 |
+
|
81 |
+
endpoint = credentials.get("endpoint")
|
82 |
+
api_key = credentials.get("jwt_token")
|
83 |
+
|
84 |
+
if not endpoint or not api_key:
|
85 |
+
raise ValueError("Azure endpoint and API key must be provided in credentials")
|
86 |
+
|
87 |
+
result = self._azure_rerank(query, docs, endpoint, api_key)
|
88 |
+
logger.info(f"Azure rerank result: {result}")
|
89 |
+
|
90 |
+
rerank_documents = []
|
91 |
+
for idx, (doc, score_dict) in enumerate(zip(docs, result)):
|
92 |
+
score = score_dict["score"]
|
93 |
+
rerank_document = RerankDocument(index=idx, text=doc, score=score)
|
94 |
+
|
95 |
+
if score_threshold is None or score >= score_threshold:
|
96 |
+
rerank_documents.append(rerank_document)
|
97 |
+
|
98 |
+
rerank_documents.sort(key=lambda x: x.score, reverse=True)
|
99 |
+
|
100 |
+
if top_n:
|
101 |
+
rerank_documents = rerank_documents[:top_n]
|
102 |
+
|
103 |
+
return RerankResult(model=model, docs=rerank_documents)
|
104 |
+
|
105 |
+
except Exception as e:
|
106 |
+
logger.exception(f"Failed to invoke rerank model, model: {model}")
|
107 |
+
raise
|
108 |
+
|
109 |
+
def validate_credentials(self, model: str, credentials: dict) -> None:
|
110 |
+
"""
|
111 |
+
Validate model credentials
|
112 |
+
|
113 |
+
:param model: model name
|
114 |
+
:param credentials: model credentials
|
115 |
+
:return:
|
116 |
+
"""
|
117 |
+
try:
|
118 |
+
self._invoke(
|
119 |
+
model=model,
|
120 |
+
credentials=credentials,
|
121 |
+
query="What is the capital of the United States?",
|
122 |
+
docs=[
|
123 |
+
"Carson City is the capital city of the American state of Nevada. At the 2010 United States "
|
124 |
+
"Census, Carson City had a population of 55,274.",
|
125 |
+
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
|
126 |
+
"are a political division controlled by the United States. Its capital is Saipan.",
|
127 |
+
],
|
128 |
+
score_threshold=0.8,
|
129 |
+
)
|
130 |
+
except Exception as ex:
|
131 |
+
raise CredentialsValidateFailedError(str(ex))
|
132 |
+
|
133 |
+
@property
|
134 |
+
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
135 |
+
"""
|
136 |
+
Map model invoke error to unified error
|
137 |
+
The key is the error type thrown to the caller
|
138 |
+
The value is the error type thrown by the model,
|
139 |
+
which needs to be converted into a unified error type for the caller.
|
140 |
+
|
141 |
+
:return: Invoke error mapping
|
142 |
+
"""
|
143 |
+
return {
|
144 |
+
InvokeConnectionError: [urllib.error.URLError],
|
145 |
+
InvokeServerUnavailableError: [urllib.error.HTTPError],
|
146 |
+
InvokeRateLimitError: [InvokeRateLimitError],
|
147 |
+
InvokeAuthorizationError: [InvokeAuthorizationError],
|
148 |
+
InvokeBadRequestError: [InvokeBadRequestError, KeyError, ValueError, json.JSONDecodeError],
|
149 |
+
}
|
150 |
+
|
151 |
+
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
152 |
+
"""
|
153 |
+
used to define customizable model schema
|
154 |
+
"""
|
155 |
+
entity = AIModelEntity(
|
156 |
+
model=model,
|
157 |
+
label=I18nObject(en_US=model),
|
158 |
+
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
159 |
+
model_type=ModelType.RERANK,
|
160 |
+
model_properties={},
|
161 |
+
parameter_rules=[],
|
162 |
+
)
|
163 |
+
|
164 |
+
return entity
|
api/core/model_runtime/model_providers/azure_openai/__init__.py
ADDED
File without changes
|