Spaces:
Runtime error
Runtime error
File size: 2,620 Bytes
25bf539 1234c7c 135173a 25bf539 1234c7c 0f6cc97 1234c7c 92e317c 8713bc4 92e317c 8713bc4 92e317c 8713bc4 92e317c 8713bc4 92e317c 8713bc4 92e317c 8713bc4 92e317c 8713bc4 92e317c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
import io
import timm
import torch
import streamlit as st
from PIL import Image
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
class ImageClassifier(object):
def __init__(self, model, labels):
self.model = model
self.labels = labels
def get_top_5_predictions(self, image):
values, indices = torch.topk(self.get_output_probabilities(image), 5)
return [
{'label': self.labels[i], 'score': v.item()}
for i, v in zip(indices, values)
]
def get_output_probabilities(self, image):
output = self.classify_image(image)
return torch.nn.functional.softmax(output[0], dim=0)
def classify_image(self, image):
self.model.eval()
transform = self.create_image_transform()
return self.model(transform(image).unsqueeze(0))
def create_image_transform(self):
return create_transform(**resolve_data_config(
self.model.pretrained_cfg, model=self.model))
class ImageClassificationApp(object):
def __init__(self, title, classifier):
self.title = title
self.classifier = classifier
def render(self):
st.title(self.title)
uploaded_image = self.get_uploaded_image()
if uploaded_image is not None:
self.show_image_and_results(uploaded_image)
def get_uploaded_image(self):
return st.file_uploader('Choose an image...', type=['jpg', 'png', 'jpeg'])
def show_image_and_results(self, uploaded_image):
self.show_uploaded_image(uploaded_image)
self.show_classification_results(self.get_image(uploaded_image.read()))
def show_uploaded_image(self, uploaded_image):
st.image(uploaded_image, caption='Uploaded Image', use_column_width=True)
def show_classification_results(self, image):
st.subheader('Classification Results:')
self.write_top_5_predictions(image)
def write_top_5_predictions(self, image):
for prediction in self.classifier.get_top_5_predictions(image):
st.write(f"- {prediction['label']}: {prediction['score']:.4f}")
def get_image(self, image_data):
return Image.open(io.BytesIO(image_data))
if __name__ == '__main__':
model = timm.create_model(
'hf-hub:nateraw/resnet50-oxford-iiit-pet',
pretrained=True
)
labels = model.pretrained_cfg['label_names']
classifier = ImageClassifier(model, labels)
ImageClassificationApp(
'Pet Image Classification App',
classifier
).render()
|