File size: 2,029 Bytes
1a5d173
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import gradio as gr
import cv2
import numpy as np
import torch
from model import DropoutNet

is_cuda = torch.cuda.is_available()


model = DropoutNet()
model.load_state_dict(torch.load('final_model.pth', map_location=torch.device('cpu')))

model.eval()


if is_cuda:
    print("Running on the GPU")
    model = model.to('cuda')
else:
    print("Running on the CPU")

def predict(image):
    image = cv2.resize(image, (28, 28))

    gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    image = gray_image

    image = np.expand_dims(image, axis=0)
    image = image.reshape(1, 1, 28, 28)
    image = torch.from_numpy(image)
    image = image.float()

    output = model(image)
    
    out_value, out_index = torch.max(output, 1)

    labels = {'0': 'choroidal neovascularization', '1': 'diabetic macular edema', '2': 'drusen', '3': 'normal'}
    return labels[str(out_index.item())]

description_html = """
    <p>This model predicts the disease based on the retinal image.</p>
"""

article_html = """

    <h3>How does it work?</h3>
    <p>The model is a Convolutional Neural Network (CNN) trained on the retinal images to predict the disease.</p>
    <p>Dataset used for training is MEDMNIST dataset which contains retinal images of 4 different diseases.</p>
    <p>It uses PyTorch framework for training and prediction.</p>

    <h3>How to use?</h3>
    <p>Upload an image of the retina and click on 'Submit' to get the prediction.</p>
    <p>It will show the predicted disease based on the input image.</p>
    <p>It can predict one of the following diseases:</p>
    <ul>
        <li>Choroidal Neovascularization</li>
        <li>Diabetic Macular Edema</li>
        <li>Drusen</li>
        <li>Normal</li>
    </ul>

    <h3>How accurate is it?</h3>
    <p>The model has an accuracy of 75 on the test dataset.</p>

"""

gr.Interface(fn=predict, inputs="image", outputs="label", title="Retinal Disease Prediction", description="This model predicts the disease based on the retinal image.", article=article_html).launch()