黄腾 aopstudio commited on
Commit
1f5bc27
·
1 Parent(s): 1164cba

add support for Gemini (#1465)

Browse files

### What problem does this PR solve?

#1036

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

Co-authored-by: Zhedong Cen <[email protected]>

api/db/init_data.py CHANGED
@@ -175,6 +175,11 @@ factory_infos = [{
175
  "logo": "",
176
  "tags": "LLM,TEXT EMBEDDING",
177
  "status": "1",
 
 
 
 
 
178
  }
179
  # {
180
  # "name": "文心一言",
@@ -898,7 +903,37 @@ def init_llm_factory():
898
  "tags": "TEXT EMBEDDING",
899
  "max_tokens": 2048,
900
  "model_type": LLMType.EMBEDDING.value
901
- },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
902
  ]
903
  for info in factory_infos:
904
  try:
 
175
  "logo": "",
176
  "tags": "LLM,TEXT EMBEDDING",
177
  "status": "1",
178
+ },{
179
+ "name": "Gemini",
180
+ "logo": "",
181
+ "tags": "LLM,TEXT EMBEDDING,IMAGE2TEXT",
182
+ "status": "1",
183
  }
184
  # {
185
  # "name": "文心一言",
 
903
  "tags": "TEXT EMBEDDING",
904
  "max_tokens": 2048,
905
  "model_type": LLMType.EMBEDDING.value
906
+ }, {
907
+ "fid": factory_infos[17]["name"],
908
+ "llm_name": "gemini-1.5-pro-latest",
909
+ "tags": "LLM,CHAT,1024K",
910
+ "max_tokens": 1024*1024,
911
+ "model_type": LLMType.CHAT.value
912
+ }, {
913
+ "fid": factory_infos[17]["name"],
914
+ "llm_name": "gemini-1.5-flash-latest",
915
+ "tags": "LLM,CHAT,1024K",
916
+ "max_tokens": 1024*1024,
917
+ "model_type": LLMType.CHAT.value
918
+ }, {
919
+ "fid": factory_infos[17]["name"],
920
+ "llm_name": "gemini-1.0-pro",
921
+ "tags": "LLM,CHAT,30K",
922
+ "max_tokens": 30*1024,
923
+ "model_type": LLMType.CHAT.value
924
+ }, {
925
+ "fid": factory_infos[17]["name"],
926
+ "llm_name": "gemini-1.0-pro-vision-latest",
927
+ "tags": "LLM,IMAGE2TEXT,12K",
928
+ "max_tokens": 12*1024,
929
+ "model_type": LLMType.IMAGE2TEXT.value
930
+ }, {
931
+ "fid": factory_infos[17]["name"],
932
+ "llm_name": "text-embedding-004",
933
+ "tags": "TEXT EMBEDDING",
934
+ "max_tokens": 2048,
935
+ "model_type": LLMType.EMBEDDING.value
936
+ }
937
  ]
938
  for info in factory_infos:
939
  try:
rag/llm/chat_model.py CHANGED
@@ -621,3 +621,64 @@ class BedrockChat(Base):
621
  yield ans + f"ERROR: Can't invoke '{self.model_name}'. Reason: {e}"
622
 
623
  yield num_tokens_from_string(ans)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
621
  yield ans + f"ERROR: Can't invoke '{self.model_name}'. Reason: {e}"
622
 
623
  yield num_tokens_from_string(ans)
624
+
625
+ class GeminiChat(Base):
626
+
627
+ def __init__(self, key, model_name,base_url=None):
628
+ from google.generativeai import client,GenerativeModel
629
+
630
+ client.configure(api_key=key)
631
+ _client = client.get_default_generative_client()
632
+ self.model_name = 'models/' + model_name
633
+ self.model = GenerativeModel(model_name=self.model_name)
634
+ self.model._client = _client
635
+
636
+ def chat(self,system,history,gen_conf):
637
+ if system:
638
+ history.insert(0, {"role": "user", "parts": system})
639
+ if 'max_tokens' in gen_conf:
640
+ gen_conf['max_output_tokens'] = gen_conf['max_tokens']
641
+ for k in list(gen_conf.keys()):
642
+ if k not in ["temperature", "top_p", "max_output_tokens"]:
643
+ del gen_conf[k]
644
+ for item in history:
645
+ if 'role' in item and item['role'] == 'assistant':
646
+ item['role'] = 'model'
647
+ if 'content' in item :
648
+ item['parts'] = item.pop('content')
649
+
650
+ try:
651
+ response = self.model.generate_content(
652
+ history,
653
+ generation_config=gen_conf)
654
+ ans = response.text
655
+ return ans, response.usage_metadata.total_token_count
656
+ except Exception as e:
657
+ return "**ERROR**: " + str(e), 0
658
+
659
+ def chat_streamly(self, system, history, gen_conf):
660
+ if system:
661
+ history.insert(0, {"role": "user", "parts": system})
662
+ if 'max_tokens' in gen_conf:
663
+ gen_conf['max_output_tokens'] = gen_conf['max_tokens']
664
+ for k in list(gen_conf.keys()):
665
+ if k not in ["temperature", "top_p", "max_output_tokens"]:
666
+ del gen_conf[k]
667
+ for item in history:
668
+ if 'role' in item and item['role'] == 'assistant':
669
+ item['role'] = 'model'
670
+ if 'content' in item :
671
+ item['parts'] = item.pop('content')
672
+ ans = ""
673
+ try:
674
+ response = self.model.generate_content(
675
+ history,
676
+ generation_config=gen_conf,stream=True)
677
+ for resp in response:
678
+ ans += resp.text
679
+ yield ans
680
+
681
+ except Exception as e:
682
+ yield ans + "\n**ERROR**: " + str(e)
683
+
684
+ yield response._chunks[-1].usage_metadata.total_token_count
rag/llm/cv_model.py CHANGED
@@ -203,6 +203,29 @@ class XinferenceCV(Base):
203
  )
