{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "d602e75d",
   "metadata": {},
   "source": [
    "# 0. Import Dependencies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "2d97e5ec",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "c:\\Users\\cpras\\anaconda3\\envs\\AItrainer\\lib\\site-packages\\numpy\\_distributor_init.py:30: UserWarning: loaded more than 1 DLL from .libs:\n",
      "c:\\Users\\cpras\\anaconda3\\envs\\AItrainer\\lib\\site-packages\\numpy\\.libs\\libopenblas.FB5AE2TYXYH2IJRDKGDGQ3XBKLKTF43H.gfortran-win_amd64.dll\n",
      "c:\\Users\\cpras\\anaconda3\\envs\\AItrainer\\lib\\site-packages\\numpy\\.libs\\libopenblas.WCDJNK7YVMPZQ2ME2ZZHJJRJ3JIKNDB7.gfortran-win_amd64.dll\n",
      "  warnings.warn(\"loaded more than 1 DLL from .libs:\"\n"
     ]
    }
   ],
   "source": [
    "import cv2\n",
    "import numpy as np\n",
    "import os\n",
    "from matplotlib import pyplot as plt\n",
    "import time\n",
    "import mediapipe as mp\n",
    "import tensorflow as tf\n",
    "import math\n",
    "\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.metrics import multilabel_confusion_matrix, accuracy_score, classification_report\n",
    "from tensorflow.keras.utils import to_categorical\n",
    "\n",
    "import tensorflow as tf\n",
    "from tensorflow.keras import backend as K\n",
    "from tensorflow.keras.callbacks import TensorBoard, EarlyStopping, ReduceLROnPlateau, ModelCheckpoint\n",
    "\n",
    "from tensorflow.keras.models import Sequential, Model\n",
    "\n",
    "from tensorflow.keras.layers import (LSTM, Dense, Concatenate, Attention, Dropout, Softmax,\n",
    "                                     Input, Flatten, Activation, Bidirectional, Permute, multiply, \n",
    "                                     ConvLSTM2D, MaxPooling3D, TimeDistributed, Conv2D, MaxPooling2D)\n",
    "\n",
    "from scipy import stats\n",
    "\n",
    "# disable some of the tf/keras training warnings \n",
    "os.environ['TF_CPP_MIN_LOG_LEVEL'] = \"3\"\n",
    "tf.get_logger().setLevel(\"ERROR\")\n",
    "tf.autograph.set_verbosity(1)\n",
    "\n",
    "# suppress untraced functions warning\n",
    "import absl.logging\n",
    "absl.logging.set_verbosity(absl.logging.ERROR)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "55f470e2",
   "metadata": {},
   "source": [
    "# 1. Keypoints using MP Pose"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "20cde117",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Pre-trained pose estimation model from Google Mediapipe\n",
    "mp_pose = mp.solutions.pose\n",
    "\n",
    "# Supported Mediapipe visualization tools\n",
    "mp_drawing = mp.solutions.drawing_utils"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "716e9f8e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def mediapipe_detection(image, model):\n",
    "    \"\"\"\n",
    "    This function detects human pose estimation keypoints from webcam footage\n",
    "    \n",
    "    \"\"\"\n",
    "    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # COLOR CONVERSION BGR 2 RGB\n",
    "    image.flags.writeable = False                  # Image is no longer writeable\n",
    "    results = model.process(image)                 # Make prediction\n",
    "    image.flags.writeable = True                   # Image is now writeable \n",
    "    image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) # COLOR COVERSION RGB 2 BGR\n",
    "    return image, results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "9bd7ba58",
   "metadata": {},
   "outputs": [],
   "source": [
    "def draw_landmarks(image, results):\n",
    "    \"\"\"\n",
    "    This function draws keypoints and landmarks detected by the human pose estimation model\n",
    "    \n",
    "    \"\"\"\n",
    "    mp_drawing.draw_landmarks(image, results.pose_landmarks, mp_pose.POSE_CONNECTIONS,\n",
    "                                mp_drawing.DrawingSpec(color=(245,117,66), thickness=2, circle_radius=2), \n",
    "                                mp_drawing.DrawingSpec(color=(245,66,230), thickness=2, circle_radius=2) \n",
    "                                 )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "c0ebe952",
   "metadata": {},
   "outputs": [],
   "source": [
    "cap = cv2.VideoCapture(0) # camera object\n",
    "HEIGHT = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) # webcam video frame height\n",
    "WIDTH = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) # webcam video frame width\n",
    "FPS = int(cap.get(cv2.CAP_PROP_FPS)) # webcam video fram rate \n",
    "\n",
    "# Set and test mediapipe model using webcam\n",
    "with mp_pose.Pose(min_detection_confidence=0.5, min_tracking_confidence=0.5) as pose:\n",
    "    while cap.isOpened():\n",
    "\n",
    "        # Read feed\n",
    "        ret, frame = cap.read()\n",
    "      \n",
    "        # Make detection\n",
    "        image, results = mediapipe_detection(frame, pose)\n",
    "        \n",
    "        # Extract landmarks\n",
    "        try:\n",
    "            landmarks = results.pose_landmarks.landmark\n",
    "        except:\n",
    "            pass\n",
    "        \n",
    "        # Render detections\n",
    "        draw_landmarks(image, results)               \n",
    "        \n",
    "        # Display frame on screen\n",
    "        cv2.imshow('OpenCV Feed', image)\n",
    "        \n",
    "        # Exit / break out logic\n",
    "        if cv2.waitKey(10) & 0xFF == ord('q'):\n",
    "            break\n",
    "\n",
    "    cap.release()\n",
    "    cv2.destroyAllWindows()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f9fb32f3",
   "metadata": {},
   "source": [
    "# 2. Extract Keypoints"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "a81823f6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Recollect and organize keypoints from the test\n",
    "pose = []\n",
    "for res in results.pose_landmarks.landmark:\n",
    "    test = np.array([res.x, res.y, res.z, res.visibility])\n",
    "    pose.append(test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "cd92eee3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 33 landmarks with 4 values (x, y, z, visibility)\n",
    "num_landmarks = len(landmarks)\n",
    "num_values = len(test)\n",
    "num_input_values = num_landmarks*num_values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "9dad5b8c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# This is an example of what we would use as an input into our AI models\n",
    "pose = np.array([[res.x, res.y, res.z, res.visibility] for res in results.pose_landmarks.landmark]).flatten() if results.pose_landmarks else np.zeros(33*4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "6f4d1079",
   "metadata": {},
   "outputs": [],
   "source": [
    "def extract_keypoints(results):\n",
    "    \"\"\"\n",
    "    Processes and organizes the keypoints detected from the pose estimation model \n",
    "    to be used as inputs for the exercise decoder models\n",
    "    \n",
    "    \"\"\"\n",
    "    pose = np.array([[res.x, res.y, res.z, res.visibility] for res in results.pose_landmarks.landmark]).flatten() if results.pose_landmarks else np.zeros(33*4)\n",
    "    return pose"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "03b907e8",
   "metadata": {},
   "source": [
    "# 3. Setup Folders for Collection"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "ddcaecfd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "c:\\Users\\cpras\\Documents\\GitHub\\AI_Personal_Trainer\\data\n"
     ]
    }
   ],
   "source": [
    "# Path for exported data, numpy arrays\n",
    "DATA_PATH = os.path.join(os. getcwd(),'data') \n",
    "print(DATA_PATH)\n",
    "\n",
    "# make directory if it does not exist yet\n",
    "if not os.path.exists(DATA_PATH):\n",
    "    os.makedirs(DATA_PATH)\n",
    "\n",
    "# Actions/exercises that we try to detect\n",
    "actions = np.array(['curl', 'press', 'squat'])\n",
    "num_classes = len(actions)\n",
    "\n",
    "# How many videos worth of data\n",
    "no_sequences = 50\n",
    "\n",
    "# Videos are going to be this many frames in length\n",
    "sequence_length = FPS*1\n",
    "\n",
    "# Folder start\n",
    "# Change this to collect more data and not lose previously collected data\n",
    "start_folder = 101"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "id": "fed6b275",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Build folder paths\n",
    "for action in actions:     \n",
    "    for sequence in range(start_folder,no_sequences+start_folder):\n",
    "        try: \n",
    "            os.makedirs(os.path.join(DATA_PATH, action, str(sequence)))  \n",
    "        except:\n",
    "            pass"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7622b573",
   "metadata": {},
   "source": [
    "# 4. Collect Keypoint Values for Training and Testing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "d224561f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Colors associated with each exercise (e.g., curls are denoted by blue, squats are denoted by orange, etc.)\n",
    "colors = [(245,117,16), (117,245,16), (16,117,245)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "id": "41b81490",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Collect Training Data\n",
    "\n",
    "cap = cv2.VideoCapture(0)\n",
    "# Set mediapipe model \n",
    "with mp_pose.Pose(min_detection_confidence=0.5, min_tracking_confidence=0.5) as pose:\n",
    "    # Loop through actions\n",
    "    for idx, action in enumerate(actions):\n",
    "        # Loop through sequences (i.e., videos)\n",
    "        for sequence in range(start_folder, start_folder+no_sequences):\n",
    "            # Loop through video length (i.e, sequence length)\n",
    "            for frame_num in range(sequence_length):\n",
    "                # Read feed\n",
    "                ret, frame = cap.read()\n",
    "                \n",
    "                # Make detection\n",
    "                image, results = mediapipe_detection(frame, pose)\n",
    "\n",
    "                # Extract landmarks\n",
    "                try:\n",
    "                    landmarks = results.pose_landmarks.landmark\n",
    "                except:\n",
    "                    pass\n",
    "                \n",
    "                # Render detections\n",
    "                draw_landmarks(image, results) \n",
    "\n",
    "                # Apply visualization logic\n",
    "                if frame_num == 0: # If first frame in sequence, print that you're starting a new data collection and wait 500 ms\n",
    "                    cv2.putText(image, 'STARTING COLLECTION', (120,200), \n",
    "                            cv2.FONT_HERSHEY_SIMPLEX, 1, (0,255, 0), 4, cv2.LINE_AA)\n",
    "                    \n",
    "                    cv2.putText(image, 'Collecting {} Video # {}'.format(action, sequence), (15,30), \n",
    "                            cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 8, cv2.LINE_AA)\n",
    "                    cv2.putText(image, 'Collecting {} Video # {}'.format(action, sequence), (15,30), \n",
    "                            cv2.FONT_HERSHEY_SIMPLEX, 1, colors[idx], 4, cv2.LINE_AA)\n",
    "                    \n",
    "                    # Show to screen\n",
    "                    cv2.imshow('OpenCV Feed', image)\n",
    "                    cv2.waitKey(500)\n",
    "                else: \n",
    "                    cv2.putText(image, 'Collecting {} Video # {}'.format(action, sequence), (15,30), \n",
    "                            cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 8, cv2.LINE_AA)\n",
    "                    cv2.putText(image, 'Collecting {} Video # {}'.format(action, sequence), (15,30), \n",
    "                            cv2.FONT_HERSHEY_SIMPLEX, 1, colors[idx], 4, cv2.LINE_AA)\n",
    "                    \n",
    "                    # Show to screen\n",
    "                    cv2.imshow('OpenCV Feed', image)\n",
    "\n",
    "                # Export keypoints (sequence + pose landmarks)\n",
    "                keypoints = extract_keypoints(results)\n",
    "                npy_path = os.path.join(DATA_PATH, action, str(sequence), str(frame_num))\n",
    "                np.save(npy_path, keypoints)\n",
    "\n",
    "                # Break gracefully\n",
    "                if cv2.waitKey(10) & 0xFF == ord('q'):\n",
    "                    break\n",
    "                    \n",
    "    cap.release()\n",
    "    cv2.destroyAllWindows()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "4b016129",
   "metadata": {},
   "outputs": [],
   "source": [
    "cap.release()\n",
    "cv2.destroyAllWindows()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d01ec993",
   "metadata": {},
   "source": [
    "# 5. Preprocess Data and Create Labels/Features"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "cad528c8",
   "metadata": {},
   "outputs": [],
   "source": [
    "label_map = {label:num for num, label in enumerate(actions)}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "a0add3fb",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load and organize recorded training data\n",
    "sequences, labels = [], []\n",
    "for action in actions:\n",
    "    for sequence in np.array(os.listdir(os.path.join(DATA_PATH, action))).astype(int):\n",
    "        window = []\n",
    "        for frame_num in range(sequence_length):         \n",
    "            # LSTM input data\n",
    "            res = np.load(os.path.join(DATA_PATH, action, str(sequence), \"{}.npy\".format(frame_num)))\n",
    "            window.append(res)  \n",
    "            \n",
    "        sequences.append(window)\n",
    "        labels.append(label_map[action])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "ab459ce4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(450, 30, 132) (450, 3)\n"
     ]
    }
   ],
   "source": [
    "# Make sure first dimensions of arrays match\n",
    "X = np.array(sequences)\n",
    "y = to_categorical(labels).astype(int)\n",
    "print(X.shape, y.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "5ac49993",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(405, 30, 132) (405, 3)\n"
     ]
    }
   ],
   "source": [
    "# Split into training, validation, and testing datasets\n",
    "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.10, random_state=1)\n",
    "print(X_train.shape, y_train.shape)\n",
    "X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=15/90, random_state=2)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e53ae03d",
   "metadata": {},
   "source": [
    "# 6. Build and Train Neural Networks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "912f3153",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Callbacks to be used during neural network training \n",
    "es_callback = EarlyStopping(monitor='val_loss', min_delta=5e-4, patience=10, verbose=0, mode='min')\n",
    "lr_callback = ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=5, min_lr=0.00001, verbose=0, mode='min')\n",
    "chkpt_callback = ModelCheckpoint(filepath=DATA_PATH, monitor='val_loss', verbose=0, save_best_only=True, \n",
    "                                 save_weights_only=False, mode='min', save_freq=1)\n",
    "\n",
    "# Optimizer\n",
    "opt = tf.keras.optimizers.Adam(learning_rate=0.01)\n",
    "\n",
    "# some hyperparamters\n",
    "batch_size = 32\n",
    "max_epochs = 500"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c1ca0cad",
   "metadata": {},
   "source": [
    "## 6a. LSTM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "730f987b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set up Tensorboard logging and callbacks\n",
    "NAME = f\"ExerciseRecognition-LSTM-{int(time.time())}\"\n",
    "log_dir = os.path.join(os.getcwd(), 'logs', NAME,'')\n",
    "tb_callback = TensorBoard(log_dir=log_dir)\n",
    "\n",
    "callbacks = [tb_callback, es_callback, lr_callback, chkpt_callback]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "ae7595c1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"sequential\"\n",
      "_________________________________________________________________\n",
      " Layer (type)                Output Shape              Param #   \n",
      "=================================================================\n",
      " lstm (LSTM)                 (None, 30, 128)           133632    \n",
      "                                                                 \n",
      " lstm_1 (LSTM)               (None, 30, 256)           394240    \n",
      "                                                                 \n",
      " lstm_2 (LSTM)               (None, 128)               197120    \n",
      "                                                                 \n",
      " dense (Dense)               (None, 128)               16512     \n",
      "                                                                 \n",
      " dense_1 (Dense)             (None, 64)                8256      \n",
      "                                                                 \n",
      " dense_2 (Dense)             (None, 3)                 195       \n",
      "                                                                 \n",
      "=================================================================\n",
      "Total params: 749,955\n",
      "Trainable params: 749,955\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n",
      "None\n"
     ]
    }
   ],
   "source": [
    "lstm = Sequential()\n",
    "lstm.add(LSTM(128, return_sequences=True, activation='relu', input_shape=(sequence_length, num_input_values)))\n",
    "lstm.add(LSTM(256, return_sequences=True, activation='relu'))\n",
    "lstm.add(LSTM(128, return_sequences=False, activation='relu'))\n",
    "lstm.add(Dense(128, activation='relu'))\n",
    "lstm.add(Dense(64, activation='relu'))\n",
    "lstm.add(Dense(actions.shape[0], activation='softmax'))\n",
    "print(lstm.summary())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "8a10e698",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/500\n",
      "11/11 [==============================] - 58s 5s/step - loss: 1.1959 - categorical_accuracy: 0.4718 - val_loss: 1.1662 - val_categorical_accuracy: 0.3529 - lr: 0.0010\n",
      "Epoch 2/500\n",
      "11/11 [==============================] - 54s 5s/step - loss: 0.9662 - categorical_accuracy: 0.3887 - val_loss: 0.7277 - val_categorical_accuracy: 0.5441 - lr: 0.0010\n",
      "Epoch 3/500\n",
      "11/11 [==============================] - 55s 5s/step - loss: 0.7514 - categorical_accuracy: 0.7092 - val_loss: 0.9325 - val_categorical_accuracy: 0.5882 - lr: 0.0010\n",
      "Epoch 4/500\n",
      "11/11 [==============================] - 54s 5s/step - loss: 0.8151 - categorical_accuracy: 0.7537 - val_loss: 4.7314 - val_categorical_accuracy: 0.4706 - lr: 0.0010\n",
      "Epoch 5/500\n",
      "11/11 [==============================] - 55s 5s/step - loss: 0.8428 - categorical_accuracy: 0.7567 - val_loss: 0.6531 - val_categorical_accuracy: 0.9706 - lr: 0.0010\n",
      "Epoch 6/500\n",
      "11/11 [==============================] - 55s 5s/step - loss: 0.6416 - categorical_accuracy: 0.8427 - val_loss: 0.7113 - val_categorical_accuracy: 0.6765 - lr: 0.0010\n",
      "Epoch 7/500\n",
      "11/11 [==============================] - 55s 5s/step - loss: 0.7782 - categorical_accuracy: 0.5964 - val_loss: 0.6241 - val_categorical_accuracy: 0.5882 - lr: 0.0010\n",
      "Epoch 8/500\n",
      "11/11 [==============================] - 62s 6s/step - loss: 0.4477 - categorical_accuracy: 0.8071 - val_loss: 0.3213 - val_categorical_accuracy: 0.9559 - lr: 0.0010\n",
      "Epoch 9/500\n",
      "11/11 [==============================] - 55s 5s/step - loss: 0.1791 - categorical_accuracy: 0.9733 - val_loss: 2.7620 - val_categorical_accuracy: 0.9412 - lr: 0.0010\n",
      "Epoch 10/500\n",
      "11/11 [==============================] - 55s 5s/step - loss: 1.7325 - categorical_accuracy: 0.6469 - val_loss: 0.8761 - val_categorical_accuracy: 0.8824 - lr: 0.0010\n",
      "Epoch 11/500\n",
      "11/11 [==============================] - 55s 5s/step - loss: 0.7803 - categorical_accuracy: 0.6558 - val_loss: 0.6232 - val_categorical_accuracy: 0.7647 - lr: 0.0010\n",
      "Epoch 12/500\n",
      "11/11 [==============================] - 55s 5s/step - loss: 0.5408 - categorical_accuracy: 0.8783 - val_loss: 0.2358 - val_categorical_accuracy: 0.8971 - lr: 0.0010\n",
      "Epoch 13/500\n",
      "11/11 [==============================] - 55s 5s/step - loss: 0.2337 - categorical_accuracy: 0.9496 - val_loss: 0.1032 - val_categorical_accuracy: 0.9853 - lr: 0.0010\n",
      "Epoch 14/500\n",
      "11/11 [==============================] - 55s 5s/step - loss: 1.0003 - categorical_accuracy: 0.7092 - val_loss: 1.2232 - val_categorical_accuracy: 0.2941 - lr: 0.0010\n",
      "Epoch 15/500\n",
      "11/11 [==============================] - 55s 5s/step - loss: 0.9862 - categorical_accuracy: 0.4896 - val_loss: 0.7796 - val_categorical_accuracy: 0.6029 - lr: 0.0010\n",
      "Epoch 16/500\n",
      "11/11 [==============================] - 56s 5s/step - loss: 0.7276 - categorical_accuracy: 0.6558 - val_loss: 0.6212 - val_categorical_accuracy: 0.6324 - lr: 0.0010\n",
      "Epoch 17/500\n",
      "11/11 [==============================] - 55s 5s/step - loss: 0.6449 - categorical_accuracy: 0.6588 - val_loss: 0.5822 - val_categorical_accuracy: 0.6176 - lr: 0.0010\n",
      "Epoch 18/500\n",
      "11/11 [==============================] - 56s 5s/step - loss: 0.5741 - categorical_accuracy: 0.6736 - val_loss: 0.5254 - val_categorical_accuracy: 0.6324 - lr: 0.0010\n",
      "Epoch 19/500\n",
      "11/11 [==============================] - 65s 6s/step - loss: 0.5246 - categorical_accuracy: 0.6617 - val_loss: 0.4942 - val_categorical_accuracy: 0.6324 - lr: 2.0000e-04\n",
      "Epoch 20/500\n",
      "11/11 [==============================] - 54s 5s/step - loss: 0.4960 - categorical_accuracy: 0.6736 - val_loss: 0.4694 - val_categorical_accuracy: 0.6324 - lr: 2.0000e-04\n",
      "Epoch 21/500\n",
      "11/11 [==============================] - 55s 5s/step - loss: 0.4588 - categorical_accuracy: 0.6766 - val_loss: 0.4269 - val_categorical_accuracy: 0.6471 - lr: 2.0000e-04\n",
      "Epoch 22/500\n",
      "11/11 [==============================] - 54s 5s/step - loss: 0.4117 - categorical_accuracy: 0.6825 - val_loss: 0.3713 - val_categorical_accuracy: 0.6471 - lr: 2.0000e-04\n",
      "Epoch 23/500\n",
      "11/11 [==============================] - 62s 6s/step - loss: 0.3329 - categorical_accuracy: 0.7122 - val_loss: 0.2746 - val_categorical_accuracy: 0.9118 - lr: 2.0000e-04\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<keras.callbacks.History at 0x1d83f889850>"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "lstm.compile(optimizer=opt, loss='categorical_crossentropy', metrics=['categorical_accuracy'])\n",
    "lstm.fit(X_train, y_train, batch_size=batch_size, epochs=max_epochs, validation_data=(X_val, y_val), callbacks=callbacks)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8f58c4d8",
   "metadata": {},
   "source": [
    "## 6b. LSTM + Attention"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "c6e12666",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set up Tensorboard logging and callbacks\n",
    "NAME = f\"ExerciseRecognition-AttnLSTM-{int(time.time())}\"\n",
    "log_dir = os.path.join(os.getcwd(), 'logs', NAME,'')\n",
    "tb_callback = TensorBoard(log_dir=log_dir)\n",
    "\n",
    "callbacks = [tb_callback, es_callback, lr_callback, chkpt_callback]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "07591dac",
   "metadata": {},
   "outputs": [],
   "source": [
    "def attention_block(inputs, time_steps):\n",
    "    \"\"\"\n",
    "    Attention layer for deep neural network\n",
    "    \n",
    "    \"\"\"\n",
    "    # Attention weights\n",
    "    a = Permute((2, 1))(inputs)\n",
    "    a = Dense(time_steps, activation='softmax')(a)\n",
    "    \n",
    "    # Attention vector\n",
    "    a_probs = Permute((2, 1), name='attention_vec')(a)\n",
    "    \n",
    "    # Luong's multiplicative score\n",
    "    output_attention_mul = multiply([inputs, a_probs], name='attention_mul') \n",
    "    \n",
    "    return output_attention_mul"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "c5c2e2e0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"model\"\n",
      "__________________________________________________________________________________________________\n",
      " Layer (type)                   Output Shape         Param #     Connected to                     \n",
      "==================================================================================================\n",
      " input_1 (InputLayer)           [(None, 30, 132)]    0           []                               \n",
      "                                                                                                  \n",
      " bidirectional (Bidirectional)  (None, 30, 512)      796672      ['input_1[0][0]']                \n",
      "                                                                                                  \n",
      " permute (Permute)              (None, 512, 30)      0           ['bidirectional[0][0]']          \n",
      "                                                                                                  \n",
      " dense_3 (Dense)                (None, 512, 30)      930         ['permute[0][0]']                \n",
      "                                                                                                  \n",
      " attention_vec (Permute)        (None, 30, 512)      0           ['dense_3[0][0]']                \n",
      "                                                                                                  \n",
      " attention_mul (Multiply)       (None, 30, 512)      0           ['bidirectional[0][0]',          \n",
      "                                                                  'attention_vec[0][0]']          \n",
      "                                                                                                  \n",
      " flatten (Flatten)              (None, 15360)        0           ['attention_mul[0][0]']          \n",
      "                                                                                                  \n",
      " dense_4 (Dense)                (None, 512)          7864832     ['flatten[0][0]']                \n",
      "                                                                                                  \n",
      " dropout (Dropout)              (None, 512)          0           ['dense_4[0][0]']                \n",
      "                                                                                                  \n",
      " dense_5 (Dense)                (None, 3)            1539        ['dropout[0][0]']                \n",
      "                                                                                                  \n",
      "==================================================================================================\n",
      "Total params: 8,663,973\n",
      "Trainable params: 8,663,973\n",
      "Non-trainable params: 0\n",
      "__________________________________________________________________________________________________\n",
      "None\n"
     ]
    }
   ],
   "source": [
    "HIDDEN_UNITS = 256\n",
    "\n",
    "# Input\n",
    "inputs = Input(shape=(sequence_length, num_input_values))\n",
    "\n",
    "# Bi-LSTM\n",
    "lstm_out = Bidirectional(LSTM(HIDDEN_UNITS, return_sequences=True))(inputs)\n",
    "\n",
    "# Attention\n",
    "attention_mul = attention_block(lstm_out, sequence_length)\n",
    "attention_mul = Flatten()(attention_mul)\n",
    "\n",
    "# Fully Connected Layer\n",
    "x = Dense(2*HIDDEN_UNITS, activation='relu')(attention_mul)\n",
    "x = Dropout(0.5)(x)\n",
    "\n",
    "# Output\n",
    "x = Dense(actions.shape[0], activation='softmax')(x)\n",
    "\n",
    "# Bring it all together\n",
    "AttnLSTM = Model(inputs=[inputs], outputs=x)\n",
    "print(AttnLSTM.summary())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "cf2f988d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/500\n",
      "11/11 [==============================] - 146s 13s/step - loss: 0.9463 - categorical_accuracy: 0.5134 - val_loss: 0.4818 - val_categorical_accuracy: 0.8529 - lr: 0.0010\n",
      "Epoch 2/500\n",
      "11/11 [==============================] - 142s 13s/step - loss: 0.3754 - categorical_accuracy: 0.8576 - val_loss: 0.0902 - val_categorical_accuracy: 1.0000 - lr: 0.0010\n",
      "Epoch 3/500\n",
      "11/11 [==============================] - 148s 14s/step - loss: 0.0666 - categorical_accuracy: 0.9852 - val_loss: 0.0059 - val_categorical_accuracy: 1.0000 - lr: 0.0010\n",
      "Epoch 4/500\n",
      "11/11 [==============================] - 144s 13s/step - loss: 0.0391 - categorical_accuracy: 0.9852 - val_loss: 6.4974e-04 - val_categorical_accuracy: 1.0000 - lr: 0.0010\n",
      "Epoch 5/500\n",
      "11/11 [==============================] - 145s 13s/step - loss: 0.0871 - categorical_accuracy: 0.9763 - val_loss: 0.0122 - val_categorical_accuracy: 1.0000 - lr: 0.0010\n",
      "Epoch 6/500\n",
      "11/11 [==============================] - 145s 13s/step - loss: 0.0815 - categorical_accuracy: 0.9644 - val_loss: 0.0237 - val_categorical_accuracy: 1.0000 - lr: 0.0010\n",
      "Epoch 7/500\n",
      "11/11 [==============================] - 150s 14s/step - loss: 0.0190 - categorical_accuracy: 0.9911 - val_loss: 0.0328 - val_categorical_accuracy: 0.9853 - lr: 0.0010\n",
      "Epoch 8/500\n",
      "11/11 [==============================] - 142s 13s/step - loss: 0.0249 - categorical_accuracy: 0.9941 - val_loss: 6.4866e-04 - val_categorical_accuracy: 1.0000 - lr: 0.0010\n",
      "Epoch 9/500\n",
      "11/11 [==============================] - 144s 13s/step - loss: 0.0101 - categorical_accuracy: 0.9941 - val_loss: 0.0058 - val_categorical_accuracy: 1.0000 - lr: 0.0010\n",
      "Epoch 10/500\n",
      "11/11 [==============================] - 143s 13s/step - loss: 0.0173 - categorical_accuracy: 0.9941 - val_loss: 0.0012 - val_categorical_accuracy: 1.0000 - lr: 2.0000e-04\n",
      "Epoch 11/500\n",
      "11/11 [==============================] - 154s 14s/step - loss: 0.0176 - categorical_accuracy: 0.9941 - val_loss: 0.0010 - val_categorical_accuracy: 1.0000 - lr: 2.0000e-04\n",
      "Epoch 12/500\n",
      "11/11 [==============================] - 144s 13s/step - loss: 0.0040 - categorical_accuracy: 0.9970 - val_loss: 5.7718e-04 - val_categorical_accuracy: 1.0000 - lr: 2.0000e-04\n",
      "Epoch 13/500\n",
      "11/11 [==============================] - 144s 13s/step - loss: 0.0015 - categorical_accuracy: 1.0000 - val_loss: 5.7380e-04 - val_categorical_accuracy: 1.0000 - lr: 2.0000e-04\n",
      "Epoch 14/500\n",
      "11/11 [==============================] - 145s 13s/step - loss: 0.0014 - categorical_accuracy: 1.0000 - val_loss: 5.2094e-04 - val_categorical_accuracy: 1.0000 - lr: 2.0000e-04\n",
      "Epoch 15/500\n",
      "11/11 [==============================] - 149s 14s/step - loss: 0.0015 - categorical_accuracy: 1.0000 - val_loss: 4.4772e-04 - val_categorical_accuracy: 1.0000 - lr: 2.0000e-04\n",
      "Epoch 16/500\n",
      "11/11 [==============================] - 144s 13s/step - loss: 0.0012 - categorical_accuracy: 1.0000 - val_loss: 3.9085e-04 - val_categorical_accuracy: 1.0000 - lr: 2.0000e-04\n",
      "Epoch 17/500\n",
      "11/11 [==============================] - 147s 13s/step - loss: 0.0010 - categorical_accuracy: 1.0000 - val_loss: 3.4933e-04 - val_categorical_accuracy: 1.0000 - lr: 2.0000e-04\n",
      "Epoch 18/500\n",
      "11/11 [==============================] - 144s 13s/step - loss: 8.0251e-04 - categorical_accuracy: 1.0000 - val_loss: 3.1589e-04 - val_categorical_accuracy: 1.0000 - lr: 2.0000e-04\n",
      "Epoch 19/500\n",
      "11/11 [==============================] - 149s 14s/step - loss: 6.4664e-04 - categorical_accuracy: 1.0000 - val_loss: 2.9034e-04 - val_categorical_accuracy: 1.0000 - lr: 2.0000e-04\n",
      "Epoch 20/500\n",
      "11/11 [==============================] - 143s 13s/step - loss: 7.9226e-04 - categorical_accuracy: 1.0000 - val_loss: 2.6785e-04 - val_categorical_accuracy: 1.0000 - lr: 2.0000e-04\n",
      "Epoch 21/500\n",
      "11/11 [==============================] - 144s 13s/step - loss: 6.2462e-04 - categorical_accuracy: 1.0000 - val_loss: 2.4908e-04 - val_categorical_accuracy: 1.0000 - lr: 2.0000e-04\n",
      "Epoch 22/500\n",
      "11/11 [==============================] - 142s 13s/step - loss: 6.9292e-04 - categorical_accuracy: 1.0000 - val_loss: 2.3473e-04 - val_categorical_accuracy: 1.0000 - lr: 2.0000e-04\n",
      "Epoch 23/500\n",
      "11/11 [==============================] - 157s 14s/step - loss: 5.5603e-04 - categorical_accuracy: 1.0000 - val_loss: 2.2057e-04 - val_categorical_accuracy: 1.0000 - lr: 2.0000e-04\n",
      "Epoch 24/500\n",
      "11/11 [==============================] - 143s 13s/step - loss: 4.8737e-04 - categorical_accuracy: 1.0000 - val_loss: 2.0835e-04 - val_categorical_accuracy: 1.0000 - lr: 2.0000e-04\n",
      "Epoch 25/500\n",
      "11/11 [==============================] - 150s 14s/step - loss: 5.3003e-04 - categorical_accuracy: 1.0000 - val_loss: 2.0614e-04 - val_categorical_accuracy: 1.0000 - lr: 4.0000e-05\n",
      "Epoch 26/500\n",
      "11/11 [==============================] - 143s 13s/step - loss: 5.3267e-04 - categorical_accuracy: 1.0000 - val_loss: 2.0380e-04 - val_categorical_accuracy: 1.0000 - lr: 4.0000e-05\n",
      "Epoch 27/500\n",
      "11/11 [==============================] - 147s 13s/step - loss: 5.8821e-04 - categorical_accuracy: 1.0000 - val_loss: 2.0116e-04 - val_categorical_accuracy: 1.0000 - lr: 4.0000e-05\n",
      "Epoch 28/500\n",
      "11/11 [==============================] - 145s 13s/step - loss: 5.7868e-04 - categorical_accuracy: 1.0000 - val_loss: 1.9871e-04 - val_categorical_accuracy: 1.0000 - lr: 4.0000e-05\n",
      "Epoch 29/500\n",
      "11/11 [==============================] - 144s 13s/step - loss: 4.5697e-04 - categorical_accuracy: 1.0000 - val_loss: 1.9647e-04 - val_categorical_accuracy: 1.0000 - lr: 4.0000e-05\n",
      "Epoch 30/500\n",
      "11/11 [==============================] - 144s 13s/step - loss: 5.0632e-04 - categorical_accuracy: 1.0000 - val_loss: 1.9593e-04 - val_categorical_accuracy: 1.0000 - lr: 1.0000e-05\n",
      "Epoch 31/500\n",
      "11/11 [==============================] - 149s 14s/step - loss: 6.3565e-04 - categorical_accuracy: 1.0000 - val_loss: 1.9525e-04 - val_categorical_accuracy: 1.0000 - lr: 1.0000e-05\n",
      "Epoch 32/500\n",
      "11/11 [==============================] - 151s 13s/step - loss: 5.2290e-04 - categorical_accuracy: 1.0000 - val_loss: 1.9461e-04 - val_categorical_accuracy: 1.0000 - lr: 1.0000e-05\n",
      "Epoch 33/500\n",
      "11/11 [==============================] - 147s 13s/step - loss: 5.2975e-04 - categorical_accuracy: 1.0000 - val_loss: 1.9395e-04 - val_categorical_accuracy: 1.0000 - lr: 1.0000e-05\n",
      "Epoch 34/500\n",
      "11/11 [==============================] - 145s 13s/step - loss: 5.6739e-04 - categorical_accuracy: 1.0000 - val_loss: 1.9334e-04 - val_categorical_accuracy: 1.0000 - lr: 1.0000e-05\n",
      "Epoch 35/500\n",
      "11/11 [==============================] - 147s 13s/step - loss: 5.2916e-04 - categorical_accuracy: 1.0000 - val_loss: 1.9273e-04 - val_categorical_accuracy: 1.0000 - lr: 1.0000e-05\n",
      "Epoch 36/500\n",
      "11/11 [==============================] - 156s 14s/step - loss: 6.5789e-04 - categorical_accuracy: 1.0000 - val_loss: 1.9208e-04 - val_categorical_accuracy: 1.0000 - lr: 1.0000e-05\n",
      "Epoch 37/500\n",
      "11/11 [==============================] - 146s 13s/step - loss: 5.9525e-04 - categorical_accuracy: 1.0000 - val_loss: 1.9143e-04 - val_categorical_accuracy: 1.0000 - lr: 1.0000e-05\n",
      "Epoch 38/500\n",
      "11/11 [==============================] - 146s 13s/step - loss: 5.2344e-04 - categorical_accuracy: 1.0000 - val_loss: 1.9073e-04 - val_categorical_accuracy: 1.0000 - lr: 1.0000e-05\n",
      "Epoch 39/500\n",
      "11/11 [==============================] - 145s 13s/step - loss: 5.9891e-04 - categorical_accuracy: 1.0000 - val_loss: 1.9004e-04 - val_categorical_accuracy: 1.0000 - lr: 1.0000e-05\n",
      "Epoch 40/500\n",
      "11/11 [==============================] - 152s 14s/step - loss: 4.6774e-04 - categorical_accuracy: 1.0000 - val_loss: 1.8929e-04 - val_categorical_accuracy: 1.0000 - lr: 1.0000e-05\n",
      "Epoch 41/500\n",
      "11/11 [==============================] - 147s 13s/step - loss: 5.1332e-04 - categorical_accuracy: 1.0000 - val_loss: 1.8856e-04 - val_categorical_accuracy: 1.0000 - lr: 1.0000e-05\n",
      "Epoch 42/500\n",
      "11/11 [==============================] - 148s 13s/step - loss: 5.2163e-04 - categorical_accuracy: 1.0000 - val_loss: 1.8792e-04 - val_categorical_accuracy: 1.0000 - lr: 1.0000e-05\n",
      "Epoch 43/500\n",
      "11/11 [==============================] - 146s 13s/step - loss: 5.1293e-04 - categorical_accuracy: 1.0000 - val_loss: 1.8728e-04 - val_categorical_accuracy: 1.0000 - lr: 1.0000e-05\n",
      "Epoch 44/500\n",
      "11/11 [==============================] - 152s 14s/step - loss: 4.9182e-04 - categorical_accuracy: 1.0000 - val_loss: 1.8657e-04 - val_categorical_accuracy: 1.0000 - lr: 1.0000e-05\n",
      "Epoch 45/500\n",
      "11/11 [==============================] - 146s 13s/step - loss: 4.7899e-04 - categorical_accuracy: 1.0000 - val_loss: 1.8585e-04 - val_categorical_accuracy: 1.0000 - lr: 1.0000e-05\n",
      "Epoch 46/500\n",
      "11/11 [==============================] - 144s 13s/step - loss: 5.3644e-04 - categorical_accuracy: 1.0000 - val_loss: 1.8514e-04 - val_categorical_accuracy: 1.0000 - lr: 1.0000e-05\n",
      "Epoch 47/500\n",
      "11/11 [==============================] - 146s 13s/step - loss: 6.9909e-04 - categorical_accuracy: 1.0000 - val_loss: 1.8448e-04 - val_categorical_accuracy: 1.0000 - lr: 1.0000e-05\n",
      "Epoch 48/500\n",
      "11/11 [==============================] - 155s 14s/step - loss: 4.5525e-04 - categorical_accuracy: 1.0000 - val_loss: 1.8381e-04 - val_categorical_accuracy: 1.0000 - lr: 1.0000e-05\n",
      "Epoch 49/500\n",
      "11/11 [==============================] - 146s 13s/step - loss: 5.4966e-04 - categorical_accuracy: 1.0000 - val_loss: 1.8309e-04 - val_categorical_accuracy: 1.0000 - lr: 1.0000e-05\n",
      "Epoch 50/500\n",
      "11/11 [==============================] - 152s 14s/step - loss: 4.9885e-04 - categorical_accuracy: 1.0000 - val_loss: 1.8239e-04 - val_categorical_accuracy: 1.0000 - lr: 1.0000e-05\n",
      "Epoch 51/500\n",
      "11/11 [==============================] - 146s 13s/step - loss: 4.9701e-04 - categorical_accuracy: 1.0000 - val_loss: 1.8164e-04 - val_categorical_accuracy: 1.0000 - lr: 1.0000e-05\n",
      "Epoch 52/500\n",
      "11/11 [==============================] - 153s 14s/step - loss: 5.7490e-04 - categorical_accuracy: 1.0000 - val_loss: 1.8081e-04 - val_categorical_accuracy: 1.0000 - lr: 1.0000e-05\n",
      "Epoch 53/500\n",
      "11/11 [==============================] - 147s 13s/step - loss: 6.1352e-04 - categorical_accuracy: 1.0000 - val_loss: 1.8002e-04 - val_categorical_accuracy: 1.0000 - lr: 1.0000e-05\n",
      "Epoch 54/500\n",
      "11/11 [==============================] - 148s 13s/step - loss: 5.3162e-04 - categorical_accuracy: 1.0000 - val_loss: 1.7926e-04 - val_categorical_accuracy: 1.0000 - lr: 1.0000e-05\n",
      "Epoch 55/500\n",
      "11/11 [==============================] - 145s 13s/step - loss: 4.3685e-04 - categorical_accuracy: 1.0000 - val_loss: 1.7852e-04 - val_categorical_accuracy: 1.0000 - lr: 1.0000e-05\n",
      "Epoch 56/500\n",
      "11/11 [==============================] - 153s 14s/step - loss: 4.6293e-04 - categorical_accuracy: 1.0000 - val_loss: 1.7782e-04 - val_categorical_accuracy: 1.0000 - lr: 1.0000e-05\n",
      "Epoch 57/500\n",
      "11/11 [==============================] - 145s 13s/step - loss: 4.9126e-04 - categorical_accuracy: 1.0000 - val_loss: 1.7709e-04 - val_categorical_accuracy: 1.0000 - lr: 1.0000e-05\n",
      "Epoch 58/500\n",
      "11/11 [==============================] - 146s 13s/step - loss: 5.6628e-04 - categorical_accuracy: 1.0000 - val_loss: 1.7628e-04 - val_categorical_accuracy: 1.0000 - lr: 1.0000e-05\n",
      "Epoch 59/500\n",
      "11/11 [==============================] - 147s 13s/step - loss: 4.5964e-04 - categorical_accuracy: 1.0000 - val_loss: 1.7552e-04 - val_categorical_accuracy: 1.0000 - lr: 1.0000e-05\n",
      "Epoch 60/500\n",
      "11/11 [==============================] - 157s 14s/step - loss: 6.6674e-04 - categorical_accuracy: 1.0000 - val_loss: 1.7479e-04 - val_categorical_accuracy: 1.0000 - lr: 1.0000e-05\n",
      "Epoch 61/500\n",
      "11/11 [==============================] - 145s 13s/step - loss: 5.1090e-04 - categorical_accuracy: 1.0000 - val_loss: 1.7403e-04 - val_categorical_accuracy: 1.0000 - lr: 1.0000e-05\n",
      "Epoch 62/500\n",
      "11/11 [==============================] - 145s 13s/step - loss: 4.9706e-04 - categorical_accuracy: 1.0000 - val_loss: 1.7322e-04 - val_categorical_accuracy: 1.0000 - lr: 1.0000e-05\n",
      "Epoch 63/500\n",
      "11/11 [==============================] - 147s 13s/step - loss: 5.4192e-04 - categorical_accuracy: 1.0000 - val_loss: 1.7244e-04 - val_categorical_accuracy: 1.0000 - lr: 1.0000e-05\n",
      "Epoch 64/500\n",
      "11/11 [==============================] - 152s 14s/step - loss: 5.9925e-04 - categorical_accuracy: 1.0000 - val_loss: 1.7167e-04 - val_categorical_accuracy: 1.0000 - lr: 1.0000e-05\n",
      "Epoch 65/500\n",
      "11/11 [==============================] - 145s 13s/step - loss: 4.5704e-04 - categorical_accuracy: 1.0000 - val_loss: 1.7089e-04 - val_categorical_accuracy: 1.0000 - lr: 1.0000e-05\n",
      "Epoch 66/500\n",
      "11/11 [==============================] - 148s 13s/step - loss: 4.4658e-04 - categorical_accuracy: 1.0000 - val_loss: 1.7017e-04 - val_categorical_accuracy: 1.0000 - lr: 1.0000e-05\n",
      "Epoch 67/500\n",
      "11/11 [==============================] - 145s 13s/step - loss: 4.1316e-04 - categorical_accuracy: 1.0000 - val_loss: 1.6946e-04 - val_categorical_accuracy: 1.0000 - lr: 1.0000e-05\n",
      "Epoch 68/500\n",
      "11/11 [==============================] - 154s 14s/step - loss: 4.0055e-04 - categorical_accuracy: 1.0000 - val_loss: 1.6878e-04 - val_categorical_accuracy: 1.0000 - lr: 1.0000e-05\n",
      "Epoch 69/500\n",
      "11/11 [==============================] - 160s 15s/step - loss: 5.0363e-04 - categorical_accuracy: 1.0000 - val_loss: 1.6799e-04 - val_categorical_accuracy: 1.0000 - lr: 1.0000e-05\n",
      "Epoch 70/500\n",
      "11/11 [==============================] - 164s 15s/step - loss: 5.1680e-04 - categorical_accuracy: 1.0000 - val_loss: 1.6720e-04 - val_categorical_accuracy: 1.0000 - lr: 1.0000e-05\n",
      "Epoch 71/500\n",
      "11/11 [==============================] - 147s 13s/step - loss: 4.6495e-04 - categorical_accuracy: 1.0000 - val_loss: 1.6639e-04 - val_categorical_accuracy: 1.0000 - lr: 1.0000e-05\n",
      "Epoch 72/500\n",
      "11/11 [==============================] - 158s 14s/step - loss: 3.7927e-04 - categorical_accuracy: 1.0000 - val_loss: 1.6566e-04 - val_categorical_accuracy: 1.0000 - lr: 1.0000e-05\n",
      "Epoch 73/500\n",
      "11/11 [==============================] - 152s 14s/step - loss: 4.0601e-04 - categorical_accuracy: 1.0000 - val_loss: 1.6494e-04 - val_categorical_accuracy: 1.0000 - lr: 1.0000e-05\n",
      "Epoch 74/500\n",
      "11/11 [==============================] - 155s 14s/step - loss: 4.0291e-04 - categorical_accuracy: 1.0000 - val_loss: 1.6418e-04 - val_categorical_accuracy: 1.0000 - lr: 1.0000e-05\n",
      "Epoch 75/500\n",
      "11/11 [==============================] - 146s 13s/step - loss: 4.8252e-04 - categorical_accuracy: 1.0000 - val_loss: 1.6345e-04 - val_categorical_accuracy: 1.0000 - lr: 1.0000e-05\n",
      "Epoch 76/500\n",
      "11/11 [==============================] - 153s 14s/step - loss: 4.6234e-04 - categorical_accuracy: 1.0000 - val_loss: 1.6267e-04 - val_categorical_accuracy: 1.0000 - lr: 1.0000e-05\n",
      "Epoch 77/500\n",
      "11/11 [==============================] - 147s 13s/step - loss: 4.7696e-04 - categorical_accuracy: 1.0000 - val_loss: 1.6190e-04 - val_categorical_accuracy: 1.0000 - lr: 1.0000e-05\n",
      "Epoch 78/500\n",
      "11/11 [==============================] - 147s 13s/step - loss: 4.7757e-04 - categorical_accuracy: 1.0000 - val_loss: 1.6107e-04 - val_categorical_accuracy: 1.0000 - lr: 1.0000e-05\n",
      "Epoch 79/500\n",
      "11/11 [==============================] - 147s 13s/step - loss: 4.7498e-04 - categorical_accuracy: 1.0000 - val_loss: 1.6023e-04 - val_categorical_accuracy: 1.0000 - lr: 1.0000e-05\n",
      "Epoch 80/500\n",
      "11/11 [==============================] - 151s 14s/step - loss: 4.5403e-04 - categorical_accuracy: 1.0000 - val_loss: 1.5941e-04 - val_categorical_accuracy: 1.0000 - lr: 1.0000e-05\n",
      "Epoch 81/500\n",
      "11/11 [==============================] - 145s 13s/step - loss: 4.4490e-04 - categorical_accuracy: 1.0000 - val_loss: 1.5862e-04 - val_categorical_accuracy: 1.0000 - lr: 1.0000e-05\n",
      "Epoch 82/500\n",
      "11/11 [==============================] - 146s 13s/step - loss: 4.3853e-04 - categorical_accuracy: 1.0000 - val_loss: 1.5785e-04 - val_categorical_accuracy: 1.0000 - lr: 1.0000e-05\n",
      "Epoch 83/500\n",
      "11/11 [==============================] - 146s 13s/step - loss: 4.5649e-04 - categorical_accuracy: 1.0000 - val_loss: 1.5709e-04 - val_categorical_accuracy: 1.0000 - lr: 1.0000e-05\n",
      "Epoch 84/500\n",
      "11/11 [==============================] - 157s 14s/step - loss: 4.5282e-04 - categorical_accuracy: 1.0000 - val_loss: 1.5631e-04 - val_categorical_accuracy: 1.0000 - lr: 1.0000e-05\n",
      "Epoch 85/500\n",
      "11/11 [==============================] - 145s 13s/step - loss: 4.4384e-04 - categorical_accuracy: 1.0000 - val_loss: 1.5551e-04 - val_categorical_accuracy: 1.0000 - lr: 1.0000e-05\n",
      "Epoch 86/500\n",
      "11/11 [==============================] - 146s 13s/step - loss: 4.5909e-04 - categorical_accuracy: 1.0000 - val_loss: 1.5476e-04 - val_categorical_accuracy: 1.0000 - lr: 1.0000e-05\n",
      "Epoch 87/500\n",
      "11/11 [==============================] - 146s 13s/step - loss: 4.7136e-04 - categorical_accuracy: 1.0000 - val_loss: 1.5396e-04 - val_categorical_accuracy: 1.0000 - lr: 1.0000e-05\n",
      "Epoch 88/500\n",
      "11/11 [==============================] - 153s 14s/step - loss: 5.0904e-04 - categorical_accuracy: 1.0000 - val_loss: 1.5321e-04 - val_categorical_accuracy: 1.0000 - lr: 1.0000e-05\n",
      "Epoch 89/500\n",
      "11/11 [==============================] - 146s 13s/step - loss: 4.3501e-04 - categorical_accuracy: 1.0000 - val_loss: 1.5245e-04 - val_categorical_accuracy: 1.0000 - lr: 1.0000e-05\n",
      "Epoch 90/500\n",
      "11/11 [==============================] - 150s 14s/step - loss: 3.7642e-04 - categorical_accuracy: 1.0000 - val_loss: 1.5165e-04 - val_categorical_accuracy: 1.0000 - lr: 1.0000e-05\n",
      "Epoch 91/500\n",
      "11/11 [==============================] - 146s 13s/step - loss: 4.0477e-04 - categorical_accuracy: 1.0000 - val_loss: 1.5089e-04 - val_categorical_accuracy: 1.0000 - lr: 1.0000e-05\n",
      "Epoch 92/500\n",
      "11/11 [==============================] - 153s 14s/step - loss: 4.2010e-04 - categorical_accuracy: 1.0000 - val_loss: 1.5009e-04 - val_categorical_accuracy: 1.0000 - lr: 1.0000e-05\n",
      "Epoch 93/500\n",
      "11/11 [==============================] - 147s 13s/step - loss: 4.4686e-04 - categorical_accuracy: 1.0000 - val_loss: 1.4930e-04 - val_categorical_accuracy: 1.0000 - lr: 1.0000e-05\n",
      "Epoch 94/500\n",
      "11/11 [==============================] - 147s 13s/step - loss: 4.8152e-04 - categorical_accuracy: 1.0000 - val_loss: 1.4845e-04 - val_categorical_accuracy: 1.0000 - lr: 1.0000e-05\n",
      "Epoch 95/500\n",
      "11/11 [==============================] - 148s 13s/step - loss: 4.2370e-04 - categorical_accuracy: 1.0000 - val_loss: 1.4769e-04 - val_categorical_accuracy: 1.0000 - lr: 1.0000e-05\n",
      "Epoch 96/500\n",
      "11/11 [==============================] - 157s 14s/step - loss: 3.8550e-04 - categorical_accuracy: 1.0000 - val_loss: 1.4696e-04 - val_categorical_accuracy: 1.0000 - lr: 1.0000e-05\n",
      "Epoch 97/500\n",
      "11/11 [==============================] - 145s 13s/step - loss: 3.4308e-04 - categorical_accuracy: 1.0000 - val_loss: 1.4615e-04 - val_categorical_accuracy: 1.0000 - lr: 1.0000e-05\n",
      "Epoch 98/500\n",
      "11/11 [==============================] - 146s 13s/step - loss: 4.1931e-04 - categorical_accuracy: 1.0000 - val_loss: 1.4544e-04 - val_categorical_accuracy: 1.0000 - lr: 1.0000e-05\n",
      "Epoch 99/500\n",
      "11/11 [==============================] - 152s 14s/step - loss: 3.9290e-04 - categorical_accuracy: 1.0000 - val_loss: 1.4467e-04 - val_categorical_accuracy: 1.0000 - lr: 1.0000e-05\n",
      "Epoch 100/500\n",
      "11/11 [==============================] - 152s 14s/step - loss: 3.4640e-04 - categorical_accuracy: 1.0000 - val_loss: 1.4395e-04 - val_categorical_accuracy: 1.0000 - lr: 1.0000e-05\n",
      "Epoch 101/500\n",
      "11/11 [==============================] - 146s 13s/step - loss: 3.9130e-04 - categorical_accuracy: 1.0000 - val_loss: 1.4314e-04 - val_categorical_accuracy: 1.0000 - lr: 1.0000e-05\n",
      "Epoch 102/500\n",
      "11/11 [==============================] - 152s 14s/step - loss: 3.4727e-04 - categorical_accuracy: 1.0000 - val_loss: 1.4239e-04 - val_categorical_accuracy: 1.0000 - lr: 1.0000e-05\n",
      "Epoch 103/500\n",
      "11/11 [==============================] - 146s 13s/step - loss: 4.5238e-04 - categorical_accuracy: 1.0000 - val_loss: 1.4166e-04 - val_categorical_accuracy: 1.0000 - lr: 1.0000e-05\n",
      "Epoch 104/500\n",
      " 1/11 [=>............................] - ETA: 2:14 - loss: 3.6348e-04 - categorical_accuracy: 1.0000"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[1;32mc:\\Users\\cpras\\Documents\\GitHub\\AI_Personal_Trainer\\ExerciseDecoder.ipynb Cell 35\u001b[0m in \u001b[0;36m<cell line: 2>\u001b[1;34m()\u001b[0m\n\u001b[0;32m      <a href='vscode-notebook-cell:/c%3A/Users/cpras/Documents/GitHub/AI_Personal_Trainer/ExerciseDecoder.ipynb#ch0000034?line=0'>1</a>\u001b[0m AttnLSTM\u001b[39m.\u001b[39mcompile(optimizer\u001b[39m=\u001b[39m\u001b[39m'\u001b[39m\u001b[39mAdam\u001b[39m\u001b[39m'\u001b[39m, loss\u001b[39m=\u001b[39m\u001b[39m'\u001b[39m\u001b[39mcategorical_crossentropy\u001b[39m\u001b[39m'\u001b[39m, metrics\u001b[39m=\u001b[39m[\u001b[39m'\u001b[39m\u001b[39mcategorical_accuracy\u001b[39m\u001b[39m'\u001b[39m])\n\u001b[1;32m----> <a href='vscode-notebook-cell:/c%3A/Users/cpras/Documents/GitHub/AI_Personal_Trainer/ExerciseDecoder.ipynb#ch0000034?line=1'>2</a>\u001b[0m AttnLSTM\u001b[39m.\u001b[39;49mfit(X_train, y_train, batch_size\u001b[39m=\u001b[39;49mbatch_size, epochs\u001b[39m=\u001b[39;49mmax_epochs, validation_data\u001b[39m=\u001b[39;49m(X_val, y_val), callbacks\u001b[39m=\u001b[39;49mcallbacks)\n",
      "File \u001b[1;32m~\\AppData\\Roaming\\Python\\Python38\\site-packages\\keras\\utils\\traceback_utils.py:64\u001b[0m, in \u001b[0;36mfilter_traceback.<locals>.error_handler\u001b[1;34m(*args, **kwargs)\u001b[0m\n\u001b[0;32m     62\u001b[0m filtered_tb \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m\n\u001b[0;32m     63\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[1;32m---> 64\u001b[0m   \u001b[39mreturn\u001b[39;00m fn(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[0;32m     65\u001b[0m \u001b[39mexcept\u001b[39;00m \u001b[39mException\u001b[39;00m \u001b[39mas\u001b[39;00m e:  \u001b[39m# pylint: disable=broad-except\u001b[39;00m\n\u001b[0;32m     66\u001b[0m   filtered_tb \u001b[39m=\u001b[39m _process_traceback_frames(e\u001b[39m.\u001b[39m__traceback__)\n",
      "File \u001b[1;32m~\\AppData\\Roaming\\Python\\Python38\\site-packages\\keras\\engine\\training.py:1414\u001b[0m, in \u001b[0;36mModel.fit\u001b[1;34m(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing)\u001b[0m\n\u001b[0;32m   1412\u001b[0m logs \u001b[39m=\u001b[39m tmp_logs  \u001b[39m# No error, now safe to assign to logs.\u001b[39;00m\n\u001b[0;32m   1413\u001b[0m end_step \u001b[39m=\u001b[39m step \u001b[39m+\u001b[39m data_handler\u001b[39m.\u001b[39mstep_increment\n\u001b[1;32m-> 1414\u001b[0m callbacks\u001b[39m.\u001b[39;49mon_train_batch_end(end_step, logs)\n\u001b[0;32m   1415\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mstop_training:\n\u001b[0;32m   1416\u001b[0m   \u001b[39mbreak\u001b[39;00m\n",
      "File \u001b[1;32m~\\AppData\\Roaming\\Python\\Python38\\site-packages\\keras\\callbacks.py:438\u001b[0m, in \u001b[0;36mCallbackList.on_train_batch_end\u001b[1;34m(self, batch, logs)\u001b[0m\n\u001b[0;32m    431\u001b[0m \u001b[39m\"\"\"Calls the `on_train_batch_end` methods of its callbacks.\u001b[39;00m\n\u001b[0;32m    432\u001b[0m \n\u001b[0;32m    433\u001b[0m \u001b[39mArgs:\u001b[39;00m\n\u001b[0;32m    434\u001b[0m \u001b[39m    batch: Integer, index of batch within the current epoch.\u001b[39;00m\n\u001b[0;32m    435\u001b[0m \u001b[39m    logs: Dict. Aggregated metric results up until this batch.\u001b[39;00m\n\u001b[0;32m    436\u001b[0m \u001b[39m\"\"\"\u001b[39;00m\n\u001b[0;32m    437\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_should_call_train_batch_hooks:\n\u001b[1;32m--> 438\u001b[0m   \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_call_batch_hook(ModeKeys\u001b[39m.\u001b[39;49mTRAIN, \u001b[39m'\u001b[39;49m\u001b[39mend\u001b[39;49m\u001b[39m'\u001b[39;49m, batch, logs\u001b[39m=\u001b[39;49mlogs)\n",
      "File \u001b[1;32m~\\AppData\\Roaming\\Python\\Python38\\site-packages\\keras\\callbacks.py:297\u001b[0m, in \u001b[0;36mCallbackList._call_batch_hook\u001b[1;34m(self, mode, hook, batch, logs)\u001b[0m\n\u001b[0;32m    295\u001b[0m   \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_call_batch_begin_hook(mode, batch, logs)\n\u001b[0;32m    296\u001b[0m \u001b[39melif\u001b[39;00m hook \u001b[39m==\u001b[39m \u001b[39m'\u001b[39m\u001b[39mend\u001b[39m\u001b[39m'\u001b[39m:\n\u001b[1;32m--> 297\u001b[0m   \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_call_batch_end_hook(mode, batch, logs)\n\u001b[0;32m    298\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m    299\u001b[0m   \u001b[39mraise\u001b[39;00m \u001b[39mValueError\u001b[39;00m(\n\u001b[0;32m    300\u001b[0m       \u001b[39mf\u001b[39m\u001b[39m'\u001b[39m\u001b[39mUnrecognized hook: \u001b[39m\u001b[39m{\u001b[39;00mhook\u001b[39m}\u001b[39;00m\u001b[39m. Expected values are [\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mbegin\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m, \u001b[39m\u001b[39m\"\u001b[39m\u001b[39mend\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m]\u001b[39m\u001b[39m'\u001b[39m)\n",
      "File \u001b[1;32m~\\AppData\\Roaming\\Python\\Python38\\site-packages\\keras\\callbacks.py:318\u001b[0m, in \u001b[0;36mCallbackList._call_batch_end_hook\u001b[1;34m(self, mode, batch, logs)\u001b[0m\n\u001b[0;32m    315\u001b[0m   batch_time \u001b[39m=\u001b[39m time\u001b[39m.\u001b[39mtime() \u001b[39m-\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_batch_start_time\n\u001b[0;32m    316\u001b[0m   \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_batch_times\u001b[39m.\u001b[39mappend(batch_time)\n\u001b[1;32m--> 318\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_call_batch_hook_helper(hook_name, batch, logs)\n\u001b[0;32m    320\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mlen\u001b[39m(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_batch_times) \u001b[39m>\u001b[39m\u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_num_batches_for_timing_check:\n\u001b[0;32m    321\u001b[0m   end_hook_name \u001b[39m=\u001b[39m hook_name\n",
      "File \u001b[1;32m~\\AppData\\Roaming\\Python\\Python38\\site-packages\\keras\\callbacks.py:356\u001b[0m, in \u001b[0;36mCallbackList._call_batch_hook_helper\u001b[1;34m(self, hook_name, batch, logs)\u001b[0m\n\u001b[0;32m    354\u001b[0m \u001b[39mfor\u001b[39;00m callback \u001b[39min\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mcallbacks:\n\u001b[0;32m    355\u001b[0m   hook \u001b[39m=\u001b[39m \u001b[39mgetattr\u001b[39m(callback, hook_name)\n\u001b[1;32m--> 356\u001b[0m   hook(batch, logs)\n\u001b[0;32m    358\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_check_timing:\n\u001b[0;32m    359\u001b[0m   \u001b[39mif\u001b[39;00m hook_name \u001b[39mnot\u001b[39;00m \u001b[39min\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_hook_times:\n",
      "File \u001b[1;32m~\\AppData\\Roaming\\Python\\Python38\\site-packages\\keras\\callbacks.py:1380\u001b[0m, in \u001b[0;36mModelCheckpoint.on_train_batch_end\u001b[1;34m(self, batch, logs)\u001b[0m\n\u001b[0;32m   1378\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mon_train_batch_end\u001b[39m(\u001b[39mself\u001b[39m, batch, logs\u001b[39m=\u001b[39m\u001b[39mNone\u001b[39;00m):\n\u001b[0;32m   1379\u001b[0m   \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_should_save_on_batch(batch):\n\u001b[1;32m-> 1380\u001b[0m     \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_save_model(epoch\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_current_epoch, batch\u001b[39m=\u001b[39;49mbatch, logs\u001b[39m=\u001b[39;49mlogs)\n",
      "File \u001b[1;32m~\\AppData\\Roaming\\Python\\Python38\\site-packages\\keras\\callbacks.py:1458\u001b[0m, in \u001b[0;36mModelCheckpoint._save_model\u001b[1;34m(self, epoch, batch, logs)\u001b[0m\n\u001b[0;32m   1455\u001b[0m       \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mmodel\u001b[39m.\u001b[39msave_weights(\n\u001b[0;32m   1456\u001b[0m           filepath, overwrite\u001b[39m=\u001b[39m\u001b[39mTrue\u001b[39;00m, options\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_options)\n\u001b[0;32m   1457\u001b[0m     \u001b[39melse\u001b[39;00m:\n\u001b[1;32m-> 1458\u001b[0m       \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mmodel\u001b[39m.\u001b[39;49msave(filepath, overwrite\u001b[39m=\u001b[39;49m\u001b[39mTrue\u001b[39;49;00m, options\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_options)\n\u001b[0;32m   1460\u001b[0m   \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_maybe_remove_file()\n\u001b[0;32m   1461\u001b[0m \u001b[39mexcept\u001b[39;00m \u001b[39mIsADirectoryError\u001b[39;00m \u001b[39mas\u001b[39;00m e:  \u001b[39m# h5py 3.x\u001b[39;00m\n",
      "File \u001b[1;32m~\\AppData\\Roaming\\Python\\Python38\\site-packages\\keras\\utils\\traceback_utils.py:64\u001b[0m, in \u001b[0;36mfilter_traceback.<locals>.error_handler\u001b[1;34m(*args, **kwargs)\u001b[0m\n\u001b[0;32m     62\u001b[0m filtered_tb \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m\n\u001b[0;32m     63\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[1;32m---> 64\u001b[0m   \u001b[39mreturn\u001b[39;00m fn(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[0;32m     65\u001b[0m \u001b[39mexcept\u001b[39;00m \u001b[39mException\u001b[39;00m \u001b[39mas\u001b[39;00m e:  \u001b[39m# pylint: disable=broad-except\u001b[39;00m\n\u001b[0;32m     66\u001b[0m   filtered_tb \u001b[39m=\u001b[39m _process_traceback_frames(e\u001b[39m.\u001b[39m__traceback__)\n",
      "File \u001b[1;32m~\\AppData\\Roaming\\Python\\Python38\\site-packages\\keras\\engine\\training.py:2435\u001b[0m, in \u001b[0;36mModel.save\u001b[1;34m(self, filepath, overwrite, include_optimizer, save_format, signatures, options, save_traces)\u001b[0m\n\u001b[0;32m   2393\u001b[0m \u001b[39m\"\"\"Saves the model to Tensorflow SavedModel or a single HDF5 file.\u001b[39;00m\n\u001b[0;32m   2394\u001b[0m \n\u001b[0;32m   2395\u001b[0m \u001b[39mPlease see `tf.keras.models.save_model` or the\u001b[39;00m\n\u001b[1;32m   (...)\u001b[0m\n\u001b[0;32m   2432\u001b[0m \u001b[39m```\u001b[39;00m\n\u001b[0;32m   2433\u001b[0m \u001b[39m\"\"\"\u001b[39;00m\n\u001b[0;32m   2434\u001b[0m \u001b[39m# pylint: enable=line-too-long\u001b[39;00m\n\u001b[1;32m-> 2435\u001b[0m save\u001b[39m.\u001b[39;49msave_model(\u001b[39mself\u001b[39;49m, filepath, overwrite, include_optimizer, save_format,\n\u001b[0;32m   2436\u001b[0m                 signatures, options, save_traces)\n",
      "File \u001b[1;32m~\\AppData\\Roaming\\Python\\Python38\\site-packages\\keras\\utils\\traceback_utils.py:64\u001b[0m, in \u001b[0;36mfilter_traceback.<locals>.error_handler\u001b[1;34m(*args, **kwargs)\u001b[0m\n\u001b[0;32m     62\u001b[0m filtered_tb \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m\n\u001b[0;32m     63\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[1;32m---> 64\u001b[0m   \u001b[39mreturn\u001b[39;00m fn(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[0;32m     65\u001b[0m \u001b[39mexcept\u001b[39;00m \u001b[39mException\u001b[39;00m \u001b[39mas\u001b[39;00m e:  \u001b[39m# pylint: disable=broad-except\u001b[39;00m\n\u001b[0;32m     66\u001b[0m   filtered_tb \u001b[39m=\u001b[39m _process_traceback_frames(e\u001b[39m.\u001b[39m__traceback__)\n",
      "File \u001b[1;32m~\\AppData\\Roaming\\Python\\Python38\\site-packages\\keras\\saving\\save.py:153\u001b[0m, in \u001b[0;36msave_model\u001b[1;34m(model, filepath, overwrite, include_optimizer, save_format, signatures, options, save_traces)\u001b[0m\n\u001b[0;32m    151\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m    152\u001b[0m   \u001b[39mwith\u001b[39;00m generic_utils\u001b[39m.\u001b[39mSharedObjectSavingScope():\n\u001b[1;32m--> 153\u001b[0m     saved_model_save\u001b[39m.\u001b[39;49msave(model, filepath, overwrite, include_optimizer,\n\u001b[0;32m    154\u001b[0m                           signatures, options, save_traces)\n",
      "File \u001b[1;32m~\\AppData\\Roaming\\Python\\Python38\\site-packages\\keras\\saving\\saved_model\\save.py:93\u001b[0m, in \u001b[0;36msave\u001b[1;34m(model, filepath, overwrite, include_optimizer, signatures, options, save_traces)\u001b[0m\n\u001b[0;32m     91\u001b[0m \u001b[39mwith\u001b[39;00m backend\u001b[39m.\u001b[39mdeprecated_internal_learning_phase_scope(\u001b[39m0\u001b[39m):\n\u001b[0;32m     92\u001b[0m   \u001b[39mwith\u001b[39;00m utils\u001b[39m.\u001b[39mkeras_option_scope(save_traces):\n\u001b[1;32m---> 93\u001b[0m     saved_nodes, node_paths \u001b[39m=\u001b[39m save_lib\u001b[39m.\u001b[39;49msave_and_return_nodes(\n\u001b[0;32m     94\u001b[0m         model, filepath, signatures, options)\n\u001b[0;32m     96\u001b[0m   \u001b[39m# Save all metadata to a separate file in the SavedModel directory.\u001b[39;00m\n\u001b[0;32m     97\u001b[0m   metadata \u001b[39m=\u001b[39m generate_keras_metadata(saved_nodes, node_paths)\n",
      "File \u001b[1;32m~\\AppData\\Roaming\\Python\\Python38\\site-packages\\tensorflow\\python\\saved_model\\save.py:1325\u001b[0m, in \u001b[0;36msave_and_return_nodes\u001b[1;34m(obj, export_dir, signatures, options, experimental_skip_checkpoint)\u001b[0m\n\u001b[0;32m   1321\u001b[0m saved_model \u001b[39m=\u001b[39m saved_model_pb2\u001b[39m.\u001b[39mSavedModel()\n\u001b[0;32m   1322\u001b[0m meta_graph_def \u001b[39m=\u001b[39m saved_model\u001b[39m.\u001b[39mmeta_graphs\u001b[39m.\u001b[39madd()\n\u001b[0;32m   1324\u001b[0m _, exported_graph, object_saver, asset_info, saved_nodes, node_paths \u001b[39m=\u001b[39m (\n\u001b[1;32m-> 1325\u001b[0m     _build_meta_graph(obj, signatures, options, meta_graph_def))\n\u001b[0;32m   1326\u001b[0m saved_model\u001b[39m.\u001b[39msaved_model_schema_version \u001b[39m=\u001b[39m (\n\u001b[0;32m   1327\u001b[0m     constants\u001b[39m.\u001b[39mSAVED_MODEL_SCHEMA_VERSION)\n\u001b[0;32m   1329\u001b[0m \u001b[39m# Write the checkpoint, copy assets into the assets directory, and write out\u001b[39;00m\n\u001b[0;32m   1330\u001b[0m \u001b[39m# the SavedModel proto itself.\u001b[39;00m\n",
      "File \u001b[1;32m~\\AppData\\Roaming\\Python\\Python38\\site-packages\\tensorflow\\python\\saved_model\\save.py:1491\u001b[0m, in \u001b[0;36m_build_meta_graph\u001b[1;34m(obj, signatures, options, meta_graph_def)\u001b[0m\n\u001b[0;32m   1466\u001b[0m \u001b[39m\"\"\"Creates a MetaGraph under a save context.\u001b[39;00m\n\u001b[0;32m   1467\u001b[0m \n\u001b[0;32m   1468\u001b[0m \u001b[39mArgs:\u001b[39;00m\n\u001b[1;32m   (...)\u001b[0m\n\u001b[0;32m   1487\u001b[0m \u001b[39m  asset_info: `_AssetInfo` tuple containing external assets in the `obj`.\u001b[39;00m\n\u001b[0;32m   1488\u001b[0m \u001b[39m\"\"\"\u001b[39;00m\n\u001b[0;32m   1490\u001b[0m \u001b[39mwith\u001b[39;00m save_context\u001b[39m.\u001b[39msave_context(options):\n\u001b[1;32m-> 1491\u001b[0m   \u001b[39mreturn\u001b[39;00m _build_meta_graph_impl(obj, signatures, options, meta_graph_def)\n",
      "File \u001b[1;32m~\\AppData\\Roaming\\Python\\Python38\\site-packages\\tensorflow\\python\\saved_model\\save.py:1443\u001b[0m, in \u001b[0;36m_build_meta_graph_impl\u001b[1;34m(obj, signatures, options, meta_graph_def)\u001b[0m\n\u001b[0;32m   1440\u001b[0m augmented_graph_view\u001b[39m.\u001b[39mset_signature(signature_map, wrapped_functions)\n\u001b[0;32m   1442\u001b[0m \u001b[39m# Use _SaveableView to provide a frozen listing of properties and functions.\u001b[39;00m\n\u001b[1;32m-> 1443\u001b[0m saveable_view \u001b[39m=\u001b[39m _SaveableView(augmented_graph_view, options)\n\u001b[0;32m   1444\u001b[0m object_saver \u001b[39m=\u001b[39m util\u001b[39m.\u001b[39mTrackableSaver(augmented_graph_view)\n\u001b[0;32m   1445\u001b[0m asset_info, exported_graph \u001b[39m=\u001b[39m _fill_meta_graph_def(\n\u001b[0;32m   1446\u001b[0m     meta_graph_def, saveable_view, signatures,\n\u001b[0;32m   1447\u001b[0m     options\u001b[39m.\u001b[39mnamespace_whitelist, options\u001b[39m.\u001b[39mexperimental_custom_gradients)\n",
      "File \u001b[1;32m~\\AppData\\Roaming\\Python\\Python38\\site-packages\\tensorflow\\python\\saved_model\\save.py:229\u001b[0m, in \u001b[0;36m_SaveableView.__init__\u001b[1;34m(self, augmented_graph_view, options)\u001b[0m\n\u001b[0;32m    224\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39maugmented_graph_view \u001b[39m=\u001b[39m augmented_graph_view\n\u001b[0;32m    225\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_options \u001b[39m=\u001b[39m options\n\u001b[0;32m    227\u001b[0m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_trackable_objects, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mnode_paths, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mnode_ids,\n\u001b[0;32m    228\u001b[0m  \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_slot_variables, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mobject_names) \u001b[39m=\u001b[39m (\n\u001b[1;32m--> 229\u001b[0m      \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49maugmented_graph_view\u001b[39m.\u001b[39;49mobjects_ids_and_slot_variables_and_paths())\n\u001b[0;32m    231\u001b[0m untraced_functions \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39maugmented_graph_view\u001b[39m.\u001b[39muntraced_functions\n\u001b[0;32m    232\u001b[0m \u001b[39mif\u001b[39;00m untraced_functions:\n",
      "File \u001b[1;32m~\\AppData\\Roaming\\Python\\Python38\\site-packages\\tensorflow\\python\\training\\tracking\\graph_view.py:544\u001b[0m, in \u001b[0;36mObjectGraphView.objects_ids_and_slot_variables_and_paths\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m    532\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mobjects_ids_and_slot_variables_and_paths\u001b[39m(\u001b[39mself\u001b[39m):\n\u001b[0;32m    533\u001b[0m   \u001b[39m\"\"\"Traverse the object graph and list all accessible objects.\u001b[39;00m\n\u001b[0;32m    534\u001b[0m \n\u001b[0;32m    535\u001b[0m \u001b[39m  Looks for `Trackable` objects which are dependencies of\u001b[39;00m\n\u001b[1;32m   (...)\u001b[0m\n\u001b[0;32m    542\u001b[0m \u001b[39m                object -> node id, slot variables, object_names)\u001b[39;00m\n\u001b[0;32m    543\u001b[0m \u001b[39m  \"\"\"\u001b[39;00m\n\u001b[1;32m--> 544\u001b[0m   trackable_objects, node_paths \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_breadth_first_traversal()\n\u001b[0;32m    545\u001b[0m   object_names \u001b[39m=\u001b[39m object_identity\u001b[39m.\u001b[39mObjectIdentityDictionary()\n\u001b[0;32m    546\u001b[0m   \u001b[39mfor\u001b[39;00m obj, path \u001b[39min\u001b[39;00m node_paths\u001b[39m.\u001b[39mitems():\n",
      "File \u001b[1;32m~\\AppData\\Roaming\\Python\\Python38\\site-packages\\tensorflow\\python\\training\\tracking\\graph_view.py:255\u001b[0m, in \u001b[0;36mObjectGraphView._breadth_first_traversal\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m    253\u001b[0m current_trackable \u001b[39m=\u001b[39m to_visit\u001b[39m.\u001b[39mpopleft()\n\u001b[0;32m    254\u001b[0m bfs_sorted\u001b[39m.\u001b[39mappend(current_trackable)\n\u001b[1;32m--> 255\u001b[0m \u001b[39mfor\u001b[39;00m name, dependency \u001b[39min\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mlist_children(current_trackable):\n\u001b[0;32m    256\u001b[0m   \u001b[39mif\u001b[39;00m dependency \u001b[39mnot\u001b[39;00m \u001b[39min\u001b[39;00m node_paths:\n\u001b[0;32m    257\u001b[0m     node_paths[dependency] \u001b[39m=\u001b[39m (\n\u001b[0;32m    258\u001b[0m         node_paths[current_trackable] \u001b[39m+\u001b[39m (\n\u001b[0;32m    259\u001b[0m             base\u001b[39m.\u001b[39mTrackableReference(name, dependency),))\n",
      "File \u001b[1;32m~\\AppData\\Roaming\\Python\\Python38\\site-packages\\tensorflow\\python\\saved_model\\save.py:143\u001b[0m, in \u001b[0;36m_AugmentedGraphView.list_children\u001b[1;34m(self, obj)\u001b[0m\n\u001b[0;32m    140\u001b[0m \u001b[39mif\u001b[39;00m obj \u001b[39mnot\u001b[39;00m \u001b[39min\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_children_cache:\n\u001b[0;32m    141\u001b[0m   children \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_children_cache[obj] \u001b[39m=\u001b[39m {}\n\u001b[1;32m--> 143\u001b[0m   \u001b[39mfor\u001b[39;00m name, child \u001b[39min\u001b[39;00m \u001b[39msuper\u001b[39;49m(_AugmentedGraphView, \u001b[39mself\u001b[39;49m)\u001b[39m.\u001b[39;49mlist_children(\n\u001b[0;32m    144\u001b[0m       obj,\n\u001b[0;32m    145\u001b[0m       save_type\u001b[39m=\u001b[39;49mbase\u001b[39m.\u001b[39;49mSaveType\u001b[39m.\u001b[39;49mSAVEDMODEL,\n\u001b[0;32m    146\u001b[0m       cache\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_serialization_cache):\n\u001b[0;32m    147\u001b[0m     \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(child, defun\u001b[39m.\u001b[39mConcreteFunction):\n\u001b[0;32m    148\u001b[0m       child \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_maybe_uncache_variable_captures(child)\n",
      "File \u001b[1;32m~\\AppData\\Roaming\\Python\\Python38\\site-packages\\tensorflow\\python\\training\\tracking\\graph_view.py:203\u001b[0m, in \u001b[0;36mObjectGraphView.list_children\u001b[1;34m(self, obj, save_type, **kwargs)\u001b[0m\n\u001b[0;32m    200\u001b[0m \u001b[39m# pylint: disable=protected-access\u001b[39;00m\n\u001b[0;32m    201\u001b[0m obj\u001b[39m.\u001b[39m_maybe_initialize_trackable()\n\u001b[0;32m    202\u001b[0m children \u001b[39m=\u001b[39m [base\u001b[39m.\u001b[39mTrackableReference(name, ref) \u001b[39mfor\u001b[39;00m name, ref\n\u001b[1;32m--> 203\u001b[0m             \u001b[39min\u001b[39;00m obj\u001b[39m.\u001b[39;49m_trackable_children(save_type, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\u001b[39m.\u001b[39mitems()]\n\u001b[0;32m    204\u001b[0m \u001b[39m# pylint: enable=protected-access\u001b[39;00m\n\u001b[0;32m    205\u001b[0m \n\u001b[0;32m    206\u001b[0m \u001b[39m# GraphView objects may define children of the root object that are not\u001b[39;00m\n\u001b[0;32m    207\u001b[0m \u001b[39m# actually attached, e.g. a Checkpoint object's save_counter.\u001b[39;00m\n\u001b[0;32m    208\u001b[0m \u001b[39mif\u001b[39;00m obj \u001b[39mis\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mroot \u001b[39mand\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_attached_dependencies:\n",
      "File \u001b[1;32m~\\AppData\\Roaming\\Python\\Python38\\site-packages\\tensorflow\\python\\training\\tracking\\autotrackable.py:115\u001b[0m, in \u001b[0;36mAutoTrackable._trackable_children\u001b[1;34m(self, save_type, **kwargs)\u001b[0m\n\u001b[0;32m    113\u001b[0m \u001b[39mfor\u001b[39;00m fn \u001b[39min\u001b[39;00m functions\u001b[39m.\u001b[39mvalues():\n\u001b[0;32m    114\u001b[0m   \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(fn, core_types\u001b[39m.\u001b[39mGenericFunction):\n\u001b[1;32m--> 115\u001b[0m     fn\u001b[39m.\u001b[39;49m_list_all_concrete_functions_for_serialization()  \u001b[39m# pylint: disable=protected-access\u001b[39;00m\n\u001b[0;32m    117\u001b[0m \u001b[39m# Additional dependencies may have been generated during function tracing\u001b[39;00m\n\u001b[0;32m    118\u001b[0m \u001b[39m# (e.g. captured variables). Make sure we return those too.\u001b[39;00m\n\u001b[0;32m    119\u001b[0m children \u001b[39m=\u001b[39m {}\n",
      "File \u001b[1;32m~\\AppData\\Roaming\\Python\\Python38\\site-packages\\tensorflow\\python\\eager\\def_function.py:1184\u001b[0m, in \u001b[0;36mFunction._list_all_concrete_functions_for_serialization\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m   1182\u001b[0m concrete_functions \u001b[39m=\u001b[39m []\n\u001b[0;32m   1183\u001b[0m \u001b[39mfor\u001b[39;00m args, kwargs \u001b[39min\u001b[39;00m seen_signatures:\n\u001b[1;32m-> 1184\u001b[0m   concrete_functions\u001b[39m.\u001b[39mappend(\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mget_concrete_function(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs))\n\u001b[0;32m   1185\u001b[0m \u001b[39mreturn\u001b[39;00m concrete_functions\n",
      "File \u001b[1;32m~\\AppData\\Roaming\\Python\\Python38\\site-packages\\tensorflow\\python\\eager\\def_function.py:1239\u001b[0m, in \u001b[0;36mFunction.get_concrete_function\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m   1237\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mget_concrete_function\u001b[39m(\u001b[39mself\u001b[39m, \u001b[39m*\u001b[39margs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs):\n\u001b[0;32m   1238\u001b[0m   \u001b[39m# Implements GenericFunction.get_concrete_function.\u001b[39;00m\n\u001b[1;32m-> 1239\u001b[0m   concrete \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_get_concrete_function_garbage_collected(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[0;32m   1240\u001b[0m   concrete\u001b[39m.\u001b[39m_garbage_collector\u001b[39m.\u001b[39mrelease()  \u001b[39m# pylint: disable=protected-access\u001b[39;00m\n\u001b[0;32m   1241\u001b[0m   \u001b[39mreturn\u001b[39;00m concrete\n",
      "File \u001b[1;32m~\\AppData\\Roaming\\Python\\Python38\\site-packages\\tensorflow\\python\\eager\\def_function.py:1230\u001b[0m, in \u001b[0;36mFunction._get_concrete_function_garbage_collected\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m   1225\u001b[0m   \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_stateless_fn\u001b[39m.\u001b[39m_get_concrete_function_garbage_collected(  \u001b[39m# pylint: disable=protected-access\u001b[39;00m\n\u001b[0;32m   1226\u001b[0m       \u001b[39m*\u001b[39margs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs)\n\u001b[0;32m   1227\u001b[0m \u001b[39melif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_stateful_fn \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[0;32m   1228\u001b[0m   \u001b[39m# In this case we have not created variables on the first call. So we can\u001b[39;00m\n\u001b[0;32m   1229\u001b[0m   \u001b[39m# run the first trace but we should fail if variables are created.\u001b[39;00m\n\u001b[1;32m-> 1230\u001b[0m   concrete \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_stateful_fn\u001b[39m.\u001b[39;49m_get_concrete_function_garbage_collected(  \u001b[39m# pylint: disable=protected-access\u001b[39;49;00m\n\u001b[0;32m   1231\u001b[0m       \u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[0;32m   1232\u001b[0m   \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_created_variables:\n\u001b[0;32m   1233\u001b[0m     \u001b[39mraise\u001b[39;00m \u001b[39mValueError\u001b[39;00m(\u001b[39m\"\u001b[39m\u001b[39mCreating variables on a non-first call to a function\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[0;32m   1234\u001b[0m                      \u001b[39m\"\u001b[39m\u001b[39m decorated with tf.function.\u001b[39m\u001b[39m\"\u001b[39m)\n",
      "File \u001b[1;32m~\\AppData\\Roaming\\Python\\Python38\\site-packages\\tensorflow\\python\\eager\\function.py:2533\u001b[0m, in \u001b[0;36mFunction._get_concrete_function_garbage_collected\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m   2531\u001b[0m   args, kwargs \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m, \u001b[39mNone\u001b[39;00m\n\u001b[0;32m   2532\u001b[0m \u001b[39mwith\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_lock:\n\u001b[1;32m-> 2533\u001b[0m   graph_function, _ \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_maybe_define_function(args, kwargs)\n\u001b[0;32m   2534\u001b[0m   seen_names \u001b[39m=\u001b[39m \u001b[39mset\u001b[39m()\n\u001b[0;32m   2535\u001b[0m   captured \u001b[39m=\u001b[39m object_identity\u001b[39m.\u001b[39mObjectIdentitySet(\n\u001b[0;32m   2536\u001b[0m       graph_function\u001b[39m.\u001b[39mgraph\u001b[39m.\u001b[39minternal_captures)\n",
      "File \u001b[1;32m~\\AppData\\Roaming\\Python\\Python38\\site-packages\\tensorflow\\python\\eager\\function.py:2711\u001b[0m, in \u001b[0;36mFunction._maybe_define_function\u001b[1;34m(self, args, kwargs)\u001b[0m\n\u001b[0;32m   2708\u001b[0m   cache_key \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_function_cache\u001b[39m.\u001b[39mgeneralize(cache_key)\n\u001b[0;32m   2709\u001b[0m   (args, kwargs) \u001b[39m=\u001b[39m cache_key\u001b[39m.\u001b[39m_placeholder_value()  \u001b[39m# pylint: disable=protected-access\u001b[39;00m\n\u001b[1;32m-> 2711\u001b[0m graph_function \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_create_graph_function(args, kwargs)\n\u001b[0;32m   2712\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_function_cache\u001b[39m.\u001b[39madd(cache_key, cache_key_deletion_observer,\n\u001b[0;32m   2713\u001b[0m                          graph_function)\n\u001b[0;32m   2715\u001b[0m \u001b[39mreturn\u001b[39;00m graph_function, filtered_flat_args\n",
      "File \u001b[1;32m~\\AppData\\Roaming\\Python\\Python38\\site-packages\\tensorflow\\python\\eager\\function.py:2627\u001b[0m, in \u001b[0;36mFunction._create_graph_function\u001b[1;34m(self, args, kwargs)\u001b[0m\n\u001b[0;32m   2622\u001b[0m missing_arg_names \u001b[39m=\u001b[39m [\n\u001b[0;32m   2623\u001b[0m     \u001b[39m\"\u001b[39m\u001b[39m%s\u001b[39;00m\u001b[39m_\u001b[39m\u001b[39m%d\u001b[39;00m\u001b[39m\"\u001b[39m \u001b[39m%\u001b[39m (arg, i) \u001b[39mfor\u001b[39;00m i, arg \u001b[39min\u001b[39;00m \u001b[39menumerate\u001b[39m(missing_arg_names)\n\u001b[0;32m   2624\u001b[0m ]\n\u001b[0;32m   2625\u001b[0m arg_names \u001b[39m=\u001b[39m base_arg_names \u001b[39m+\u001b[39m missing_arg_names\n\u001b[0;32m   2626\u001b[0m graph_function \u001b[39m=\u001b[39m ConcreteFunction(\n\u001b[1;32m-> 2627\u001b[0m     func_graph_module\u001b[39m.\u001b[39;49mfunc_graph_from_py_func(\n\u001b[0;32m   2628\u001b[0m         \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_name,\n\u001b[0;32m   2629\u001b[0m         \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_python_function,\n\u001b[0;32m   2630\u001b[0m         args,\n\u001b[0;32m   2631\u001b[0m         kwargs,\n\u001b[0;32m   2632\u001b[0m         \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49minput_signature,\n\u001b[0;32m   2633\u001b[0m         autograph\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_autograph,\n\u001b[0;32m   2634\u001b[0m         autograph_options\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_autograph_options,\n\u001b[0;32m   2635\u001b[0m         arg_names\u001b[39m=\u001b[39;49marg_names,\n\u001b[0;32m   2636\u001b[0m         capture_by_value\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_capture_by_value),\n\u001b[0;32m   2637\u001b[0m     \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_function_attributes,\n\u001b[0;32m   2638\u001b[0m     spec\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mfunction_spec,\n\u001b[0;32m   2639\u001b[0m     \u001b[39m# Tell the ConcreteFunction to clean up its graph once it goes out of\u001b[39;00m\n\u001b[0;32m   2640\u001b[0m     \u001b[39m# scope. This is not the default behavior since it gets used in some\u001b[39;00m\n\u001b[0;32m   2641\u001b[0m     \u001b[39m# places (like Keras) where the FuncGraph lives longer than the\u001b[39;00m\n\u001b[0;32m   2642\u001b[0m     \u001b[39m# ConcreteFunction.\u001b[39;00m\n\u001b[0;32m   2643\u001b[0m     shared_func_graph\u001b[39m=\u001b[39m\u001b[39mFalse\u001b[39;00m)\n\u001b[0;32m   2644\u001b[0m \u001b[39mreturn\u001b[39;00m graph_function\n",
      "File \u001b[1;32m~\\AppData\\Roaming\\Python\\Python38\\site-packages\\tensorflow\\python\\framework\\func_graph.py:1141\u001b[0m, in \u001b[0;36mfunc_graph_from_py_func\u001b[1;34m(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, acd_record_initial_resource_uses)\u001b[0m\n\u001b[0;32m   1138\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m   1139\u001b[0m   _, original_func \u001b[39m=\u001b[39m tf_decorator\u001b[39m.\u001b[39munwrap(python_func)\n\u001b[1;32m-> 1141\u001b[0m func_outputs \u001b[39m=\u001b[39m python_func(\u001b[39m*\u001b[39;49mfunc_args, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mfunc_kwargs)\n\u001b[0;32m   1143\u001b[0m \u001b[39m# invariant: `func_outputs` contains only Tensors, CompositeTensors,\u001b[39;00m\n\u001b[0;32m   1144\u001b[0m \u001b[39m# TensorArrays and `None`s.\u001b[39;00m\n\u001b[0;32m   1145\u001b[0m func_outputs \u001b[39m=\u001b[39m nest\u001b[39m.\u001b[39mmap_structure(\n\u001b[0;32m   1146\u001b[0m     convert, func_outputs, expand_composites\u001b[39m=\u001b[39m\u001b[39mTrue\u001b[39;00m)\n",
      "File \u001b[1;32m~\\AppData\\Roaming\\Python\\Python38\\site-packages\\tensorflow\\python\\eager\\def_function.py:677\u001b[0m, in \u001b[0;36mFunction._defun_with_scope.<locals>.wrapped_fn\u001b[1;34m(*args, **kwds)\u001b[0m\n\u001b[0;32m    673\u001b[0m \u001b[39mwith\u001b[39;00m default_graph\u001b[39m.\u001b[39m_variable_creator_scope(scope, priority\u001b[39m=\u001b[39m\u001b[39m50\u001b[39m):  \u001b[39m# pylint: disable=protected-access\u001b[39;00m\n\u001b[0;32m    674\u001b[0m   \u001b[39m# __wrapped__ allows AutoGraph to swap in a converted function. We give\u001b[39;00m\n\u001b[0;32m    675\u001b[0m   \u001b[39m# the function a weak reference to itself to avoid a reference cycle.\u001b[39;00m\n\u001b[0;32m    676\u001b[0m   \u001b[39mwith\u001b[39;00m OptionalXlaContext(compile_with_xla):\n\u001b[1;32m--> 677\u001b[0m     out \u001b[39m=\u001b[39m weak_wrapped_fn()\u001b[39m.\u001b[39;49m__wrapped__(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwds)\n\u001b[0;32m    678\u001b[0m   \u001b[39mreturn\u001b[39;00m out\n",
      "File \u001b[1;32m~\\AppData\\Roaming\\Python\\Python38\\site-packages\\keras\\saving\\saved_model\\save_impl.py:572\u001b[0m, in \u001b[0;36mlayer_call_wrapper.<locals>.wrapper\u001b[1;34m(*args, **kwargs)\u001b[0m\n\u001b[0;32m    567\u001b[0m \u001b[39mwith\u001b[39;00m base_layer_utils\u001b[39m.\u001b[39mcall_context()\u001b[39m.\u001b[39menter(\n\u001b[0;32m    568\u001b[0m     layer, inputs\u001b[39m=\u001b[39minputs, build_graph\u001b[39m=\u001b[39m\u001b[39mFalse\u001b[39;00m, training\u001b[39m=\u001b[39mtraining,\n\u001b[0;32m    569\u001b[0m     saving\u001b[39m=\u001b[39m\u001b[39mTrue\u001b[39;00m):\n\u001b[0;32m    570\u001b[0m   \u001b[39mwith\u001b[39;00m autocast_variable\u001b[39m.\u001b[39menable_auto_cast_variables(\n\u001b[0;32m    571\u001b[0m       layer\u001b[39m.\u001b[39m_compute_dtype_object):  \u001b[39m# pylint: disable=protected-access\u001b[39;00m\n\u001b[1;32m--> 572\u001b[0m     ret \u001b[39m=\u001b[39m method(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[0;32m    573\u001b[0m _restore_layer_losses(original_losses)\n\u001b[0;32m    574\u001b[0m \u001b[39mreturn\u001b[39;00m ret\n",
      "File \u001b[1;32m~\\AppData\\Roaming\\Python\\Python38\\site-packages\\keras\\saving\\saved_model\\utils.py:168\u001b[0m, in \u001b[0;36mmaybe_add_training_arg.<locals>.wrap_with_training_arg\u001b[1;34m(*args, **kwargs)\u001b[0m\n\u001b[0;32m    165\u001b[0m   set_training_arg(training, training_arg_index, args, kwargs)\n\u001b[0;32m    166\u001b[0m   \u001b[39mreturn\u001b[39;00m wrapped_call(\u001b[39m*\u001b[39margs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs)\n\u001b[1;32m--> 168\u001b[0m \u001b[39mreturn\u001b[39;00m control_flow_util\u001b[39m.\u001b[39;49msmart_cond(\n\u001b[0;32m    169\u001b[0m     training, \u001b[39mlambda\u001b[39;49;00m: replace_training_and_call(\u001b[39mTrue\u001b[39;49;00m),\n\u001b[0;32m    170\u001b[0m     \u001b[39mlambda\u001b[39;49;00m: replace_training_and_call(\u001b[39mFalse\u001b[39;49;00m))\n",
      "File \u001b[1;32m~\\AppData\\Roaming\\Python\\Python38\\site-packages\\keras\\utils\\control_flow_util.py:105\u001b[0m, in \u001b[0;36msmart_cond\u001b[1;34m(pred, true_fn, false_fn, name)\u001b[0m\n\u001b[0;32m    102\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(pred, tf\u001b[39m.\u001b[39mVariable):\n\u001b[0;32m    103\u001b[0m   \u001b[39mreturn\u001b[39;00m tf\u001b[39m.\u001b[39mcond(\n\u001b[0;32m    104\u001b[0m       pred, true_fn\u001b[39m=\u001b[39mtrue_fn, false_fn\u001b[39m=\u001b[39mfalse_fn, name\u001b[39m=\u001b[39mname)\n\u001b[1;32m--> 105\u001b[0m \u001b[39mreturn\u001b[39;00m tf\u001b[39m.\u001b[39;49m__internal__\u001b[39m.\u001b[39;49msmart_cond\u001b[39m.\u001b[39;49msmart_cond(\n\u001b[0;32m    106\u001b[0m     pred, true_fn\u001b[39m=\u001b[39;49mtrue_fn, false_fn\u001b[39m=\u001b[39;49mfalse_fn, name\u001b[39m=\u001b[39;49mname)\n",
      "File \u001b[1;32m~\\AppData\\Roaming\\Python\\Python38\\site-packages\\tensorflow\\python\\framework\\smart_cond.py:53\u001b[0m, in \u001b[0;36msmart_cond\u001b[1;34m(pred, true_fn, false_fn, name)\u001b[0m\n\u001b[0;32m     51\u001b[0m \u001b[39mif\u001b[39;00m pred_value \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[0;32m     52\u001b[0m   \u001b[39mif\u001b[39;00m pred_value:\n\u001b[1;32m---> 53\u001b[0m     \u001b[39mreturn\u001b[39;00m true_fn()\n\u001b[0;32m     54\u001b[0m   \u001b[39melse\u001b[39;00m:\n\u001b[0;32m     55\u001b[0m     \u001b[39mreturn\u001b[39;00m false_fn()\n",
      "File \u001b[1;32m~\\AppData\\Roaming\\Python\\Python38\\site-packages\\keras\\saving\\saved_model\\utils.py:169\u001b[0m, in \u001b[0;36mmaybe_add_training_arg.<locals>.wrap_with_training_arg.<locals>.<lambda>\u001b[1;34m()\u001b[0m\n\u001b[0;32m    165\u001b[0m   set_training_arg(training, training_arg_index, args, kwargs)\n\u001b[0;32m    166\u001b[0m   \u001b[39mreturn\u001b[39;00m wrapped_call(\u001b[39m*\u001b[39margs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs)\n\u001b[0;32m    168\u001b[0m \u001b[39mreturn\u001b[39;00m control_flow_util\u001b[39m.\u001b[39msmart_cond(\n\u001b[1;32m--> 169\u001b[0m     training, \u001b[39mlambda\u001b[39;00m: replace_training_and_call(\u001b[39mTrue\u001b[39;49;00m),\n\u001b[0;32m    170\u001b[0m     \u001b[39mlambda\u001b[39;00m: replace_training_and_call(\u001b[39mFalse\u001b[39;00m))\n",
      "File \u001b[1;32m~\\AppData\\Roaming\\Python\\Python38\\site-packages\\keras\\saving\\saved_model\\utils.py:166\u001b[0m, in \u001b[0;36mmaybe_add_training_arg.<locals>.wrap_with_training_arg.<locals>.replace_training_and_call\u001b[1;34m(training)\u001b[0m\n\u001b[0;32m    164\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mreplace_training_and_call\u001b[39m(training):\n\u001b[0;32m    165\u001b[0m   set_training_arg(training, training_arg_index, args, kwargs)\n\u001b[1;32m--> 166\u001b[0m   \u001b[39mreturn\u001b[39;00m wrapped_call(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n",
      "File \u001b[1;32m~\\AppData\\Roaming\\Python\\Python38\\site-packages\\keras\\saving\\saved_model\\save_impl.py:634\u001b[0m, in \u001b[0;36m_wrap_call_and_conditional_losses.<locals>.call_and_return_conditional_losses\u001b[1;34m(*args, **kwargs)\u001b[0m\n\u001b[0;32m    632\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mcall_and_return_conditional_losses\u001b[39m(\u001b[39m*\u001b[39margs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs):\n\u001b[0;32m    633\u001b[0m   \u001b[39m\"\"\"Returns layer (call_output, conditional losses) tuple.\"\"\"\u001b[39;00m\n\u001b[1;32m--> 634\u001b[0m   call_output \u001b[39m=\u001b[39m layer_call(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[0;32m    635\u001b[0m   \u001b[39mif\u001b[39;00m version_utils\u001b[39m.\u001b[39mis_v1_layer_or_model(layer):\n\u001b[0;32m    636\u001b[0m     conditional_losses \u001b[39m=\u001b[39m layer\u001b[39m.\u001b[39mget_losses_for(\n\u001b[0;32m    637\u001b[0m         _filtered_inputs([args, kwargs]))\n",
      "File \u001b[1;32m~\\AppData\\Roaming\\Python\\Python38\\site-packages\\keras\\layers\\rnn\\bidirectional.py:366\u001b[0m, in \u001b[0;36mBidirectional.call\u001b[1;34m(self, inputs, training, mask, initial_state, constants)\u001b[0m\n\u001b[0;32m    362\u001b[0m     forward_state, backward_state \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m, \u001b[39mNone\u001b[39;00m\n\u001b[0;32m    364\u001b[0m   y \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mforward_layer(forward_inputs,\n\u001b[0;32m    365\u001b[0m                          initial_state\u001b[39m=\u001b[39mforward_state, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs)\n\u001b[1;32m--> 366\u001b[0m   y_rev \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mbackward_layer(backward_inputs,\n\u001b[0;32m    367\u001b[0m                               initial_state\u001b[39m=\u001b[39;49mbackward_state, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[0;32m    368\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m    369\u001b[0m   y \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mforward_layer(inputs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs)\n",
      "File \u001b[1;32m~\\AppData\\Roaming\\Python\\Python38\\site-packages\\keras\\layers\\rnn\\base_rnn.py:515\u001b[0m, in \u001b[0;36mRNN.__call__\u001b[1;34m(self, inputs, initial_state, constants, **kwargs)\u001b[0m\n\u001b[0;32m    511\u001b[0m inputs, initial_state, constants \u001b[39m=\u001b[39m rnn_utils\u001b[39m.\u001b[39mstandardize_args(\n\u001b[0;32m    512\u001b[0m     inputs, initial_state, constants, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_num_constants)\n\u001b[0;32m    514\u001b[0m \u001b[39mif\u001b[39;00m initial_state \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39mand\u001b[39;00m constants \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[1;32m--> 515\u001b[0m   \u001b[39mreturn\u001b[39;00m \u001b[39msuper\u001b[39;49m(RNN, \u001b[39mself\u001b[39;49m)\u001b[39m.\u001b[39;49m\u001b[39m__call__\u001b[39;49m(inputs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[0;32m    517\u001b[0m \u001b[39m# If any of `initial_state` or `constants` are specified and are Keras\u001b[39;00m\n\u001b[0;32m    518\u001b[0m \u001b[39m# tensors, then add them to the inputs and temporarily modify the\u001b[39;00m\n\u001b[0;32m    519\u001b[0m \u001b[39m# input_spec to include them.\u001b[39;00m\n\u001b[0;32m    521\u001b[0m additional_inputs \u001b[39m=\u001b[39m []\n",
      "File \u001b[1;32m~\\AppData\\Roaming\\Python\\Python38\\site-packages\\keras\\utils\\traceback_utils.py:64\u001b[0m, in \u001b[0;36mfilter_traceback.<locals>.error_handler\u001b[1;34m(*args, **kwargs)\u001b[0m\n\u001b[0;32m     62\u001b[0m filtered_tb \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m\n\u001b[0;32m     63\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[1;32m---> 64\u001b[0m   \u001b[39mreturn\u001b[39;00m fn(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[0;32m     65\u001b[0m \u001b[39mexcept\u001b[39;00m \u001b[39mException\u001b[39;00m \u001b[39mas\u001b[39;00m e:  \u001b[39m# pylint: disable=broad-except\u001b[39;00m\n\u001b[0;32m     66\u001b[0m   filtered_tb \u001b[39m=\u001b[39m _process_traceback_frames(e\u001b[39m.\u001b[39m__traceback__)\n",
      "File \u001b[1;32m~\\AppData\\Roaming\\Python\\Python38\\site-packages\\keras\\engine\\base_layer.py:1014\u001b[0m, in \u001b[0;36mLayer.__call__\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m   1010\u001b[0m   inputs \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_maybe_cast_inputs(inputs, input_list)\n\u001b[0;32m   1012\u001b[0m \u001b[39mwith\u001b[39;00m autocast_variable\u001b[39m.\u001b[39menable_auto_cast_variables(\n\u001b[0;32m   1013\u001b[0m     \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_compute_dtype_object):\n\u001b[1;32m-> 1014\u001b[0m   outputs \u001b[39m=\u001b[39m call_fn(inputs, \u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[0;32m   1016\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_activity_regularizer:\n\u001b[0;32m   1017\u001b[0m   \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_handle_activity_regularization(inputs, outputs)\n",
      "File \u001b[1;32m~\\AppData\\Roaming\\Python\\Python38\\site-packages\\keras\\utils\\traceback_utils.py:92\u001b[0m, in \u001b[0;36minject_argument_info_in_traceback.<locals>.error_handler\u001b[1;34m(*args, **kwargs)\u001b[0m\n\u001b[0;32m     90\u001b[0m bound_signature \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m\n\u001b[0;32m     91\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[1;32m---> 92\u001b[0m   \u001b[39mreturn\u001b[39;00m fn(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[0;32m     93\u001b[0m \u001b[39mexcept\u001b[39;00m \u001b[39mException\u001b[39;00m \u001b[39mas\u001b[39;00m e:  \u001b[39m# pylint: disable=broad-except\u001b[39;00m\n\u001b[0;32m     94\u001b[0m   \u001b[39mif\u001b[39;00m \u001b[39mhasattr\u001b[39m(e, \u001b[39m'\u001b[39m\u001b[39m_keras_call_info_injected\u001b[39m\u001b[39m'\u001b[39m):\n\u001b[0;32m     95\u001b[0m     \u001b[39m# Only inject info for the innermost failing call\u001b[39;00m\n",
      "File \u001b[1;32m~\\AppData\\Roaming\\Python\\Python38\\site-packages\\keras\\layers\\rnn\\lstm.py:673\u001b[0m, in \u001b[0;36mLSTM.call\u001b[1;34m(self, inputs, mask, training, initial_state)\u001b[0m\n\u001b[0;32m    669\u001b[0m         last_output, outputs, new_h, new_c, runtime \u001b[39m=\u001b[39m standard_lstm(\n\u001b[0;32m    670\u001b[0m             \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mnormal_lstm_kwargs)\n\u001b[0;32m    671\u001b[0m     \u001b[39melse\u001b[39;00m:\n\u001b[0;32m    672\u001b[0m       (last_output, outputs, new_h, new_c,\n\u001b[1;32m--> 673\u001b[0m        runtime) \u001b[39m=\u001b[39m lstm_with_backend_selection(\u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mnormal_lstm_kwargs)\n\u001b[0;32m    675\u001b[0m   states \u001b[39m=\u001b[39m [new_h, new_c]\n\u001b[0;32m    677\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mstateful:\n",
      "File \u001b[1;32m~\\AppData\\Roaming\\Python\\Python38\\site-packages\\keras\\layers\\rnn\\lstm.py:1183\u001b[0m, in \u001b[0;36mlstm_with_backend_selection\u001b[1;34m(inputs, init_h, init_c, kernel, recurrent_kernel, bias, mask, time_major, go_backwards, sequence_lengths, zero_output_for_mask, return_sequences)\u001b[0m\n\u001b[0;32m   1177\u001b[0m   defun_gpu_lstm \u001b[39m=\u001b[39m gru_lstm_utils\u001b[39m.\u001b[39mgenerate_defun_backend(\n\u001b[0;32m   1178\u001b[0m       api_name, gru_lstm_utils\u001b[39m.\u001b[39mGPU_DEVICE_NAME, gpu_lstm_with_fallback,\n\u001b[0;32m   1179\u001b[0m       supportive_attribute)\n\u001b[0;32m   1181\u001b[0m   \u001b[39m# Call the normal LSTM impl and register the cuDNN impl function. The\u001b[39;00m\n\u001b[0;32m   1182\u001b[0m   \u001b[39m# grappler will kick in during session execution to optimize the graph.\u001b[39;00m\n\u001b[1;32m-> 1183\u001b[0m   last_output, outputs, new_h, new_c, runtime \u001b[39m=\u001b[39m defun_standard_lstm(\u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mparams)\n\u001b[0;32m   1184\u001b[0m   gru_lstm_utils\u001b[39m.\u001b[39mfunction_register(defun_gpu_lstm, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mparams)\n\u001b[0;32m   1186\u001b[0m \u001b[39mreturn\u001b[39;00m last_output, outputs, new_h, new_c, runtime\n",
      "File \u001b[1;32m~\\AppData\\Roaming\\Python\\Python38\\site-packages\\tensorflow\\python\\eager\\function.py:2452\u001b[0m, in \u001b[0;36mFunction.__call__\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m   2449\u001b[0m \u001b[39m\"\"\"Calls a graph function specialized to the inputs.\"\"\"\u001b[39;00m\n\u001b[0;32m   2450\u001b[0m \u001b[39mwith\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_lock:\n\u001b[0;32m   2451\u001b[0m   (graph_function,\n\u001b[1;32m-> 2452\u001b[0m    filtered_flat_args) \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_maybe_define_function(args, kwargs)\n\u001b[0;32m   2453\u001b[0m \u001b[39mreturn\u001b[39;00m graph_function\u001b[39m.\u001b[39m_call_flat(\n\u001b[0;32m   2454\u001b[0m     filtered_flat_args, captured_inputs\u001b[39m=\u001b[39mgraph_function\u001b[39m.\u001b[39mcaptured_inputs)\n",
      "File \u001b[1;32m~\\AppData\\Roaming\\Python\\Python38\\site-packages\\tensorflow\\python\\eager\\function.py:2711\u001b[0m, in \u001b[0;36mFunction._maybe_define_function\u001b[1;34m(self, args, kwargs)\u001b[0m\n\u001b[0;32m   2708\u001b[0m   cache_key \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_function_cache\u001b[39m.\u001b[39mgeneralize(cache_key)\n\u001b[0;32m   2709\u001b[0m   (args, kwargs) \u001b[39m=\u001b[39m cache_key\u001b[39m.\u001b[39m_placeholder_value()  \u001b[39m# pylint: disable=protected-access\u001b[39;00m\n\u001b[1;32m-> 2711\u001b[0m graph_function \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_create_graph_function(args, kwargs)\n\u001b[0;32m   2712\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_function_cache\u001b[39m.\u001b[39madd(cache_key, cache_key_deletion_observer,\n\u001b[0;32m   2713\u001b[0m                          graph_function)\n\u001b[0;32m   2715\u001b[0m \u001b[39mreturn\u001b[39;00m graph_function, filtered_flat_args\n",
      "File \u001b[1;32m~\\AppData\\Roaming\\Python\\Python38\\site-packages\\tensorflow\\python\\eager\\function.py:2627\u001b[0m, in \u001b[0;36mFunction._create_graph_function\u001b[1;34m(self, args, kwargs)\u001b[0m\n\u001b[0;32m   2622\u001b[0m missing_arg_names \u001b[39m=\u001b[39m [\n\u001b[0;32m   2623\u001b[0m     \u001b[39m\"\u001b[39m\u001b[39m%s\u001b[39;00m\u001b[39m_\u001b[39m\u001b[39m%d\u001b[39;00m\u001b[39m\"\u001b[39m \u001b[39m%\u001b[39m (arg, i) \u001b[39mfor\u001b[39;00m i, arg \u001b[39min\u001b[39;00m \u001b[39menumerate\u001b[39m(missing_arg_names)\n\u001b[0;32m   2624\u001b[0m ]\n\u001b[0;32m   2625\u001b[0m arg_names \u001b[39m=\u001b[39m base_arg_names \u001b[39m+\u001b[39m missing_arg_names\n\u001b[0;32m   2626\u001b[0m graph_function \u001b[39m=\u001b[39m ConcreteFunction(\n\u001b[1;32m-> 2627\u001b[0m     func_graph_module\u001b[39m.\u001b[39;49mfunc_graph_from_py_func(\n\u001b[0;32m   2628\u001b[0m         \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_name,\n\u001b[0;32m   2629\u001b[0m         \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_python_function,\n\u001b[0;32m   2630\u001b[0m         args,\n\u001b[0;32m   2631\u001b[0m         kwargs,\n\u001b[0;32m   2632\u001b[0m         \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49minput_signature,\n\u001b[0;32m   2633\u001b[0m         autograph\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_autograph,\n\u001b[0;32m   2634\u001b[0m         autograph_options\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_autograph_options,\n\u001b[0;32m   2635\u001b[0m         arg_names\u001b[39m=\u001b[39;49marg_names,\n\u001b[0;32m   2636\u001b[0m         capture_by_value\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_capture_by_value),\n\u001b[0;32m   2637\u001b[0m     \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_function_attributes,\n\u001b[0;32m   2638\u001b[0m     spec\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mfunction_spec,\n\u001b[0;32m   2639\u001b[0m     \u001b[39m# Tell the ConcreteFunction to clean up its graph once it goes out of\u001b[39;00m\n\u001b[0;32m   2640\u001b[0m     \u001b[39m# scope. This is not the default behavior since it gets used in some\u001b[39;00m\n\u001b[0;32m   2641\u001b[0m     \u001b[39m# places (like Keras) where the FuncGraph lives longer than the\u001b[39;00m\n\u001b[0;32m   2642\u001b[0m     \u001b[39m# ConcreteFunction.\u001b[39;00m\n\u001b[0;32m   2643\u001b[0m     shared_func_graph\u001b[39m=\u001b[39m\u001b[39mFalse\u001b[39;00m)\n\u001b[0;32m   2644\u001b[0m \u001b[39mreturn\u001b[39;00m graph_function\n",
      "File \u001b[1;32m~\\AppData\\Roaming\\Python\\Python38\\site-packages\\tensorflow\\python\\framework\\func_graph.py:1141\u001b[0m, in \u001b[0;36mfunc_graph_from_py_func\u001b[1;34m(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, acd_record_initial_resource_uses)\u001b[0m\n\u001b[0;32m   1138\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m   1139\u001b[0m   _, original_func \u001b[39m=\u001b[39m tf_decorator\u001b[39m.\u001b[39munwrap(python_func)\n\u001b[1;32m-> 1141\u001b[0m func_outputs \u001b[39m=\u001b[39m python_func(\u001b[39m*\u001b[39;49mfunc_args, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mfunc_kwargs)\n\u001b[0;32m   1143\u001b[0m \u001b[39m# invariant: `func_outputs` contains only Tensors, CompositeTensors,\u001b[39;00m\n\u001b[0;32m   1144\u001b[0m \u001b[39m# TensorArrays and `None`s.\u001b[39;00m\n\u001b[0;32m   1145\u001b[0m func_outputs \u001b[39m=\u001b[39m nest\u001b[39m.\u001b[39mmap_structure(\n\u001b[0;32m   1146\u001b[0m     convert, func_outputs, expand_composites\u001b[39m=\u001b[39m\u001b[39mTrue\u001b[39;00m)\n",
      "File \u001b[1;32m~\\AppData\\Roaming\\Python\\Python38\\site-packages\\keras\\layers\\rnn\\lstm.py:891\u001b[0m, in \u001b[0;36mstandard_lstm\u001b[1;34m(inputs, init_h, init_c, kernel, recurrent_kernel, bias, mask, time_major, go_backwards, sequence_lengths, zero_output_for_mask, return_sequences)\u001b[0m\n\u001b[0;32m    888\u001b[0m   h \u001b[39m=\u001b[39m o \u001b[39m*\u001b[39m tf\u001b[39m.\u001b[39mtanh(c)\n\u001b[0;32m    889\u001b[0m   \u001b[39mreturn\u001b[39;00m h, [h, c]\n\u001b[1;32m--> 891\u001b[0m last_output, outputs, new_states \u001b[39m=\u001b[39m backend\u001b[39m.\u001b[39;49mrnn(\n\u001b[0;32m    892\u001b[0m     step,\n\u001b[0;32m    893\u001b[0m     inputs, [init_h, init_c],\n\u001b[0;32m    894\u001b[0m     constants\u001b[39m=\u001b[39;49m\u001b[39mNone\u001b[39;49;00m,\n\u001b[0;32m    895\u001b[0m     unroll\u001b[39m=\u001b[39;49m\u001b[39mFalse\u001b[39;49;00m,\n\u001b[0;32m    896\u001b[0m     time_major\u001b[39m=\u001b[39;49mtime_major,\n\u001b[0;32m    897\u001b[0m     mask\u001b[39m=\u001b[39;49mmask,\n\u001b[0;32m    898\u001b[0m     go_backwards\u001b[39m=\u001b[39;49mgo_backwards,\n\u001b[0;32m    899\u001b[0m     input_length\u001b[39m=\u001b[39;49m(sequence_lengths\n\u001b[0;32m    900\u001b[0m                   \u001b[39mif\u001b[39;49;00m sequence_lengths \u001b[39mis\u001b[39;49;00m \u001b[39mnot\u001b[39;49;00m \u001b[39mNone\u001b[39;49;00m \u001b[39melse\u001b[39;49;00m timesteps),\n\u001b[0;32m    901\u001b[0m     zero_output_for_mask\u001b[39m=\u001b[39;49mzero_output_for_mask,\n\u001b[0;32m    902\u001b[0m     return_all_outputs\u001b[39m=\u001b[39;49mreturn_sequences)\n\u001b[0;32m    903\u001b[0m \u001b[39mreturn\u001b[39;00m (last_output, outputs, new_states[\u001b[39m0\u001b[39m], new_states[\u001b[39m1\u001b[39m],\n\u001b[0;32m    904\u001b[0m         gru_lstm_utils\u001b[39m.\u001b[39mruntime(gru_lstm_utils\u001b[39m.\u001b[39mRUNTIME_CPU))\n",
      "File \u001b[1;32m~\\AppData\\Roaming\\Python\\Python38\\site-packages\\tensorflow\\python\\util\\traceback_utils.py:150\u001b[0m, in \u001b[0;36mfilter_traceback.<locals>.error_handler\u001b[1;34m(*args, **kwargs)\u001b[0m\n\u001b[0;32m    148\u001b[0m filtered_tb \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m\n\u001b[0;32m    149\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[1;32m--> 150\u001b[0m   \u001b[39mreturn\u001b[39;00m fn(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[0;32m    151\u001b[0m \u001b[39mexcept\u001b[39;00m \u001b[39mException\u001b[39;00m \u001b[39mas\u001b[39;00m e:\n\u001b[0;32m    152\u001b[0m   filtered_tb \u001b[39m=\u001b[39m _process_traceback_frames(e\u001b[39m.\u001b[39m__traceback__)\n",
      "File \u001b[1;32m~\\AppData\\Roaming\\Python\\Python38\\site-packages\\tensorflow\\python\\util\\dispatch.py:1082\u001b[0m, in \u001b[0;36madd_dispatch_support.<locals>.decorator.<locals>.op_dispatch_handler\u001b[1;34m(*args, **kwargs)\u001b[0m\n\u001b[0;32m   1080\u001b[0m \u001b[39m# Fallback dispatch system (dispatch v1):\u001b[39;00m\n\u001b[0;32m   1081\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[1;32m-> 1082\u001b[0m   \u001b[39mreturn\u001b[39;00m dispatch_target(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[0;32m   1083\u001b[0m \u001b[39mexcept\u001b[39;00m (\u001b[39mTypeError\u001b[39;00m, \u001b[39mValueError\u001b[39;00m):\n\u001b[0;32m   1084\u001b[0m   \u001b[39m# Note: convert_to_eager_tensor currently raises a ValueError, not a\u001b[39;00m\n\u001b[0;32m   1085\u001b[0m   \u001b[39m# TypeError, when given unexpected types.  So we need to catch both.\u001b[39;00m\n\u001b[0;32m   1086\u001b[0m   result \u001b[39m=\u001b[39m dispatch(op_dispatch_handler, args, kwargs)\n",
      "File \u001b[1;32m~\\AppData\\Roaming\\Python\\Python38\\site-packages\\keras\\backend.py:4624\u001b[0m, in \u001b[0;36mrnn\u001b[1;34m(step_function, inputs, initial_states, go_backwards, mask, constants, unroll, input_length, time_major, zero_output_for_mask, return_all_outputs)\u001b[0m\n\u001b[0;32m   4620\u001b[0m input_time_zero \u001b[39m=\u001b[39m tf\u001b[39m.\u001b[39mnest\u001b[39m.\u001b[39mpack_sequence_as(inputs,\n\u001b[0;32m   4621\u001b[0m                                         [inp[\u001b[39m0\u001b[39m] \u001b[39mfor\u001b[39;00m inp \u001b[39min\u001b[39;00m flatted_inputs])\n\u001b[0;32m   4622\u001b[0m \u001b[39m# output_time_zero is used to determine the cell output shape and its dtype.\u001b[39;00m\n\u001b[0;32m   4623\u001b[0m \u001b[39m# the value is discarded.\u001b[39;00m\n\u001b[1;32m-> 4624\u001b[0m output_time_zero, _ \u001b[39m=\u001b[39m step_function(\n\u001b[0;32m   4625\u001b[0m     input_time_zero, \u001b[39mtuple\u001b[39;49m(initial_states) \u001b[39m+\u001b[39;49m \u001b[39mtuple\u001b[39;49m(constants))\n\u001b[0;32m   4627\u001b[0m output_ta_size \u001b[39m=\u001b[39m time_steps_t \u001b[39mif\u001b[39;00m return_all_outputs \u001b[39melse\u001b[39;00m \u001b[39m1\u001b[39m\n\u001b[0;32m   4628\u001b[0m output_ta \u001b[39m=\u001b[39m \u001b[39mtuple\u001b[39m(\n\u001b[0;32m   4629\u001b[0m     tf\u001b[39m.\u001b[39mTensorArray(\n\u001b[0;32m   4630\u001b[0m         dtype\u001b[39m=\u001b[39mout\u001b[39m.\u001b[39mdtype,\n\u001b[1;32m   (...)\u001b[0m\n\u001b[0;32m   4633\u001b[0m         tensor_array_name\u001b[39m=\u001b[39m\u001b[39m'\u001b[39m\u001b[39moutput_ta_\u001b[39m\u001b[39m%s\u001b[39;00m\u001b[39m'\u001b[39m \u001b[39m%\u001b[39m i)\n\u001b[0;32m   4634\u001b[0m     \u001b[39mfor\u001b[39;00m i, out \u001b[39min\u001b[39;00m \u001b[39menumerate\u001b[39m(tf\u001b[39m.\u001b[39mnest\u001b[39m.\u001b[39mflatten(output_time_zero)))\n",
      "File \u001b[1;32m~\\AppData\\Roaming\\Python\\Python38\\site-packages\\keras\\layers\\rnn\\lstm.py:878\u001b[0m, in \u001b[0;36mstandard_lstm.<locals>.step\u001b[1;34m(cell_inputs, cell_states)\u001b[0m\n\u001b[0;32m    875\u001b[0m c_tm1 \u001b[39m=\u001b[39m cell_states[\u001b[39m1\u001b[39m]  \u001b[39m# previous carry state\u001b[39;00m\n\u001b[0;32m    877\u001b[0m z \u001b[39m=\u001b[39m backend\u001b[39m.\u001b[39mdot(cell_inputs, kernel)\n\u001b[1;32m--> 878\u001b[0m z \u001b[39m+\u001b[39m\u001b[39m=\u001b[39m backend\u001b[39m.\u001b[39;49mdot(h_tm1, recurrent_kernel)\n\u001b[0;32m    879\u001b[0m z \u001b[39m=\u001b[39m backend\u001b[39m.\u001b[39mbias_add(z, bias)\n\u001b[0;32m    881\u001b[0m z0, z1, z2, z3 \u001b[39m=\u001b[39m tf\u001b[39m.\u001b[39msplit(z, \u001b[39m4\u001b[39m, axis\u001b[39m=\u001b[39m\u001b[39m1\u001b[39m)\n",
      "File \u001b[1;32m~\\AppData\\Roaming\\Python\\Python38\\site-packages\\tensorflow\\python\\util\\traceback_utils.py:150\u001b[0m, in \u001b[0;36mfilter_traceback.<locals>.error_handler\u001b[1;34m(*args, **kwargs)\u001b[0m\n\u001b[0;32m    148\u001b[0m filtered_tb \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m\n\u001b[0;32m    149\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[1;32m--> 150\u001b[0m   \u001b[39mreturn\u001b[39;00m fn(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[0;32m    151\u001b[0m \u001b[39mexcept\u001b[39;00m \u001b[39mException\u001b[39;00m \u001b[39mas\u001b[39;00m e:\n\u001b[0;32m    152\u001b[0m   filtered_tb \u001b[39m=\u001b[39m _process_traceback_frames(e\u001b[39m.\u001b[39m__traceback__)\n",
      "File \u001b[1;32m~\\AppData\\Roaming\\Python\\Python38\\site-packages\\tensorflow\\python\\util\\dispatch.py:1082\u001b[0m, in \u001b[0;36madd_dispatch_support.<locals>.decorator.<locals>.op_dispatch_handler\u001b[1;34m(*args, **kwargs)\u001b[0m\n\u001b[0;32m   1080\u001b[0m \u001b[39m# Fallback dispatch system (dispatch v1):\u001b[39;00m\n\u001b[0;32m   1081\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[1;32m-> 1082\u001b[0m   \u001b[39mreturn\u001b[39;00m dispatch_target(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[0;32m   1083\u001b[0m \u001b[39mexcept\u001b[39;00m (\u001b[39mTypeError\u001b[39;00m, \u001b[39mValueError\u001b[39;00m):\n\u001b[0;32m   1084\u001b[0m   \u001b[39m# Note: convert_to_eager_tensor currently raises a ValueError, not a\u001b[39;00m\n\u001b[0;32m   1085\u001b[0m   \u001b[39m# TypeError, when given unexpected types.  So we need to catch both.\u001b[39;00m\n\u001b[0;32m   1086\u001b[0m   result \u001b[39m=\u001b[39m dispatch(op_dispatch_handler, args, kwargs)\n",
      "File \u001b[1;32m~\\AppData\\Roaming\\Python\\Python38\\site-packages\\keras\\backend.py:2223\u001b[0m, in \u001b[0;36mdot\u001b[1;34m(x, y)\u001b[0m\n\u001b[0;32m   2221\u001b[0m   out \u001b[39m=\u001b[39m tf\u001b[39m.\u001b[39msparse\u001b[39m.\u001b[39msparse_dense_matmul(x, y)\n\u001b[0;32m   2222\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m-> 2223\u001b[0m   out \u001b[39m=\u001b[39m tf\u001b[39m.\u001b[39;49mmatmul(x, y)\n\u001b[0;32m   2224\u001b[0m \u001b[39mreturn\u001b[39;00m out\n",
      "File \u001b[1;32m~\\AppData\\Roaming\\Python\\Python38\\site-packages\\tensorflow\\python\\util\\traceback_utils.py:150\u001b[0m, in \u001b[0;36mfilter_traceback.<locals>.error_handler\u001b[1;34m(*args, **kwargs)\u001b[0m\n\u001b[0;32m    148\u001b[0m filtered_tb \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m\n\u001b[0;32m    149\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[1;32m--> 150\u001b[0m   \u001b[39mreturn\u001b[39;00m fn(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[0;32m    151\u001b[0m \u001b[39mexcept\u001b[39;00m \u001b[39mException\u001b[39;00m \u001b[39mas\u001b[39;00m e:\n\u001b[0;32m    152\u001b[0m   filtered_tb \u001b[39m=\u001b[39m _process_traceback_frames(e\u001b[39m.\u001b[39m__traceback__)\n",
      "File \u001b[1;32m~\\AppData\\Roaming\\Python\\Python38\\site-packages\\tensorflow\\python\\util\\dispatch.py:1082\u001b[0m, in \u001b[0;36madd_dispatch_support.<locals>.decorator.<locals>.op_dispatch_handler\u001b[1;34m(*args, **kwargs)\u001b[0m\n\u001b[0;32m   1080\u001b[0m \u001b[39m# Fallback dispatch system (dispatch v1):\u001b[39;00m\n\u001b[0;32m   1081\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[1;32m-> 1082\u001b[0m   \u001b[39mreturn\u001b[39;00m dispatch_target(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[0;32m   1083\u001b[0m \u001b[39mexcept\u001b[39;00m (\u001b[39mTypeError\u001b[39;00m, \u001b[39mValueError\u001b[39;00m):\n\u001b[0;32m   1084\u001b[0m   \u001b[39m# Note: convert_to_eager_tensor currently raises a ValueError, not a\u001b[39;00m\n\u001b[0;32m   1085\u001b[0m   \u001b[39m# TypeError, when given unexpected types.  So we need to catch both.\u001b[39;00m\n\u001b[0;32m   1086\u001b[0m   result \u001b[39m=\u001b[39m dispatch(op_dispatch_handler, args, kwargs)\n",
      "File \u001b[1;32m~\\AppData\\Roaming\\Python\\Python38\\site-packages\\tensorflow\\python\\ops\\math_ops.py:3713\u001b[0m, in \u001b[0;36mmatmul\u001b[1;34m(a, b, transpose_a, transpose_b, adjoint_a, adjoint_b, a_is_sparse, b_is_sparse, output_type, name)\u001b[0m\n\u001b[0;32m   3710\u001b[0m   \u001b[39mreturn\u001b[39;00m gen_math_ops\u001b[39m.\u001b[39mbatch_mat_mul_v3(\n\u001b[0;32m   3711\u001b[0m       a, b, adj_x\u001b[39m=\u001b[39madjoint_a, adj_y\u001b[39m=\u001b[39madjoint_b, Tout\u001b[39m=\u001b[39moutput_type, name\u001b[39m=\u001b[39mname)\n\u001b[0;32m   3712\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m-> 3713\u001b[0m   \u001b[39mreturn\u001b[39;00m gen_math_ops\u001b[39m.\u001b[39;49mmat_mul(\n\u001b[0;32m   3714\u001b[0m       a, b, transpose_a\u001b[39m=\u001b[39;49mtranspose_a, transpose_b\u001b[39m=\u001b[39;49mtranspose_b, name\u001b[39m=\u001b[39;49mname)\n",
      "File \u001b[1;32m~\\AppData\\Roaming\\Python\\Python38\\site-packages\\tensorflow\\python\\ops\\gen_math_ops.py:6033\u001b[0m, in \u001b[0;36mmat_mul\u001b[1;34m(a, b, transpose_a, transpose_b, name)\u001b[0m\n\u001b[0;32m   6031\u001b[0m   transpose_b \u001b[39m=\u001b[39m \u001b[39mFalse\u001b[39;00m\n\u001b[0;32m   6032\u001b[0m transpose_b \u001b[39m=\u001b[39m _execute\u001b[39m.\u001b[39mmake_bool(transpose_b, \u001b[39m\"\u001b[39m\u001b[39mtranspose_b\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[1;32m-> 6033\u001b[0m _, _, _op, _outputs \u001b[39m=\u001b[39m _op_def_library\u001b[39m.\u001b[39;49m_apply_op_helper(\n\u001b[0;32m   6034\u001b[0m       \u001b[39m\"\u001b[39;49m\u001b[39mMatMul\u001b[39;49m\u001b[39m\"\u001b[39;49m, a\u001b[39m=\u001b[39;49ma, b\u001b[39m=\u001b[39;49mb, transpose_a\u001b[39m=\u001b[39;49mtranspose_a, transpose_b\u001b[39m=\u001b[39;49mtranspose_b,\n\u001b[0;32m   6035\u001b[0m                 name\u001b[39m=\u001b[39;49mname)\n\u001b[0;32m   6036\u001b[0m _result \u001b[39m=\u001b[39m _outputs[:]\n\u001b[0;32m   6037\u001b[0m \u001b[39mif\u001b[39;00m _execute\u001b[39m.\u001b[39mmust_record_gradient():\n",
      "File \u001b[1;32m~\\AppData\\Roaming\\Python\\Python38\\site-packages\\tensorflow\\python\\framework\\op_def_library.py:797\u001b[0m, in \u001b[0;36m_apply_op_helper\u001b[1;34m(op_type_name, name, **keywords)\u001b[0m\n\u001b[0;32m    792\u001b[0m must_colocate_inputs \u001b[39m=\u001b[39m [val \u001b[39mfor\u001b[39;00m arg, val \u001b[39min\u001b[39;00m \u001b[39mzip\u001b[39m(op_def\u001b[39m.\u001b[39minput_arg, inputs)\n\u001b[0;32m    793\u001b[0m                         \u001b[39mif\u001b[39;00m arg\u001b[39m.\u001b[39mis_ref]\n\u001b[0;32m    794\u001b[0m \u001b[39mwith\u001b[39;00m _MaybeColocateWith(must_colocate_inputs):\n\u001b[0;32m    795\u001b[0m   \u001b[39m# Add Op to graph\u001b[39;00m\n\u001b[0;32m    796\u001b[0m   \u001b[39m# pylint: disable=protected-access\u001b[39;00m\n\u001b[1;32m--> 797\u001b[0m   op \u001b[39m=\u001b[39m g\u001b[39m.\u001b[39;49m_create_op_internal(op_type_name, inputs, dtypes\u001b[39m=\u001b[39;49m\u001b[39mNone\u001b[39;49;00m,\n\u001b[0;32m    798\u001b[0m                              name\u001b[39m=\u001b[39;49mscope, input_types\u001b[39m=\u001b[39;49minput_types,\n\u001b[0;32m    799\u001b[0m                              attrs\u001b[39m=\u001b[39;49mattr_protos, op_def\u001b[39m=\u001b[39;49mop_def)\n\u001b[0;32m    801\u001b[0m \u001b[39m# `outputs` is returned as a separate return value so that the output\u001b[39;00m\n\u001b[0;32m    802\u001b[0m \u001b[39m# tensors can the `op` per se can be decoupled so that the\u001b[39;00m\n\u001b[0;32m    803\u001b[0m \u001b[39m# `op_callbacks` can function properly. See framework/op_callbacks.py\u001b[39;00m\n\u001b[0;32m    804\u001b[0m \u001b[39m# for more details.\u001b[39;00m\n\u001b[0;32m    805\u001b[0m outputs \u001b[39m=\u001b[39m op\u001b[39m.\u001b[39moutputs\n",
      "File \u001b[1;32m~\\AppData\\Roaming\\Python\\Python38\\site-packages\\tensorflow\\python\\framework\\func_graph.py:694\u001b[0m, in \u001b[0;36mFuncGraph._create_op_internal\u001b[1;34m(self, op_type, inputs, dtypes, input_types, name, attrs, op_def, compute_device)\u001b[0m\n\u001b[0;32m    692\u001b[0m   inp \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mcapture(inp)\n\u001b[0;32m    693\u001b[0m   captured_inputs\u001b[39m.\u001b[39mappend(inp)\n\u001b[1;32m--> 694\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39msuper\u001b[39;49m(FuncGraph, \u001b[39mself\u001b[39;49m)\u001b[39m.\u001b[39;49m_create_op_internal(  \u001b[39m# pylint: disable=protected-access\u001b[39;49;00m\n\u001b[0;32m    695\u001b[0m     op_type, captured_inputs, dtypes, input_types, name, attrs, op_def,\n\u001b[0;32m    696\u001b[0m     compute_device)\n",
      "File \u001b[1;32m~\\AppData\\Roaming\\Python\\Python38\\site-packages\\tensorflow\\python\\framework\\ops.py:3754\u001b[0m, in \u001b[0;36mGraph._create_op_internal\u001b[1;34m(self, op_type, inputs, dtypes, input_types, name, attrs, op_def, compute_device)\u001b[0m\n\u001b[0;32m   3751\u001b[0m \u001b[39m# _create_op_helper mutates the new Operation. `_mutation_lock` ensures a\u001b[39;00m\n\u001b[0;32m   3752\u001b[0m \u001b[39m# Session.run call cannot occur between creating and mutating the op.\u001b[39;00m\n\u001b[0;32m   3753\u001b[0m \u001b[39mwith\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_mutation_lock():\n\u001b[1;32m-> 3754\u001b[0m   ret \u001b[39m=\u001b[39m Operation(\n\u001b[0;32m   3755\u001b[0m       node_def,\n\u001b[0;32m   3756\u001b[0m       \u001b[39mself\u001b[39;49m,\n\u001b[0;32m   3757\u001b[0m       inputs\u001b[39m=\u001b[39;49minputs,\n\u001b[0;32m   3758\u001b[0m       output_types\u001b[39m=\u001b[39;49mdtypes,\n\u001b[0;32m   3759\u001b[0m       control_inputs\u001b[39m=\u001b[39;49mcontrol_inputs,\n\u001b[0;32m   3760\u001b[0m       input_types\u001b[39m=\u001b[39;49minput_types,\n\u001b[0;32m   3761\u001b[0m       original_op\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_default_original_op,\n\u001b[0;32m   3762\u001b[0m       op_def\u001b[39m=\u001b[39;49mop_def)\n\u001b[0;32m   3763\u001b[0m   \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_create_op_helper(ret, compute_device\u001b[39m=\u001b[39mcompute_device)\n\u001b[0;32m   3764\u001b[0m \u001b[39mreturn\u001b[39;00m ret\n",
      "File \u001b[1;32m~\\AppData\\Roaming\\Python\\Python38\\site-packages\\tensorflow\\python\\framework\\ops.py:2129\u001b[0m, in \u001b[0;36mOperation.__init__\u001b[1;34m(self, node_def, g, inputs, output_types, control_inputs, input_types, original_op, op_def)\u001b[0m\n\u001b[0;32m   2127\u001b[0m   \u001b[39mif\u001b[39;00m op_def \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[0;32m   2128\u001b[0m     op_def \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_graph\u001b[39m.\u001b[39m_get_op_def(node_def\u001b[39m.\u001b[39mop)\n\u001b[1;32m-> 2129\u001b[0m   \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_c_op \u001b[39m=\u001b[39m _create_c_op(\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_graph, node_def, inputs,\n\u001b[0;32m   2130\u001b[0m                             control_input_ops, op_def)\n\u001b[0;32m   2131\u001b[0m   name \u001b[39m=\u001b[39m compat\u001b[39m.\u001b[39mas_str(node_def\u001b[39m.\u001b[39mname)\n\u001b[0;32m   2133\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_traceback \u001b[39m=\u001b[39m tf_stack\u001b[39m.\u001b[39mextract_stack_for_node(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_c_op)\n",
      "File \u001b[1;32m~\\AppData\\Roaming\\Python\\Python38\\site-packages\\tensorflow\\python\\util\\traceback_utils.py:150\u001b[0m, in \u001b[0;36mfilter_traceback.<locals>.error_handler\u001b[1;34m(*args, **kwargs)\u001b[0m\n\u001b[0;32m    148\u001b[0m filtered_tb \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m\n\u001b[0;32m    149\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[1;32m--> 150\u001b[0m   \u001b[39mreturn\u001b[39;00m fn(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[0;32m    151\u001b[0m \u001b[39mexcept\u001b[39;00m \u001b[39mException\u001b[39;00m \u001b[39mas\u001b[39;00m e:\n\u001b[0;32m    152\u001b[0m   filtered_tb \u001b[39m=\u001b[39m _process_traceback_frames(e\u001b[39m.\u001b[39m__traceback__)\n",
      "File \u001b[1;32m~\\AppData\\Roaming\\Python\\Python38\\site-packages\\tensorflow\\python\\framework\\ops.py:1960\u001b[0m, in \u001b[0;36m_create_c_op\u001b[1;34m(graph, node_def, inputs, control_inputs, op_def)\u001b[0m\n\u001b[0;32m   1956\u001b[0m   pywrap_tf_session\u001b[39m.\u001b[39mTF_SetAttrValueProto(op_desc, compat\u001b[39m.\u001b[39mas_str(name),\n\u001b[0;32m   1957\u001b[0m                                          serialized)\n\u001b[0;32m   1959\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[1;32m-> 1960\u001b[0m   c_op \u001b[39m=\u001b[39m pywrap_tf_session\u001b[39m.\u001b[39;49mTF_FinishOperation(op_desc)\n\u001b[0;32m   1961\u001b[0m \u001b[39mexcept\u001b[39;00m errors\u001b[39m.\u001b[39mInvalidArgumentError \u001b[39mas\u001b[39;00m e:\n\u001b[0;32m   1962\u001b[0m   \u001b[39m# Convert to ValueError for backwards compatibility.\u001b[39;00m\n\u001b[0;32m   1963\u001b[0m   \u001b[39mraise\u001b[39;00m \u001b[39mValueError\u001b[39;00m(e\u001b[39m.\u001b[39mmessage)\n",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "AttnLSTM.compile(optimizer=opt, loss='categorical_crossentropy', metrics=['categorical_accuracy'])\n",
    "AttnLSTM.fit(X_train, y_train, batch_size=batch_size, epochs=max_epochs, validation_data=(X_val, y_val), callbacks=callbacks)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "b89f67cf",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Model map\n",
    "models = {\n",
    "    'LSTM': lstm, \n",
    "    'LSTM_Attention_128HUs': AttnLSTM, \n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a928f612",
   "metadata": {},
   "source": [
    "# 7a. Save Weights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "0a7647ad",
   "metadata": {},
   "outputs": [],
   "source": [
    "for model_name, model in models.items():\n",
    "    save_dir = os.path.join(os.getcwd(), f\"{model_name}.h5\")\n",
    "    model.save(save_dir)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "13fecf26",
   "metadata": {},
   "source": [
    "# 7b. Load Weights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "ed0114a6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Run model rebuild before doing this\n",
    "for model_name, model in models.items():\n",
    "    load_dir = os.path.join(os.getcwd(), f\"{model_name}.h5\")\n",
    "    model.load_weights(load_dir)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1b7747c6",
   "metadata": {},
   "source": [
    "# 8. Make Predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "2101a592",
   "metadata": {},
   "outputs": [],
   "source": [
    "for model in models.values():\n",
    "    res = model.predict(X_test, verbose=0)   "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6b36c98e",
   "metadata": {},
   "source": [
    "# 9. Evaluations using Confusion Matrix and Accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "ecf242d9",
   "metadata": {},
   "outputs": [],
   "source": [
    "eval_results = {}\n",
    "eval_results['confusion matrix'] = None\n",
    "eval_results['accuracy'] = None\n",
    "eval_results['precision'] = None\n",
    "eval_results['recall'] = None\n",
    "eval_results['f1 score'] = None\n",
    "\n",
    "confusion_matrices = {}\n",
    "classification_accuracies = {}   \n",
    "precisions = {}\n",
    "recalls = {}\n",
    "f1_scores = {} "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "74d6778f",
   "metadata": {},
   "source": [
    "## 9a. Confusion Matrices"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "fccbb90f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "LSTM confusion matrix: \n",
      "[[[27  5]\n",
      "  [ 1 12]]\n",
      "\n",
      " [[26  0]\n",
      "  [ 5 14]]\n",
      "\n",
      " [[31  1]\n",
      "  [ 0 13]]]\n",
      "LSTM_Attention_128HUs confusion matrix: \n",
      "[[[32  0]\n",
      "  [ 0 13]]\n",
      "\n",
      " [[26  0]\n",
      "  [ 0 19]]\n",
      "\n",
      " [[32  0]\n",
      "  [ 0 13]]]\n"
     ]
    }
   ],
   "source": [
    "for model_name, model in models.items():\n",
    "    yhat = model.predict(X_test, verbose=0)\n",
    "    \n",
    "    # Get list of classification predictions\n",
    "    ytrue = np.argmax(y_test, axis=1).tolist()\n",
    "    yhat = np.argmax(yhat, axis=1).tolist()\n",
    "    \n",
    "    # Confusion matrix\n",
    "    confusion_matrices[model_name] = multilabel_confusion_matrix(ytrue, yhat)\n",
    "    print(f\"{model_name} confusion matrix: {os.linesep}{confusion_matrices[model_name]}\")\n",
    "\n",
    "# Collect results \n",
    "eval_results['confusion matrix'] = confusion_matrices"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b76c6dc5",
   "metadata": {},
   "source": [
    "## 9b. Accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "id": "e36146f5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "LSTM classification accuracy = 86.667%\n",
      "LSTM_Attention_128HUs classification accuracy = 100.0%\n"
     ]
    }
   ],
   "source": [
    "for model_name, model in models.items():\n",
    "    yhat = model.predict(X_test, verbose=0)\n",
    "    \n",
    "    # Get list of classification predictions\n",
    "    ytrue = np.argmax(y_test, axis=1).tolist()\n",
    "    yhat = np.argmax(yhat, axis=1).tolist()\n",
    "    \n",
    "    # Model accuracy\n",
    "    classification_accuracies[model_name] = accuracy_score(ytrue, yhat)    \n",
    "    print(f\"{model_name} classification accuracy = {round(classification_accuracies[model_name]*100,3)}%\")\n",
    "\n",
    "# Collect results \n",
    "eval_results['accuracy'] = classification_accuracies"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "33efa73a",
   "metadata": {},
   "source": [
    "## 9c. Precision, Recall, and F1 Score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "35067c48",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "LSTM weighted average precision = 0.894\n",
      "LSTM weighted average recall = 0.867\n",
      "LSTM weighted average f1-score = 0.868\n",
      "\n",
      "LSTM_Attention_128HUs weighted average precision = 1.0\n",
      "LSTM_Attention_128HUs weighted average recall = 1.0\n",
      "LSTM_Attention_128HUs weighted average f1-score = 1.0\n",
      "\n"
     ]
    }
   ],
   "source": [
    "for model_name, model in models.items():\n",
    "    yhat = model.predict(X_test, verbose=0)\n",
    "    \n",
    "    # Get list of classification predictions\n",
    "    ytrue = np.argmax(y_test, axis=1).tolist()\n",
    "    yhat = np.argmax(yhat, axis=1).tolist()\n",
    "    \n",
    "    # Precision, recall, and f1 score\n",
    "    report = classification_report(ytrue, yhat, target_names=actions, output_dict=True)\n",
    "    \n",
    "    precisions[model_name] = report['weighted avg']['precision']\n",
    "    recalls[model_name] = report['weighted avg']['recall']\n",
    "    f1_scores[model_name] = report['weighted avg']['f1-score'] \n",
    "   \n",
    "    print(f\"{model_name} weighted average precision = {round(precisions[model_name],3)}\")\n",
    "    print(f\"{model_name} weighted average recall = {round(recalls[model_name],3)}\")\n",
    "    print(f\"{model_name} weighted average f1-score = {round(f1_scores[model_name],3)}\\n\")\n",
    "\n",
    "# Collect results \n",
    "eval_results['precision'] = precisions\n",
    "eval_results['recall'] = recalls\n",
    "eval_results['f1 score'] = f1_scores"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e5d39476",
   "metadata": {},
   "source": [
    "# 10. Choose Model to Test in Real Time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "id": "d72d0605",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = AttnLSTM\n",
    "model_name = 'AttnLSTM'"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9f0015ce",
   "metadata": {},
   "source": [
    "# 11. Calculate Joint Angles & Count Reps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "id": "f172932f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def calculate_angle(a,b,c):\n",
    "    \"\"\"\n",
    "    Computes 3D joint angle inferred by 3 keypoints and their relative positions to one another\n",
    "    \n",
    "    \"\"\"\n",
    "    a = np.array(a) # First\n",
    "    b = np.array(b) # Mid\n",
    "    c = np.array(c) # End\n",
    "    \n",
    "    radians = np.arctan2(c[1]-b[1], c[0]-b[0]) - np.arctan2(a[1]-b[1], a[0]-b[0])\n",
    "    angle = np.abs(radians*180.0/np.pi)\n",
    "    \n",
    "    if angle >180.0:\n",
    "        angle = 360-angle\n",
    "        \n",
    "    return angle "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "id": "26f357fe",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_coordinates(landmarks, mp_pose, side, joint):\n",
    "    \"\"\"\n",
    "    Retrieves x and y coordinates of a particular keypoint from the pose estimation model\n",
    "         \n",
    "     Args:\n",
    "         landmarks: processed keypoints from the pose estimation model\n",
    "         mp_pose: Mediapipe pose estimation model\n",
    "         side: 'left' or 'right'. Denotes the side of the body of the landmark of interest.\n",
    "         joint: 'shoulder', 'elbow', 'wrist', 'hip', 'knee', or 'ankle'. Denotes which body joint is associated with the landmark of interest.\n",
    "    \n",
    "    \"\"\"\n",
    "    coord = getattr(mp_pose.PoseLandmark,side.upper()+\"_\"+joint.upper())\n",
    "    x_coord_val = landmarks[coord.value].x\n",
    "    y_coord_val = landmarks[coord.value].y\n",
    "    return [x_coord_val, y_coord_val]            "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "id": "f11273cf",
   "metadata": {},
   "outputs": [],
   "source": [
    "def viz_joint_angle(image, angle, joint):\n",
    "    \"\"\"\n",
    "    Displays the joint angle value near the joint within the image frame\n",
    "    \n",
    "    \"\"\"\n",
    "    cv2.putText(image, str(int(angle)), \n",
    "                   tuple(np.multiply(joint, [640, 480]).astype(int)), \n",
    "                   cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2, cv2.LINE_AA\n",
    "                        )\n",
    "    return"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "id": "b64050d3",
   "metadata": {},
   "outputs": [],
   "source": [
    "def count_reps(image, current_action, landmarks, mp_pose):\n",
    "    \"\"\"\n",
    "    Counts repetitions of each exercise. Global count and stage (i.e., state) variables are updated within this function.\n",
    "    \n",
    "    \"\"\"\n",
    "\n",
    "    global curl_counter, press_counter, squat_counter, curl_stage, press_stage, squat_stage\n",
    "    \n",
    "    if current_action == 'curl':\n",
    "        # Get coords\n",
    "        shoulder = get_coordinates(landmarks, mp_pose, 'left', 'shoulder')\n",
    "        elbow = get_coordinates(landmarks, mp_pose, 'left', 'elbow')\n",
    "        wrist = get_coordinates(landmarks, mp_pose, 'left', 'wrist')\n",
    "        \n",
    "        # calculate elbow angle\n",
    "        angle = calculate_angle(shoulder, elbow, wrist)\n",
    "        \n",
    "        # curl counter logic\n",
    "        if angle < 30:\n",
    "            curl_stage = \"up\" \n",
    "        if angle > 140 and curl_stage =='up':\n",
    "            curl_stage=\"down\"  \n",
    "            curl_counter +=1\n",
    "        press_stage = None\n",
    "        squat_stage = None\n",
    "            \n",
    "        # Viz joint angle\n",
    "        viz_joint_angle(image, angle, elbow)\n",
    "        \n",
    "    elif current_action == 'press':\n",
    "        \n",
    "        # Get coords\n",
    "        shoulder = get_coordinates(landmarks, mp_pose, 'left', 'shoulder')\n",
    "        elbow = get_coordinates(landmarks, mp_pose, 'left', 'elbow')\n",
    "        wrist = get_coordinates(landmarks, mp_pose, 'left', 'wrist')\n",
    "\n",
    "        # Calculate elbow angle\n",
    "        elbow_angle = calculate_angle(shoulder, elbow, wrist)\n",
    "        \n",
    "        # Compute distances between joints\n",
    "        shoulder2elbow_dist = abs(math.dist(shoulder,elbow))\n",
    "        shoulder2wrist_dist = abs(math.dist(shoulder,wrist))\n",
    "        \n",
    "        # Press counter logic\n",
    "        if (elbow_angle > 130) and (shoulder2elbow_dist < shoulder2wrist_dist):\n",
    "            press_stage = \"up\"\n",
    "        if (elbow_angle < 50) and (shoulder2elbow_dist > shoulder2wrist_dist) and (press_stage =='up'):\n",
    "            press_stage='down'\n",
    "            press_counter += 1\n",
    "        curl_stage = None\n",
    "        squat_stage = None\n",
    "            \n",
    "        # Viz joint angle\n",
    "        viz_joint_angle(image, elbow_angle, elbow)\n",
    "        \n",
    "    elif current_action == 'squat':\n",
    "        # Get coords\n",
    "        # left side\n",
    "        left_shoulder = get_coordinates(landmarks, mp_pose, 'left', 'shoulder')\n",
    "        left_hip = get_coordinates(landmarks, mp_pose, 'left', 'hip')\n",
    "        left_knee = get_coordinates(landmarks, mp_pose, 'left', 'knee')\n",
    "        left_ankle = get_coordinates(landmarks, mp_pose, 'left', 'ankle')\n",
    "        # right side\n",
    "        right_shoulder = get_coordinates(landmarks, mp_pose, 'right', 'shoulder')\n",
    "        right_hip = get_coordinates(landmarks, mp_pose, 'right', 'hip')\n",
    "        right_knee = get_coordinates(landmarks, mp_pose, 'right', 'knee')\n",
    "        right_ankle = get_coordinates(landmarks, mp_pose, 'right', 'ankle')\n",
    "        \n",
    "        # Calculate knee angles\n",
    "        left_knee_angle = calculate_angle(left_hip, left_knee, left_ankle)\n",
    "        right_knee_angle = calculate_angle(right_hip, right_knee, right_ankle)\n",
    "        \n",
    "        # Calculate hip angles\n",
    "        left_hip_angle = calculate_angle(left_shoulder, left_hip, left_knee)\n",
    "        right_hip_angle = calculate_angle(right_shoulder, right_hip, right_knee)\n",
    "        \n",
    "        # Squat counter logic\n",
    "        thr = 165\n",
    "        if (left_knee_angle < thr) and (right_knee_angle < thr) and (left_hip_angle < thr) and (right_hip_angle < thr):\n",
    "            squat_stage = \"down\"\n",
    "        if (left_knee_angle > thr) and (right_knee_angle > thr) and (left_hip_angle > thr) and (right_hip_angle > thr) and (squat_stage =='down'):\n",
    "            squat_stage='up'\n",
    "            squat_counter += 1\n",
    "        curl_stage = None\n",
    "        press_stage = None\n",
    "            \n",
    "        # Viz joint angles\n",
    "        viz_joint_angle(image, left_knee_angle, left_knee)\n",
    "        viz_joint_angle(image, left_hip_angle, left_hip)\n",
    "        \n",
    "    else:\n",
    "        pass"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a5116ef6",
   "metadata": {},
   "source": [
    "# 12. Test in Real Time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "id": "4775b75e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def prob_viz(res, actions, input_frame, colors):\n",
    "    \"\"\"\n",
    "    This function displays the model prediction probability distribution over the set of exercise classes\n",
    "    as a horizontal bar graph\n",
    "    \n",
    "    \"\"\"\n",
    "    output_frame = input_frame.copy()\n",
    "    for num, prob in enumerate(res):        \n",
    "        cv2.rectangle(output_frame, (0,60+num*40), (int(prob*100), 90+num*40), colors[num], -1)\n",
    "        cv2.putText(output_frame, actions[num], (0, 85+num*40), cv2.FONT_HERSHEY_SIMPLEX, 1, (255,255,255), 2, cv2.LINE_AA)\n",
    "        \n",
    "    return output_frame"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "id": "6332bf1a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 1. New detection variables\n",
    "sequence = []\n",
    "predictions = []\n",
    "res = []\n",
    "threshold = 0.5 # minimum confidence to classify as an action/exercise\n",
    "current_action = ''\n",
    "\n",
    "# Rep counter logic variables\n",
    "curl_counter = 0\n",
    "press_counter = 0\n",
    "squat_counter = 0\n",
    "curl_stage = None\n",
    "press_stage = None\n",
    "squat_stage = None\n",
    "\n",
    "# Camera object\n",
    "cap = cv2.VideoCapture(0)\n",
    "\n",
    "# Video writer object that saves a video of the real time test\n",
    "fourcc = cv2.VideoWriter_fourcc('M','J','P','G') # video compression format\n",
    "HEIGHT = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) # webcam video frame height\n",
    "WIDTH = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) # webcam video frame width\n",
    "FPS = int(cap.get(cv2.CAP_PROP_FPS)) # webcam video fram rate \n",
    "\n",
    "video_name = os.path.join(os.getcwd(),f\"{model_name}_real_time_test.avi\")\n",
    "out = cv2.VideoWriter(video_name, cv2.VideoWriter_fourcc(*\"MJPG\"), FPS, (WIDTH,HEIGHT))\n",
    "\n",
    "# Set mediapipe model \n",
    "with mp_pose.Pose(min_detection_confidence=0.5, min_tracking_confidence=0.5) as pose:\n",
    "    while cap.isOpened():\n",
    "\n",
    "        # Read feed\n",
    "        ret, frame = cap.read()\n",
    "\n",
    "        # Make detection\n",
    "        image, results = mediapipe_detection(frame, pose)\n",
    "        \n",
    "        # Draw landmarks\n",
    "        draw_landmarks(image, results)\n",
    "        \n",
    "        # 2. Prediction logic\n",
    "        keypoints = extract_keypoints(results)        \n",
    "        sequence.append(keypoints)      \n",
    "        sequence = sequence[-sequence_length:]\n",
    "              \n",
    "        if len(sequence) == sequence_length:\n",
    "            res = model.predict(np.expand_dims(sequence, axis=0), verbose=0)[0]           \n",
    "            predictions.append(np.argmax(res))\n",
    "            current_action = actions[np.argmax(res)]\n",
    "            confidence = np.max(res)\n",
    "            \n",
    "        #3. Viz logic\n",
    "            # Erase current action variable if no probability is above threshold\n",
    "            if confidence < threshold:\n",
    "                current_action = ''\n",
    "\n",
    "            # Viz probabilities\n",
    "            image = prob_viz(res, actions, image, colors)\n",
    "            \n",
    "            # Count reps\n",
    "            try:\n",
    "                landmarks = results.pose_landmarks.landmark\n",
    "                count_reps(\n",
    "                    image, current_action, landmarks, mp_pose)\n",
    "            except:\n",
    "                pass\n",
    "\n",
    "            # Display graphical information\n",
    "            cv2.rectangle(image, (0,0), (640, 40), colors[np.argmax(res)], -1)\n",
    "            cv2.putText(image, 'curl ' + str(curl_counter), (3,30), \n",
    "                           cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA)\n",
    "            cv2.putText(image, 'press ' + str(press_counter), (240,30), \n",
    "                           cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA)\n",
    "            cv2.putText(image, 'squat ' + str(squat_counter), (490,30), \n",
    "                           cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA)\n",
    "         \n",
    "        # Show to screen\n",
    "        cv2.imshow('OpenCV Feed', image)\n",
    "        \n",
    "        # Write to video file\n",
    "        if ret == True:\n",
    "            out.write(image)\n",
    "\n",
    "        # Break gracefully\n",
    "        if cv2.waitKey(10) & 0xFF == ord('q'):\n",
    "            break\n",
    "    cap.release()\n",
    "    out.release()\n",
    "    cv2.destroyAllWindows()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "id": "af9980a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "cap.release()\n",
    "out.release()\n",
    "cv2.destroyAllWindows()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.8.13 ('AItrainer')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  },
  "vscode": {
   "interpreter": {
    "hash": "80aa1d3f3a8cfb37a38c47373cc49a39149184c5fa770d709389b1b8782c1d85"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}