Spaces:
Runtime error
Runtime error
import os | |
import gradio as gr | |
import timm | |
from huggingface_hub import login | |
from torch import no_grad, softmax, topk | |
MODEL_NAME = os.getenv("MODEL_NAME") | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
login(token=HF_TOKEN) | |
model = timm.create_model(f"hf_hub:{MODEL_NAME}", pretrained=True) | |
model.eval() | |
data_cfg = timm.data.resolve_data_config(model.pretrained_cfg) | |
transform = timm.data.create_transform(**data_cfg) | |
def classify_image(input): | |
inp = transform(input) | |
with no_grad(): | |
output = model(inp.unsqueeze(0)) | |
probabilities = softmax(output[0], dim=0) | |
values, indices = topk(probabilities, 3) | |
return { | |
model.pretrained_cfg["label_names"][str(id.item())].title(): prob | |
for id, prob in zip(indices, values) | |
} | |
demo = gr.Interface( | |
fn=classify_image, | |
inputs=gr.Image(type="pil", sources=["upload", "clipboard"]), | |
outputs=gr.Label(num_top_classes=3), | |
allow_flagging="never", | |
examples="examples", | |
) | |
demo.queue() | |
demo.launch(debug=True) | |