ric9176 commited on
Commit
704be27
·
1 Parent(s): 0e4dc68

Add basic agent with tavily search

Browse files
Files changed (1) hide show
  1. app.py +78 -118
app.py CHANGED
@@ -1,139 +1,99 @@
1
- import os
2
- from typing import List
3
- from chainlit.types import AskFileResponse
4
- from aimakerspace.text_utils import CharacterTextSplitter, TextFileLoader, PDFLoader
5
- from aimakerspace.openai_utils.prompts import (
6
- UserRolePrompt,
7
- SystemRolePrompt,
8
- AssistantRolePrompt,
9
- )
10
- from aimakerspace.openai_utils.embedding import EmbeddingModel
11
- from aimakerspace.vectordatabase import VectorDatabase
12
- from aimakerspace.openai_utils.chatmodel import ChatOpenAI
13
  import chainlit as cl
14
 
15
- system_template = """\
16
- Use the following context to answer a users question. If you cannot find the answer in the context, say you don't know the answer."""
17
- system_role_prompt = SystemRolePrompt(system_template)
18
 
19
- user_prompt_template = """\
20
- Context:
21
- {context}
 
 
 
 
22
 
23
- Question:
24
- {question}
25
- """
26
- user_role_prompt = UserRolePrompt(user_prompt_template)
 
 
27
 
28
- class RetrievalAugmentedQAPipeline:
29
- def __init__(self, llm: ChatOpenAI(), vector_db_retriever: VectorDatabase) -> None:
30
- self.llm = llm
31
- self.vector_db_retriever = vector_db_retriever
32
 
33
- async def arun_pipeline(self, user_query: str):
34
- context_list = self.vector_db_retriever.search_by_text(user_query, k=4)
 
 
35
 
36
- context_prompt = ""
37
- for context in context_list:
38
- context_prompt += context[0] + "\n"
39
 
40
- formatted_system_prompt = system_role_prompt.create_message()
41
 
42
- formatted_user_prompt = user_role_prompt.create_message(question=user_query, context=context_prompt)
 
 
43
 
44
- async def generate_response():
45
- async for chunk in self.llm.astream([formatted_system_prompt, formatted_user_prompt]):
46
- yield chunk
47
 
48
- return {"response": generate_response(), "context": context_list}
49
 
50
- text_splitter = CharacterTextSplitter()
 
51
 
 
 
 
 
 
 
 
 
52
 
53
- def process_file(file: AskFileResponse):
54
- import tempfile
55
- import shutil
56
-
57
- print(f"Processing file: {file.name}")
58
-
59
- # Create a temporary file with the correct extension
60
- suffix = f".{file.name.split('.')[-1]}"
61
- with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file:
62
- # Copy the uploaded file content to the temporary file
63
- shutil.copyfile(file.path, temp_file.name)
64
- print(f"Created temporary file at: {temp_file.name}")
65
-
66
- # Create appropriate loader
67
- if file.name.lower().endswith('.pdf'):
68
- loader = PDFLoader(temp_file.name)
69
- else:
70
- loader = TextFileLoader(temp_file.name)
71
-
72
- try:
73
- # Load and process the documents
74
- documents = loader.load_documents()
75
- texts = text_splitter.split_texts(documents)
76
- return texts
77
- finally:
78
- # Clean up the temporary file
79
- try:
80
- os.unlink(temp_file.name)
81
- except Exception as e:
82
- print(f"Error cleaning up temporary file: {e}")
83
 
 
 
84
 
85
  @cl.on_chat_start
86
  async def on_chat_start():
87
- files = None
88
-
89
- # Wait for the user to upload a file
90
- while files == None:
91
- files = await cl.AskFileMessage(
92
- content="Please upload a Text or PDF file to begin!",
93
- accept=["text/plain", "application/pdf"],
94
- max_size_mb=2,
95
- timeout=180,
96
- ).send()
97
-
98
- file = files[0]
99
-
100
- msg = cl.Message(
101
- content=f"Processing `{file.name}`..."
102
- )
103
- await msg.send()
104
-
105
- # load the file
106
- texts = process_file(file)
107
 
108
- print(f"Processing {len(texts)} text chunks")
109
-
110
- # Create a dict vector store
111
- vector_db = VectorDatabase()
112
- vector_db = await vector_db.abuild_from_list(texts)
 
 
 
 
113
 
