Spaces:
Runtime error
Runtime error
import gradio as gr | |
from transformers import AutoProcessor, AutoModelForImageClassification | |
import torch | |
from PIL import Image | |
model_name = 'e1010101/vit-384-tongue-image' | |
processor = AutoProcessor.from_pretrained(model_name) | |
model = AutoModelForImageClassification.from_pretrained(model_name) | |
def classify_image(image): | |
inputs = processor(images=image, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
logits = outputs.logits | |
# Apply sigmoid for multi-label classification | |
probs = torch.sigmoid(logits)[0].numpy() | |
# Get label names | |
labels = model.config.id2label.values() | |
# Create a dictionary of labels and probabilities | |
result = {label: float(prob) for label, prob in zip(labels, probs)} | |
# Sort results by probability | |
result = dict(sorted(result.items(), key=lambda item: item[1], reverse=True)) | |
return result | |
interface = gr.Interface( | |
fn=classify_image, | |
inputs=gr.inputs.Image(type="pil"), | |
outputs=gr.outputs.Label(num_top_classes=None), | |
title="Multi-Label Image Classification", | |
description="Upload an image to get classification results." | |
) | |
if __name__ == "__main__": | |
interface.launch() |