PandasAI / app.py
soojeongcrystal's picture
Update app.py
6843942 verified
import streamlit as st
import pandas as pd
import plotly.express as px
from pandasai import Agent
from langchain_community.embeddings.openai import OpenAIEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_openai import ChatOpenAI
from langchain.chains import RetrievalQA
from langchain.schema import Document
import os
# ์ œ๋ชฉ ์„ค์ •
st.title("PandasAI ๋ฐ์ดํ„ฐ ๋ถ„์„๊ธฐ with RAG")
# ์‚ฌ์ด๋“œ๋ฐ”์— API ํ‚ค ์ž…๋ ฅ ํ•„๋“œ ์ถ”๊ฐ€
api_key = st.sidebar.text_input("OpenAI API Key", type="password")
pandasai_api_key = st.sidebar.text_input("PandasAI API Key", type="password")
# ํŒŒ์ผ ์—…๋กœ๋“œ
uploaded_file = st.file_uploader("์—‘์…€ ๋˜๋Š” CSV ํŒŒ์ผ์„ ์—…๋กœ๋“œํ•˜์„ธ์š”", type=["xlsx", "csv"])
if uploaded_file is not None and api_key and pandasai_api_key:
# API ํ‚ค ์„ค์ •
os.environ["OPENAI_API_KEY"] = api_key
os.environ["PANDASAI_API_KEY"] = pandasai_api_key
# ๋ฐ์ดํ„ฐ ๋กœ๋“œ
if uploaded_file.name.endswith('.xlsx'):
df = pd.read_excel(uploaded_file)
else:
df = pd.read_csv(uploaded_file)
st.write("๋ฐ์ดํ„ฐ ๋ฏธ๋ฆฌ๋ณด๊ธฐ:")
st.write(df.head())
# PandasAI Agent ์„ค์ •
agent = Agent(df)
# ๋ฐ์ดํ„ฐํ”„๋ ˆ์ž„์„ ๋ฌธ์„œ๋กœ ๋ณ€ํ™˜
documents = [
Document(
page_content=", ".join([f"{col}: {row[col]}" for col in df.columns]),
metadata={"index": index}
)
for index, row in df.iterrows()
]
# RAG ์„ค์ •
embeddings = OpenAIEmbeddings()
vectorstore = FAISS.from_documents(documents, embeddings)
retriever = vectorstore.as_retriever()
qa_chain = RetrievalQA.from_chain_type(
llm=ChatOpenAI(),
chain_type="stuff",
retriever=retriever
)
# ํƒญ ์ƒ์„ฑ
tab1, tab2, tab3 = st.tabs(["PandasAI ๋ถ„์„", "RAG ์งˆ๋ฌธ๋‹ต๋ณ€", "๋ฐ์ดํ„ฐ ์‹œ๊ฐํ™”"])
with tab1:
st.header("PandasAI๋ฅผ ์‚ฌ์šฉํ•œ ๋ฐ์ดํ„ฐ ๋ถ„์„")
pandas_question = st.text_input("๋ฐ์ดํ„ฐ์— ๋Œ€ํ•ด ์งˆ๋ฌธํ•˜์„ธ์š” (PandasAI):")
if pandas_question:
result = agent.chat(pandas_question)
st.write("PandasAI ๋‹ต๋ณ€:", result)
with tab2:
st.header("RAG๋ฅผ ์‚ฌ์šฉํ•œ ์งˆ๋ฌธ๋‹ต๋ณ€")
rag_question = st.text_input("๋ฐ์ดํ„ฐ์— ๋Œ€ํ•ด ์งˆ๋ฌธํ•˜์„ธ์š” (RAG):")
if rag_question:
result = qa_chain.run(rag_question)
st.write("RAG ๋‹ต๋ณ€:", result)
with tab3:
st.header("๋ฐ์ดํ„ฐ ์‹œ๊ฐํ™”")
viz_question = st.text_input("์–ด๋–ค ๊ทธ๋ž˜ํ”„๋ฅผ ๊ทธ๋ฆฌ๊ณ  ์‹ถ์œผ์‹ ๊ฐ€์š”? (์˜ˆ: '์—ฐ๋ด‰๊ณผ ๊ฒฝ๋ ฅ์˜ ๊ด€๊ณ„๋ฅผ ์‚ฐ์ ๋„๋กœ ๋ณด์—ฌ์ค˜')")
if viz_question:
try:
result = agent.chat(viz_question)
# PandasAI์˜ ๊ฒฐ๊ณผ๊ฐ€ ๋ฌธ์ž์—ด์ด๋ฏ€๋กœ, ์ด๋ฅผ ์‹คํ–‰ ๊ฐ€๋Šฅํ•œ ์ฝ”๋“œ๋กœ ๋ณ€ํ™˜
import re
code_pattern = r'```python\n(.*?)\n```'
code_match = re.search(code_pattern, result, re.DOTALL)
if code_match:
viz_code = code_match.group(1)
# 'plt' ๋Œ€์‹  'px'๋ฅผ ์‚ฌ์šฉํ•˜๋„๋ก ์ฝ”๋“œ ์ˆ˜์ •
viz_code = viz_code.replace('plt.', 'px.')
viz_code = viz_code.replace('plt.show()', 'fig = px.scatter(df, x=x, y=y)')
# ์ฝ”๋“œ ์‹คํ–‰ ๋ฐ ๊ทธ๋ž˜ํ”„ ํ‘œ์‹œ
exec(viz_code)
st.plotly_chart(fig)
else:
st.write("๊ทธ๋ž˜ํ”„๋ฅผ ์ƒ์„ฑํ•˜์ง€ ๋ชปํ–ˆ์Šต๋‹ˆ๋‹ค. ๋‹ค๋ฅธ ์งˆ๋ฌธ์„ ํ•ด๋ณด์„ธ์š”.")
except Exception as e:
st.write(f"์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค: {str(e)}")
st.write("๋‹ค๋ฅธ ๋ฐฉ์‹์œผ๋กœ ์งˆ๋ฌธํ•ด๋ณด์„ธ์š”.")
elif not api_key:
st.warning("OpenAI API ํ‚ค๋ฅผ ์ž…๋ ฅํ•ด์ฃผ์„ธ์š”.")
elif not pandasai_api_key:
st.warning("PandasAI API ํ‚ค๋ฅผ ์ž…๋ ฅํ•ด์ฃผ์„ธ์š”.")