Spaces:
Sleeping
Sleeping
File size: 5,091 Bytes
26ba08b 71f7404 1deeca6 71f7404 1deeca6 a55ee38 23fc9d5 a55ee38 822836e 01e4a84 a55ee38 822836e 01e4a84 20c3633 01e4a84 a55ee38 d9e73e3 652c7e2 a55ee38 20c3633 a55ee38 01e4a84 cb72cb9 1deeca6 41c6607 1deeca6 41c6607 20d7a51 1deeca6 20d7a51 cb72cb9 2606cf8 cb72cb9 20d7a51 cb72cb9 20d7a51 f39ef25 71f7404 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
import os
os.system('pip install git+https://github.com/openai/CLIP.git')
import gradio as gr
import datetime
import PIL
from PIL import Image
from transformers import BlipProcessor, BlipForConditionalGeneration
from transformers import Blip2Processor, Blip2ForConditionalGeneration
import torch
import clip
import torch.nn as nn
from torchvision.transforms import transforms
device = device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class OpenaAIClip(nn.Module):
def __init__(self, arch="resnet50", modality="image"):
super().__init__()
self.model = None
self.modality = modality
if arch == "resnet50":
self.model, _ = clip.load("RN50")
if self.modality == "image":
for name, param in self.model.named_parameters():
if "visual" in name:
#print("Unfreezing layer: ", name)
param.requires_grad = True
else:
param.requires_grad = False
self.fc = nn.Identity()
def forward(self, image, text=None):
image_features = self.model.encode_image(image)
if self.modality == "image+text":
text = clip.tokenize(text, truncate=True).to(device)
text_features = self.model.encode_text(text)
else:
return self.fc(image_features)
combined_features = torch.cat((image_features, text_features), dim=1)
return self.fc(combined_features)
def preprocessing(img, size):
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
data_transforms = transforms.Compose([
transforms.Resize((size, size)),
transforms.ToTensor(),
normalize])
img = data_transforms(img)
return img
def get_model(model_path,modality):
if modality == "Image":
model = OpenaAIClip(arch="resnet50", modality="image")
dim_mlp = 1024
fc_units = [512]
model.fc = nn.Sequential(nn.Linear(dim_mlp, fc_units[0]), nn.ReLU(), nn.Linear(fc_units[0], 1),
nn.Sigmoid())
elif modality == "Image+Text":
model = OpenaAIClip(arch="resnet50", modality="image+text")
dim_mlp = 2048
fc_units = [1024]
model.fc = nn.Sequential(nn.Linear(dim_mlp, fc_units[0]), nn.ReLU(), nn.Linear(fc_units[0], 1),
nn.Sigmoid())
checkpoint_dict = torch.load(model_path, map_location=torch.device('cpu'))
model.load_state_dict(checkpoint_dict['state_dict'])
model.eval()
return model
def get_blip_model():
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large")
#processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
#model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b")
model.eval()
return processor, model
def get_caption(image):
processor, model = get_blip_model()
inputs = processor(images=image, return_tensors="pt")
outputs = model.generate(**inputs.to(device))
caption = processor.decode(outputs[0], skip_special_tokens=True)
return caption
def predict(img, caption):
now = datetime.datetime.now()
print(now)
if img is not None:
print(caption)
if caption is None or caption == "":
caption = get_caption(img)
print("Generated caption-->", caption)
else:
print("User input caption-->", caption)
img.save("models/"+str(now)+".png")
prediction=[]
models_list = ['models/clip-sd.pth', 'models/clip-glide.pth', 'models/clip-ld.pth']
modality = "Image+Text"
for i, model_path in enumerate(models_list):
model = get_model(model_path, modality)
tensor = preprocessing(img, 224)
input_tensor = tensor.view(1, 3, 224, 224)
with torch.no_grad():
out = model(input_tensor, caption)
print(models_list[i], ' ----> ', out)
prediction.append(out.item())
# Count the number of predictions that are greater than or equal to 0.5
count_ones = sum(1 for p in prediction if p >= 0.5)
if count_ones > len(prediction) / 2:
return "Fake Image"
else:
return "Real Image"
else:
print("Alert: Input image missing")
return "Alert: Input image missing"
# Create Gradio interface
image_input = gr.Image(type="pil", label="Input Image")
text_input = gr.Textbox(label="Caption for image (Optional)")
iface = gr.Interface(fn=predict,
inputs=[image_input, text_input],
outputs=gr.Label(),
examples=[["examples/trump-fake.jpeg", "Donald Trump being arrested by authorities."],
["examples/astronaut_space.png", "An astronaut playing basketball with a cat in space, digital art"]])
iface.launch()
|