{ "cells": [ { "cell_type": "markdown", "id": "f56cc5ad", "metadata": {}, "source": [ "# NDIS Project - PBSP Scoring - Page 2" ] }, { "cell_type": "code", "execution_count": null, "id": "a8d844ea", "metadata": { "hide_input": false }, "outputs": [], "source": [ "import os\n", "from ipywidgets import interact\n", "import ipywidgets as widgets\n", "from IPython.display import display, clear_output, Javascript, HTML, Markdown\n", "import matplotlib.pyplot as plt\n", "import matplotlib.ticker as mtick\n", "import json\n", "import spacy\n", "from spacy import displacy\n", "from flair.data import Corpus\n", "from flair.datasets import CSVClassificationCorpus\n", "from flair.models import TARSClassifier\n", "from flair.data import Sentence\n", "from flair.trainers import ModelTrainer\n", "from sklearn.feature_extraction import text\n", "from pprint import pprint\n", "import re\n", "import pandas as pd\n", "import argilla as rg\n", "from argilla.metrics.text_classification import f1\n", "from argilla.training import ArgillaTrainer\n", "import joblib\n", "import random\n", "from typing import Dict\n", "import warnings\n", "warnings.filterwarnings('ignore')\n", "import logging\n", "logging.getLogger().setLevel(logging.CRITICAL)\n", "%matplotlib inline" ] }, { "cell_type": "code", "execution_count": null, "id": "96b83a1d", "metadata": {}, "outputs": [], "source": [ "#initializations\n", "tars_model_path = 'few-shot-model-1'\n", "tars = TARSClassifier().load(tars_model_path+'/best-model.pt')\n", "\n", "# argilla\n", "rg.init(\n", " api_url=os.environ[\"ARGILLA_API_URL\"],\n", " api_key=os.environ[\"ARGILLA_API_KEY\"]\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "72c2c6f9", "metadata": { "hide_input": false }, "outputs": [], "source": [ "#Text Preprocessing\n", "try:\n", " nlp = spacy.load('en_core_web_sm')\n", "except OSError:\n", " spacy.cli.download('en_core_web_sm')\n", " nlp = spacy.load('en_core_web_sm')\n", "sw_lst = text.ENGLISH_STOP_WORDS\n", "def preprocess(onto_lst):\n", " cleaned_onto_lst = []\n", " pattern = re.compile(r'^[a-z ]*$')\n", " for document in onto_lst:\n", " text = []\n", " doc = nlp(document)\n", " person_tokens = []\n", " for w in doc:\n", " if w.ent_type_ == 'PERSON':\n", " person_tokens.append(w.lemma_)\n", " for w in doc:\n", " if not w.is_stop and not w.is_punct and not w.like_num and not len(w.text.strip()) == 0 and not w.lemma_ in person_tokens:\n", " text.append(w.lemma_.lower())\n", " texts = [t for t in text if len(t) > 1 and pattern.search(t) is not None and t not in sw_lst]\n", " cleaned_onto_lst.append(\" \".join(texts))\n", " return cleaned_onto_lst" ] }, { "cell_type": "code", "execution_count": null, "id": "40070313", "metadata": {}, "outputs": [], "source": [ "#sentence extraction\n", "def extract_sentences(query):\n", " # Compile the regular expression pattern\n", " pattern = re.compile(r'[.,;!?]')\n", " # Split the sentences on the punctuation characters\n", " sentences = [query]\n", " split_sentences = [pattern.split(sentence) for sentence in sentences]\n", " # Flatten the list of split sentences\n", " flat_list = [item for sublist in split_sentences for item in sublist]\n", " # Remove empty strings from the list\n", " filtered_sentences = [sentence.strip() for sentence in flat_list if sentence.strip()]\n", " return filtered_sentences" ] }, { "cell_type": "code", "execution_count": null, "id": "1550437b", "metadata": {}, "outputs": [], "source": [ "#query and get predicted topic\n", "\n", "p_classes = {'speak_to_family': 0,\n", " 'stakeholder_interview': 1,\n", " 'case_file_review': 2,\n", " 'observation': 3,\n", " 'no_information_collection': 4}\n", "def get_topic(sentences):\n", " preds = []\n", " for t in sentences:\n", " sentence = Sentence(t)\n", " tars.predict(sentence)\n", " try:\n", " pred = p_classes[sentence.tag]\n", " except:\n", " pred = 4\n", " preds.append(pred)\n", " return preds\n", "def get_topic_scores(sentences):\n", " preds = []\n", " for t in sentences:\n", " sentence = Sentence(t)\n", " tars.predict(sentence)\n", " try:\n", " pred = sentence.score\n", " except:\n", " pred = 0.75\n", " preds.append(pred)\n", " return preds" ] }, { "cell_type": "code", "execution_count": null, "id": "7b73a6ab", "metadata": {}, "outputs": [], "source": [ "def convert_df(result_df):\n", " new_df = pd.DataFrame(columns=['text', 'prediction'])\n", " new_df['text'] = result_df['Phrase']\n", " new_df['prediction'] = result_df.apply(lambda row: [[row['Topic'], row['Score']]], axis=1)\n", " return new_df" ] }, { "cell_type": "code", "execution_count": null, "id": "990acf9f", "metadata": {}, "outputs": [], "source": [ "def get_trainer_preds(trainer, records):\n", " sentences = [records[x].text for x in range(0, len(records))]\n", " topics = [records[x].prediction[0][0] for x in range(0, len(records))]\n", " scores = [records[x].prediction[0][1] for x in range(0, len(records))]\n", " result_df = pd.DataFrame({'Phrase': sentences, 'Topic': topics, 'Score': scores})\n", " return result_df" ] }, { "cell_type": "code", "execution_count": null, "id": "973f46c4", "metadata": {}, "outputs": [], "source": [ "def custom_f1(data: Dict[str, float], title: str):\n", " from plotly.subplots import make_subplots\n", " import plotly.colors\n", " import random\n", "\n", " fig = make_subplots(\n", " rows=2,\n", " cols=1,\n", " subplot_titles=[ \"Overall Model Score\", \"Model Score By Category\", ],\n", " )\n", "\n", " x = ['precision', 'recall', 'f1']\n", " macro_data = [v for k, v in data.items() if \"macro\" in k]\n", " fig.add_bar(\n", " x=x,\n", " y=macro_data,\n", " row=1,\n", " col=1,\n", " )\n", " per_label = {\n", " k: v\n", " for k, v in data.items()\n", " if all(key not in k for key in [\"macro\", \"micro\", \"support\"])\n", " }\n", "\n", " num_labels = int(len(per_label.keys())/3)\n", " fixed_colors = [str(color) for color in plotly.colors.qualitative.Plotly]\n", " colors = random.sample(fixed_colors, num_labels)\n", "\n", " fig.add_bar(\n", " x=[k for k, v in per_label.items()],\n", " y=[v for k, v in per_label.items()],\n", " row=2,\n", " col=1,\n", " marker_color=[colors[int(i/3)] for i in range(0, len(per_label.keys()))]\n", " )\n", " fig.update_layout(showlegend=False, title_text=title)\n", "\n", " return fig" ] }, { "cell_type": "code", "execution_count": null, "id": "02fda761", "metadata": {}, "outputs": [], "source": [ "# format output\n", "ind_topic_dict = {\n", " 0: 'SPEAK-TO-FAMILY',\n", " 1: 'STAKEHOLDER-INTERVIEW',\n", " 2: 'FILE-REVIEW',\n", " 3: 'OBSERVATION/COMMUNICATION',\n", " 4: 'NO-COLLECTED-INFO'\n", " }\n", "\n", "topic_color_dict = {\n", " 'SPEAK-TO-FAMILY': '#FFCCCC',\n", " 'STAKEHOLDER-INTERVIEW': '#CCFFFF',\n", " 'FILE-REVIEW': '#FF69B4',\n", " 'OBSERVATION/COMMUNICATION': '#FFFF00',\n", " 'NO-COLLECTED-INFO': '#ECECEC'\n", " }\n", "\n", "passing_score = 0.25\n", "\n", "def color(df, color):\n", " return df.style.format({'Score': '{:,.2%}'.format}).bar(subset=['Score'], color=color)\n", "\n", "def annotate_query(highlights, query, topics):\n", " ents = []\n", " for h, t in zip(highlights, topics):\n", " ent_dict = {}\n", " for match in re.finditer(h, query):\n", " ent_dict = {\"start\": match.start(), \"end\": match.end(), \"label\": t}\n", " break\n", " if len(ent_dict.keys()) > 0:\n", " ents.append(ent_dict)\n", " return ents" ] }, { "cell_type": "code", "execution_count": null, "id": "905eaf2a", "metadata": {}, "outputs": [], "source": [ "def path_to_image_html(path):\n", " return ''\n", "\n", "final_passing = 0.0\n", "def display_final_df(agg_df):\n", " tags = []\n", " crits = [\n", " 'SPEAK-TO-FAMILY',\n", " 'STAKEHOLDER-INTERVIEW',\n", " 'FILE-REVIEW',\n", " 'OBSERVATION/COMMUNICATION'\n", " ]\n", " orig_crits = crits\n", " crits = [x for x in crits if x in agg_df.index.tolist()]\n", " bools = [agg_df.loc[crit, 'Final_Score'] > final_passing for crit in crits]\n", " paths = ['./tick_green.png' if x else './cross_red.png' for x in bools]\n", " df = pd.DataFrame({'Approach': crits, 'USED': paths})\n", " rem_crits = [x for x in orig_crits if x not in crits]\n", " if len(rem_crits) > 0:\n", " df2 = pd.DataFrame({'Approach': rem_crits, 'USED': ['./cross_red.png'] * len(rem_crits)})\n", " df = pd.concat([df, df2])\n", " df = df.set_index('Approach')\n", " pd.set_option('display.max_colwidth', None)\n", " display(HTML('