Spaces:
Sleeping
Sleeping
File size: 7,100 Bytes
4717959 863b3ac 4717959 8c83cf7 4717959 863b3ac 2614015 4717959 8c83cf7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 |
import streamlit as st
import os
import logging
from typing import List, Dict, Any
from data_processor import load_json_data, process_documents, split_documents
from rag_pipeline import RAGPipeline
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Constants
DATA_PATH = "ltu_programme_data.json"
QDRANT_PATH = "./qdrant_data"
EMBEDDING_MODEL = "BAAI/bge-en-icl"
LLM_MODEL = "meta-llama/Llama-3.3-70B-Instruct"
qdrant = None
# Initialize session state
if "messages" not in st.session_state:
st.session_state.messages = []
@st.cache_resource
def get_rag_pipeline():
return RAGPipeline(
embedding_model_name=EMBEDDING_MODEL,
llm_model_name=LLM_MODEL,
qdrant_path = QDRANT_PATH
)
def load_and_index_documents(rag_pipeline: RAGPipeline) -> bool:
"""Load and index documents"""
if not os.path.exists(DATA_PATH):
st.error(f"Data file not found: {DATA_PATH}")
return False
with st.spinner("Loading and processing documents..."):
# Load data
data = load_json_data(DATA_PATH)
if not data:
st.error("Failed to load data")
return False
# Process documents
processed_docs = process_documents(data)
if not processed_docs:
st.error("Failed to process documents")
return False
# Split documents
chunked_docs = split_documents(processed_docs, chunk_size=1000, overlap=100)
if not chunked_docs:
st.error("Failed to split documents")
return False
# Index documents
with st.spinner(f"Indexing {len(chunked_docs)} document chunks..."):
rag_pipeline.index_documents(chunked_docs)
return True
def display_document_sources(documents: List[Dict[str, Any]]):
"""Display the sources of the retrieved documents"""
if documents:
with st.expander("View Sources"):
for i, doc in enumerate(documents):
st.markdown(f"**Source {i+1}**: [{doc.meta.get('url', 'Unknown')}]({doc.meta.get('url', '#')})")
st.markdown(f"**Excerpt**: {doc.content[:200]}...")
st.markdown("---")
def check_documents_indexed(qdrant_path: str) -> int:
"""Check if documents are already indexed by returning the number of documents in Qdrant"""
try:
from haystack_integrations.document_stores.qdrant import QdrantDocumentStore
# Initialize the document store with the existing path
document_store = QdrantDocumentStore(
path=qdrant_path,
embedding_dim=4096,
recreate_index=False,
index="ltu_documents"
)
# Get the document count
document_count = len(document_store.filter_documents({}))
return document_count
except Exception:
# If there's an error (e.g., Qdrant not initialized), return 0
return 0
def main():
# Set page config
st.set_page_config(
page_title="LTU Chat - QA App",
page_icon="π",
layout="wide"
)
# Header
st.title("π LTU Chat - QA App")
st.markdown("""
Ask questions about LTU programmes and get answers powered by AI.
This app uses RAG (Retrieval Augmented Generation) to provide accurate information.
""")
rag_pipeline = get_rag_pipeline()
# Sidebar
with st.sidebar:
st.header("Sett`ings")
# Initialize RAG pipeline if not already done
# if st.session_state.rag_pipeline is None:
# if st.button("Initialize RAG Pipeline"):
# st.session_state.rag_pipeline = get_rag_pipeline()
# st.success("RAG pipeline initialized successfully!")
# else:
# st.success("RAG pipeline is ready!")
# Check if documents are already indexed
documents_indexed = rag_pipeline.get_document_count()
if not documents_indexed:
if st.button("Index Documents"):
success = load_and_index_documents(rag_pipeline)
if success:
st.success("Documents indexed successfully!")
# Refresh the documents_indexed status
documents_indexed = True
# Get document counts
count = rag_pipeline.get_document_count()
st.info(f"Indexed {count} documents documents in vector store.")
else:
st.success(f"{documents_indexed} documents are indexed and ready!")
top_k = st.slider("Number of documents to retrieve", min_value=1, max_value=10, value=5)
# Work in progress
st.title("Work in progress")
st.toggle("Hybrid retrieval", disabled=True)
st.toggle("Self RAG", disabled=True)
st.toggle("Query Expansion", disabled=True)
st.toggle("Graph RAG", disabled=True)
st.toggle("Prompt engineering (CoT, Step-Back Prompt, Active Prompt)", disabled=True)
# Display chat messages
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
if message.get("documents"):
display_document_sources(message["documents"])
# Chat input
if prompt := st.chat_input("Ask a question about LTU programmes"):
# Add user message to chat history
st.session_state.messages.append({"role": "user", "content": prompt})
# Display user message
with st.chat_message("user"):
st.markdown(prompt)
# Generate response
if rag_pipeline and documents_indexed:
with st.chat_message("assistant"):
with st.spinner("Thinking..."):
# Query the RAG pipeline
result = rag_pipeline.query(prompt, top_k=top_k)
# Display the answer
st.markdown(result["answer"])
# Display sources
if result.get("documents"):
display_document_sources(result["documents"])
# Add assistant message to chat history
st.session_state.messages.append({
"role": "assistant",
"content": result["answer"],
"documents": result.get("documents", [])
})
else:
with st.chat_message("assistant"):
if not rag_pipeline:
error_message = "Please initialize the RAG pipeline first."
else:
error_message = "Please index documents first."
st.error(error_message)
st.session_state.messages.append({"role": "assistant", "content": error_message})
if __name__ == "__main__":
main()
|