|
import gradio as gr |
|
import cv2 |
|
import numpy as np |
|
import torch |
|
from model import DropoutNet |
|
|
|
is_cuda = torch.cuda.is_available() |
|
|
|
|
|
model = DropoutNet() |
|
model.load_state_dict(torch.load('final_model.pth', map_location=torch.device('cpu'))) |
|
|
|
model.eval() |
|
|
|
|
|
if is_cuda: |
|
print("Running on the GPU") |
|
model = model.to('cuda') |
|
else: |
|
print("Running on the CPU") |
|
|
|
def predict(image): |
|
image = cv2.resize(image, (28, 28)) |
|
|
|
gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) |
|
image = gray_image |
|
|
|
image = np.expand_dims(image, axis=0) |
|
image = image.reshape(1, 1, 28, 28) |
|
image = torch.from_numpy(image) |
|
image = image.float() |
|
|
|
output = model(image) |
|
|
|
out_value, out_index = torch.max(output, 1) |
|
|
|
labels = {'0': 'choroidal neovascularization', '1': 'diabetic macular edema', '2': 'drusen', '3': 'normal'} |
|
return labels[str(out_index.item())] |
|
|
|
description_html = """ |
|
<p>This model predicts the disease based on the retinal image.</p> |
|
""" |
|
|
|
article_html = """ |
|
|
|
<h3>How does it work?</h3> |
|
<p>The model is a Convolutional Neural Network (CNN) trained on the retinal images to predict the disease.</p> |
|
<p>Dataset used for training is MEDMNIST dataset which contains retinal images of 4 different diseases.</p> |
|
<p>It uses PyTorch framework for training and prediction.</p> |
|
|
|
<h3>How to use?</h3> |
|
<p>Upload an image of the retina and click on 'Submit' to get the prediction.</p> |
|
<p>It will show the predicted disease based on the input image.</p> |
|
<p>It can predict one of the following diseases:</p> |
|
<ul> |
|
<li>Choroidal Neovascularization</li> |
|
<li>Diabetic Macular Edema</li> |
|
<li>Drusen</li> |
|
<li>Normal</li> |
|
</ul> |
|
|
|
<h3>How accurate is it?</h3> |
|
<p>The model has an accuracy of 75 on the test dataset.</p> |
|
|
|
""" |
|
|
|
gr.Interface(fn=predict, inputs="image", outputs="label", title="Retinal Disease Prediction", description="This model predicts the disease based on the retinal image.", article=article_html).launch() |
|
|