makkzone's picture
Update app.py
cd0725e verified
import streamlit as st
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.llms import Ollama
from langgraph.graph import StateGraph, END
from langchain.chains import RetrievalQA
import requests
from typing import TypedDict, Annotated, List
from langchain_core.messages import HumanMessage, AIMessage
import operator
# Function to fetch GitHub repo data
def fetch_github_data(repo_url):
parts = repo_url.split('/')
owner, repo = parts[-2], parts[-1]
headers = {'Accept': 'application/vnd.github.v3+json'}
base_url = 'https://api.github.com'
content = ""
repo_response = requests.get(f"{base_url}/repos/{owner}/{repo}", headers=headers)
if repo_response.status_code == 200:
repo_data = repo_response.json()
content += f"Description: {repo_data.get('description', '')}\n"
readme_response = requests.get(f"{base_url}/repos/{owner}/{repo}/readme", headers=headers)
if readme_response.status_code == 200:
import base64
readme_data = readme_response.json()
content += base64.b64decode(readme_data['content']).decode('utf-8') + "\n"
return content
# Function to create vector store
def create_vector_store(text_data):
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=200
)
chunks = text_splitter.split_text(text_data)
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
vector_store = FAISS.from_texts(chunks, embeddings)
return vector_store
# Define the state for LangGraph
class GraphState(TypedDict):
messages: Annotated[List[HumanMessage | AIMessage], operator.add]
generation: str
search_count: int
# Node to perform initial/additional vector store search
def search_vector_store(state: GraphState):
llm = st.session_state.llm
vector_store = st.session_state.vector_store
question = state["messages"][0].content # Original question
current_generation = state["generation"]
search_count = state["search_count"] + 1
# Modify query slightly for additional searches
if search_count > 1:
query = f"{question} (additional context for: {current_generation})"
else:
query = question
retriever = vector_store.as_retriever()
qa_chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=retriever
)
response = qa_chain.run(query)
# Append new info to existing generation
new_generation = f"{current_generation}\nAdditional Info: {response}" if current_generation else response
return {
"messages": [AIMessage(content=new_generation)],
"generation": new_generation,
"search_count": search_count
}
# Node to evaluate sufficiency of the answer
def evaluate_sufficiency(state: GraphState):
llm = st.session_state.llm
question = state["messages"][0].content # Original question
current_generation = state["generation"]
prompt = (
f"Given the question '{question}' and the current information:\n'{current_generation}'\n"
f"Is this sufficient to fully answer the question? Respond with 'Yes' or 'No'."
)
decision = llm.invoke(prompt).strip()
return {
"messages": [AIMessage(content=f"Sufficiency check: {decision}")],
"generation": current_generation,
"search_count": state["search_count"]
}
# Node to finalize the answer
def finalize_answer(state: GraphState):
llm = st.session_state.llm
current_info = state["generation"]
question = state["messages"][0].content # Original question
prompt = (
f"Given the question '{question}' and the current information:\n'{current_info}'\n"
f"Answer the question as you are answering for the first time"
)
final_answer = llm.invoke(prompt).strip()
return {
"messages": [AIMessage(content=f"Final Answer: {final_answer}")],
"generation": final_answer,
"search_count": state["search_count"]
}
# Function to decide next step
def route_next_step(state: GraphState):
last_message = state["messages"][-1].content
search_count = state["search_count"]
if "Sufficiency check: Yes" in last_message:
return "finalize_answer"
elif search_count >= 5:
return "finalize_answer" # Max 5 iterations
else:
return "search_vector_store"
# Build the LangGraph workflow
def build_graph():
workflow = StateGraph(GraphState)
workflow.add_node("search_vector_store", search_vector_store)
workflow.add_node("evaluate_sufficiency", evaluate_sufficiency)
workflow.add_node("finalize_answer", finalize_answer)
workflow.set_entry_point("search_vector_store")
workflow.add_edge("search_vector_store", "evaluate_sufficiency")
workflow.add_conditional_edges(
"evaluate_sufficiency",
route_next_step,
{
"search_vector_store": "search_vector_store",
"finalize_answer": "finalize_answer"
}
)
workflow.add_edge("finalize_answer", END)
return workflow.compile()
# Streamlit app
def main():
st.title("Project Resilience Q&A Assistant")
st.write("Ask anything about Project Resilience - answers always come from repo data!")
# Hardcoded GitHub URL
github_url = 'https://github.com/Project-Resilience/platform'
repo_data = fetch_github_data(github_url)
# Initialize session state
if 'vector_store' not in st.session_state:
st.session_state.vector_store = create_vector_store(repo_data)
st.session_state.llm = Ollama(model="llama3.2:1b", temperature=0.7)
st.session_state.graph = build_graph()
# Question input
question = st.text_input("Ask a question about the project# Project Resilience")
# Get and display answer
if question and st.session_state.graph:
with st.spinner("Generating answer..."):
initial_state = {
"messages": [HumanMessage(content=question)],
"generation": "",
"search_count": 0
}
result = st.session_state.graph.invoke(initial_state)
final_answer = result["generation"]
st.write("**Answer:**")
st.write(final_answer)
# Sidebar with additional info
st.sidebar.header("Project Resilience Assistant")
st.sidebar.write("""
Project Resilience's platform for decision makers, data scientists and the public.
Project Resilience, initiated under the Global Initiative on AI and Data Commons, is a collaborative effort to build a public AI utility that could inform and help address global decision-augmentation challenges.
The project empowers a global community of innovators, thought leaders, and the public to enhance and use a shared collection of data and AI tools, improving preparedness, intervention, and response to environmental, health, information, or economic threats in our communities. It also supports broader efforts toward achieving the Sustainable Development Goals (SDGs).
""")
if __name__ == "__main__":
main()