黄腾 aopstudio commited on
Commit
5bd5c21
·
1 Parent(s): cab96b4

add support for LocalAI (#1608)

Browse files

### What problem does this PR solve?

#762

### Type of change
- [x] New Feature (non-breaking change which adds functionality)

---------

Co-authored-by: Zhedong Cen <[email protected]>

api/apps/llm_app.py CHANGED
@@ -20,7 +20,7 @@ from api.utils.api_utils import server_error_response, get_data_error_result, va
20
  from api.db import StatusEnum, LLMType
21
  from api.db.db_models import TenantLLM
22
  from api.utils.api_utils import get_json_result
23
- from rag.llm import EmbeddingModel, ChatModel, RerankModel
24
 
25
 
26
  @manager.route('/factories', methods=['GET'])
@@ -126,6 +126,9 @@ def add_llm():
126
  api_key = '{' + f'"bedrock_ak": "{req.get("bedrock_ak", "")}", ' \
127
  f'"bedrock_sk": "{req.get("bedrock_sk", "")}", ' \
128
  f'"bedrock_region": "{req.get("bedrock_region", "")}", ' + '}'
 
 
 
129
  else:
130
  llm_name = req["llm_name"]
131
  api_key = "xxxxxxxxxxxxxxx"
@@ -176,6 +179,21 @@ def add_llm():
176
  except Exception as e:
177
  msg += f"\nFail to access model({llm['llm_name']})." + str(
178
  e)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
  else:
180
  # TODO: check other type of models
181
  pass
 
20
  from api.db import StatusEnum, LLMType
21
  from api.db.db_models import TenantLLM
22
  from api.utils.api_utils import get_json_result
23
+ from rag.llm import EmbeddingModel, ChatModel, RerankModel,CvModel
24
 
25
 
26
  @manager.route('/factories', methods=['GET'])
 
126
  api_key = '{' + f'"bedrock_ak": "{req.get("bedrock_ak", "")}", ' \
127
  f'"bedrock_sk": "{req.get("bedrock_sk", "")}", ' \
128
  f'"bedrock_region": "{req.get("bedrock_region", "")}", ' + '}'
129
+ elif factory == "LocalAI":
130
+ llm_name = req["llm_name"]+"___LocalAI"
131
+ api_key = "xxxxxxxxxxxxxxx"
132
  else:
133
  llm_name = req["llm_name"]
134
  api_key = "xxxxxxxxxxxxxxx"
 
179
  except Exception as e:
180
  msg += f"\nFail to access model({llm['llm_name']})." + str(
181
  e)
182
+ elif llm["model_type"] == LLMType.IMAGE2TEXT.value:
183
+ mdl = CvModel[factory](
184
+ key=None, model_name=llm["llm_name"], base_url=llm["api_base"]
185
+ )
186
+ try:
187
+ img_url = (
188
+ "https://upload.wikimedia.org/wikipedia/comm"
189
+ "ons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/256"
190
+ "0px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
191
+ )
192
+ m, tc = mdl.describe(img_url)
193
+ if not tc:
194
+ raise Exception(m)
195
+ except Exception as e:
196
+ msg += f"\nFail to access model({llm['llm_name']})." + str(e)
197
  else:
198
  # TODO: check other type of models
199
  pass
conf/llm_factories.json CHANGED
@@ -157,6 +157,13 @@
157
  "status": "1",
158
  "llm": []
159
  },
 
 
 
 
 
 
 
160
  {
161
  "name": "Moonshot",
162
  "logo": "",
 
157
  "status": "1",
158
  "llm": []
159
  },
