Théo Rousseaux
hugging face
6eebc5e
raw
history blame
4.32 kB
from langchain.tools import tool
from langchain.agents import AgentExecutor, create_tool_calling_agent
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.messages import HumanMessage
from langchain_mistralai.chat_models import ChatMistralAI
import torch
import os
import sys
import json
sys.path.append(os.getcwd())
from Modules.PoseEstimation.pose_estimator import model, compute_right_knee_angle, moving_average, save_knee_angle_fig
# If api_key is not passed, default behavior is to use the `MISTRAL_API_KEY` environment variable.
llm = ChatMistralAI(model='mistral-large-latest', api_key="i5jSJkCFNGKfgIztloxTMjfckiFbYBj4")
device = 'cuda' if torch.cuda.is_available() else 'cpu'
@tool
def get_keypoints_from_keypoints(video_path: str) -> str:
"""
Extracts keypoints from a video file.
Args:
video_path (str): path to the video file
Returns:
file_path (str): path to the JSON file containing the keypoints
"""
save_folder='tmp'
os.makedirs(save_folder, exist_ok=True)
keypoints = []
results = model(video_path, save=True, show_conf=False, show_boxes=False, device=device)
for (i, frame) in enumerate(results):
frame_dict = {}
frame_dict['frame'] = i
frame_dict['keypoints'] = frame.keypoints.xy[0].tolist()
keypoints.append(frame_dict)
file_path = os.path.join(save_folder, 'keypoints.json')
with open(file_path, 'w') as f:
json.dump(keypoints, f)
return file_path
def compute_right_knee_angle_list(json_path: str) -> list[float]:
"""
Computes the knee angle from a list of keypoints.
Args:
json_path (str): path to the JSON file containing the keypoints
Returns:
right_knee_angle_list (list[float]): list of knee angles
"""
keypoints_list = json.load(open(json_path))
right_knee_angle_list = []
for keypoints in keypoints_list:
right_knee_angle = compute_right_knee_angle(keypoints['keypoints'])
right_knee_angle_list.append(right_knee_angle)
right_knee_angle_list = moving_average(right_knee_angle_list, 10)
save_knee_angle_fig(right_knee_angle_list)
return right_knee_angle_list
def check_knee_angle(json_path: str) -> bool:
"""
Checks if the minimum knee angle is smaller than a threshold.
If the minimum knee angle is smaller than 90 degrees, the squat is considered correct, because it means the user is going deep enough.
Args:
json_path (str): path to the JSON file containing the keypoints
Returns:
is_correct (bool): True if the minimum knee angle is smaller than a threshold, False otherwise
"""
angles_list = compute_right_knee_angle_list(json_path)
for angle in angles_list:
if angle < 90:
return True
return False
@tool
def check_squat(file_name: str) -> str:
"""
Checks if the squat is correct.
This function uses the check_knee_angle tool to check if the squat is correct. If the minimum knee angle is smaller than 90 degrees, the squat is considered correct, because it means the user is going deep enough.
Args:
video_path (str): path to the video file
Returns:
is_correct (bool): True if the squat is correct, False otherwise
"""
video_path = os.path.join('uploaded', file_name)
if os.path.exists(video_path):
json_path = get_keypoints_from_keypoints(video_path)
is_correct = check_knee_angle(json_path)
if is_correct:
return "The squat is correct because your knee angle is smaller than 90 degrees."
else:
return "The squat is incorrect because your knee angle is greater than 90 degrees."
else:
return "The video file does not exist."
tools = [check_squat]
prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"You are a helpful assistant. Make sure to use the check_squat tool if the user wants to check his movement. Also explain your response",
),
("placeholder", "{chat_history}"),
("human", "{input}"),
("placeholder", "{agent_scratchpad}"),
]
)
# Construct the Tools agent
agent = create_tool_calling_agent(llm, tools, prompt)
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)