Rohit8y's picture
BLIP model shifted to CPU
652c7e2
raw
history blame
4.31 kB
import os
os.system('pip install git+https://github.com/openai/CLIP.git')
import gradio as gr
from PIL import Image
from transformers import BlipProcessor, BlipForConditionalGeneration
from transformers import Blip2Processor, Blip2ForConditionalGeneration
import torch
import clip
import torch.nn as nn
from torchvision.transforms import transforms
device = device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class OpenaAIClip(nn.Module):
def __init__(self, arch="resnet50", modality="image"):
super().__init__()
self.model = None
self.modality = modality
if arch == "resnet50":
self.model, _ = clip.load("RN50")
if self.modality == "image":
for name, param in self.model.named_parameters():
if "visual" in name:
#print("Unfreezing layer: ", name)
param.requires_grad = True
else:
param.requires_grad = False
self.fc = nn.Identity()
def forward(self, image, text=None):
image_features = self.model.encode_image(image)
if self.modality == "image+text":
text = clip.tokenize(text, truncate=True).to(device)
text_features = self.model.encode_text(text)
else:
return self.fc(image_features)
combined_features = torch.cat((image_features, text_features), dim=1)
return self.fc(combined_features)
def preprocessing(img, size):
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
data_transforms = transforms.Compose([
transforms.Resize((size, size)),
transforms.ToTensor(),
normalize])
img = data_transforms(img)
return img
def get_model(model_path,modality):
if modality == "Image":
model = OpenaAIClip(arch="resnet50", modality="image")
dim_mlp = 1024
fc_units = [512]
model.fc = nn.Sequential(nn.Linear(dim_mlp, fc_units[0]), nn.ReLU(), nn.Linear(fc_units[0], 1),
nn.Sigmoid())
elif modality == "Image+Text":
model = OpenaAIClip(arch="resnet50", modality="image+text")
dim_mlp = 2048
fc_units = [1024]
model.fc = nn.Sequential(nn.Linear(dim_mlp, fc_units[0]), nn.ReLU(), nn.Linear(fc_units[0], 1),
nn.Sigmoid())
checkpoint_dict = torch.load(model_path, map_location=torch.device('cpu'))
model.load_state_dict(checkpoint_dict['state_dict'])
model.eval()
return model
def get_blip_model():
#processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
#model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large")
processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b")
model.eval()
return processor, model
def get_caption(image):
processor, model = get_blip_model()
inputs = processor(images=image, return_tensors="pt")
outputs = model.generate(**inputs.to(device))
caption = processor.decode(outputs[0], skip_special_tokens=True)
return caption
def predict(img, caption):
if caption is not None or "":
caption = get_caption(img)
print(caption)
prediction=[]
models_list = ['clip-sd.pth', 'clip-glide.pth', 'clip-ld.pth']
modality = "Image+Text"
for model_path in models_list:
model = get_model(model_path, modality)
tensor = preprocessing(img, 224)
input_tensor = tensor.view(1, 3, 224, 224)
with torch.no_grad():
out = model(input_tensor, caption)
print('------------>', out)
prediction.append(out.item())
if prediction[0] > 0.5:
return "Fake"
else:
return "Real"
# Create Gradio interface
image_input = gr.Image(type="pil", label="Upload Image")
text_input = gr.Textbox(label="Caption for image (Optional)")
iface = gr.Interface(fn=predict,
inputs=[image_input, text_input],
outputs=gr.Label(),
examples=[["trump-fake.jpeg", "Donald Trump being arrested by authorities."]])
iface.launch()