import gradio as gr from PIL import Image from joblib import load import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torchvision.models import efficientnet_b0 import torchvision.transforms as transforms class MultiModalClassifier(nn.Module): def __init__(self, num_classes, num_features): super(MultiModalClassifier, self).__init__() # Load pre-trained EfficientNet model efficientnet = efficientnet_b0(pretrained=True) # Remove the last classifier layer self.efficientnet_features = nn.Sequential(*list(efficientnet.children())[:-1]) # Define additional feature dimensions self.age_dim = 1 # assuming age is a single scalar value self.anatom_site_dim = 1 # assuming anatomical site is a single scalar value self.sex_dim = 1 # assuming sex is a single scalar value # Fully connected layers for classification self.fc1 = nn.Linear(num_features + self.age_dim + self.anatom_site_dim + self.sex_dim, 256) self.fc2 = nn.Linear(256, num_classes) # Dropout layer. self.dropout = nn.Dropout(p=0.5) def forward(self, image, age, anatom_site, sex): # Forward pass through the pre-trained EfficientNet model image_features = self.efficientnet_features(image) image_features = F.avg_pool2d(image_features, image_features.size()[2:]).view(image.size(0), -1) # Flatten # Reshape age, anatom_site, and sex tensors age = age.view(-1, 1) # Reshape to [batch_size, 1] anatom_site = anatom_site.view(-1, 1) # Reshape to [batch_size, 1] sex = sex.view(-1, 1) # Reshape to [batch_size, 1] # Concatenate image features with additional features additional_features = torch.cat((age, anatom_site, sex), dim=1) combined_features = torch.cat((image_features, additional_features), dim=1) # Fully connected layers for classification combined_features = F.relu(self.fc1(combined_features)) combined_features = self.dropout(combined_features) output = self.fc2(combined_features) return output # Initialize the model num_classes = 1 # Assuming binary classification num_features = 1280 # Number of features extracted by EfficientNet-B0 model = MultiModalClassifier(num_classes, num_features) # Load the saved model state dictionary model.load_state_dict(torch.load(r'best_epoch_weights.pth',map_location=torch.device('cpu'))) # Set the model to evaluation mode model.eval() # Load the age scaler age_scaler = load(r'age_approx_scaler.joblib') # Define transforms for the data (adjust as necessary.) test_transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), ]) diagnosis_map = {0: 'benign', 1: 'malignant'} # Define mapping dictionaries for sexes and anatom_sites. sexes_mapping = {'male': 0, 'female': 1} # Define mapping dictionary for anatom_site_general. anatom_site_mapping = { 'torso': 0, 'lower extremity': 1, 'head/neck': 2, 'upper extremity': 3, 'palms/soles': 4, 'oral/genital': 5, } def predict(image, age, gender, anatom_site): image = Image.fromarray(image) # Apply transformations to the image image = test_transform(image) image = image.float() image = image.unsqueeze(0) # Add batch dimension sex = torch.tensor([[sexes_mapping[gender.lower()]]], dtype=torch.float32) anatom_site = torch.tensor([[anatom_site_mapping[anatom_site]]], dtype=torch.float32) # Scale the age using the loaded scaler scaled_age = age_scaler.transform([[age]]) # Convert scaled age to a tensor age_tensor = torch.tensor(np.array(scaled_age), dtype=torch.float32) # Forward pass output = model(image, age_tensor, anatom_site, sex) # Apply sigmoid to the output (since it's a binary classification) output_sigmoid = torch.sigmoid(output) # Get the predicted class (0 or 1) predicted_class = (output_sigmoid > 0.5).float() return f"The predicted_class is a {diagnosis_map[int(predicted_class)]}." description_html = """ Fill in the required parameters and click 'classify'. """ example_data = [ ["ISIC_0000060_downsampled.jpg", 35, "Female", "torso"], ["ISIC_0068279.jpg", 45.0, "Female", "head/neck"] ] inputs = [ "image", gr.Number(label="Age", minimum=0, maximum=120), gr.Dropdown(['Male', 'Female'], label="Gender"), gr.Dropdown(['torso', 'lower extremity', 'head/neck', 'upper extremity', 'palms/soles', 'oral/genital'], label="Anatomical Site") ] gr.Interface( predict, inputs, outputs = gr.Textbox(label="Output", type="text"), title="Skin Cancer Diagnosis", description=description_html, allow_flagging='never', examples=example_data ).launch()