204
  return res.choices[0].message.content.strip(), res.usage.total_tokens
205
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
 
207
  class LocalCV(Base):
208
  def __init__(self, key, model_name="glm-4v", lang="Chinese", **kwargs):
 
203
  )
204
  return res.choices[0].message.content.strip(), res.usage.total_tokens
205
 
206
+ class GeminiCV(Base):
207
+ def __init__(self, key, model_name="gemini-1.0-pro-vision-latest", lang="Chinese", **kwargs):
208
+ from google.generativeai import client,GenerativeModel
209
+ client.configure(api_key=key)
210
+ _client = client.get_default_generative_client()
211
+ self.model_name = model_name
212
+ self.model = GenerativeModel(model_name=self.model_name)
213
+ self.model._client = _client
214
+ self.lang = lang
215
+
216
+ def describe(self, image, max_tokens=2048):
217
+ from PIL.Image import open
218
+ gen_config = {'max_output_tokens':max_tokens}
219
+ prompt = "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else \
220
+ "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out."
221
+ b64 = self.image2base64(image)
222
+ img = open(BytesIO(base64.b64decode(b64)))
223
+ input = [prompt,img]
224
+ res = self.model.generate_content(
225
+ input,
226
+ generation_config=gen_config,
227
+ )
228
+ return res.text,res.usage_metadata.total_token_count
229
 
230
  class LocalCV(Base):
231
  def __init__(self, key, model_name="glm-4v", lang="Chinese", **kwargs):
rag/llm/embedding_model.py CHANGED
@@ -31,7 +31,7 @@ import numpy as np
31
  import asyncio
32
  from api.utils.file_utils import get_home_cache_dir
33
  from rag.utils import num_tokens_from_string, truncate
34
-
35
 
36
  class Base(ABC):
37
  def __init__(self, key, model_name):
@@ -419,3 +419,27 @@ class BedrockEmbed(Base):
419
 
420
  return np.array(embeddings), token_count
421
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  import asyncio
32
  from api.utils.file_utils import get_home_cache_dir
33
  from rag.utils import num_tokens_from_string, truncate
34
+ import google.generativeai as genai
35
 
36
  class Base(ABC):
37
  def __init__(self, key, model_name):
 
419
 
420
  return np.array(embeddings), token_count
421
 
422
+ class GeminiEmbed(Base):
423
+ def __init__(self, key, model_name='models/text-embedding-004',
424
+ **kwargs):
425
+ genai.configure(api_key=key)
426
+ self.model_name = 'models/' + model_name
427
+
428
+ def encode(self, texts: list, batch_size=32):
429
+ texts = [truncate(t, 2048) for t in texts]
430
+ token_count = sum(num_tokens_from_string(text) for text in texts)
431
+ result = genai.embed_content(
432
+ model=self.model_name,
433
+ content=texts,
434
+ task_type="retrieval_document",
435
+ title="Embedding of list of strings")
436
+ return np.array(result['embedding']),token_count
437
+
438
+ def encode_queries(self, text):
439
+ result = genai.embed_content(
440
+ model=self.model_name,
441
+ content=truncate(text,2048),
442
+ task_type="retrieval_document",
443
+ title="Embedding of single string")
444
+ token_count = num_tokens_from_string(text)
445
+ return np.array(result['embedding']),token_count
requirements.txt CHANGED
@@ -147,3 +147,4 @@ markdown==3.6
147
  mistralai==0.4.2
148
  boto3==1.34.140
149
  duckduckgo_search==6.1.9
 
 
147
  mistralai==0.4.2
148
  boto3==1.34.140
149
  duckduckgo_search==6.1.9
150
+ google-generativeai==0.7.2
requirements_arm.txt CHANGED
@@ -148,3 +148,4 @@ markdown==3.6
148
  mistralai==0.4.2
149
  boto3==1.34.140
150
  duckduckgo_search==6.1.9
 
 
148
  mistralai==0.4.2
149
  boto3==1.34.140
150
  duckduckgo_search==6.1.9
151
+ google-generativeai==0.7.2
requirements_dev.txt CHANGED
@@ -133,3 +133,4 @@ markdown==3.6
133
  mistralai==0.4.2
134
  boto3==1.34.140
135
  duckduckgo_search==6.1.9
 
 
133
  mistralai==0.4.2
134
  boto3==1.34.140
135
  duckduckgo_search==6.1.9
136
+ google-generativeai==0.7.2
web/src/assets/svg/llm/gemini.svg ADDED
web/src/pages/user-setting/setting-model/index.tsx CHANGED
@@ -61,6 +61,7 @@ const IconMap = {
61
  Mistral: 'mistral',
62
  'Azure-OpenAI': 'azure',
63
  Bedrock: 'bedrock',
 
64
  };
65
 
66
  const LlmIcon = ({ name }: { name: string }) => {
 
61
  Mistral: 'mistral',
62
  'Azure-OpenAI': 'azure',
63
  Bedrock: 'bedrock',
64
+ Gemini:'gemini',
65
  };
66
 
67
  const LlmIcon = ({ name }: { name: string }) => {