apamplona2011's picture
Upload 3 files
c2f3c5f verified
import streamlit as st
import pandas as pd
import numpy as np
import faiss
import re
import ast
import os
import urllib.request
from sentence_transformers import SentenceTransformer
from sentence_transformers.util import cos_sim
from langchain.chat_models import ChatOpenAI
from langchain.agents import initialize_agent, AgentType, tool
from streamlit_chat import message
# ---------------------------
# Configuration
# ---------------------------
st.set_page_config(page_title="📱 AI Product Search Agent", layout="wide")
# ---------------------------
# Load model
# ---------------------------
@st.cache_resource
def load_model():
return SentenceTransformer("all-MiniLM-L6-v2")
# ---------------------------
# Load dataset and FAISS index
# ---------------------------
@st.cache_data
def load_data():
parquet_url = "https://huggingface.co/datasets/McAuley-Lab/Amazon-Reviews-2023/resolve/main/raw_meta_Cell_Phones_and_Accessories/full-00000-of-00007.parquet"
df = pd.read_parquet(parquet_url)
index_url = "https://huggingface.co/GovinKin/MGTA415database/resolve/main/cellphones_index.faiss"
local_index_path = "cellphones_index.faiss"
if not os.path.exists(local_index_path):
urllib.request.urlretrieve(index_url, local_index_path)
index = faiss.read_index(local_index_path)
return df, index
# ---------------------------
# Search functions
# ---------------------------
def search(query, model, df, index, top_k=10):
query_vector = model.encode([query]).astype("float32")
distances, indices = index.search(query_vector, k=top_k)
results = df.iloc[indices[0]].copy()
results["distance"] = distances[0]
return results
def search_plus(query, model, df, index, top_k=20):
results = search(query, model, df, index, top_k=top_k)
price_match = re.search(r"(under|below)\s*\$?(\d+)", query.lower())
price_under = float(price_match.group(2)) if price_match else None
if price_under:
try:
results["price"] = results["price"].astype(float)
results = results[results["price"] < price_under]
except:
pass
stop_words = {"i", "want", "need", "the", "a", "for", "with", "to", "is", "it", "on", "of", "buy", "and", "in"}
keywords = [kw for kw in query.lower().split() if kw not in stop_words and len(kw) > 2]
if not results.empty and keywords:
pattern = '|'.join(map(re.escape, keywords))
results = results[results["title"].str.lower().str.contains(pattern, na=False)]
return results
def rerank_by_similarity(query, results, model, top_n=5):
if results.empty:
return results
query_vec = model.encode([query], convert_to_tensor=True)
titles = results["title"].astype(str).tolist()
title_vecs = model.encode(titles, convert_to_tensor=True)
scores = cos_sim(query_vec, title_vecs)[0].cpu().numpy()
results["similarity"] = scores
return results.sort_values("similarity", ascending=False).head(top_n)
# ---------------------------
# Agent Tool: wraps search_plus
# ---------------------------
@tool
def product_search_tool(query: str) -> str:
"""Search for cellphone accessories using a natural query."""
results = search_plus(query, model, df_all, index, top_k=10)
if results.empty:
return "No results found."
return "\n".join(results["title"].head(5).tolist())
# ---------------------------
# Load all resources
# ---------------------------
model = load_model()
df_all, index = load_data()
# ---------------------------
# Agent setup
# ---------------------------
import os
os.environ["OPENAI_API_KEY"] = st.secrets["openai"]["api_key"]
os.environ["OPENAI_API_BASE"] = st.secrets["openai"].get("base_url", "https://api.openai.com/v1")
llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0.3)
agent = initialize_agent(
tools=[product_search_tool],
llm=llm,
agent=AgentType.OPENAI_FUNCTIONS,
verbose=True
)
# ---------------------------
# Streamlit Chat Interface
# ---------------------------
st.title("🤖 AI Product Search Agent")
st.markdown("Ask natural questions like 'cheap rugged iPhone case under $30'")
if "chat_history" not in st.session_state:
st.session_state.chat_history = []
user_input = st.chat_input("Ask about cellphone accessories...")
if user_input:
st.session_state.chat_history.append(("user", user_input))
with st.spinner("Agent is thinking..."):
try:
reply = agent.run(user_input)
except Exception as e:
reply = f"⚠️ Agent error: {e}"
st.session_state.chat_history.append(("agent", reply))
for role, msg in st.session_state.chat_history:
message(msg, is_user=(role == "user"))