# # Copyright 2024 The InfiniFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from zhipuai import ZhipuAI import io from abc import ABC from PIL import Image from openai import OpenAI import os import base64 from io import BytesIO from api.utils import get_uuid from api.utils.file_utils import get_project_base_directory class Base(ABC): def __init__(self, key, model_name): pass def describe(self, image, max_tokens=300): raise NotImplementedError("Please implement encode method!") def image2base64(self, image): if isinstance(image, bytes): return base64.b64encode(image).decode("utf-8") if isinstance(image, BytesIO): return base64.b64encode(image.getvalue()).decode("utf-8") buffered = BytesIO() try: image.save(buffered, format="JPEG") except Exception as e: image.save(buffered, format="PNG") return base64.b64encode(buffered.getvalue()).decode("utf-8") def prompt(self, b64): return [ { "role": "user", "content": [ { "type": "image_url", "image_url": { "url": f"data:image/jpeg;base64,{b64}" }, }, { "text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.", }, ], } ] class GptV4(Base): def __init__(self, key, model_name="gpt-4-vision-preview", lang="Chinese", base_url="https://api.openai.com/v1"): if not base_url: base_url="https://api.openai.com/v1" self.client = OpenAI(api_key=key, base_url=base_url) self.model_name = model_name self.lang = lang def describe(self, image, max_tokens=300): b64 = self.image2base64(image) res = self.client.chat.completions.create( model=self.model_name, messages=self.prompt(b64), max_tokens=max_tokens, ) return res.choices[0].message.content.strip(), res.usage.total_tokens class QWenCV(Base): def __init__(self, key, model_name="qwen-vl-chat-v1", lang="Chinese", **kwargs): import dashscope dashscope.api_key = key self.model_name = model_name self.lang = lang def prompt(self, binary): # stupid as hell tmp_dir = get_project_base_directory("tmp") if not os.path.exists(tmp_dir): os.mkdir(tmp_dir) path = os.path.join(tmp_dir, "%s.jpg" % get_uuid()) Image.open(io.BytesIO(binary)).save(path) return [ { "role": "user", "content": [ { "image": f"file://{path}" }, { "text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.", }, ], } ] def describe(self, image, max_tokens=300): from http import HTTPStatus from dashscope import MultiModalConversation response = MultiModalConversation.call(model=self.model_name, messages=self.prompt(image)) if response.status_code == HTTPStatus.OK: return response.output.choices[0]['message']['content'][0]["text"], response.usage.output_tokens return response.message, 0 class Zhipu4V(Base): def __init__(self, key, model_name="glm-4v", lang="Chinese", **kwargs): self.client = ZhipuAI(api_key=key) self.model_name = model_name self.lang = lang def describe(self, image, max_tokens=1024): b64 = self.image2base64(image) res = self.client.chat.completions.create( model=self.model_name, messages=self.prompt(b64), max_tokens=max_tokens, ) return res.choices[0].message.content.strip(), res.usage.total_tokens class LocalCV(Base): def __init__(self, key, model_name="glm-4v", lang="Chinese", **kwargs): pass def describe(self, image, max_tokens=1024): return "", 0