Spaces:
Sleeping
Sleeping
import os | |
import json | |
import torch | |
from PIL import Image | |
from torchvision import transforms | |
import matplotlib.pyplot as plt | |
import gradio as gr | |
from io import BytesIO | |
from vit_model import vit_base_patch16_224_in21k as create_model | |
def classify_image(img): | |
# Your existing code here, modified to use `img_path` as input | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
data_transform = transforms.Compose( | |
[transforms.Resize(256), | |
transforms.CenterCrop(224), | |
transforms.ToTensor(), | |
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) | |
# [N, C, H, W] | |
img = data_transform(img) | |
# expand batch dimension | |
img = torch.unsqueeze(img, dim=0) | |
# read class_indict | |
json_path = './class_indices.json' | |
assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path) | |
with open(json_path, "r") as f: | |
class_indict = json.load(f) | |
# create model | |
model = create_model(num_classes=370, has_logits=False).to(device) | |
# load model weights | |
model_weight_path = "./best_model.pth" | |
model.load_state_dict(torch.load(model_weight_path, map_location=device)) | |
model.eval() | |
with torch.no_grad(): | |
# predict class | |
output = torch.squeeze(model(img.to(device))).cpu() | |
predict = torch.softmax(output, dim=0) | |
predict_cla = torch.argmax(predict).numpy() | |
print_res = "class: {} prob: {:.3}".format(class_indict[str(predict_cla)], | |
predict[predict_cla].numpy()) | |
# Combine the two lists into a list of tuples | |
combined_list = list(zip(class_indict, predict)) | |
# Sort the combined list by the 'predict' values in descending order | |
sorted_combined_list = sorted(combined_list, key=lambda x: x[1], reverse=True) | |
# Determine the position you are currently interested in | |
current_position = 5 # Example position | |
# Get the previous five elements from the sorted list | |
# Ensure that the index does not go below zero | |
start_index = max(current_position - 5, 0) | |
previous_five = sorted_combined_list[start_index:current_position] | |
joined_string = "" | |
for i in previous_five: | |
#print("class: {:10} prob: {:.3}".format(class_indict[str(i[0])], i[1].numpy())) | |
joined_string += ("class: {:10} prob: {:.3}".format(class_indict[str(i[0])], i[1].numpy())) + "\n" | |
#print(joined_string) | |
plt.title(joined_string) | |
plt.tight_layout() | |
fig = plt.figure() | |
return joined_string | |
# Create a Gradio interface | |
iface = gr.Interface( | |
fn=classify_image, | |
theme=gr.themes.Default(text_size="lg"), | |
inputs=gr.Image(type='pil'), | |
outputs=gr.Textbox(), | |
title="Mushroom Image Classification", | |
description="Upload a mushroom image to classify." | |
) | |
# Run the Gradio app | |
#if __name__ == '__main__': | |
iface.launch() | |