MEDIGPT / app.py
FawadHaider2's picture
Update app.py
75c3719 verified
# -*- coding: utf-8 -*-
"""app.ipynb
Automatically generated by Colab.
Original file is located at
https://colab.research.google.com/drive/1sjyLFLqBccpUzaUi4eyyP3NYE3gDtHfs
"""
import streamlit as st
from streamlit_option_menu import option_menu
import tensorflow as tf
from tensorflow import keras
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
import os
import time
from efficientnet_pytorch import EfficientNet
from fastai.vision.all import load_learner
# Set up environment variables for GPU handling
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
# Cache to load models efficiently
@st.cache_resource
def load_skin_model():
model = MelanomaModel(out_size=9)
model_path = "multi_weight.pth"
checkpoint = torch.load(model_path, map_location=torch.device('cpu'), weights_only=False)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()
return model
# Transform for skin lesion images
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# Diagnosis map for skin disease model
DIAGNOSIS_MAP = {
0: 'Melanoma', 1: 'Melanocytic nevus', 2: 'Basal cell carcinoma', 3: 'Actinic keratosis',
4: 'Benign keratosis', 5: 'Dermatofibroma', 6: 'Vascular lesion', 7: 'Squamous cell carcinoma', 8: 'Unknown'
}
# Model for skin lesion classification
class MelanomaModel(nn.Module):
def __init__(self, out_size, dropout_prob=0.5):
super(MelanomaModel, self).__init__()
self.efficient_net = EfficientNet.from_pretrained('efficientnet-b0')
self.efficient_net._fc = nn.Identity()
self.fc1 = nn.Linear(1280, 512)
self.fc2 = nn.Linear(512, 256)
self.fc3 = nn.Linear(256, out_size)
self.dropout = nn.Dropout(dropout_prob)
def forward(self, x):
x = self.efficient_net(x)
x = x.view(x.size(0), -1)
x = F.relu(self.fc1(x))
x = self.dropout(x)
x = F.relu(self.fc2(x))
x = self.dropout(x)
x = self.fc3(x)
return x
# Alzheimer's Prediction
@st.cache_resource
def load_alzheimer_model():
return keras.models.load_model('alzheimer_99.5.h5')
# Brain Tumor Prediction
@st.cache_resource
def load_brain_tumor_model(classes):
if classes == '44 Classes':
return keras.models.load_model('44class_96.5.h5')
elif classes == '17 Classes':
return keras.models.load_model('17class_98.1.h5')
elif classes == '15 Classes':
return keras.models.load_model('15class_99.8.h5')
else: # Default 2 classes
return keras.models.load_model('2calss_lagre_dataset_99.1.h5')
# Prediction for Skin Disease
def predict_skin_lesion(img: Image.Image, model: nn.Module):
img_tensor = transform(img).unsqueeze(0)
with torch.no_grad():
outputs = model(img_tensor)
probs = F.softmax(outputs, dim=1)
top_probs, top_idxs = torch.topk(probs, 3, dim=1)
predictions = [(DIAGNOSIS_MAP.get(idx.item(), "Unknown"), prob.item() * 100) for prob, idx in zip(top_probs[0], top_idxs[0])]
return predictions
# Prediction for Brain Tumor and Alzheimer
def predict(img_path, model, result_classes):
img = tf.keras.utils.load_img(img_path, target_size=(224, 224))
img_array = np.array(img).reshape(-1, 224, 224, 3)
pred = model.predict(img_array)
return result_classes[np.argmax(pred, axis=1)[0]]
# Sidebar for Disease Categories
def spr_sidebar():
menu = option_menu(
menu_title="Navigation",
options=["Brain Tumor", "Alzheimer", "Skin Disease", "Eye Disease", "About"],
icons=["house", "brain", "microscope", "eye", "info-square"],
menu_icon="cast",
default_index=0,
orientation="horizontal"
)
return menu
# Home Page Content
def home_page(selected_category): # Accept selected_category as a parameter
st.title("Disease Detection Web App")
uploaded_file = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"])
if uploaded_file is not None:
st.image(uploaded_file, caption="Uploaded Image", use_column_width=True)
if st.button("Classify"):
if selected_category == "Brain Tumor":
classes = st.selectbox("Select Number of Classes", ['44 Classes', '17 Classes', '15 Classes', '2 Classes'])
model = load_brain_tumor_model(classes)
result_classes = ['Astrocitoma', 'Carcinoma', 'Ependimoma', '_NORMAL', 'etc...'] # Define all the classes
result = predict(uploaded_file, model, result_classes)
st.success(f"Prediction: {result}")
elif selected_category == "Alzheimer":
model = load_alzheimer_model()
result_classes = ['Mild_Demented', 'Moderate_Demented', 'Non_Demented']
result = predict(uploaded_file, model, result_classes)
st.success(f"Prediction: {result}")
elif selected_category == "Skin Disease":
model = load_skin_model()
img = Image.open(uploaded_file)
predictions = predict_skin_lesion(img, model)
for idx, (label, confidence) in enumerate(predictions, 1):
st.write(f"{idx}. {label}: {confidence:.2f}%")
elif selected_category == "Eye Disease":
# Implement Eye Disease prediction (similar to others)
pass
# About Page Content
def about_page():
st.header('About the Project')
st.write("""
This web app detects different diseases using machine learning models.
The diseases it covers include:
""")
st.subheader("1. Brain Tumor Detection")
st.write("""
**Brain Tumor Detection** involves classifying various types of brain tumors using image analysis.
The model can identify different classes of brain tumors such as Astrocytomas, Gliomas, and more.
It aids doctors in diagnosing brain tumors early for timely intervention.
""")
st.subheader("2. Alzheimer’s Disease Detection")
st.write("""
**Alzheimer's Disease Detection** focuses on predicting the stages of Alzheimer's disease from brain scans.
It classifies the brain images into stages such as Mild Demented, Moderate Demented, and Non-Demented.
Early detection of Alzheimer's can help in planning the appropriate treatment for patients.
""")
st.subheader("3. Skin Disease Classification")
st.write("""
**Skin Disease Classification** uses deep learning models to detect and classify different types of skin lesions.
This includes conditions like Melanoma, Basal Cell Carcinoma, and various benign lesions.
Early detection of skin cancer, like Melanoma, can significantly improve the survival rate of patients.
""")
st.subheader("4. Eye Disease Detection")
st.write("""
**Eye Disease Detection** (TBD) focuses on diagnosing eye conditions from images of the eye.
This could include diseases such as diabetic retinopathy, cataracts, and glaucoma, which can affect vision and lead to blindness if untreated.
""")
st.write("""
This is a project by **Fawad Haider** and **Sameer Ahmed**.
We aim to assist healthcare professionals in diagnosing various diseases early, improving the accuracy of predictions using AI and deep learning models.
""")
# Main Function to Run the App
def main():
selected_category = spr_sidebar()
if selected_category == "Brain Tumor":
home_page(selected_category) # Pass selected_category here
elif selected_category == "Alzheimer":
home_page(selected_category) # Pass selected_category here
elif selected_category == "Skin Disease":
home_page(selected_category) # Pass selected_category here
elif selected_category == "Eye Disease":
home_page(selected_category) # Pass selected_category here
elif selected_category == "About":
about_page()
if __name__ == '__main__':
main()