Spaces:
Sleeping
Sleeping
import streamlit as st | |
import os | |
import logging | |
from typing import List, Dict, Any | |
from data_processor import load_json_data, process_documents, split_documents | |
from rag_pipeline import RAGPipeline | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Constants | |
DATA_PATH = "ltu_programme_data.json" | |
QDRANT_PATH = "./qdrant_data" | |
EMBEDDING_MODEL = "BAAI/bge-en-icl" | |
LLM_MODEL = "meta-llama/Llama-3.3-70B-Instruct" | |
qdrant = None | |
# Initialize session state | |
if "messages" not in st.session_state: | |
st.session_state.messages = [] | |
def get_rag_pipeline(): | |
return RAGPipeline( | |
embedding_model_name=EMBEDDING_MODEL, | |
llm_model_name=LLM_MODEL, | |
qdrant_path = QDRANT_PATH | |
) | |
def load_and_index_documents(rag_pipeline: RAGPipeline) -> bool: | |
"""Load and index documents""" | |
if not os.path.exists(DATA_PATH): | |
st.error(f"Data file not found: {DATA_PATH}") | |
return False | |
with st.spinner("Loading and processing documents..."): | |
# Load data | |
data = load_json_data(DATA_PATH) | |
if not data: | |
st.error("Failed to load data") | |
return False | |
# Process documents | |
processed_docs = process_documents(data) | |
if not processed_docs: | |
st.error("Failed to process documents") | |
return False | |
# Split documents | |
chunked_docs = split_documents(processed_docs, chunk_size=1000, overlap=100) | |
if not chunked_docs: | |
st.error("Failed to split documents") | |
return False | |
# Index documents | |
with st.spinner(f"Indexing {len(chunked_docs)} document chunks..."): | |
rag_pipeline.index_documents(chunked_docs) | |
return True | |
def display_document_sources(documents: List[Dict[str, Any]]): | |
"""Display the sources of the retrieved documents""" | |
if documents: | |
with st.expander("View Sources"): | |
for i, doc in enumerate(documents): | |
st.markdown(f"**Source {i+1}**: [{doc.meta.get('url', 'Unknown')}]({doc.meta.get('url', '#')})") | |
st.markdown(f"**Excerpt**: {doc.content[:200]}...") | |
st.markdown("---") | |
def check_documents_indexed(qdrant_path: str) -> int: | |
"""Check if documents are already indexed by returning the number of documents in Qdrant""" | |
try: | |
from haystack_integrations.document_stores.qdrant import QdrantDocumentStore | |
# Initialize the document store with the existing path | |
document_store = QdrantDocumentStore( | |
path=qdrant_path, | |
embedding_dim=4096, | |
recreate_index=False, | |
index="ltu_documents" | |
) | |
# Get the document count | |
document_count = len(document_store.filter_documents({})) | |
return document_count | |
except Exception: | |
# If there's an error (e.g., Qdrant not initialized), return 0 | |
return 0 | |
def main(): | |
# Set page config | |
st.set_page_config( | |
page_title="LTU Chat - QA App", | |
page_icon="π", | |
layout="wide" | |
) | |
# Header | |
st.title("π LTU Chat - QA App") | |
st.markdown(""" | |
Ask questions about LTU programmes and get answers powered by AI. | |
This app uses RAG (Retrieval Augmented Generation) to provide accurate information. | |
""") | |
rag_pipeline = get_rag_pipeline() | |
# Sidebar | |
with st.sidebar: | |
st.header("Sett`ings") | |
# Initialize RAG pipeline if not already done | |
# if st.session_state.rag_pipeline is None: | |
# if st.button("Initialize RAG Pipeline"): | |
# st.session_state.rag_pipeline = get_rag_pipeline() | |
# st.success("RAG pipeline initialized successfully!") | |
# else: | |
# st.success("RAG pipeline is ready!") | |
# Check if documents are already indexed | |
documents_indexed = rag_pipeline.get_document_count() | |
if not documents_indexed: | |
if st.button("Index Documents"): | |
success = load_and_index_documents(rag_pipeline) | |
if success: | |
st.success("Documents indexed successfully!") | |
# Refresh the documents_indexed status | |
documents_indexed = True | |
# Get document counts | |
count = rag_pipeline.get_document_count() | |
st.info(f"Indexed {count} documents documents in vector store.") | |
else: | |
st.success(f"{documents_indexed} documents are indexed and ready!") | |
top_k = st.slider("Number of documents to retrieve", min_value=1, max_value=10, value=5) | |
# Work in progress | |
st.title("Work in progress") | |
st.toggle("Hybrid retrieval", disabled=True) | |
st.toggle("Self RAG", disabled=True) | |
st.toggle("Query Expansion", disabled=True) | |
st.toggle("Graph RAG", disabled=True) | |
st.toggle("Prompt engineering (CoT, Step-Back Prompt, Active Prompt)", disabled=True) | |
# Display chat messages | |
for message in st.session_state.messages: | |
with st.chat_message(message["role"]): | |
st.markdown(message["content"]) | |
if message.get("documents"): | |
display_document_sources(message["documents"]) | |
# Chat input | |
if prompt := st.chat_input("Ask a question about LTU programmes"): | |
# Add user message to chat history | |
st.session_state.messages.append({"role": "user", "content": prompt}) | |
# Display user message | |
with st.chat_message("user"): | |
st.markdown(prompt) | |
# Generate response | |
if rag_pipeline and documents_indexed: | |
with st.chat_message("assistant"): | |
with st.spinner("Thinking..."): | |
# Query the RAG pipeline | |
result = rag_pipeline.query(prompt, top_k=top_k) | |
# Display the answer | |
st.markdown(result["answer"]) | |
# Display sources | |
if result.get("documents"): | |
display_document_sources(result["documents"]) | |
# Add assistant message to chat history | |
st.session_state.messages.append({ | |
"role": "assistant", | |
"content": result["answer"], | |
"documents": result.get("documents", []) | |
}) | |
else: | |
with st.chat_message("assistant"): | |
if not rag_pipeline: | |
error_message = "Please initialize the RAG pipeline first." | |
else: | |
error_message = "Please index documents first." | |
st.error(error_message) | |
st.session_state.messages.append({"role": "assistant", "content": error_message}) | |
if __name__ == "__main__": | |
main() | |