Spaces:
Configuration error
Configuration error
File size: 4,710 Bytes
c2f3c5f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import streamlit as st
import pandas as pd
import numpy as np
import faiss
import re
import ast
import os
import urllib.request
from sentence_transformers import SentenceTransformer
from sentence_transformers.util import cos_sim
from langchain.chat_models import ChatOpenAI
from langchain.agents import initialize_agent, AgentType, tool
from streamlit_chat import message
# ---------------------------
# Configuration
# ---------------------------
st.set_page_config(page_title="📱 AI Product Search Agent", layout="wide")
# ---------------------------
# Load model
# ---------------------------
@st.cache_resource
def load_model():
return SentenceTransformer("all-MiniLM-L6-v2")
# ---------------------------
# Load dataset and FAISS index
# ---------------------------
@st.cache_data
def load_data():
parquet_url = "https://huggingface.co/datasets/McAuley-Lab/Amazon-Reviews-2023/resolve/main/raw_meta_Cell_Phones_and_Accessories/full-00000-of-00007.parquet"
df = pd.read_parquet(parquet_url)
index_url = "https://huggingface.co/GovinKin/MGTA415database/resolve/main/cellphones_index.faiss"
local_index_path = "cellphones_index.faiss"
if not os.path.exists(local_index_path):
urllib.request.urlretrieve(index_url, local_index_path)
index = faiss.read_index(local_index_path)
return df, index
# ---------------------------
# Search functions
# ---------------------------
def search(query, model, df, index, top_k=10):
query_vector = model.encode([query]).astype("float32")
distances, indices = index.search(query_vector, k=top_k)
results = df.iloc[indices[0]].copy()
results["distance"] = distances[0]
return results
def search_plus(query, model, df, index, top_k=20):
results = search(query, model, df, index, top_k=top_k)
price_match = re.search(r"(under|below)\s*\$?(\d+)", query.lower())
price_under = float(price_match.group(2)) if price_match else None
if price_under:
try:
results["price"] = results["price"].astype(float)
results = results[results["price"] < price_under]
except:
pass
stop_words = {"i", "want", "need", "the", "a", "for", "with", "to", "is", "it", "on", "of", "buy", "and", "in"}
keywords = [kw for kw in query.lower().split() if kw not in stop_words and len(kw) > 2]
if not results.empty and keywords:
pattern = '|'.join(map(re.escape, keywords))
results = results[results["title"].str.lower().str.contains(pattern, na=False)]
return results
def rerank_by_similarity(query, results, model, top_n=5):
if results.empty:
return results
query_vec = model.encode([query], convert_to_tensor=True)
titles = results["title"].astype(str).tolist()
title_vecs = model.encode(titles, convert_to_tensor=True)
scores = cos_sim(query_vec, title_vecs)[0].cpu().numpy()
results["similarity"] = scores
return results.sort_values("similarity", ascending=False).head(top_n)
# ---------------------------
# Agent Tool: wraps search_plus
# ---------------------------
@tool
def product_search_tool(query: str) -> str:
"""Search for cellphone accessories using a natural query."""
results = search_plus(query, model, df_all, index, top_k=10)
if results.empty:
return "No results found."
return "\n".join(results["title"].head(5).tolist())
# ---------------------------
# Load all resources
# ---------------------------
model = load_model()
df_all, index = load_data()
# ---------------------------
# Agent setup
# ---------------------------
import os
os.environ["OPENAI_API_KEY"] = st.secrets["openai"]["api_key"]
os.environ["OPENAI_API_BASE"] = st.secrets["openai"].get("base_url", "https://api.openai.com/v1")
llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0.3)
agent = initialize_agent(
tools=[product_search_tool],
llm=llm,
agent=AgentType.OPENAI_FUNCTIONS,
verbose=True
)
# ---------------------------
# Streamlit Chat Interface
# ---------------------------
st.title("🤖 AI Product Search Agent")
st.markdown("Ask natural questions like 'cheap rugged iPhone case under $30'")
if "chat_history" not in st.session_state:
st.session_state.chat_history = []
user_input = st.chat_input("Ask about cellphone accessories...")
if user_input:
st.session_state.chat_history.append(("user", user_input))
with st.spinner("Agent is thinking..."):
try:
reply = agent.run(user_input)
except Exception as e:
reply = f"⚠️ Agent error: {e}"
st.session_state.chat_history.append(("agent", reply))
for role, msg in st.session_state.chat_history:
message(msg, is_user=(role == "user"))
|