add support for Gemini (#1465)
Browse files### What problem does this PR solve?
#1036
### Type of change
- [x] New Feature (non-breaking change which adds functionality)
Co-authored-by: Zhedong Cen <[email protected]>
- api/db/init_data.py +36 -1
- rag/llm/chat_model.py +61 -0
- rag/llm/cv_model.py +23 -0
- rag/llm/embedding_model.py +25 -1
- requirements.txt +1 -0
- requirements_arm.txt +1 -0
- requirements_dev.txt +1 -0
- web/src/assets/svg/llm/gemini.svg +114 -0
- web/src/pages/user-setting/setting-model/index.tsx +1 -0
api/db/init_data.py
CHANGED
@@ -175,6 +175,11 @@ factory_infos = [{
|
|
175 |
"logo": "",
|
176 |
"tags": "LLM,TEXT EMBEDDING",
|
177 |
"status": "1",
|
|
|
|
|
|
|
|
|
|
|
178 |
}
|
179 |
# {
|
180 |
# "name": "文心一言",
|
@@ -898,7 +903,37 @@ def init_llm_factory():
|
|
898 |
"tags": "TEXT EMBEDDING",
|
899 |
"max_tokens": 2048,
|
900 |
"model_type": LLMType.EMBEDDING.value
|
901 |
-
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
902 |
]
|
903 |
for info in factory_infos:
|
904 |
try:
|
|
|
175 |
"logo": "",
|
176 |
"tags": "LLM,TEXT EMBEDDING",
|
177 |
"status": "1",
|
178 |
+
},{
|
179 |
+
"name": "Gemini",
|
180 |
+
"logo": "",
|
181 |
+
"tags": "LLM,TEXT EMBEDDING,IMAGE2TEXT",
|
182 |
+
"status": "1",
|
183 |
}
|
184 |
# {
|
185 |
# "name": "文心一言",
|
|
|
903 |
"tags": "TEXT EMBEDDING",
|
904 |
"max_tokens": 2048,
|
905 |
"model_type": LLMType.EMBEDDING.value
|
906 |
+
}, {
|
907 |
+
"fid": factory_infos[17]["name"],
|
908 |
+
"llm_name": "gemini-1.5-pro-latest",
|
909 |
+
"tags": "LLM,CHAT,1024K",
|
910 |
+
"max_tokens": 1024*1024,
|
911 |
+
"model_type": LLMType.CHAT.value
|
912 |
+
}, {
|
913 |
+
"fid": factory_infos[17]["name"],
|
914 |
+
"llm_name": "gemini-1.5-flash-latest",
|
915 |
+
"tags": "LLM,CHAT,1024K",
|
916 |
+
"max_tokens": 1024*1024,
|
917 |
+
"model_type": LLMType.CHAT.value
|
918 |
+
}, {
|
919 |
+
"fid": factory_infos[17]["name"],
|
920 |
+
"llm_name": "gemini-1.0-pro",
|
921 |
+
"tags": "LLM,CHAT,30K",
|
922 |
+
"max_tokens": 30*1024,
|
923 |
+
"model_type": LLMType.CHAT.value
|
924 |
+
}, {
|
925 |
+
"fid": factory_infos[17]["name"],
|
926 |
+
"llm_name": "gemini-1.0-pro-vision-latest",
|
927 |
+
"tags": "LLM,IMAGE2TEXT,12K",
|
928 |
+
"max_tokens": 12*1024,
|
929 |
+
"model_type": LLMType.IMAGE2TEXT.value
|
930 |
+
}, {
|
931 |
+
"fid": factory_infos[17]["name"],
|
932 |
+
"llm_name": "text-embedding-004",
|
933 |
+
"tags": "TEXT EMBEDDING",
|
934 |
+
"max_tokens": 2048,
|
935 |
+
"model_type": LLMType.EMBEDDING.value
|
936 |
+
}
|
937 |
]
|
938 |
for info in factory_infos:
|
939 |
try:
|
rag/llm/chat_model.py
CHANGED
@@ -621,3 +621,64 @@ class BedrockChat(Base):
|
|
621 |
yield ans + f"ERROR: Can't invoke '{self.model_name}'. Reason: {e}"
|
622 |
|
623 |
yield num_tokens_from_string(ans)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
621 |
yield ans + f"ERROR: Can't invoke '{self.model_name}'. Reason: {e}"
|
622 |
|
623 |
yield num_tokens_from_string(ans)
|
624 |
+
|
625 |
+
class GeminiChat(Base):
|
626 |
+
|
627 |
+
def __init__(self, key, model_name,base_url=None):
|
628 |
+
from google.generativeai import client,GenerativeModel
|
629 |
+
|
630 |
+
client.configure(api_key=key)
|
631 |
+
_client = client.get_default_generative_client()
|
632 |
+
self.model_name = 'models/' + model_name
|
633 |
+
self.model = GenerativeModel(model_name=self.model_name)
|
634 |
+
self.model._client = _client
|
635 |
+
|
636 |
+
def chat(self,system,history,gen_conf):
|
637 |
+
if system:
|
638 |
+
history.insert(0, {"role": "user", "parts": system})
|
639 |
+
if 'max_tokens' in gen_conf:
|
640 |
+
gen_conf['max_output_tokens'] = gen_conf['max_tokens']
|
641 |
+
for k in list(gen_conf.keys()):
|
642 |
+
if k not in ["temperature", "top_p", "max_output_tokens"]:
|
643 |
+
del gen_conf[k]
|
644 |
+
for item in history:
|
645 |
+
if 'role' in item and item['role'] == 'assistant':
|
646 |
+
item['role'] = 'model'
|
647 |
+
if 'content' in item :
|
648 |
+
item['parts'] = item.pop('content')
|
649 |
+
|
650 |
+
try:
|
651 |
+
response = self.model.generate_content(
|
652 |
+
history,
|
653 |
+
generation_config=gen_conf)
|
654 |
+
ans = response.text
|
655 |
+
return ans, response.usage_metadata.total_token_count
|
656 |
+
except Exception as e:
|
657 |
+
return "**ERROR**: " + str(e), 0
|
658 |
+
|
659 |
+
def chat_streamly(self, system, history, gen_conf):
|
660 |
+
if system:
|
661 |
+
history.insert(0, {"role": "user", "parts": system})
|
662 |
+
if 'max_tokens' in gen_conf:
|
663 |
+
gen_conf['max_output_tokens'] = gen_conf['max_tokens']
|
664 |
+
for k in list(gen_conf.keys()):
|
665 |
+
if k not in ["temperature", "top_p", "max_output_tokens"]:
|
666 |
+
del gen_conf[k]
|
667 |
+
for item in history:
|
668 |
+
if 'role' in item and item['role'] == 'assistant':
|
669 |
+
item['role'] = 'model'
|
670 |
+
if 'content' in item :
|
671 |
+
item['parts'] = item.pop('content')
|
672 |
+
ans = ""
|
673 |
+
try:
|
674 |
+
response = self.model.generate_content(
|
675 |
+
history,
|
676 |
+
generation_config=gen_conf,stream=True)
|
677 |
+
for resp in response:
|
678 |
+
ans += resp.text
|
679 |
+
yield ans
|
680 |
+
|
681 |
+
except Exception as e:
|
682 |
+
yield ans + "\n**ERROR**: " + str(e)
|
683 |
+
|
684 |
+
yield response._chunks[-1].usage_metadata.total_token_count
|
rag/llm/cv_model.py
CHANGED
@@ -203,6 +203,29 @@ class XinferenceCV(Base):
|
|
203 |
)
|
204 |
return res.choices[0].message.content.strip(), res.usage.total_tokens
|
205 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
206 |
|
207 |
class LocalCV(Base):
|
208 |
def __init__(self, key, model_name="glm-4v", lang="Chinese", **kwargs):
|
|
|
203 |
)
|
204 |
return res.choices[0].message.content.strip(), res.usage.total_tokens
|
205 |
|
206 |
+
class GeminiCV(Base):
|
207 |
+
def __init__(self, key, model_name="gemini-1.0-pro-vision-latest", lang="Chinese", **kwargs):
|
208 |
+
from google.generativeai import client,GenerativeModel
|
209 |
+
client.configure(api_key=key)
|
210 |
+
_client = client.get_default_generative_client()
|
211 |
+
self.model_name = model_name
|
212 |
+
self.model = GenerativeModel(model_name=self.model_name)
|
213 |
+
self.model._client = _client
|
214 |
+
self.lang = lang
|
215 |
+
|
216 |
+
def describe(self, image, max_tokens=2048):
|
217 |
+
from PIL.Image import open
|
218 |
+
gen_config = {'max_output_tokens':max_tokens}
|
219 |
+
prompt = "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else \
|
220 |
+
"Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out."
|
221 |
+
b64 = self.image2base64(image)
|
222 |
+
img = open(BytesIO(base64.b64decode(b64)))
|
223 |
+
input = [prompt,img]
|
224 |
+
res = self.model.generate_content(
|
225 |
+
input,
|
226 |
+
generation_config=gen_config,
|
227 |
+
)
|
228 |
+
return res.text,res.usage_metadata.total_token_count
|
229 |
|
230 |
class LocalCV(Base):
|
231 |
def __init__(self, key, model_name="glm-4v", lang="Chinese", **kwargs):
|
rag/llm/embedding_model.py
CHANGED
@@ -31,7 +31,7 @@ import numpy as np
|
|
31 |
import asyncio
|
32 |
from api.utils.file_utils import get_home_cache_dir
|
33 |
from rag.utils import num_tokens_from_string, truncate
|
34 |
-
|
35 |
|
36 |
class Base(ABC):
|
37 |
def __init__(self, key, model_name):
|
@@ -419,3 +419,27 @@ class BedrockEmbed(Base):
|
|
419 |
|
420 |
return np.array(embeddings), token_count
|
421 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
import asyncio
|
32 |
from api.utils.file_utils import get_home_cache_dir
|
33 |
from rag.utils import num_tokens_from_string, truncate
|
34 |
+
import google.generativeai as genai
|
35 |
|
36 |
class Base(ABC):
|
37 |
def __init__(self, key, model_name):
|
|
|
419 |
|
420 |
return np.array(embeddings), token_count
|
421 |
|
422 |
+
class GeminiEmbed(Base):
|
423 |
+
def __init__(self, key, model_name='models/text-embedding-004',
|
424 |
+
**kwargs):
|
425 |
+
genai.configure(api_key=key)
|
426 |
+
self.model_name = 'models/' + model_name
|
427 |
+
|
428 |
+
def encode(self, texts: list, batch_size=32):
|
429 |
+
texts = [truncate(t, 2048) for t in texts]
|
430 |
+
token_count = sum(num_tokens_from_string(text) for text in texts)
|
431 |
+
result = genai.embed_content(
|
432 |
+
model=self.model_name,
|
433 |
+
content=texts,
|
434 |
+
task_type="retrieval_document",
|
435 |
+
title="Embedding of list of strings")
|
436 |
+
return np.array(result['embedding']),token_count
|
437 |
+
|
438 |
+
def encode_queries(self, text):
|
439 |
+
result = genai.embed_content(
|
440 |
+
model=self.model_name,
|
441 |
+
content=truncate(text,2048),
|
442 |
+
task_type="retrieval_document",
|
443 |
+
title="Embedding of single string")
|
444 |
+
token_count = num_tokens_from_string(text)
|
445 |
+
return np.array(result['embedding']),token_count
|
requirements.txt
CHANGED
@@ -147,3 +147,4 @@ markdown==3.6
|
|
147 |
mistralai==0.4.2
|
148 |
boto3==1.34.140
|
149 |
duckduckgo_search==6.1.9
|
|
|
|
147 |
mistralai==0.4.2
|
148 |
boto3==1.34.140
|
149 |
duckduckgo_search==6.1.9
|
150 |
+
google-generativeai==0.7.2
|
requirements_arm.txt
CHANGED
@@ -148,3 +148,4 @@ markdown==3.6
|
|
148 |
mistralai==0.4.2
|
149 |
boto3==1.34.140
|
150 |
duckduckgo_search==6.1.9
|
|
|
|
148 |
mistralai==0.4.2
|
149 |
boto3==1.34.140
|
150 |
duckduckgo_search==6.1.9
|
151 |
+
google-generativeai==0.7.2
|
requirements_dev.txt
CHANGED
@@ -133,3 +133,4 @@ markdown==3.6
|
|
133 |
mistralai==0.4.2
|
134 |
boto3==1.34.140
|
135 |
duckduckgo_search==6.1.9
|
|
|
|
133 |
mistralai==0.4.2
|
134 |
boto3==1.34.140
|
135 |
duckduckgo_search==6.1.9
|
136 |
+
google-generativeai==0.7.2
|
web/src/assets/svg/llm/gemini.svg
ADDED
|
web/src/pages/user-setting/setting-model/index.tsx
CHANGED
@@ -61,6 +61,7 @@ const IconMap = {
|
|
61 |
Mistral: 'mistral',
|
62 |
'Azure-OpenAI': 'azure',
|
63 |
Bedrock: 'bedrock',
|
|
|
64 |
};
|
65 |
|
66 |
const LlmIcon = ({ name }: { name: string }) => {
|
|
|
61 |
Mistral: 'mistral',
|
62 |
'Azure-OpenAI': 'azure',
|
63 |
Bedrock: 'bedrock',
|
64 |
+
Gemini:'gemini',
|
65 |
};
|
66 |
|
67 |
const LlmIcon = ({ name }: { name: string }) => {
|