黄腾 aopstudio commited on
Commit
5e7d900
·
1 Parent(s): 10534c3

add support for LM Studio (#1663)

Browse files

### What problem does this PR solve?

#1602

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

---------

Co-authored-by: Zhedong Cen <[email protected]>

api/apps/llm_app.py CHANGED
@@ -21,7 +21,7 @@ from api.db import StatusEnum, LLMType
21
  from api.db.db_models import TenantLLM
22
  from api.utils.api_utils import get_json_result
23
  from rag.llm import EmbeddingModel, ChatModel, RerankModel,CvModel
24
-
25
 
26
  @manager.route('/factories', methods=['GET'])
27
  @login_required
@@ -189,9 +189,13 @@ def add_llm():
189
  "ons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/256"
190
  "0px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
191
  )
192
- m, tc = mdl.describe(img_url)
193
- if not tc:
194
- raise Exception(m)
 
 
 
 
195
  except Exception as e:
196
  msg += f"\nFail to access model({llm['llm_name']})." + str(e)
197
  else:
 
21
  from api.db.db_models import TenantLLM
22
  from api.utils.api_utils import get_json_result
23
  from rag.llm import EmbeddingModel, ChatModel, RerankModel,CvModel
24
+ import requests
25
 
26
  @manager.route('/factories', methods=['GET'])
27
  @login_required
 
189
  "ons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/256"
190
  "0px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
191
  )
192
+ res = requests.get(img_url)
193
+ if res.status_code == 200:
194
+ m, tc = mdl.describe(res.content)
195
+ if not tc:
196
+ raise Exception(m)
197
+ else:
198
+ raise ConnectionError("fail to download the test picture")
199
  except Exception as e:
200
  msg += f"\nFail to access model({llm['llm_name']})." + str(e)
201
  else:
conf/llm_factories.json CHANGED
@@ -2208,6 +2208,13 @@
2208
  "model_type": "image2text"
2209
  }
2210
  ]
 
 
 
 
 
 
 
2211
  }
2212
  ]
2213
  }
 
2208
  "model_type": "image2text"
2209
  }
2210
  ]
2211
+ },
2212
+ {
2213
+ "name": "LM-Studio",
2214
+ "logo": "",
2215
+ "tags": "LLM,TEXT EMBEDDING,IMAGE2TEXT",
2216
+ "status": "1",
2217
+ "llm": []
2218
  }
2219
  ]
2220
  }
rag/llm/__init__.py CHANGED
@@ -34,8 +34,9 @@ EmbeddingModel = {
34
  "BAAI": DefaultEmbedding,
35
  "Mistral": MistralEmbed,
36
  "Bedrock": BedrockEmbed,
37
- "Gemini":GeminiEmbed,
38
- "NVIDIA":NvidiaEmbed
 
39
  }
40
 
41
 
@@ -47,10 +48,11 @@ CvModel = {
47
  "Tongyi-Qianwen": QWenCV,
48
  "ZHIPU-AI": Zhipu4V,
49
  "Moonshot": LocalCV,
50
- 'Gemini':GeminiCV,
51
- 'OpenRouter':OpenRouterCV,
52
- "LocalAI":LocalAICV,
53
- "NVIDIA":NvidiaCV
 
54
  }
55
 
56
 
@@ -69,12 +71,13 @@ ChatModel = {
69
  "MiniMax": MiniMaxChat,
70
  "Minimax": MiniMaxChat,
71
  "Mistral": MistralChat,
72
- 'Gemini' : GeminiChat,
73
  "Bedrock": BedrockChat,
74
  "Groq": GroqChat,
75
- 'OpenRouter':OpenRouterChat,
76
- "StepFun":StepFunChat,
77
- "NVIDIA":NvidiaChat
 
78
  }
79
 
80
 
