File size: 5,091 Bytes
26ba08b
 
 
71f7404
1deeca6
71f7404
1deeca6
a55ee38
 
23fc9d5
 
a55ee38
822836e
 
01e4a84
 
a55ee38
 
822836e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
01e4a84
20c3633
01e4a84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a55ee38
d9e73e3
 
 
 
652c7e2
a55ee38
 
 
20c3633
a55ee38
 
 
 
 
01e4a84
cb72cb9
1deeca6
 
41c6607
 
 
 
 
 
 
 
1deeca6
 
41c6607
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20d7a51
1deeca6
 
20d7a51
 
cb72cb9
2606cf8
cb72cb9
 
20d7a51
cb72cb9
20d7a51
f39ef25
 
71f7404
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import os
os.system('pip install git+https://github.com/openai/CLIP.git')

import gradio as gr
import datetime

import PIL
from PIL import Image
from transformers import BlipProcessor, BlipForConditionalGeneration
from transformers import Blip2Processor, Blip2ForConditionalGeneration

import torch
import clip
import torch.nn as nn
from torchvision.transforms import transforms

device = device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class OpenaAIClip(nn.Module):

    def __init__(self, arch="resnet50", modality="image"):
        super().__init__()
        self.model = None
        self.modality = modality
        if arch == "resnet50":
            self.model, _ = clip.load("RN50")

        if self.modality == "image":
            for name, param in self.model.named_parameters():
                if "visual" in name:
                    #print("Unfreezing layer: ", name)
                    param.requires_grad = True
                else:
                    param.requires_grad = False

        self.fc = nn.Identity()

    def forward(self, image, text=None):
        image_features = self.model.encode_image(image)

        if self.modality == "image+text":
            text = clip.tokenize(text, truncate=True).to(device)
            text_features = self.model.encode_text(text)
        else:
            return self.fc(image_features)

        combined_features = torch.cat((image_features, text_features), dim=1)
        return self.fc(combined_features)


def preprocessing(img, size):
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    data_transforms = transforms.Compose([
        transforms.Resize((size, size)),
        transforms.ToTensor(),
        normalize])
    img = data_transforms(img)
    return img

def get_model(model_path,modality):
    if modality == "Image":
        model = OpenaAIClip(arch="resnet50", modality="image")
        dim_mlp = 1024
        fc_units = [512]
        model.fc = nn.Sequential(nn.Linear(dim_mlp, fc_units[0]), nn.ReLU(), nn.Linear(fc_units[0], 1),
                                 nn.Sigmoid())
    elif modality == "Image+Text":
        model = OpenaAIClip(arch="resnet50", modality="image+text")
        dim_mlp = 2048
        fc_units = [1024]
        model.fc = nn.Sequential(nn.Linear(dim_mlp, fc_units[0]), nn.ReLU(), nn.Linear(fc_units[0], 1),
                                 nn.Sigmoid())

    checkpoint_dict = torch.load(model_path, map_location=torch.device('cpu'))
    model.load_state_dict(checkpoint_dict['state_dict'])
    model.eval()
    return model

def get_blip_model():
    processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
    model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large")
    #processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
    #model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b")

    model.eval()
    return processor, model

def get_caption(image):
    processor, model = get_blip_model()
    inputs = processor(images=image, return_tensors="pt")
    outputs = model.generate(**inputs.to(device))
    caption = processor.decode(outputs[0], skip_special_tokens=True)
    return caption

def predict(img, caption):
    now = datetime.datetime.now()
    print(now)
    if img is not None:
        print(caption)
        if caption is None or caption == "":
            caption = get_caption(img)
            print("Generated caption-->", caption)
        else:
            print("User input caption-->", caption)

        img.save("models/"+str(now)+".png")

        prediction=[]
        models_list = ['models/clip-sd.pth', 'models/clip-glide.pth', 'models/clip-ld.pth']
        modality = "Image+Text"
        for i, model_path in enumerate(models_list):
            model = get_model(model_path, modality)
            tensor = preprocessing(img, 224)
            input_tensor = tensor.view(1, 3, 224, 224)
            with torch.no_grad():
                out = model(input_tensor, caption)
                print(models_list[i], ' ----> ', out)
                prediction.append(out.item())

        # Count the number of predictions that are greater than or equal to 0.5
        count_ones = sum(1 for p in prediction if p >= 0.5)
        if count_ones > len(prediction) / 2:
            return "Fake Image"
        else:
            return "Real Image"
    else:
        print("Alert: Input image missing")
        return "Alert: Input image missing"


# Create Gradio interface
image_input = gr.Image(type="pil", label="Input Image")
text_input = gr.Textbox(label="Caption for image (Optional)")

iface = gr.Interface(fn=predict,
             inputs=[image_input, text_input],
             outputs=gr.Label(),
             examples=[["examples/trump-fake.jpeg", "Donald Trump being arrested by authorities."],
                       ["examples/astronaut_space.png", "An astronaut playing basketball with a cat in space, digital art"]])
iface.launch()