File size: 4,894 Bytes
91cdd56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
02b7bab
91cdd56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
832f210
91cdd56
 
 
 
 
 
 
dde6b42
91cdd56
 
5860c3f
91cdd56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import gradio as gr
from PIL import Image
from joblib import load
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import efficientnet_b0
import torchvision.transforms as transforms


class MultiModalClassifier(nn.Module):
    def __init__(self, num_classes, num_features):
        super(MultiModalClassifier, self).__init__()
        # Load pre-trained EfficientNet model
        efficientnet = efficientnet_b0(pretrained=True)
        
        # Remove the last classifier layer
        self.efficientnet_features = nn.Sequential(*list(efficientnet.children())[:-1])
        
        # Define additional feature dimensions
        self.age_dim = 1  # assuming age is a single scalar value
        self.anatom_site_dim = 1  # assuming anatomical site is a single scalar value
        self.sex_dim = 1  # assuming sex is a single scalar value
        
        # Fully connected layers for classification
        self.fc1 = nn.Linear(num_features + self.age_dim + self.anatom_site_dim + self.sex_dim, 256)
        self.fc2 = nn.Linear(256, num_classes)
        
        # Dropout layer.
        self.dropout = nn.Dropout(p=0.5)
        
    def forward(self, image, age, anatom_site, sex):
        # Forward pass through the pre-trained EfficientNet model
        image_features = self.efficientnet_features(image)
        image_features = F.avg_pool2d(image_features, image_features.size()[2:]).view(image.size(0), -1)  # Flatten
        
        # Reshape age, anatom_site, and sex tensors
        age = age.view(-1, 1)  # Reshape to [batch_size, 1]
        anatom_site = anatom_site.view(-1, 1)  # Reshape to [batch_size, 1]
        sex = sex.view(-1, 1)  # Reshape to [batch_size, 1]
        # Concatenate image features with additional features
        additional_features = torch.cat((age, anatom_site, sex), dim=1)
        combined_features = torch.cat((image_features, additional_features), dim=1)
        
        # Fully connected layers for classification
        combined_features = F.relu(self.fc1(combined_features))
        combined_features = self.dropout(combined_features)
        output = self.fc2(combined_features) 
        
        return output

# Initialize the model
num_classes = 1  # Assuming binary classification
num_features = 1280  # Number of features extracted by EfficientNet-B0
model = MultiModalClassifier(num_classes, num_features)

# Load the saved model state dictionary
model.load_state_dict(torch.load(r'best_epoch_weights.pth',map_location=torch.device('cpu')))

# Set the model to evaluation mode
model.eval()

# Load the age scaler
age_scaler = load(r'age_approx_scaler.joblib')

# Define transforms for the data (adjust as necessary.)
test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

diagnosis_map = {0: 'benign', 1: 'malignant'} 

# Define mapping dictionaries for sexes and anatom_sites.
sexes_mapping = {'male': 0, 'female': 1}

# Define mapping dictionary for anatom_site_general.
anatom_site_mapping = {
    'torso': 0,
    'lower extremity': 1,
    'head/neck': 2,
    'upper extremity': 3,
    'palms/soles': 4,
    'oral/genital': 5,
}

def predict(image, age, gender, anatom_site):

    image = Image.fromarray(image)
    # Apply transformations to the image
    image = test_transform(image)
    image = image.float()
    image = image.unsqueeze(0)  # Add batch dimension

    sex = torch.tensor([[sexes_mapping[gender.lower()]]],  dtype=torch.float32)
    anatom_site = torch.tensor([[anatom_site_mapping[anatom_site]]], dtype=torch.float32)

    # Scale the age using the loaded scaler
    scaled_age = age_scaler.transform([[age]])
    # Convert scaled age to a tensor
    age_tensor = torch.tensor(np.array(scaled_age), dtype=torch.float32)

    # Forward pass
    output = model(image, age_tensor, anatom_site, sex)

    # Apply sigmoid to the output (since it's a binary classification)
    output_sigmoid = torch.sigmoid(output)
    # Get the predicted class (0 or 1)
    predicted_class = (output_sigmoid > 0.5).float()

    
    return f"The predicted_class is a {diagnosis_map[int(predicted_class)]}."


description_html = """
Fill in the required parameters and click 'classify'.
"""

example_data = [
    ["ISIC_0000060_downsampled.jpg", 35, "Female", "torso"],
    ["ISIC_0068279.jpg", 45.0, "Female", "head/neck"]
]

inputs = [
    "image",
    gr.Number(label="Age", minimum=0, maximum=120),
    gr.Dropdown(['Male', 'Female'], label="Gender"),
    gr.Dropdown(['torso', 'lower extremity', 'head/neck', 'upper extremity', 'palms/soles', 'oral/genital'], label="Anatomical Site")
]

gr.Interface(
    predict,
    inputs,
    outputs = gr.Textbox(label="Output", type="text"),
    title="Skin Cancer Diagnosis",
    description=description_html,
    allow_flagging='never',
    examples=example_data
).launch()