import os import torch import gradio as gr from PIL import Image from lavis.models import load_model_and_preprocess, model_zoo # Fix CUDA plugin registration errors os.environ["CUDA_VISIBLE_DEVICES"] = "0" if torch.cuda.is_available() else "-1" class InstructBLIP: def __init__(self): self.model = None self.vis_processors = None self.txt_processors = None self.device = "cpu" def load_models(self, model, vis_processors, txt_processors, device): self.model = model self.vis_processors = vis_processors self.txt_processors = txt_processors self.device = device def query(self, image, question): image = self.vis_processors["eval"](image).unsqueeze(0).to(self.device) samples = {"image": image, "prompt": question} candidates = ["yes", "no"] ans = self.model.predict_class(samples=samples, candidates=candidates) # Convert logits to probabilities logits = ans[0] yes_prob = torch.sigmoid(torch.tensor(logits[0])).item() * 100 no_prob = torch.sigmoid(torch.tensor(logits[1])).item() * 100 result = "Real" if yes_prob > no_prob else "Fake" confidence = max(yes_prob, no_prob) return result, round(confidence, 2) def load_model(model_name="blip2_t5", model_type="pretrain_flant5xl"): device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}") try: model, vis_processors, txt_processors = load_model_and_preprocess( name=model_name, model_type=model_type, is_eval=True, device=device ) if model is None: raise ValueError(f"Failed to load model '{model_name}' with type '{model_type}'") instruct = InstructBLIP() instruct.load_models(model, vis_processors, txt_processors, device) return instruct except Exception as e: print(f"Error loading model: {e}") return None # Load the model once when the script starts model_instance = load_model() def predict_image(input_image, question="Is this photo real [*]?"): if input_image is None: return "No image provided", 0 try: # Ensure input is a PIL Image if not isinstance(input_image, Image.Image): input_image = Image.fromarray(input_image) # Run model inference result, confidence = model_instance.query(input_image, question) return result, confidence except Exception as e: return f"Error: {str(e)}", 0 # Create Gradio interface def create_interface(): with gr.Blocks(title="Fake Image Detector") as app: gr.Markdown(""" # Real vs Fake Image Detector Upload an image to check if it's real or AI-generated. The model will classify the image and provide a confidence score. Based on AntifakePrompt: https://github.com/nctu-eva-lab/AntifakePrompt """) with gr.Row(): with gr.Column(): input_image = gr.Image(type="pil", label="Upload Image") question = gr.Textbox(label="Question Prompt", value="Is this photo real [*]?") submit_btn = gr.Button("Analyze Image", variant="primary") with gr.Column(): result_label = gr.Textbox(label="Classification Result") confidence = gr.Number(label="Confidence Score (%)") submit_btn.click( fn=predict_image, inputs=[input_image, question], outputs=[result_label, confidence] ) gr.Examples( examples=[ ["example_real.jpg", "Is this photo real [*]?"], ["example_fake.jpg", "Is this photo real [*]?"], ], inputs=[input_image, question], outputs=[result_label, confidence], fn=predict_image, cache_examples=True, ) return app if __name__ == "__main__": # Create and launch the Gradio interface demo = create_interface() demo.launch(share=True) # Set share=True to get a public link