yungongzi commited on
Commit
dea956e
·
1 Parent(s): 9d69040

Add support for VolcEngine - the current version supports SDK2 (#885)

Browse files

- The main idea is to assemble **ak**, **sk**, and **ep_id** into a
dictionary and store it in the database **api_key** field
- I don’t know much about the front-end, so I learned from Ollama, which
may be redundant.

### Configuration method

- model name

- Format requirements: {"VolcEngine model name":"endpoint_id"}
- For example: {"Skylark-pro-32K":"ep-xxxxxxxxx"}

- Volcano ACCESS_KEY
- Format requirements: VOLC_ACCESSKEY of the volcano engine
corresponding to the model

- Volcano SECRET_KEY
- Format requirements: VOLC_SECRETKEY of the volcano engine
corresponding to the model

### What problem does this PR solve?

_Briefly describe what this PR aims to solve. Include background context
that will help reviewers understand the purpose of the PR._

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

api/apps/llm_app.py CHANGED
@@ -96,16 +96,29 @@ def set_api_key():
96
  @validate_request("llm_factory", "llm_name", "model_type")
97
  def add_llm():
98
  req = request.json
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  llm = {
100
  "tenant_id": current_user.id,
101
- "llm_factory": req["llm_factory"],
102
  "model_type": req["model_type"],
103
- "llm_name": req["llm_name"],
104
  "api_base": req.get("api_base", ""),
105
- "api_key": "xxxxxxxxxxxxxxx"
106
  }
107
 
108
- factory = req["llm_factory"]
109
  msg = ""
110
  if llm["model_type"] == LLMType.EMBEDDING.value:
111
  mdl = EmbeddingModel[factory](
@@ -118,7 +131,10 @@ def add_llm():
118
  msg += f"\nFail to access embedding model({llm['llm_name']})." + str(e)
119
  elif llm["model_type"] == LLMType.CHAT.value:
120
  mdl = ChatModel[factory](
121
- key=None, model_name=llm["llm_name"], base_url=llm["api_base"])
 
 
 
122
  try:
123
  m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {
124
  "temperature": 0.9})
@@ -134,7 +150,6 @@ def add_llm():
134
  if msg:
135
  return get_data_error_result(retmsg=msg)
136
 
137
-
138
  if not TenantLLMService.filter_update(
139
  [TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == factory, TenantLLM.llm_name == llm["llm_name"]], llm):
140
  TenantLLMService.save(**llm)
 
96
  @validate_request("llm_factory", "llm_name", "model_type")
97
  def add_llm():
98
  req = request.json
99
+ factory = req["llm_factory"]
100
+ # For VolcEngine, due to its special authentication method
101
+ # Assemble volc_ak, volc_sk, endpoint_id into api_key
102
+ if factory == "VolcEngine":
103
+ temp = list(eval(req["llm_name"]).items())[0]
104
+ llm_name = temp[0]
105
+ endpoint_id = temp[1]
106
+ api_key = '{' + f'"volc_ak": "{req.get("volc_ak", "")}", ' \
107
+ f'"volc_sk": "{req.get("volc_sk", "")}", ' \
108
+ f'"ep_id": "{endpoint_id}", ' + '}'
109
+ else:
110
+ llm_name = req["llm_name"]
111
+ api_key = "xxxxxxxxxxxxxxx"
112
+
113
  llm = {
114
  "tenant_id": current_user.id,
115
+ "llm_factory": factory,
116
  "model_type": req["model_type"],
117
+ "llm_name": llm_name,
118
  "api_base": req.get("api_base", ""),
119
+ "api_key": api_key
120
  }
121
 
 
122
  msg = ""
123
  if llm["model_type"] == LLMType.EMBEDDING.value:
124
  mdl = EmbeddingModel[factory](
 
131
  msg += f"\nFail to access embedding model({llm['llm_name']})." + str(e)
132
  elif llm["model_type"] == LLMType.CHAT.value:
133
  mdl = ChatModel[factory](
134
+ key=llm['api_key'] if factory == "VolcEngine" else None,
135
+ model_name=llm["llm_name"],
136
+ base_url=llm["api_base"]
137
+ )
138
  try:
139
  m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {
140
  "temperature": 0.9})
 
150
  if msg:
151
  return get_data_error_result(retmsg=msg)
152
 
 
153
  if not TenantLLMService.filter_update(
154
  [TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == factory, TenantLLM.llm_name == llm["llm_name"]], llm):
155
  TenantLLMService.save(**llm)
api/db/init_data.py CHANGED
@@ -132,7 +132,12 @@ factory_infos = [{
132
  "logo": "",
133
  "tags": "LLM",
134
  "status": "1",
135
- },
 
 
 
 
 
136
  # {
137
  # "name": "文心一言",
138
  # "logo": "",
@@ -372,6 +377,21 @@ def init_llm_factory():
372
  "max_tokens": 16385,
373
  "model_type": LLMType.CHAT.value
374
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
375
  ]
376
  for info in factory_infos:
377
  try:
 
132
  "logo": "",
133
  "tags": "LLM",
134
  "status": "1",
135
+ },{
136
+ "name": "VolcEngine",
137
+ "logo": "",
138
+ "tags": "LLM, TEXT EMBEDDING",
139
+ "status": "1",
140
+ }
141
  # {
