Order-Bot / llm_agent.py
Viraj2307's picture
Initial Commit
3618a4d verified
from tools import MenuTool, CartTool, OrderTool, greetings_function
from data_section import data_bifercation
from tools.prompts import tool_prompt_function
from config import settings
from utils import client
from context import ollama_context_query, summarised_output
import pandas as pd
class ReactAgent:
def __init__(self):
self.store_id = "66dff7a04b17303d454d4bbc"
self.brand_id = "66cec85093c5b0896c9125c5"
columns = ["category", "item", "price"]
main_data, category, items = data_bifercation(self.store_id, self.brand_id)
self.items = items
self.category = category
df = pd.DataFrame(main_data)
df.columns = columns
df["item"] = df["item"].str.lower()
df["category"] = df["category"].str.lower()
self.df = df
self.menu_tool = MenuTool(df)
self.cart_tool = CartTool(df)
self.order_tool = OrderTool(df)
self.llm_client = client
def handle_query(self, session_id, query, chat_history):
prompt = tool_prompt_function(current_query=query, session_id=session_id)
context_query, greet_bool = ollama_context_query(
chat_history=chat_history, user_query=query
)
if not greet_bool:
return greetings_function(query)
if context_query in ["MenuTool", "CartTool", "OrderTool"]:
user_message_content = query
else:
user_message_content = context_query
messages = [
{"role": "system", "content": prompt},
{"role": "user", "content": user_message_content},
]
response = self.llm_client.chat(
model=settings.MODEL_NAME,
messages=messages,
tools=[
{
"type": "function",
"function": {
"name": "menu_tool",
"description": "Fetch the restaurant menu based on user input",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "User's natural language query for the menu",
}
},
"required": ["query"],
},
},
},
{
"type": "function",
"function": {
"name": "cart_tool",
"description": "Manage the cart based on user input (add/remove/view)",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "User's cart query to add/remove/view items",
},
"session_id": {
"type": "string",
"description": "current session id",
},
},
"required": ["query", "session_id"],
},
},
},
{
"type": "function",
"function": {
"name": "order_tool",
"description": "Handle order and checkout functionality",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "User's request to place an order",
},
},
"required": ["query"],
},
},
},
],
)
print("----" * 30)
print(query)
print("----" * 30)
print(response)
print("----" * 30)
tool_call = response["message"].get("tool_calls", [])
tool_calls = [
tool_call[i].get("function").get("name")
for i in range(0, len(response["message"].get("tool_calls", [])))
]
print("----" * 30)
print(tool_call)
print("----" * 30)
print(tool_calls)
print("----" * 30)
tool_responses = []
for tool_name in tool_calls:
if tool_name == "menu_tool":
tool_call_index = next(
(
index
for index, call in enumerate(tool_call)
if call["function"]["name"] == "menu_tool"
),
None,
)
tool_args = tool_call[tool_call_index]["function"]["arguments"]
response = self.menu_tool.run(tool_args["query"], session_id)
print("menu tool response :: ", response)
tool_responses.append(response)
elif tool_name == "cart_tool":
tool_call_index = next(
(
index
for index, call in enumerate(tool_call)
if call["function"]["name"] == "cart_tool"
),
None,
)
tool_args = tool_call[tool_call_index]["function"]["arguments"]
response = self.cart_tool.run(tool_args["query"], session_id=session_id)
print("cart tool response :: ", response)
tool_responses.append(response)
elif tool_name == "order_tool":
tool_call_index = next(
(
index
for index, call in enumerate(tool_call)
if call["function"]["name"] == "order_tool"
),
None,
)
tool_args = tool_call[tool_call_index]["function"]["arguments"]
print("order tool response :: ", response)
response = self.order_tool.run(
df=self.df,
session_id=session_id,
category=self.category,
items=self.items,
store_id=self.store_id,
brand_id=self.brand_id,
)
tool_responses.append(response)
combined_response = summarised_output(
messages=tool_responses,
chat_history=chat_history,
context_query=context_query,
user_query=query,
)
return combined_response