{ "cells": [ { "cell_type": "markdown", "id": "f56cc5ad", "metadata": {}, "source": [ "# NDIS Project - PBSP Scoring - Page 3" ] }, { "cell_type": "code", "execution_count": null, "id": "a8d844ea", "metadata": { "hide_input": false }, "outputs": [], "source": [ "import os\n", "from ipywidgets import interact\n", "import ipywidgets as widgets\n", "from IPython.display import display, clear_output, Javascript, HTML, Markdown\n", "from qdrant_client import QdrantClient\n", "from qdrant_client.http.models import Distance, VectorParams, Batch, Filter, FieldCondition, Range, MatchValue\n", "import json\n", "import spacy\n", "from spacy import displacy\n", "import nltk\n", "from nltk import sent_tokenize\n", "from sklearn.feature_extraction import text\n", "from pprint import pprint\n", "import re\n", "from flair.embeddings import TransformerDocumentEmbeddings\n", "from flair.data import Sentence\n", "from sentence_transformers import SentenceTransformer, util\n", "import pandas as pd\n", "import argilla as rg\n", "from argilla.metrics.text_classification import f1\n", "from typing import Dict\n", "from setfit import SetFitModel\n", "from tqdm import tqdm\n", "import time\n", "for i in tqdm(range(30), disable=True):\n", " time.sleep(1)" ] }, { "cell_type": "code", "execution_count": null, "id": "96b83a1d", "metadata": {}, "outputs": [], "source": [ "#initializations\n", "bhvr_onto_file = 'ontology_page3_bhvr.csv'\n", "event_onto_file = 'ontology_page3_event.csv'\n", "embedding = TransformerDocumentEmbeddings('distilbert-base-uncased')\n", "client = QdrantClient(\n", " host=os.environ[\"QDRANT_API_URL\"], \n", " api_key=os.environ[\"QDRANT_API_KEY\"],\n", " timeout=60,\n", " port=443\n", ")\n", "collection_name = \"my_collection\"\n", "model = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1')\n", "vector_dim = 384 #{distilbert-base-uncased: 768, multi-qa-MiniLM-L6-cos-v1:384}\n", "sf_trig_model_name = \"setfit-zero-shot-classification-pbsp-p3-trig\"\n", "sf_trig_model = SetFitModel.from_pretrained(f\"aammari/{sf_trig_model_name}\")\n", "sf_cons_model_name = \"setfit-zero-shot-classification-pbsp-p3-cons\"\n", "sf_cons_model = SetFitModel.from_pretrained(f\"aammari/{sf_cons_model_name}\")\n", "\n", "# download nltk 'punkt' if not available\n", "try:\n", " nltk.data.find('tokenizers/punkt')\n", "except LookupError:\n", " nltk.download('punkt')\n", "\n", "# download nltk 'averaged_perceptron_tagger' if not available\n", "try:\n", " nltk.data.find('taggers/averaged_perceptron_tagger')\n", "except LookupError:\n", " nltk.download('averaged_perceptron_tagger')\n", " \n", "#argilla\n", "rg.init(\n", " api_url=os.environ[\"ARGILLA_API_URL\"],\n", " api_key=os.environ[\"ARGILLA_API_KEY\"]\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "17fe501c", "metadata": { "hide_input": false }, "outputs": [], "source": [ "#Behaviour ontology to be loaded from a CSV file\n", "bhvr_onto_df = pd.read_csv(bhvr_onto_file, header=None).dropna()\n", "bhvr_onto_df.columns = ['text']\n", "bhvr_onto_lst = bhvr_onto_df['text'].tolist()" ] }, { "cell_type": "code", "execution_count": null, "id": "7fa6ce86", "metadata": {}, "outputs": [], "source": [ "#Setting event ontology to be loaded from a CSV file\n", "event_onto_df = pd.read_csv(event_onto_file, header=None).dropna()\n", "event_onto_df.columns = ['text']\n", "event_onto_lst = event_onto_df['text'].tolist()" ] }, { "cell_type": "code", "execution_count": null, "id": "72c2c6f9", "metadata": { "hide_input": false }, "outputs": [], "source": [ "#Text Preprocessing\n", "try:\n", " nlp = spacy.load('en_core_web_sm')\n", "except OSError:\n", " spacy.cli.download('en_core_web_sm')\n", " nlp = spacy.load('en_core_web_sm')\n", "sw_lst = text.ENGLISH_STOP_WORDS\n", "def preprocess(onto_lst):\n", " cleaned_onto_lst = []\n", " pattern = re.compile(r'^[a-z ]*$')\n", " for document in onto_lst:\n", " text = []\n", " doc = nlp(document)\n", " person_tokens = []\n", " for w in doc:\n", " if w.ent_type_ == 'PERSON':\n", " person_tokens.append(w.lemma_)\n", " for w in doc:\n", " if not w.is_stop and not w.is_punct and not w.like_num and not len(w.text.strip()) == 0 and not w.lemma_ in person_tokens:\n", " text.append(w.lemma_.lower())\n", " texts = [t for t in text if len(t) > 1 and pattern.search(t) is not None and t not in sw_lst]\n", " cleaned_onto_lst.append(\" \".join(texts))\n", " return cleaned_onto_lst\n", "\n", "cl_bhvr_onto_lst = preprocess(bhvr_onto_lst)\n", "cl_event_onto_lst = preprocess(event_onto_lst)\n", "\n", "#pprint(cl_bhvr_onto_lst)\n", "#pprint(cl_event_onto_lst)" ] }, { "cell_type": "code", "execution_count": null, "id": "a1f934eb", "metadata": {}, "outputs": [], "source": [ "#compute document embeddings\n", "\n", "# distilbert-base-uncased from Flair\n", "def embeddings(cl_onto_lst):\n", " emb_onto_lst = []\n", " for doc in cl_onto_lst:\n", " sentence = Sentence(doc)\n", " embedding.embed(sentence)\n", " emb_onto_lst.append(sentence.embedding.tolist())\n", " return emb_onto_lst\n", "\n", "# multi-qa-MiniLM-L6-cos-v1 from sentence_transformers\n", "def sentence_embeddings(cl_onto_lst):\n", " emb_onto_lst_temp = model.encode(cl_onto_lst)\n", " emb_onto_lst = [x.tolist() for x in emb_onto_lst_temp]\n", " return emb_onto_lst\n", "\n", "'''\n", "emb_bhvr_onto_lst = embeddings(cl_bhvr_onto_lst)\n", "emb_fh_onto_lst = embeddings(cl_fh_onto_lst)\n", "emb_rep_onto_lst = embeddings(cl_rep_onto_lst)\n", "'''\n", "\n", "emb_bhvr_onto_lst = sentence_embeddings(cl_bhvr_onto_lst)\n", "emb_event_onto_lst = sentence_embeddings(cl_event_onto_lst)" ] }, { "cell_type": "code", "execution_count": null, "id": "6302e312", "metadata": { "scrolled": false }, "outputs": [], "source": [ "#add to qdrant collection\n", "def add_to_collection():\n", " global cl_bhvr_onto_lst, emb_bhvr_onto_lst, cl_event_onto_lst, emb_event_onto_lst\n", " client.recreate_collection(\n", " collection_name=collection_name,\n", " vectors_config=VectorParams(size=vector_dim, distance=Distance.COSINE),\n", " )\n", " doc_count = len(emb_bhvr_onto_lst) + len(emb_event_onto_lst)\n", " ids = list(range(1, doc_count+1))\n", " payloads = [{\"ontology\": \"behaviours\", \"phrase\": x} for x in cl_bhvr_onto_lst] + \\\n", " [{\"ontology\": \"setting_events\", \"phrase\": y} for y in cl_event_onto_lst]\n", " vectors = emb_bhvr_onto_lst+emb_event_onto_lst\n", " client.upsert(\n", " collection_name=f\"{collection_name}\",\n", " points=Batch(\n", " ids=ids,\n", " payloads=payloads,\n", " vectors=vectors\n", " ),\n", " )\n", "\n", "def count_collection():\n", " return len(client.scroll(\n", " collection_name=f\"{collection_name}\"\n", " )[0])\n", "\n", "add_to_collection()\n", "point_count = count_collection()" ] }, { "cell_type": "code", "execution_count": null, "id": "dff0ffdf", "metadata": {}, "outputs": [], "source": [ "#print(point_count)" ] }, { "cell_type": "code", "execution_count": null, "id": "b74861d4", "metadata": {}, "outputs": [], "source": [ "#query_filter=Filter(\n", "# must=[ \n", "# FieldCondition(\n", "# key='ontology',\n", "# match=MatchValue(value=\"setting_events\")# Condition based on values of `rand_number` field.\n", "# )\n", "# ]\n", "# )" ] }, { "cell_type": "code", "execution_count": null, "id": "40070313", "metadata": {}, "outputs": [], "source": [ "#noun phrase extraction\n", "def extract_noun_phrases(text):\n", " # Tokenize the text\n", " tokens = nltk.word_tokenize(text)\n", "\n", " # Part-of-speech tag the tokens\n", " tagged_tokens = nltk.pos_tag(tokens)\n", "\n", " # Define the noun phrase grammar\n", " grammar = r\"\"\"\n", " NP: {?*+} # noun phrase with optional determiner and adjectives\n", " {+} # proper noun phrase\n", " {?+} # noun phrase with optional possessive pronoun\n", " \"\"\"\n", "\n", " # Extract the noun phrases\n", " parser = nltk.RegexpParser(grammar)\n", " tree = parser.parse(tagged_tokens)\n", "\n", " # Extract the phrase text from the tree\n", " phrases = []\n", " for subtree in tree.subtrees():\n", " if subtree.label() == \"NP\":\n", " phrase = \" \".join([token[0] for token in subtree.leaves()])\n", " phrases.append(phrase)\n", " return phrases\n", "\n", "\n", "#verb phrase extraction\n", "def extract_vbs(data_chunked):\n", " for tup in data_chunked:\n", " if len(tup) > 2:\n", " yield(str(\" \".join(str(x[0]) for x in tup)))\n", "\n", "def get_verb_phrases(nltk_query):\n", " data_tok = nltk.word_tokenize(nltk_query) #tokenisation\n", " data_pos = nltk.pos_tag(data_tok) #POS tagging\n", " cfgs = [\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\"\n", " ]\n", " vbs = []\n", " for cfg_1 in cfgs: \n", " chunker = nltk.RegexpParser(cfg_1)\n", " data_chunked = chunker.parse(data_pos)\n", " vbs += extract_vbs(data_chunked)\n", " return vbs" ] }, { "cell_type": "code", "execution_count": null, "id": "a3a96c5f", "metadata": {}, "outputs": [], "source": [ "#text = \"The quick brown fox jumps over the lazy dog.\"\n", "#phrases = extract_noun_phrases(text)\n", "#cl_phrases = preprocess(phrases)\n", "#print(cl_phrases) " ] }, { "cell_type": "code", "execution_count": null, "id": "5db008ec", "metadata": { "scrolled": true }, "outputs": [], "source": [ "#use the get_verb_phrases function to enrich the behaviour ontology\n", "#from itertools import chain\n", "#maria_file = 'behaviour_score_maria_2.csv'\n", "#maria_df = pd.read_csv(maria_file).dropna()\n", "#clf_df = maria_df[['Behaviour', 'Behaviour Score']]\n", "#one_lst = clf_df[clf_df['Behaviour Score'] == 1]['Behaviour']\n", "#list_of_lists = [get_verb_phrases(x) for x in one_lst]\n", "#vbs = list(set(list(chain.from_iterable(list_of_lists))))\n", "#cl_vbs = preprocess(vbs)\n", "#cl_vbs = [x for x in cl_vbs if len(x.split()) > 1]\n", "#for cl_vb in cl_vbs:\n", "# print(cl_vb)" ] }, { "cell_type": "code", "execution_count": null, "id": "acc657fe", "metadata": { "scrolled": true }, "outputs": [], "source": [ "#use the get_verb_phrases function to enrich the setting events ontology\n", "#from itertools import chain\n", "#maria_file = 'behaviour_score_maria_2.csv'\n", "#maria_df = pd.read_csv(maria_file).dropna()\n", "#clf_df = maria_df[['Setting Event']]\n", "#one_lst = clf_df['Setting Event'].tolist()\n", "#list_of_lists = [get_verb_phrases(x) for x in one_lst]\n", "#vbs = list(set(list(chain.from_iterable(list_of_lists))))\n", "#cl_vbs = preprocess(vbs)\n", "#cl_vbs = list(set([x for x in cl_vbs if len(x.split()) > 1]))\n", "#for cl_vb in cl_vbs:\n", "# print(cl_vb)" ] }, { "cell_type": "code", "execution_count": null, "id": "cec885c1", "metadata": {}, "outputs": [], "source": [ "#use the extract_noun_phrases function to enrich the setting events ontology\n", "#from itertools import chain\n", "#maria_file = 'behaviour_score_maria_2.csv'\n", "#maria_df = pd.read_csv(maria_file).dropna()\n", "#clf_df = maria_df[['Setting Event']]\n", "#one_lst = clf_df['Setting Event'].tolist()\n", "#list_of_lists = [extract_noun_phrases(x) for x in one_lst]\n", "#vbs = list(set(list(chain.from_iterable(list_of_lists))))\n", "#cl_vbs = preprocess(vbs)\n", "#cl_vbs = list(set([x for x in cl_vbs if len(x.split()) > 1]))\n", "#for cl_vb in cl_vbs:\n", "# print(cl_vb)" ] }, { "cell_type": "code", "execution_count": null, "id": "1550437b", "metadata": {}, "outputs": [], "source": [ "#query and get score\n", "\n", "# distilbert-base-uncased from Flair\n", "def get_query_vector(query):\n", " sentence = Sentence(query)\n", " embedding.embed(sentence)\n", " query_vector = sentence.embedding.tolist()\n", " return query_vector\n", "\n", "# multi-qa-MiniLM-L6-cos-v1 from sentence_transformers\n", "def sentence_get_query_vector(query):\n", " query_vector = model.encode(query)\n", " return query_vector\n", "\n", "def search_collection(ontology, query_vector):\n", " query_filter=Filter(\n", " must=[ \n", " FieldCondition(\n", " key='ontology',\n", " match=MatchValue(value=ontology)\n", " )\n", " ]\n", " )\n", " \n", " hits = client.search(\n", " collection_name=f\"{collection_name}\",\n", " query_vector=query_vector,\n", " query_filter=query_filter, \n", " append_payload=True, \n", " limit=point_count \n", " )\n", " return hits\n", "\n", "semantic_passing_score = 0.50\n", "\n", "\n", "#ontology = 'behaviours'\n", "#query = 'punch father face'\n", "#query_vector = sentence_get_query_vector(query)\n", "#hist = search_collection(ontology, query_vector)" ] }, { "cell_type": "code", "execution_count": null, "id": "02fda761", "metadata": {}, "outputs": [], "source": [ "# format output\n", "def bhvr_color(df):\n", " return df.style.format({'Score': '{:,.2%}'.format}).bar(subset=['Score'], color='#90EE90')\n", "\n", "def event_color(df):\n", " return df.style.format({'Score': '{:,.2%}'.format}).bar(subset=['Score'], color='#9370DB')\n", "\n", "def annotate_query(highlights, query, topic):\n", " ents = []\n", " for h in highlights:\n", " ent_dict = {}\n", " for match in re.finditer(h, query):\n", " ent_dict = {\"start\": match.start(), \"end\": match.end(), \"label\": topic}\n", " break\n", " if len(ent_dict.keys()) > 0:\n", " ents.append(ent_dict)\n", " return ents" ] }, { "cell_type": "code", "execution_count": null, "id": "79b519a6", "metadata": {}, "outputs": [], "source": [ "#setfit trig and cons sentence extraction\n", "def extract_sentences(nltk_query):\n", " sentences = sent_tokenize(nltk_query)\n", " return sentences" ] }, { "cell_type": "code", "execution_count": null, "id": "dc30f011", "metadata": {}, "outputs": [], "source": [ "def convert_df(result_df):\n", " new_df = pd.DataFrame(columns=['text', 'prediction'])\n", " new_df['text'] = result_df['Phrase']\n", " new_df['prediction'] = result_df.apply(lambda row: [[row['Topic'], min(row['Score'], 1.0)]], axis=1)\n", " return new_df" ] }, { "cell_type": "code", "execution_count": null, "id": "8d4e4e20", "metadata": {}, "outputs": [], "source": [ "def custom_f1(data: Dict[str, float], title: str):\n", " from plotly.subplots import make_subplots\n", " import plotly.colors\n", " import random\n", "\n", " fig = make_subplots(\n", " rows=2,\n", " cols=1,\n", " subplot_titles=[ \"Overall Model Score\", \"Model Score By Category\", ],\n", " )\n", "\n", " x = ['precision', 'recall', 'f1']\n", " macro_data = [v for k, v in data.items() if \"macro\" in k]\n", " fig.add_bar(\n", " x=x,\n", " y=macro_data,\n", " row=1,\n", " col=1,\n", " )\n", " per_label = {\n", " k: v\n", " for k, v in data.items()\n", " if all(key not in k for key in [\"macro\", \"micro\", \"support\"])\n", " }\n", "\n", " num_labels = int(len(per_label.keys())/3)\n", " fixed_colors = [str(color) for color in plotly.colors.qualitative.Plotly]\n", " colors = random.sample(fixed_colors, num_labels)\n", "\n", " fig.add_bar(\n", " x=[k for k, v in per_label.items()],\n", " y=[v for k, v in per_label.items()],\n", " row=2,\n", " col=1,\n", " marker_color=[colors[int(i/3)] for i in range(0, len(per_label.keys()))]\n", " )\n", " fig.update_layout(showlegend=False, title_text=title)\n", "\n", " return fig" ] }, { "cell_type": "code", "execution_count": null, "id": "0105e10c", "metadata": {}, "outputs": [], "source": [ "def get_null_class_df(sentences, result_df):\n", " sents = result_df['Phrase'].tolist()\n", " null_sents = [x for x in sentences if x not in sents]\n", " topics = ['NONE'] * len(null_sents)\n", " scores = [0.90] * len(null_sents)\n", " null_df = pd.DataFrame({'Phrase': null_sents, 'Topic': topics, 'Score': scores})\n", " return null_df" ] }, { "cell_type": "code", "execution_count": null, "id": "0d5f8d3c", "metadata": {}, "outputs": [], "source": [ "#setfit trig query and get predicted topic\n", "\n", "def get_sf_trig_topic(sentences):\n", " preds = list(sf_trig_model(sentences))\n", " return preds\n", "def get_sf_trig_topic_scores(sentences):\n", " preds = sf_trig_model.predict_proba(sentences)\n", " preds = [max(list(x)) for x in preds]\n", " return preds" ] }, { "cell_type": "code", "execution_count": null, "id": "67bf2154", "metadata": {}, "outputs": [], "source": [ "# setfit trig format output\n", "ind_trig_topic_dict = {\n", " 0: 'NO TRIGGER',\n", " 1: 'TRIGGER',\n", " }\n", "\n", "highlight_threshold = 0.25\n", "passing_score = 0.50\n", "\n", "def sf_trig_color(df):\n", " return df.style.format({'Score': '{:,.2%}'.format}).bar(subset=['Score'], color='#ADD8E6')\n", "\n", "def sf_annotate_query(highlights, query, topics):\n", " ents = []\n", " for h, t in zip(highlights, topics):\n", " ent_dict = {}\n", " for match in re.finditer(h, query):\n", " ent_dict = {\"start\": match.start(), \"end\": match.end(), \"label\": t}\n", " break\n", " if len(ent_dict.keys()) > 0:\n", " ents.append(ent_dict)\n", " return ents" ] }, { "cell_type": "code", "execution_count": null, "id": "a66eaa42", "metadata": {}, "outputs": [], "source": [ "#setfit cons query and get predicted topic\n", "\n", "def get_sf_cons_topic(sentences):\n", " preds = list(sf_cons_model(sentences))\n", " return preds\n", "def get_sf_cons_topic_scores(sentences):\n", " preds = sf_cons_model.predict_proba(sentences)\n", " preds = [max(list(x)) for x in preds]\n", " return preds" ] }, { "cell_type": "code", "execution_count": null, "id": "91774d6f", "metadata": {}, "outputs": [], "source": [ "# setfit cons format output\n", "ind_cons_topic_dict = {\n", " 0: 'NO CONSEQUENCE',\n", " 1: 'CONSEQUENCE',\n", " }\n", "\n", "def sf_cons_color(df):\n", " return df.style.format({'Score': '{:,.2%}'.format}).bar(subset=['Score'], color='#F08080')" ] }, { "cell_type": "code", "execution_count": null, "id": "39ba54ec", "metadata": {}, "outputs": [], "source": [ "def rem_prev_detections(sub_current_df, sub_prev_df):\n", " prevs = sub_prev_df['Phrase'].tolist()\n", " cl_sub_current_df = sub_current_df[~sub_current_df['Phrase'].isin(prevs)]\n", " return cl_sub_current_df\n", "\n", "def path_to_image_html(path):\n", " return ''\n", "\n", "def display_final_df(tags):\n", " crits = [\n", " 'Setting Event',\n", " 'Triggers',\n", " 'Behaviour',\n", " 'Consequences'\n", " ]\n", " descs = [\n", " 'Does the plan identify at least one setting event for the behaviour?',\n", " 'Does the plan identify at least one immediate trigger/antecedent for the behaviour?',\n", " 'Does the plan identify at least one behaviour?',\n", " 'Does the plan identify at least one maintaining consequence for the behaviour?'\n", " ]\n", " paths = ['./thumbs_up.png' if x else './thumbs_down.png' for x in tags]\n", " df = pd.DataFrame({'Criteria': crits, 'Descrption': descs, 'Score': paths})\n", " df = df.set_index('Criteria')\n", " pd.set_option('display.max_colwidth', None)\n", " display(HTML('
' + df.to_html(classes=[\"align-center\"], index=True, escape=False ,formatters=dict(Score=path_to_image_html)) + '
'))" ] }, { "cell_type": "markdown", "id": "2c6e9fe7", "metadata": {}, "source": [ "### Please complete the following A-B-C chain to demonstrate how the identified triggers are linked to the person’s behaviour, and what happens after the behaviour to reinforce it, and therefore maintain the consequences. Also include setting events" ] }, { "cell_type": "code", "execution_count": null, "id": "76dd8cab", "metadata": { "scrolled": false }, "outputs": [], "source": [ "#demo with Voila\n", "\n", "event_label = widgets.Label(value = r'\\(\\color{purple} {' + 'Setting Events:' + '}\\)')\n", "event_text_input = widgets.Textarea(\n", " value='',\n", " placeholder='Type your answer',\n", " description='',\n", " disabled=False,\n", " layout={'height': '100%', 'width': '90%'}\n", ")\n", "\n", "trig_label = widgets.Label(value = r'\\(\\color{blue} {' + 'Triggers:' + '}\\)')\n", "trig_text_input = widgets.Textarea(\n", " value='',\n", " placeholder='Type your answer',\n", " description='',\n", " disabled=False,\n", " layout={'height': '100%', 'width': '90%'}\n", ")\n", "\n", "bhvr_label = widgets.Label(value = r'\\(\\color{green} {' + 'Behaviours:' + '}\\)')\n", "bhvr_text_input = widgets.Textarea(\n", " value='',\n", " placeholder='Type your answer',\n", " description='',\n", " disabled=False,\n", " layout={'height': '100%', 'width': '90%'}\n", ")\n", "\n", "cons_label = widgets.Label(value = r'\\(\\color{red} {' + 'Consequences:' + '}\\)')\n", "cons_text_input = widgets.Textarea(\n", " value='',\n", " placeholder='Type your answer',\n", " description='',\n", " disabled=False,\n", " layout={'height': '100%', 'width': '90%'}\n", ")\n", "\n", "bhvr_nlp_btn = widgets.Button(\n", " description='Score Answer',\n", " disabled=False,\n", " button_style='success', # 'success', 'info', 'warning', 'danger' or ''\n", " tooltip='Score Answer',\n", " icon='check',\n", " layout={'height': '70px', 'width': '250px'}\n", ")\n", "bhvr_agr_btn = widgets.Button(\n", " description='Validate Data',\n", " disabled=False,\n", " button_style='success', # 'success', 'info', 'warning', 'danger' or ''\n", " tooltip='Validate Data',\n", " icon='check',\n", " layout={'height': '70px', 'width': '250px'}\n", ")\n", "bhvr_eval_btn = widgets.Button(\n", " description='Evaluate Model',\n", " disabled=False,\n", " button_style='success', # 'success', 'info', 'warning', 'danger' or ''\n", " tooltip='Evaluate Model',\n", " icon='check',\n", " layout={'height': '70px', 'width': '250px'}\n", ")\n", "btn_box = widgets.HBox([bhvr_nlp_btn, bhvr_agr_btn, bhvr_eval_btn], \n", " layout={'width': '100%', 'height': '160%'})\n", "bhvr_outt = widgets.Output()\n", "bhvr_outt.layout.height = '100%'\n", "bhvr_outt.layout.width = '100%'\n", "\n", "event_answer_box = widgets.VBox([event_label, event_text_input], \n", " layout={'width': '400px', 'height': '200px'})\n", "\n", "trig_answer_box = widgets.VBox([trig_label, trig_text_input], \n", " layout={'width': '400px', 'height': '200px'})\n", "\n", "bhvr_answer_box = widgets.VBox([bhvr_label, bhvr_text_input], \n", " layout={'width': '400px', 'height': '200px'})\n", "\n", "cons_answer_box = widgets.VBox([cons_label, cons_text_input], \n", " layout={'width': '400px', 'height': '200px'})\n", "\n", "answer_box = widgets.HBox([event_answer_box, trig_answer_box, bhvr_answer_box, cons_answer_box], \n", " layout={'width': '90%', 'height': '400px'})\n", "\n", "total_box = widgets.VBox([answer_box, btn_box, bhvr_outt], \n", " layout={'width': '100%', 'height': '100%'})\n", "dataset_rg_name = 'pbsp-page3-abc-argilla-ds'\n", "agrilla_df = None\n", "annotated = False\n", "sub_2_result_dfs = []\n", "def on_bhvr_button_next(b):\n", " global bhvr_onto_lst, cl_bhvr_onto_lst, event_onto_lst, cl_event_onto_lst, agrilla_df\n", " with bhvr_outt:\n", " bhvr_tag = False\n", " event_tag = False\n", " trig_tag = False\n", " cons_tag = False\n", " clear_output()\n", " #semantic behaviour\n", " orig_cl_dict = {x:y for x,y in zip(cl_bhvr_onto_lst, bhvr_onto_lst)}\n", " query = bhvr_text_input.value\n", " vbs = get_verb_phrases(query)\n", " cl_vbs = preprocess(vbs)\n", " emb_vbs = sentence_embeddings(cl_vbs)\n", " vb_ind = -1\n", " highlights = []\n", " highlight_scores = []\n", " bhvr_result_dfs = []\n", " for query_vector in emb_vbs:\n", " vb_ind += 1\n", " hist = search_collection('behaviours', query_vector)\n", " hist_dict = [dict(x) for x in hist]\n", " scores = [x['score'] for x in hist_dict]\n", " payloads = [orig_cl_dict[x['payload']['phrase']] for x in hist_dict]\n", " result_df = pd.DataFrame({'Score': scores, 'Glossary': payloads})\n", " result_df = result_df[result_df['Score'] >= semantic_passing_score]\n", " if len(result_df) > 0:\n", " highlights.append(vbs[vb_ind])\n", " highlight_scores.append(result_df.Score.max())\n", " result_df['Phrase'] = [vbs[vb_ind]] * len(result_df)\n", " result_df = result_df.sort_values(by='Score', ascending=False).reset_index(drop=True)\n", " bhvr_result_dfs.append(result_df)\n", " else:\n", " continue\n", " ents = []\n", " colors = {}\n", " if len(highlights) > 0:\n", " ents = annotate_query(highlights, query, \"BEHAVIOUR\")\n", " for ent in ents:\n", " colors[ent['label']] = '#90EE90'\n", " options = {\"ents\": list(colors), \"colors\": colors}\n", " ex = [{\"text\": query,\n", " \"ents\": ents,\n", " \"title\": None}]\n", " if len(ents) > 0:\n", " title = \"Behaviour Phrases\"\n", " display(HTML(f'

{title}

'))\n", " html = displacy.render(ex, style=\"ent\", manual=True, options=options)\n", " display(HTML(html))\n", " else:\n", " pass\n", " \n", " if len(bhvr_result_dfs) > 0:\n", " bhvr_tag = True\n", " result_df = pd.concat(bhvr_result_dfs).reset_index(drop = True)\n", " result_df = result_df.sort_values(by='Score', ascending=False).reset_index(drop=True)\n", " sub_2_result_df = result_df.copy()\n", " sub_2_result_df['Topic'] = ['BEHAVIOUR'] * len(result_df)\n", " sub_2_result_df = sub_2_result_df[['Phrase', 'Topic', 'Score']].drop_duplicates().reset_index(drop=True)\n", " null_df = get_null_class_df(vbs, sub_2_result_df)\n", " if len(null_df) > 0:\n", " sub_2_result_df = pd.concat([sub_2_result_df, null_df]).reset_index(drop=True)\n", " sub_2_result_dfs.append(sub_2_result_df)\n", " agg_df = result_df.groupby(result_df.Phrase).max()\n", " agg_df['Phrase'] = agg_df.index\n", " agg_df = agg_df.reset_index(drop=True)\n", " agg_df = agg_df.drop(columns=['Glossary'])\n", " result_df = pd.merge(result_df, agg_df, 'inner', ['Phrase', 'Score'])\n", " result_df = result_df[['Phrase', 'Glossary', 'Score']]\n", " result_df = result_df.set_index('Phrase')\n", " display(bhvr_color(result_df))\n", " \n", " #semantic setting events\n", " orig_cl_dict = {x:y for x,y in zip(cl_event_onto_lst, event_onto_lst)}\n", " query = event_text_input.value\n", " vbs = get_verb_phrases(query)\n", " cl_vbs = preprocess(vbs)\n", " nouns = extract_noun_phrases(query)\n", " cl_nouns = preprocess(nouns)\n", " sents = vbs+nouns\n", " emb_sents = sentence_embeddings(cl_vbs+cl_nouns)\n", " vb_ind = -1\n", " highlights = []\n", " highlight_scores = []\n", " event_result_dfs = []\n", " for query_vector in emb_sents:\n", " vb_ind += 1\n", " if len(sents[vb_ind].split()) <= 1:\n", " continue\n", " hist = search_collection('setting_events', query_vector)\n", " hist_dict = [dict(x) for x in hist]\n", " scores = [x['score'] for x in hist_dict]\n", " payloads = [orig_cl_dict[x['payload']['phrase']] for x in hist_dict]\n", " result_df = pd.DataFrame({'Score': scores, 'Glossary': payloads})\n", " result_df = result_df[result_df['Score'] >= semantic_passing_score]\n", " if len(result_df) > 0:\n", " highlights.append(sents[vb_ind])\n", " highlight_scores.append(result_df.Score.max())\n", " result_df['Phrase'] = [sents[vb_ind]] * len(result_df)\n", " result_df = result_df.sort_values(by='Score', ascending=False).reset_index(drop=True)\n", " event_result_dfs.append(result_df)\n", " else:\n", " continue\n", " event_ents = []\n", " colors = {}\n", " if len(highlights) > 0:\n", " event_ents = annotate_query(highlights, query, \"SETTING EVENT\")\n", " for ent in event_ents:\n", " colors[ent['label']] = '#9370DB'\n", " options = {\"ents\": list(colors), \"colors\": colors}\n", " ex = [{\"text\": query,\n", " \"ents\": event_ents,\n", " \"title\": None}]\n", " if len(event_ents) > 0:\n", " title = \"Setting Event Phrases\"\n", " display(HTML(f'

{title}

'))\n", " html = displacy.render(ex, style=\"ent\", manual=True, options=options)\n", " display(HTML(html))\n", " else:\n", " pass\n", "\n", " if len(event_result_dfs) > 0:\n", " event_tag = True\n", " result_df = pd.concat(event_result_dfs).reset_index(drop = True)\n", " result_df = result_df.sort_values(by='Score', ascending=False).reset_index(drop=True)\n", " sub_2_result_df = result_df.copy()\n", " sub_2_result_df['Topic'] = ['SETTING EVENT'] * len(result_df)\n", " sub_2_result_df = sub_2_result_df[['Phrase', 'Topic', 'Score']].drop_duplicates().reset_index(drop=True)\n", " null_df = get_null_class_df(sents, sub_2_result_df)\n", " if len(null_df) > 0:\n", " sub_2_result_df = pd.concat([sub_2_result_df, null_df]).reset_index(drop=True)\n", " sub_2_result_dfs.append(sub_2_result_df)\n", " agg_df = result_df.groupby(result_df.Phrase).max()\n", " agg_df['Phrase'] = agg_df.index\n", " agg_df = agg_df.reset_index(drop=True)\n", " agg_df = agg_df.drop(columns=['Glossary'])\n", " result_df = pd.merge(result_df, agg_df, 'inner', ['Phrase', 'Score'])\n", " result_df = result_df[['Phrase', 'Glossary', 'Score']]\n", " result_df = result_df.drop_duplicates()\n", " result_df = result_df.set_index('Phrase')\n", " #display(result_df)\n", " display(event_color(result_df))\n", " \n", " #setfit trig\n", " query = trig_text_input.value\n", " sentences = extract_sentences(query)\n", " cl_sentences = preprocess(sentences)\n", " topic_inds = get_sf_trig_topic(cl_sentences)\n", " topics = [ind_trig_topic_dict[i] for i in topic_inds]\n", " scores = get_sf_trig_topic_scores(cl_sentences)\n", " sf_trig_result_df = pd.DataFrame({'Phrase': sentences, 'Topic': topics, 'Score': scores})\n", " sf_trig_sub_result_df = sf_trig_result_df[sf_trig_result_df['Topic'] == 'TRIGGER']\n", " sub_2_result_df = sf_trig_result_df[sf_trig_result_df['Topic'] == 'NO TRIGGER']\n", " sub_2_result_df = pd.concat([sub_2_result_df, sf_trig_sub_result_df]).reset_index(drop=True)\n", " sub_2_result_dfs.append(sub_2_result_df)\n", " sf_trig_highlights = []\n", " sf_trig_ents = []\n", " colors = {}\n", " if len(sf_trig_sub_result_df) > 0:\n", " sf_trig_highlights = sf_trig_sub_result_df['Phrase'].tolist()\n", " sf_trig_highlight_topics = sf_trig_sub_result_df['Topic'].tolist()\n", " sf_trig_highlight_scores = sf_trig_sub_result_df['Score'].tolist() \n", " sf_trig_ents = sf_annotate_query(sf_trig_highlights, query, sf_trig_highlight_topics)\n", " for ent, hs in zip(sf_trig_ents, sf_trig_highlight_scores):\n", " if hs >= passing_score:\n", " colors[ent['label']] = '#ADD8E6'\n", " else:\n", " colors[ent['label']] = '#FFCC66'\n", " options = {\"ents\": list(colors), \"colors\": colors}\n", " ex = [{\"text\": query,\n", " \"ents\": sf_trig_ents,\n", " \"title\": None}]\n", " if len(sf_trig_ents) > 0:\n", " title = \"Trigger Phrases\"\n", " display(HTML(f'

{title}

'))\n", " html = displacy.render(ex, style=\"ent\", manual=True, options=options)\n", " display(HTML(html))\n", " else:\n", " pass\n", " \n", " if len(sf_trig_sub_result_df) > 0:\n", " trig_tag = True\n", " result_df = sf_trig_sub_result_df.sort_values(by='Score', ascending=False).reset_index(drop=True)\n", " result_df = result_df.set_index('Phrase')\n", " display(sf_trig_color(result_df))\n", " \n", " #setfit cons\n", " query = cons_text_input.value\n", " sentences = extract_sentences(query)\n", " cl_sentences = preprocess(sentences)\n", " topic_inds = get_sf_cons_topic(cl_sentences)\n", " topics = [ind_cons_topic_dict[i] for i in topic_inds]\n", " scores = get_sf_cons_topic_scores(cl_sentences)\n", " sf_cons_result_df = pd.DataFrame({'Phrase': sentences, 'Topic': topics, 'Score': scores})\n", " sf_cons_sub_result_df = sf_cons_result_df[sf_cons_result_df['Topic'] == 'CONSEQUENCE']\n", " sub_2_result_df = sf_cons_result_df[sf_cons_result_df['Topic'] == 'NO CONSEQUENCE']\n", " sub_2_result_df = pd.concat([sub_2_result_df, sf_cons_sub_result_df]).reset_index(drop=True)\n", " sub_2_result_dfs.append(sub_2_result_df)\n", " sf_cons_highlights = []\n", " sf_cons_ents = []\n", " colors = {}\n", " if len(sf_cons_sub_result_df) > 0:\n", " sf_cons_highlights = sf_cons_sub_result_df['Phrase'].tolist()\n", " sf_cons_highlight_topics = sf_cons_sub_result_df['Topic'].tolist()\n", " sf_cons_highlight_scores = sf_cons_sub_result_df['Score'].tolist() \n", " sf_cons_ents = sf_annotate_query(sf_cons_highlights, query, sf_cons_highlight_topics)\n", " for ent, hs in zip(sf_cons_ents, sf_cons_highlight_scores):\n", " if hs >= passing_score:\n", " colors[ent['label']] = '#F08080'\n", " else:\n", " colors[ent['label']] = '#FFCC66'\n", " options = {\"ents\": list(colors), \"colors\": colors}\n", " ex = [{\"text\": query,\n", " \"ents\": sf_cons_ents,\n", " \"title\": None}]\n", " if len(sf_cons_ents) > 0:\n", " title = \"Consequence Phrases\"\n", " display(HTML(f'

{title}

'))\n", " html = displacy.render(ex, style=\"ent\", manual=True, options=options)\n", " display(HTML(html))\n", " else:\n", " pass\n", " \n", " if len(sf_cons_sub_result_df) > 0:\n", " cons_tag = True\n", " result_df = sf_cons_sub_result_df.sort_values(by='Score', ascending=False).reset_index(drop=True)\n", " result_df = result_df.set_index('Phrase')\n", " display(sf_cons_color(result_df))\n", " \n", " title = \"Final Scores\"\n", " display(HTML(f'

{title}

'))\n", " display_final_df([event_tag, trig_tag, bhvr_tag, cons_tag])\n", " if len(sub_2_result_dfs) > 0:\n", " sub_2_result_df = pd.concat(sub_2_result_dfs).reset_index(drop=True)\n", " agrilla_df = sub_2_result_df.copy()\n", "\n", "def on_agr_button_next(b):\n", " global agrilla_df, annotated\n", " with bhvr_outt:\n", " clear_output()\n", " if agrilla_df is not None:\n", " # convert the dataframe to the structure accepted by argilla\n", " converted_df = convert_df(agrilla_df)\n", " # convert pandas dataframe to DatasetForTextClassification\n", " dataset_rg = rg.DatasetForTextClassification.from_pandas(converted_df)\n", " # delete the old DatasetForTextClassification from the Argilla web app if exists\n", " rg.delete(dataset_rg_name)\n", " # load the new DatasetForTextClassification into the Argilla web app\n", " rg.log(dataset_rg, name=dataset_rg_name)\n", " annotated = True\n", " else:\n", " display(Markdown(\"

Please score the answer first!

\"))\n", " \n", "def on_eval_button_next(b):\n", " global annotated\n", " with bhvr_outt:\n", " clear_output()\n", " if annotated:\n", " data = dict(f1(dataset_rg_name))['data']\n", " display(custom_f1(data, \"Model Evaluation Results\"))\n", " else:\n", " display(Markdown(\"

Please score the answer and validate the data first!

\"))\n", "\n", "bhvr_nlp_btn.on_click(on_bhvr_button_next)\n", "bhvr_agr_btn.on_click(on_agr_button_next)\n", "bhvr_eval_btn.on_click(on_eval_button_next)\n", "\n", "display(total_box)" ] } ], "metadata": { "hide_input": false, "kernelspec": { "display_name": "Python 3.9 (Argilla)", "language": "python", "name": "argilla" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.16" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": false, "sideBar": true, "skip_h1_title": true, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": { "height": "calc(100% - 180px)", "left": "10px", "top": "150px", "width": "258.097px" }, "toc_section_display": true, "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 5 }