{ "cells": [ { "cell_type": "markdown", "id": "f56cc5ad", "metadata": {}, "source": [ "# NDIS Project - OpenAI - PBSP Scoring - Page 2 - Direct / Indirect Data Collection" ] }, { "cell_type": "code", "execution_count": null, "id": "a8d844ea", "metadata": { "hide_input": false }, "outputs": [], "source": [ "import openai\n", "import re\n", "import string\n", "from ipywidgets import interact\n", "import ipywidgets as widgets\n", "from IPython.display import display, clear_output, Javascript, HTML, Markdown\n", "import matplotlib.pyplot as plt\n", "import matplotlib.ticker as mtick\n", "import json\n", "import spacy\n", "from spacy import displacy\n", "from dotenv import load_dotenv\n", "import pandas as pd\n", "import argilla as rg\n", "from argilla.metrics.text_classification import f1\n", "import warnings\n", "warnings.filterwarnings('ignore')\n", "%matplotlib inline\n", "pd.set_option('display.max_rows', 500)\n", "pd.set_option('display.max_colwidth', 10000)\n", "pd.set_option('display.width', 10000)" ] }, { "cell_type": "code", "execution_count": null, "id": "96b83a1d", "metadata": {}, "outputs": [], "source": [ "#initializations\n", "openai.api_key = os.environ['API_KEY']\n", "openai.api_base = os.environ['API_BASE']\n", "openai.api_type = os.environ['API_TYPE']\n", "openai.api_version = os.environ['API_VERSION']\n", "deployment_name = os.environ['DEPLOYMENT_ID']\n", "\n", "#argilla\n", "rg.init(\n", " api_url=os.environ[\"ARGILLA_API_URL\"],\n", " api_key=os.environ[\"ARGILLA_API_KEY\"]\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "dee25d82", "metadata": {}, "outputs": [], "source": [ "#sentence extraction\n", "def extract_sentences(paragraph):\n", " symbols = ['\\\\.', '!', '\\\\?', ';', ':', ',', '\\\\_', '\\n', '\\\\-']\n", " pattern = '|'.join([f'{symbol}' for symbol in symbols])\n", " sentences = re.split(pattern, paragraph)\n", " sentences = [sentence.strip() for sentence in sentences if sentence.strip()]\n", " return sentences" ] }, { "cell_type": "code", "execution_count": null, "id": "08365143", "metadata": {}, "outputs": [], "source": [ "def filter_dataframe(result_df, paragraph):\n", " filtered_df = result_df[result_df['Phrase'].apply(lambda x: x.lower().translate(str.maketrans(\"\", \"\", string.punctuation)) in paragraph.lower().translate(str.maketrans(\"\", \"\", string.punctuation)) or \n", " x.lower().translate(str.maketrans(\"\", \"\", string.punctuation)).replace(\"’s\",\"s'\") in paragraph.lower().translate(str.maketrans(\"\", \"\", string.punctuation)))]\n", " filtered_df['Match_Percentage'] = filtered_df.apply(lambda row: len(set(row['Phrase'].lower()) & set(paragraph.lower())) / len(set(row['Phrase'].lower())), axis=1)\n", " filtered_df = filtered_df[filtered_df['Match_Percentage'] >= 0.2]\n", " filtered_df = filtered_df.drop(['Match_Percentage'], axis=1)\n", " if len(filtered_df) == 0:\n", " filtered_df = result_df\n", " filtered_df = filtered_df.drop_duplicates()\n", " return filtered_df" ] }, { "cell_type": "code", "execution_count": null, "id": "02fda761", "metadata": {}, "outputs": [], "source": [ "def process_response(response, query):\n", " sentences = []\n", " topics = []\n", " scores = []\n", " lines = response.strip().split(\"\\n\")\n", " topic = None\n", " for line in lines:\n", " if \"Direct data collection:\" in line:\n", " topic = \"DIRECT\"\n", " elif \"Indirect data collection:\" in line:\n", " topic = \"INDIRECT\"\n", " elif \"None:\" in line:\n", " topic = \"NONE\"\n", " else:\n", " try:\n", " parts = line.split(\"(Confidence Score:\")\n", " if len(parts) == 2:\n", " phrase = parts[0].strip()\n", " score = float(parts[1].strip().replace(\")\", \"\"))\n", " sentences.append(phrase)\n", " topics.append(topic)\n", " scores.append(score)\n", " except:\n", " pass\n", " result_df = pd.DataFrame({'Phrase': sentences, 'Topic': topics, 'Score': scores})\n", " try:\n", " result_df['Phrase'] = result_df['Phrase'].str.replace('\\d+\\.', '', regex=True)\n", " result_df['Phrase'] = result_df['Phrase'].str.replace('^\\s', '', regex=True)\n", " result_df['Phrase'] = result_df['Phrase'].str.strip('\"')\n", " result_df = filter_dataframe(result_df, query)\n", " except:\n", " sentences = extract_sentences(query)\n", " topics = ['NONE'] * len(sentences)\n", " scores = [0.9] * len(sentences)\n", " result_df = pd.DataFrame({'Phrase': sentences, 'Topic': topics, 'Score': scores})\n", " return result_df" ] }, { "cell_type": "code", "execution_count": null, "id": "714fafb4", "metadata": {}, "outputs": [], "source": [ "def get_prompt(query):\n", " prompt = f\"\"\"\n", " The paragraph below is written by a disability practitioner to describe the direct and/or indirect data collection approaches which has been undertaken to prepare the positive behaviour support plan. \n", "\n", " Paragraph:\n", " {query}\n", "\n", " Requirement:\n", " - Identify the phrases from the paragraph above that represent each of the following data collection categories: \"Direct data collection\", \"Indirect data collection\".\n", "\n", " Guidelines:\n", " - \"Direct data collection\": To detect any phrases from the paragraph that might represent direct data collection approaches, look for phrases that mention direct observation of the person with disability by either the practitioner or a relevant stakeholder (e.g., support worker). Example phrases include: \"we conducted direct observations of the person with disability\", \"we implemented a functional assessment that involved direct observation by both myself and the support worker\". Also look for phrases that mention completion of behavioural data collection tools like ABC note cards and scatter plots, completed by the practitioner or relevant stakeholder. Example phrases include: \"I conducted direct observation using ABC note cards to capture antecedents, behaviors, and consequences in the individual's natural environment.\", \"I completed scatter plots to visually represent the frequency, duration, and intensity of the targeted behaviors over time.\".\n", "\n", " - \"Indirect data collection\": To detect any phrases from the paragraph that might represent indirect data collection approaches, look for phrases that mention the use of standardised tools completed by the practitioner in consultation with relevant stakeholders who know the person of focus well, such as Contextual Assessment Inventory, Functional Assessment Interview. Example phrases include: \"I conducted a Contextual Assessment Inventory with input from relevant stakeholders\", \"I utilized the Functional Assessment Interview to gather information from individuals who know the person best\". Also look for phrases that mention interviews, phone calls or any other form of communication with relevant stakeholders who know the person of focus well. Example phrases include: \"We conducted several phone interviews with family members who have known the person for several years\", \"We spoke with the person's support workers to gain insight into their daily routines and any challenges they may be experiencing\". Also look for phrases that mention consultation of relevant reports (e.g., previous positive behaviour support plans, incident report, previous assessment reports from health and allied health professionals). Example phrases include: \"We have consulted relevant reports, such as previous positive behavior support plans, incident reports\", \"Our team has reviewed previous assessment reports from health and allied health professionals to gather information about any underlying medical or psychological conditions that may be influencing the individual's behavior.\"\n", "\n", " Specifications of a correct answer:\n", " - Please provide a response that closely matches the information in the paragraph and does not deviate significantly from it.\n", " - Provide your answer in numbered lists. \n", " - All the phrases in your answer must be exact substrings in the original paragraph. without changing any characters.\n", " - All the upper case and lower case characters in the phrases in your answer must match the upper case and lower case characters in the original paragraph.\n", " - Start numbering the phrases under each social validity topic from number 1. \n", " - Start each list of phrases with these titles: \"Direct data collection\", \"Indirect data collection\".\n", " - For each phrase that belongs to any of the above data collection categories, provide a confidence score that ranges between 0.50 and 1.00, where a score of 0.50 means you are very weakly confident that the phrase belongs to that specific data collection category, whereas a score of 1.00 means you are very strongly confident that the phrase belongs to that specific data collection category.\n", " - Never include any phrase in your answer that does not exist in the paragraph above.\n", " - Include a final numbered list titled \"None:\", which include all the remaining phrases from the paragraph above that do not belong to any of the data collection categories above. Provide a confidence score for each of these phrases as well.\n", "\n", " Example answer:\n", "\n", " Direct data collection:\n", " 1. we conducted direct observations of the person with disability. (Confidence Score: 1.00)\n", " 2. we implemented a functional assessment that involved direct observation by both myself and the support worker. (Confidence Score: 0.95)\n", " 3. I conducted direct observation using ABC note cards to capture antecedents, behaviors, and consequences in the individual's natural environment. (Confidence Score: 1.00)\n", " 4. I completed scatter plots to visually represent the frequency, duration, and intensity of the targeted behaviors over time. (Confidence Score: 0.92)\n", " \n", " Indirect data collection:\n", " 1. I conducted a Contextual Assessment Inventory with input from relevant stakeholders. (Confidence Score: 1.00)\n", " 2. I utilized the Functional Assessment Interview to gather information from individuals who know the person best. (Confidence Score: 1.00)\n", " 3. We conducted several phone interviews with family members who have known the person for several years. (Confidence Score: 0.94)\n", " 4. We spoke with the person's support workers to gain insight into their daily routines and any challenges they may be experiencing (Confidence Score: 0.92)\n", " 5. We have consulted relevant reports, such as previous positive behavior support plans, incident reports. (Confidence Score: 0.89)\n", " 6. Our team has reviewed previous assessment reports from health and allied health professionals to gather information about any underlying medical or psychological conditions that may be influencing the individual's behavior. (Confidence Score: 0.87)\n", " \n", " None:\n", " 1. . (Confidence Score: )\n", " 2. . (Confidence Score: )\n", " \"\"\"\n", " return prompt" ] }, { "cell_type": "code", "execution_count": null, "id": "9e23821b", "metadata": {}, "outputs": [], "source": [ "def get_response_chatgpt(prompt):\n", " response=openai.ChatCompletion.create( \n", " engine=deployment_name, \n", " messages=[ \n", " {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"}, \n", " {\"role\": \"user\", \"content\": prompt} \n", " ],\n", " temperature=0\n", " )\n", " reply = response[\"choices\"][0][\"message\"][\"content\"]\n", " return reply" ] }, { "cell_type": "code", "execution_count": null, "id": "983765bc", "metadata": {}, "outputs": [], "source": [ "def convert_df(result_df):\n", " new_df = pd.DataFrame(columns=['text', 'prediction'])\n", " new_df['text'] = result_df['Phrase']\n", " new_df['prediction'] = result_df.apply(lambda row: [[row['Topic'], row['Score']]], axis=1)\n", " return new_df" ] }, { "cell_type": "code", "execution_count": null, "id": "bc69cc81", "metadata": {}, "outputs": [], "source": [ "#query = \"\"\"\n", "#In preparing the positive behavior support plan, our team utilized a combination of direct and indirect data collection approaches to gain a comprehensive understanding of the person of focus and their behavior patterns. Direct data collection approaches included conducting direct observations of the individual by both the practitioner and support workers, as well as the completion of behavioral data collection tools such as ABC note cards and scatter plots. Indirect data collection methods included the use of standardized tools completed by the practitioner in consultation with relevant stakeholders who know the individual well, such as the Contextual Assessment Inventory and Functional Assessment Interview. We also conducted interviews and engaged in communication with family members, caregivers, and healthcare professionals to gather additional information about the individual's behavior patterns and needs. Additionally, we consulted relevant reports, such as previous positive behavior support plans, incident reports, and assessment reports from health and allied health professionals to gain insight into potential triggers and reinforcement patterns that may be contributing to the individual's challenging behaviors. By utilizing both direct and indirect data collection approaches, we were able to develop a comprehensive positive behavior support plan that is tailored to the unique needs of the individual.\n", "#\"\"\"\n", "#prompt = get_prompt(query)\n", "#response = get_response_chatgpt(prompt)\n", "#result_df = process_response(response, query)\n", "#result_df" ] }, { "cell_type": "code", "execution_count": null, "id": "905eaf2a", "metadata": {}, "outputs": [], "source": [ "topic_color_dict = {\n", " 'DIRECT': '#90EE90',\n", " 'INDIRECT': '#F08080',\n", " 'NONE': '#CCCCCC'\n", " }\n", "\n", "def color(df, color):\n", " return df.style.format({'Score': '{:,.2%}'.format}).bar(subset=['Score'], color=color)\n", "\n", "def annotate_query(highlights, query, topics):\n", " ents = []\n", " for h, t in zip(highlights, topics):\n", " pattern = re.escape(h)\n", " pattern = re.sub(r'\\\\(.)', r'[\\1\\\\W]*', pattern) # optional non-alphanumeric characters\n", " for match in re.finditer(pattern, query, re.IGNORECASE):\n", " ent_dict = {\"start\": match.start(), \"end\": match.end(), \"label\": t}\n", " ents.append(ent_dict)\n", " return ents\n", "\n", "def path_to_image_html(path):\n", " return ''\n", "\n", "passing_score = 0.5\n", "final_passing = 0.0\n", "def display_final_df(agg_df):\n", " tags = []\n", " crits = [\n", " 'DIRECT',\n", " 'INDIRECT'\n", " ]\n", " orig_crits = crits\n", " crits = [x for x in crits if x in agg_df.index.tolist()]\n", " bools = [agg_df.loc[crit, 'Final_Score'] > final_passing for crit in crits]\n", " paths = ['./thumbs_up.png' if x else './thumbs_down.png' for x in bools]\n", " df = pd.DataFrame({'Data Collection Category': crits, 'USED': paths})\n", " rem_crits = [x for x in orig_crits if x not in crits]\n", " if len(rem_crits) > 0:\n", " df2 = pd.DataFrame({'Data Collection Category': rem_crits, 'USED': ['./thumbs_down.png'] * len(rem_crits)})\n", " df = pd.concat([df, df2])\n", " df = df.set_index('Data Collection Category')\n", " pd.set_option('display.max_colwidth', None)\n", " display(HTML('
' + df.to_html(classes=[\"align-center\"], index=True, escape=False ,formatters=dict(USED=path_to_image_html)) + '
'))\n", " " ] }, { "cell_type": "markdown", "id": "2c6e9fe7", "metadata": {}, "source": [ "### Quality Markers:\n", "#### Q2c. The plan indicates that at least one direct data collection approach has been undertaken.\n", "\n", "#### Q2d. The plan indicates that at least one indirect data collection approach has been undertaken." ] }, { "cell_type": "code", "execution_count": null, "id": "76dd8cab", "metadata": { "scrolled": false }, "outputs": [], "source": [ "#demo with Voila\n", "\n", "bhvr_label = widgets.Label(value='Please type your answer:')\n", "bhvr_text_input = widgets.Textarea(\n", " value='',\n", " placeholder='Type your answer',\n", " description='',\n", " disabled=False,\n", " layout={'height': '300px', 'width': '90%'}\n", ")\n", "\n", "bhvr_nlp_btn = widgets.Button(\n", " description='Score Answer',\n", " disabled=False,\n", " button_style='success', # 'success', 'info', 'warning', 'danger' or ''\n", " tooltip='Score Answer',\n", " icon='check',\n", " layout={'height': '70px', 'width': '250px'}\n", ")\n", "bhvr_agr_btn = widgets.Button(\n", " description='Validate Data',\n", " disabled=False,\n", " button_style='success', # 'success', 'info', 'warning', 'danger' or ''\n", " tooltip='Validate Data',\n", " icon='check',\n", " layout={'height': '70px', 'width': '250px'}\n", ")\n", "bhvr_eval_btn = widgets.Button(\n", " description='Evaluate Model',\n", " disabled=False,\n", " button_style='success', # 'success', 'info', 'warning', 'danger' or ''\n", " tooltip='Evaluate Model',\n", " icon='check',\n", " layout={'height': '70px', 'width': '250px'}\n", ")\n", "btn_box = widgets.HBox([bhvr_nlp_btn, bhvr_agr_btn, bhvr_eval_btn], \n", " layout={'width': '100%', 'height': '160%'})\n", "bhvr_outt = widgets.Output()\n", "bhvr_outt.layout.height = '100%'\n", "bhvr_outt.layout.width = '100%'\n", "bhvr_box = widgets.VBox([bhvr_text_input, btn_box, bhvr_outt], \n", " layout={'width': '100%', 'height': '160%'})\n", "dataset_rg_name = 'pbsp-page2-direct-indirect-argilla-ds'\n", "agrilla_df = None\n", "annotated = False\n", "def on_bhvr_button_next(b):\n", " global agrilla_df\n", " with bhvr_outt:\n", " clear_output()\n", " query = bhvr_text_input.value\n", " prompt = get_prompt(query)\n", " response = get_response_chatgpt(prompt)\n", " result_df = process_response(response, query)\n", " sub_result_df = result_df[(result_df['Score'] >= passing_score) & (result_df['Topic'] != 'NONE')]\n", " sub_2_result_df = result_df[result_df['Topic'] == 'NONE']\n", " highlights = []\n", " if len(sub_result_df) > 0:\n", " highlights = sub_result_df['Phrase'].tolist()\n", " highlight_topics = sub_result_df['Topic'].tolist() \n", " ents = annotate_query(highlights, query, highlight_topics)\n", " colors = {}\n", " for ent, ht in zip(ents, highlight_topics):\n", " colors[ent['label']] = topic_color_dict[ht]\n", "\n", " ex = [{\"text\": query,\n", " \"ents\": ents,\n", " \"title\": None}]\n", " title = \"Data Collection Category Highlights\"\n", " display(HTML(f'

{title}

'))\n", " html = displacy.render(ex, style=\"ent\", manual=True, jupyter=True, options={'colors': colors})\n", " display(HTML(html))\n", " title = \"Data Collection Category Classifications\"\n", " display(HTML(f'

{title}

'))\n", " for top in topic_color_dict.keys():\n", " top_result_df = sub_result_df[sub_result_df['Topic'] == top]\n", " if len(top_result_df) > 0:\n", " top_result_df = top_result_df.sort_values(by='Score', ascending=False).reset_index(drop=True)\n", " top_result_df = top_result_df.set_index('Phrase')\n", " top_result_df = top_result_df[['Score']]\n", " display(HTML(\n", " f'

{top}

'))\n", " display(color(top_result_df, topic_color_dict[top]))\n", "\n", " agg_df = sub_result_df.groupby('Topic')['Score'].sum()\n", " agg_df = agg_df.to_frame()\n", " agg_df.index.name = 'Topic'\n", " agg_df.columns = ['Total Score']\n", " agg_df = agg_df.assign(\n", " Final_Score=lambda x: x['Total Score'] / x['Total Score'].sum() * 100.00\n", " )\n", " agg_df = agg_df.sort_values(by='Final_Score', ascending=False)\n", " title = \"Data Collection Category Coverage\"\n", " display(HTML(f'

{title}

'))\n", " agg_df['Topic'] = agg_df.index\n", " rem_topics= [x for x in list(topic_color_dict.keys()) if not x in agg_df.Topic.tolist()]\n", " if len(rem_topics) > 0:\n", " rem_agg_df = pd.DataFrame({'Topic': rem_topics, 'Final_Score': 0.0, 'Total Score': 0.0})\n", " agg_df = pd.concat([agg_df, rem_agg_df])\n", " labels = agg_df['Final_Score'].round(1).astype('str') + '%'\n", " ax = agg_df.plot.bar(x='Topic', y='Final_Score', rot=0, figsize=(20, 5), align='center')\n", " for container in ax.containers:\n", " ax.bar_label(container, labels=labels)\n", " ax.yaxis.set_major_formatter(mtick.PercentFormatter())\n", " ax.legend([\"Final Score (%)\"])\n", " ax.set_xlabel('')\n", " plt.show()\n", " title = \"Final Scores\"\n", " display(HTML(f'

{title}

'))\n", " display_final_df(agg_df)\n", " if len(sub_2_result_df) > 0:\n", " sub_result_df = pd.concat([sub_result_df, sub_2_result_df]).reset_index(drop=True)\n", " agrilla_df = sub_result_df.copy()\n", " else:\n", " print(query)\n", " \n", "def on_agr_button_next(b):\n", " global agrilla_df, annotated\n", " with bhvr_outt:\n", " clear_output()\n", " if agrilla_df is not None:\n", " # convert the dataframe to the structure accepted by argilla\n", " converted_df = convert_df(agrilla_df)\n", " # convert pandas dataframe to DatasetForTextClassification\n", " dataset_rg = rg.DatasetForTextClassification.from_pandas(converted_df)\n", " # delete the old DatasetForTextClassification from the Argilla web app if exists\n", " rg.delete(dataset_rg_name, workspace=\"admin\")\n", " # load the new DatasetForTextClassification into the Argilla web app\n", " rg.log(dataset_rg, name=dataset_rg_name, workspace=\"admin\")\n", " # Make sure all classes are present for annotation\n", " rg_settings = rg.TextClassificationSettings(label_schema=list(topic_color_dict.keys()))\n", " rg.configure_dataset(name=dataset_rg_name, workspace=\"admin\", settings=rg_settings)\n", " annotated = True\n", " else:\n", " display(Markdown(\"

Please score the answer first!

\"))\n", " \n", "def on_eval_button_next(b):\n", " global annotated\n", " with bhvr_outt:\n", " clear_output()\n", " if annotated:\n", " display(f1(dataset_rg_name).visualize())\n", " else:\n", " display(Markdown(\"

Please score the answer and validate the data first!

\"))\n", "\n", "bhvr_nlp_btn.on_click(on_bhvr_button_next)\n", "bhvr_agr_btn.on_click(on_agr_button_next)\n", "bhvr_eval_btn.on_click(on_eval_button_next)\n", "\n", "display(bhvr_label, bhvr_box)" ] }, { "cell_type": "code", "execution_count": null, "id": "ed551eba", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "hide_input": false, "kernelspec": { "display_name": "Python 3.9 (Argilla)", "language": "python", "name": "argilla" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.16" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": false, "sideBar": true, "skip_h1_title": true, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": { "height": "calc(100% - 180px)", "left": "10px", "top": "150px", "width": "258.097px" }, "toc_section_display": true, "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 5 }