qcloud
1
5bfdfae
import gradio as gr
from vector_db.vector_db_client import VectorDB
from PIL import Image
from transformers import AutoProcessor, CLIPModel
import os
import uuid
from tcvectordb.model.document import SearchParams
import traceback
LOCAL_MODEL_PATH = "download_model.local_model_path"
MODEL_NAME = "download_model.model_name"
LOCAL_GRAPH_PATH = "graph_upload.local_graph_path"
class ChatSearch:
def __init__(self, config, vdb: VectorDB):
self.vdb = vdb
self.model_name = config.get(MODEL_NAME)
self.local_model_path = config.get(LOCAL_MODEL_PATH)
self.local_graph_path = config.get(LOCAL_GRAPH_PATH)
self.model_cache_directory = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), self.local_model_path, self.model_name)
self.graph_cache_directory = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), self.local_graph_path)
def initial_model(self):
model = CLIPModel.from_pretrained(self.model_cache_directory)
processor = AutoProcessor.from_pretrained(self.model_cache_directory)
return model, processor
def search_result(self, image):
if image is None:
return "请先上传图片..."
if not os.path.exists(self.model_cache_directory):
return f"缓存目录 {self.model_cache_directory} 不存在,无法初始化模型。"
model, processor = self.initial_model()
try:
# 生成唯一的文件名
unique_filename = f"{uuid.uuid4().hex}.png"
image_path = os.path.join(self.graph_cache_directory, unique_filename)
# 保存图片到指定文件夹
image.save(image_path)
image_vector = self._process_image(image_path, model, processor).squeeze().tolist() # 转换为一维列表
# 假设你的 VectorDB 支持图片搜索
collection = self.vdb.get_collection()
res = collection.search(
vectors=[image_vector],
params=SearchParams(ef=200),
limit=10,
output_fields=['local_graph_path']
)
results = []
for i, docs in enumerate(res):
for doc in docs:
image_path = doc['local_graph_path']
try:
image = Image.open(image_path)
results.append(image)
except Exception as e:
print(f"无法加载图片 {image_path}: {e}")
return results
except Exception as e:
print(f"问题:{e}\n")
error_trace = traceback.format_exc()
print(error_trace)
def _process_image(self, image_path, emb_model, processor):
"""
处理单个图片文件,将其转换为向量。
参数:
image_path (str): 图片文件的路径。
返回:
torch.Tensor: 图片的向量表示。
"""
image = Image.open(image_path)
inputs = processor(images=image, return_tensors="pt")
image_features = emb_model.get_image_features(**inputs)
return image_features
def get_chart(self):
return gr.Interface(
fn=self.search_result,
inputs=gr.Image(type="pil", label="上传图片"),
outputs=gr.Gallery(label="检索结果"),
theme="soft",
description="上传图片进行检索",
allow_flagging="never"
)