truens66 commited on
Commit
ae2818c
·
verified ·
1 Parent(s): a95db74

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +120 -0
app.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import gradio as gr
4
+ from PIL import Image
5
+ from lavis.models import load_model_and_preprocess, model_zoo
6
+
7
+ # Fix CUDA plugin registration errors
8
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0" if torch.cuda.is_available() else "-1"
9
+
10
+ class InstructBLIP:
11
+ def __init__(self):
12
+ self.model = None
13
+ self.vis_processors = None
14
+ self.txt_processors = None
15
+ self.device = "cpu"
16
+
17
+ def load_models(self, model, vis_processors, txt_processors, device):
18
+ self.model = model
19
+ self.vis_processors = vis_processors
20
+ self.txt_processors = txt_processors
21
+ self.device = device
22
+
23
+ def query(self, image, question):
24
+ image = self.vis_processors["eval"](image).unsqueeze(0).to(self.device)
25
+ samples = {"image": image, "prompt": question}
26
+ candidates = ["yes", "no"]
27
+ ans = self.model.predict_class(samples=samples, candidates=candidates)
28
+
29
+ # Convert logits to probabilities
30
+ logits = ans[0]
31
+ yes_prob = torch.sigmoid(torch.tensor(logits[0])).item() * 100
32
+ no_prob = torch.sigmoid(torch.tensor(logits[1])).item() * 100
33
+
34
+ result = "Real" if yes_prob > no_prob else "Fake"
35
+ confidence = max(yes_prob, no_prob)
36
+
37
+ return result, round(confidence, 2)
38
+
39
+ def load_model(model_name="blip2_t5", model_type="pretrain_flant5xl"):
40
+ device = "cuda" if torch.cuda.is_available() else "cpu"
41
+ print(f"Using device: {device}")
42
+
43
+ try:
44
+ model, vis_processors, txt_processors = load_model_and_preprocess(
45
+ name=model_name,
46
+ model_type=model_type,
47
+ is_eval=True,
48
+ device=device
49
+ )
50
+ if model is None:
51
+ raise ValueError(f"Failed to load model '{model_name}' with type '{model_type}'")
52
+
53
+ instruct = InstructBLIP()
54
+ instruct.load_models(model, vis_processors, txt_processors, device)
55
+ return instruct
56
+ except Exception as e:
57
+ print(f"Error loading model: {e}")
58
+ return None
59
+
60
+
61
+ # Load the model once when the script starts
62
+ model_instance = load_model()
63
+
64
+ def predict_image(input_image, question="Is this photo real [*]?"):
65
+ if input_image is None:
66
+ return "No image provided", 0
67
+
68
+ try:
69
+ # Ensure input is a PIL Image
70
+ if not isinstance(input_image, Image.Image):
71
+ input_image = Image.fromarray(input_image)
72
+
73
+ # Run model inference
74
+ result, confidence = model_instance.query(input_image, question)
75
+ return result, confidence
76
+ except Exception as e:
77
+ return f"Error: {str(e)}", 0
78
+
79
+ # Create Gradio interface
80
+ def create_interface():
81
+ with gr.Blocks(title="Fake Image Detector") as app:
82
+ gr.Markdown("""
83
+ # Real vs Fake Image Detector
84
+ Upload an image to check if it's real or AI-generated. The model will classify the image and provide a confidence score.
85
+ Based on AntifakePrompt: https://github.com/nctu-eva-lab/AntifakePrompt
86
+ """)
87
+
88
+ with gr.Row():
89
+ with gr.Column():
90
+ input_image = gr.Image(type="pil", label="Upload Image")
91
+ question = gr.Textbox(label="Question Prompt", value="Is this photo real [*]?")
92
+ submit_btn = gr.Button("Analyze Image", variant="primary")
93
+
94
+ with gr.Column():
95
+ result_label = gr.Textbox(label="Classification Result")
96
+ confidence = gr.Number(label="Confidence Score (%)")
97
+
98
+ submit_btn.click(
99
+ fn=predict_image,
100
+ inputs=[input_image, question],
101
+ outputs=[result_label, confidence]
102
+ )
103
+
104
+ gr.Examples(
105
+ examples=[
106
+ ["example_real.jpg", "Is this photo real [*]?"],
107
+ ["example_fake.jpg", "Is this photo real [*]?"],
108
+ ],
109
+ inputs=[input_image, question],
110
+ outputs=[result_label, confidence],
111
+ fn=predict_image,
112
+ cache_examples=True,
113
+ )
114
+
115
+ return app
116
+
117
+ if __name__ == "__main__":
118
+ # Create and launch the Gradio interface
119
+ demo = create_interface()
120
+ demo.launch(share=True) # Set share=True to get a public link