# -*- coding: utf-8 -*- """app.ipynb Automatically generated by Colab. Original file is located at https://colab.research.google.com/drive/1sjyLFLqBccpUzaUi4eyyP3NYE3gDtHfs """ import gradio as gr from fastai.vision.all import load_learner from PIL import Image import torch import torch.nn as nn import torch.nn.functional as F from torchvision import transforms # Model paths for all disease types model_path_skin_disease = 'multi_weight.pth' # Skin Disease Model model_path_brain_tumor = 'brain_tumor_model.pkl' model_path_alzheimers = 'alzheimers_model.pkl' model_path_eye_disease = 'eye_disease_model.pkl' # Load models skin_disease_model = torch.load(model_path_skin_disease) # For Skin Disease model brain_tumor_model = load_learner(model_path_brain_tumor) alzheimers_model = load_learner(model_path_alzheimers) eye_disease_model = load_learner(model_path_eye_disease) # Diagnosis Map for Skin Disease Model DIAGNOSIS_MAP = { 0: 'Melanoma', 1: 'Melanocytic nevus', 2: 'Basal cell carcinoma', 3: 'Actinic keratosis', 4: 'Benign keratosis', 5: 'Dermatofibroma', 6: 'Vascular lesion', 7: 'Squamous cell carcinoma', 8: 'Unknown' } # Image Preprocessing for Skin Disease Model transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) # Skin Disease Prediction Function def predict_skin_disease(img: Image.Image): img_tensor = transform(img).unsqueeze(0) with torch.no_grad(): outputs = skin_disease_model(img_tensor) probs = F.softmax(outputs, dim=1) top_probs, top_idxs = torch.topk(probs, 3, dim=1) # top 3 predictions predictions = [] for prob, idx in zip(top_probs[0], top_idxs[0]): label = DIAGNOSIS_MAP.get(idx.item(), "Unknown") confidence = prob.item() * 100 predictions.append(f"{label}: {confidence:.2f}%") return "\n".join(predictions) # Brain Tumor Prediction Function def predict_brain_tumor(image): pred, _, prob = brain_tumor_model.predict(image) return f"Prediction: {pred}, Probability: {prob.max():.2f}" # Alzheimer's Prediction Function def predict_alzheimers(image): pred, _, prob = alzheimers_model.predict(image) return f"Prediction: {pred}, Probability: {prob.max():.2f}" # Eye Disease Prediction Function def predict_eye_disease(image): pred, _, prob = eye_disease_model.predict(image) return f"Prediction: {pred}, Probability: {prob.max():.2f}" # Gradio Interface Function def main(): # Image input component image_input = gr.inputs.Image(shape=(224, 224), image_mode='RGB') # Dropdown to choose disease type model_choice = gr.inputs.Dropdown(choices=[ "Skin Disease", "Brain Tumor", "Alzheimer's Detection", "Eye Disease"], label="Select Disease Type") # Gradio tabs for each category with gr.Blocks() as demo: gr.Markdown("# Medical Image Classifier Dashboard") with gr.Tab("Skin Disease Prediction"): with gr.Column(): gr.Markdown("Upload a skin lesion image for diagnosis prediction.") image_input_skin = gr.Image(type="pil", label="Upload Skin Lesion Image") output_skin = gr.Textbox(label="Prediction Results") image_input_skin.change(predict_skin_disease, inputs=image_input_skin, outputs=output_skin) with gr.Tab("Brain Tumor Prediction"): with gr.Column(): gr.Markdown("Upload a brain scan image for tumor classification.") image_input_brain = gr.Image(type="pil", label="Upload Brain Scan Image") output_brain = gr.Textbox(label="Prediction Results") image_input_brain.change(predict_brain_tumor, inputs=image_input_brain, outputs=output_brain) with gr.Tab("Alzheimer's Prediction"): with gr.Column(): gr.Markdown("Upload a brain image for Alzheimer's detection.") image_input_alz = gr.Image(type="pil", label="Upload Alzheimer's Image") output_alz = gr.Textbox(label="Prediction Results") image_input_alz.change(predict_alzheimers, inputs=image_input_alz, outputs=output_alz) with gr.Tab("Eye Disease Prediction"): with gr.Column(): gr.Markdown("Upload an image for eye disease classification.") image_input_eye = gr.Image(type="pil", label="Upload Eye Disease Image") output_eye = gr.Textbox(label="Prediction Results") image_input_eye.change(predict_eye_disease, inputs=image_input_eye, outputs=output_eye) demo.launch() # Run the Gradio app if __name__ == "__main__": main()