Spaces:
Sleeping
Sleeping
File size: 4,486 Bytes
e354c74 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
# Import necessary libraries
import streamlit as st
import pandas as pd
from sentence_transformers import SentenceTransformer
from langchain.vectorstores import FAISS
from langchain.embeddings.base import Embeddings
from transformers import pipeline
from langchain.llms.huggingface_pipeline import HuggingFacePipeline
from langchain.chains import RetrievalQA
# Define a LangChain-compatible wrapper for SentenceTransformer
class SentenceTransformerEmbeddings(Embeddings):
"""
Wrapper for SentenceTransformer to integrate with LangChain.
"""
def __init__(self, model_name: str):
self.model = SentenceTransformer(model_name)
def embed_documents(self, texts):
"""
Generates embeddings for a list of documents.
Args:
texts (list): List of strings to embed.
Returns:
np.ndarray: Embedding vectors.
"""
return self.model.encode(texts, show_progress_bar=False)
def embed_query(self, text):
"""
Generates an embedding for a single query.
Args:
text (str): Query string to embed.
Returns:
np.ndarray: Embedding vector.
"""
return self.model.encode([text], show_progress_bar=False)[0]
# Initialize the embedding model
embedding_model = SentenceTransformerEmbeddings('all-MiniLM-L6-v2')
# Preprocess data into descriptive text entries
def preprocess_data(data):
"""
Combines multiple dataset columns into descriptive text entries for embedding.
Args:
data (pd.DataFrame): The input dataset containing participant details.
Returns:
list: A list of combined textual descriptions for each row in the dataset.
"""
combined_entries = []
for _, row in data.iterrows():
entry = f"Participant {row['ID']}:\n"
entry += f"- AI Knowledge Level: {row['Q1.AI_knowledge']}\n"
entry += f"- Sources of AI Knowledge: {row['Q2.AI_sources']}\n"
entry += f"- Perspectives on AI: Dehumanization ({row['Q3#1.AI_dehumanization']}), "
entry += f"Job Replacement ({row['Q3#2.Job_replacement']})\n"
entry += f"- Domains Impacted by AI: {row['Q6.Domains']}\n"
entry += f"- Utility Grade for AI: {row['Q7.Utility_grade']}\n"
entry += f"- GPA: {row['Q16.GPA']}\n"
combined_entries.append(entry)
return combined_entries
# App logic
def main():
# Set up the Streamlit UI
st.title("RAG Chatbot")
st.write("This chatbot answers questions based on the dataset.")
# Load the dataset directly from the space directory
dataset_path = "Survey_AI.csv"
try:
data = pd.read_csv(dataset_path)
st.write("Dataset successfully loaded!")
# Preprocess data and create vector store
combined_texts = preprocess_data(data)
vector_store = FAISS.from_texts(combined_texts, embedding_model)
retriever = vector_store.as_retriever()
# Set up QA chain
flan_t5 = pipeline("text2text-generation", model="google/flan-t5-base", device=-1) # CPU mode
llm = HuggingFacePipeline(pipeline=flan_t5)
qa_chain = RetrievalQA.from_chain_type(llm=llm, retriever=retriever)
# Default sample questions
sample_questions = [
"What are the sources of AI knowledge for participants?",
"Which domains are impacted by AI?",
"What are participants' perspectives on job replacement due to AI?",
"What is the average GPA of participants?",
"What is the utility grade for AI?",
"Which participants view AI as highly beneficial in their domain?"
]
st.subheader("Sample Questions")
selected_question = st.selectbox("Select a question to see the response:", [""] + sample_questions)
if selected_question:
response = qa_chain.run(selected_question)
st.write("Question:", selected_question)
st.write("Answer:", response)
# Custom user query
st.subheader("Custom Query")
query = st.text_input("Or, enter your own question:")
if query:
response = qa_chain.run(query)
st.write("Question:", query)
st.write("Answer:", response)
except FileNotFoundError:
st.error("Dataset file not found. Please ensure the file is named 'dataset.csv' and uploaded to the root directory.")
# Run the app
if __name__ == "__main__":
main()
|