ragflow / rag /llm /cv_model.py
KevinHuSh
add base url for OpenAI (#166)
e06e08c
raw
history blame
5.36 kB
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from zhipuai import ZhipuAI
import io
from abc import ABC
from PIL import Image
from openai import OpenAI
import os
import base64
from io import BytesIO
from api.utils import get_uuid
from api.utils.file_utils import get_project_base_directory
class Base(ABC):
def __init__(self, key, model_name):
pass
def describe(self, image, max_tokens=300):
raise NotImplementedError("Please implement encode method!")
def image2base64(self, image):
if isinstance(image, bytes):
return base64.b64encode(image).decode("utf-8")
if isinstance(image, BytesIO):
return base64.b64encode(image.getvalue()).decode("utf-8")
buffered = BytesIO()
try:
image.save(buffered, format="JPEG")
except Exception as e:
image.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode("utf-8")
def prompt(self, b64):
return [
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{b64}"
},
},
{
"text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else
"Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.",
},
],
}
]
class GptV4(Base):
def __init__(self, key, model_name="gpt-4-vision-preview", lang="Chinese", base_url="https://api.openai.com/v1"):
if not base_url: base_url="https://api.openai.com/v1"
self.client = OpenAI(api_key=key, base_url=base_url)
self.model_name = model_name
self.lang = lang
def describe(self, image, max_tokens=300):
b64 = self.image2base64(image)
res = self.client.chat.completions.create(
model=self.model_name,
messages=self.prompt(b64),
max_tokens=max_tokens,
)
return res.choices[0].message.content.strip(), res.usage.total_tokens
class QWenCV(Base):
def __init__(self, key, model_name="qwen-vl-chat-v1", lang="Chinese", **kwargs):
import dashscope
dashscope.api_key = key
self.model_name = model_name
self.lang = lang
def prompt(self, binary):
# stupid as hell
tmp_dir = get_project_base_directory("tmp")
if not os.path.exists(tmp_dir):
os.mkdir(tmp_dir)
path = os.path.join(tmp_dir, "%s.jpg" % get_uuid())
Image.open(io.BytesIO(binary)).save(path)
return [
{
"role": "user",
"content": [
{
"image": f"file://{path}"
},
{
"text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else
"Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.",
},
],
}
]
def describe(self, image, max_tokens=300):
from http import HTTPStatus
from dashscope import MultiModalConversation
response = MultiModalConversation.call(model=self.model_name,
messages=self.prompt(image))
if response.status_code == HTTPStatus.OK:
return response.output.choices[0]['message']['content'][0]["text"], response.usage.output_tokens
return response.message, 0
class Zhipu4V(Base):
def __init__(self, key, model_name="glm-4v", lang="Chinese", **kwargs):
self.client = ZhipuAI(api_key=key)
self.model_name = model_name
self.lang = lang
def describe(self, image, max_tokens=1024):
b64 = self.image2base64(image)
res = self.client.chat.completions.create(
model=self.model_name,
messages=self.prompt(b64),
max_tokens=max_tokens,
)
return res.choices[0].message.content.strip(), res.usage.total_tokens
class LocalCV(Base):
def __init__(self, key, model_name="glm-4v", lang="Chinese", **kwargs):
pass
def describe(self, image, max_tokens=1024):
return "", 0