from tools import MenuTool, CartTool, OrderTool, greetings_function from data_section import data_bifercation from tools.prompts import tool_prompt_function from config import settings from utils import client from context import ollama_context_query, summarised_output import pandas as pd class ReactAgent: def __init__(self): self.store_id = "66dff7a04b17303d454d4bbc" self.brand_id = "66cec85093c5b0896c9125c5" columns = ["category", "item", "price"] main_data, category, items = data_bifercation(self.store_id, self.brand_id) self.items = items self.category = category df = pd.DataFrame(main_data) df.columns = columns df["item"] = df["item"].str.lower() df["category"] = df["category"].str.lower() self.df = df self.menu_tool = MenuTool(df) self.cart_tool = CartTool(df) self.order_tool = OrderTool(df) self.llm_client = client def handle_query(self, session_id, query, chat_history): prompt = tool_prompt_function(current_query=query, session_id=session_id) context_query, greet_bool = ollama_context_query( chat_history=chat_history, user_query=query ) if not greet_bool: return greetings_function(query) if context_query in ["MenuTool", "CartTool", "OrderTool"]: user_message_content = query else: user_message_content = context_query messages = [ {"role": "system", "content": prompt}, {"role": "user", "content": user_message_content}, ] response = self.llm_client.chat( model=settings.MODEL_NAME, messages=messages, tools=[ { "type": "function", "function": { "name": "menu_tool", "description": "Fetch the restaurant menu based on user input", "parameters": { "type": "object", "properties": { "query": { "type": "string", "description": "User's natural language query for the menu", } }, "required": ["query"], }, }, }, { "type": "function", "function": { "name": "cart_tool", "description": "Manage the cart based on user input (add/remove/view)", "parameters": { "type": "object", "properties": { "query": { "type": "string", "description": "User's cart query to add/remove/view items", }, "session_id": { "type": "string", "description": "current session id", }, }, "required": ["query", "session_id"], }, }, }, { "type": "function", "function": { "name": "order_tool", "description": "Handle order and checkout functionality", "parameters": { "type": "object", "properties": { "query": { "type": "string", "description": "User's request to place an order", }, }, "required": ["query"], }, }, }, ], ) print("----" * 30) print(query) print("----" * 30) print(response) print("----" * 30) tool_call = response["message"].get("tool_calls", []) tool_calls = [ tool_call[i].get("function").get("name") for i in range(0, len(response["message"].get("tool_calls", []))) ] print("----" * 30) print(tool_call) print("----" * 30) print(tool_calls) print("----" * 30) tool_responses = [] for tool_name in tool_calls: if tool_name == "menu_tool": tool_call_index = next( ( index for index, call in enumerate(tool_call) if call["function"]["name"] == "menu_tool" ), None, ) tool_args = tool_call[tool_call_index]["function"]["arguments"] response = self.menu_tool.run(tool_args["query"], session_id) print("menu tool response :: ", response) tool_responses.append(response) elif tool_name == "cart_tool": tool_call_index = next( ( index for index, call in enumerate(tool_call) if call["function"]["name"] == "cart_tool" ), None, ) tool_args = tool_call[tool_call_index]["function"]["arguments"] response = self.cart_tool.run(tool_args["query"], session_id=session_id) print("cart tool response :: ", response) tool_responses.append(response) elif tool_name == "order_tool": tool_call_index = next( ( index for index, call in enumerate(tool_call) if call["function"]["name"] == "order_tool" ), None, ) tool_args = tool_call[tool_call_index]["function"]["arguments"] print("order tool response :: ", response) response = self.order_tool.run( df=self.df, session_id=session_id, category=self.category, items=self.items, store_id=self.store_id, brand_id=self.brand_id, ) tool_responses.append(response) combined_response = summarised_output( messages=tool_responses, chat_history=chat_history, context_query=context_query, user_query=query, ) return combined_response