142
  # "name": "文心一言",
143
  # "logo": "",
 
377
  "max_tokens": 16385,
378
  "model_type": LLMType.CHAT.value
379
  },
380
+ # ------------------------ VolcEngine -----------------------
381
+ {
382
+ "fid": factory_infos[9]["name"],
383
+ "llm_name": "Skylark2-pro-32k",
384
+ "tags": "LLM,CHAT,32k",
385
+ "max_tokens": 32768,
386
+ "model_type": LLMType.CHAT.value
387
+ },
388
+ {
389
+ "fid": factory_infos[9]["name"],
390
+ "llm_name": "Skylark2-pro-4k",
391
+ "tags": "LLM,CHAT,4k",
392
+ "max_tokens": 4096,
393
+ "model_type": LLMType.CHAT.value
394
+ },
395
  ]
396
  for info in factory_infos:
397
  try:
rag/llm/chat_model.py CHANGED
@@ -19,6 +19,7 @@ from abc import ABC
19
  from openai import OpenAI
20
  import openai
21
  from ollama import Client
 
22
  from rag.nlp import is_english
23
  from rag.utils import num_tokens_from_string
24
 
@@ -315,3 +316,71 @@ class LocalLLM(Base):
315
  yield answer + "\n**ERROR**: " + str(e)
316
 
317
  yield token_count
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  from openai import OpenAI
20
  import openai
21
  from ollama import Client
22
+ from volcengine.maas.v2 import MaasService
23
  from rag.nlp import is_english
24
  from rag.utils import num_tokens_from_string
25
 
 
316
  yield answer + "\n**ERROR**: " + str(e)
317
 
318
  yield token_count
