File size: 7,100 Bytes
4717959
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
863b3ac
4717959
 
 
 
 
 
 
 
 
 
 
 
 
8c83cf7
4717959
 
 
 
 
 
 
 
 
 
 
 
863b3ac
 
 
 
 
 
2614015
4717959
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8c83cf7
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
import streamlit as st
import os
import logging
from typing import List, Dict, Any
from data_processor import load_json_data, process_documents, split_documents
from rag_pipeline import RAGPipeline

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Constants
DATA_PATH = "ltu_programme_data.json"
QDRANT_PATH = "./qdrant_data"
EMBEDDING_MODEL = "BAAI/bge-en-icl"
LLM_MODEL = "meta-llama/Llama-3.3-70B-Instruct"
qdrant = None

# Initialize session state
if "messages" not in st.session_state:
    st.session_state.messages = []

@st.cache_resource
def get_rag_pipeline():
    return RAGPipeline(
        embedding_model_name=EMBEDDING_MODEL,
        llm_model_name=LLM_MODEL,
        qdrant_path = QDRANT_PATH
    )

def load_and_index_documents(rag_pipeline: RAGPipeline) -> bool:
    """Load and index documents"""
    if not os.path.exists(DATA_PATH):
        st.error(f"Data file not found: {DATA_PATH}")
        return False
    
    with st.spinner("Loading and processing documents..."):
        # Load data
        data = load_json_data(DATA_PATH)
        
        if not data:
            st.error("Failed to load data")
            return False
        
        # Process documents
        processed_docs = process_documents(data)
        
        if not processed_docs:
            st.error("Failed to process documents")
            return False
        
        # Split documents
        chunked_docs = split_documents(processed_docs, chunk_size=1000, overlap=100)
        
        if not chunked_docs:
            st.error("Failed to split documents")
            return False
        
        # Index documents
        with st.spinner(f"Indexing {len(chunked_docs)} document chunks..."):
            rag_pipeline.index_documents(chunked_docs)
        
        return True

def display_document_sources(documents: List[Dict[str, Any]]):
    """Display the sources of the retrieved documents"""
    if documents:
        with st.expander("View Sources"):
            for i, doc in enumerate(documents):
                st.markdown(f"**Source {i+1}**: [{doc.meta.get('url', 'Unknown')}]({doc.meta.get('url', '#')})")
                st.markdown(f"**Excerpt**: {doc.content[:200]}...")
                st.markdown("---")

def check_documents_indexed(qdrant_path: str) -> int:
    """Check if documents are already indexed by returning the number of documents in Qdrant"""
    try:
        from haystack_integrations.document_stores.qdrant import QdrantDocumentStore
        
        # Initialize the document store with the existing path
        document_store = QdrantDocumentStore(
            path=qdrant_path,
            embedding_dim=4096,
            recreate_index=False,
            index="ltu_documents"
        )
        
        # Get the document count
        document_count = len(document_store.filter_documents({}))
        return document_count
    except Exception:
        # If there's an error (e.g., Qdrant not initialized), return 0
        return 0

def main():
    # Set page config
    st.set_page_config(
        page_title="LTU Chat - QA App",
        page_icon="πŸŽ“",
        layout="wide"
    )
    
    # Header
    st.title("πŸŽ“ LTU Chat - QA App")
    st.markdown("""
    Ask questions about LTU programmes and get answers powered by AI.
    This app uses RAG (Retrieval Augmented Generation) to provide accurate information.
    """)
    rag_pipeline = get_rag_pipeline()
    # Sidebar
    with st.sidebar:
        st.header("Sett`ings")
        
        # Initialize RAG pipeline if not already done
        # if st.session_state.rag_pipeline is None:
        #     if st.button("Initialize RAG Pipeline"):
        #         st.session_state.rag_pipeline = get_rag_pipeline()
        #         st.success("RAG pipeline initialized successfully!")
        # else:
        #     st.success("RAG pipeline is ready!")
        
        # Check if documents are already indexed
        documents_indexed = rag_pipeline.get_document_count()
        if not documents_indexed:
            if st.button("Index Documents"):
                success = load_and_index_documents(rag_pipeline)
                if success:
                    st.success("Documents indexed successfully!")
                    # Refresh the documents_indexed status
                    documents_indexed = True
                    
                    # Get document counts
                    count = rag_pipeline.get_document_count()
                    st.info(f"Indexed {count} documents documents in vector store.")
        else:
            st.success(f"{documents_indexed} documents are indexed and ready!")
        
        top_k = st.slider("Number of documents to retrieve", min_value=1, max_value=10, value=5)
        # Work in progress
        st.title("Work in progress")
        st.toggle("Hybrid retrieval", disabled=True)
        st.toggle("Self RAG", disabled=True)
        st.toggle("Query Expansion", disabled=True)
        st.toggle("Graph RAG", disabled=True)
        st.toggle("Prompt engineering (CoT, Step-Back Prompt, Active Prompt)", disabled=True)
            
    
    # Display chat messages
    for message in st.session_state.messages:
        with st.chat_message(message["role"]):
            st.markdown(message["content"])
            if message.get("documents"):
                display_document_sources(message["documents"])
    
    # Chat input
    if prompt := st.chat_input("Ask a question about LTU programmes"):
        # Add user message to chat history
        st.session_state.messages.append({"role": "user", "content": prompt})
        
        # Display user message
        with st.chat_message("user"):
            st.markdown(prompt)
        
        # Generate response
        if rag_pipeline and documents_indexed:
            with st.chat_message("assistant"):
                with st.spinner("Thinking..."):
                    # Query the RAG pipeline
                    result = rag_pipeline.query(prompt, top_k=top_k)
                    
                    # Display the answer
                    st.markdown(result["answer"])
                    
                    # Display sources
                    if result.get("documents"):
                        display_document_sources(result["documents"])
                    
                    # Add assistant message to chat history
                    st.session_state.messages.append({
                        "role": "assistant", 
                        "content": result["answer"],
                        "documents": result.get("documents", [])
                    })
        else:
            with st.chat_message("assistant"):
                if not rag_pipeline:
                    error_message = "Please initialize the RAG pipeline first."
                else:
                    error_message = "Please index documents first."
                st.error(error_message)
                st.session_state.messages.append({"role": "assistant", "content": error_message})

if __name__ == "__main__":
    main()