@@ -83,7 +86,8 @@ RerankModel = {
83
  "Jina": JinaRerank,
84
  "Youdao": YoudaoRerank,
85
  "Xinference": XInferenceRerank,
86
- "NVIDIA":NvidiaRerank
 
87
  }
88
 
89
 
 
34
  "BAAI": DefaultEmbedding,
35
  "Mistral": MistralEmbed,
36
  "Bedrock": BedrockEmbed,
37
+ "Gemini": GeminiEmbed,
38
+ "NVIDIA": NvidiaEmbed,
39
+ "LM-Studio": LmStudioEmbed
40
  }
41
 
42
 
 
48
  "Tongyi-Qianwen": QWenCV,
49
  "ZHIPU-AI": Zhipu4V,
50
  "Moonshot": LocalCV,
51
+ "Gemini": GeminiCV,
52
+ "OpenRouter": OpenRouterCV,
53
+ "LocalAI": LocalAICV,
54
+ "NVIDIA": NvidiaCV,
55
+ "LM-Studio": LmStudioCV
56
  }
57
 
58
 
 
71
  "MiniMax": MiniMaxChat,
72
  "Minimax": MiniMaxChat,
73
  "Mistral": MistralChat,
74
+ "Gemini": GeminiChat,
75
  "Bedrock": BedrockChat,
76
  "Groq": GroqChat,
77
+ "OpenRouter": OpenRouterChat,
78
+ "StepFun": StepFunChat,
79
+ "NVIDIA": NvidiaChat,
80
+ "LM-Studio": LmStudioChat
81
  }
82
 
83
 
 
86
  "Jina": JinaRerank,
87
  "Youdao": YoudaoRerank,
88
  "Xinference": XInferenceRerank,
89
+ "NVIDIA": NvidiaRerank,
90
+ "LM-Studio": LmStudioRerank
91
  }
92
 
93
 
rag/llm/chat_model.py CHANGED
@@ -976,3 +976,15 @@ class NvidiaChat(Base):
976
  yield ans + "\n**ERROR**: " + str(e)
977
 
978
  yield total_tokens
 
 
 
 
 
 
 
 
 
 
 
 
 
976
  yield ans + "\n**ERROR**: " + str(e)
977
 
978
  yield total_tokens
979
+
980
+
981
+ class LmStudioChat(Base):
982
+ def __init__(self, key, model_name, base_url):
983
+ from os.path import join
984
+
985
+ if not base_url:
986
+ raise ValueError("Local llm url cannot be None")
987
+ if base_url.split("/")[-1] != "v1":
988
+ self.base_url = join(base_url, "v1")
989
+ self.client = OpenAI(api_key="lm-studio", base_url=self.base_url)
990
+ self.model_name = model_name
rag/llm/cv_model.py CHANGED
@@ -440,15 +440,8 @@ class LocalAICV(Base):
440
  self.lang = lang
441
 
442
  def describe(self, image, max_tokens=300):
443
- if not isinstance(image, bytes) and not isinstance(
444
- image, BytesIO
445
- ): # if url string
446
- prompt = self.prompt(image)
447
- for i in range(len(prompt)):
448
- prompt[i]["content"]["image_url"]["url"] = image
449
- else:
450
- b64 = self.image2base64(image)
451
- prompt = self.prompt(b64)
452
  for i in range(len(prompt)):
453
  for c in prompt[i]["content"]:
454
  if "text" in c:
@@ -680,3 +673,14 @@ class NvidiaCV(Base):
680
  "content": text + f' <img src="data:image/jpeg;base64,{b64}"/>',
681
  }
682
  ]
 
 
 
 
 
 
 
 
 
 
 
 
440
  self.lang = lang
441
 
442
  def describe(self, image, max_tokens=300):
443
+ b64 = self.image2base64(image)
444
+ prompt = self.prompt(b64)
 
 
 
 
 
 
 
445
  for i in range(len(prompt)):
446
  for c in prompt[i]["content"]:
447
  if "text" in c:
 
