Spaces:
Configuration error
Configuration error
import streamlit as st | |
import pandas as pd | |
import numpy as np | |
import faiss | |
import re | |
import ast | |
import os | |
import urllib.request | |
from sentence_transformers import SentenceTransformer | |
from sentence_transformers.util import cos_sim | |
from langchain.chat_models import ChatOpenAI | |
from langchain.agents import initialize_agent, AgentType, tool | |
from streamlit_chat import message | |
# --------------------------- | |
# Configuration | |
# --------------------------- | |
st.set_page_config(page_title="📱 AI Product Search Agent", layout="wide") | |
# --------------------------- | |
# Load model | |
# --------------------------- | |
def load_model(): | |
return SentenceTransformer("all-MiniLM-L6-v2") | |
# --------------------------- | |
# Load dataset and FAISS index | |
# --------------------------- | |
def load_data(): | |
parquet_url = "https://huggingface.co/datasets/McAuley-Lab/Amazon-Reviews-2023/resolve/main/raw_meta_Cell_Phones_and_Accessories/full-00000-of-00007.parquet" | |
df = pd.read_parquet(parquet_url) | |
index_url = "https://huggingface.co/GovinKin/MGTA415database/resolve/main/cellphones_index.faiss" | |
local_index_path = "cellphones_index.faiss" | |
if not os.path.exists(local_index_path): | |
urllib.request.urlretrieve(index_url, local_index_path) | |
index = faiss.read_index(local_index_path) | |
return df, index | |
# --------------------------- | |
# Search functions | |
# --------------------------- | |
def search(query, model, df, index, top_k=10): | |
query_vector = model.encode([query]).astype("float32") | |
distances, indices = index.search(query_vector, k=top_k) | |
results = df.iloc[indices[0]].copy() | |
results["distance"] = distances[0] | |
return results | |
def search_plus(query, model, df, index, top_k=20): | |
results = search(query, model, df, index, top_k=top_k) | |
price_match = re.search(r"(under|below)\s*\$?(\d+)", query.lower()) | |
price_under = float(price_match.group(2)) if price_match else None | |
if price_under: | |
try: | |
results["price"] = results["price"].astype(float) | |
results = results[results["price"] < price_under] | |
except: | |
pass | |
stop_words = {"i", "want", "need", "the", "a", "for", "with", "to", "is", "it", "on", "of", "buy", "and", "in"} | |
keywords = [kw for kw in query.lower().split() if kw not in stop_words and len(kw) > 2] | |
if not results.empty and keywords: | |
pattern = '|'.join(map(re.escape, keywords)) | |
results = results[results["title"].str.lower().str.contains(pattern, na=False)] | |
return results | |
def rerank_by_similarity(query, results, model, top_n=5): | |
if results.empty: | |
return results | |
query_vec = model.encode([query], convert_to_tensor=True) | |
titles = results["title"].astype(str).tolist() | |
title_vecs = model.encode(titles, convert_to_tensor=True) | |
scores = cos_sim(query_vec, title_vecs)[0].cpu().numpy() | |
results["similarity"] = scores | |
return results.sort_values("similarity", ascending=False).head(top_n) | |
# --------------------------- | |
# Agent Tool: wraps search_plus | |
# --------------------------- | |
def product_search_tool(query: str) -> str: | |
"""Search for cellphone accessories using a natural query.""" | |
results = search_plus(query, model, df_all, index, top_k=10) | |
if results.empty: | |
return "No results found." | |
return "\n".join(results["title"].head(5).tolist()) | |
# --------------------------- | |
# Load all resources | |
# --------------------------- | |
model = load_model() | |
df_all, index = load_data() | |
# --------------------------- | |
# Agent setup | |
# --------------------------- | |
import os | |
os.environ["OPENAI_API_KEY"] = st.secrets["openai"]["api_key"] | |
os.environ["OPENAI_API_BASE"] = st.secrets["openai"].get("base_url", "https://api.openai.com/v1") | |
llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0.3) | |
agent = initialize_agent( | |
tools=[product_search_tool], | |
llm=llm, | |
agent=AgentType.OPENAI_FUNCTIONS, | |
verbose=True | |
) | |
# --------------------------- | |
# Streamlit Chat Interface | |
# --------------------------- | |
st.title("🤖 AI Product Search Agent") | |
st.markdown("Ask natural questions like 'cheap rugged iPhone case under $30'") | |
if "chat_history" not in st.session_state: | |
st.session_state.chat_history = [] | |
user_input = st.chat_input("Ask about cellphone accessories...") | |
if user_input: | |
st.session_state.chat_history.append(("user", user_input)) | |
with st.spinner("Agent is thinking..."): | |
try: | |
reply = agent.run(user_input) | |
except Exception as e: | |
reply = f"⚠️ Agent error: {e}" | |
st.session_state.chat_history.append(("agent", reply)) | |
for role, msg in st.session_state.chat_history: | |
message(msg, is_user=(role == "user")) | |