import gradio as gr from vector_db.vector_db_client import VectorDB from PIL import Image from transformers import AutoProcessor, CLIPModel import os import uuid from tcvectordb.model.document import SearchParams import traceback LOCAL_MODEL_PATH = "download_model.local_model_path" MODEL_NAME = "download_model.model_name" LOCAL_GRAPH_PATH = "graph_upload.local_graph_path" class ChatSearch: def __init__(self, config, vdb: VectorDB): self.vdb = vdb self.model_name = config.get(MODEL_NAME) self.local_model_path = config.get(LOCAL_MODEL_PATH) self.local_graph_path = config.get(LOCAL_GRAPH_PATH) self.model_cache_directory = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), self.local_model_path, self.model_name) self.graph_cache_directory = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), self.local_graph_path) def initial_model(self): model = CLIPModel.from_pretrained(self.model_cache_directory) processor = AutoProcessor.from_pretrained(self.model_cache_directory) return model, processor def search_result(self, image): if image is None: return "请先上传图片..." if not os.path.exists(self.model_cache_directory): return f"缓存目录 {self.model_cache_directory} 不存在,无法初始化模型。" model, processor = self.initial_model() try: # 生成唯一的文件名 unique_filename = f"{uuid.uuid4().hex}.png" image_path = os.path.join(self.graph_cache_directory, unique_filename) # 保存图片到指定文件夹 image.save(image_path) image_vector = self._process_image(image_path, model, processor).squeeze().tolist() # 转换为一维列表 # 假设你的 VectorDB 支持图片搜索 collection = self.vdb.get_collection() res = collection.search( vectors=[image_vector], params=SearchParams(ef=200), limit=10, output_fields=['local_graph_path'] ) results = [] for i, docs in enumerate(res): for doc in docs: image_path = doc['local_graph_path'] try: image = Image.open(image_path) results.append(image) except Exception as e: print(f"无法加载图片 {image_path}: {e}") return results except Exception as e: print(f"问题:{e}\n") error_trace = traceback.format_exc() print(error_trace) def _process_image(self, image_path, emb_model, processor): """ 处理单个图片文件,将其转换为向量。 参数: image_path (str): 图片文件的路径。 返回: torch.Tensor: 图片的向量表示。 """ image = Image.open(image_path) inputs = processor(images=image, return_tensors="pt") image_features = emb_model.get_image_features(**inputs) return image_features def get_chart(self): return gr.Interface( fn=self.search_result, inputs=gr.Image(type="pil", label="上传图片"), outputs=gr.Gallery(label="检索结果"), theme="soft", description="上传图片进行检索", allow_flagging="never" )