File size: 7,319 Bytes
62cc824
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cd0725e
62cc824
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
import streamlit as st
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.llms import Ollama
from langgraph.graph import StateGraph, END
from langchain.chains import RetrievalQA
import requests
from typing import TypedDict, Annotated, List
from langchain_core.messages import HumanMessage, AIMessage
import operator

# Function to fetch GitHub repo data
def fetch_github_data(repo_url):
    parts = repo_url.split('/')
    owner, repo = parts[-2], parts[-1]
    
    headers = {'Accept': 'application/vnd.github.v3+json'}
    base_url = 'https://api.github.com'
    
    content = ""
    repo_response = requests.get(f"{base_url}/repos/{owner}/{repo}", headers=headers)
    if repo_response.status_code == 200:
        repo_data = repo_response.json()
        content += f"Description: {repo_data.get('description', '')}\n"
    
    readme_response = requests.get(f"{base_url}/repos/{owner}/{repo}/readme", headers=headers)
    if readme_response.status_code == 200:
        import base64
        readme_data = readme_response.json()
        content += base64.b64decode(readme_data['content']).decode('utf-8') + "\n"
    
    return content

# Function to create vector store
def create_vector_store(text_data):
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=1000,
        chunk_overlap=200
    )
    chunks = text_splitter.split_text(text_data)
    
    embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
    vector_store = FAISS.from_texts(chunks, embeddings)
    return vector_store

# Define the state for LangGraph
class GraphState(TypedDict):
    messages: Annotated[List[HumanMessage | AIMessage], operator.add]
    generation: str
    search_count: int

# Node to perform initial/additional vector store search
def search_vector_store(state: GraphState):
    llm = st.session_state.llm
    vector_store = st.session_state.vector_store
    question = state["messages"][0].content  # Original question
    current_generation = state["generation"]
    search_count = state["search_count"] + 1
    
    # Modify query slightly for additional searches
    if search_count > 1:
        query = f"{question} (additional context for: {current_generation})"
    else:
        query = question
    
    retriever = vector_store.as_retriever()
    qa_chain = RetrievalQA.from_chain_type(
        llm=llm,
        chain_type="stuff",
        retriever=retriever
    )
    response = qa_chain.run(query)
    
    # Append new info to existing generation
    new_generation = f"{current_generation}\nAdditional Info: {response}" if current_generation else response
    
    return {
        "messages": [AIMessage(content=new_generation)],
        "generation": new_generation,
        "search_count": search_count
    }

# Node to evaluate sufficiency of the answer
def evaluate_sufficiency(state: GraphState):
    llm = st.session_state.llm
    question = state["messages"][0].content  # Original question
    current_generation = state["generation"]
    
    prompt = (
        f"Given the question '{question}' and the current information:\n'{current_generation}'\n"
        f"Is this sufficient to fully answer the question? Respond with 'Yes' or 'No'."
    )
    decision = llm.invoke(prompt).strip()
    
    return {
        "messages": [AIMessage(content=f"Sufficiency check: {decision}")],
        "generation": current_generation,
        "search_count": state["search_count"]
    }

# Node to finalize the answer
def finalize_answer(state: GraphState):
    llm = st.session_state.llm
    current_info = state["generation"]
    question = state["messages"][0].content  # Original question
    prompt = (
        f"Given the question '{question}' and the current information:\n'{current_info}'\n"
        f"Answer the question as you are answering for the first time"
    )
    final_answer = llm.invoke(prompt).strip()
    return {
        "messages": [AIMessage(content=f"Final Answer: {final_answer}")],
        "generation": final_answer,
        "search_count": state["search_count"]
    }

# Function to decide next step
def route_next_step(state: GraphState):
    last_message = state["messages"][-1].content
    search_count = state["search_count"]
    
    if "Sufficiency check: Yes" in last_message:
        return "finalize_answer"
    elif search_count >= 5:
        return "finalize_answer"  # Max 5 iterations
    else:
        return "search_vector_store"

# Build the LangGraph workflow
def build_graph():
    workflow = StateGraph(GraphState)
    
    workflow.add_node("search_vector_store", search_vector_store)
    workflow.add_node("evaluate_sufficiency", evaluate_sufficiency)
    workflow.add_node("finalize_answer", finalize_answer)
    
    workflow.set_entry_point("search_vector_store")
    workflow.add_edge("search_vector_store", "evaluate_sufficiency")
    workflow.add_conditional_edges(
        "evaluate_sufficiency",
        route_next_step,
        {
            "search_vector_store": "search_vector_store",
            "finalize_answer": "finalize_answer"
        }
    )
    workflow.add_edge("finalize_answer", END)
    
    return workflow.compile()

# Streamlit app
def main():
    st.title("Project Resilience Q&A Assistant")
    st.write("Ask anything about Project Resilience - answers always come from repo data!")
    
    # Hardcoded GitHub URL
    github_url = 'https://github.com/Project-Resilience/platform'
    repo_data = fetch_github_data(github_url)
    
    # Initialize session state
    if 'vector_store' not in st.session_state:
        st.session_state.vector_store = create_vector_store(repo_data)
        st.session_state.llm = Ollama(model="llama3.2:1b", temperature=0.7)
        st.session_state.graph = build_graph()

    # Question input
    question = st.text_input("Ask a question about the project# Project Resilience")
    
    # Get and display answer
    if question and st.session_state.graph:
        with st.spinner("Generating answer..."):
            initial_state = {
                "messages": [HumanMessage(content=question)],
                "generation": "",
                "search_count": 0
            }
            result = st.session_state.graph.invoke(initial_state)
            final_answer = result["generation"]
            st.write("**Answer:**")
            st.write(final_answer)

    # Sidebar with additional info
    st.sidebar.header("Project Resilience Assistant")
    st.sidebar.write("""
    Project Resilience's platform for decision makers, data scientists and the public.

    Project Resilience, initiated under the Global Initiative on AI and Data Commons, is a collaborative effort to build a public AI utility that could inform and help address global decision-augmentation challenges.

    The project empowers a global community of innovators, thought leaders, and the public to enhance and use a shared collection of data and AI tools, improving preparedness, intervention, and response to environmental, health, information, or economic threats in our communities. It also supports broader efforts toward achieving the Sustainable Development Goals (SDGs).
    """)

if __name__ == "__main__":
    main()