Spaces:
Runtime error
Runtime error
File size: 7,319 Bytes
62cc824 cd0725e 62cc824 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 |
import streamlit as st
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.llms import Ollama
from langgraph.graph import StateGraph, END
from langchain.chains import RetrievalQA
import requests
from typing import TypedDict, Annotated, List
from langchain_core.messages import HumanMessage, AIMessage
import operator
# Function to fetch GitHub repo data
def fetch_github_data(repo_url):
parts = repo_url.split('/')
owner, repo = parts[-2], parts[-1]
headers = {'Accept': 'application/vnd.github.v3+json'}
base_url = 'https://api.github.com'
content = ""
repo_response = requests.get(f"{base_url}/repos/{owner}/{repo}", headers=headers)
if repo_response.status_code == 200:
repo_data = repo_response.json()
content += f"Description: {repo_data.get('description', '')}\n"
readme_response = requests.get(f"{base_url}/repos/{owner}/{repo}/readme", headers=headers)
if readme_response.status_code == 200:
import base64
readme_data = readme_response.json()
content += base64.b64decode(readme_data['content']).decode('utf-8') + "\n"
return content
# Function to create vector store
def create_vector_store(text_data):
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=200
)
chunks = text_splitter.split_text(text_data)
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
vector_store = FAISS.from_texts(chunks, embeddings)
return vector_store
# Define the state for LangGraph
class GraphState(TypedDict):
messages: Annotated[List[HumanMessage | AIMessage], operator.add]
generation: str
search_count: int
# Node to perform initial/additional vector store search
def search_vector_store(state: GraphState):
llm = st.session_state.llm
vector_store = st.session_state.vector_store
question = state["messages"][0].content # Original question
current_generation = state["generation"]
search_count = state["search_count"] + 1
# Modify query slightly for additional searches
if search_count > 1:
query = f"{question} (additional context for: {current_generation})"
else:
query = question
retriever = vector_store.as_retriever()
qa_chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=retriever
)
response = qa_chain.run(query)
# Append new info to existing generation
new_generation = f"{current_generation}\nAdditional Info: {response}" if current_generation else response
return {
"messages": [AIMessage(content=new_generation)],
"generation": new_generation,
"search_count": search_count
}
# Node to evaluate sufficiency of the answer
def evaluate_sufficiency(state: GraphState):
llm = st.session_state.llm
question = state["messages"][0].content # Original question
current_generation = state["generation"]
prompt = (
f"Given the question '{question}' and the current information:\n'{current_generation}'\n"
f"Is this sufficient to fully answer the question? Respond with 'Yes' or 'No'."
)
decision = llm.invoke(prompt).strip()
return {
"messages": [AIMessage(content=f"Sufficiency check: {decision}")],
"generation": current_generation,
"search_count": state["search_count"]
}
# Node to finalize the answer
def finalize_answer(state: GraphState):
llm = st.session_state.llm
current_info = state["generation"]
question = state["messages"][0].content # Original question
prompt = (
f"Given the question '{question}' and the current information:\n'{current_info}'\n"
f"Answer the question as you are answering for the first time"
)
final_answer = llm.invoke(prompt).strip()
return {
"messages": [AIMessage(content=f"Final Answer: {final_answer}")],
"generation": final_answer,
"search_count": state["search_count"]
}
# Function to decide next step
def route_next_step(state: GraphState):
last_message = state["messages"][-1].content
search_count = state["search_count"]
if "Sufficiency check: Yes" in last_message:
return "finalize_answer"
elif search_count >= 5:
return "finalize_answer" # Max 5 iterations
else:
return "search_vector_store"
# Build the LangGraph workflow
def build_graph():
workflow = StateGraph(GraphState)
workflow.add_node("search_vector_store", search_vector_store)
workflow.add_node("evaluate_sufficiency", evaluate_sufficiency)
workflow.add_node("finalize_answer", finalize_answer)
workflow.set_entry_point("search_vector_store")
workflow.add_edge("search_vector_store", "evaluate_sufficiency")
workflow.add_conditional_edges(
"evaluate_sufficiency",
route_next_step,
{
"search_vector_store": "search_vector_store",
"finalize_answer": "finalize_answer"
}
)
workflow.add_edge("finalize_answer", END)
return workflow.compile()
# Streamlit app
def main():
st.title("Project Resilience Q&A Assistant")
st.write("Ask anything about Project Resilience - answers always come from repo data!")
# Hardcoded GitHub URL
github_url = 'https://github.com/Project-Resilience/platform'
repo_data = fetch_github_data(github_url)
# Initialize session state
if 'vector_store' not in st.session_state:
st.session_state.vector_store = create_vector_store(repo_data)
st.session_state.llm = Ollama(model="llama3.2:1b", temperature=0.7)
st.session_state.graph = build_graph()
# Question input
question = st.text_input("Ask a question about the project# Project Resilience")
# Get and display answer
if question and st.session_state.graph:
with st.spinner("Generating answer..."):
initial_state = {
"messages": [HumanMessage(content=question)],
"generation": "",
"search_count": 0
}
result = st.session_state.graph.invoke(initial_state)
final_answer = result["generation"]
st.write("**Answer:**")
st.write(final_answer)
# Sidebar with additional info
st.sidebar.header("Project Resilience Assistant")
st.sidebar.write("""
Project Resilience's platform for decision makers, data scientists and the public.
Project Resilience, initiated under the Global Initiative on AI and Data Commons, is a collaborative effort to build a public AI utility that could inform and help address global decision-augmentation challenges.
The project empowers a global community of innovators, thought leaders, and the public to enhance and use a shared collection of data and AI tools, improving preparedness, intervention, and response to environmental, health, information, or economic threats in our communities. It also supports broader efforts toward achieving the Sustainable Development Goals (SDGs).
""")
if __name__ == "__main__":
main() |