319
+
320
+
321
+ class VolcEngineChat(Base):
322
+ def __init__(self, key, model_name, base_url):
323
+ """
324
+ Since do not want to modify the original database fields, and the VolcEngine authentication method is quite special,
325
+ Assemble ak, sk, ep_id into api_key, store it as a dictionary type, and parse it for use
326
+ model_name is for display only
327
+ """
328
+ self.client = MaasService('maas-api.ml-platform-cn-beijing.volces.com', 'cn-beijing')
329
+ self.volc_ak = eval(key).get('volc_ak', '')
330
+ self.volc_sk = eval(key).get('volc_sk', '')
331
+ self.client.set_ak(self.volc_ak)
332
+ self.client.set_sk(self.volc_sk)
333
+ self.model_name = eval(key).get('ep_id', '')
334
+
335
+ def chat(self, system, history, gen_conf):
336
+ if system:
337
+ history.insert(0, {"role": "system", "content": system})
338
+ try:
339
+ req = {
340
+ "parameters": {
341
+ "min_new_tokens": gen_conf.get("min_new_tokens", 1),
342
+ "top_k": gen_conf.get("top_k", 0),
343
+ "max_prompt_tokens": gen_conf.get("max_prompt_tokens", 30000),
344
+ "temperature": gen_conf.get("temperature", 0.1),
345
+ "max_new_tokens": gen_conf.get("max_tokens", 1000),
346
+ "top_p": gen_conf.get("top_p", 0.3),
347
+ },
348
+ "messages": history
349
+ }
350
+ response = self.client.chat(self.model_name, req)
351
+ ans = response.choices[0].message.content.strip()
352
+ if response.choices[0].finish_reason == "length":
353
+ ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
354
+ [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
355
+ return ans, response.usage.total_tokens
356
+ except Exception as e:
357
+ return "**ERROR**: " + str(e), 0
358
+
359
+ def chat_streamly(self, system, history, gen_conf):
360
+ if system:
361
+ history.insert(0, {"role": "system", "content": system})
362
+ ans = ""
363
+ try:
364
+ req = {
365
+ "parameters": {
366
+ "min_new_tokens": gen_conf.get("min_new_tokens", 1),
367
+ "top_k": gen_conf.get("top_k", 0),
368
+ "max_prompt_tokens": gen_conf.get("max_prompt_tokens", 30000),
369
+ "temperature": gen_conf.get("temperature", 0.1),
370
+ "max_new_tokens": gen_conf.get("max_tokens", 1000),
371
+ "top_p": gen_conf.get("top_p", 0.3),
372
+ },
373
+ "messages": history
374
+ }
375
+ stream = self.client.stream_chat(self.model_name, req)
376
+ for resp in stream:
377
+ if not resp.choices[0].message.content:
378
+ continue
379
+ ans += resp.choices[0].message.content
380
+ yield ans
381
+ if resp.choices[0].finish_reason == "stop":
382
+ return resp.usage.total_tokens
383
+
384
+ except Exception as e:
385
+ yield ans + "\n**ERROR**: " + str(e)
386
+ yield 0
web/src/assets/svg/llm/volc_engine.svg ADDED
web/src/locales/en.ts CHANGED
@@ -477,6 +477,11 @@ The above is the content you need to summarize.`,
477
  baseUrlNameMessage: 'Please input your base url!',
478
  vision: 'Does it support Vision?',
479
  ollamaLink: 'How to integrate {{name}}',
 
 
 
 
 
480
  },
481
  message: {
482
  registered: 'Registered!',
 
477
  baseUrlNameMessage: 'Please input your base url!',
478
  vision: 'Does it support Vision?',
479
  ollamaLink: 'How to integrate {{name}}',
480
+ volcModelNameMessage: 'Please input your model name! Format: {"ModelName":"EndpointID"}',
481
+ addVolcEngineAK: 'VOLC ACCESS_KEY',
482
+ volcAKMessage: 'Please input your VOLC_ACCESS_KEY',
483
+ addVolcEngineSK: 'VOLC SECRET_KEY',
484
+ volcSKMessage: 'Please input your SECRET_KEY',
485
  },
486
  message: {
487
  registered: 'Registered!',
web/src/locales/zh-traditional.ts CHANGED
@@ -440,7 +440,12 @@ export default {
440
  modelNameMessage: '請輸入模型名稱!',
441
  modelTypeMessage: '請輸入模型類型!',
442
  baseUrlNameMessage: '請輸入基礎 Url!',
443
- ollamaLink: '如何集成Ollama',
 
 
 
 
 
444
  },
445
  message: {
446
  registered: '註冊成功',
 
440
  modelNameMessage: '請輸入模型名稱!',
441
  modelTypeMessage: '請輸入模型類型!',
442
  baseUrlNameMessage: '請輸入基礎 Url!',
443
+ ollamaLink: '如何集成 {{name}}',
444
+ volcModelNameMessage: '請輸入模型名稱!格式:{"模型名稱":"EndpointID"}',
445
+ addVolcEngineAK: '火山 ACCESS_KEY',
446
+ volcAKMessage: '請輸入VOLC_ACCESS_KEY',
447
+ addVolcEngineSK: '火山 SECRET_KEY',
448
+ volcSKMessage: '請輸入VOLC_SECRET_KEY',
449
  },
450
  message: {
451
  registered: '註冊成功',
web/src/locales/zh.ts CHANGED
@@ -458,6 +458,11 @@ export default {
458
  modelTypeMessage: '请输入模型类型!',
459
  baseUrlNameMessage: '请输入基础 Url!',
460
  ollamaLink: '如何集成 {{name}}',
 
 
 
 
 
461
  },
462
  message: {
463
  registered: '注册成功',
 
458
  modelTypeMessage: '请输入模型类型!',
459
  baseUrlNameMessage: '请输入基础 Url!',
460
  ollamaLink: '如何集成 {{name}}',
461
+ volcModelNameMessage: '请输入模型名称!格式:{"模型名称":"EndpointID"}',
462
+ addVolcEngineAK: '火山 ACCESS_KEY',
463
+ volcAKMessage: '请输入VOLC_ACCESS_KEY',
464
+ addVolcEngineSK: '火山 SECRET_KEY',
465
+ volcSKMessage: '请输入VOLC_SECRET_KEY',
466
  },
467
  message: {
468
  registered: '注册成功',
web/src/pages/user-setting/setting-model/hooks.ts CHANGED
@@ -166,6 +166,41 @@ export const useSubmitOllama = () => {
166
  };
167
  };
168
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
  export const useHandleDeleteLlm = (llmFactory: string) => {
170
  const deleteLlm = useDeleteLlm();
171
  const showDeleteConfirm = useShowDeleteConfirm();
 
166
  };
167
  };
168
 
169
+ export const useSubmitVolcEngine = () => {
170
+ const loading = useOneNamespaceEffectsLoading('settingModel', ['add_llm']);
171
+ const [selectedVolcFactory, setSelectedVolcFactory] = useState<string>('');
172
+ const addLlm = useAddLlm();
173
+ const {
174
+ visible: volcAddingVisible,
175
+ hideModal: hideVolcAddingModal,
176
+ showModal: showVolcAddingModal,
177
+ } = useSetModalState();
178
+
179
+ const onVolcAddingOk = useCallback(
180
+ async (payload: IAddLlmRequestBody) => {
181
+ const ret = await addLlm(payload);
182
+ if (ret === 0) {
183
+ hideVolcAddingModal();
184
+ }
185
+ },
186
+ [hideVolcAddingModal, addLlm],
187
+ );
188
+
189
+ const handleShowVolcAddingModal = (llmFactory: string) => {
190
+ setSelectedVolcFactory(llmFactory);
191
+ showVolcAddingModal();
192
+ };
193
+
194
+ return {
195
+ volcAddingLoading: loading,
196
+ onVolcAddingOk,
197
+ volcAddingVisible,
198
+ hideVolcAddingModal,
199
+ showVolcAddingModal: handleShowVolcAddingModal,
200
+ selectedVolcFactory,
201
+ };
202
+ };
203
+
204
  export const useHandleDeleteLlm = (llmFactory: string) => {
205
  const deleteLlm = useDeleteLlm();
206
  const showDeleteConfirm = useShowDeleteConfirm();
web/src/pages/user-setting/setting-model/index.tsx CHANGED
@@ -37,10 +37,12 @@ import {
37
  useSelectModelProvidersLoading,
38
  useSubmitApiKey,
39
  useSubmitOllama,
 
40
  useSubmitSystemModelSetting,
41
  } from './hooks';
42
  import styles from './index.less';
43
  import OllamaModal from './ollama-modal';
 
44
  import SystemModelSettingModal from './system-model-setting-modal';
45
 
46
  const IconMap = {
@@ -52,6 +54,7 @@ const IconMap = {
52
  Ollama: 'ollama',
53
  Xinference: 'xinference',
54
  DeepSeek: 'deepseek',
 
55
  };
56
 
57
  const LlmIcon = ({ name }: { name: string }) => {
@@ -165,6 +168,15 @@ const UserSettingModel = () => {
165
  selectedLlmFactory,
166
  } = useSubmitOllama();
167
 
 
 
 
 
 
 
 
 
 
168
  const handleApiKeyClick = useCallback(
169
  (llmFactory: string) => {
170
  if (isLocalLlmFactory(llmFactory)) {
@@ -179,6 +191,8 @@ const UserSettingModel = () => {
179
  const handleAddModel = (llmFactory: string) => () => {
180
  if (isLocalLlmFactory(llmFactory)) {
181
  showLlmAddingModal(llmFactory);
 
 
182
  } else {
183
  handleApiKeyClick(llmFactory);
184
  }
@@ -270,6 +284,13 @@ const UserSettingModel = () => {
270
  loading={llmAddingLoading}
271
  llmFactory={selectedLlmFactory}
272
  ></OllamaModal>
 
 
 
 
 
 
 
273
  </section>
274
  );
275
  };
 
37
  useSelectModelProvidersLoading,
38
  useSubmitApiKey,
39
  useSubmitOllama,
40
+ useSubmitVolcEngine,
41
  useSubmitSystemModelSetting,
42
  } from './hooks';
43
  import styles from './index.less';
44
  import OllamaModal from './ollama-modal';
45
+ import VolcEngineModal from "./volcengine-model";
46
  import SystemModelSettingModal from './system-model-setting-modal';
47
 
48
  const IconMap = {
 
54
  Ollama: 'ollama',
55
  Xinference: 'xinference',
56
  DeepSeek: 'deepseek',
57
+ VolcEngine: 'volc_engine',
58
  };
59
 
60
  const LlmIcon = ({ name }: { name: string }) => {
 
168
  selectedLlmFactory,
169
  } = useSubmitOllama();
170
 
171
+ const {
172
+ volcAddingVisible,
173
+ hideVolcAddingModal,
174
+ showVolcAddingModal,
175
+ onVolcAddingOk,
176
+ volcAddingLoading,
177
+ selectedVolcFactory,
178
+ } = useSubmitVolcEngine();
179
+
180
  const handleApiKeyClick = useCallback(
181
  (llmFactory: string) => {
182
  if (isLocalLlmFactory(llmFactory)) {
 
191
  const handleAddModel = (llmFactory: string) => () => {
192
  if (isLocalLlmFactory(llmFactory)) {
193
  showLlmAddingModal(llmFactory);
194
+ } else if (llmFactory === 'VolcEngine') {
195
+ showVolcAddingModal('VolcEngine');
196
  } else {
197
  handleApiKeyClick(llmFactory);
198
  }
 
284
  loading={llmAddingLoading}
285
  llmFactory={selectedLlmFactory}
286
  ></OllamaModal>
287
+ <VolcEngineModal
288
+ visible={volcAddingVisible}
289
+ hideModal={hideVolcAddingModal}
290
+ onOk={onVolcAddingOk}
291
+ loading={volcAddingLoading}
292
+ llmFactory={selectedVolcFactory}
293
+ ></VolcEngineModal>
294
  </section>
295
  );
296
  };
web/src/pages/user-setting/setting-model/volcengine-model/index.tsx ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { useTranslate } from '@/hooks/commonHooks';
2
+ import { IModalProps } from '@/interfaces/common';
3
+ import { IAddLlmRequestBody } from '@/interfaces/request/llm';
4
+ import { Flex, Form, Input, Modal, Select, Space, Switch } from 'antd';
5
+ import omit from 'lodash/omit';
6
+
7
+ type FieldType = IAddLlmRequestBody & { vision: boolean };
8
+
9
+ const { Option } = Select;
10
+
11
+ const VolcEngineModal = ({
12
+ visible,
13
+ hideModal,
14
+ onOk,
15
+ loading,
16
+ llmFactory
17
+ }: IModalProps<IAddLlmRequestBody> & { llmFactory: string }) => {
18
+ const [form] = Form.useForm<FieldType>();
19
+
20
+ const { t } = useTranslate('setting');
21
+
22
+ const handleOk = async () => {
23
+ const values = await form.validateFields();
24
+ const modelType =
25
+ values.model_type === 'chat' && values.vision
26
+ ? 'image2text'
27
+ : values.model_type;
28
+
29
+ const data = {
30
+ ...omit(values, ['vision']),
31
+ model_type: modelType,
32
+ llm_factory: llmFactory,
33
+ };
34
+ console.info(data);
35
+
36
+ onOk?.(data);
37
+ };
38
+
39
+ return (
40
+ <Modal
41
+ title={t('addLlmTitle', { name: llmFactory })}
42
+ open={visible}
43
+ onOk={handleOk}
44
+ onCancel={hideModal}
45
+ okButtonProps={{ loading }}
46
+ footer={(originNode: React.ReactNode) => {
47
+ return (
48
+ <Flex justify={'space-between'}>
49
+ <a
50
+ href="https://www.volcengine.com/docs/82379/1095322"
51
+ target="_blank"
52
+ rel="noreferrer"
53
+ >
54
+ {t('ollamaLink', { name: llmFactory })}
55
+ </a>
56
+ <Space>{originNode}</Space>
57
+ </Flex>
58
+ );
59
+ }}
60
+ >
61
+ <Form
62
+ name="basic"
63
+ style={{ maxWidth: 600 }}
64
+ autoComplete="off"
65
+ layout={'vertical'}
66
+ form={form}
67
+ >
68
+ <Form.Item<FieldType>
69
+ label={t('modelType')}
70
+ name="model_type"
71
+ initialValue={'chat'}
72
+ rules={[{ required: true, message: t('modelTypeMessage') }]}
73
+ >
74
+ <Select placeholder={t('modelTypeMessage')}>
75
+ <Option value="chat">chat</Option>
76
+ <Option value="embedding">embedding</Option>
77
+ </Select>
78
+ </Form.Item>
79
+ <Form.Item<FieldType>
80
+ label={t('modelName')}
81
+ name="llm_name"
82
+ rules={[{ required: true, message: t('volcModelNameMessage') }]}
83
+ >
84
+ <Input placeholder={t('volcModelNameMessage')} />
85
+ </Form.Item>
86
+ <Form.Item<FieldType>
87
+ label={t('addVolcEngineAK')}
88
+ name="volc_ak"
89
+ rules={[{ required: true, message: t('volcAKMessage') }]}
90
+ >
91
+ <Input placeholder={t('volcAKMessage')} />
92
+ </Form.Item>
93
+ <Form.Item<FieldType>
94
+ label={t('addVolcEngineSK')}
95
+ name="volc_sk"
96
+ rules={[{ required: true, message: t('volcAKMessage') }]}
97
+ >
98
+ <Input placeholder={t('volcAKMessage')} />
99
+ </Form.Item>
100
+ <Form.Item noStyle dependencies={['model_type']}>
101
+ {({ getFieldValue }) =>
102
+ getFieldValue('model_type') === 'chat' && (
103
+ <Form.Item
104
+ label={t('vision')}
105
+ valuePropName="checked"
106
+ name={'vision'}
107
+ >
108
+ <Switch />
109
+ </Form.Item>
110
+ )
111
+ }
112
+ </Form.Item>
113
+ </Form>
114
+ </Modal>
115
+ );
116
+ };
117
+
118
+ export default VolcEngineModal;