Spaces:
Sleeping
Sleeping
from time import time | |
from pprint import pprint | |
import huggingface_hub | |
import streamlit as st | |
from typing import Literal, Dict | |
from typing_extensions import TypedDict | |
import langchain | |
from langgraph.graph import END, StateGraph | |
from langchain_community.chat_models import ChatOllama | |
from logger import logger | |
from config import config | |
from agents import get_agents, tools_dict | |
class GraphState(TypedDict): | |
"""Represents the state of the graph.""" | |
question: str | |
rephrased_question: str | |
function_agent_output: str | |
generation: str | |
def init_agents() -> dict[str, langchain.agents.AgentExecutor]: | |
huggingface_hub.login(token=config.hf_token, new_session=False) | |
llm = ChatOllama(model = config.ollama_model, temperature = 0.8) | |
return get_agents(llm) | |
# Nodes ----------------------------------------------------------------------- | |
def question_node(state: GraphState) -> Dict[str, str]: | |
""" | |
Generate a question for the function agent. | |
""" | |
logger.info("Generating question for function agent") | |
# config.status.update(label=":question: Breaking down question") | |
question = state["question"] | |
logger.info(f"Original question: {question}") | |
rephrased_question = agents["rephrase_agent"].invoke({"question": question}) | |
logger.info(f"Rephrased question: {rephrased_question}") | |
return {"rephrased_question": rephrased_question} | |
def function_agent_node(state: GraphState) -> Literal["finished"]: | |
""" | |
Call the function agent | |
""" | |
logger.info("Calling function agent") | |
question = state["rephrased_question"] | |
response = agents["function_agent"].invoke({"input": question, "tools": tools_dict}).get("output") | |
# config.status.update(label=":brain: Analysing data..") | |
logger.info(f"Function agent output: {response}") | |
return {"function_agent_output": response} | |
def output_node(state: GraphState) -> Dict[str, str]: | |
""" | |
Generate the final output | |
""" | |
logger.info("Generating output") | |
# config.status.update(label=":bulb: Preparing response..") | |
generation = agents["output_agent"].invoke({"context": state["function_agent_output"], | |
"question": state["rephrased_question"]}) | |
return {"generation": generation} | |
# Conditional Edge ------------------------------------------------------------ | |
def route_question(state: GraphState) -> Literal["vectorstore", "websearch"]: | |
""" | |
Route quesition to web search or RAG | |
""" | |
logger.info("Routing question") | |
# config.state.update(label=":chart_with_upwards_trend: Routing question") | |
question = state["question"] | |
logger.info(f"Question: {question}") | |
source = agents["router_agent"].invoke({"question": question}) | |
logger.info(source) | |
logger.info(source["datasource"]) | |
if source["datasource"] == "vectorstore": | |
return "vectorstore" | |
elif source["datasource"] == "websearch": | |
return "websearch" | |
# Graph ----------------------------------------------------------------------- | |
workflow = StateGraph(GraphState) | |
workflow.add_node("question_rephrase", question_node) | |
workflow.add_node("function_agent", function_agent_node) | |
workflow.add_node("output_node", output_node) | |
workflow.set_entry_point("question_rephrase") | |
workflow.add_edge("question_rephrase", "function_agent") | |
workflow.add_edge("function_agent", "output_node") | |
workflow.set_finish_point("output_node") | |
flow = workflow.compile() | |
progress_map = { | |
"question_rephrase": ":mag: Collecting data", | |
"function_agent": ":bulb: Preparing response", | |
"output_node": ":bulb: Done!", | |
} | |
def main(): | |
st.title("LLM-ADE 9B Demo") | |
input_text = st.text_area("Enter your text here:", value="", height=200) | |
def get_response(input_text: str, depth: int = 1) -> str: | |
try: | |
for output in flow.stream({"question": input_text}): | |
for key, value in output.items(): | |
config.status.update(label=progress_map[key]) | |
pprint(f"Finished running: {key}") | |
return value["generation"] | |
except Exception as e: | |
logger.error(e) | |
logger.info("Retrying..") | |
if depth < 5: | |
return get_response(input_text, depth + 1) | |
if st.button("Generate"): | |
if input_text: | |
with st.status("Generating response...") as status: | |
config.status = status | |
config.status.update(label=":question: Breaking down question") | |
response = get_response(input_text) | |
st.write(response) | |
config.status.update(label="Finished!", state="complete", expanded=True) | |
else: | |
st.warning("Please enter some text to generate a response.") | |
def main_headless(prompt: str): | |
start = time() | |
for output in flow.stream({"question": prompt}): | |
for key, value in output.items(): | |
pprint(f"Finished running: {key}") | |
print("\033[94m" + value["generation"] + "\033[0m") | |
print(f"Time taken: {time() - start:.2f}s\n" + "-" * 20) | |
agents = init_agents() | |
if __name__ == "__main__": | |
if config.headless: | |
import fire | |
fire.Fire(main_headless) | |
else: | |
main() | |