truens66's picture
Create app.py
ae2818c verified
import os
import torch
import gradio as gr
from PIL import Image
from lavis.models import load_model_and_preprocess, model_zoo
# Fix CUDA plugin registration errors
os.environ["CUDA_VISIBLE_DEVICES"] = "0" if torch.cuda.is_available() else "-1"
class InstructBLIP:
def __init__(self):
self.model = None
self.vis_processors = None
self.txt_processors = None
self.device = "cpu"
def load_models(self, model, vis_processors, txt_processors, device):
self.model = model
self.vis_processors = vis_processors
self.txt_processors = txt_processors
self.device = device
def query(self, image, question):
image = self.vis_processors["eval"](image).unsqueeze(0).to(self.device)
samples = {"image": image, "prompt": question}
candidates = ["yes", "no"]
ans = self.model.predict_class(samples=samples, candidates=candidates)
# Convert logits to probabilities
logits = ans[0]
yes_prob = torch.sigmoid(torch.tensor(logits[0])).item() * 100
no_prob = torch.sigmoid(torch.tensor(logits[1])).item() * 100
result = "Real" if yes_prob > no_prob else "Fake"
confidence = max(yes_prob, no_prob)
return result, round(confidence, 2)
def load_model(model_name="blip2_t5", model_type="pretrain_flant5xl"):
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
try:
model, vis_processors, txt_processors = load_model_and_preprocess(
name=model_name,
model_type=model_type,
is_eval=True,
device=device
)
if model is None:
raise ValueError(f"Failed to load model '{model_name}' with type '{model_type}'")
instruct = InstructBLIP()
instruct.load_models(model, vis_processors, txt_processors, device)
return instruct
except Exception as e:
print(f"Error loading model: {e}")
return None
# Load the model once when the script starts
model_instance = load_model()
def predict_image(input_image, question="Is this photo real [*]?"):
if input_image is None:
return "No image provided", 0
try:
# Ensure input is a PIL Image
if not isinstance(input_image, Image.Image):
input_image = Image.fromarray(input_image)
# Run model inference
result, confidence = model_instance.query(input_image, question)
return result, confidence
except Exception as e:
return f"Error: {str(e)}", 0
# Create Gradio interface
def create_interface():
with gr.Blocks(title="Fake Image Detector") as app:
gr.Markdown("""
# Real vs Fake Image Detector
Upload an image to check if it's real or AI-generated. The model will classify the image and provide a confidence score.
Based on AntifakePrompt: https://github.com/nctu-eva-lab/AntifakePrompt
""")
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil", label="Upload Image")
question = gr.Textbox(label="Question Prompt", value="Is this photo real [*]?")
submit_btn = gr.Button("Analyze Image", variant="primary")
with gr.Column():
result_label = gr.Textbox(label="Classification Result")
confidence = gr.Number(label="Confidence Score (%)")
submit_btn.click(
fn=predict_image,
inputs=[input_image, question],
outputs=[result_label, confidence]
)
gr.Examples(
examples=[
["example_real.jpg", "Is this photo real [*]?"],
["example_fake.jpg", "Is this photo real [*]?"],
],
inputs=[input_image, question],
outputs=[result_label, confidence],
fn=predict_image,
cache_examples=True,
)
return app
if __name__ == "__main__":
# Create and launch the Gradio interface
demo = create_interface()
demo.launch(share=True) # Set share=True to get a public link