ideityfy / app.py
Yegiiii's picture
Update app.py
092b75f verified
raw
history blame
2.43 kB
import streamlit as st
import torch
import torchvision.transforms as T
from PIL import Image
from module import myModule
CLASS_TO_IDX = ['AMMAN', 'AYYAPPA', 'BHAIRAV', 'BRAHMA', 'BUDDHA', 'DURGA', 'GANESHA', 'HANUMAN', 'KAALI',
'KRISHNA', 'KURMA', 'LAKSHMI', 'LINGA', 'MATSYA', 'MURUGA', 'NARASIMHA', 'NATARAJA', 'PARASURAMA',
'RAMA', 'SARASWATI', 'SHIVA', 'THIRTHANKARA', 'VAMANA', 'VARAHA', 'VISHNU']
IMG_SIZE = (224, 224)
STATS = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
# Define the transformation for the input image
TTA_TRANSFORM = T.Compose([
T.Resize(IMG_SIZE),
T.AutoAugment(),
T.ToTensor(),
T.Normalize(**STATS)
])
st.set_page_config(
page_title="Identify the deity using Computer Vision.",
layout="centered",
initial_sidebar_state="collapsed",
menu_items={
'Get Help': 'https://www.extremelycoolapp.com/help',
'Report a bug': "https://www.extremelycoolapp.com/bug",
'About': "# This is an *extremely* cool app!"
}
)
st.title(":sparkles: I:orange[deity]fy")
st.header("Discover the deity with a snap.")
model = myModule.load_from_checkpoint("vit_base_clip_rank4.ckpt")
model.to("cpu")
model.eval()
# Function to make predictions
def predict(image):
# Load and preprocess the input image
with Image.open(image).convert('RGB') as img:
img_tensor = torch.stack([TTA_TRANSFORM(img) for img in [img for _ in range(10)]])
img_tensor = torch.mean(img_tensor, dim=0).unsqueeze(0)
# Make a prediction
with torch.no_grad():
logits = model(img_tensor)
# Get the top 3 predictions and their probabilities
probs = torch.softmax(logits, dim=1)
topk = torch.topk(probs, k=3)
values, indices = topk.values, topk.indices
values = values.squeeze().cpu().numpy().tolist()
indices = indices.cpu().squeeze().numpy().tolist()
return values, indices
# Upload image through Streamlit
img = st.file_uploader(label='choose a file', type=['png', 'jpg', 'jpeg'], label_visibility="hidden")
if img is not None:
# Make predictions when the user clicks the "Predict" button
if st.button("Predict"):
values, indices = predict(img)
classes = [CLASS_TO_IDX[index] for index in indices]
# Display the top 3 predictions as a bar chart
st.bar_chart({label: prob for label, prob in zip(classes, values)}, color="#FFC101")