673
  "content": text + f' <img src="data:image/jpeg;base64,{b64}"/>',
674
  }
675
  ]
676
+
677
+
678
+ class LmStudioCV(LocalAICV):
679
+ def __init__(self, key, model_name, base_url, lang="Chinese"):
680
+ if not base_url:
681
+ raise ValueError("Local llm url cannot be None")
682
+ if base_url.split('/')[-1] != 'v1':
683
+ self.base_url = os.path.join(base_url,'v1')
684
+ self.client = OpenAI(api_key="lm-studio", base_url=self.base_url)
685
+ self.model_name = model_name
686
+ self.lang = lang
rag/llm/embedding_model.py CHANGED
@@ -500,3 +500,24 @@ class NvidiaEmbed(Base):
500
  def encode_queries(self, text):
501
  embds, cnt = self.encode([text])
502
  return np.array(embds[0]), cnt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
500
  def encode_queries(self, text):
501
  embds, cnt = self.encode([text])
502
  return np.array(embds[0]), cnt
503
+
504
+
505
+ class LmStudioEmbed(Base):
506
+ def __init__(self, key, model_name, base_url):
507
+ if not base_url:
508
+ raise ValueError("Local llm url cannot be None")
509
+ if base_url.split("/")[-1] != "v1":
510
+ self.base_url = os.path.join(base_url, "v1")
511
+ self.client = OpenAI(api_key="lm-studio", base_url=self.base_url)
512
+ self.model_name = model_name
513
+
514
+ def encode(self, texts: list, batch_size=32):
515
+ res = self.client.embeddings.create(input=texts, model=self.model_name)
516
+ return (
517
+ np.array([d.embedding for d in res.data]),
518
+ 1024,
519
+ ) # local embedding for LmStudio donot count tokens
520
+
521
+ def encode_queries(self, text):
522
+ res = self.client.embeddings.create(text, model=self.model_name)
523
+ return np.array(res.data[0].embedding), 1024
rag/llm/rerank_model.py CHANGED
@@ -202,3 +202,11 @@ class NvidiaRerank(Base):
202
  }
203
  res = requests.post(self.base_url, headers=self.headers, json=data).json()
204
  return (np.array([d["logit"] for d in res["rankings"]]), token_count)
 
 
 
 
 
 
 
 
 
202
  }
203
  res = requests.post(self.base_url, headers=self.headers, json=data).json()
204
  return (np.array([d["logit"] for d in res["rankings"]]), token_count)
205
+
206
+
207
+ class LmStudioRerank(Base):
208
+ def __init__(self, key, model_name, base_url):
209
+ pass
210
+
211
+ def similarity(self, query: str, texts: list):
212
+ raise NotImplementedError("The LmStudioRerank has not been implement")
web/src/assets/svg/llm/lm-studio.svg ADDED
web/src/pages/user-setting/constants.tsx CHANGED
@@ -17,4 +17,4 @@ export const UserSettingIconMap = {
17
 
18
  export * from '@/constants/setting';
19
 
20
- export const LocalLlmFactories = ['Ollama', 'Xinference','LocalAI'];
 
17
 
18
  export * from '@/constants/setting';
19
 
20
+ export const LocalLlmFactories = ['Ollama', 'Xinference','LocalAI','LM-Studio'];
web/src/pages/user-setting/setting-model/constant.ts CHANGED
@@ -20,7 +20,8 @@ export const IconMap = {
20
  OpenRouter: 'open-router',
21
  LocalAI: 'local-ai',
22
  StepFun: 'stepfun',
23
- NVIDIA:'nvidia'
 
24
  };
25
 
26
  export const BedrockRegionList = [
 
20
  OpenRouter: 'open-router',
21
  LocalAI: 'local-ai',
22
  StepFun: 'stepfun',
23
+ NVIDIA:'nvidia',
24
+ 'LM-Studio':'lm-studio'
25
  };
26
 
27
  export const BedrockRegionList = [