Update app.py
Browse files
app.py
CHANGED
@@ -261,6 +261,8 @@ if st.button('Check for Infringement'):
|
|
261 |
with tab1:
|
262 |
with st.spinner('Processing...'):
|
263 |
|
|
|
|
|
264 |
if not os.path.exists('/home/user/app/embeddings'):
|
265 |
download_db()
|
266 |
print("\u2713 Downloaded Database\n\n")
|
@@ -300,54 +302,54 @@ if st.button('Check for Infringement'):
|
|
300 |
with col2:
|
301 |
st.markdown(f"*Similar Text:* {similar_text}")
|
302 |
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
|
322 |
-
|
323 |
-
|
324 |
|
325 |
-
|
326 |
-
|
327 |
|
328 |
|
329 |
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
|
334 |
-
|
335 |
|
336 |
-
|
337 |
-
|
338 |
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
|
343 |
-
|
344 |
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
|
350 |
-
|
351 |
-
|
352 |
|
353 |
|
|
|
261 |
with tab1:
|
262 |
with st.spinner('Processing...'):
|
263 |
|
264 |
+
for path in os.listdir('/home/user/app'):
|
265 |
+
print(path)
|
266 |
if not os.path.exists('/home/user/app/embeddings'):
|
267 |
download_db()
|
268 |
print("\u2713 Downloaded Database\n\n")
|
|
|
302 |
with col2:
|
303 |
st.markdown(f"*Similar Text:* {similar_text}")
|
304 |
|
305 |
+
if need_image == 'True':
|
306 |
+
with st.spinner('Processing Images...'):
|
307 |
+
emb_main , main_prod_imgs = get_image_embeddings(main_product)
|
308 |
+
similar_prod = extract_similar_products(main_product)[0]
|
309 |
+
emb_similar , similar_prod_imgs = get_image_embeddings(similar_prod)
|
310 |
|
311 |
+
similarity_matrix = np.zeros((5, 5))
|
312 |
+
for i in range(5):
|
313 |
+
for j in range(5):
|
314 |
+
similarity_matrix[i][j] = cosine_similarity([emb_main[i]], [emb_similar[j]])[0][0]
|
315 |
|
316 |
+
st.subheader("Image Similarity")
|
317 |
+
# Create an interactive heatmap
|
318 |
+
fig = px.imshow(similarity_matrix,
|
319 |
+
labels=dict(x=f"{similar_prod} Images", y=f"{main_product} Images", color="Similarity"),
|
320 |
+
x=[f"Image {i+1}" for i in range(5)],
|
321 |
+
y=[f"Image {i+1}" for i in range(5)],
|
322 |
+
color_continuous_scale="Viridis")
|
323 |
|
324 |
+
# Add title to the heatmap
|
325 |
+
fig.update_layout(title="Image Similarity Heatmap")
|
326 |
|
327 |
+
# Display the interactive heatmap
|
328 |
+
st.plotly_chart(fig)
|
329 |
|
330 |
|
331 |
|
332 |
+
@st.experimental_fragment
|
333 |
+
def image_viewer():
|
334 |
+
# Form to handle image selection
|
335 |
|
336 |
+
st.subheader("Image Viewer")
|
337 |
|
338 |
+
selected_row = st.selectbox('Select a row (Main Product Image)', [f'Image {i+1}' for i in range(5)])
|
339 |
+
selected_col = st.selectbox('Select a column (Similar Product Image)', [f'Image {i+1}' for i in range(5)])
|
340 |
|
341 |
+
# Get the selected indices from session state
|
342 |
+
row_idx = int(selected_row.split()[1]) - 1
|
343 |
+
col_idx = int(selected_col.split()[1]) - 1
|
344 |
|
345 |
+
col1, col2 = st.columns(2)
|
346 |
|
347 |
+
with col1:
|
348 |
+
st.image(main_prod_imgs[row_idx], caption=f'Main Product Image {row_idx+1}', use_column_width=True)
|
349 |
+
with col2:
|
350 |
+
st.image(similar_prod_imgs[col_idx], caption=f'Similar Product Image {col_idx+1}', use_column_width=True)
|
351 |
|
352 |
+
# Call the fragment
|
353 |
+
image_viewer()
|
354 |
|
355 |
|