114
- chat_openai = ChatOpenAI()
115
-
116
- # Create a chain
117
- retrieval_augmented_qa_pipeline = RetrievalAugmentedQAPipeline(
118
- vector_db_retriever=vector_db,
119
- llm=chat_openai
120
- )
121
 
122
- # Let the user know that the system is ready
123
- msg.content = f"Processing `{file.name}` done. You can now ask questions!"
124
- await msg.update()
125
-
126
- cl.user_session.set("chain", retrieval_augmented_qa_pipeline)
127
-
128
-
129
- @cl.on_message
130
- async def main(message):
131
- chain = cl.user_session.get("chain")
132
-
133
- msg = cl.Message(content="")
134
- result = await chain.arun_pipeline(message.content)
135
-
136
- async for stream_resp in result["response"]:
137
- await msg.stream_token(stream_resp)
138
-
139
- await msg.send()
 
1
+ from typing import Annotated, TypedDict, Literal
2
+ from langchain_openai import ChatOpenAI
3
+ from langgraph.graph import StateGraph, START, END
4
+ from langgraph.graph.message import MessagesState
5
+ from langgraph.prebuilt import ToolNode
6
+ from langgraph.graph.message import add_messages
7
+
8
+ from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
9
+ from langchain.schema.runnable.config import RunnableConfig
10
+ from langchain_community.tools.tavily_search import TavilySearchResults
11
+
 
12
  import chainlit as cl
13
 
14
+ class AgentState(TypedDict):
15
+ messages: Annotated[list, add_messages]
 
16
 
17
+ tavily_tool = TavilySearchResults(max_results=5)
18
+ tool_belt = [tavily_tool]
19
+ # Initialize the language models
20
+ # llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)
21
+ # final_llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0).with_config(tags=["final_node"])
22
+ model = ChatOpenAI(model="gpt-4o", temperature=0)
23
+ model = model.bind_tools(tool_belt)
24
 
25
+ # Define system prompt
26
+ SYSTEM_PROMPT = SystemMessage(content="""
27
+ You are a helpful AI assistant that answers questions clearly and concisely.
28
+ If you don't know something, simply say you don't know.
29
+ Be engaging and professional in your responses.
30
+ """)
31
 
 
 
 
 
32
 
33
+ def call_model(state: AgentState):
34
+ messages = state["messages"]
35
+ response = model.invoke(messages)
36
+ return {"messages" : [response]}
37
 
38
+ tool_node = ToolNode(tool_belt)
 
 
39
 
 
40
 
41
+ # Simple flow control - always go to final
42
+ def should_continue(state):
43
+ last_message = state["messages"][-1]
44
 
45
+ if last_message.tool_calls:
46
+ return "action"
 
47
 
48
+ return END
49
 
50
+ # Create the graph
51
+ builder = StateGraph(AgentState)
52
 
53
+ builder.set_entry_point("agent")
54
+ builder.add_node("agent", call_model)
55
+ builder.add_node("action", tool_node)
56
+ # Add edges
57
+ builder.add_conditional_edges(
58
+ "agent",
59
+ should_continue,
60
+ )
61
 
62
+ builder.add_edge("action", "agent")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
+ # Compile the graph
65
+ graph = builder.compile()
66
 
67
  @cl.on_chat_start
68
  async def on_chat_start():
69
+ await cl.Message("Hello! I'm your AI assistant. How can I help you today?").send()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
+ @cl.on_message
72
+ async def on_message(message: cl.Message):
73
+ # Create configuration with thread ID
74
+ config = {
75
+ "configurable": {
76
+ "thread_id": cl.context.session.id,
77
+ "checkpoint_ns": "default_namespace"
78
+ }
79
+ }
80
 
81
+ # Setup callback handler and final answer message
82
+ cb = cl.LangchainCallbackHandler()
83
+ final_answer = cl.Message(content="")
84
+ await final_answer.send()
 
 
 
85
 
86
+ # Stream the response
87
+ async for chunk in graph.astream(
88
+ {"messages": [HumanMessage(content=message.content)]},
89
+ config=RunnableConfig(callbacks=[cb], **config)
90
+ ):
91
+ for node, values in chunk.items():
92
+ if values.get("messages"):
93
+ last_message = values["messages"][-1]
94
+
95
+ # Only stream AI messages, skip tool outputs
96
+ if isinstance(last_message, AIMessage):
97
+ await final_answer.stream_token(last_message.content)
98
+
99
+ await final_answer.send()