160
+ {
161
+ "name": "LocalAI",
162
+ "logo": "",
163
+ "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
164
+ "status": "1",
165
+ "llm": []
166
+ },
167
  {
168
  "name": "Moonshot",
169
  "logo": "",
rag/llm/__init__.py CHANGED
@@ -21,6 +21,7 @@ from .rerank_model import *
21
 
22
  EmbeddingModel = {
23
  "Ollama": OllamaEmbed,
 
24
  "OpenAI": OpenAIEmbed,
25
  "Azure-OpenAI": AzureEmbed,
26
  "Xinference": XinferenceEmbed,
@@ -46,7 +47,8 @@ CvModel = {
46
  "ZHIPU-AI": Zhipu4V,
47
  "Moonshot": LocalCV,
48
  'Gemini':GeminiCV,
49
- 'OpenRouter':OpenRouterCV
 
50
  }
51
 
52
 
@@ -56,6 +58,7 @@ ChatModel = {
56
  "ZHIPU-AI": ZhipuChat,
57
  "Tongyi-Qianwen": QWenChat,
58
  "Ollama": OllamaChat,
 
59
  "Xinference": XinferenceChat,
60
  "Moonshot": MoonshotChat,
61
  "DeepSeek": DeepSeekChat,
@@ -67,7 +70,7 @@ ChatModel = {
67
  'Gemini' : GeminiChat,
68
  "Bedrock": BedrockChat,
69
  "Groq": GroqChat,
70
- 'OpenRouter':OpenRouterChat
71
  }
72
 
73
 
 
21
 
22
  EmbeddingModel = {
23
  "Ollama": OllamaEmbed,
24
+ "LocalAI": LocalAIEmbed,
25
  "OpenAI": OpenAIEmbed,
26
  "Azure-OpenAI": AzureEmbed,
27
  "Xinference": XinferenceEmbed,
 
47
  "ZHIPU-AI": Zhipu4V,
48
  "Moonshot": LocalCV,
49
  'Gemini':GeminiCV,
50
+ 'OpenRouter':OpenRouterCV,
51
+ "LocalAI":LocalAICV
52
  }
53
 
54
 
 
58
  "ZHIPU-AI": ZhipuChat,
59
  "Tongyi-Qianwen": QWenChat,
60
  "Ollama": OllamaChat,
61
+ "LocalAI": LocalAIChat,
62
  "Xinference": XinferenceChat,
63
  "Moonshot": MoonshotChat,
64
  "DeepSeek": DeepSeekChat,
 
70
  'Gemini' : GeminiChat,
71
  "Bedrock": BedrockChat,
72
  "Groq": GroqChat,
73
+ 'OpenRouter':OpenRouterChat,
74
  }
75
 
76
 
rag/llm/chat_model.py CHANGED
@@ -348,6 +348,82 @@ class OllamaChat(Base):
348
  yield 0
349
 
350
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
351
  class LocalLLM(Base):
352
  class RPCProxy:
353
  def __init__(self, host, port):
 
348
  yield 0
349
 
350
 
351
+ class LocalAIChat(Base):
352
+ def __init__(self, key, model_name, base_url):
353
+ if base_url[-1] == "/":
354
+ base_url = base_url[:-1]
355
+ self.base_url = base_url + "/v1/chat/completions"
356
+ self.model_name = model_name.split("___")[0]
357
+
358
+ def chat(self, system, history, gen_conf):
359
+ if system:
360
+ history.insert(0, {"role": "system", "content": system})
361
+ for k in list(gen_conf.keys()):
362
+ if k not in ["temperature", "top_p", "max_tokens"]:
363
+ del gen_conf[k]
364
+ headers = {
365
+ "Content-Type": "application/json",
366
+ }
367
+ payload = json.dumps(
368
+ {"model": self.model_name, "messages": history, **gen_conf}
369
+ )
370
+ try:
371
+ response = requests.request(
372
+ "POST", url=self.base_url, headers=headers, data=payload
373
+ )
374
+ response = response.json()
375
+ ans = response["choices"][0]["message"]["content"].strip()
376
+ if response["choices"][0]["finish_reason"] == "length":
377
+ ans += (
378
+ "...\nFor the content length reason, it stopped, continue?"
379
+ if is_english([ans])
380
+ else "······\n由于长度的原因,回答被截断了,要继续吗?"
381
+ )
382
+ return ans, response["usage"]["total_tokens"]
383
+ except Exception as e:
384
+ return "**ERROR**: " + str(e), 0
385
+
386
+ def chat_streamly(self, system, history, gen_conf):
387
+ if system:
388
+ history.insert(0, {"role": "system", "content": system})
389
+ ans = ""
390
+ total_tokens = 0
391
+ try:
392
+ headers = {
393
+ "Content-Type": "application/json",
394
+ }
395
+ payload = json.dumps(
396
+ {
397
+ "model": self.model_name,
398
+ "messages": history,
399
+ "stream": True,
400
+ **gen_conf,
401
+ }
402
+ )
403
+ response = requests.request(
404
+ "POST",
405
+ url=self.base_url,
406
+ headers=headers,
407
+ data=payload,
408
+ )
409
+ for resp in response.content.decode("utf-8").split("\n\n"):
410
+ if "choices" not in resp:
411
+ continue
412
+ resp = json.loads(resp[6:])
413
+ if "delta" in resp["choices"][0]:
414
+ text = resp["choices"][0]["delta"]["content"]
415
+ else:
416
+ continue
417
+ ans += text
418
+ total_tokens += 1
419
+ yield ans
420
+
421
+ except Exception as e:
422
+ yield ans + "\n**ERROR**: " + str(e)
423
+
424
+ yield total_tokens
425
+
426
+
427
  class LocalLLM(Base):
428
  class RPCProxy:
429
  def __init__(self, host, port):
rag/llm/cv_model.py CHANGED
@@ -189,6 +189,35 @@ class OllamaCV(Base):
189
  return "**ERROR**: " + str(e), 0
190
 
191
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
  class XinferenceCV(Base):
193
  def __init__(self, key, model_name="", lang="Chinese", base_url=""):
194
  self.client = OpenAI(api_key="xxx", base_url=base_url)
 
189
  return "**ERROR**: " + str(e), 0
190
 
191
 
192
+ class LocalAICV(Base):
193
+ def __init__(self, key, model_name, base_url, lang="Chinese"):
194
+ self.client = OpenAI(api_key="empty", base_url=base_url)
195
+ self.model_name = model_name.split("___")[0]
196
+ self.lang = lang
197
+
198
+ def describe(self, image, max_tokens=300):
199
+ if not isinstance(image, bytes) and not isinstance(
200
+ image, BytesIO
201
+ ): # if url string
202
+ prompt = self.prompt(image)
203
+ for i in range(len(prompt)):
204
+ prompt[i]["content"]["image_url"]["url"] = image
205
+ else:
206
+ b64 = self.image2base64(image)
207
+ prompt = self.prompt(b64)
208
+ for i in range(len(prompt)):
209
+ for c in prompt[i]["content"]:
210
+ if "text" in c:
211
+ c["type"] = "text"
212
+
213
+ res = self.client.chat.completions.create(
214
+ model=self.model_name,
215
+ messages=prompt,
216
+ max_tokens=max_tokens,
217
+ )
218
+ return res.choices[0].message.content.strip(), res.usage.total_tokens
219
+
220
+
221
  class XinferenceCV(Base):
222
  def __init__(self, key, model_name="", lang="Chinese", base_url=""):
223
  self.client = OpenAI(api_key="xxx", base_url=base_url)
rag/llm/embedding_model.py CHANGED
@@ -111,6 +111,24 @@ class OpenAIEmbed(Base):
111
  return np.array(res.data[0].embedding), res.usage.total_tokens
112
 
113
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  class AzureEmbed(OpenAIEmbed):
115
  def __init__(self, key, model_name, **kwargs):
116
  self.client = AzureOpenAI(api_key=key, azure_endpoint=kwargs["base_url"], api_version="2024-02-01")
@@ -443,4 +461,4 @@ class GeminiEmbed(Base):
443
  task_type="retrieval_document",
444
  title="Embedding of single string")
445
  token_count = num_tokens_from_string(text)
446
- return np.array(result['embedding']),token_count
 
111
  return np.array(res.data[0].embedding), res.usage.total_tokens
112
 
113
 
114
+ class LocalAIEmbed(Base):
115
+ def __init__(self, key, model_name, base_url):
116
+ self.base_url = base_url + "/embeddings"
117
+ self.headers = {
118
+ "Content-Type": "application/json",
119
+ }
120
+ self.model_name = model_name.split("___")[0]
121
+
122
+ def encode(self, texts: list, batch_size=None):
123
+ data = {"model": self.model_name, "input": texts, "encoding_type": "float"}
124
+ res = requests.post(self.base_url, headers=self.headers, json=data).json()
125
+
126
+ return np.array([d["embedding"] for d in res["data"]]), 1024
127
+
128
+ def encode_queries(self, text):
129
+ embds, cnt = self.encode([text])
130
+ return np.array(embds[0]), cnt
131
+
132
  class AzureEmbed(OpenAIEmbed):
133
  def __init__(self, key, model_name, **kwargs):
134
  self.client = AzureOpenAI(api_key=key, azure_endpoint=kwargs["base_url"], api_version="2024-02-01")
 
461
  task_type="retrieval_document",
462
  title="Embedding of single string")
