Spaces:
Runtime error
Runtime error
import streamlit as st | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from langchain.vectorstores import FAISS | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_community.llms import Ollama | |
from langgraph.graph import StateGraph, END | |
from langchain.chains import RetrievalQA | |
import requests | |
from typing import TypedDict, Annotated, List | |
from langchain_core.messages import HumanMessage, AIMessage | |
import operator | |
# Function to fetch GitHub repo data | |
def fetch_github_data(repo_url): | |
parts = repo_url.split('/') | |
owner, repo = parts[-2], parts[-1] | |
headers = {'Accept': 'application/vnd.github.v3+json'} | |
base_url = 'https://api.github.com' | |
content = "" | |
repo_response = requests.get(f"{base_url}/repos/{owner}/{repo}", headers=headers) | |
if repo_response.status_code == 200: | |
repo_data = repo_response.json() | |
content += f"Description: {repo_data.get('description', '')}\n" | |
readme_response = requests.get(f"{base_url}/repos/{owner}/{repo}/readme", headers=headers) | |
if readme_response.status_code == 200: | |
import base64 | |
readme_data = readme_response.json() | |
content += base64.b64decode(readme_data['content']).decode('utf-8') + "\n" | |
return content | |
# Function to create vector store | |
def create_vector_store(text_data): | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=1000, | |
chunk_overlap=200 | |
) | |
chunks = text_splitter.split_text(text_data) | |
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
vector_store = FAISS.from_texts(chunks, embeddings) | |
return vector_store | |
# Define the state for LangGraph | |
class GraphState(TypedDict): | |
messages: Annotated[List[HumanMessage | AIMessage], operator.add] | |
generation: str | |
search_count: int | |
# Node to perform initial/additional vector store search | |
def search_vector_store(state: GraphState): | |
llm = st.session_state.llm | |
vector_store = st.session_state.vector_store | |
question = state["messages"][0].content # Original question | |
current_generation = state["generation"] | |
search_count = state["search_count"] + 1 | |
# Modify query slightly for additional searches | |
if search_count > 1: | |
query = f"{question} (additional context for: {current_generation})" | |
else: | |
query = question | |
retriever = vector_store.as_retriever() | |
qa_chain = RetrievalQA.from_chain_type( | |
llm=llm, | |
chain_type="stuff", | |
retriever=retriever | |
) | |
response = qa_chain.run(query) | |
# Append new info to existing generation | |
new_generation = f"{current_generation}\nAdditional Info: {response}" if current_generation else response | |
return { | |
"messages": [AIMessage(content=new_generation)], | |
"generation": new_generation, | |
"search_count": search_count | |
} | |
# Node to evaluate sufficiency of the answer | |
def evaluate_sufficiency(state: GraphState): | |
llm = st.session_state.llm | |
question = state["messages"][0].content # Original question | |
current_generation = state["generation"] | |
prompt = ( | |
f"Given the question '{question}' and the current information:\n'{current_generation}'\n" | |
f"Is this sufficient to fully answer the question? Respond with 'Yes' or 'No'." | |
) | |
decision = llm.invoke(prompt).strip() | |
return { | |
"messages": [AIMessage(content=f"Sufficiency check: {decision}")], | |
"generation": current_generation, | |
"search_count": state["search_count"] | |
} | |
# Node to finalize the answer | |
def finalize_answer(state: GraphState): | |
llm = st.session_state.llm | |
current_info = state["generation"] | |
question = state["messages"][0].content # Original question | |
prompt = ( | |
f"Given the question '{question}' and the current information:\n'{current_info}'\n" | |
f"Answer the question as you are answering for the first time" | |
) | |
final_answer = llm.invoke(prompt).strip() | |
return { | |
"messages": [AIMessage(content=f"Final Answer: {final_answer}")], | |
"generation": final_answer, | |
"search_count": state["search_count"] | |
} | |
# Function to decide next step | |
def route_next_step(state: GraphState): | |
last_message = state["messages"][-1].content | |
search_count = state["search_count"] | |
if "Sufficiency check: Yes" in last_message: | |
return "finalize_answer" | |
elif search_count >= 5: | |
return "finalize_answer" # Max 5 iterations | |
else: | |
return "search_vector_store" | |
# Build the LangGraph workflow | |
def build_graph(): | |
workflow = StateGraph(GraphState) | |
workflow.add_node("search_vector_store", search_vector_store) | |
workflow.add_node("evaluate_sufficiency", evaluate_sufficiency) | |
workflow.add_node("finalize_answer", finalize_answer) | |
workflow.set_entry_point("search_vector_store") | |
workflow.add_edge("search_vector_store", "evaluate_sufficiency") | |
workflow.add_conditional_edges( | |
"evaluate_sufficiency", | |
route_next_step, | |
{ | |
"search_vector_store": "search_vector_store", | |
"finalize_answer": "finalize_answer" | |
} | |
) | |
workflow.add_edge("finalize_answer", END) | |
return workflow.compile() | |
# Streamlit app | |
def main(): | |
st.title("Project Resilience Q&A Assistant") | |
st.write("Ask anything about Project Resilience - answers always come from repo data!") | |
# Hardcoded GitHub URL | |
github_url = 'https://github.com/Project-Resilience/platform' | |
repo_data = fetch_github_data(github_url) | |
# Initialize session state | |
if 'vector_store' not in st.session_state: | |
st.session_state.vector_store = create_vector_store(repo_data) | |
st.session_state.llm = Ollama(model="llama3.2:1b", temperature=0.7) | |
st.session_state.graph = build_graph() | |
# Question input | |
question = st.text_input("Ask a question about the project# Project Resilience") | |
# Get and display answer | |
if question and st.session_state.graph: | |
with st.spinner("Generating answer..."): | |
initial_state = { | |
"messages": [HumanMessage(content=question)], | |
"generation": "", | |
"search_count": 0 | |
} | |
result = st.session_state.graph.invoke(initial_state) | |
final_answer = result["generation"] | |
st.write("**Answer:**") | |
st.write(final_answer) | |
# Sidebar with additional info | |
st.sidebar.header("Project Resilience Assistant") | |
st.sidebar.write(""" | |
Project Resilience's platform for decision makers, data scientists and the public. | |
Project Resilience, initiated under the Global Initiative on AI and Data Commons, is a collaborative effort to build a public AI utility that could inform and help address global decision-augmentation challenges. | |
The project empowers a global community of innovators, thought leaders, and the public to enhance and use a shared collection of data and AI tools, improving preparedness, intervention, and response to environmental, health, information, or economic threats in our communities. It also supports broader efforts toward achieving the Sustainable Development Goals (SDGs). | |
""") | |
if __name__ == "__main__": | |
main() |