import gradio as gr import os import subprocess from PIL import Image, ImageChops, ImageFilter from ultralytics import YOLO from segment_anything import SamPredictor, sam_model_registry from transformers import BlipProcessor, BlipForConditionalGeneration, CLIPProcessor, CLIPModel, AutoProcessor, AutoModelForImageClassification import torch import matplotlib.pyplot as plt import numpy as np from openai import OpenAI from huggingface_hub import hf_hub_download from segment_anything import SamPredictor, sam_model_registry from yolo_world.models.detectors import build_detector from mmcv import Config # 初始化模型 clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base") sam_checkpoint = hf_hub_download( repo_id="facebook/sam-vit-large", # 仓库 ID filename="model.safetensors", # 模型文件名 use_auth_token=False # 公共仓库无需身份验证 ) sam = sam_model_registry["vit_h"](checkpoint=sam_checkpoint) sam_predictor = SamPredictor(sam) # 从 Hugging Face 下载 YOLO-World 权重 yolo_checkpoint = hf_hub_download( repo_id="stevengrove/YOLO-World", # Hugging Face 仓库 ID filename="yolo_world_v2_xl_obj365v1_goldg_cc3mlite_pretrain.pth", # 模型权重文件名 use_auth_token=False # 公共仓库无需身份验证 ) # 加载 YOLO-World 配置文件 yolo_config = Config.fromfile('path/to/yolo_world_config.py') # 替换为实际配置文件路径 # 构建 YOLO-World 模型 yolo_model = build_detector(yolo_config.model) # 加载权重到模型 checkpoint = torch.load(yolo_checkpoint, map_location="cpu") # 使用 CPU 加载权重,后续可以转移到 GPU yolo_model.load_state_dict(checkpoint["state_dict"]) yolo_model.eval() # 设置为评估模式 wd_processor = AutoProcessor.from_pretrained("SmilingWolf/wd-vit-tagger-v3") wd_model = AutoModelForImageClassification.from_pretrained("SmilingWolf/wd-vit-tagger-v3") # 自动识别图片类型 def classify_image_type(image): inputs = wd_processor(images=image, return_tensors="pt") outputs = wd_model(**inputs) scores = torch.softmax(outputs.logits, dim=1)[0] anime_score = scores[wd_processor.label2id["anime"]].item() return "anime" if anime_score > 0.5 else "real" # 分割图像对象 def segment_objects(image, boxes): image_np = np.array(image) sam_predictor.set_image(image_np) masks = [] for box in boxes: mask, _, _ = sam_predictor.predict( point_coords=None, point_labels=None, box=box, multimask_output=False ) masks.append(mask) return masks # 检测对象 def detect_objects(image, image_type): if image_type == "real": results = yolo_model.predict(np.array(image), conf=0.25) objects = [{"label": r["class"], "box": r["bbox"], "confidence": r["confidence"]} for r in results] else: inputs = wd_processor(images=image, return_tensors="pt") outputs = wd_model(**inputs) scores = torch.softmax(outputs.logits, dim=1)[0] top_k = torch.topk(scores, k=5) objects = [{"label": wd_processor.decode(top_k.indices[i].item()), "confidence": top_k.values[i].item()} for i in range(5)] return objects # 生成语义描述 def generate_object_descriptions(image, objects): descriptions = [] for obj in objects: box = obj.get("box", None) if box: cropped = image.crop(box) else: cropped = image inputs = blip_processor(cropped, return_tensors="pt") caption = blip_model.generate(**inputs, max_length=128, num_beams=5, no_repeat_ngram_size=2) description = blip_processor.decode(caption[0], skip_special_tokens=True) descriptions.append({"label": obj["label"], "description": description}) return descriptions # 特征差异可视化 def plot_feature_differences(latent_diff, descriptions, prefix): diff_magnitude = [abs(x) for x in latent_diff[0]] indices = range(len(diff_magnitude)) top_indices = np.argsort(diff_magnitude)[-10:][::-1] plt.figure(figsize=(8, 4)) plt.bar(indices, diff_magnitude, alpha=0.7) plt.xlabel("Feature Index") plt.ylabel("Magnitude of Difference") plt.title("Feature Differences (Bar Chart)") bar_chart_path = f"{prefix}_bar_chart.png" plt.savefig(bar_chart_path) plt.close() plt.figure(figsize=(6, 6)) plt.pie( [diff_magnitude[i] for i in top_indices], labels=[descriptions[i] for i in top_indices], autopct="%1.1f%%", startangle=140 ) plt.title("Top 10 Feature Differences (Pie Chart)") pie_chart_path = f"{prefix}_pie_chart.png" plt.savefig(pie_chart_path) plt.close() return bar_chart_path, pie_chart_path # 生成详细分析文本 def generate_text_analysis(api_key, api_type, caption_a, caption_b): if api_type == "DeepSeek": client = OpenAI(api_key=api_key, base_url="https://api.deepseek.com") else: client = OpenAI(api_key=api_key) response = client.chat.completions.create( model="gpt-4" if api_type == "GPT" else "deepseek-chat", messages=[ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": f"图片A的描述为:{caption_a}。\n图片B的描述为:{caption_b}。\n请对两张图片进行详细对比分析。"} ] ) return response.choices[0].message.content.strip() # 分析单对图片 def analyze_images(img_a, img_b, api_key, api_type, prefix): type_a = classify_image_type(img_a) type_b = classify_image_type(img_b) objects_a = detect_objects(img_a, type_a) objects_b = detect_objects(img_b, type_b) descriptions_a = generate_object_descriptions(img_a, objects_a) descriptions_b = generate_object_descriptions(img_b, objects_b) inputs = clip_processor(images=img_a, return_tensors="pt") features_a = clip_model.get_image_features(**inputs).detach().numpy() inputs = clip_processor(images=img_b, return_tensors="pt") features_b = clip_model.get_image_features(**inputs).detach().numpy() latent_diff = np.abs(features_a - features_b).tolist() bar_chart, pie_chart = plot_feature_differences(latent_diff, [d['label'] for d in descriptions_a], prefix) text_analysis = generate_text_analysis(api_key, api_type, descriptions_a, descriptions_b) return { "bar_chart": bar_chart, "pie_chart": pie_chart, "text_analysis": text_analysis } # Gradio 界面 with gr.Blocks() as demo: gr.Markdown("# 综合图像对比分析工具") api_key_input = gr.Textbox(label="API Key", placeholder="输入 API Key", type="password") api_type_input = gr.Radio(label="API 类型", choices=["GPT", "DeepSeek"], value="GPT") images_a_input = gr.File(label="上传文件夹A图片", file_types=[".png", ".jpg"], file_count="multiple") images_b_input = gr.File(label="上传文件夹B图片", file_types=[".png", ".jpg"], file_count="multiple") analyze_button = gr.Button("开始分析") result_gallery = gr.Gallery(label="差异可视化") result_text = gr.Textbox(label="分析结果", lines=5) def process_batch(images_a, images_b, api_key, api_type): images_a = [Image.open(img).convert("RGB") for img in images_a] images_b = [Image.open(img).convert("RGB") for img in images_b] results = [analyze_images(img_a, img_b, api_key, api_type, f"comparison_{i+1}") for i, (img_a, img_b) in enumerate(zip(images_a, images_b))] return results analyze_button.click(process_batch, inputs=[images_a_input, images_b_input, api_key_input, api_type_input], outputs=[result_gallery, result_text]) demo.launch()