yurii_l
uploaded baseline app
4bb166c
raw
history blame
2.34 kB
import os
import boto3
import streamlit as st
import faiss
import pandas as pd
from PIL import Image
from model_prediction import Ranker
from io import BytesIO
@st.cache
def load_model():
return Ranker()
def load_faiss_index():
return faiss.read_index('embeddings.index')
def load_labels():
return pd.read_csv("labels.csv")
class ModelLoader:
model = None
index = None
labels = None
@classmethod
def get_model(cls):
if cls.model is None:
cls.model = load_model()
return cls.model
@classmethod
def get_index(cls):
if cls.index is None:
cls.index = load_faiss_index()
return cls.index
@classmethod
def get_labels(cls):
if cls.labels is None:
cls.labels = load_labels()
return cls.labels
target_size = (224, 224)
st.set_page_config(page_title="Product Retrieval App")
st.title("Product Retrieval App")
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
loading_text = st.empty()
s3 = boto3.client(
's3',
aws_access_key_id='AKIAUUWYLZEQYT6ESW4Q',
aws_secret_access_key='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',
region_name='eu-west-1'
)
bucket_name = "product-retrieval"
if uploaded_file is not None:
image = Image.open(uploaded_file)
st.image(image, caption="Uploaded image", use_column_width=True)
loading_text.text("Loading predictions...")
model = ModelLoader.get_model()
index = ModelLoader.get_index()
labels = ModelLoader.get_labels()
image_embedding = model.predict(image)
distances, indices = index.search(image_embedding, 12)
predicted_images = labels["path"][indices[0]].to_list()
loading_text.empty()
col1, col2, col3, col4 = st.columns(4)
for i, img_path in enumerate(predicted_images):
response = s3.get_object(Bucket=bucket_name, Key=img_path.split("/")[-1])
image_data = response['Body'].read()
img = Image.open(BytesIO(image_data)).resize(target_size)
if i % 4 == 0:
column = col1
elif i % 4 == 1:
column = col2
elif i % 4 == 2:
column = col3
else:
column = col4
with column:
st.image(img, caption=f"Predicted image {i+1}", use_column_width=True)