import streamlit as st
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.llms import Ollama
from langgraph.graph import StateGraph, END
from langchain.chains import RetrievalQA
import requests
from typing import TypedDict, Annotated, List
from langchain_core.messages import HumanMessage, AIMessage
import operator

# Function to fetch GitHub repo data
def fetch_github_data(repo_url):
    parts = repo_url.split('/')
    owner, repo = parts[-2], parts[-1]
    
    headers = {'Accept': 'application/vnd.github.v3+json'}
    base_url = 'https://api.github.com'
    
    content = ""
    repo_response = requests.get(f"{base_url}/repos/{owner}/{repo}", headers=headers)
    if repo_response.status_code == 200:
        repo_data = repo_response.json()
        content += f"Description: {repo_data.get('description', '')}\n"
    
    readme_response = requests.get(f"{base_url}/repos/{owner}/{repo}/readme", headers=headers)
    if readme_response.status_code == 200:
        import base64
        readme_data = readme_response.json()
        content += base64.b64decode(readme_data['content']).decode('utf-8') + "\n"
    
    return content

# Function to create vector store
def create_vector_store(text_data):
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=1000,
        chunk_overlap=200
    )
    chunks = text_splitter.split_text(text_data)
    
    embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
    vector_store = FAISS.from_texts(chunks, embeddings)
    return vector_store

# Define the state for LangGraph
class GraphState(TypedDict):
    messages: Annotated[List[HumanMessage | AIMessage], operator.add]
    generation: str
    search_count: int

# Node to perform initial/additional vector store search
def search_vector_store(state: GraphState):
    llm = st.session_state.llm
    vector_store = st.session_state.vector_store
    question = state["messages"][0].content  # Original question
    current_generation = state["generation"]
    search_count = state["search_count"] + 1
    
    # Modify query slightly for additional searches
    if search_count > 1:
        query = f"{question} (additional context for: {current_generation})"
    else:
        query = question
    
    retriever = vector_store.as_retriever()
    qa_chain = RetrievalQA.from_chain_type(
        llm=llm,
        chain_type="stuff",
        retriever=retriever
    )
    response = qa_chain.run(query)
    
    # Append new info to existing generation
    new_generation = f"{current_generation}\nAdditional Info: {response}" if current_generation else response
    
    return {
        "messages": [AIMessage(content=new_generation)],
        "generation": new_generation,
        "search_count": search_count
    }

# Node to evaluate sufficiency of the answer
def evaluate_sufficiency(state: GraphState):
    llm = st.session_state.llm
    question = state["messages"][0].content  # Original question
    current_generation = state["generation"]
    
    prompt = (
        f"Given the question '{question}' and the current information:\n'{current_generation}'\n"
        f"Is this sufficient to fully answer the question? Respond with 'Yes' or 'No'."
    )
    decision = llm.invoke(prompt).strip()
    
    return {
        "messages": [AIMessage(content=f"Sufficiency check: {decision}")],
        "generation": current_generation,
        "search_count": state["search_count"]
    }

# Node to finalize the answer
def finalize_answer(state: GraphState):
    llm = st.session_state.llm
    current_info = state["generation"]
    question = state["messages"][0].content  # Original question
    prompt = (
        f"Given the question '{question}' and the current information:\n'{current_info}'\n"
        f"Answer the question as you are answering for the first time"
    )
    final_answer = llm.invoke(prompt).strip()
    return {
        "messages": [AIMessage(content=f"Final Answer: {final_answer}")],
        "generation": final_answer,
        "search_count": state["search_count"]
    }

# Function to decide next step
def route_next_step(state: GraphState):
    last_message = state["messages"][-1].content
    search_count = state["search_count"]
    
    if "Sufficiency check: Yes" in last_message:
        return "finalize_answer"
    elif search_count >= 5:
        return "finalize_answer"  # Max 5 iterations
    else:
        return "search_vector_store"

# Build the LangGraph workflow
def build_graph():
    workflow = StateGraph(GraphState)
    
    workflow.add_node("search_vector_store", search_vector_store)
    workflow.add_node("evaluate_sufficiency", evaluate_sufficiency)
    workflow.add_node("finalize_answer", finalize_answer)
    
    workflow.set_entry_point("search_vector_store")
    workflow.add_edge("search_vector_store", "evaluate_sufficiency")
    workflow.add_conditional_edges(
        "evaluate_sufficiency",
        route_next_step,
        {
            "search_vector_store": "search_vector_store",
            "finalize_answer": "finalize_answer"
        }
    )
    workflow.add_edge("finalize_answer", END)
    
    return workflow.compile()

# Streamlit app
def main():
    st.title("Project Resilience Q&A Assistant")
    st.write("Ask anything about Project Resilience - answers always come from repo data!")
    
    # Hardcoded GitHub URL
    github_url = 'https://github.com/Project-Resilience/platform'
    repo_data = fetch_github_data(github_url)
    
    # Initialize session state
    if 'vector_store' not in st.session_state:
        st.session_state.vector_store = create_vector_store(repo_data)
        st.session_state.llm = Ollama(model="llama3.2:1b", temperature=0.7)
        st.session_state.graph = build_graph()

    # Question input
    question = st.text_input("Ask a question about the project# Project Resilience")
    
    # Get and display answer
    if question and st.session_state.graph:
        with st.spinner("Generating answer..."):
            initial_state = {
                "messages": [HumanMessage(content=question)],
                "generation": "",
                "search_count": 0
            }
            result = st.session_state.graph.invoke(initial_state)
            final_answer = result["generation"]
            st.write("**Answer:**")
            st.write(final_answer)

    # Sidebar with additional info
    st.sidebar.header("Project Resilience Assistant")
    st.sidebar.write("""
    Project Resilience's platform for decision makers, data scientists and the public.

    Project Resilience, initiated under the Global Initiative on AI and Data Commons, is a collaborative effort to build a public AI utility that could inform and help address global decision-augmentation challenges.

    The project empowers a global community of innovators, thought leaders, and the public to enhance and use a shared collection of data and AI tools, improving preparedness, intervention, and response to environmental, health, information, or economic threats in our communities. It also supports broader efforts toward achieving the Sustainable Development Goals (SDGs).
    """)

if __name__ == "__main__":
    main()