Spaces:
Runtime error
Runtime error
import gradio as gr | |
from transformers import AutoImageProcessor, AutoModelForImageClassification | |
import torch | |
from PIL import Image | |
model_name = 'e1010101/vit-384-tongue-image' | |
processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-384") | |
model = AutoModelForImageClassification.from_pretrained( | |
model_name, | |
num_labels=3, | |
problem_type="multi_label_classification", | |
ignore_mismatched_sizes=True, | |
id2label={0: 'Crack', 1: 'Red-Dots', 2: 'Toothmark'}, | |
label2id={'Crack': 0, 'Red-Dots': 1, 'Toothmark': 2} | |
) | |
def classify_image(image): | |
inputs = processor(images=image, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
logits = outputs.logits | |
# Apply sigmoid for multi-label classification | |
probs = torch.sigmoid(logits)[0].numpy() | |
# Get label names | |
labels = model.config.id2label.values() | |
# Create a dictionary of labels and probabilities | |
result = {label: float(prob) for label, prob in zip(labels, probs)} | |
# Sort results by probability | |
result = dict(sorted(result.items(), key=lambda item: item[1], reverse=True)) | |
return result | |
interface = gr.Interface( | |
fn=classify_image, | |
inputs=gr.inputs.Image(type="pil"), | |
outputs=gr.outputs.Label(num_top_classes=None), | |
title="Multi-Label Image Classification", | |
description="Upload an image to get classification results." | |
) | |
if __name__ == "__main__": | |
interface.launch() |