import streamlit as st import pandas as pd import numpy as np import faiss import re import ast import os import urllib.request from sentence_transformers import SentenceTransformer from sentence_transformers.util import cos_sim from langchain.chat_models import ChatOpenAI from langchain.agents import initialize_agent, AgentType, tool from streamlit_chat import message # --------------------------- # Configuration # --------------------------- st.set_page_config(page_title="📱 AI Product Search Agent", layout="wide") # --------------------------- # Load model # --------------------------- @st.cache_resource def load_model(): return SentenceTransformer("all-MiniLM-L6-v2") # --------------------------- # Load dataset and FAISS index # --------------------------- @st.cache_data def load_data(): parquet_url = "https://huggingface.co/datasets/McAuley-Lab/Amazon-Reviews-2023/resolve/main/raw_meta_Cell_Phones_and_Accessories/full-00000-of-00007.parquet" df = pd.read_parquet(parquet_url) index_url = "https://huggingface.co/GovinKin/MGTA415database/resolve/main/cellphones_index.faiss" local_index_path = "cellphones_index.faiss" if not os.path.exists(local_index_path): urllib.request.urlretrieve(index_url, local_index_path) index = faiss.read_index(local_index_path) return df, index # --------------------------- # Search functions # --------------------------- def search(query, model, df, index, top_k=10): query_vector = model.encode([query]).astype("float32") distances, indices = index.search(query_vector, k=top_k) results = df.iloc[indices[0]].copy() results["distance"] = distances[0] return results def search_plus(query, model, df, index, top_k=20): results = search(query, model, df, index, top_k=top_k) price_match = re.search(r"(under|below)\s*\$?(\d+)", query.lower()) price_under = float(price_match.group(2)) if price_match else None if price_under: try: results["price"] = results["price"].astype(float) results = results[results["price"] < price_under] except: pass stop_words = {"i", "want", "need", "the", "a", "for", "with", "to", "is", "it", "on", "of", "buy", "and", "in"} keywords = [kw for kw in query.lower().split() if kw not in stop_words and len(kw) > 2] if not results.empty and keywords: pattern = '|'.join(map(re.escape, keywords)) results = results[results["title"].str.lower().str.contains(pattern, na=False)] return results def rerank_by_similarity(query, results, model, top_n=5): if results.empty: return results query_vec = model.encode([query], convert_to_tensor=True) titles = results["title"].astype(str).tolist() title_vecs = model.encode(titles, convert_to_tensor=True) scores = cos_sim(query_vec, title_vecs)[0].cpu().numpy() results["similarity"] = scores return results.sort_values("similarity", ascending=False).head(top_n) # --------------------------- # Agent Tool: wraps search_plus # --------------------------- @tool def product_search_tool(query: str) -> str: """Search for cellphone accessories using a natural query.""" results = search_plus(query, model, df_all, index, top_k=10) if results.empty: return "No results found." return "\n".join(results["title"].head(5).tolist()) # --------------------------- # Load all resources # --------------------------- model = load_model() df_all, index = load_data() # --------------------------- # Agent setup # --------------------------- import os os.environ["OPENAI_API_KEY"] = st.secrets["openai"]["api_key"] os.environ["OPENAI_API_BASE"] = st.secrets["openai"].get("base_url", "https://api.openai.com/v1") llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0.3) agent = initialize_agent( tools=[product_search_tool], llm=llm, agent=AgentType.OPENAI_FUNCTIONS, verbose=True ) # --------------------------- # Streamlit Chat Interface # --------------------------- st.title("🤖 AI Product Search Agent") st.markdown("Ask natural questions like 'cheap rugged iPhone case under $30'") if "chat_history" not in st.session_state: st.session_state.chat_history = [] user_input = st.chat_input("Ask about cellphone accessories...") if user_input: st.session_state.chat_history.append(("user", user_input)) with st.spinner("Agent is thinking..."): try: reply = agent.run(user_input) except Exception as e: reply = f"⚠️ Agent error: {e}" st.session_state.chat_history.append(("agent", reply)) for role, msg in st.session_state.chat_history: message(msg, is_user=(role == "user"))