|
from tools import MenuTool, CartTool, OrderTool, greetings_function |
|
from data_section import data_bifercation |
|
from tools.prompts import tool_prompt_function |
|
from config import settings |
|
from utils import client |
|
from context import ollama_context_query, summarised_output |
|
import pandas as pd |
|
|
|
|
|
class ReactAgent: |
|
def __init__(self): |
|
self.store_id = "66dff7a04b17303d454d4bbc" |
|
self.brand_id = "66cec85093c5b0896c9125c5" |
|
columns = ["category", "item", "price"] |
|
main_data, category, items = data_bifercation(self.store_id, self.brand_id) |
|
|
|
self.items = items |
|
self.category = category |
|
|
|
df = pd.DataFrame(main_data) |
|
df.columns = columns |
|
df["item"] = df["item"].str.lower() |
|
df["category"] = df["category"].str.lower() |
|
|
|
self.df = df |
|
self.menu_tool = MenuTool(df) |
|
self.cart_tool = CartTool(df) |
|
self.order_tool = OrderTool(df) |
|
self.llm_client = client |
|
|
|
def handle_query(self, session_id, query, chat_history): |
|
|
|
prompt = tool_prompt_function(current_query=query, session_id=session_id) |
|
|
|
context_query, greet_bool = ollama_context_query( |
|
chat_history=chat_history, user_query=query |
|
) |
|
|
|
if not greet_bool: |
|
return greetings_function(query) |
|
|
|
if context_query in ["MenuTool", "CartTool", "OrderTool"]: |
|
user_message_content = query |
|
else: |
|
user_message_content = context_query |
|
|
|
messages = [ |
|
{"role": "system", "content": prompt}, |
|
{"role": "user", "content": user_message_content}, |
|
] |
|
|
|
response = self.llm_client.chat( |
|
model=settings.MODEL_NAME, |
|
messages=messages, |
|
tools=[ |
|
{ |
|
"type": "function", |
|
"function": { |
|
"name": "menu_tool", |
|
"description": "Fetch the restaurant menu based on user input", |
|
"parameters": { |
|
"type": "object", |
|
"properties": { |
|
"query": { |
|
"type": "string", |
|
"description": "User's natural language query for the menu", |
|
} |
|
}, |
|
"required": ["query"], |
|
}, |
|
}, |
|
}, |
|
{ |
|
"type": "function", |
|
"function": { |
|
"name": "cart_tool", |
|
"description": "Manage the cart based on user input (add/remove/view)", |
|
"parameters": { |
|
"type": "object", |
|
"properties": { |
|
"query": { |
|
"type": "string", |
|
"description": "User's cart query to add/remove/view items", |
|
}, |
|
"session_id": { |
|
"type": "string", |
|
"description": "current session id", |
|
}, |
|
}, |
|
"required": ["query", "session_id"], |
|
}, |
|
}, |
|
}, |
|
{ |
|
"type": "function", |
|
"function": { |
|
"name": "order_tool", |
|
"description": "Handle order and checkout functionality", |
|
"parameters": { |
|
"type": "object", |
|
"properties": { |
|
"query": { |
|
"type": "string", |
|
"description": "User's request to place an order", |
|
}, |
|
}, |
|
"required": ["query"], |
|
}, |
|
}, |
|
}, |
|
], |
|
) |
|
|
|
print("----" * 30) |
|
print(query) |
|
print("----" * 30) |
|
print(response) |
|
print("----" * 30) |
|
|
|
tool_call = response["message"].get("tool_calls", []) |
|
tool_calls = [ |
|
tool_call[i].get("function").get("name") |
|
for i in range(0, len(response["message"].get("tool_calls", []))) |
|
] |
|
|
|
print("----" * 30) |
|
print(tool_call) |
|
print("----" * 30) |
|
print(tool_calls) |
|
print("----" * 30) |
|
|
|
tool_responses = [] |
|
for tool_name in tool_calls: |
|
if tool_name == "menu_tool": |
|
tool_call_index = next( |
|
( |
|
index |
|
for index, call in enumerate(tool_call) |
|
if call["function"]["name"] == "menu_tool" |
|
), |
|
None, |
|
) |
|
tool_args = tool_call[tool_call_index]["function"]["arguments"] |
|
response = self.menu_tool.run(tool_args["query"], session_id) |
|
print("menu tool response :: ", response) |
|
tool_responses.append(response) |
|
|
|
elif tool_name == "cart_tool": |
|
tool_call_index = next( |
|
( |
|
index |
|
for index, call in enumerate(tool_call) |
|
if call["function"]["name"] == "cart_tool" |
|
), |
|
None, |
|
) |
|
tool_args = tool_call[tool_call_index]["function"]["arguments"] |
|
response = self.cart_tool.run(tool_args["query"], session_id=session_id) |
|
print("cart tool response :: ", response) |
|
tool_responses.append(response) |
|
|
|
elif tool_name == "order_tool": |
|
tool_call_index = next( |
|
( |
|
index |
|
for index, call in enumerate(tool_call) |
|
if call["function"]["name"] == "order_tool" |
|
), |
|
None, |
|
) |
|
tool_args = tool_call[tool_call_index]["function"]["arguments"] |
|
print("order tool response :: ", response) |
|
response = self.order_tool.run( |
|
df=self.df, |
|
session_id=session_id, |
|
category=self.category, |
|
items=self.items, |
|
store_id=self.store_id, |
|
brand_id=self.brand_id, |
|
) |
|
tool_responses.append(response) |
|
|
|
combined_response = summarised_output( |
|
messages=tool_responses, |
|
chat_history=chat_history, |
|
context_query=context_query, |
|
user_query=query, |
|
) |
|
|
|
return combined_response |
|
|