File size: 2,029 Bytes
1a5d173 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
import gradio as gr
import cv2
import numpy as np
import torch
from model import DropoutNet
is_cuda = torch.cuda.is_available()
model = DropoutNet()
model.load_state_dict(torch.load('final_model.pth', map_location=torch.device('cpu')))
model.eval()
if is_cuda:
print("Running on the GPU")
model = model.to('cuda')
else:
print("Running on the CPU")
def predict(image):
image = cv2.resize(image, (28, 28))
gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
image = gray_image
image = np.expand_dims(image, axis=0)
image = image.reshape(1, 1, 28, 28)
image = torch.from_numpy(image)
image = image.float()
output = model(image)
out_value, out_index = torch.max(output, 1)
labels = {'0': 'choroidal neovascularization', '1': 'diabetic macular edema', '2': 'drusen', '3': 'normal'}
return labels[str(out_index.item())]
description_html = """
<p>This model predicts the disease based on the retinal image.</p>
"""
article_html = """
<h3>How does it work?</h3>
<p>The model is a Convolutional Neural Network (CNN) trained on the retinal images to predict the disease.</p>
<p>Dataset used for training is MEDMNIST dataset which contains retinal images of 4 different diseases.</p>
<p>It uses PyTorch framework for training and prediction.</p>
<h3>How to use?</h3>
<p>Upload an image of the retina and click on 'Submit' to get the prediction.</p>
<p>It will show the predicted disease based on the input image.</p>
<p>It can predict one of the following diseases:</p>
<ul>
<li>Choroidal Neovascularization</li>
<li>Diabetic Macular Edema</li>
<li>Drusen</li>
<li>Normal</li>
</ul>
<h3>How accurate is it?</h3>
<p>The model has an accuracy of 75 on the test dataset.</p>
"""
gr.Interface(fn=predict, inputs="image", outputs="label", title="Retinal Disease Prediction", description="This model predicts the disease based on the retinal image.", article=article_html).launch()
|