moazx's picture
Update app.py
5860c3f verified
raw
history blame
4.89 kB
import gradio as gr
from PIL import Image
from joblib import load
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import efficientnet_b0
import torchvision.transforms as transforms
class MultiModalClassifier(nn.Module):
def __init__(self, num_classes, num_features):
super(MultiModalClassifier, self).__init__()
# Load pre-trained EfficientNet model
efficientnet = efficientnet_b0(pretrained=True)
# Remove the last classifier layer
self.efficientnet_features = nn.Sequential(*list(efficientnet.children())[:-1])
# Define additional feature dimensions
self.age_dim = 1 # assuming age is a single scalar value
self.anatom_site_dim = 1 # assuming anatomical site is a single scalar value
self.sex_dim = 1 # assuming sex is a single scalar value
# Fully connected layers for classification
self.fc1 = nn.Linear(num_features + self.age_dim + self.anatom_site_dim + self.sex_dim, 256)
self.fc2 = nn.Linear(256, num_classes)
# Dropout layer.
self.dropout = nn.Dropout(p=0.5)
def forward(self, image, age, anatom_site, sex):
# Forward pass through the pre-trained EfficientNet model
image_features = self.efficientnet_features(image)
image_features = F.avg_pool2d(image_features, image_features.size()[2:]).view(image.size(0), -1) # Flatten
# Reshape age, anatom_site, and sex tensors
age = age.view(-1, 1) # Reshape to [batch_size, 1]
anatom_site = anatom_site.view(-1, 1) # Reshape to [batch_size, 1]
sex = sex.view(-1, 1) # Reshape to [batch_size, 1]
# Concatenate image features with additional features
additional_features = torch.cat((age, anatom_site, sex), dim=1)
combined_features = torch.cat((image_features, additional_features), dim=1)
# Fully connected layers for classification
combined_features = F.relu(self.fc1(combined_features))
combined_features = self.dropout(combined_features)
output = self.fc2(combined_features)
return output
# Initialize the model
num_classes = 1 # Assuming binary classification
num_features = 1280 # Number of features extracted by EfficientNet-B0
model = MultiModalClassifier(num_classes, num_features)
# Load the saved model state dictionary
model.load_state_dict(torch.load(r'best_epoch_weights.pth',map_location=torch.device('cpu')))
# Set the model to evaluation mode
model.eval()
# Load the age scaler
age_scaler = load(r'age_approx_scaler.joblib')
# Define transforms for the data (adjust as necessary.)
test_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
])
diagnosis_map = {0: 'benign', 1: 'malignant'}
# Define mapping dictionaries for sexes and anatom_sites
sexes_mapping = {'male': 0, 'female': 1}
# Define mapping dictionary for anatom_site_general.
anatom_site_mapping = {
'torso': 0,
'lower extremity': 1,
'head/neck': 2,
'upper extremity': 3,
'palms/soles': 4,
'oral/genital': 5,
}
def predict(image, age, gender, anatom_site):
image = Image.fromarray(image)
# Apply transformations to the image
image = test_transform(image)
image = image.float()
image = image.unsqueeze(0) # Add batch dimension
sex = torch.tensor([[sexes_mapping[gender.lower()]]], dtype=torch.float32)
anatom_site = torch.tensor([[anatom_site_mapping[anatom_site]]], dtype=torch.float32)
# Scale the age using the loaded scaler
scaled_age = age_scaler.transform([[age]])
# Convert scaled age to a tensor
age_tensor = torch.tensor(np.array(scaled_age), dtype=torch.float32)
# Forward pass
output = model(image, age_tensor, anatom_site, sex)
# Apply sigmoid to the output (since it's a binary classification)
output_sigmoid = torch.sigmoid(output)
# Get the predicted class (0 or 1)
predicted_class = (output_sigmoid > 0.5).float()
return f"The predicted_class is a {diagnosis_map[int(predicted_class)]}."
description_html = """
Fill in the required parameters and click 'classify'.
"""
example_data = [
["ISIC_0000060_downsampled.jpg", 35, "Female", "torso"],
["ISIC_0068279.jpg", 45.0, "Female", "head/neck"]
]
inputs = [
"image",
gr.Number(label="Age", minimum=0, maximum=120),
gr.Dropdown(['Male', 'Female'], label="Gender"),
gr.Dropdown(['torso', 'lower extremity', 'head/neck', 'upper extremity', 'palms/soles', 'oral/genital'], label="Anatomical Site")
]
gr.Interface(
predict,
inputs,
outputs = gr.Textbox(label="Output", type="text"),
title="Skin Cancer Diagnosis",
description=description_html,
allow_flagging='never',
examples=example_data
).launch()