|
|
from time import time |
|
|
from pprint import pprint |
|
|
import huggingface_hub |
|
|
import streamlit as st |
|
|
from typing import Literal, Dict |
|
|
from typing_extensions import TypedDict |
|
|
import langchain |
|
|
from langgraph.graph import END, StateGraph |
|
|
from langchain_community.chat_models import ChatOllama |
|
|
from logger import logger |
|
|
|
|
|
from config import config |
|
|
from agents import get_agents, tools_dict |
|
|
|
|
|
|
|
|
class GraphState(TypedDict): |
|
|
"""Represents the state of the graph.""" |
|
|
question: str |
|
|
rephrased_question: str |
|
|
function_agent_output: str |
|
|
generation: str |
|
|
|
|
|
|
|
|
@st.cache_resource(show_spinner="Loading model..") |
|
|
def init_agents() -> dict[str, langchain.agents.AgentExecutor]: |
|
|
huggingface_hub.login(token=config.hf_token, new_session=False) |
|
|
llm = ChatOllama(model = config.ollama_model, temperature = 0.8) |
|
|
return get_agents(llm) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def question_node(state: GraphState) -> Dict[str, str]: |
|
|
""" |
|
|
Generate a question for the function agent. |
|
|
""" |
|
|
logger.info("Generating question for function agent") |
|
|
|
|
|
question = state["question"] |
|
|
logger.info(f"Original question: {question}") |
|
|
rephrased_question = agents["rephrase_agent"].invoke({"question": question}) |
|
|
logger.info(f"Rephrased question: {rephrased_question}") |
|
|
return {"rephrased_question": rephrased_question} |
|
|
|
|
|
def function_agent_node(state: GraphState) -> Literal["finished"]: |
|
|
""" |
|
|
Call the function agent |
|
|
""" |
|
|
logger.info("Calling function agent") |
|
|
question = state["rephrased_question"] |
|
|
response = agents["function_agent"].invoke({"input": question, "tools": tools_dict}).get("output") |
|
|
|
|
|
logger.info(f"Function agent output: {response}") |
|
|
return {"function_agent_output": response} |
|
|
|
|
|
def output_node(state: GraphState) -> Dict[str, str]: |
|
|
""" |
|
|
Generate the final output |
|
|
""" |
|
|
logger.info("Generating output") |
|
|
|
|
|
generation = agents["output_agent"].invoke({"context": state["function_agent_output"], |
|
|
"question": state["rephrased_question"]}) |
|
|
return {"generation": generation} |
|
|
|
|
|
|
|
|
|
|
|
def route_question(state: GraphState) -> Literal["vectorstore", "websearch"]: |
|
|
""" |
|
|
Route quesition to web search or RAG |
|
|
""" |
|
|
logger.info("Routing question") |
|
|
|
|
|
question = state["question"] |
|
|
logger.info(f"Question: {question}") |
|
|
source = agents["router_agent"].invoke({"question": question}) |
|
|
logger.info(source) |
|
|
logger.info(source["datasource"]) |
|
|
if source["datasource"] == "vectorstore": |
|
|
return "vectorstore" |
|
|
elif source["datasource"] == "websearch": |
|
|
return "websearch" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
workflow = StateGraph(GraphState) |
|
|
workflow.add_node("question_rephrase", question_node) |
|
|
workflow.add_node("function_agent", function_agent_node) |
|
|
workflow.add_node("output_node", output_node) |
|
|
|
|
|
workflow.set_entry_point("question_rephrase") |
|
|
workflow.add_edge("question_rephrase", "function_agent") |
|
|
workflow.add_edge("function_agent", "output_node") |
|
|
workflow.set_finish_point("output_node") |
|
|
|
|
|
flow = workflow.compile() |
|
|
|
|
|
progress_map = { |
|
|
"question_rephrase": ":mag: Collecting data", |
|
|
"function_agent": ":bulb: Preparing response", |
|
|
"output_node": ":bulb: Done!", |
|
|
} |
|
|
|
|
|
def main(): |
|
|
st.title("LLM-ADE 9B Demo") |
|
|
|
|
|
input_text = st.text_area("Enter your text here:", value="", height=200) |
|
|
|
|
|
def get_response(input_text: str, depth: int = 1) -> str: |
|
|
try: |
|
|
for output in flow.stream({"question": input_text}): |
|
|
for key, value in output.items(): |
|
|
config.status.update(label=progress_map[key]) |
|
|
pprint(f"Finished running: {key}") |
|
|
return value["generation"] |
|
|
except Exception as e: |
|
|
logger.error(e) |
|
|
logger.info("Retrying..") |
|
|
if depth < 5: |
|
|
return get_response(input_text, depth + 1) |
|
|
|
|
|
if st.button("Generate") or input_text: |
|
|
start = time() |
|
|
if input_text: |
|
|
with st.status("Generating response...") as status: |
|
|
config.status = status |
|
|
config.status.update(label=":question: Breaking down question") |
|
|
response = get_response(input_text) |
|
|
response = response.replace("$", "\$") |
|
|
st.info(response) |
|
|
config.status.update(label=f"Finished! ({time() - start:.2f}s)", state="complete", expanded=True) |
|
|
else: |
|
|
st.warning("Please enter some text to generate a response.") |
|
|
|
|
|
|
|
|
def main_headless(prompt: str): |
|
|
start = time() |
|
|
for output in flow.stream({"question": prompt}): |
|
|
for key, value in output.items(): |
|
|
pprint(f"Finished running: {key}") |
|
|
print("\033[94m" + value["generation"] + "\033[0m") |
|
|
print(f"Time taken: {time() - start:.2f}s\n" + "-" * 20) |
|
|
|
|
|
|
|
|
agents = init_agents() |
|
|
|
|
|
if __name__ == "__main__": |
|
|
if config.headless: |
|
|
import fire |
|
|
fire.Fire(main_headless) |
|
|
else: |
|
|
main() |
|
|
|