Spaces:
Runtime error
Runtime error
import gradio as gr | |
from vector_db.vector_db_client import VectorDB | |
from PIL import Image | |
from transformers import AutoProcessor, CLIPModel | |
import os | |
import uuid | |
from tcvectordb.model.document import SearchParams | |
import traceback | |
LOCAL_MODEL_PATH = "download_model.local_model_path" | |
MODEL_NAME = "download_model.model_name" | |
LOCAL_GRAPH_PATH = "graph_upload.local_graph_path" | |
class ChatSearch: | |
def __init__(self, config, vdb: VectorDB): | |
self.vdb = vdb | |
self.model_name = config.get(MODEL_NAME) | |
self.local_model_path = config.get(LOCAL_MODEL_PATH) | |
self.local_graph_path = config.get(LOCAL_GRAPH_PATH) | |
self.model_cache_directory = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), self.local_model_path, self.model_name) | |
self.graph_cache_directory = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), self.local_graph_path) | |
def initial_model(self): | |
model = CLIPModel.from_pretrained(self.model_cache_directory) | |
processor = AutoProcessor.from_pretrained(self.model_cache_directory) | |
return model, processor | |
def search_result(self, image): | |
if image is None: | |
return "请先上传图片..." | |
if not os.path.exists(self.model_cache_directory): | |
return f"缓存目录 {self.model_cache_directory} 不存在,无法初始化模型。" | |
model, processor = self.initial_model() | |
try: | |
# 生成唯一的文件名 | |
unique_filename = f"{uuid.uuid4().hex}.png" | |
image_path = os.path.join(self.graph_cache_directory, unique_filename) | |
# 保存图片到指定文件夹 | |
image.save(image_path) | |
image_vector = self._process_image(image_path, model, processor).squeeze().tolist() # 转换为一维列表 | |
# 假设你的 VectorDB 支持图片搜索 | |
collection = self.vdb.get_collection() | |
res = collection.search( | |
vectors=[image_vector], | |
params=SearchParams(ef=200), | |
limit=10, | |
output_fields=['local_graph_path'] | |
) | |
results = [] | |
for i, docs in enumerate(res): | |
for doc in docs: | |
image_path = doc['local_graph_path'] | |
try: | |
image = Image.open(image_path) | |
results.append(image) | |
except Exception as e: | |
print(f"无法加载图片 {image_path}: {e}") | |
return results | |
except Exception as e: | |
print(f"问题:{e}\n") | |
error_trace = traceback.format_exc() | |
print(error_trace) | |
def _process_image(self, image_path, emb_model, processor): | |
""" | |
处理单个图片文件,将其转换为向量。 | |
参数: | |
image_path (str): 图片文件的路径。 | |
返回: | |
torch.Tensor: 图片的向量表示。 | |
""" | |
image = Image.open(image_path) | |
inputs = processor(images=image, return_tensors="pt") | |
image_features = emb_model.get_image_features(**inputs) | |
return image_features | |
def get_chart(self): | |
return gr.Interface( | |
fn=self.search_result, | |
inputs=gr.Image(type="pil", label="上传图片"), | |
outputs=gr.Gallery(label="检索结果"), | |
theme="soft", | |
description="上传图片进行检索", | |
allow_flagging="never" | |
) | |