|
import gradio as gr |
|
from PIL import Image |
|
from joblib import load |
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torchvision.models import efficientnet_b0 |
|
import torchvision.transforms as transforms |
|
|
|
|
|
class MultiModalClassifier(nn.Module): |
|
def __init__(self, num_classes, num_features): |
|
super(MultiModalClassifier, self).__init__() |
|
|
|
efficientnet = efficientnet_b0(pretrained=True) |
|
|
|
|
|
self.efficientnet_features = nn.Sequential(*list(efficientnet.children())[:-1]) |
|
|
|
|
|
self.age_dim = 1 |
|
self.anatom_site_dim = 1 |
|
self.sex_dim = 1 |
|
|
|
|
|
self.fc1 = nn.Linear(num_features + self.age_dim + self.anatom_site_dim + self.sex_dim, 256) |
|
self.fc2 = nn.Linear(256, num_classes) |
|
|
|
|
|
self.dropout = nn.Dropout(p=0.5) |
|
|
|
def forward(self, image, age, anatom_site, sex): |
|
|
|
image_features = self.efficientnet_features(image) |
|
image_features = F.avg_pool2d(image_features, image_features.size()[2:]).view(image.size(0), -1) |
|
|
|
|
|
age = age.view(-1, 1) |
|
anatom_site = anatom_site.view(-1, 1) |
|
sex = sex.view(-1, 1) |
|
|
|
additional_features = torch.cat((age, anatom_site, sex), dim=1) |
|
combined_features = torch.cat((image_features, additional_features), dim=1) |
|
|
|
|
|
combined_features = F.relu(self.fc1(combined_features)) |
|
combined_features = self.dropout(combined_features) |
|
output = self.fc2(combined_features) |
|
|
|
return output |
|
|
|
|
|
num_classes = 1 |
|
num_features = 1280 |
|
model = MultiModalClassifier(num_classes, num_features) |
|
|
|
|
|
model.load_state_dict(torch.load(r'best_epoch_weights.pth',map_location=torch.device('cpu'))) |
|
|
|
|
|
model.eval() |
|
|
|
|
|
age_scaler = load(r'age_approx_scaler.joblib') |
|
|
|
|
|
test_transform = transforms.Compose([ |
|
transforms.Resize((224, 224)), |
|
transforms.ToTensor(), |
|
]) |
|
|
|
diagnosis_map = {0: 'benign', 1: 'malignant'} |
|
|
|
|
|
sexes_mapping = {'male': 0, 'female': 1} |
|
|
|
|
|
anatom_site_mapping = { |
|
'torso': 0, |
|
'lower extremity': 1, |
|
'head/neck': 2, |
|
'upper extremity': 3, |
|
'palms/soles': 4, |
|
'oral/genital': 5, |
|
} |
|
|
|
def predict(image, age, gender, anatom_site): |
|
|
|
image = Image.fromarray(image) |
|
|
|
image = test_transform(image) |
|
image = image.float() |
|
image = image.unsqueeze(0) |
|
|
|
sex = torch.tensor([[sexes_mapping[gender.lower()]]], dtype=torch.float32) |
|
anatom_site = torch.tensor([[anatom_site_mapping[anatom_site]]], dtype=torch.float32) |
|
|
|
|
|
scaled_age = age_scaler.transform([[age]]) |
|
|
|
age_tensor = torch.tensor(np.array(scaled_age), dtype=torch.float32) |
|
|
|
|
|
output = model(image, age_tensor, anatom_site, sex) |
|
|
|
|
|
output_sigmoid = torch.sigmoid(output) |
|
|
|
predicted_class = (output_sigmoid > 0.5).float() |
|
|
|
|
|
return f"The predicted_class is a {diagnosis_map[int(predicted_class)]}." |
|
|
|
|
|
description_html = """ |
|
Fill in the required parameters and click 'classify'. |
|
""" |
|
|
|
example_data = [ |
|
["ISIC_0000060_downsampled.jpg", 35, "Female", "torso"], |
|
["ISIC_0068279.jpg", 45.0, "Female", "head/neck"] |
|
] |
|
|
|
inputs = [ |
|
"image", |
|
gr.Number(label="Age", minimum=0, maximum=120), |
|
gr.Dropdown(['Male', 'Female'], label="Gender"), |
|
gr.Dropdown(['torso', 'lower extremity', 'head/neck', 'upper extremity', 'palms/soles', 'oral/genital'], label="Anatomical Site") |
|
] |
|
|
|
gr.Interface( |
|
predict, |
|
inputs, |
|
outputs = gr.Textbox(label="Output", type="text"), |
|
title="Skin Cancer Diagnosis", |
|
description=description_html, |
|
allow_flagging='never', |
|
examples=example_data |
|
).launch() |