File size: 4,710 Bytes
c2f3c5f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import streamlit as st
import pandas as pd
import numpy as np
import faiss
import re
import ast
import os
import urllib.request

from sentence_transformers import SentenceTransformer
from sentence_transformers.util import cos_sim
from langchain.chat_models import ChatOpenAI
from langchain.agents import initialize_agent, AgentType, tool
from streamlit_chat import message

# ---------------------------
# Configuration
# ---------------------------
st.set_page_config(page_title="📱 AI Product Search Agent", layout="wide")

# ---------------------------
# Load model
# ---------------------------
@st.cache_resource
def load_model():
    return SentenceTransformer("all-MiniLM-L6-v2")

# ---------------------------
# Load dataset and FAISS index
# ---------------------------
@st.cache_data
def load_data():
    parquet_url = "https://huggingface.co/datasets/McAuley-Lab/Amazon-Reviews-2023/resolve/main/raw_meta_Cell_Phones_and_Accessories/full-00000-of-00007.parquet"
    df = pd.read_parquet(parquet_url)

    index_url = "https://huggingface.co/GovinKin/MGTA415database/resolve/main/cellphones_index.faiss"
    local_index_path = "cellphones_index.faiss"
    if not os.path.exists(local_index_path):
        urllib.request.urlretrieve(index_url, local_index_path)

    index = faiss.read_index(local_index_path)
    return df, index

# ---------------------------
# Search functions
# ---------------------------
def search(query, model, df, index, top_k=10):
    query_vector = model.encode([query]).astype("float32")
    distances, indices = index.search(query_vector, k=top_k)
    results = df.iloc[indices[0]].copy()
    results["distance"] = distances[0]
    return results

def search_plus(query, model, df, index, top_k=20):
    results = search(query, model, df, index, top_k=top_k)

    price_match = re.search(r"(under|below)\s*\$?(\d+)", query.lower())
    price_under = float(price_match.group(2)) if price_match else None

    if price_under:
        try:
            results["price"] = results["price"].astype(float)
            results = results[results["price"] < price_under]
        except:
            pass

    stop_words = {"i", "want", "need", "the", "a", "for", "with", "to", "is", "it", "on", "of", "buy", "and", "in"}
    keywords = [kw for kw in query.lower().split() if kw not in stop_words and len(kw) > 2]

    if not results.empty and keywords:
        pattern = '|'.join(map(re.escape, keywords))
        results = results[results["title"].str.lower().str.contains(pattern, na=False)]

    return results

def rerank_by_similarity(query, results, model, top_n=5):
    if results.empty:
        return results
    query_vec = model.encode([query], convert_to_tensor=True)
    titles = results["title"].astype(str).tolist()
    title_vecs = model.encode(titles, convert_to_tensor=True)
    scores = cos_sim(query_vec, title_vecs)[0].cpu().numpy()
    results["similarity"] = scores
    return results.sort_values("similarity", ascending=False).head(top_n)

# ---------------------------
# Agent Tool: wraps search_plus
# ---------------------------
@tool
def product_search_tool(query: str) -> str:
    """Search for cellphone accessories using a natural query."""
    results = search_plus(query, model, df_all, index, top_k=10)
    if results.empty:
        return "No results found."
    return "\n".join(results["title"].head(5).tolist())

# ---------------------------
# Load all resources
# ---------------------------
model = load_model()
df_all, index = load_data()

# ---------------------------
# Agent setup
# ---------------------------
import os
os.environ["OPENAI_API_KEY"] = st.secrets["openai"]["api_key"]
os.environ["OPENAI_API_BASE"] = st.secrets["openai"].get("base_url", "https://api.openai.com/v1")

llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0.3)
agent = initialize_agent(
    tools=[product_search_tool],
    llm=llm,
    agent=AgentType.OPENAI_FUNCTIONS,
    verbose=True
)

# ---------------------------
# Streamlit Chat Interface
# ---------------------------
st.title("🤖 AI Product Search Agent")
st.markdown("Ask natural questions like 'cheap rugged iPhone case under $30'")

if "chat_history" not in st.session_state:
    st.session_state.chat_history = []

user_input = st.chat_input("Ask about cellphone accessories...")

if user_input:
    st.session_state.chat_history.append(("user", user_input))
    with st.spinner("Agent is thinking..."):
        try:
            reply = agent.run(user_input)
        except Exception as e:
            reply = f"⚠️ Agent error: {e}"
    st.session_state.chat_history.append(("agent", reply))

for role, msg in st.session_state.chat_history:
    message(msg, is_user=(role == "user"))