{ "cells": [ { "cell_type": "markdown", "id": "f56cc5ad", "metadata": {}, "source": [ "# NDIS Project - OpenAI - PBSP Scoring - Page 4 - Strategies to address Setting Events, Triggers, Others" ] }, { "cell_type": "code", "execution_count": null, "id": "a8d844ea", "metadata": { "hide_input": false }, "outputs": [], "source": [ "import openai\n", "import re\n", "from ipywidgets import interact\n", "import ipywidgets as widgets\n", "from IPython.display import display, clear_output, Javascript, HTML, Markdown\n", "import matplotlib.pyplot as plt\n", "import matplotlib.ticker as mtick\n", "import json\n", "import spacy\n", "from spacy import displacy\n", "from dotenv import load_dotenv\n", "import pandas as pd\n", "import argilla as rg\n", "from argilla.metrics.text_classification import f1\n", "import warnings\n", "warnings.filterwarnings('ignore')\n", "%matplotlib inline\n", "pd.set_option('display.max_rows', 500)\n", "pd.set_option('display.max_colwidth', 10000)\n", "pd.set_option('display.width', 10000)" ] }, { "cell_type": "code", "execution_count": null, "id": "96b83a1d", "metadata": {}, "outputs": [], "source": [ "#initializations\n", "openai.api_key = os.environ['API_KEY']\n", "openai.api_base = os.environ['API_BASE']\n", "openai.api_type = os.environ['API_TYPE']\n", "openai.api_version = os.environ['API_VERSION']\n", "deployment_name = os.environ['DEPLOYMENT_ID']\n", "\n", "#argilla\n", "rg.init(\n", " api_url=os.environ[\"ARGILLA_API_URL\"],\n", " api_key=os.environ[\"ARGILLA_API_KEY\"]\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "dee25d82", "metadata": {}, "outputs": [], "source": [ "#sentence extraction\n", "def extract_sentences(paragraph):\n", " symbols = ['\\\\.', '!', '\\\\?', ';', ':', ',', '\\\\_', '\\n', '\\\\-']\n", " pattern = '|'.join([f'{symbol}' for symbol in symbols])\n", " sentences = re.split(pattern, paragraph)\n", " sentences = [sentence.strip() for sentence in sentences if sentence.strip()]\n", " return sentences" ] }, { "cell_type": "code", "execution_count": null, "id": "08365143", "metadata": {}, "outputs": [], "source": [ "def filter_dataframe(result_df, paragraph):\n", " filtered_df = result_df[result_df['Phrase'].apply(lambda x: x.lower() in paragraph.lower() or \n", " x.lower().replace(\"’s\",\"s'\") in paragraph.lower())]\n", " filtered_df['Match_Percentage'] = filtered_df.apply(lambda row: len(set(row['Phrase'].lower()) & set(paragraph.lower())) / len(set(row['Phrase'].lower())), axis=1)\n", " filtered_df = filtered_df[filtered_df['Match_Percentage'] >= 0.9]\n", " filtered_df = filtered_df.drop(['Match_Percentage'], axis=1)\n", " return filtered_df" ] }, { "cell_type": "code", "execution_count": null, "id": "02fda761", "metadata": {}, "outputs": [], "source": [ "def process_response(response, query):\n", " sentences = []\n", " topics = []\n", " scores = []\n", " lines = response.strip().split(\"\\n\")\n", " topic = None\n", " for line in lines:\n", " if \"Setting event strategies:\" in line:\n", " topic = \"SE STRATEGY\"\n", " elif \"Trigger strategies:\" in line:\n", " topic = \"TRIGGER STRATEGY\"\n", " elif \"Other strategies:\" in line:\n", " topic = \"OTHER STRATEGY\"\n", " elif \"None:\" in line:\n", " topic = \"NO STRATEGY\"\n", " else:\n", " try:\n", " parts = line.split(\"(Confidence Score:\")\n", " if len(parts) == 2:\n", " phrase = parts[0].strip()\n", " score = float(parts[1].strip().replace(\")\", \"\"))\n", " sentences.append(phrase)\n", " topics.append(topic)\n", " scores.append(score)\n", " except:\n", " pass\n", " result_df = pd.DataFrame({'Phrase': sentences, 'Topic': topics, 'Score': scores})\n", " try:\n", " result_df['Phrase'] = result_df['Phrase'].str.replace('\\d+\\.', '', regex=True)\n", " result_df['Phrase'] = result_df['Phrase'].str.replace('^\\s', '', regex=True)\n", " result_df = filter_dataframe(result_df, query)\n", " except:\n", " sentences = extract_sentences(query)\n", " topics = ['NO STRATEGY'] * len(sentences)\n", " scores = [0.9] * len(sentences)\n", " result_df = pd.DataFrame({'Phrase': sentences, 'Topic': topics, 'Score': scores})\n", " return result_df" ] }, { "cell_type": "code", "execution_count": null, "id": "714fafb4", "metadata": {}, "outputs": [], "source": [ "def get_prompt(query):\n", " prompt = f\"\"\"\n", " Given the paragraph below, identify the phrases that describe the strategies that address the identified setting events and the phrases that describe the strategies that address the triggers/antecendents, and the phrases that describe other strategies (e.g. specific support strategies, strategies to improve independence, coping or tolerance):\n", "\n", " Paragraph:\n", " {query}\n", "\n", " Guidelines:\n", " The strategies that address the identified setting events are focused on reactively responding to specific negative interactions that may occur between the person with disability and co-tenants or staff members, and providing additional information and reminders to the person with disability after the situation has been de-escalated and he is open to communication. The purpose of these strategies is to help prevent future negative interactions and provide the person with disability with the necessary support to manage his behavior in challenging situations. The strategies that address the triggers/antecedents are focused on proactively addressing the underlying causes of the person with disability's behavior, such as changes in his daily schedule or sleep disturbances. These strategies aim to create a stable and predictable environment for the person with disability, where he can feel comfortable and secure, and provide him with clear communication and choices throughout his day to help him manage his behavior more effectively. There are other strategies such as specific support strategies as well as strategies to improve independence, coping or tolerance.\n", "\n", " Requirements:\n", " - Provide your answer in numbered lists. \n", " - All the phrases in your answer must be exact substrings in the original paragraph. without changing any characters.\n", " - All the upper case and lower case characters in the phrases in your answer must match the upper case and lower case characters in the original paragraph.\n", " - Start numbering the phrases under the setting event strategies from number 1. \n", " - Start numbering the phrases under the trigger strategies from number 1.\n", " - Start numbering the phrases under the other strategies from number 1.\n", " - Start each list of phrases with these titles: \"Setting event strategies:\", \"Trigger strategies:\", \"Other strategies:\"\n", " - For each phrase that belongs to any of the above group (Setting event strategies, Trigger strategies, Other strategies), provide a confidence score that ranges between 0.50 and 1.00, where a score of 0.50 means you are very weakly confident that the phrase belongs to that specific group, whereas a score of 1.00 means you are very strongly confident that the phrase belongs to that specific group.\n", " - Never include any phrase that does not exist in the paragraph above.\n", " - If there are not any phrases that belong to one or more of the groups (Setting event strategies, Trigger strategies, Other strategies), then do not include these groups in your answer. \n", " - Never change the wording on any phrase in the paragraph above.\n", " - Include a final numbered list titled \"None:\", which include all the remaining phrases from the paragraph above that do not belong to any of the groups above. Provide a confidence score for each of these phrases as well.\n", " \n", " Example answer:\n", "\n", " Setting event strategies:\n", " 1. . (Confidence Score: )\n", " 2. . (Confidence Score: )\n", "\n", " Trigger strategies:\n", " 1. . (Confidence Score: )\n", " 2. . (Confidence Score: )\n", "\n", " Other strategies:\n", " 1. . (Confidence Score: )\n", " 2. . (Confidence Score: )\n", " \n", " None:\n", " 1. . (Confidence Score: )\n", " 2. . (Confidence Score: )\n", " \"\"\"\n", " return prompt" ] }, { "cell_type": "code", "execution_count": null, "id": "9e23821b", "metadata": {}, "outputs": [], "source": [ "def get_response_chatgpt(prompt):\n", " response=openai.ChatCompletion.create( \n", " engine=deployment_name, \n", " messages=[ \n", " {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"}, \n", " {\"role\": \"user\", \"content\": prompt} \n", " ],\n", " temperature=0\n", " )\n", " reply = response[\"choices\"][0][\"message\"][\"content\"]\n", " return reply" ] }, { "cell_type": "code", "execution_count": null, "id": "983765bc", "metadata": {}, "outputs": [], "source": [ "def convert_df(result_df):\n", " new_df = pd.DataFrame(columns=['text', 'prediction'])\n", " new_df['text'] = result_df['Phrase']\n", " new_df['prediction'] = result_df.apply(lambda row: [[row['Topic'], row['Score']]], axis=1)\n", " return new_df" ] }, { "cell_type": "code", "execution_count": null, "id": "bc69cc81", "metadata": {}, "outputs": [], "source": [ "#query = \"\"\"\n", "#In preparing the behavior support plan for Taylor, a person with disability, several strategies have been identified to address the identified setting events. To address specific negative interactions, staff will ensure Taylor receives extra information and reminders about his daily activities, and will have alternative options available in case of any changes. Additionally, staff will engage in frequent communication with Taylor to ensure he has clear expectations and choices throughout the day. Furthermore, specific support strategies, such as the use of speech and sign language and clear, concise communication, will be implemented to improve Taylor's independence and coping skills. To enhance his tolerance, staff will gradually introduce new experiences and support him through any challenges he may face. Phrases that do not describe any strategy, such as 'best practices' and 'evidence-based approaches,' will be avoided in favor of concrete, actionable strategies tailored to Taylor's unique needs.\n", "#\"\"\"\n", "#prompt = get_prompt(query)\n", "#response = get_response_chatgpt(prompt)\n", "#result_df = process_response(response, query)\n", "#result_df" ] }, { "cell_type": "code", "execution_count": null, "id": "905eaf2a", "metadata": {}, "outputs": [], "source": [ "topic_color_dict = {\n", " 'SE STRATEGY': '#90EE90',\n", " 'TRIGGER STRATEGY': '#F08080',\n", " 'OTHER STRATEGY': '#ADD8E6',\n", " 'NONE': '#CCCCCC'\n", " }\n", "\n", "def color(df, color):\n", " return df.style.format({'Score': '{:,.2%}'.format}).bar(subset=['Score'], color=color)\n", "\n", "def annotate_query(highlights, query, topics):\n", " ents = []\n", " for h, t in zip(highlights, topics):\n", " ent_dict = {}\n", " for match in re.finditer(h, query, re.IGNORECASE):\n", " ent_dict = {\"start\": match.start(), \"end\": match.end(), \"label\": t}\n", " break\n", " if len(ent_dict.keys()) > 0:\n", " ents.append(ent_dict)\n", " return ents\n", "\n", "def path_to_image_html(path):\n", " return ''\n", "\n", "passing_score = 0.7\n", "final_passing = 0.0\n", "def display_final_df(agg_df):\n", " tags = []\n", " crits = [\n", " 'SE STRATEGY',\n", " 'TRIGGER STRATEGY',\n", " 'OTHER STRATEGY'\n", " ]\n", " orig_crits = crits\n", " crits = [x for x in crits if x in agg_df.index.tolist()]\n", " bools = [agg_df.loc[crit, 'Final_Score'] > final_passing for crit in crits]\n", " paths = ['./thumbs_up.png' if x else './thumbs_down.png' for x in bools]\n", " df = pd.DataFrame({'Strategy': crits, 'USED': paths})\n", " rem_crits = [x for x in orig_crits if x not in crits]\n", " if len(rem_crits) > 0:\n", " df2 = pd.DataFrame({'Strategy': rem_crits, 'USED': ['./thumbs_down.png'] * len(rem_crits)})\n", " df = pd.concat([df, df2])\n", " df = df.set_index('Strategy')\n", " pd.set_option('display.max_colwidth', None)\n", " display(HTML('
' + df.to_html(classes=[\"align-center\"], index=True, escape=False ,formatters=dict(USED=path_to_image_html)) + '
'))\n", " " ] }, { "cell_type": "markdown", "id": "2c6e9fe7", "metadata": {}, "source": [ "### Strategies to address setting events and triggers/antecedents, and any other strategies" ] }, { "cell_type": "code", "execution_count": null, "id": "76dd8cab", "metadata": { "scrolled": false }, "outputs": [], "source": [ "#demo with Voila\n", "\n", "bhvr_label = widgets.Label(value='Please type your answer:')\n", "bhvr_text_input = widgets.Textarea(\n", " value='',\n", " placeholder='Type your answer',\n", " description='',\n", " disabled=False,\n", " layout={'height': '300px', 'width': '90%'}\n", ")\n", "\n", "bhvr_nlp_btn = widgets.Button(\n", " description='Score Answer',\n", " disabled=False,\n", " button_style='success', # 'success', 'info', 'warning', 'danger' or ''\n", " tooltip='Score Answer',\n", " icon='check',\n", " layout={'height': '70px', 'width': '250px'}\n", ")\n", "bhvr_agr_btn = widgets.Button(\n", " description='Validate Data',\n", " disabled=False,\n", " button_style='success', # 'success', 'info', 'warning', 'danger' or ''\n", " tooltip='Validate Data',\n", " icon='check',\n", " layout={'height': '70px', 'width': '250px'}\n", ")\n", "bhvr_eval_btn = widgets.Button(\n", " description='Evaluate Model',\n", " disabled=False,\n", " button_style='success', # 'success', 'info', 'warning', 'danger' or ''\n", " tooltip='Evaluate Model',\n", " icon='check',\n", " layout={'height': '70px', 'width': '250px'}\n", ")\n", "btn_box = widgets.HBox([bhvr_nlp_btn, bhvr_agr_btn, bhvr_eval_btn], \n", " layout={'width': '100%', 'height': '160%'})\n", "bhvr_outt = widgets.Output()\n", "bhvr_outt.layout.height = '100%'\n", "bhvr_outt.layout.width = '100%'\n", "bhvr_box = widgets.VBox([bhvr_text_input, btn_box, bhvr_outt], \n", " layout={'width': '100%', 'height': '160%'})\n", "dataset_rg_name = 'pbsp-page4-se-trig-other-strategy-argilla-ds'\n", "agrilla_df = None\n", "annotated = False\n", "def on_bhvr_button_next(b):\n", " global agrilla_df\n", " with bhvr_outt:\n", " clear_output()\n", " query = bhvr_text_input.value\n", " prompt = get_prompt(query)\n", " response = get_response_chatgpt(prompt)\n", " result_df = process_response(response, query)\n", " sub_result_df = result_df[(result_df['Score'] >= passing_score) & (result_df['Topic'] != 'NO STRATEGY')]\n", " sub_2_result_df = result_df[result_df['Topic'] == 'NO STRATEGY']\n", " highlights = []\n", " if len(sub_result_df) > 0:\n", " highlights = sub_result_df['Phrase'].tolist()\n", " highlight_topics = sub_result_df['Topic'].tolist() \n", " ents = annotate_query(highlights, query, highlight_topics)\n", " colors = {}\n", " for ent, ht in zip(ents, highlight_topics):\n", " colors[ent['label']] = topic_color_dict[ht]\n", "\n", " ex = [{\"text\": query,\n", " \"ents\": ents,\n", " \"title\": None}]\n", " title = \"Strategy Highlights\"\n", " display(HTML(f'

{title}

'))\n", " html = displacy.render(ex, style=\"ent\", manual=True, jupyter=True, options={'colors': colors})\n", " display(HTML(html))\n", " title = \"Strategy Classifications\"\n", " display(HTML(f'

{title}

'))\n", " for top in topic_color_dict.keys():\n", " top_result_df = sub_result_df[sub_result_df['Topic'] == top]\n", " if len(top_result_df) > 0:\n", " top_result_df = top_result_df.sort_values(by='Score', ascending=False).reset_index(drop=True)\n", " top_result_df = top_result_df.set_index('Phrase')\n", " top_result_df = top_result_df[['Score']]\n", " display(HTML(\n", " f'

{top}

'))\n", " display(color(top_result_df, topic_color_dict[top]))\n", "\n", " agg_df = sub_result_df.groupby('Topic')['Score'].sum()\n", " agg_df = agg_df.to_frame()\n", " agg_df.index.name = 'Topic'\n", " agg_df.columns = ['Total Score']\n", " agg_df = agg_df.assign(\n", " Final_Score=lambda x: x['Total Score'] / x['Total Score'].sum() * 100.00\n", " )\n", " agg_df = agg_df.sort_values(by='Final_Score', ascending=False)\n", " title = \"Strategy Coverage\"\n", " display(HTML(f'

{title}

'))\n", " agg_df['Topic'] = agg_df.index\n", " rem_topics= [x for x in list(topic_color_dict.keys()) if not x in agg_df.Topic.tolist()]\n", " if len(rem_topics) > 0:\n", " rem_agg_df = pd.DataFrame({'Topic': rem_topics, 'Final_Score': 0.0, 'Total Score': 0.0})\n", " agg_df = pd.concat([agg_df, rem_agg_df])\n", " labels = agg_df['Final_Score'].round(1).astype('str') + '%'\n", " ax = agg_df.plot.bar(x='Topic', y='Final_Score', rot=0, figsize=(20, 5), align='center')\n", " for container in ax.containers:\n", " ax.bar_label(container, labels=labels)\n", " ax.yaxis.set_major_formatter(mtick.PercentFormatter())\n", " ax.legend([\"Final Score (%)\"])\n", " ax.set_xlabel('')\n", " plt.show()\n", " title = \"Final Scores\"\n", " display(HTML(f'

{title}

'))\n", " display_final_df(agg_df)\n", " if len(sub_2_result_df) > 0:\n", " sub_result_df = pd.concat([sub_result_df, sub_2_result_df]).reset_index(drop=True)\n", " agrilla_df = sub_result_df.copy()\n", " else:\n", " print(query)\n", " \n", "def on_agr_button_next(b):\n", " global agrilla_df, annotated\n", " with bhvr_outt:\n", " clear_output()\n", " if agrilla_df is not None:\n", " # convert the dataframe to the structure accepted by argilla\n", " converted_df = convert_df(agrilla_df)\n", " # convert pandas dataframe to DatasetForTextClassification\n", " dataset_rg = rg.DatasetForTextClassification.from_pandas(converted_df)\n", " # delete the old DatasetForTextClassification from the Argilla web app if exists\n", " rg.delete(dataset_rg_name, workspace=\"admin\")\n", " # load the new DatasetForTextClassification into the Argilla web app\n", " rg.log(dataset_rg, name=dataset_rg_name, workspace=\"admin\")\n", " # Make sure all classes are present for annotation\n", " rg_settings = rg.TextClassificationSettings(label_schema=list(topic_color_dict.keys()))\n", " rg.configure_dataset(name=dataset_rg_name, workspace=\"admin\", settings=rg_settings)\n", " annotated = True\n", " else:\n", " display(Markdown(\"

Please score the answer first!

\"))\n", " \n", "def on_eval_button_next(b):\n", " global annotated\n", " with bhvr_outt:\n", " clear_output()\n", " if annotated:\n", " display(f1(dataset_rg_name).visualize())\n", " else:\n", " display(Markdown(\"

Please score the answer and validate the data first!

\"))\n", "\n", "bhvr_nlp_btn.on_click(on_bhvr_button_next)\n", "bhvr_agr_btn.on_click(on_agr_button_next)\n", "bhvr_eval_btn.on_click(on_eval_button_next)\n", "\n", "display(bhvr_label, bhvr_box)" ] }, { "cell_type": "code", "execution_count": null, "id": "ed551eba", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "hide_input": false, "kernelspec": { "display_name": "Python 3.9 (Argilla)", "language": "python", "name": "argilla" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.16" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": false, "sideBar": true, "skip_h1_title": true, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": { "height": "calc(100% - 180px)", "left": "10px", "top": "150px", "width": "258.097px" }, "toc_section_display": true, "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 5 }