Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
from torchvision import transforms | |
from PIL import Image | |
from transformers import BertTokenizer, BertModel, ViTModel | |
import gradio as gr | |
# Set device | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# Define the VQA Model | |
class VQAModel(nn.Module): | |
def __init__(self, vit_model, bert_model, num_classes, hidden_size=768): | |
super(VQAModel, self).__init__() | |
self.vit_model = vit_model | |
self.bert_model = bert_model | |
self.fc = nn.Linear(768 + hidden_size, hidden_size) # Adjust input size to match concatenated features | |
self.classifier = nn.Linear(hidden_size, num_classes) # num_classes is dynamically determined | |
def forward(self, image, question): | |
# Extract image features | |
with torch.no_grad(): | |
image_features = self.vit_model(image).last_hidden_state[:, 0, :] # [CLS] token, Shape: (batch_size, 768) | |
# Extract text features | |
with torch.no_grad(): | |
question_encoded = self.bert_model(question).last_hidden_state[:, 0, :] # [CLS] token, Shape: (batch_size, 768) | |
# Concatenate image and text features | |
combined_features = torch.cat((image_features, question_encoded), dim=1) # Shape: (batch_size, 1536) | |
# Pass through fully connected layer | |
combined_features = self.fc(combined_features) # Shape: (batch_size, hidden_size) | |
# Classify | |
output = self.classifier(combined_features) # Shape: (batch_size, num_classes) | |
return output | |
# Load the saved model checkpoint | |
checkpoint_path = 'vqa_vit_best_model.pth' # Path to the saved model | |
checkpoint = torch.load(checkpoint_path, map_location=device) | |
# Load ViT and BERT models | |
vit_model = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k').to(device) | |
bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') | |
bert_model = BertModel.from_pretrained('bert-base-uncased').to(device) | |
# Initialize the VQA model with the correct number of classes | |
model = VQAModel(vit_model, bert_model, num_classes=checkpoint['num_classes']).to(device) | |
# Load the model state dict | |
model.load_state_dict(checkpoint['model_state_dict']) | |
# Load the answer-to-label mapping | |
answer_to_label = checkpoint['answer_to_label'] | |
label_to_answer = {v: k for k, v in answer_to_label.items()} # Reverse mapping for inference | |
# Set the model to evaluation mode | |
model.eval() | |
# Define transformations for the image | |
transform = transforms.Compose([ | |
transforms.Resize((224, 224)), # Resize to 224x224 as required by ViT | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # Normalize for ViT | |
]) | |
# Function to preprocess and predict | |
def predict(image_path, question): | |
# Load and transform the image | |
image = Image.open(image_path).convert('RGB') | |
image = transform(image).unsqueeze(0).to(device) # Add batch dimension and move to device | |
# Tokenize the question | |
question_encoded = bert_tokenizer( | |
question, | |
return_tensors='pt', | |
padding='max_length', # Pad to the maximum length | |
truncation=True, # Truncate if the question is too long | |
max_length=32 # Set a maximum sequence length | |
).to(device) | |
# Perform inference | |
with torch.no_grad(): | |
output = model(image, question_encoded['input_ids']) | |
# Get the predicted label | |
_, predicted_label = torch.max(output, 1) | |
predicted_label = predicted_label.item() | |
# Map the label back to the answer | |
predicted_answer = label_to_answer[predicted_label] | |
return predicted_answer | |
# Define the question (already set) | |
question = "What is the overall complexity of this model?" | |
# Define the Gradio interface function | |
def vqa_interface(image): | |
# Predict the answer using the provided image and the predefined question | |
predicted_answer = predict(image, question) | |
return predicted_answer | |
# Create the Gradio interface | |
iface = gr.Interface( | |
fn=vqa_interface, # Function to call | |
inputs=gr.Image(type="filepath"), # Input type: image file path | |
outputs="text", # Output type: text (predicted answer) | |
title="Visual Question Answering (VQA) System", | |
description="Upload an image, and the system will answer the question: 'What is the overall complexity of this model?'", | |
examples=[ | |
["02_uml.png"],["2ndIterationClassDiagram.png"],["4-gameUML.png"]] | |
) | |
# Launch the Gradio interface | |
iface.launch() |