import streamlit as st
from dotenv import load_dotenv
import os
import traceback

# PDF and NLP Libraries
import PyPDF2
from langchain.text_splitter import RecursiveCharacterTextSplitter
from sentence_transformers import SentenceTransformer, util

# Embedding and Vector Store
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS

# LLM and Conversational Chain
from langchain_groq import ChatGroq
from langchain.memory import ConversationBufferMemory
from langchain.chains import ConversationalRetrievalChain
from langchain.prompts import PromptTemplate

# Custom Templates
from htmlTemplate import css, bot_template, user_template

# Load environment variables
os.environ["GROQ_API_KEY"]= os.getenv('GROQ_API_KEY')

# LLM Template for focused responses
llmtemplate = """You're an AI information specialist with a strong emphasis on extracting accurate information from markdown documents. Your expertise involves summarizing data succinctly while adhering to strict guidelines about neutrality and clarity.
Your task is to answer a specific question based on a provided markdown document. Here is the question you need to address:  
{question}
Keep in mind the following instructions:  
- Your response should be direct and factual, limited to 50 words and 2-3 sentences.  
- Avoid using introductory phrases like "yes" or "no."  
- Maintain an ethical and unbiased tone, steering clear of harmful or offensive content.  
- If the document lacks relevant information, respond with "I cannot provide an answer based on the provided document."  
- Do not fabricate information, include questions, or use confirmatory phrases.  
- Remember not to prompt for additional information or ask any questions.  
Ensure your response is strictly based on the content of the markdown document.
"""

def prepare_docs(pdf_docs):
    """Extract text from uploaded PDF documents"""
    docs = []
    metadata = []
    content = []

    for pdf in pdf_docs:
        pdf_reader = PyPDF2.PdfReader(pdf)
        for index, text in enumerate(pdf_reader.pages):
            doc_page = {
                'title': f"{pdf.name} page {index + 1}",
                'content': pdf_reader.pages[index].extract_text()
            }
            docs.append(doc_page)
    
    for doc in docs:
        content.append(doc["content"])
        metadata.append({"title": doc["title"]})
    
    return content, metadata

def get_text_chunks(content, metadata):
    """Split documents into manageable chunks"""
    text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
        chunk_size=1024,
        chunk_overlap=256,
    )
    split_docs = text_splitter.create_documents(content, metadatas=metadata)
    print(f"Split documents into {len(split_docs)} passages")
    return split_docs

def ingest_into_vectordb(split_docs):
    """Create vector embeddings and store in FAISS"""
    embeddings = HuggingFaceEmbeddings(
        model_name="sentence-transformers/all-MiniLM-L6-v2",
        model_kwargs={'device':'cpu'}
    )
    db = FAISS.from_documents(split_docs, embeddings)
    DB_FAISS_PATH = 'vectorstore/db_faiss'
    db.save_local(DB_FAISS_PATH)
    return db

def get_conversation_chain(vectordb):
    """Create conversational retrieval chain"""
    llm = ChatGroq(model="llama3-70b-8192", temperature=0.25)
    retriever = vectordb.as_retriever()

    memory = ConversationBufferMemory(
        memory_key='chat_history', 
        return_messages=True, 
        output_key='answer'
    )

    conversation_chain = ConversationalRetrievalChain.from_llm(
        llm=llm,
        retriever=retriever,
        memory=memory,
        return_source_documents=True
    )
    
    print("Conversational Chain created for the LLM using the vector store")
    return conversation_chain

def validate_answer_against_sources(response_answer, source_documents):
    """Validate AI's response against source documents"""
    model = SentenceTransformer('all-MiniLM-L6-v2')
    similarity_threshold = 0.5  
    source_texts = [doc.page_content for doc in source_documents]

    answer_embedding = model.encode(response_answer, convert_to_tensor=True)
    source_embeddings = model.encode(source_texts, convert_to_tensor=True)

    cosine_scores = util.pytorch_cos_sim(answer_embedding, source_embeddings)

    return any(score.item() > similarity_threshold for score in cosine_scores[0])

def handle_userinput(user_question):
    """Process user input and display chat history"""
    response = st.session_state.conversation({'question': user_question})
    st.session_state.chat_history = response['chat_history']
    
    for i, message in enumerate(st.session_state.chat_history):
        if i % 2 == 0:
            st.write(user_template.replace(
                "{{MSG}}", message.content), unsafe_allow_html=True)
        else:
            st.write(bot_template.replace(
                "{{MSG}}", message.content), unsafe_allow_html=True)

def main():
    """Main Streamlit application"""
    load_dotenv()

    st.set_page_config(
        page_title="PDF Insights AI", 
        page_icon=":books:", 
        layout="wide"
    )
    st.write(css, unsafe_allow_html=True)

    # Welcome section
    st.title("📚 PDF Insights AI")
    st.markdown("""
    ### Unlock the Knowledge in Your PDFs
    - 🤖 AI-powered document analysis
    - 💬 Ask questions about your uploaded documents
    - 📄 Support for multiple PDF files
    """)

    # Initialize session state
    if "conversation" not in st.session_state:
        st.session_state.conversation = None
    if "chat_history" not in st.session_state:
        st.session_state.chat_history = []

    # File upload section
    with st.sidebar:
        st.header("📤 Upload Documents")
        pdf_docs = st.file_uploader(
            "Upload your PDFs here", 
            type=['pdf'], 
            accept_multiple_files=True,
            help="Upload PDF files to analyze. Max file size: 200MB"
        )

        # File validation
        if pdf_docs:
            for doc in pdf_docs:
                if doc.size > 200 * 1024 * 1024:  # 200 MB
                    st.error(f"File {doc.name} is too large. Maximum file size is 200MB.")
                    pdf_docs.remove(doc)

        if st.button("Process Documents", type="primary"):
            if not pdf_docs:
                st.warning("Please upload at least one PDF file.")
            else:
                with st.spinner("Processing your documents..."):
                    try:
                        # Process documents
                        content, metadata = prepare_docs(pdf_docs)
                        split_docs = get_text_chunks(content, metadata)
                        vectorstore = ingest_into_vectordb(split_docs)
                        st.session_state.conversation = get_conversation_chain(vectorstore)
                        
                        st.success("Documents processed successfully! You can now ask questions.")
                    except Exception as e:
                        st.error(f"An error occurred while processing documents: {str(e)}")

    # Question input section
    user_question = st.text_input(
        "📝 Ask a question about your documents", 
        placeholder="What insights can you provide from these documents?"
    )

    if user_question:
        if st.session_state.conversation is None:
            st.warning("Please upload and process documents first.")
        else:
            handle_userinput(user_question)

if __name__ == '__main__':
    main()