Spaces:
Sleeping
Sleeping
import os | |
import boto3 | |
import streamlit as st | |
import faiss | |
import pandas as pd | |
from PIL import Image | |
from model_prediction import Ranker | |
from io import BytesIO | |
def load_model(): | |
return Ranker() | |
def load_faiss_index(): | |
return faiss.read_index('embeddings.index') | |
def load_labels(): | |
return pd.read_csv("labels.csv") | |
class ModelLoader: | |
model = None | |
index = None | |
labels = None | |
def get_model(cls): | |
if cls.model is None: | |
cls.model = load_model() | |
return cls.model | |
def get_index(cls): | |
if cls.index is None: | |
cls.index = load_faiss_index() | |
return cls.index | |
def get_labels(cls): | |
if cls.labels is None: | |
cls.labels = load_labels() | |
return cls.labels | |
target_size = (224, 224) | |
st.set_page_config(page_title="Product Retrieval App") | |
st.title("Product Retrieval App") | |
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) | |
loading_text = st.empty() | |
s3 = boto3.client( | |
's3', | |
aws_access_key_id='AKIAUUWYLZEQYT6ESW4Q', | |
aws_secret_access_key='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', | |
region_name='eu-west-1' | |
) | |
bucket_name = "product-retrieval" | |
if uploaded_file is not None: | |
image = Image.open(uploaded_file) | |
st.image(image, caption="Uploaded image", use_column_width=True) | |
loading_text.text("Loading predictions...") | |
model = ModelLoader.get_model() | |
index = ModelLoader.get_index() | |
labels = ModelLoader.get_labels() | |
image_embedding = model.predict(image) | |
distances, indices = index.search(image_embedding, 12) | |
predicted_images = labels["path"][indices[0]].to_list() | |
loading_text.empty() | |
col1, col2, col3, col4 = st.columns(4) | |
for i, img_path in enumerate(predicted_images): | |
response = s3.get_object(Bucket=bucket_name, Key=img_path.split("/")[-1]) | |
image_data = response['Body'].read() | |
img = Image.open(BytesIO(image_data)).resize(target_size) | |
if i % 4 == 0: | |
column = col1 | |
elif i % 4 == 1: | |
column = col2 | |
elif i % 4 == 2: | |
column = col3 | |
else: | |
column = col4 | |
with column: | |
st.image(img, caption=f"Predicted image {i+1}", use_column_width=True) | |