Spaces:
Sleeping
Sleeping
# Import necessary libraries | |
import streamlit as st | |
import pandas as pd | |
from sentence_transformers import SentenceTransformer | |
from langchain.vectorstores import FAISS | |
from langchain.embeddings.base import Embeddings | |
from transformers import pipeline | |
from langchain.llms.huggingface_pipeline import HuggingFacePipeline | |
from langchain.chains import RetrievalQA | |
# Define a LangChain-compatible wrapper for SentenceTransformer | |
class SentenceTransformerEmbeddings(Embeddings): | |
""" | |
Wrapper for SentenceTransformer to integrate with LangChain. | |
""" | |
def __init__(self, model_name: str): | |
self.model = SentenceTransformer(model_name) | |
def embed_documents(self, texts): | |
""" | |
Generates embeddings for a list of documents. | |
Args: | |
texts (list): List of strings to embed. | |
Returns: | |
np.ndarray: Embedding vectors. | |
""" | |
return self.model.encode(texts, show_progress_bar=False) | |
def embed_query(self, text): | |
""" | |
Generates an embedding for a single query. | |
Args: | |
text (str): Query string to embed. | |
Returns: | |
np.ndarray: Embedding vector. | |
""" | |
return self.model.encode([text], show_progress_bar=False)[0] | |
# Initialize the embedding model | |
embedding_model = SentenceTransformerEmbeddings('all-MiniLM-L6-v2') | |
# Preprocess data into descriptive text entries | |
def preprocess_data(data): | |
""" | |
Combines multiple dataset columns into descriptive text entries for embedding. | |
Args: | |
data (pd.DataFrame): The input dataset containing participant details. | |
Returns: | |
list: A list of combined textual descriptions for each row in the dataset. | |
""" | |
combined_entries = [] | |
for _, row in data.iterrows(): | |
entry = f"Participant {row['ID']}:\n" | |
entry += f"- AI Knowledge Level: {row['Q1.AI_knowledge']}\n" | |
entry += f"- Sources of AI Knowledge: {row['Q2.AI_sources']}\n" | |
entry += f"- Perspectives on AI: Dehumanization ({row['Q3#1.AI_dehumanization']}), " | |
entry += f"Job Replacement ({row['Q3#2.Job_replacement']})\n" | |
entry += f"- Domains Impacted by AI: {row['Q6.Domains']}\n" | |
entry += f"- Utility Grade for AI: {row['Q7.Utility_grade']}\n" | |
entry += f"- GPA: {row['Q16.GPA']}\n" | |
combined_entries.append(entry) | |
return combined_entries | |
# App logic | |
def main(): | |
# Set up the Streamlit UI | |
st.title("RAG Chatbot") | |
st.write("This chatbot answers questions based on the dataset.") | |
# Load the dataset directly from the space directory | |
dataset_path = "Survey_AI.csv" | |
try: | |
data = pd.read_csv(dataset_path) | |
st.write("Dataset successfully loaded!") | |
# Preprocess data and create vector store | |
combined_texts = preprocess_data(data) | |
vector_store = FAISS.from_texts(combined_texts, embedding_model) | |
retriever = vector_store.as_retriever() | |
# Set up QA chain | |
flan_t5 = pipeline("text2text-generation", model="google/flan-t5-base", device=-1) # CPU mode | |
llm = HuggingFacePipeline(pipeline=flan_t5) | |
qa_chain = RetrievalQA.from_chain_type(llm=llm, retriever=retriever) | |
# Default sample questions | |
sample_questions = [ | |
"What are the sources of AI knowledge for participants?", | |
"Which domains are impacted by AI?", | |
"What are participants' perspectives on job replacement due to AI?", | |
"What is the average GPA of participants?", | |
"What is the utility grade for AI?", | |
"Which participants view AI as highly beneficial in their domain?" | |
] | |
st.subheader("Sample Questions") | |
selected_question = st.selectbox("Select a question to see the response:", [""] + sample_questions) | |
if selected_question: | |
response = qa_chain.run(selected_question) | |
st.write("Question:", selected_question) | |
st.write("Answer:", response) | |
# Custom user query | |
st.subheader("Custom Query") | |
query = st.text_input("Or, enter your own question:") | |
if query: | |
response = qa_chain.run(query) | |
st.write("Question:", query) | |
st.write("Answer:", response) | |
except FileNotFoundError: | |
st.error("Dataset file not found. Please ensure the file is named 'dataset.csv' and uploaded to the root directory.") | |
# Run the app | |
if __name__ == "__main__": | |
main() | |