File size: 3,592 Bytes
418196b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9c43502
418196b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90f97af
418196b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import numpy as np
import gradio as gr
from base import TrainingObjective, ModelBackbone, model_input_dict, TransformationType
import os 
import json
from utils import get_model_architecture, get_transforms_to_apply_, get_transforms_to_apply
import torch
from PIL import Image
from torchvision import transforms
from matplotlib import pyplot as plt
from dataset import AcneDataset
import config
from torch.utils.data import DataLoader
from tqdm import tqdm


curr_dir = os.path.dirname(os.path.abspath(__file__))
#go up one level
# curr_dir = os.path.dirname(curr_dir)
model_dir = os.path.join(curr_dir, 'model')
model_name = 'model_0'
model_dir = os.path.join(model_dir, model_name)
# model_dir = '/Users/suyashharlalka/Documents/workspace/gabit/acne_classification/model/model_0'
model_config = os.path.join(model_dir, 'config.json')

with open(model_config, 'r') as f:
    config_json = json.load(f)

model = get_model_architecture(config_json)
model_path = os.path.join(model_dir, 'model.pth')
model.load_state_dict(torch.load(model_path,map_location=torch.device('cpu')))

testing_transform = config_json['TRANSFORMS_TO_APPLY']
testing_transform = [TransformationType[transform] for transform in testing_transform if transform != 'RANDOM_HORIZONTAL_FLIP']
testing_transform = [get_transforms_to_apply_(transform) for transform in testing_transform]
testing_transform = transforms.Compose(testing_transform)

# data_dir = config.DATASET_PATH
# isLimited = config.IS_LIMITED
# dataset = AcneDataset(data_dir, limit=isLimited)
# dataset.transform = testing_transform
# dataloader = DataLoader(dataset, batch_size=32, shuffle=False, num_workers=0)
device = torch.device('cpu')

def predict_acne(img):
    img = testing_transform(img)
    img = img.unsqueeze(0)
    img = img.to(device)
    output = model(img)
    predicted = torch.round(output.data)
    predicted = predicted.squeeze(1)
    predicted = predicted.cpu().numpy()[0]
    return predicted


# img_dir = '/Users/suyash.harlalka/Desktop/personal/acne_classification/dataset/Classification/JPEGImages'
# img_names = os.listdir(img_dir)
# for name in img_names:
#     if name.endswith('.jpg'):
#         img_path = os.path.join(img_dir, name)
#         img = Image.open(img_path)
#         img = img.convert('RGB')
#         print(predict_acne(img))

correct = 0
total = 0
mainLabel = []
predictedLabel = []
model.to(device)
model.eval()

# for name in tqdm(img_names):
#     if name.endswith('.jpg'):
#         img_path = os.path.join(img_dir, name)
#         img = Image.open(img_path)
#         predicted = predict_acne(img)
#         label = int(name.split('_')[0][-1])
#         label = int(label)
#         correct += (predicted == label).sum().item()
#         mainLabel.append(label)
#         predictedLabel.append(predicted)
#         total += 1


import gradio as gr


gr.Interface(fn=predict_acne,
             inputs=gr.Image(type="pil"),
             outputs="number",
            ).launch()


# for images, labels in tqdm(dataloader, desc="Processing", unit="batch"):
#     images = images.to(device)
#     labels = labels.to(device)
#     outputs = model(images)
#     # print('Outputs:{outputs}, label: {label}'.format(outputs=outputs, label=labels))

#     predicted = torch.round(outputs.data)
#     predicted = predicted.squeeze(1)
#     total += labels.size(0)
#     correct += (predicted == labels).sum().item()

#     mainLabel.extend(labels.cpu().numpy())
#     predictedLabel.extend(predicted.cpu().numpy())
    
# print(f'Accuracy of the network on the {total} test images: {100 * correct / total}%')