import io
import os
import torch
import streamlit as st
from PyPDF2 import PdfReader
from langchain.text_splitter import CharacterTextSplitter
from langchain.chains.question_answering import load_qa_chain
from langchain_huggingface import HuggingFaceEmbeddings, HuggingFacePipeline
from langchain_community.vectorstores import FAISS
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline


# Global variables are no longer needed, we will use session state

# PDF 파일 로드 및 텍스트 추출
def load_pdf(pdf_file):
    pdf_reader = PdfReader(pdf_file)
    text = "".join(page.extract_text() for page in pdf_reader.pages)
    return text

# 텍스트를 청크로 분할
def split_text(text):
    text_splitter = CharacterTextSplitter(
        separator="\n", 
        chunk_size=1000, 
        chunk_overlap=200, 
        length_function=len
    )
    return text_splitter.split_text(text)

# FAISS 벡터 저장소 생성
def create_knowledge_base(chunks):
    model_name = "sentence-transformers/all-mpnet-base-v2"  # 임베딩 모델을 명시
    embeddings = HuggingFaceEmbeddings(model_name=model_name)
    return FAISS.from_texts(chunks, embeddings)

# Hugging Face 모델 로드
def load_model():
    model_name = "google/gemma-2-2b"  # Hugging Face 모델 ID
    access_token = os.getenv("HF_TOKEN")
    
    try:
        tokenizer = AutoTokenizer.from_pretrained(model_name, token=access_token, clean_up_tokenization_spaces=False)
        model = AutoModelForCausalLM.from_pretrained(model_name, token=access_token)
        
        # 디바이스 설정
        if torch.cuda.is_available():
            device = 0
        else:
            device = -1
        
        # `do_sample`을 True로 설정
        return pipeline(
            "text-generation", 
            model=model, 
            tokenizer=tokenizer, 
            max_new_tokens=150, 
            temperature=0.1, 
            do_sample=True,  # 이 설정 추가
            device=device
        )
    
    except Exception as e:
        print(f"Error loading model: {e}")
        return None

# 모델 응답 처리
def get_response_from_model(prompt):
    try:
        if "knowledge_base" not in st.session_state:
            return "No PDF has been uploaded yet."
        if "qa_chain" not in st.session_state:
            return "QA chain is not initialized."

        docs = st.session_state.knowledge_base.similarity_search(prompt)
        print("docs:", docs) # 이까진 됐는데
        print("prompt:", prompt) # 이까진 됐는데
        # Chain의 invoke() 메소드 사용 (input_documents로 전달)
        response = st.session_state.qa_chain.invoke({
            "input_documents": docs,
            "question": prompt
        })
        try:
            if "Helpful Answer:" in response:
                response = response.split("Helpful Answer:")[1].strip()
        except ValueError as e:
            print(f"ValueError occurred: {e}")
            return f"Error: Invalid response format - {e}"
        return response
    except Exception as e:
        return f"Error: {str(e)}"


# 페이지 UI
def main():
    st.title("Welcome to GemmaPaperQA")
    
    # PDF 업로드 섹션
    with st.expander("Upload Your Paper", expanded=True):
        paper = st.file_uploader("Upload Here!", type="pdf", label_visibility="hidden")
        
        if paper:
            st.write(f"Upload complete! File name: {paper.name}")
            
            # 파일 크기 확인
            file_size = paper.size  # 파일 크기를 파일 포인터 이동 없이 확인
            if file_size > 10 * 1024 * 1024:  # 10MB 제한
                st.error("File is too large! Please upload a file smaller than 10MB.")
                return

            # PDF 텍스트 미리보기
            with st.spinner('Processing PDF...'):
                try:
                    paper.seek(0)
                    contents = paper.read()
                    pdf_file = io.BytesIO(contents)
                    text = load_pdf(pdf_file)

                    if len(text.strip()) == 0:
                        st.error("The PDF appears to have no extractable text. Please check the file and try again.")
                        return

                    st.text_area("Preview of extracted text", text[:1000], height=200)
                    st.write(f"Total characters extracted: {len(text)}")

                    if st.button("Create Knowledge Base"):
                        chunks = split_text(text)
                        st.session_state.knowledge_base = create_knowledge_base(chunks)
                        print("knowledge_base:", st.session_state.knowledge_base)

                        if st.session_state.knowledge_base is None:
                            st.error("Failed to create knowledge base.")
                            return

                        try:
                            pipe = load_model()
                        except Exception as e:
                            st.error(f"Error loading model: {e}")
                            return
                        llm = HuggingFacePipeline(pipeline=pipe)
                        st.session_state.qa_chain = load_qa_chain(llm, chain_type="map_rerank")

                        st.success("Knowledge base created! You can now ask questions.")

                except Exception as e:
                    st.error(f"Failed to process the PDF: {str(e)}")

    # 질문-응답 섹션
    if "knowledge_base" in st.session_state and "qa_chain" in st.session_state:
        with st.expander("Ask Questions", expanded=True):
            prompt = st.text_input("Chat here!")
            if prompt:
                response = get_response_from_model(prompt)
                if response:
                    st.write(f"**Assistant**: {response}")

# 앱 실행
if __name__ == "__main__":
    main()