File size: 4,486 Bytes
2a2d54d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
from transformers import BertTokenizer, BertModel, ViTModel
import gradio as gr

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define the VQA Model
class VQAModel(nn.Module):
    def __init__(self, vit_model, bert_model, num_classes, hidden_size=768):
        super(VQAModel, self).__init__()
        self.vit_model = vit_model
        self.bert_model = bert_model
        self.fc = nn.Linear(768 + hidden_size, hidden_size)  # Adjust input size to match concatenated features
        self.classifier = nn.Linear(hidden_size, num_classes)  # num_classes is dynamically determined

    def forward(self, image, question):
        # Extract image features
        with torch.no_grad():
            image_features = self.vit_model(image).last_hidden_state[:, 0, :]  # [CLS] token, Shape: (batch_size, 768)

        # Extract text features
        with torch.no_grad():
            question_encoded = self.bert_model(question).last_hidden_state[:, 0, :]  # [CLS] token, Shape: (batch_size, 768)

        # Concatenate image and text features
        combined_features = torch.cat((image_features, question_encoded), dim=1)  # Shape: (batch_size, 1536)

        # Pass through fully connected layer
        combined_features = self.fc(combined_features)  # Shape: (batch_size, hidden_size)

        # Classify
        output = self.classifier(combined_features)  # Shape: (batch_size, num_classes)
        return output

# Load the saved model checkpoint
checkpoint_path = 'vqa_vit_best_model.pth'  # Path to the saved model
checkpoint = torch.load(checkpoint_path, map_location=device)

# Load ViT and BERT models
vit_model = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k').to(device)
bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_model = BertModel.from_pretrained('bert-base-uncased').to(device)

# Initialize the VQA model with the correct number of classes
model = VQAModel(vit_model, bert_model, num_classes=checkpoint['num_classes']).to(device)

# Load the model state dict
model.load_state_dict(checkpoint['model_state_dict'])

# Load the answer-to-label mapping
answer_to_label = checkpoint['answer_to_label']
label_to_answer = {v: k for k, v in answer_to_label.items()}  # Reverse mapping for inference

# Set the model to evaluation mode
model.eval()

# Define transformations for the image
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize to 224x224 as required by ViT
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # Normalize for ViT
])

# Function to preprocess and predict
def predict(image_path, question):
    # Load and transform the image
    image = Image.open(image_path).convert('RGB')
    image = transform(image).unsqueeze(0).to(device)  # Add batch dimension and move to device

    # Tokenize the question
    question_encoded = bert_tokenizer(
        question,
        return_tensors='pt',
        padding='max_length',  # Pad to the maximum length
        truncation=True,       # Truncate if the question is too long
        max_length=32          # Set a maximum sequence length
    ).to(device)

    # Perform inference
    with torch.no_grad():
        output = model(image, question_encoded['input_ids'])

    # Get the predicted label
    _, predicted_label = torch.max(output, 1)
    predicted_label = predicted_label.item()

    # Map the label back to the answer
    predicted_answer = label_to_answer[predicted_label]

    return predicted_answer

# Define the question (already set)
question = "What is the overall complexity of this model?"

# Define the Gradio interface function
def vqa_interface(image):
    # Predict the answer using the provided image and the predefined question
    predicted_answer = predict(image, question)
    return predicted_answer

# Create the Gradio interface
iface = gr.Interface(
    fn=vqa_interface,  # Function to call
    inputs=gr.Image(type="filepath"),  # Input type: image file path
    outputs="text",  # Output type: text (predicted answer)
    title="Visual Question Answering (VQA) System",
    description="Upload an image, and the system will answer the question: 'What is the overall complexity of this model?'",
    examples=[
        ["02_uml.png"],["2ndIterationClassDiagram.png"],["4-gameUML.png"]]
)

# Launch the Gradio interface
iface.launch()