{ "cells": [ { "cell_type": "markdown", "id": "f56cc5ad", "metadata": {}, "source": [ "# NDIS Project - PBSP Scoring - Page 2" ] }, { "cell_type": "code", "execution_count": null, "id": "a8d844ea", "metadata": { "hide_input": false }, "outputs": [], "source": [ "import os\n", "from ipywidgets import interact\n", "import ipywidgets as widgets\n", "from IPython.display import display, clear_output, Javascript, HTML, Markdown\n", "import matplotlib.pyplot as plt\n", "import matplotlib.ticker as mtick\n", "import json\n", "import spacy\n", "from spacy import displacy\n", "from flair.data import Corpus\n", "from flair.datasets import CSVClassificationCorpus\n", "from flair.models import TARSClassifier\n", "from flair.data import Sentence\n", "from flair.trainers import ModelTrainer\n", "from sklearn.feature_extraction import text\n", "from pprint import pprint\n", "import re\n", "import pandas as pd\n", "import argilla as rg\n", "from argilla.metrics.text_classification import f1\n", "from argilla.training import ArgillaTrainer\n", "import joblib\n", "import random\n", "from typing import Dict\n", "import warnings\n", "warnings.filterwarnings('ignore')\n", "import logging\n", "logging.getLogger().setLevel(logging.CRITICAL)\n", "%matplotlib inline" ] }, { "cell_type": "code", "execution_count": null, "id": "96b83a1d", "metadata": {}, "outputs": [], "source": [ "#initializations\n", "tars_model_path = 'few-shot-model-1'\n", "tars = TARSClassifier().load(tars_model_path+'/best-model.pt')\n", "\n", "# argilla\n", "rg.init(\n", " api_url=os.environ[\"ARGILLA_API_URL\"],\n", " api_key=os.environ[\"ARGILLA_API_KEY\"]\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "72c2c6f9", "metadata": { "hide_input": false }, "outputs": [], "source": [ "#Text Preprocessing\n", "try:\n", " nlp = spacy.load('en_core_web_sm')\n", "except OSError:\n", " spacy.cli.download('en_core_web_sm')\n", " nlp = spacy.load('en_core_web_sm')\n", "sw_lst = text.ENGLISH_STOP_WORDS\n", "def preprocess(onto_lst):\n", " cleaned_onto_lst = []\n", " pattern = re.compile(r'^[a-z ]*$')\n", " for document in onto_lst:\n", " text = []\n", " doc = nlp(document)\n", " person_tokens = []\n", " for w in doc:\n", " if w.ent_type_ == 'PERSON':\n", " person_tokens.append(w.lemma_)\n", " for w in doc:\n", " if not w.is_stop and not w.is_punct and not w.like_num and not len(w.text.strip()) == 0 and not w.lemma_ in person_tokens:\n", " text.append(w.lemma_.lower())\n", " texts = [t for t in text if len(t) > 1 and pattern.search(t) is not None and t not in sw_lst]\n", " cleaned_onto_lst.append(\" \".join(texts))\n", " return cleaned_onto_lst" ] }, { "cell_type": "code", "execution_count": null, "id": "40070313", "metadata": {}, "outputs": [], "source": [ "#sentence extraction\n", "def extract_sentences(query):\n", " # Compile the regular expression pattern\n", " pattern = re.compile(r'[.,;!?]')\n", " # Split the sentences on the punctuation characters\n", " sentences = [query]\n", " split_sentences = [pattern.split(sentence) for sentence in sentences]\n", " # Flatten the list of split sentences\n", " flat_list = [item for sublist in split_sentences for item in sublist]\n", " # Remove empty strings from the list\n", " filtered_sentences = [sentence.strip() for sentence in flat_list if sentence.strip()]\n", " return filtered_sentences" ] }, { "cell_type": "code", "execution_count": null, "id": "1550437b", "metadata": {}, "outputs": [], "source": [ "#query and get predicted topic\n", "\n", "p_classes = {'speak_to_family': 0,\n", " 'stakeholder_interview': 1,\n", " 'case_file_review': 2,\n", " 'observation': 3,\n", " 'no_information_collection': 4}\n", "def get_topic(sentences):\n", " preds = []\n", " for t in sentences:\n", " sentence = Sentence(t)\n", " tars.predict(sentence)\n", " try:\n", " pred = p_classes[sentence.tag]\n", " except:\n", " pred = 4\n", " preds.append(pred)\n", " return preds\n", "def get_topic_scores(sentences):\n", " preds = []\n", " for t in sentences:\n", " sentence = Sentence(t)\n", " tars.predict(sentence)\n", " try:\n", " pred = sentence.score\n", " except:\n", " pred = 0.75\n", " preds.append(pred)\n", " return preds" ] }, { "cell_type": "code", "execution_count": null, "id": "7b73a6ab", "metadata": {}, "outputs": [], "source": [ "def convert_df(result_df):\n", " new_df = pd.DataFrame(columns=['text', 'prediction'])\n", " new_df['text'] = result_df['Phrase']\n", " new_df['prediction'] = result_df.apply(lambda row: [[row['Topic'], row['Score']]], axis=1)\n", " return new_df" ] }, { "cell_type": "code", "execution_count": null, "id": "990acf9f", "metadata": {}, "outputs": [], "source": [ "def get_trainer_preds(trainer, records):\n", " sentences = [records[x].text for x in range(0, len(records))]\n", " topics = [records[x].prediction[0][0] for x in range(0, len(records))]\n", " scores = [records[x].prediction[0][1] for x in range(0, len(records))]\n", " result_df = pd.DataFrame({'Phrase': sentences, 'Topic': topics, 'Score': scores})\n", " return result_df" ] }, { "cell_type": "code", "execution_count": null, "id": "973f46c4", "metadata": {}, "outputs": [], "source": [ "def custom_f1(data: Dict[str, float], title: str):\n", " from plotly.subplots import make_subplots\n", " import plotly.colors\n", " import random\n", "\n", " fig = make_subplots(\n", " rows=2,\n", " cols=1,\n", " subplot_titles=[ \"Overall Model Score\", \"Model Score By Category\", ],\n", " )\n", "\n", " x = ['precision', 'recall', 'f1']\n", " macro_data = [v for k, v in data.items() if \"macro\" in k]\n", " fig.add_bar(\n", " x=x,\n", " y=macro_data,\n", " row=1,\n", " col=1,\n", " )\n", " per_label = {\n", " k: v\n", " for k, v in data.items()\n", " if all(key not in k for key in [\"macro\", \"micro\", \"support\"])\n", " }\n", "\n", " num_labels = int(len(per_label.keys())/3)\n", " fixed_colors = [str(color) for color in plotly.colors.qualitative.Plotly]\n", " colors = random.sample(fixed_colors, num_labels)\n", "\n", " fig.add_bar(\n", " x=[k for k, v in per_label.items()],\n", " y=[v for k, v in per_label.items()],\n", " row=2,\n", " col=1,\n", " marker_color=[colors[int(i/3)] for i in range(0, len(per_label.keys()))]\n", " )\n", " fig.update_layout(showlegend=False, title_text=title)\n", "\n", " return fig" ] }, { "cell_type": "code", "execution_count": null, "id": "02fda761", "metadata": {}, "outputs": [], "source": [ "# format output\n", "ind_topic_dict = {\n", " 0: 'SPEAK-TO-FAMILY',\n", " 1: 'STAKEHOLDER-INTERVIEW',\n", " 2: 'FILE-REVIEW',\n", " 3: 'OBSERVATION/COMMUNICATION',\n", " 4: 'NO-COLLECTED-INFO'\n", " }\n", "\n", "topic_color_dict = {\n", " 'SPEAK-TO-FAMILY': '#FFCCCC',\n", " 'STAKEHOLDER-INTERVIEW': '#CCFFFF',\n", " 'FILE-REVIEW': '#FF69B4',\n", " 'OBSERVATION/COMMUNICATION': '#FFFF00',\n", " 'NO-COLLECTED-INFO': '#ECECEC'\n", " }\n", "\n", "passing_score = 0.25\n", "\n", "def color(df, color):\n", " return df.style.format({'Score': '{:,.2%}'.format}).bar(subset=['Score'], color=color)\n", "\n", "def annotate_query(highlights, query, topics):\n", " ents = []\n", " for h, t in zip(highlights, topics):\n", " ent_dict = {}\n", " for match in re.finditer(h, query):\n", " ent_dict = {\"start\": match.start(), \"end\": match.end(), \"label\": t}\n", " break\n", " if len(ent_dict.keys()) > 0:\n", " ents.append(ent_dict)\n", " return ents" ] }, { "cell_type": "code", "execution_count": null, "id": "905eaf2a", "metadata": {}, "outputs": [], "source": [ "def path_to_image_html(path):\n", " return ''\n", "\n", "final_passing = 0.0\n", "def display_final_df(agg_df):\n", " tags = []\n", " crits = [\n", " 'SPEAK-TO-FAMILY',\n", " 'STAKEHOLDER-INTERVIEW',\n", " 'FILE-REVIEW',\n", " 'OBSERVATION/COMMUNICATION'\n", " ]\n", " orig_crits = crits\n", " crits = [x for x in crits if x in agg_df.index.tolist()]\n", " bools = [agg_df.loc[crit, 'Final_Score'] > final_passing for crit in crits]\n", " paths = ['./tick_green.png' if x else './cross_red.png' for x in bools]\n", " df = pd.DataFrame({'Approach': crits, 'USED': paths})\n", " rem_crits = [x for x in orig_crits if x not in crits]\n", " if len(rem_crits) > 0:\n", " df2 = pd.DataFrame({'Approach': rem_crits, 'USED': ['./cross_red.png'] * len(rem_crits)})\n", " df = pd.concat([df, df2])\n", " df = df.set_index('Approach')\n", " pd.set_option('display.max_colwidth', None)\n", " display(HTML('
' + df.to_html(classes=[\"align-center\"], index=True, escape=False ,formatters=dict(USED=path_to_image_html)) + '
'))" ] }, { "cell_type": "markdown", "id": "2c6e9fe7", "metadata": {}, "source": [ "### Outline the behavioural assessment approaches implemented to develop this PBSP " ] }, { "cell_type": "code", "execution_count": null, "id": "76dd8cab", "metadata": { "scrolled": false }, "outputs": [], "source": [ "#demo with Voila\n", "\n", "bhvr_label = widgets.Label(value='Please type your answer:')\n", "bhvr_text_input = widgets.Textarea(\n", " value='',\n", " placeholder='Type your answer',\n", " description='',\n", " disabled=False,\n", " layout={'height': '300px', 'width': '90%'}\n", ")\n", "\n", "bhvr_nlp_btn = widgets.Button(\n", " description='Score Answer',\n", " disabled=False,\n", " button_style='success', # 'success', 'info', 'warning', 'danger' or ''\n", " tooltip='Score Answer',\n", " icon='check',\n", " layout={'height': '70px', 'width': '250px'}\n", ")\n", "bhvr_agr_btn = widgets.Button(\n", " description='Validate Data',\n", " disabled=False,\n", " button_style='success', # 'success', 'info', 'warning', 'danger' or ''\n", " tooltip='Validate Data',\n", " icon='check',\n", " layout={'height': '70px', 'width': '250px'}\n", ")\n", "bhvr_eval_btn = widgets.Button(\n", " description='Evaluate Model',\n", " disabled=False,\n", " button_style='success', # 'success', 'info', 'warning', 'danger' or ''\n", " tooltip='Evaluate Model',\n", " icon='check',\n", " layout={'height': '70px', 'width': '250px'}\n", ")\n", "bhvr_trn_btn = widgets.Button(\n", " description='Re-train Model',\n", " disabled=False,\n", " button_style='success', # 'success', 'info', 'warning', 'danger' or ''\n", " tooltip='Re-train Model',\n", " icon='check',\n", " layout={'height': '70px', 'width': '250px'}\n", ")\n", "btn_box = widgets.HBox([bhvr_nlp_btn, bhvr_agr_btn, bhvr_eval_btn, bhvr_trn_btn], \n", " layout={'width': '100%', 'height': '160%'})\n", "bhvr_outt = widgets.Output()\n", "bhvr_outt.layout.height = '100%'\n", "bhvr_outt.layout.width = '100%'\n", "bhvr_box = widgets.VBox([bhvr_text_input, btn_box, bhvr_outt], \n", " layout={'width': '100%', 'height': '160%'})\n", "dataset_rg_name = 'pbsp-page2-q1-argilla-ds'\n", "trainer_rg_name = 'pbsp-page2-q1-argilla-trn'\n", "agrilla_df = None\n", "dataset_rg = None\n", "annotated = False\n", "trainer = None\n", "def on_bhvr_button_next(b):\n", " global agrilla_df, trainer, dataset_rg\n", " with bhvr_outt:\n", " clear_output()\n", " query = bhvr_text_input.value\n", " sentences = extract_sentences(query)\n", " if trainer is not None:\n", " records = trainer.predict(dataset_rg.to_pandas()['text'].tolist(), as_argilla_records=True)\n", " result_df = get_trainer_preds(trainer, records)\n", " else:\n", " cl_sentences = preprocess(sentences)\n", " topic_inds = get_topic(cl_sentences)\n", " topics = [ind_topic_dict[i] for i in topic_inds]\n", " scores = get_topic_scores(cl_sentences)\n", " result_df = pd.DataFrame({'Phrase': sentences, 'Topic': topics, 'Score': scores})\n", " sub_result_df = result_df[(result_df['Score'] >= passing_score) & (result_df['Topic'] != 'NO-COLLECTED-INFO')]\n", " sub_2_result_df = result_df[result_df['Topic'] == 'NO-COLLECTED-INFO']\n", " highlights = []\n", " if len(sub_result_df) > 0:\n", " highlights = sub_result_df['Phrase'].tolist()\n", " highlight_topics = sub_result_df['Topic'].tolist() \n", " ents = annotate_query(highlights, query, highlight_topics)\n", " colors = {}\n", " for ent, ht in zip(ents, highlight_topics):\n", " colors[ent['label']] = topic_color_dict[ht]\n", "\n", " ex = [{\"text\": query,\n", " \"ents\": ents,\n", " \"title\": None}]\n", " title = \"Highlighting Used Approaches\"\n", " display(HTML(f'

{title}

'))\n", " html = displacy.render(ex, style=\"ent\", manual=True, jupyter=True, options={'colors': colors})\n", " display(HTML(html))\n", " title = \"Used Approach Classifications\"\n", " display(HTML(f'

{title}

'))\n", " for top in topic_color_dict.keys():\n", " top_result_df = sub_result_df[sub_result_df['Topic'] == top]\n", " if len(top_result_df) > 0:\n", " top_result_df = top_result_df.sort_values(by='Score', ascending=False).reset_index(drop=True)\n", " top_result_df = top_result_df.set_index('Phrase')\n", " top_result_df = top_result_df[['Score']]\n", " display(HTML(\n", " f'

{top}

'))\n", " display(color(top_result_df, topic_color_dict[top]))\n", " \n", " agg_df = sub_result_df.groupby('Topic')['Score'].sum()\n", " agg_df = agg_df.to_frame()\n", " agg_df.index.name = 'Topic'\n", " agg_df.columns = ['Total Score']\n", " agg_df = agg_df.assign(\n", " Final_Score=lambda x: x['Total Score'] / x['Total Score'].sum() * 100.00\n", " )\n", " agg_df = agg_df.sort_values(by='Final_Score', ascending=False)\n", " title = \"Approach Coverage\"\n", " display(HTML(f'

{title}

'))\n", " agg_df['Topic'] = agg_df.index\n", " rem_topics= [x for x in list(topic_color_dict.keys()) if not x in agg_df.Topic.tolist()]\n", " if len(rem_topics) > 0:\n", " rem_agg_df = pd.DataFrame({'Topic': rem_topics, 'Final_Score': 0.0, 'Total Score': 0.0})\n", " agg_df = pd.concat([agg_df, rem_agg_df])\n", " labels = agg_df['Final_Score'].round(1).astype('str') + '%'\n", " ax = agg_df.plot.bar(x='Topic', y='Final_Score', rot=0, figsize=(20, 5), align='center')\n", " for container in ax.containers:\n", " ax.bar_label(container, labels=labels)\n", " ax.yaxis.set_major_formatter(mtick.PercentFormatter())\n", " ax.legend([\"Final Score (%)\"])\n", " ax.set_xlabel('')\n", " plt.show()\n", " title = \"Final Approach Scores\"\n", " display(HTML(f'

{title}

'))\n", " display_final_df(agg_df)\n", " if len(sub_2_result_df) > 0:\n", " sub_result_df = pd.concat([sub_result_df, sub_2_result_df]).reset_index(drop=True)\n", " agrilla_df = sub_result_df.copy()\n", " else:\n", " print(query)\n", "\n", "def on_agr_button_next(b):\n", " global agrilla_df, annotated, dataset_rg\n", " with bhvr_outt:\n", " clear_output()\n", " if agrilla_df is not None:\n", " # convert the dataframe to the structure accepted by argilla\n", " converted_df = convert_df(agrilla_df)\n", " # convert pandas dataframe to DatasetForTextClassification\n", " dataset_rg = rg.DatasetForTextClassification.from_pandas(converted_df)\n", " # delete the old DatasetForTextClassification from the Argilla web app if exists\n", " rg.delete(dataset_rg_name, workspace=\"admin\")\n", " # load the new DatasetForTextClassification into the Argilla web app\n", " rg.log(dataset_rg, name=dataset_rg_name, workspace=\"admin\")\n", " # Make sure all classes are present for annotation\n", " rg_settings = rg.TextClassificationSettings(label_schema=list(topic_color_dict.keys()))\n", " rg.configure_dataset(name=dataset_rg_name, workspace=\"admin\", settings=rg_settings)\n", " annotated = True\n", " else:\n", " display(Markdown(\"

Please score the answer first!

\"))\n", " \n", "def on_eval_button_next(b):\n", " global annotated\n", " with bhvr_outt:\n", " clear_output()\n", " if annotated:\n", " data = dict(f1(dataset_rg_name))['data']\n", " display(custom_f1(data, \"Model Evaluation Results\"))\n", " else:\n", " display(Markdown(\"

Please score the answer and validate the data first!

\"))\n", "\n", "def on_trn_button_next(b):\n", " global annotated, trainer\n", " with bhvr_outt:\n", " clear_output()\n", " if annotated:\n", " trainer = ArgillaTrainer(\n", " name=dataset_rg_name,\n", " workspace=\"admin\",\n", " framework=\"setfit\",\n", " train_size=1.0\n", " )\n", " trainer.update_config(\n", " pretrained_model_name_or_path = \"all-mpnet-base-v2\",\n", " force_download = False,\n", " resume_download = False,\n", " proxies = None,\n", " token = None,\n", " cache_dir = None,\n", " local_files_only = False,\n", " num_iterations=10\n", " )\n", " trainer.train(output_dir=trainer_rg_name)\n", " \n", " else:\n", " display(Markdown(\"

Please score the answer and validate the data first!

\"))\n", "\n", "bhvr_nlp_btn.on_click(on_bhvr_button_next)\n", "bhvr_agr_btn.on_click(on_agr_button_next)\n", "bhvr_eval_btn.on_click(on_eval_button_next)\n", "bhvr_trn_btn.on_click(on_trn_button_next)\n", "\n", "display(bhvr_label, bhvr_box)" ] } ], "metadata": { "hide_input": false, "kernelspec": { "display_name": "Python 3.9 (Argilla Trainer)", "language": "python", "name": "argilla_trainer" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.16" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": false, "sideBar": true, "skip_h1_title": true, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": { "height": "calc(100% - 180px)", "left": "10px", "top": "150px", "width": "258.097px" }, "toc_section_display": true, "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 5 }