463
  token_count = num_tokens_from_string(text)
464
+ return np.array(result['embedding']),token_count
rag/llm/rerank_model.py CHANGED
@@ -135,7 +135,7 @@ class YoudaoRerank(DefaultRerank):
135
  if isinstance(scores, float): res.append(scores)
136
  else: res.extend(scores)
137
  return np.array(res), token_count
138
-
139
 
140
  class XInferenceRerank(Base):
141
  def __init__(self, key="xxxxxxx", model_name="", base_url=""):
@@ -156,3 +156,11 @@ class XInferenceRerank(Base):
156
  }
157
  res = requests.post(self.base_url, headers=self.headers, json=data).json()
158
  return np.array([d["relevance_score"] for d in res["results"]]), res["meta"]["tokens"]["input_tokens"]+res["meta"]["tokens"]["output_tokens"]
 
 
 
 
 
 
 
 
 
135
  if isinstance(scores, float): res.append(scores)
136
  else: res.extend(scores)
137
  return np.array(res), token_count
138
+
139
 
140
  class XInferenceRerank(Base):
141
  def __init__(self, key="xxxxxxx", model_name="", base_url=""):
 
156
  }
157
  res = requests.post(self.base_url, headers=self.headers, json=data).json()
158
  return np.array([d["relevance_score"] for d in res["results"]]), res["meta"]["tokens"]["input_tokens"]+res["meta"]["tokens"]["output_tokens"]
