Spaces:
Sleeping
Sleeping
from langchain.tools import tool | |
from langchain.agents import AgentExecutor, create_tool_calling_agent | |
from langchain_core.prompts import ChatPromptTemplate | |
from langchain_core.messages import HumanMessage | |
from langchain_mistralai.chat_models import ChatMistralAI | |
import torch | |
import os | |
import sys | |
import json | |
sys.path.append(os.getcwd()) | |
from Modules.PoseEstimation.pose_estimator import model, compute_right_knee_angle, moving_average, save_knee_angle_fig | |
# If api_key is not passed, default behavior is to use the `MISTRAL_API_KEY` environment variable. | |
llm = ChatMistralAI(model='mistral-large-latest', api_key="i5jSJkCFNGKfgIztloxTMjfckiFbYBj4") | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
def get_keypoints_from_keypoints(video_path: str) -> str: | |
""" | |
Extracts keypoints from a video file. | |
Args: | |
video_path (str): path to the video file | |
Returns: | |
file_path (str): path to the JSON file containing the keypoints | |
""" | |
save_folder='tmp' | |
os.makedirs(save_folder, exist_ok=True) | |
keypoints = [] | |
results = model(video_path, save=True, show_conf=False, show_boxes=False, device=device) | |
for (i, frame) in enumerate(results): | |
frame_dict = {} | |
frame_dict['frame'] = i | |
frame_dict['keypoints'] = frame.keypoints.xy[0].tolist() | |
keypoints.append(frame_dict) | |
file_path = os.path.join(save_folder, 'keypoints.json') | |
with open(file_path, 'w') as f: | |
json.dump(keypoints, f) | |
return file_path | |
def compute_right_knee_angle_list(json_path: str) -> list[float]: | |
""" | |
Computes the knee angle from a list of keypoints. | |
Args: | |
json_path (str): path to the JSON file containing the keypoints | |
Returns: | |
right_knee_angle_list (list[float]): list of knee angles | |
""" | |
keypoints_list = json.load(open(json_path)) | |
right_knee_angle_list = [] | |
for keypoints in keypoints_list: | |
right_knee_angle = compute_right_knee_angle(keypoints['keypoints']) | |
right_knee_angle_list.append(right_knee_angle) | |
right_knee_angle_list = moving_average(right_knee_angle_list, 10) | |
save_knee_angle_fig(right_knee_angle_list) | |
return right_knee_angle_list | |
def check_knee_angle(json_path: str) -> bool: | |
""" | |
Checks if the minimum knee angle is smaller than a threshold. | |
If the minimum knee angle is smaller than 90 degrees, the squat is considered correct, because it means the user is going deep enough. | |
Args: | |
json_path (str): path to the JSON file containing the keypoints | |
Returns: | |
is_correct (bool): True if the minimum knee angle is smaller than a threshold, False otherwise | |
""" | |
angles_list = compute_right_knee_angle_list(json_path) | |
for angle in angles_list: | |
if angle < 90: | |
return True | |
return False | |
def check_squat(file_name: str) -> str: | |
""" | |
Checks if the squat is correct. | |
This function uses the check_knee_angle tool to check if the squat is correct. If the minimum knee angle is smaller than 90 degrees, the squat is considered correct, because it means the user is going deep enough. | |
Args: | |
video_path (str): path to the video file | |
Returns: | |
is_correct (bool): True if the squat is correct, False otherwise | |
""" | |
video_path = os.path.join('uploaded', file_name) | |
if os.path.exists(video_path): | |
json_path = get_keypoints_from_keypoints(video_path) | |
is_correct = check_knee_angle(json_path) | |
if is_correct: | |
return "The squat is correct because your knee angle is smaller than 90 degrees." | |
else: | |
return "The squat is incorrect because your knee angle is greater than 90 degrees." | |
else: | |
return "The video file does not exist." | |
tools = [check_squat] | |
prompt = ChatPromptTemplate.from_messages( | |
[ | |
( | |
"system", | |
"You are a helpful assistant. Make sure to use the check_squat tool if the user wants to check his movement. Also explain your response", | |
), | |
("placeholder", "{chat_history}"), | |
("human", "{input}"), | |
("placeholder", "{agent_scratchpad}"), | |
] | |
) | |
# Construct the Tools agent | |
agent = create_tool_calling_agent(llm, tools, prompt) | |
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True) |