Racso777's picture
Update app.py
4b7a32d
import os
import json
import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
import gradio as gr
from io import BytesIO
from vit_model import vit_base_patch16_224_in21k as create_model
def classify_image(img):
# Your existing code here, modified to use `img_path` as input
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
data_transform = transforms.Compose(
[transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
# [N, C, H, W]
img = data_transform(img)
# expand batch dimension
img = torch.unsqueeze(img, dim=0)
# read class_indict
json_path = './class_indices.json'
assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)
with open(json_path, "r") as f:
class_indict = json.load(f)
# create model
model = create_model(num_classes=370, has_logits=False).to(device)
# load model weights
model_weight_path = "./best_model.pth"
model.load_state_dict(torch.load(model_weight_path, map_location=device))
model.eval()
with torch.no_grad():
# predict class
output = torch.squeeze(model(img.to(device))).cpu()
predict = torch.softmax(output, dim=0)
predict_cla = torch.argmax(predict).numpy()
print_res = "class: {} prob: {:.3}".format(class_indict[str(predict_cla)],
predict[predict_cla].numpy())
# Combine the two lists into a list of tuples
combined_list = list(zip(class_indict, predict))
# Sort the combined list by the 'predict' values in descending order
sorted_combined_list = sorted(combined_list, key=lambda x: x[1], reverse=True)
# Determine the position you are currently interested in
current_position = 5 # Example position
# Get the previous five elements from the sorted list
# Ensure that the index does not go below zero
start_index = max(current_position - 5, 0)
previous_five = sorted_combined_list[start_index:current_position]
joined_string = ""
for i in previous_five:
#print("class: {:10} prob: {:.3}".format(class_indict[str(i[0])], i[1].numpy()))
joined_string += ("class: {:10} prob: {:.3}".format(class_indict[str(i[0])], i[1].numpy())) + "\n"
#print(joined_string)
plt.title(joined_string)
plt.tight_layout()
fig = plt.figure()
return joined_string
# Create a Gradio interface
iface = gr.Interface(
fn=classify_image,
theme=gr.themes.Default(text_size="lg"),
inputs=gr.Image(type='pil'),
outputs=gr.Textbox(),
title="Mushroom Image Classification",
description="Upload a mushroom image to classify."
)
# Run the Gradio app
#if __name__ == '__main__':
iface.launch()