jayakrishna01's picture
Upload folder using huggingface_hub
1a5d173 verified
import gradio as gr
import cv2
import numpy as np
import torch
from model import DropoutNet
is_cuda = torch.cuda.is_available()
model = DropoutNet()
model.load_state_dict(torch.load('final_model.pth', map_location=torch.device('cpu')))
model.eval()
if is_cuda:
print("Running on the GPU")
model = model.to('cuda')
else:
print("Running on the CPU")
def predict(image):
image = cv2.resize(image, (28, 28))
gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
image = gray_image
image = np.expand_dims(image, axis=0)
image = image.reshape(1, 1, 28, 28)
image = torch.from_numpy(image)
image = image.float()
output = model(image)
out_value, out_index = torch.max(output, 1)
labels = {'0': 'choroidal neovascularization', '1': 'diabetic macular edema', '2': 'drusen', '3': 'normal'}
return labels[str(out_index.item())]
description_html = """
<p>This model predicts the disease based on the retinal image.</p>
"""
article_html = """
<h3>How does it work?</h3>
<p>The model is a Convolutional Neural Network (CNN) trained on the retinal images to predict the disease.</p>
<p>Dataset used for training is MEDMNIST dataset which contains retinal images of 4 different diseases.</p>
<p>It uses PyTorch framework for training and prediction.</p>
<h3>How to use?</h3>
<p>Upload an image of the retina and click on 'Submit' to get the prediction.</p>
<p>It will show the predicted disease based on the input image.</p>
<p>It can predict one of the following diseases:</p>
<ul>
<li>Choroidal Neovascularization</li>
<li>Diabetic Macular Edema</li>
<li>Drusen</li>
<li>Normal</li>
</ul>
<h3>How accurate is it?</h3>
<p>The model has an accuracy of 75 on the test dataset.</p>
"""
gr.Interface(fn=predict, inputs="image", outputs="label", title="Retinal Disease Prediction", description="This model predicts the disease based on the retinal image.", article=article_html).launch()