0-ma's picture
Update app.py
079d565 verified
raw
history blame
2.68 kB
import gradio as gr
import numpy as np
from PIL import Image
from transformers import AutoImageProcessor, AutoModelForImageClassification
import os
model_names = [
"0-ma/swin-geometric-shapes-tiny",
"0-ma/mobilenet-v2-geometric-shapes",
"0-ma/focalnet-geometric-shapes-tiny",
"0-ma/efficientnet-b2-geometric-shapes",
"0-ma/beit-geometric-shapes-base",
"0-ma/mit-b0-geometric-shapes",
"0-ma/vit-geometric-shapes-base",
"0-ma/resnet-geometric-shapes",
"0-ma/vit-geometric-shapes-tiny",
]
# example_images = [
# 'example/1_None.jpg',
# 'example/2_Circle.jpg',
# 'example/3_Triangle.jpg',
# 'example/4_Square.jpg',
# 'example/5_Pentagone.jpg',
# 'example/6_Hexagone.jpg'
# ]
example_dir = "./example"
example_images = []
for example_image in os.list_dir(example_dir):
example_images+= [os.path.join(example_dir,example_image)]
#labels = [example.split("_")[1].split(".")[0] for example in example_images]
labels = [
'None',
'Circle',
'Triangle',
'Square',
'Pentagone',
'Hexagone'
]
feature_extractors = {model_name: AutoImageProcessor.from_pretrained(model_name) for model_name in model_names}
classification_models = {model_name: AutoModelForImageClassification.from_pretrained(model_name) for model_name in model_names}
def predict(image, selected_model):
if image is None:
return None
feature_extractor = feature_extractors[selected_model]
model = classification_models[selected_model]
inputs = feature_extractor(images=[image], return_tensors="pt")
logits = model(**inputs)['logits'].cpu().detach().numpy()[0]
logits_positive = logits
logits_positive[logits < 0] = 0
logits_positive = logits_positive/np.sum(logits_positive)
confidences = {}
for i in range(len(labels)):
if logits[i] > 0:
confidences[labels[i]] = float(logits_positive[i])
return confidences
title = "Geometric Shape Classifier"
description = "Select a model and upload an image to classify geometric shapes."
with gr.Blocks() as demo:
gr.Markdown(f"# {title}")
gr.Markdown(description)
model_dropdown = gr.Dropdown(choices=model_names, label="Select Model", value=model_names[0])
image_input = gr.Image(type="pil")
gr.Examples(
examples=example_images,
inputs=image_input,
label="Click on an example image to test",
)
output = gr.Label(label="Classification Result")
image_input.change(fn=predict, inputs=[image_input, model_dropdown], outputs=output)
model_dropdown.change(fn=predict, inputs=[image_input, model_dropdown], outputs=output)
demo.launch()