159
+
160
+
161
+ class LocalAIRerank(Base):
162
+ def __init__(self, key, model_name, base_url):
163
+ pass
164
+
165
+ def similarity(self, query: str, texts: list):
166
+ raise NotImplementedError("The LocalAIRerank has not been implement")
web/src/pages/user-setting/constants.tsx CHANGED
@@ -17,4 +17,4 @@ export const UserSettingIconMap = {
17
 
18
  export * from '@/constants/setting';
19
 
20
- export const LocalLlmFactories = ['Ollama', 'Xinference'];
 
17
 
18
  export * from '@/constants/setting';
19
 
20
+ export const LocalLlmFactories = ['Ollama', 'Xinference','LocalAI'];
web/src/pages/user-setting/setting-model/ollama-modal/index.tsx CHANGED
@@ -75,6 +75,7 @@ const OllamaModal = ({
75
  <Option value="chat">chat</Option>
76
  <Option value="embedding">embedding</Option>
77
  <Option value="rerank">rerank</Option>
 
78
  </Select>
79
  </Form.Item>
80
  <Form.Item<FieldType>
 
75
  <Option value="chat">chat</Option>
76
  <Option value="embedding">embedding</Option>
77
  <Option value="rerank">rerank</Option>
78
+ <Option value="image2text">image2text</Option>
79
  </Select>
80
  </Form.Item>
81
  <Form.Item<FieldType>