File size: 2,338 Bytes
4bb166c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import os
import boto3
import streamlit as st
import faiss
import pandas as pd
from PIL import Image
from model_prediction import Ranker
from io import BytesIO


@st.cache
def load_model():
    return Ranker()


def load_faiss_index():
    return faiss.read_index('embeddings.index')


def load_labels():
    return pd.read_csv("labels.csv")


class ModelLoader:
    model = None
    index = None
    labels = None

    @classmethod
    def get_model(cls):
        if cls.model is None:
            cls.model = load_model()
        return cls.model

    @classmethod
    def get_index(cls):
        if cls.index is None:
            cls.index = load_faiss_index()
        return cls.index

    @classmethod
    def get_labels(cls):
        if cls.labels is None:
            cls.labels = load_labels()
        return cls.labels


target_size = (224, 224)
st.set_page_config(page_title="Product Retrieval App")
st.title("Product Retrieval App")

uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
loading_text = st.empty()

s3 = boto3.client(
    's3',
    aws_access_key_id='AKIAUUWYLZEQYT6ESW4Q',
    aws_secret_access_key='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',
    region_name='eu-west-1'
)

bucket_name = "product-retrieval"

if uploaded_file is not None:
    image = Image.open(uploaded_file)
    st.image(image, caption="Uploaded image", use_column_width=True)

    loading_text.text("Loading predictions...")

    model = ModelLoader.get_model()
    index = ModelLoader.get_index()
    labels = ModelLoader.get_labels()

    image_embedding = model.predict(image)
    distances, indices = index.search(image_embedding, 12)
    predicted_images = labels["path"][indices[0]].to_list()
    loading_text.empty()

    col1, col2, col3, col4 = st.columns(4)

    for i, img_path in enumerate(predicted_images):
        response = s3.get_object(Bucket=bucket_name, Key=img_path.split("/")[-1])
        image_data = response['Body'].read()
        img = Image.open(BytesIO(image_data)).resize(target_size)

        if i % 4 == 0:
            column = col1
        elif i % 4 == 1:
            column = col2
        elif i % 4 == 2:
            column = col3
        else:
            column = col4
        with column:
            st.image(img, caption=f"Predicted image {i+1}", use_column_width=True)