Spaces:
Sleeping
Sleeping
File size: 4,486 Bytes
2a2d54d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
from transformers import BertTokenizer, BertModel, ViTModel
import gradio as gr
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Define the VQA Model
class VQAModel(nn.Module):
def __init__(self, vit_model, bert_model, num_classes, hidden_size=768):
super(VQAModel, self).__init__()
self.vit_model = vit_model
self.bert_model = bert_model
self.fc = nn.Linear(768 + hidden_size, hidden_size) # Adjust input size to match concatenated features
self.classifier = nn.Linear(hidden_size, num_classes) # num_classes is dynamically determined
def forward(self, image, question):
# Extract image features
with torch.no_grad():
image_features = self.vit_model(image).last_hidden_state[:, 0, :] # [CLS] token, Shape: (batch_size, 768)
# Extract text features
with torch.no_grad():
question_encoded = self.bert_model(question).last_hidden_state[:, 0, :] # [CLS] token, Shape: (batch_size, 768)
# Concatenate image and text features
combined_features = torch.cat((image_features, question_encoded), dim=1) # Shape: (batch_size, 1536)
# Pass through fully connected layer
combined_features = self.fc(combined_features) # Shape: (batch_size, hidden_size)
# Classify
output = self.classifier(combined_features) # Shape: (batch_size, num_classes)
return output
# Load the saved model checkpoint
checkpoint_path = 'vqa_vit_best_model.pth' # Path to the saved model
checkpoint = torch.load(checkpoint_path, map_location=device)
# Load ViT and BERT models
vit_model = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k').to(device)
bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_model = BertModel.from_pretrained('bert-base-uncased').to(device)
# Initialize the VQA model with the correct number of classes
model = VQAModel(vit_model, bert_model, num_classes=checkpoint['num_classes']).to(device)
# Load the model state dict
model.load_state_dict(checkpoint['model_state_dict'])
# Load the answer-to-label mapping
answer_to_label = checkpoint['answer_to_label']
label_to_answer = {v: k for k, v in answer_to_label.items()} # Reverse mapping for inference
# Set the model to evaluation mode
model.eval()
# Define transformations for the image
transform = transforms.Compose([
transforms.Resize((224, 224)), # Resize to 224x224 as required by ViT
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # Normalize for ViT
])
# Function to preprocess and predict
def predict(image_path, question):
# Load and transform the image
image = Image.open(image_path).convert('RGB')
image = transform(image).unsqueeze(0).to(device) # Add batch dimension and move to device
# Tokenize the question
question_encoded = bert_tokenizer(
question,
return_tensors='pt',
padding='max_length', # Pad to the maximum length
truncation=True, # Truncate if the question is too long
max_length=32 # Set a maximum sequence length
).to(device)
# Perform inference
with torch.no_grad():
output = model(image, question_encoded['input_ids'])
# Get the predicted label
_, predicted_label = torch.max(output, 1)
predicted_label = predicted_label.item()
# Map the label back to the answer
predicted_answer = label_to_answer[predicted_label]
return predicted_answer
# Define the question (already set)
question = "What is the overall complexity of this model?"
# Define the Gradio interface function
def vqa_interface(image):
# Predict the answer using the provided image and the predefined question
predicted_answer = predict(image, question)
return predicted_answer
# Create the Gradio interface
iface = gr.Interface(
fn=vqa_interface, # Function to call
inputs=gr.Image(type="filepath"), # Input type: image file path
outputs="text", # Output type: text (predicted answer)
title="Visual Question Answering (VQA) System",
description="Upload an image, and the system will answer the question: 'What is the overall complexity of this model?'",
examples=[
["02_uml.png"],["2ndIterationClassDiagram.png"],["4-gameUML.png"]]
)
# Launch the Gradio interface
iface.launch() |