File size: 2,905 Bytes
d998353
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fe266b3
d998353
 
 
 
 
 
 
 
fe266b3
 
d998353
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4b7a32d
d998353
 
609518b
d998353
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import os
import json
import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
import gradio as gr
from io import BytesIO

from vit_model import vit_base_patch16_224_in21k as create_model

def classify_image(img):
    # Your existing code here, modified to use `img_path` as input
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
    data_transform = transforms.Compose(
        [transforms.Resize(256),
         transforms.CenterCrop(224),
         transforms.ToTensor(),
         transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])

    # [N, C, H, W]
    img = data_transform(img)
    # expand batch dimension
    img = torch.unsqueeze(img, dim=0)
 
    # read class_indict
    json_path = './class_indices.json'
    assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)
 
    with open(json_path, "r") as f:
        class_indict = json.load(f)
 
    # create model
    model = create_model(num_classes=370, has_logits=False).to(device)
    # load model weights
    model_weight_path = "./best_model.pth"

    model.load_state_dict(torch.load(model_weight_path, map_location=device))
    model.eval()
    with torch.no_grad():
        # predict class
        output = torch.squeeze(model(img.to(device))).cpu()
        predict = torch.softmax(output, dim=0)
        predict_cla = torch.argmax(predict).numpy()
 
    print_res = "class: {}   prob: {:.3}".format(class_indict[str(predict_cla)],
                                                 predict[predict_cla].numpy())

    # Combine the two lists into a list of tuples
    combined_list = list(zip(class_indict, predict))

    # Sort the combined list by the 'predict' values in descending order
    sorted_combined_list = sorted(combined_list, key=lambda x: x[1], reverse=True)

    # Determine the position you are currently interested in
    current_position = 5  # Example position

    # Get the previous five elements from the sorted list
    # Ensure that the index does not go below zero
    start_index = max(current_position - 5, 0)
    previous_five = sorted_combined_list[start_index:current_position]

    joined_string = ""
    for i in previous_five:
        #print("class: {:10}   prob: {:.3}".format(class_indict[str(i[0])], i[1].numpy()))
        joined_string += ("class: {:10}   prob: {:.3}".format(class_indict[str(i[0])], i[1].numpy())) + "\n"

    #print(joined_string)    
    plt.title(joined_string)
    plt.tight_layout()
    fig = plt.figure()
    return joined_string

# Create a Gradio interface
iface = gr.Interface(
    fn=classify_image, 
    theme=gr.themes.Default(text_size="lg"),
    inputs=gr.Image(type='pil'), 
    outputs=gr.Textbox(),
    title="Mushroom Image Classification",
    description="Upload a mushroom image to classify."
)

# Run the Gradio app
#if __name__ == '__main__':
iface.launch()