File size: 4,139 Bytes
ae2818c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import os
import torch
import gradio as gr
from PIL import Image
from lavis.models import load_model_and_preprocess, model_zoo

# Fix CUDA plugin registration errors
os.environ["CUDA_VISIBLE_DEVICES"] = "0" if torch.cuda.is_available() else "-1"

class InstructBLIP:
    def __init__(self):
        self.model = None
        self.vis_processors = None
        self.txt_processors = None
        self.device = "cpu"

    def load_models(self, model, vis_processors, txt_processors, device):
        self.model = model
        self.vis_processors = vis_processors
        self.txt_processors = txt_processors
        self.device = device

    def query(self, image, question):
        image = self.vis_processors["eval"](image).unsqueeze(0).to(self.device)
        samples = {"image": image, "prompt": question}
        candidates = ["yes", "no"]
        ans = self.model.predict_class(samples=samples, candidates=candidates)

        # Convert logits to probabilities
        logits = ans[0]
        yes_prob = torch.sigmoid(torch.tensor(logits[0])).item() * 100
        no_prob = torch.sigmoid(torch.tensor(logits[1])).item() * 100

        result = "Real" if yes_prob > no_prob else "Fake"
        confidence = max(yes_prob, no_prob)

        return result, round(confidence, 2)

def load_model(model_name="blip2_t5", model_type="pretrain_flant5xl"):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")

    try:
        model, vis_processors, txt_processors = load_model_and_preprocess(
            name=model_name, 
            model_type=model_type, 
            is_eval=True, 
            device=device
        )
        if model is None:
            raise ValueError(f"Failed to load model '{model_name}' with type '{model_type}'")

        instruct = InstructBLIP()
        instruct.load_models(model, vis_processors, txt_processors, device)
        return instruct
    except Exception as e:
        print(f"Error loading model: {e}")
        return None


# Load the model once when the script starts
model_instance = load_model()

def predict_image(input_image, question="Is this photo real [*]?"):
    if input_image is None:
        return "No image provided", 0

    try:
        # Ensure input is a PIL Image
        if not isinstance(input_image, Image.Image):
            input_image = Image.fromarray(input_image)

        # Run model inference
        result, confidence = model_instance.query(input_image, question)
        return result, confidence
    except Exception as e:
        return f"Error: {str(e)}", 0

# Create Gradio interface
def create_interface():
    with gr.Blocks(title="Fake Image Detector") as app:
        gr.Markdown("""
        # Real vs Fake Image Detector
        Upload an image to check if it's real or AI-generated. The model will classify the image and provide a confidence score.
        Based on AntifakePrompt: https://github.com/nctu-eva-lab/AntifakePrompt
        """)

        with gr.Row():
            with gr.Column():
                input_image = gr.Image(type="pil", label="Upload Image")
                question = gr.Textbox(label="Question Prompt", value="Is this photo real [*]?")
                submit_btn = gr.Button("Analyze Image", variant="primary")

            with gr.Column():
                result_label = gr.Textbox(label="Classification Result")
                confidence = gr.Number(label="Confidence Score (%)")

        submit_btn.click(
            fn=predict_image,
            inputs=[input_image, question],
            outputs=[result_label, confidence]
        )

        gr.Examples(
            examples=[
                ["example_real.jpg", "Is this photo real [*]?"],
                ["example_fake.jpg", "Is this photo real [*]?"],
            ],
            inputs=[input_image, question],
            outputs=[result_label, confidence],
            fn=predict_image,
            cache_examples=True,
        )

    return app

if __name__ == "__main__":
    # Create and launch the Gradio interface
    demo = create_interface()
    demo.launch(share=True)  # Set share=True to get a public link