Spaces:
Sleeping
Sleeping
File size: 2,905 Bytes
d998353 fe266b3 d998353 fe266b3 d998353 4b7a32d d998353 609518b d998353 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
import os
import json
import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
import gradio as gr
from io import BytesIO
from vit_model import vit_base_patch16_224_in21k as create_model
def classify_image(img):
# Your existing code here, modified to use `img_path` as input
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
data_transform = transforms.Compose(
[transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
# [N, C, H, W]
img = data_transform(img)
# expand batch dimension
img = torch.unsqueeze(img, dim=0)
# read class_indict
json_path = './class_indices.json'
assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)
with open(json_path, "r") as f:
class_indict = json.load(f)
# create model
model = create_model(num_classes=370, has_logits=False).to(device)
# load model weights
model_weight_path = "./best_model.pth"
model.load_state_dict(torch.load(model_weight_path, map_location=device))
model.eval()
with torch.no_grad():
# predict class
output = torch.squeeze(model(img.to(device))).cpu()
predict = torch.softmax(output, dim=0)
predict_cla = torch.argmax(predict).numpy()
print_res = "class: {} prob: {:.3}".format(class_indict[str(predict_cla)],
predict[predict_cla].numpy())
# Combine the two lists into a list of tuples
combined_list = list(zip(class_indict, predict))
# Sort the combined list by the 'predict' values in descending order
sorted_combined_list = sorted(combined_list, key=lambda x: x[1], reverse=True)
# Determine the position you are currently interested in
current_position = 5 # Example position
# Get the previous five elements from the sorted list
# Ensure that the index does not go below zero
start_index = max(current_position - 5, 0)
previous_five = sorted_combined_list[start_index:current_position]
joined_string = ""
for i in previous_five:
#print("class: {:10} prob: {:.3}".format(class_indict[str(i[0])], i[1].numpy()))
joined_string += ("class: {:10} prob: {:.3}".format(class_indict[str(i[0])], i[1].numpy())) + "\n"
#print(joined_string)
plt.title(joined_string)
plt.tight_layout()
fig = plt.figure()
return joined_string
# Create a Gradio interface
iface = gr.Interface(
fn=classify_image,
theme=gr.themes.Default(text_size="lg"),
inputs=gr.Image(type='pil'),
outputs=gr.Textbox(),
title="Mushroom Image Classification",
description="Upload a mushroom image to classify."
)
# Run the Gradio app
#if __name__ == '__main__':
iface.launch()
|