import os import boto3 import streamlit as st import faiss import pandas as pd from PIL import Image from model_prediction import Ranker from io import BytesIO @st.cache def load_model(): return Ranker() def load_faiss_index(): return faiss.read_index('embeddings.index') def load_labels(): return pd.read_csv("labels.csv") class ModelLoader: model = None index = None labels = None @classmethod def get_model(cls): if cls.model is None: cls.model = load_model() return cls.model @classmethod def get_index(cls): if cls.index is None: cls.index = load_faiss_index() return cls.index @classmethod def get_labels(cls): if cls.labels is None: cls.labels = load_labels() return cls.labels target_size = (224, 224) st.set_page_config(page_title="Product Retrieval App") st.title("Product Retrieval App") uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) loading_text = st.empty() s3 = boto3.client( 's3', aws_access_key_id='AKIAUUWYLZEQYT6ESW4Q', aws_secret_access_key='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', region_name='eu-west-1' ) bucket_name = "product-retrieval" if uploaded_file is not None: image = Image.open(uploaded_file) st.image(image, caption="Uploaded image", use_column_width=True) loading_text.text("Loading predictions...") model = ModelLoader.get_model() index = ModelLoader.get_index() labels = ModelLoader.get_labels() image_embedding = model.predict(image) distances, indices = index.search(image_embedding, 12) predicted_images = labels["path"][indices[0]].to_list() loading_text.empty() col1, col2, col3, col4 = st.columns(4) for i, img_path in enumerate(predicted_images): response = s3.get_object(Bucket=bucket_name, Key=img_path.split("/")[-1]) image_data = response['Body'].read() img = Image.open(BytesIO(image_data)).resize(target_size) if i % 4 == 0: column = col1 elif i % 4 == 1: column = col2 elif i % 4 == 2: column = col3 else: column = col4 with column: st.image(img, caption=f"Predicted image {i+1}", use_column_width=True)