Hyma7's picture
Create app.py
e354c74 verified
# Import necessary libraries
import streamlit as st
import pandas as pd
from sentence_transformers import SentenceTransformer
from langchain.vectorstores import FAISS
from langchain.embeddings.base import Embeddings
from transformers import pipeline
from langchain.llms.huggingface_pipeline import HuggingFacePipeline
from langchain.chains import RetrievalQA
# Define a LangChain-compatible wrapper for SentenceTransformer
class SentenceTransformerEmbeddings(Embeddings):
"""
Wrapper for SentenceTransformer to integrate with LangChain.
"""
def __init__(self, model_name: str):
self.model = SentenceTransformer(model_name)
def embed_documents(self, texts):
"""
Generates embeddings for a list of documents.
Args:
texts (list): List of strings to embed.
Returns:
np.ndarray: Embedding vectors.
"""
return self.model.encode(texts, show_progress_bar=False)
def embed_query(self, text):
"""
Generates an embedding for a single query.
Args:
text (str): Query string to embed.
Returns:
np.ndarray: Embedding vector.
"""
return self.model.encode([text], show_progress_bar=False)[0]
# Initialize the embedding model
embedding_model = SentenceTransformerEmbeddings('all-MiniLM-L6-v2')
# Preprocess data into descriptive text entries
def preprocess_data(data):
"""
Combines multiple dataset columns into descriptive text entries for embedding.
Args:
data (pd.DataFrame): The input dataset containing participant details.
Returns:
list: A list of combined textual descriptions for each row in the dataset.
"""
combined_entries = []
for _, row in data.iterrows():
entry = f"Participant {row['ID']}:\n"
entry += f"- AI Knowledge Level: {row['Q1.AI_knowledge']}\n"
entry += f"- Sources of AI Knowledge: {row['Q2.AI_sources']}\n"
entry += f"- Perspectives on AI: Dehumanization ({row['Q3#1.AI_dehumanization']}), "
entry += f"Job Replacement ({row['Q3#2.Job_replacement']})\n"
entry += f"- Domains Impacted by AI: {row['Q6.Domains']}\n"
entry += f"- Utility Grade for AI: {row['Q7.Utility_grade']}\n"
entry += f"- GPA: {row['Q16.GPA']}\n"
combined_entries.append(entry)
return combined_entries
# App logic
def main():
# Set up the Streamlit UI
st.title("RAG Chatbot")
st.write("This chatbot answers questions based on the dataset.")
# Load the dataset directly from the space directory
dataset_path = "Survey_AI.csv"
try:
data = pd.read_csv(dataset_path)
st.write("Dataset successfully loaded!")
# Preprocess data and create vector store
combined_texts = preprocess_data(data)
vector_store = FAISS.from_texts(combined_texts, embedding_model)
retriever = vector_store.as_retriever()
# Set up QA chain
flan_t5 = pipeline("text2text-generation", model="google/flan-t5-base", device=-1) # CPU mode
llm = HuggingFacePipeline(pipeline=flan_t5)
qa_chain = RetrievalQA.from_chain_type(llm=llm, retriever=retriever)
# Default sample questions
sample_questions = [
"What are the sources of AI knowledge for participants?",
"Which domains are impacted by AI?",
"What are participants' perspectives on job replacement due to AI?",
"What is the average GPA of participants?",
"What is the utility grade for AI?",
"Which participants view AI as highly beneficial in their domain?"
]
st.subheader("Sample Questions")
selected_question = st.selectbox("Select a question to see the response:", [""] + sample_questions)
if selected_question:
response = qa_chain.run(selected_question)
st.write("Question:", selected_question)
st.write("Answer:", response)
# Custom user query
st.subheader("Custom Query")
query = st.text_input("Or, enter your own question:")
if query:
response = qa_chain.run(query)
st.write("Question:", query)
st.write("Answer:", response)
except FileNotFoundError:
st.error("Dataset file not found. Please ensure the file is named 'dataset.csv' and uploaded to the root directory.")
# Run the app
if __name__ == "__main__":
main()