Spaces:
Runtime error
Runtime error
import os | |
import torch | |
import gradio as gr | |
from PIL import Image | |
from lavis.models import load_model_and_preprocess, model_zoo | |
# Fix CUDA plugin registration errors | |
os.environ["CUDA_VISIBLE_DEVICES"] = "0" if torch.cuda.is_available() else "-1" | |
class InstructBLIP: | |
def __init__(self): | |
self.model = None | |
self.vis_processors = None | |
self.txt_processors = None | |
self.device = "cpu" | |
def load_models(self, model, vis_processors, txt_processors, device): | |
self.model = model | |
self.vis_processors = vis_processors | |
self.txt_processors = txt_processors | |
self.device = device | |
def query(self, image, question): | |
image = self.vis_processors["eval"](image).unsqueeze(0).to(self.device) | |
samples = {"image": image, "prompt": question} | |
candidates = ["yes", "no"] | |
ans = self.model.predict_class(samples=samples, candidates=candidates) | |
# Convert logits to probabilities | |
logits = ans[0] | |
yes_prob = torch.sigmoid(torch.tensor(logits[0])).item() * 100 | |
no_prob = torch.sigmoid(torch.tensor(logits[1])).item() * 100 | |
result = "Real" if yes_prob > no_prob else "Fake" | |
confidence = max(yes_prob, no_prob) | |
return result, round(confidence, 2) | |
def load_model(model_name="blip2_t5", model_type="pretrain_flant5xl"): | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"Using device: {device}") | |
try: | |
model, vis_processors, txt_processors = load_model_and_preprocess( | |
name=model_name, | |
model_type=model_type, | |
is_eval=True, | |
device=device | |
) | |
if model is None: | |
raise ValueError(f"Failed to load model '{model_name}' with type '{model_type}'") | |
instruct = InstructBLIP() | |
instruct.load_models(model, vis_processors, txt_processors, device) | |
return instruct | |
except Exception as e: | |
print(f"Error loading model: {e}") | |
return None | |
# Load the model once when the script starts | |
model_instance = load_model() | |
def predict_image(input_image, question="Is this photo real [*]?"): | |
if input_image is None: | |
return "No image provided", 0 | |
try: | |
# Ensure input is a PIL Image | |
if not isinstance(input_image, Image.Image): | |
input_image = Image.fromarray(input_image) | |
# Run model inference | |
result, confidence = model_instance.query(input_image, question) | |
return result, confidence | |
except Exception as e: | |
return f"Error: {str(e)}", 0 | |
# Create Gradio interface | |
def create_interface(): | |
with gr.Blocks(title="Fake Image Detector") as app: | |
gr.Markdown(""" | |
# Real vs Fake Image Detector | |
Upload an image to check if it's real or AI-generated. The model will classify the image and provide a confidence score. | |
Based on AntifakePrompt: https://github.com/nctu-eva-lab/AntifakePrompt | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image(type="pil", label="Upload Image") | |
question = gr.Textbox(label="Question Prompt", value="Is this photo real [*]?") | |
submit_btn = gr.Button("Analyze Image", variant="primary") | |
with gr.Column(): | |
result_label = gr.Textbox(label="Classification Result") | |
confidence = gr.Number(label="Confidence Score (%)") | |
submit_btn.click( | |
fn=predict_image, | |
inputs=[input_image, question], | |
outputs=[result_label, confidence] | |
) | |
gr.Examples( | |
examples=[ | |
["example_real.jpg", "Is this photo real [*]?"], | |
["example_fake.jpg", "Is this photo real [*]?"], | |
], | |
inputs=[input_image, question], | |
outputs=[result_label, confidence], | |
fn=predict_image, | |
cache_examples=True, | |
) | |
return app | |
if __name__ == "__main__": | |
# Create and launch the Gradio interface | |
demo = create_interface() | |
demo.launch(share=True) # Set share=True to get a public link | |