ResNet-18 / app.py
sam12-33's picture
title added
71ca213
raw
history blame
1.53 kB
from pathlib import Path
from PIL import Image
import numpy as np
import torch
import requests
from io import BytesIO
from torchvision.models import resnet18, ResNet18_Weights
def predict(img_path = None) -> str:
# Initialize the model and transform
resnet_model = resnet18(weights=ResNet18_Weights.DEFAULT)
resnet_transform = ResNet18_Weights.DEFAULT.transforms()
# Load the image
if img_path is None:
image = Image.open("examples/steak.jpeg").convert("RGB")
if isinstance(img_path, np.ndarray):
img = Image.fromarray(img_path.astype("uint8"), "RGB")
# img = effnet_b2_transform(img).unsqueeze(0)
# Convert to tensor
# img = torch.from_numpy(np.array(image)).permute(2, 0, 1)
img = resnet_transform(img)
# Inference
resnet_model.eval()
with torch.inference_mode():
logits = resnet_model(img.unsqueeze(0))
pred_class = torch.softmax(logits, dim=1).argmax(dim=1).item()
predicted_label = ResNet18_Weights.DEFAULT.meta["categories"][pred_class]
print(f"Predicted class: {predicted_label}")
return predicted_label
import numpy as np
import gradio as gr
demo = gr.Interface(predict,
gr.Image(),
"label",
title="ResNet-18_1K πŸš—",
description="Upload an image to see classification probabilities based on ResNet-18 with 1K classes",)
if __name__ == "__main__":
demo.launch()